An atproto PDS written in Go
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 9fbd74c10e490237e8beb34e8e9732843b6bada3 447 lines 11 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/bluesky-social/indigo/atproto/data" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/bluesky-social/indigo/carstore" 16 "github.com/bluesky-social/indigo/events" 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 "github.com/bluesky-social/indigo/repo" 19 "github.com/haileyok/cocoon/internal/db" 20 "github.com/haileyok/cocoon/models" 21 "github.com/haileyok/cocoon/recording_blockstore" 22 blocks "github.com/ipfs/go-block-format" 23 "github.com/ipfs/go-cid" 24 cbor "github.com/ipfs/go-ipld-cbor" 25 "github.com/ipld/go-car" 26 "gorm.io/gorm/clause" 27) 28 29type RepoMan struct { 30 db *db.DB 31 s *Server 32 clock *syntax.TIDClock 33} 34 35func NewRepoMan(s *Server) *RepoMan { 36 clock := syntax.NewTIDClock(0) 37 38 return &RepoMan{ 39 s: s, 40 db: s.db, 41 clock: &clock, 42 } 43} 44 45type OpType string 46 47var ( 48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create") 49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update") 50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete") 51) 52 53func (ot OpType) String() string { 54 return string(ot) 55} 56 57type Op struct { 58 Type OpType `json:"$type"` 59 Collection string `json:"collection"` 60 Rkey *string `json:"rkey,omitempty"` 61 Validate *bool `json:"validate,omitempty"` 62 SwapRecord *string `json:"swapRecord,omitempty"` 63 Record *MarshalableMap `json:"record,omitempty"` 64} 65 66type MarshalableMap map[string]any 67 68type FirehoseOp struct { 69 Cid cid.Cid 70 Path string 71 Action string 72} 73 74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 75 data, err := data.MarshalCBOR(*mm) 76 if err != nil { 77 return err 78 } 79 80 w.Write(data) 81 82 return nil 83} 84 85type ApplyWriteResult struct { 86 Type *string `json:"$type,omitempty"` 87 Uri *string `json:"uri,omitempty"` 88 Cid *string `json:"cid,omitempty"` 89 Commit *RepoCommit `json:"commit,omitempty"` 90 ValidationStatus *string `json:"validationStatus,omitempty"` 91} 92 93type RepoCommit struct { 94 Cid string `json:"cid"` 95 Rev string `json:"rev"` 96} 97 98// TODO make use of swap commit 99func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 100 rootcid, err := cid.Cast(urepo.Root) 101 if err != nil { 102 return nil, err 103 } 104 105 dbs := rm.s.getBlockstore(urepo.Did) 106 bs := recording_blockstore.New(dbs) 107 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid) 108 109 entries := []models.Record{} 110 var results []ApplyWriteResult 111 112 for i, op := range writes { 113 if op.Type != OpTypeCreate && op.Rkey == nil { 114 return nil, fmt.Errorf("invalid rkey") 115 } else if op.Type == OpTypeCreate && op.Rkey != nil { 116 _, _, err := r.GetRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 117 if err == nil { 118 op.Type = OpTypeUpdate 119 } 120 } else if op.Rkey == nil { 121 op.Rkey = to.StringPtr(rm.clock.Next().String()) 122 writes[i].Rkey = op.Rkey 123 } 124 125 _, err := syntax.ParseRecordKey(*op.Rkey) 126 if err != nil { 127 return nil, err 128 } 129 130 switch op.Type { 131 case OpTypeCreate: 132 j, err := json.Marshal(*op.Record) 133 if err != nil { 134 return nil, err 135 } 136 out, err := data.UnmarshalJSON(j) 137 if err != nil { 138 return nil, err 139 } 140 mm := MarshalableMap(out) 141 nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 142 if err != nil { 143 return nil, err 144 } 145 d, err := data.MarshalCBOR(mm) 146 if err != nil { 147 return nil, err 148 } 149 entries = append(entries, models.Record{ 150 Did: urepo.Did, 151 CreatedAt: rm.clock.Next().String(), 152 Nsid: op.Collection, 153 Rkey: *op.Rkey, 154 Cid: nc.String(), 155 Value: d, 156 }) 157 results = append(results, ApplyWriteResult{ 158 Type: to.StringPtr(OpTypeCreate.String()), 159 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 160 Cid: to.StringPtr(nc.String()), 161 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 162 }) 163 case OpTypeDelete: 164 var old models.Record 165 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 166 return nil, err 167 } 168 entries = append(entries, models.Record{ 169 Did: urepo.Did, 170 Nsid: op.Collection, 171 Rkey: *op.Rkey, 172 Value: old.Value, 173 }) 174 err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 175 if err != nil { 176 return nil, err 177 } 178 results = append(results, ApplyWriteResult{ 179 Type: to.StringPtr(OpTypeDelete.String()), 180 }) 181 case OpTypeUpdate: 182 j, err := json.Marshal(*op.Record) 183 if err != nil { 184 return nil, err 185 } 186 out, err := data.UnmarshalJSON(j) 187 if err != nil { 188 return nil, err 189 } 190 mm := MarshalableMap(out) 191 nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 192 if err != nil { 193 return nil, err 194 } 195 d, err := data.MarshalCBOR(mm) 196 if err != nil { 197 return nil, err 198 } 199 entries = append(entries, models.Record{ 200 Did: urepo.Did, 201 CreatedAt: rm.clock.Next().String(), 202 Nsid: op.Collection, 203 Rkey: *op.Rkey, 204 Cid: nc.String(), 205 Value: d, 206 }) 207 results = append(results, ApplyWriteResult{ 208 Type: to.StringPtr(OpTypeUpdate.String()), 209 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 210 Cid: to.StringPtr(nc.String()), 211 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 212 }) 213 } 214 } 215 216 newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor) 217 if err != nil { 218 return nil, err 219 } 220 221 buf := new(bytes.Buffer) 222 223 hb, err := cbor.DumpObject(&car.CarHeader{ 224 Roots: []cid.Cid{newroot}, 225 Version: 1, 226 }) 227 228 if _, err := carstore.LdWrite(buf, hb); err != nil { 229 return nil, err 230 } 231 232 diffops, err := r.DiffSince(context.TODO(), rootcid) 233 if err != nil { 234 return nil, err 235 } 236 237 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 238 239 for _, op := range diffops { 240 var c cid.Cid 241 switch op.Op { 242 case "add", "mut": 243 kind := "create" 244 if op.Op == "mut" { 245 kind = "update" 246 } 247 248 c = op.NewCid 249 ll := lexutil.LexLink(op.NewCid) 250 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 251 Action: kind, 252 Path: op.Rpath, 253 Cid: &ll, 254 }) 255 256 case "del": 257 c = op.OldCid 258 ll := lexutil.LexLink(op.OldCid) 259 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 260 Action: "delete", 261 Path: op.Rpath, 262 Cid: nil, 263 Prev: &ll, 264 }) 265 } 266 267 blk, err := dbs.Get(context.TODO(), c) 268 if err != nil { 269 return nil, err 270 } 271 272 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 273 return nil, err 274 } 275 } 276 277 for _, op := range bs.GetLogMap() { 278 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 279 return nil, err 280 } 281 } 282 283 var blobs []lexutil.LexLink 284 for _, entry := range entries { 285 var cids []cid.Cid 286 if entry.Cid != "" { 287 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ 288 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 289 UpdateAll: true, 290 }}).Error; err != nil { 291 return nil, err 292 } 293 294 cids, err = rm.incrementBlobRefs(urepo, entry.Value) 295 if err != nil { 296 return nil, err 297 } 298 } else { 299 if err := rm.s.db.Delete(&entry, nil).Error; err != nil { 300 return nil, err 301 } 302 cids, err = rm.decrementBlobRefs(urepo, entry.Value) 303 if err != nil { 304 return nil, err 305 } 306 } 307 308 for _, c := range cids { 309 blobs = append(blobs, lexutil.LexLink(c)) 310 } 311 } 312 313 rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 314 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 315 Repo: urepo.Did, 316 Blocks: buf.Bytes(), 317 Blobs: blobs, 318 Rev: rev, 319 Since: &urepo.Rev, 320 Commit: lexutil.LexLink(newroot), 321 Time: time.Now().Format(time.RFC3339Nano), 322 Ops: ops, 323 TooBig: false, 324 }, 325 }) 326 327 if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil { 328 return nil, err 329 } 330 331 for i := range results { 332 results[i].Type = to.StringPtr(*results[i].Type + "Result") 333 results[i].Commit = &RepoCommit{ 334 Cid: newroot.String(), 335 Rev: rev, 336 } 337 } 338 339 return results, nil 340} 341 342func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 343 c, err := cid.Cast(urepo.Root) 344 if err != nil { 345 return cid.Undef, nil, err 346 } 347 348 dbs := rm.s.getBlockstore(urepo.Did) 349 bs := recording_blockstore.New(dbs) 350 351 r, err := repo.OpenRepo(context.TODO(), bs, c) 352 if err != nil { 353 return cid.Undef, nil, err 354 } 355 356 _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey) 357 if err != nil { 358 return cid.Undef, nil, err 359 } 360 361 return c, bs.GetLogArray(), nil 362} 363 364func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 365 cids, err := getBlobCidsFromCbor(cbor) 366 if err != nil { 367 return nil, err 368 } 369 370 for _, c := range cids { 371 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 372 return nil, err 373 } 374 } 375 376 return cids, nil 377} 378 379func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 380 cids, err := getBlobCidsFromCbor(cbor) 381 if err != nil { 382 return nil, err 383 } 384 385 for _, c := range cids { 386 var res struct { 387 ID uint 388 Count int 389 } 390 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 391 return nil, err 392 } 393 394 if res.Count == 0 { 395 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 396 return nil, err 397 } 398 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 399 return nil, err 400 } 401 } 402 } 403 404 return cids, nil 405} 406 407// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional 408// unmarshal here. this will work for now though 409func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 410 var cids []cid.Cid 411 412 decoded, err := data.UnmarshalCBOR(cbor) 413 if err != nil { 414 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 415 } 416 417 var deepiter func(any) error 418 deepiter = func(item any) error { 419 switch val := item.(type) { 420 case map[string]any: 421 if val["$type"] == "blob" { 422 if ref, ok := val["ref"].(string); ok { 423 c, err := cid.Parse(ref) 424 if err != nil { 425 return err 426 } 427 cids = append(cids, c) 428 } 429 for _, v := range val { 430 return deepiter(v) 431 } 432 } 433 case []any: 434 for _, v := range val { 435 deepiter(v) 436 } 437 } 438 439 return nil 440 } 441 442 if err := deepiter(decoded); err != nil { 443 return nil, err 444 } 445 446 return cids, nil 447}