A very experimental PLC implementation which uses BFT consensus for decentralization

Compare changes

Choose any two refs to compare.

+33 -56
abciapp/app.go
··· 8 8 "sync" 9 9 "time" 10 10 11 + dbm "github.com/cometbft/cometbft-db" 11 12 abcitypes "github.com/cometbft/cometbft/abci/types" 12 13 "github.com/cosmos/iavl" 13 - "github.com/dgraph-io/badger/v4" 14 14 "github.com/palantir/stacktrace" 15 15 "github.com/samber/lo" 16 - "tangled.org/gbl08ma.com/didplcbft/badgeradapter" 16 + "tangled.org/gbl08ma.com/didplcbft/dbadapter" 17 17 "tangled.org/gbl08ma.com/didplcbft/plc" 18 18 "tangled.org/gbl08ma.com/didplcbft/store" 19 + "tangled.org/gbl08ma.com/didplcbft/transaction" 19 20 ) 20 21 21 22 type DIDPLCApplication struct { 22 - runnerContext context.Context 23 - plc plc.PLC 24 - tree *iavl.MutableTree 25 - fullyClearTree func() error 23 + runnerContext context.Context 24 + plc plc.PLC 25 + txFactory *transaction.Factory 26 + tree *iavl.MutableTree 27 + fullyClearApplicationData func() error 28 + 29 + ongoingRead transaction.Read 30 + ongoingWrite transaction.Write 26 31 27 32 snapshotDirectory string 28 33 snapshotApplier *snapshotApplier ··· 34 39 } 35 40 36 41 // store and plc must be able to share transaction objects 37 - func NewDIDPLCApplication(badgerDB *badger.DB, snapshotDirectory string) (*DIDPLCApplication, plc.PLC, func(), error) { 38 - treePrefix := []byte{} 42 + func NewDIDPLCApplication(treeDB dbm.DB, indexDB dbm.DB, clearData func() (dbm.DB, dbm.DB), snapshotDirectory string) (*DIDPLCApplication, *transaction.Factory, plc.PLC, func(), error) { 39 43 mkTree := func() *iavl.MutableTree { 40 - return iavl.NewMutableTree(badgeradapter.AdaptBadger(badgerDB, treePrefix), 2048, false, iavl.NewNopLogger(), iavl.AsyncPruningOption(true)) 44 + return iavl.NewMutableTree(dbadapter.Adapt(treeDB), 2048, false, iavl.NewNopLogger(), iavl.AsyncPruningOption(false)) 41 45 } 42 46 43 47 tree := mkTree() 44 48 45 49 _, err := tree.Load() 46 50 if err != nil { 47 - return nil, nil, func() {}, stacktrace.Propagate(err, "error loading latest version of the tree from storage") 51 + return nil, nil, nil, func() {}, stacktrace.Propagate(err, "error loading latest version of the tree from storage") 48 52 } 49 53 50 - err = os.MkdirAll(snapshotDirectory, os.FileMode(0755)) 51 - if err != nil { 52 - return nil, nil, func() {}, stacktrace.Propagate(err, "") 54 + if snapshotDirectory != "" { 55 + err = os.MkdirAll(snapshotDirectory, os.FileMode(0755)) 56 + if err != nil { 57 + return nil, nil, nil, func() {}, stacktrace.Propagate(err, "") 58 + } 53 59 } 54 60 55 61 d := &DIDPLCApplication{ 56 62 runnerContext: context.Background(), 57 63 tree: tree, 64 + txFactory: transaction.NewFactory(tree, indexDB, store.Tree.NextOperationSequence), 58 65 snapshotDirectory: snapshotDirectory, 59 66 aocsByPLC: make(map[string]*authoritativeOperationsCache), 60 67 } 61 - d.fullyClearTree = func() error { 68 + 69 + d.fullyClearApplicationData = func() error { 62 70 // we assume this is called in a single-threaded context, which should be a safe assumption since we'll only call this during snapshot import 63 71 // and CometBFT only calls one ABCI method at a time 64 72 err := d.tree.Close() ··· 66 74 return stacktrace.Propagate(err, "") 67 75 } 68 76 69 - if len(treePrefix) == 0 { 70 - // this is probably slightly more efficient when we don't actually need to clear a prefix 71 - err = badgerDB.DropAll() 72 - if err != nil { 73 - return stacktrace.Propagate(err, "") 74 - } 75 - } else { 76 - err = badgerDB.DropPrefix(treePrefix) 77 - if err != nil { 78 - return stacktrace.Propagate(err, "") 79 - } 80 - } 77 + treeDB, indexDB = clearData() 81 78 82 79 *d.tree = *mkTree() 80 + 81 + d.txFactory = transaction.NewFactory(tree, indexDB, store.Tree.NextOperationSequence) 83 82 return nil 84 83 } 85 84 86 - d.plc = plc.NewPLC(d) 85 + d.plc = plc.NewPLC() 87 86 88 87 lastSnapshotVersion := tree.Version() 89 88 ··· 97 96 case <-time.After(5 * time.Minute): 98 97 } 99 98 treeVersion := tree.Version() 100 - if treeVersion > int64(lastSnapshotVersion+1000) { 99 + if treeVersion > int64(lastSnapshotVersion+10000) { 101 100 err = d.createSnapshot(treeVersion, filepath.Join(snapshotDirectory, "snapshot.tmp")) 102 101 if err != nil { 103 102 fmt.Println("FAILED TO TAKE SNAPSHOT", stacktrace.Propagate(err, "")) ··· 138 137 fmt.Println("Imported tree hash", hex.EncodeToString(tree2.Hash()), "and version", tree2.Version()) 139 138 */ 140 139 141 - return d, d.plc, func() { 140 + return d, d.txFactory, d.plc, func() { 142 141 closeCh <- struct{}{} 143 142 wg.Wait() 144 143 lo.Must0(tree.Close()) ··· 146 145 } 147 146 148 147 var _ abcitypes.Application = (*DIDPLCApplication)(nil) 149 - var _ plc.TreeProvider = (*DIDPLCApplication)(nil) 150 148 151 - // ImmutableTree implements [plc.TreeProvider]. 152 - func (d *DIDPLCApplication) ImmutableTree(version plc.TreeVersion) (store.ReadOnlyTree, error) { 153 - if version.IsMutable() { 154 - return store.AdaptMutableTree(d.tree), nil 155 - } 156 - var v int64 157 - if version.IsCommitted() { 158 - var err error 159 - v, err = d.tree.GetLatestVersion() 160 - if err != nil { 161 - return nil, stacktrace.Propagate(err, "") 162 - } 163 - } else { 164 - var ok bool 165 - v, ok = version.SpecificVersion() 166 - if !ok { 167 - return nil, stacktrace.NewError("unsupported TreeVersion") 168 - } 149 + func (d *DIDPLCApplication) DiscardChanges() { 150 + if d.ongoingWrite != nil { 151 + d.ongoingWrite.Rollback() 169 152 } 170 - 171 - it, err := d.tree.GetImmutable(v) 172 - return store.AdaptImmutableTree(it), stacktrace.Propagate(err, "") 173 - } 174 - 175 - // MutableTree implements [plc.TreeProvider]. 176 - func (d *DIDPLCApplication) MutableTree() (*iavl.MutableTree, error) { 177 - return d.tree, nil 153 + d.ongoingWrite = nil 154 + d.ongoingRead = nil 178 155 }
+2 -9
abciapp/app_test.go
··· 4 4 "encoding/json" 5 5 "testing" 6 6 7 + dbm "github.com/cometbft/cometbft-db" 7 8 "github.com/cometbft/cometbft/abci/types" 8 - "github.com/dgraph-io/badger/v4" 9 9 cbornode "github.com/ipfs/go-ipld-cbor" 10 10 "github.com/stretchr/testify/require" 11 11 "tangled.org/gbl08ma.com/didplcbft/abciapp" ··· 21 21 } 22 22 23 23 func TestCheckTx(t *testing.T) { 24 - badgerDB, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) 25 - require.NoError(t, err) 26 - t.Cleanup(func() { 27 - err := badgerDB.Close() 28 - require.NoError(t, err) 29 - }) 30 - 31 - app, _, cleanup, err := abciapp.NewDIDPLCApplication(badgerDB, "") 24 + app, _, _, cleanup, err := abciapp.NewDIDPLCApplication(dbm.NewMemDB(), dbm.NewMemDB(), nil, "") 32 25 require.NoError(t, err) 33 26 t.Cleanup(cleanup) 34 27
+47 -19
abciapp/execution.go
··· 11 11 cbornode "github.com/ipfs/go-ipld-cbor" 12 12 "github.com/palantir/stacktrace" 13 13 "github.com/samber/lo" 14 + "github.com/samber/mo" 15 + "tangled.org/gbl08ma.com/didplcbft/transaction" 14 16 ) 15 17 16 18 // InitChain implements [types.Application]. ··· 21 23 22 24 // PrepareProposal implements [types.Application]. 23 25 func (d *DIDPLCApplication) PrepareProposal(ctx context.Context, req *abcitypes.RequestPrepareProposal) (*abcitypes.ResponsePrepareProposal, error) { 24 - defer d.tree.Rollback() 26 + defer d.DiscardChanges() 25 27 26 28 if req.Height == 2 { 27 29 tx := Transaction[SetAuthoritativePlcArguments]{ ··· 46 48 for { 47 49 toTryNext := [][]byte{} 48 50 for _, tx := range toProcess { 49 - result, err := processTx(ctx, d.transactionProcessorDependencies(), tx, req.Time, true) 51 + result, err := processTx(ctx, d.transactionProcessorDependenciesForOngoingProcessing(true, req.Time), tx) 50 52 if err != nil { 51 53 return nil, stacktrace.Propagate(err, "") 52 54 } ··· 86 88 87 89 // set execute to false to save a lot of time 88 90 // (we trust that running the import will succeed, so just do bare minimum checks here) 89 - result, err := processTx(ctx, d.transactionProcessorDependencies(), maybeTx, req.Time, false) 91 + result, err := processTx(ctx, d.transactionProcessorDependenciesForOngoingProcessing(false, req.Time), maybeTx) 90 92 if err != nil { 91 93 return nil, stacktrace.Propagate(err, "") 92 94 } ··· 114 116 if d.lastProcessedProposalHash == nil { 115 117 // we didn't vote ACCEPT 116 118 // we could rollback only eventually on FinalizeBlock, but why wait - rollback now for safety 117 - d.tree.Rollback() 119 + d.DiscardChanges() 118 120 } 119 121 }() 120 122 ··· 131 133 } 132 134 133 135 st := time.Now() 134 - result, err = finishProcessTx(ctx, d.transactionProcessorDependencies(), processor, tx, req.Time, true) 136 + result, err = finishProcessTx(ctx, d.transactionProcessorDependenciesForOngoingProcessing(true, req.Time), processor, tx) 135 137 if err != nil { 136 138 return nil, stacktrace.Propagate(err, "") 137 139 } ··· 179 181 } 180 182 // a block other than the one we processed in ProcessProposal was decided 181 183 // discard the current modified state, and process the decided block 182 - d.tree.Rollback() 184 + d.DiscardChanges() 183 185 184 186 txResults := make([]*processResult, len(req.Txs)) 185 187 for i, tx := range req.Txs { 186 188 var err error 187 - txResults[i], err = processTx(ctx, d.transactionProcessorDependencies(), tx, req.Time, true) 189 + txResults[i], err = processTx(ctx, d.transactionProcessorDependenciesForOngoingProcessing(true, req.Time), tx) 188 190 if err != nil { 189 191 return nil, stacktrace.Propagate(err, "") 190 192 } ··· 203 205 204 206 // Commit implements [types.Application]. 205 207 func (d *DIDPLCApplication) Commit(context.Context, *abcitypes.RequestCommit) (*abcitypes.ResponseCommit, error) { 206 - _, newVersion, err := d.tree.SaveVersion() 208 + // ensure we always advance tree version by creating ongoingWrite if it hasn't been created already 209 + d.createOngoingTxIfNeeded(time.Now()) 210 + 211 + err := d.ongoingWrite.Commit() 207 212 if err != nil { 208 213 return nil, stacktrace.Propagate(err, "") 209 214 } ··· 214 219 } 215 220 } 216 221 217 - minHeightToKeep := max(newVersion-100, 0) 218 - minVerToKeep := max(minHeightToKeep-5, 0) 219 - if minVerToKeep > 0 { 220 - err = d.tree.DeleteVersionsTo(minVerToKeep) 221 - if err != nil { 222 - return nil, stacktrace.Propagate(err, "") 223 - } 224 - } 225 - 226 222 return &abcitypes.ResponseCommit{ 227 223 // TODO only discard actual blockchain history based on settings 228 224 //RetainHeight: minHeightToKeep, 229 225 }, nil 230 226 } 231 227 232 - func (d *DIDPLCApplication) transactionProcessorDependencies() TransactionProcessorDependencies { 228 + func (d *DIDPLCApplication) transactionProcessorDependenciesForCommittedRead() TransactionProcessorDependencies { 233 229 return TransactionProcessorDependencies{ 234 230 runnerContext: d.runnerContext, 235 231 plc: d.plc, 236 - tree: d, 232 + readTx: d.txFactory.ReadCommitted(), 233 + writeTx: mo.None[transaction.Write](), 237 234 aocsByPLC: d.aocsByPLC, 238 235 } 239 236 } 237 + 238 + func (d *DIDPLCApplication) transactionProcessorDependenciesForOngoingProcessing(write bool, timestamp time.Time) TransactionProcessorDependencies { 239 + d.createOngoingTxIfNeeded(timestamp) 240 + 241 + writeTx := mo.None[transaction.Write]() 242 + 243 + if write { 244 + writeTx = mo.Some(d.ongoingWrite) 245 + } 246 + 247 + return TransactionProcessorDependencies{ 248 + runnerContext: d.runnerContext, 249 + plc: d.plc, 250 + readTx: d.ongoingRead, 251 + writeTx: writeTx, 252 + aocsByPLC: d.aocsByPLC, 253 + } 254 + } 255 + 256 + func (d *DIDPLCApplication) createOngoingTxIfNeeded(timestamp time.Time) { 257 + if d.ongoingWrite != nil && d.ongoingRead == nil { 258 + panic("inconsistent internal state") 259 + } 260 + if d.ongoingRead == nil { 261 + d.ongoingRead = d.txFactory.ReadWorking(timestamp) 262 + 263 + if d.ongoingWrite == nil { 264 + d.ongoingWrite, _ = d.ongoingRead.Upgrade() 265 + } 266 + } 267 + }
+3 -7
abciapp/import.go
··· 18 18 "github.com/ipfs/go-cid" 19 19 cbornode "github.com/ipfs/go-ipld-cbor" 20 20 "github.com/palantir/stacktrace" 21 - "tangled.org/gbl08ma.com/didplcbft/plc" 22 21 "tangled.org/gbl08ma.com/didplcbft/store" 23 22 ) 24 23 ··· 250 249 251 250 func (d *DIDPLCApplication) maybeCreateAuthoritativeImportTx(ctx context.Context) ([]byte, error) { 252 251 // use WorkingTreeVersion so we take into account any import operation that may have been processed in this block 253 - roTree, err := d.ImmutableTree(plc.WorkingTreeVersion) 254 - if err != nil { 255 - return nil, stacktrace.Propagate(err, "") 256 - } 252 + readTx := d.txFactory.ReadWorking(time.Now()) 257 253 258 - plcURL, err := store.Tree.AuthoritativePLC(roTree) 254 + plcURL, err := store.Tree.AuthoritativePLC(readTx) 259 255 if err != nil { 260 256 return nil, stacktrace.Propagate(err, "") 261 257 } ··· 265 261 return nil, nil 266 262 } 267 263 268 - cursor, err := store.Tree.AuthoritativeImportProgress(roTree) 264 + cursor, err := store.Tree.AuthoritativeImportProgress(readTx) 269 265 if err != nil { 270 266 return nil, stacktrace.Propagate(err, "") 271 267 }
+15 -4
abciapp/info.go
··· 6 6 "errors" 7 7 "net/http" 8 8 "net/url" 9 + "time" 9 10 10 11 abcitypes "github.com/cometbft/cometbft/abci/types" 11 12 "github.com/palantir/stacktrace" 12 13 "github.com/ucarion/urlpath" 13 14 "tangled.org/gbl08ma.com/didplcbft/plc" 15 + "tangled.org/gbl08ma.com/didplcbft/transaction" 14 16 ) 15 17 16 18 // Info implements [types.Application]. ··· 34 36 }, nil 35 37 } 36 38 37 - treeVersion := plc.CommittedTreeVersion 39 + var readTx transaction.Read 40 + height := d.tree.Version() 38 41 if req.Height != 0 { 39 - treeVersion = plc.SpecificTreeVersion(req.Height) 42 + height = req.Height 43 + } 44 + 45 + readTx, err = d.txFactory.ReadHeight(time.Now(), height) 46 + if err != nil { 47 + return &abcitypes.ResponseQuery{ 48 + Code: 6001, 49 + Info: "Unavailable height", 50 + }, nil 40 51 } 41 52 42 53 handlers := []struct { ··· 47 58 matcher: urlpath.New("/plc/:did"), 48 59 handler: func(match urlpath.Match) (*abcitypes.ResponseQuery, error) { 49 60 did := match.Params["did"] 50 - doc, err := d.plc.Resolve(ctx, treeVersion, did) 61 + doc, err := d.plc.Resolve(ctx, readTx, did) 51 62 if err != nil { 52 63 switch { 53 64 case errors.Is(err, plc.ErrDIDNotFound): ··· 110 121 matcher: urlpath.New("/plc/:did/data"), 111 122 handler: func(match urlpath.Match) (*abcitypes.ResponseQuery, error) { 112 123 did := match.Params["did"] 113 - data, err := d.plc.Data(ctx, treeVersion, did) 124 + data, err := d.plc.Data(ctx, readTx, did) 114 125 if err != nil { 115 126 switch { 116 127 case errors.Is(err, plc.ErrDIDNotFound):
+1 -2
abciapp/mempool.go
··· 2 2 3 3 import ( 4 4 "context" 5 - "time" 6 5 7 6 abcitypes "github.com/cometbft/cometbft/abci/types" 8 7 "github.com/palantir/stacktrace" ··· 23 22 }, nil 24 23 } 25 24 26 - result, err = finishProcessTx(ctx, d.transactionProcessorDependencies(), processor, req.Tx, time.Now(), false) 25 + result, err = finishProcessTx(ctx, d.transactionProcessorDependenciesForCommittedRead(), processor, req.Tx) 27 26 if err != nil { 28 27 return nil, stacktrace.Propagate(err, "") 29 28 }
+1 -2
abciapp/snapshots.go
··· 171 171 172 172 err := d.snapshotApplier.Apply(int(req.Index), req.Chunk) 173 173 if err != nil { 174 - fmt.Println("SNAPSHOT APPLY FAILED:", err.Error()) 175 174 if errors.Is(err, errMalformedChunk) { 176 175 return &abcitypes.ResponseApplySnapshotChunk{ 177 176 Result: abcitypes.ResponseApplySnapshotChunk_RETRY, ··· 500 499 } 501 500 502 501 if !d.tree.IsEmpty() { 503 - err := d.fullyClearTree() 502 + err := d.fullyClearApplicationData() 504 503 if err != nil { 505 504 return nil, stacktrace.Propagate(err, "") 506 505 }
+9 -7
abciapp/tx.go
··· 3 3 import ( 4 4 "bytes" 5 5 "context" 6 - "time" 7 6 8 7 abcitypes "github.com/cometbft/cometbft/abci/types" 9 8 cbornode "github.com/ipfs/go-ipld-cbor" 10 9 "github.com/palantir/stacktrace" 10 + "github.com/samber/mo" 11 11 "tangled.org/gbl08ma.com/didplcbft/plc" 12 + "tangled.org/gbl08ma.com/didplcbft/transaction" 12 13 ) 13 14 14 15 type ArgumentType interface { ··· 19 20 20 21 type TransactionProcessorDependencies struct { 21 22 runnerContext context.Context 23 + readTx transaction.Read 24 + writeTx mo.Option[transaction.Write] 22 25 plc plc.PLC 23 - tree plc.TreeProvider // TODO maybe we should move the TreeProvider definition out of the plc package then? 24 26 aocsByPLC map[string]*authoritativeOperationsCache 25 27 } 26 28 27 - type TransactionProcessor func(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) 29 + type TransactionProcessor func(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte) (*processResult, error) 28 30 29 31 var knownActions = map[TransactionAction]TransactionProcessor{} 30 32 ··· 152 154 return &processResult{}, action, processor, nil 153 155 } 154 156 155 - func finishProcessTx(ctx context.Context, deps TransactionProcessorDependencies, processor TransactionProcessor, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) { 156 - result, err := processor(ctx, deps, txBytes, atTime, execute) 157 + func finishProcessTx(ctx context.Context, deps TransactionProcessorDependencies, processor TransactionProcessor, txBytes []byte) (*processResult, error) { 158 + result, err := processor(ctx, deps, txBytes) 157 159 return result, stacktrace.Propagate(err, "") 158 160 } 159 161 160 - func processTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) { 162 + func processTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte) (*processResult, error) { 161 163 result, _, processor, err := beginProcessTx(txBytes) 162 164 if err != nil { 163 165 return nil, stacktrace.Propagate(err, "") ··· 166 168 return result, nil 167 169 } 168 170 169 - result, err = finishProcessTx(ctx, deps, processor, txBytes, atTime, execute) 171 + result, err = finishProcessTx(ctx, deps, processor, txBytes) 170 172 return result, stacktrace.Propagate(err, "") 171 173 }
+4 -5
abciapp/tx_create_plc_op.go
··· 3 3 import ( 4 4 "context" 5 5 "encoding/json" 6 - "time" 7 6 8 7 "github.com/did-method-plc/go-didplc" 9 8 cbornode "github.com/ipfs/go-ipld-cbor" ··· 27 26 cbornode.RegisterCborType(Transaction[CreatePlcOpArguments]{}) 28 27 } 29 28 30 - func processCreatePlcOpTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) { 29 + func processCreatePlcOpTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte) (*processResult, error) { 31 30 tx, err := UnmarshalTransaction[CreatePlcOpArguments](txBytes) 32 31 if err != nil { 33 32 return &processResult{ ··· 53 52 return nil, stacktrace.Propagate(err, "internal error") 54 53 } 55 54 56 - if execute { 57 - err = deps.plc.ExecuteOperation(ctx, atTime, tx.Arguments.DID, opBytes) 55 + if writeTx, ok := deps.writeTx.Get(); ok { 56 + err = deps.plc.ExecuteOperation(ctx, writeTx, tx.Arguments.DID, opBytes) 58 57 } else { 59 - err = deps.plc.ValidateOperation(ctx, plc.CommittedTreeVersion, atTime, tx.Arguments.DID, opBytes) 58 + err = deps.plc.ValidateOperation(ctx, deps.readTx, tx.Arguments.DID, opBytes) 60 59 } 61 60 if err != nil { 62 61 if code, ok := plc.InvalidOperationErrorCode(err); ok {
+14 -32
abciapp/tx_import.go
··· 4 4 "context" 5 5 "encoding/hex" 6 6 "net/url" 7 - "time" 8 7 9 - "github.com/did-method-plc/go-didplc" 10 8 cbornode "github.com/ipfs/go-ipld-cbor" 11 9 "github.com/palantir/stacktrace" 12 - "github.com/samber/lo" 13 - "tangled.org/gbl08ma.com/didplcbft/plc" 14 10 "tangled.org/gbl08ma.com/didplcbft/store" 15 11 ) 16 12 ··· 30 26 cbornode.RegisterCborType(Transaction[SetAuthoritativePlcArguments]{}) 31 27 } 32 28 33 - func processSetAuthoritativePlcTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) { 29 + func processSetAuthoritativePlcTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte) (*processResult, error) { 34 30 tx, err := UnmarshalTransaction[SetAuthoritativePlcArguments](txBytes) 35 31 if err != nil { 36 32 return &processResult{ ··· 57 53 } 58 54 } 59 55 60 - if execute { 61 - tree, err := deps.tree.MutableTree() 62 - if err != nil { 63 - return nil, stacktrace.Propagate(err, "") 64 - } 65 - err = store.Tree.SetAuthoritativePLC(tree, tx.Arguments.PLCURL) 56 + if writeTx, ok := deps.writeTx.Get(); ok { 57 + err = store.Tree.SetAuthoritativePLC(writeTx, tx.Arguments.PLCURL) 66 58 if err != nil { 67 59 return nil, stacktrace.Propagate(err, "") 68 60 } 69 61 70 62 if tx.Arguments.RestartImport { 71 - err = store.Tree.SetAuthoritativeImportProgress(tree, 0) 63 + err = store.Tree.SetAuthoritativeImportProgress(writeTx, 0) 72 64 if err != nil { 73 65 return nil, stacktrace.Propagate(err, "") 74 66 } ··· 98 90 cbornode.RegisterCborType(Transaction[AuthoritativeImportArguments]{}) 99 91 } 100 92 101 - func processAuthoritativeImportTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte, atTime time.Time, execute bool) (*processResult, error) { 93 + func processAuthoritativeImportTx(ctx context.Context, deps TransactionProcessorDependencies, txBytes []byte) (*processResult, error) { 102 94 tx, err := UnmarshalTransaction[AuthoritativeImportArguments](txBytes) 103 95 if err != nil { 104 96 return &processResult{ ··· 107 99 }, nil 108 100 } 109 101 110 - roTree, err := deps.tree.ImmutableTree(plc.CommittedTreeVersion) 111 - if err != nil { 112 - return nil, stacktrace.Propagate(err, "") 113 - } 114 - 115 - expectedPlcUrl, err := store.Tree.AuthoritativePLC(roTree) 102 + expectedPlcUrl, err := store.Tree.AuthoritativePLC(deps.readTx) 116 103 if err != nil { 117 104 return nil, stacktrace.Propagate(err, "") 118 105 } ··· 126 113 127 114 aoc := getOrCreateAuthoritativeOperationsCache(deps.runnerContext, deps.aocsByPLC, expectedPlcUrl) 128 115 129 - expectedCursor, err := store.Tree.AuthoritativeImportProgress(roTree) 116 + expectedCursor, err := store.Tree.AuthoritativeImportProgress(deps.readTx) 130 117 if err != nil { 131 118 return nil, stacktrace.Propagate(err, "") 132 119 } ··· 170 157 newCursor = operations[len(operations)-1].Seq 171 158 } 172 159 173 - if execute { 174 - err := deps.plc.ImportOperationsFromAuthoritativeSource(ctx, lo.Map(operations, func(l logEntryWithSeq, _ int) didplc.LogEntry { 175 - return l.LogEntry 176 - })) 177 - if err != nil { 178 - return nil, stacktrace.Propagate(err, "") 160 + if writeTx, ok := deps.writeTx.Get(); ok { 161 + for _, op := range operations { 162 + err := deps.plc.ImportOperationFromAuthoritativeSource(ctx, writeTx, op.LogEntry) 163 + if err != nil { 164 + return nil, stacktrace.Propagate(err, "") 165 + } 179 166 } 180 167 181 - tree, err := deps.tree.MutableTree() 182 - if err != nil { 183 - return nil, stacktrace.Propagate(err, "") 184 - } 185 - 186 - err = store.Tree.SetAuthoritativeImportProgress(tree, newCursor) 168 + err = store.Tree.SetAuthoritativeImportProgress(writeTx, newCursor) 187 169 if err != nil { 188 170 return nil, stacktrace.Propagate(err, "") 189 171 }
-407
badgeradapter/adapter.go
··· 1 - package badgeradapter 2 - 3 - import ( 4 - "bytes" 5 - "slices" 6 - 7 - "cosmossdk.io/core/store" 8 - "github.com/cosmos/iavl/db" 9 - "github.com/palantir/stacktrace" 10 - 11 - badger "github.com/dgraph-io/badger/v4" 12 - ) 13 - 14 - type BadgerAdapter struct { 15 - badgerDB *badger.DB 16 - keyPrefix []byte 17 - } 18 - 19 - func AdaptBadger(badgerDB *badger.DB, keyPrefix []byte) *BadgerAdapter { 20 - return &BadgerAdapter{ 21 - badgerDB: badgerDB, 22 - keyPrefix: keyPrefix, 23 - } 24 - } 25 - 26 - var _ db.DB = (*BadgerAdapter)(nil) 27 - 28 - // prefixKey adds the keyPrefix to the given key 29 - func (b *BadgerAdapter) prefixKey(key []byte) []byte { 30 - result := make([]byte, 0, len(b.keyPrefix)+len(key)) 31 - result = append(result, b.keyPrefix...) 32 - result = append(result, key...) 33 - return result 34 - } 35 - 36 - // Close implements [db.DB]. 37 - func (b *BadgerAdapter) Close() error { 38 - return b.badgerDB.Close() 39 - } 40 - 41 - // Get implements [db.DB]. 42 - func (b *BadgerAdapter) Get(key []byte) ([]byte, error) { 43 - prefixedKey := b.prefixKey(key) 44 - 45 - var value []byte 46 - err := b.badgerDB.View(func(txn *badger.Txn) error { 47 - item, err := txn.Get(prefixedKey) 48 - if err != nil { 49 - return err 50 - } 51 - value, err = item.ValueCopy(nil) 52 - return err 53 - }) 54 - 55 - if err == badger.ErrKeyNotFound { 56 - return nil, nil 57 - } 58 - if err != nil { 59 - return nil, stacktrace.Propagate(err, "failed to get key from badger") 60 - } 61 - 62 - return value, nil 63 - } 64 - 65 - // Has implements [db.DB]. 66 - func (b *BadgerAdapter) Has(key []byte) (bool, error) { 67 - prefixedKey := b.prefixKey(key) 68 - 69 - var has bool 70 - err := b.badgerDB.View(func(txn *badger.Txn) error { 71 - _, err := txn.Get(prefixedKey) 72 - if err == badger.ErrKeyNotFound { 73 - has = false 74 - return nil 75 - } 76 - if err != nil { 77 - return err 78 - } 79 - has = true 80 - return nil 81 - }) 82 - 83 - if err != nil { 84 - return false, stacktrace.Propagate(err, "failed to check key existence in badger") 85 - } 86 - 87 - return has, nil 88 - } 89 - 90 - // BadgerIterator adapts badger.Iterator to store.Iterator 91 - type BadgerIterator struct { 92 - badgerIter *badger.Iterator 93 - txn *badger.Txn 94 - start []byte 95 - end []byte 96 - reverse bool // true if this is a reverse iterator 97 - valid bool 98 - keyPrefix []byte 99 - } 100 - 101 - // hasPrefix checks if a prefixed key actually has the expected keyPrefix 102 - func (i *BadgerIterator) hasPrefix(prefixedKey []byte) bool { 103 - return len(prefixedKey) >= len(i.keyPrefix) && bytes.Equal(prefixedKey[:len(i.keyPrefix)], i.keyPrefix) 104 - } 105 - 106 - // stripPrefix removes the keyPrefix from a prefixed key 107 - func (i *BadgerIterator) stripPrefix(prefixedKey []byte) []byte { 108 - if len(prefixedKey) < len(i.keyPrefix) { 109 - return prefixedKey // Shouldn't happen, but defensive programming 110 - } 111 - stripped := make([]byte, len(prefixedKey)-len(i.keyPrefix)) 112 - copy(stripped, prefixedKey[len(i.keyPrefix):]) 113 - return stripped 114 - } 115 - 116 - func (i *BadgerIterator) Domain() (start, end []byte) { 117 - // Return copies to ensure they're safe for modification 118 - startCopy := make([]byte, len(i.start)) 119 - endCopy := make([]byte, len(i.end)) 120 - copy(startCopy, i.start) 121 - copy(endCopy, i.end) 122 - return startCopy, endCopy 123 - } 124 - 125 - func (i *BadgerIterator) Valid() bool { 126 - if !i.valid || !i.badgerIter.Valid() { 127 - return false 128 - } 129 - 130 - // Ensure the current key has the correct keyPrefix 131 - // If not, skip to the next valid key 132 - item := i.badgerIter.Item() 133 - prefixedKey := item.Key() 134 - if !i.hasPrefix(prefixedKey) { 135 - // We've gone out of the bounds of "our" prefixes 136 - return false 137 - } 138 - 139 - // For forward iteration, check if we've reached the end (end is exclusive) 140 - if i.end != nil && !i.reverse { 141 - currentKey := i.stripPrefix(prefixedKey) 142 - // If current key >= end key, we're done 143 - if bytes.Compare(currentKey, i.end) >= 0 { 144 - return false 145 - } 146 - } 147 - 148 - // For reverse iteration, check if we've gone below the start (start is inclusive) 149 - if i.start != nil && i.reverse { 150 - currentKey := i.stripPrefix(prefixedKey) 151 - // If current key < start key, we're done 152 - if bytes.Compare(currentKey, i.start) < 0 { 153 - return false 154 - } 155 - } 156 - 157 - return true 158 - } 159 - 160 - func (i *BadgerIterator) Next() { 161 - if !i.valid { 162 - panic("iterator is not valid") 163 - } 164 - i.badgerIter.Next() 165 - 166 - // Check if the badger iterator is still valid 167 - if !i.badgerIter.Valid() { 168 - i.valid = false 169 - return 170 - } 171 - 172 - item := i.badgerIter.Item() 173 - prefixedKey := item.Key() 174 - if !i.hasPrefix(prefixedKey) { 175 - // We've gone out of the bounds of "our" prefixes 176 - i.valid = false 177 - return 178 - } 179 - 180 - // For forward iteration, check if we've reached the end (end is exclusive) 181 - if i.end != nil && !i.reverse { 182 - currentKey := i.stripPrefix(prefixedKey) 183 - // If current key >= end key, we're done 184 - if bytes.Compare(currentKey, i.end) >= 0 { 185 - i.valid = false 186 - return 187 - } 188 - } 189 - 190 - // For reverse iteration, check if we've gone below the start (start is inclusive) 191 - if i.start != nil && i.reverse { 192 - currentKey := i.stripPrefix(prefixedKey) 193 - // If current key < start key, we're done 194 - if bytes.Compare(currentKey, i.start) < 0 { 195 - i.valid = false 196 - return 197 - } 198 - } 199 - 200 - i.valid = true 201 - } 202 - 203 - func (i *BadgerIterator) Key() []byte { 204 - if !i.valid { 205 - panic("iterator is not valid") 206 - } 207 - item := i.badgerIter.Item() 208 - return i.stripPrefix(item.Key()) 209 - } 210 - 211 - func (i *BadgerIterator) Value() []byte { 212 - if !i.valid { 213 - panic("iterator is not valid") 214 - } 215 - item := i.badgerIter.Item() 216 - value, err := item.ValueCopy(nil) 217 - if err != nil { 218 - panic("failed to copy value: " + err.Error()) 219 - } 220 - return value 221 - } 222 - 223 - func (i *BadgerIterator) Error() error { 224 - // Badger iterator doesn't have a separate error method 225 - // Errors are typically caught during iteration setup 226 - return nil 227 - } 228 - 229 - func (i *BadgerIterator) Close() error { 230 - // Close the badger iterator first - this is critical to avoid panics 231 - if i.badgerIter != nil { 232 - i.badgerIter.Close() 233 - } 234 - 235 - // Mark as invalid 236 - i.valid = false 237 - 238 - // Discard the transaction to release resources 239 - if i.txn != nil { 240 - i.txn.Discard() 241 - i.txn = nil 242 - } 243 - 244 - return nil 245 - } 246 - 247 - // Iterator implements [db.DB]. 248 - func (b *BadgerAdapter) Iterator(start []byte, end []byte) (store.Iterator, error) { 249 - // Create a read-only transaction to hold the iterator 250 - txn := b.badgerDB.NewTransaction(false) 251 - 252 - // Create prefixed version of start 253 - prefixedStart := b.prefixKey(start) 254 - 255 - opts := badger.IteratorOptions{ 256 - PrefetchValues: true, 257 - Reverse: false, 258 - AllVersions: false, 259 - } 260 - badgerIter := txn.NewIterator(opts) 261 - 262 - badgerIter.Seek(prefixedStart) 263 - 264 - iterator := &BadgerIterator{ 265 - badgerIter: badgerIter, 266 - txn: txn, 267 - start: start, // Store original start/end for Domain() method 268 - end: end, 269 - reverse: false, // This is a forward iterator 270 - valid: badgerIter.Valid(), 271 - keyPrefix: b.keyPrefix, 272 - } 273 - 274 - return iterator, nil 275 - } 276 - 277 - // incrementSlice assumes that the first byte of b is not 0xff 278 - func incrementSlice(b []byte) { 279 - for i := len(b) - 1; i >= 0; i-- { 280 - b[i] += 1 281 - if b[i] != 0 { 282 - break 283 - } 284 - } 285 - } 286 - 287 - // ReverseIterator implements [db.DB]. 288 - func (b *BadgerAdapter) ReverseIterator(start []byte, end []byte) (store.Iterator, error) { 289 - // Create a read-only transaction to hold the iterator 290 - txn := b.badgerDB.NewTransaction(false) 291 - 292 - opts := badger.IteratorOptions{ 293 - PrefetchValues: true, 294 - Reverse: true, // This enables reverse iteration 295 - AllVersions: false, 296 - } 297 - badgerIter := txn.NewIterator(opts) 298 - 299 - prefixedEnd := b.prefixKey(end) 300 - incrementedEnd := slices.Clone(prefixedEnd) 301 - incrementSlice(incrementedEnd) // Badger's Seek is inclusive but in these iterators end is exclusive (except if nil) 302 - 303 - badgerIter.Seek(incrementedEnd) 304 - // if end is nil, then Badger might be (depending on whether end matches an existing key) 305 - // already giving us the key we want and there's no need to skip 306 - if end != nil && badgerIter.Valid() && bytes.Equal(badgerIter.Item().Key(), prefixedEnd) { 307 - badgerIter.Next() 308 - } 309 - 310 - iterator := &BadgerIterator{ 311 - badgerIter: badgerIter, 312 - txn: txn, 313 - start: start, 314 - end: end, 315 - reverse: true, // This is a reverse iterator 316 - valid: badgerIter.Valid(), 317 - keyPrefix: b.keyPrefix, 318 - } 319 - 320 - return iterator, nil 321 - } 322 - 323 - // BadgerBatch implements store.Batch 324 - // BadgerBatch writes are atomic up until the point where they'd exceed the badger max transaction size, 325 - // at which point they are split into multiple non-atomic writes 326 - type BadgerBatch struct { 327 - wb *badger.WriteBatch 328 - closed bool 329 - keyPrefix []byte 330 - } 331 - 332 - func (b *BadgerBatch) Set(key, value []byte) error { 333 - if b.closed { 334 - return stacktrace.NewError("batch has been written or closed") 335 - } 336 - if len(key) == 0 { 337 - return stacktrace.NewError("key cannot be empty") 338 - } 339 - if value == nil { 340 - return stacktrace.NewError("value cannot be nil") 341 - } 342 - 343 - prefixedKey := make([]byte, 0, len(b.keyPrefix)+len(key)) 344 - prefixedKey = append(prefixedKey, b.keyPrefix...) 345 - prefixedKey = append(prefixedKey, key...) 346 - 347 - err := b.wb.Set(prefixedKey, value) 348 - return stacktrace.Propagate(err, "failed to set key in batch") 349 - } 350 - 351 - func (b *BadgerBatch) Delete(key []byte) error { 352 - if b.closed { 353 - return stacktrace.NewError("batch has been written or closed") 354 - } 355 - if len(key) == 0 { 356 - return stacktrace.NewError("key cannot be empty") 357 - } 358 - 359 - prefixedKey := make([]byte, 0, len(b.keyPrefix)+len(key)) 360 - prefixedKey = append(prefixedKey, b.keyPrefix...) 361 - prefixedKey = append(prefixedKey, key...) 362 - 363 - err := b.wb.Delete(prefixedKey) 364 - return stacktrace.Propagate(err, "failed to delete key in batch") 365 - } 366 - 367 - func (b *BadgerBatch) Write() error { 368 - if b.closed { 369 - return stacktrace.NewError("batch has been written or closed") 370 - } 371 - b.closed = true 372 - err := b.wb.Flush() 373 - return stacktrace.Propagate(err, "failed to write batch") 374 - } 375 - 376 - func (b *BadgerBatch) WriteSync() error { 377 - // Badger doesn't have separate WriteSync, so we just use Write 378 - return b.Write() 379 - } 380 - 381 - func (b *BadgerBatch) Close() error { 382 - if !b.closed { 383 - b.wb.Cancel() 384 - b.closed = true 385 - } 386 - return nil 387 - } 388 - 389 - func (b *BadgerBatch) GetByteSize() (int, error) { 390 - // Badger doesn't provide byte size tracking for batches 391 - // Return 0 as a placeholder 392 - return 0, nil 393 - } 394 - 395 - // NewBatch implements [db.DB]. 396 - func (b *BadgerAdapter) NewBatch() store.Batch { 397 - return &BadgerBatch{ 398 - wb: b.badgerDB.NewWriteBatch(), 399 - keyPrefix: b.keyPrefix, 400 - } 401 - } 402 - 403 - // NewBatchWithSize implements [db.DB]. 404 - func (b *BadgerAdapter) NewBatchWithSize(size int) store.Batch { 405 - // Badger doesn't support pre-allocated batch sizes, so we just create a regular batch 406 - return b.NewBatch() 407 - }
-454
badgeradapter/adapter_test.go
··· 1 - package badgeradapter 2 - 3 - import ( 4 - "testing" 5 - 6 - badger "github.com/dgraph-io/badger/v4" 7 - "github.com/stretchr/testify/require" 8 - ) 9 - 10 - func TestBadgerAdapter_KeyPrefixStripping(t *testing.T) { 11 - // Create a temporary badger database 12 - opts := badger.DefaultOptions("").WithInMemory(true) 13 - db, err := badger.Open(opts) 14 - require.NoError(t, err) 15 - defer db.Close() 16 - 17 - // Create adapter with a specific key prefix 18 - keyPrefix := []byte("test:") 19 - adapter := AdaptBadger(db, keyPrefix) 20 - 21 - // Write some test data 22 - batch := adapter.NewBatch() 23 - err = batch.Set([]byte("key1"), []byte("value1")) 24 - require.NoError(t, err) 25 - err = batch.Write() 26 - require.NoError(t, err) 27 - 28 - // Verify that the underlying badger database stores the key WITH the prefix 29 - var foundPrefixedKey bool 30 - err = db.View(func(txn *badger.Txn) error { 31 - opts := badger.IteratorOptions{ 32 - PrefetchValues: true, 33 - Reverse: false, 34 - AllVersions: false, 35 - } 36 - iter := txn.NewIterator(opts) 37 - defer iter.Close() 38 - 39 - for iter.Seek([]byte("test:")); iter.Valid(); iter.Next() { 40 - item := iter.Item() 41 - key := item.KeyCopy(nil) 42 - if string(key) == "test:key1" { 43 - foundPrefixedKey = true 44 - value, err := item.ValueCopy(nil) 45 - require.NoError(t, err) 46 - require.Equal(t, []byte("value1"), value) 47 - break 48 - } 49 - } 50 - return nil 51 - }) 52 - require.NoError(t, err) 53 - require.True(t, foundPrefixedKey, "Expected to find prefixed key 'test:key1' in underlying badger database") 54 - 55 - // Test Get operation - should work with unprefixed key 56 - value, err := adapter.Get([]byte("key1")) 57 - require.NoError(t, err) 58 - require.Equal(t, []byte("value1"), value) 59 - 60 - // Test iterator - should iterate over prefixed keys but return unprefixed keys 61 - iter, err := adapter.Iterator([]byte("key1"), []byte("key2")) 62 - require.NoError(t, err) 63 - defer iter.Close() 64 - 65 - require.True(t, iter.Valid()) 66 - returnedKey := iter.Key() 67 - returnedValue := iter.Value() 68 - 69 - // The returned key should NOT have the prefix 70 - require.Equal(t, []byte("key1"), returnedKey) 71 - require.Equal(t, []byte("value1"), returnedValue) 72 - 73 - iter.Next() 74 - require.False(t, iter.Valid()) 75 - } 76 - 77 - func TestBadgerAdapter_ReverseIteratorPrefixStripping(t *testing.T) { 78 - // Create a temporary badger database 79 - opts := badger.DefaultOptions("").WithInMemory(true) 80 - db, err := badger.Open(opts) 81 - require.NoError(t, err) 82 - defer db.Close() 83 - 84 - // Create adapter with a specific key prefix 85 - keyPrefix := []byte("prefix:") 86 - adapter := AdaptBadger(db, keyPrefix) 87 - 88 - // Write multiple test data entries 89 - batch := adapter.NewBatch() 90 - for i := 1; i <= 3; i++ { 91 - key := []byte("key" + string(rune('0'+i))) 92 - value := []byte("value" + string(rune('0'+i))) 93 - err = batch.Set(key, value) 94 - require.NoError(t, err) 95 - } 96 - err = batch.Write() 97 - require.NoError(t, err) 98 - 99 - // Verify that the underlying badger database stores the keys WITH the prefix 100 - var foundPrefixedKeys []string 101 - err = db.View(func(txn *badger.Txn) error { 102 - opts := badger.IteratorOptions{ 103 - PrefetchValues: true, 104 - Reverse: false, 105 - AllVersions: false, 106 - } 107 - iter := txn.NewIterator(opts) 108 - defer iter.Close() 109 - 110 - for iter.Seek([]byte("prefix:")); iter.Valid(); iter.Next() { 111 - item := iter.Item() 112 - key := item.KeyCopy(nil) 113 - keyStr := string(key) 114 - if len(keyStr) > len("prefix:") && keyStr[:len("prefix:")] == "prefix:" { 115 - foundPrefixedKeys = append(foundPrefixedKeys, keyStr) 116 - } 117 - } 118 - return nil 119 - }) 120 - require.NoError(t, err) 121 - require.Len(t, foundPrefixedKeys, 3, "Expected to find 3 prefixed keys in underlying badger database") 122 - require.Contains(t, foundPrefixedKeys, "prefix:key1") 123 - require.Contains(t, foundPrefixedKeys, "prefix:key2") 124 - require.Contains(t, foundPrefixedKeys, "prefix:key3") 125 - 126 - // Test reverse iterator - should iterate over prefixed keys but return unprefixed keys 127 - iter, err := adapter.ReverseIterator([]byte("key1"), []byte("key4")) 128 - require.NoError(t, err) 129 - defer iter.Close() 130 - 131 - // Should start with the last key in range 132 - require.True(t, iter.Valid()) 133 - returnedKey := iter.Key() 134 - returnedValue := iter.Value() 135 - 136 - // The returned key should NOT have the prefix 137 - require.Equal(t, []byte("key3"), returnedKey) 138 - require.Equal(t, []byte("value3"), returnedValue) 139 - 140 - // Move to previous key 141 - iter.Next() 142 - require.True(t, iter.Valid()) 143 - returnedKey = iter.Key() 144 - returnedValue = iter.Value() 145 - require.Equal(t, []byte("key2"), returnedKey) 146 - require.Equal(t, []byte("value2"), returnedValue) 147 - 148 - // Move to previous key again 149 - iter.Next() 150 - require.True(t, iter.Valid()) 151 - returnedKey = iter.Key() 152 - returnedValue = iter.Value() 153 - require.Equal(t, []byte("key1"), returnedKey) 154 - require.Equal(t, []byte("value1"), returnedValue) 155 - 156 - // Should be at the beginning of range 157 - iter.Next() 158 - require.False(t, iter.Valid()) 159 - } 160 - 161 - func TestBadgerAdapter_IteratorRespectsEnd(t *testing.T) { 162 - // Create a temporary badger database 163 - opts := badger.DefaultOptions("").WithInMemory(true) 164 - db, err := badger.Open(opts) 165 - require.NoError(t, err) 166 - defer db.Close() 167 - 168 - // Create adapter with a specific key prefix 169 - keyPrefix := []byte("test:") 170 - adapter := AdaptBadger(db, keyPrefix) 171 - 172 - // Write test data 173 - batch := adapter.NewBatch() 174 - data := map[string]string{ 175 - "apple": "fruit1", 176 - "banana": "fruit2", 177 - "cherry": "fruit3", 178 - "date": "fruit4", 179 - "elderberry": "fruit5", 180 - } 181 - for key, value := range data { 182 - err = batch.Set([]byte(key), []byte(value)) 183 - require.NoError(t, err) 184 - } 185 - err = batch.Write() 186 - require.NoError(t, err) 187 - 188 - // Test forward iteration with end boundary 189 - iter, err := adapter.Iterator([]byte("apple"), []byte("cherry")) 190 - require.NoError(t, err) 191 - defer iter.Close() 192 - 193 - // Should include "apple" and "banana" but stop before "cherry" 194 - require.True(t, iter.Valid()) 195 - require.Equal(t, []byte("apple"), iter.Key()) 196 - require.Equal(t, []byte("fruit1"), iter.Value()) 197 - 198 - iter.Next() 199 - require.True(t, iter.Valid()) 200 - require.Equal(t, []byte("banana"), iter.Key()) 201 - require.Equal(t, []byte("fruit2"), iter.Value()) 202 - 203 - // Next should stop before "cherry" since end is exclusive 204 - iter.Next() 205 - require.False(t, iter.Valid(), "Iterator should be invalid after reaching end boundary") 206 - 207 - // Test forward iteration with nil end (should iterate to the end) 208 - iter, err = adapter.Iterator([]byte("apple"), nil) 209 - require.NoError(t, err) 210 - defer iter.Close() 211 - 212 - count := 0 213 - for iter.Valid() { 214 - count++ 215 - iter.Next() 216 - } 217 - require.Equal(t, 5, count, "Should iterate over all 5 keys when end is nil") 218 - 219 - // Test forward iteration with start = nil (should start from first key) 220 - iter, err = adapter.Iterator(nil, []byte("cherry")) 221 - require.NoError(t, err) 222 - defer iter.Close() 223 - 224 - count = 0 225 - for iter.Valid() { 226 - count++ 227 - iter.Next() 228 - } 229 - require.Equal(t, 2, count, "Should iterate over 2 keys (apple, banana) before cherry") 230 - } 231 - 232 - func TestBadgerAdapter_ReverseIteratorRespectsStart(t *testing.T) { 233 - // Create a temporary badger database 234 - opts := badger.DefaultOptions("").WithInMemory(true) 235 - db, err := badger.Open(opts) 236 - require.NoError(t, err) 237 - defer db.Close() 238 - 239 - // Create adapter with a specific key prefix 240 - keyPrefix := []byte("test:") 241 - adapter := AdaptBadger(db, keyPrefix) 242 - 243 - // Write test data 244 - batch := adapter.NewBatch() 245 - data := map[string]string{ 246 - "apple": "fruit1", 247 - "banana": "fruit2", 248 - "cherry": "fruit3", 249 - "date": "fruit4", 250 - "elderberry": "fruit5", 251 - } 252 - for key, value := range data { 253 - err = batch.Set([]byte(key), []byte(value)) 254 - require.NoError(t, err) 255 - } 256 - err = batch.Write() 257 - require.NoError(t, err) 258 - 259 - // Test reverse iteration with start boundary 260 - iter, err := adapter.ReverseIterator([]byte("banana"), []byte("elderberry")) 261 - require.NoError(t, err) 262 - defer iter.Close() 263 - 264 - // Should start from "date" and go backwards to "banana" (inclusive) 265 - require.True(t, iter.Valid()) 266 - require.Equal(t, []byte("date"), iter.Key()) 267 - require.Equal(t, []byte("fruit4"), iter.Value()) 268 - 269 - iter.Next() 270 - require.True(t, iter.Valid()) 271 - require.Equal(t, []byte("cherry"), iter.Key()) 272 - require.Equal(t, []byte("fruit3"), iter.Value()) 273 - 274 - iter.Next() 275 - require.True(t, iter.Valid()) 276 - require.Equal(t, []byte("banana"), iter.Key()) 277 - require.Equal(t, []byte("fruit2"), iter.Value()) 278 - 279 - // Next should stop since we've reached the start boundary (inclusive) 280 - iter.Next() 281 - require.False(t, iter.Valid(), "Iterator should be invalid after reaching start boundary") 282 - 283 - // Test reverse iteration with nil start (should go to the beginning) 284 - iter, err = adapter.ReverseIterator(nil, []byte("cherry")) 285 - require.NoError(t, err) 286 - defer iter.Close() 287 - 288 - count := 0 289 - for iter.Valid() { 290 - count++ 291 - iter.Next() 292 - } 293 - require.Equal(t, 2, count, "Should iterate over 2 keys (banana, apple) before cherry") 294 - 295 - // Test reverse iteration with nil end (should start from the last key) 296 - iter, err = adapter.ReverseIterator([]byte("banana"), nil) 297 - require.NoError(t, err) 298 - defer iter.Close() 299 - 300 - count = 0 301 - for iter.Valid() { 302 - count++ 303 - iter.Next() 304 - } 305 - require.Equal(t, 4, count, "Should iterate over 4 keys (elderberry, date, cherry, banana) when end is nil") 306 - } 307 - 308 - func TestBadgerAdapter_IteratorRespectsKeyPrefix(t *testing.T) { 309 - // Create a temporary badger database 310 - opts := badger.DefaultOptions("").WithInMemory(true) 311 - db, err := badger.Open(opts) 312 - require.NoError(t, err) 313 - defer db.Close() 314 - 315 - // Create adapter with a specific key prefix 316 - keyPrefix := []byte("table1:") 317 - adapter := AdaptBadger(db, keyPrefix) 318 - 319 - // Write test data directly to badger with different prefixes to simulate multiple "tables" 320 - err = db.Update(func(txn *badger.Txn) error { 321 - // Write keys with the correct prefix (what the adapter should see) 322 - err := txn.Set([]byte("table1:apple"), []byte("fruit1")) 323 - require.NoError(t, err) 324 - err = txn.Set([]byte("table1:banana"), []byte("fruit2")) 325 - require.NoError(t, err) 326 - err = txn.Set([]byte("table1:cherry"), []byte("fruit3")) 327 - require.NoError(t, err) 328 - 329 - // Write keys with a different prefix (what the adapter should NOT see) 330 - err = txn.Set([]byte("table2:apple"), []byte("other1")) 331 - require.NoError(t, err) 332 - err = txn.Set([]byte("table2:date"), []byte("other2")) 333 - require.NoError(t, err) 334 - 335 - // Write keys with no prefix (what the adapter should NOT see) 336 - err = txn.Set([]byte("apple"), []byte("raw1")) 337 - require.NoError(t, err) 338 - err = txn.Set([]byte("zebra"), []byte("raw2")) 339 - require.NoError(t, err) 340 - 341 - return nil 342 - }) 343 - require.NoError(t, err) 344 - 345 - // Test forward iteration - should only see keys with "table1:" prefix 346 - iter, err := adapter.Iterator(nil, nil) 347 - require.NoError(t, err) 348 - defer iter.Close() 349 - 350 - var keys []string 351 - for iter.Valid() { 352 - keys = append(keys, string(iter.Key())) 353 - iter.Next() 354 - } 355 - 356 - // Should only see the 3 keys with the correct prefix, stripped of the prefix 357 - require.Equal(t, []string{"apple", "banana", "cherry"}, keys) 358 - 359 - // Test forward iteration with range - should only see keys with "table1:" prefix in range 360 - iter, err = adapter.Iterator([]byte("banana"), []byte("cherry")) 361 - require.NoError(t, err) 362 - defer iter.Close() 363 - 364 - keys = nil 365 - for iter.Valid() { 366 - keys = append(keys, string(iter.Key())) 367 - iter.Next() 368 - } 369 - 370 - // Should only see "banana" (inclusive) but not "cherry" (exclusive) 371 - require.Equal(t, []string{"banana"}, keys) 372 - 373 - // Test reverse iteration - should only see keys with "table1:" prefix 374 - iter, err = adapter.ReverseIterator(nil, nil) 375 - require.NoError(t, err) 376 - defer iter.Close() 377 - 378 - keys = nil 379 - for iter.Valid() { 380 - keys = append(keys, string(iter.Key())) 381 - iter.Next() 382 - } 383 - 384 - // Should see the 3 keys in reverse order, stripped of the prefix 385 - require.Equal(t, []string{"cherry", "banana", "apple"}, keys) 386 - 387 - // Test reverse iteration with range - should only see keys with "table1:" prefix in range 388 - iter, err = adapter.ReverseIterator([]byte("apple"), []byte("cherry")) 389 - require.NoError(t, err) 390 - defer iter.Close() 391 - 392 - keys = nil 393 - for iter.Valid() { 394 - keys = append(keys, string(iter.Key())) 395 - iter.Next() 396 - } 397 - 398 - // Should see keys from cherry (exclusive) down to apple (inclusive) 399 - require.Equal(t, []string{"banana", "apple"}, keys) 400 - 401 - // Test reverse iteration with wider range - should only see keys with "table1:" prefix in range 402 - iter, err = adapter.ReverseIterator([]byte("apple"), []byte("zzz")) 403 - require.NoError(t, err) 404 - defer iter.Close() 405 - 406 - keys = nil 407 - for iter.Valid() { 408 - keys = append(keys, string(iter.Key())) 409 - iter.Next() 410 - } 411 - 412 - // Should see keys from cherry (exclusive) down to apple (inclusive) 413 - require.Equal(t, []string{"cherry", "banana", "apple"}, keys) 414 - 415 - // An adapter without key prefix should be able to iterate over all keys 416 - adapter = AdaptBadger(db, []byte{}) 417 - 418 - iter, err = adapter.ReverseIterator(nil, nil) 419 - require.NoError(t, err) 420 - defer iter.Close() 421 - 422 - keys = nil 423 - for iter.Valid() { 424 - keys = append(keys, string(iter.Key())) 425 - iter.Next() 426 - } 427 - 428 - // Should see all keys in reverse order, regardless of prefix 429 - require.Len(t, keys, 7) 430 - 431 - iter, err = adapter.ReverseIterator([]byte("table2:date"), []byte("zebra")) 432 - require.NoError(t, err) 433 - defer iter.Close() 434 - 435 - keys = nil 436 - for iter.Valid() { 437 - keys = append(keys, string(iter.Key())) 438 - iter.Next() 439 - } 440 - 441 - require.Equal(t, []string{"table2:date"}, keys) 442 - 443 - iter, err = adapter.ReverseIterator([]byte("table2:date"), []byte("zzz")) 444 - require.NoError(t, err) 445 - defer iter.Close() 446 - 447 - keys = nil 448 - for iter.Valid() { 449 - keys = append(keys, string(iter.Key())) 450 - iter.Next() 451 - } 452 - 453 - require.Equal(t, []string{"zebra", "table2:date"}, keys) 454 - }
+109
dbadapter/adapter.go
··· 1 + package dbadapter 2 + 3 + import ( 4 + "cosmossdk.io/core/store" 5 + dbm "github.com/cometbft/cometbft-db" 6 + iavldbm "github.com/cosmos/iavl/db" 7 + ) 8 + 9 + type AdaptedDB struct { 10 + underlying dbm.DB 11 + } 12 + 13 + func Adapt(underlying dbm.DB) *AdaptedDB { 14 + return &AdaptedDB{ 15 + underlying: underlying, 16 + } 17 + } 18 + 19 + var _ iavldbm.DB = (*AdaptedDB)(nil) 20 + 21 + // Close implements [iavldbm.DB]. 22 + func (b *AdaptedDB) Close() error { 23 + return b.underlying.Close() 24 + } 25 + 26 + // Get implements [iavldbm.DB]. 27 + func (b *AdaptedDB) Get(key []byte) ([]byte, error) { 28 + return b.underlying.Get(key) 29 + } 30 + 31 + // Has implements [iavldbm.DB]. 32 + func (b *AdaptedDB) Has(key []byte) (bool, error) { 33 + return b.underlying.Has(key) 34 + } 35 + 36 + // AdaptedIterator adapts badger.Iterator to store.Iterator 37 + type AdaptedIterator struct { 38 + underlying dbm.Iterator 39 + calledNextOnce bool 40 + } 41 + 42 + func (i *AdaptedIterator) Domain() (start, end []byte) { 43 + return i.underlying.Domain() 44 + } 45 + 46 + func (i *AdaptedIterator) Valid() bool { 47 + return i.underlying.Valid() 48 + } 49 + 50 + func (i *AdaptedIterator) Next() { 51 + if !i.calledNextOnce { 52 + i.calledNextOnce = true 53 + return 54 + } 55 + i.underlying.Next() 56 + } 57 + 58 + func (i *AdaptedIterator) Key() []byte { 59 + return i.underlying.Key() 60 + } 61 + 62 + func (i *AdaptedIterator) Value() []byte { 63 + return i.underlying.Value() 64 + } 65 + 66 + func (i *AdaptedIterator) Error() error { 67 + return i.underlying.Error() 68 + } 69 + 70 + func (i *AdaptedIterator) Close() error { 71 + return i.underlying.Close() 72 + } 73 + 74 + // Iterator implements [iavldbm.DB]. 75 + func (b *AdaptedDB) Iterator(start []byte, end []byte) (store.Iterator, error) { 76 + i, err := b.underlying.Iterator(start, end) 77 + if err != nil { 78 + return nil, err 79 + } 80 + return &AdaptedIterator{underlying: i}, nil 81 + } 82 + 83 + // ReverseIterator implements [iavldbm.DB]. 84 + func (b *AdaptedDB) ReverseIterator(start []byte, end []byte) (store.Iterator, error) { 85 + i, err := b.underlying.ReverseIterator(start, end) 86 + if err != nil { 87 + return nil, err 88 + } 89 + return &AdaptedIterator{underlying: i}, nil 90 + } 91 + 92 + // NewBatch implements [db.DB]. 93 + func (b *AdaptedDB) NewBatch() store.Batch { 94 + return &AdaptedBatch{b.underlying.NewBatch()} 95 + } 96 + 97 + // NewBatchWithSize implements [db.DB]. 98 + func (b *AdaptedDB) NewBatchWithSize(int) store.Batch { 99 + return &AdaptedBatch{b.underlying.NewBatch()} 100 + } 101 + 102 + type AdaptedBatch struct { 103 + dbm.Batch 104 + } 105 + 106 + // GetByteSize implements [store.Batch]. 107 + func (a *AdaptedBatch) GetByteSize() (int, error) { 108 + return 0, nil 109 + }
+2 -2
go.mod
··· 6 6 cosmossdk.io/core v0.12.1-0.20240725072823-6a2d039e1212 7 7 github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe 8 8 github.com/cometbft/cometbft v0.38.19 9 + github.com/cometbft/cometbft-db v0.14.1 9 10 github.com/cosmos/iavl v1.3.5 10 11 github.com/cosmos/ics23/go v0.10.0 11 - github.com/dgraph-io/badger/v4 v4.9.0 12 12 github.com/did-method-plc/go-didplc v0.0.0-20251125183445-342320c327e2 13 13 github.com/google/uuid v1.6.0 14 14 github.com/ipfs/go-cid v0.4.1 ··· 35 35 github.com/cockroachdb/pebble v1.1.5 // indirect 36 36 github.com/cockroachdb/redact v1.1.5 // indirect 37 37 github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect 38 - github.com/cometbft/cometbft-db v0.14.1 // indirect 39 38 github.com/cosmos/gogoproto v1.7.0 // indirect 40 39 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 41 40 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect 41 + github.com/dgraph-io/badger/v4 v4.9.0 // indirect 42 42 github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect 43 43 github.com/dustin/go-humanize v1.0.1 // indirect 44 44 github.com/emicklei/dot v1.6.2 // indirect
+11 -8
httpapi/server.go
··· 27 27 28 28 "tangled.org/gbl08ma.com/didplcbft/abciapp" 29 29 "tangled.org/gbl08ma.com/didplcbft/plc" 30 + "tangled.org/gbl08ma.com/didplcbft/transaction" 30 31 ) 31 32 32 33 // Server represents the HTTP server for the PLC directory. 33 34 type Server struct { 35 + txFactory *transaction.Factory 34 36 plc plc.ReadPLC 35 37 router *http.ServeMux 36 38 node *node.Node ··· 44 46 } 45 47 46 48 // NewServer creates a new instance of the Server. 47 - func NewServer(plc plc.ReadPLC, node *node.Node, listenAddr string, handlerTimeout time.Duration) (*Server, error) { 49 + func NewServer(txFactory *transaction.Factory, plc plc.ReadPLC, node *node.Node, listenAddr string, handlerTimeout time.Duration) (*Server, error) { 48 50 s := &Server{ 51 + txFactory: txFactory, 49 52 plc: plc, 50 53 router: http.NewServeMux(), 51 54 node: node, ··· 142 145 // handleResolveDID handles the GET /{did} endpoint. 143 146 func (s *Server) handleResolveDID(w http.ResponseWriter, r *http.Request, did string) { 144 147 ctx := context.Background() 145 - doc, err := s.plc.Resolve(ctx, plc.CommittedTreeVersion, did) 148 + doc, err := s.plc.Resolve(ctx, s.txFactory.ReadCommitted(), did) 146 149 if handlePLCError(w, err, did) { 147 150 return 148 151 } ··· 204 207 return 205 208 } 206 209 207 - if err := s.plc.ValidateOperation(r.Context(), plc.CommittedTreeVersion, time.Now(), did, opBytes); err != nil { 210 + if err := s.plc.ValidateOperation(r.Context(), s.txFactory.ReadCommitted(), did, opBytes); err != nil { 208 211 sendErrorResponse(w, http.StatusBadRequest, "Invalid operation") 209 212 return 210 213 } ··· 253 256 254 257 // handleGetPLCLog handles the GET /{did}/log endpoint. 255 258 func (s *Server) handleGetPLCLog(w http.ResponseWriter, r *http.Request, did string) { 256 - ops, err := s.plc.OperationLog(r.Context(), plc.CommittedTreeVersion, did) 259 + ops, err := s.plc.OperationLog(r.Context(), s.txFactory.ReadCommitted(), did) 257 260 if handlePLCError(w, err, did) { 258 261 return 259 262 } ··· 264 267 265 268 // handleGetPLCAuditLog handles the GET /{did}/log/audit endpoint. 266 269 func (s *Server) handleGetPLCAuditLog(w http.ResponseWriter, r *http.Request, did string) { 267 - entries, err := s.plc.AuditLog(r.Context(), plc.CommittedTreeVersion, did) 270 + entries, err := s.plc.AuditLog(r.Context(), s.txFactory.ReadCommitted(), did) 268 271 if handlePLCError(w, err, did) { 269 272 return 270 273 } ··· 275 278 276 279 // handleGetLastOp handles the GET /{did}/log/last endpoint. 277 280 func (s *Server) handleGetLastOp(w http.ResponseWriter, r *http.Request, did string) { 278 - op, err := s.plc.LastOperation(r.Context(), plc.CommittedTreeVersion, did) 281 + op, err := s.plc.LastOperation(r.Context(), s.txFactory.ReadCommitted(), did) 279 282 if handlePLCError(w, err, did) { 280 283 return 281 284 } ··· 286 289 287 290 // handleGetPLCData handles the GET /{did}/data endpoint. 288 291 func (s *Server) handleGetPLCData(w http.ResponseWriter, r *http.Request, did string) { 289 - data, err := s.plc.Data(r.Context(), plc.CommittedTreeVersion, did) 292 + data, err := s.plc.Data(r.Context(), s.txFactory.ReadCommitted(), did) 290 293 if handlePLCError(w, err, did) { 291 294 return 292 295 } ··· 336 339 } 337 340 } 338 341 339 - entries, err := s.plc.Export(r.Context(), plc.CommittedTreeVersion, after, count) 342 + entries, err := s.plc.Export(r.Context(), s.txFactory.ReadCommitted(), after, count) 340 343 if handlePLCError(w, err, "") { 341 344 return 342 345 }
+30 -21
httpapi/server_test.go
··· 10 10 "testing" 11 11 "time" 12 12 13 + "github.com/cosmos/iavl" 14 + dbm "github.com/cosmos/iavl/db" 13 15 "github.com/did-method-plc/go-didplc" 14 16 "github.com/stretchr/testify/require" 15 17 "tangled.org/gbl08ma.com/didplcbft/plc" 18 + "tangled.org/gbl08ma.com/didplcbft/store" 19 + "tangled.org/gbl08ma.com/didplcbft/transaction" 16 20 "tangled.org/gbl08ma.com/didplcbft/types" 17 21 ) 18 22 ··· 22 26 errorType string 23 27 } 24 28 25 - func (m *MockReadPLC) ValidateOperation(ctx context.Context, atHeight plc.TreeVersion, at time.Time, did string, opBytes []byte) error { 29 + func (m *MockReadPLC) ValidateOperation(ctx context.Context, readTx transaction.Read, did string, opBytes []byte) error { 26 30 if m.shouldReturnError { 27 31 switch m.errorType { 28 32 case "notfound": ··· 35 39 return nil 36 40 } 37 41 38 - func (m *MockReadPLC) Resolve(ctx context.Context, atHeight plc.TreeVersion, did string) (didplc.Doc, error) { 42 + func (m *MockReadPLC) Resolve(ctx context.Context, readTx transaction.Read, did string) (didplc.Doc, error) { 39 43 if m.shouldReturnError { 40 44 switch m.errorType { 41 45 case "notfound": ··· 50 54 }, nil 51 55 } 52 56 53 - func (m *MockReadPLC) OperationLog(ctx context.Context, atHeight plc.TreeVersion, did string) ([]didplc.OpEnum, error) { 57 + func (m *MockReadPLC) OperationLog(ctx context.Context, readTx transaction.Read, did string) ([]didplc.OpEnum, error) { 54 58 if m.shouldReturnError { 55 59 if m.errorType == "notfound" { 56 60 return []didplc.OpEnum{}, plc.ErrDIDNotFound ··· 60 64 return []didplc.OpEnum{}, nil 61 65 } 62 66 63 - func (m *MockReadPLC) AuditLog(ctx context.Context, atHeight plc.TreeVersion, did string) ([]didplc.LogEntry, error) { 67 + func (m *MockReadPLC) AuditLog(ctx context.Context, readTx transaction.Read, did string) ([]didplc.LogEntry, error) { 64 68 if m.shouldReturnError { 65 69 if m.errorType == "notfound" { 66 70 return []didplc.LogEntry{}, plc.ErrDIDNotFound ··· 70 74 return []didplc.LogEntry{}, nil 71 75 } 72 76 73 - func (m *MockReadPLC) LastOperation(ctx context.Context, atHeight plc.TreeVersion, did string) (didplc.OpEnum, error) { 77 + func (m *MockReadPLC) LastOperation(ctx context.Context, readTx transaction.Read, did string) (didplc.OpEnum, error) { 74 78 if m.shouldReturnError { 75 79 if m.errorType == "notfound" { 76 80 return didplc.OpEnum{}, plc.ErrDIDNotFound ··· 80 84 return didplc.OpEnum{}, nil 81 85 } 82 86 83 - func (m *MockReadPLC) Data(ctx context.Context, atHeight plc.TreeVersion, did string) (didplc.RegularOp, error) { 87 + func (m *MockReadPLC) Data(ctx context.Context, readTx transaction.Read, did string) (didplc.RegularOp, error) { 84 88 if m.shouldReturnError { 85 89 switch m.errorType { 86 90 case "notfound": ··· 93 97 return didplc.RegularOp{}, nil 94 98 } 95 99 96 - func (m *MockReadPLC) Export(ctx context.Context, atHeight plc.TreeVersion, after uint64, count int) ([]types.SequencedLogEntry, error) { 100 + func (m *MockReadPLC) Export(ctx context.Context, readTx transaction.Read, after uint64, count int) ([]types.SequencedLogEntry, error) { 97 101 if m.shouldReturnError { 98 102 return []types.SequencedLogEntry{}, fmt.Errorf("internal error") 99 103 } ··· 103 107 func TestServer(t *testing.T) { 104 108 mockPLC := &MockReadPLC{} 105 109 110 + // this tree is just to avoid a nil pointer when creating a transaction with the factory 111 + // the transactions don't actually get used 112 + tree := iavl.NewMutableTree(dbm.NewMemDB(), 128, false, iavl.NewNopLogger()) 113 + txFactory := transaction.NewFactory(tree, nil, store.Tree.NextOperationSequence) 114 + 106 115 t.Run("Test Resolve DID", func(t *testing.T) { 107 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 116 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 108 117 require.NoError(t, err) 109 118 110 119 req, err := http.NewRequest("GET", "/did:plc:test", nil) ··· 119 128 120 129 t.Run("Test Resolve DID Not Found", func(t *testing.T) { 121 130 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "notfound"} 122 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 131 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 123 132 require.NoError(t, err) 124 133 125 134 req, err := http.NewRequest("GET", "/did:plc:test", nil) ··· 134 143 135 144 t.Run("Test Resolve DID Gone", func(t *testing.T) { 136 145 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "gone"} 137 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 146 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 138 147 require.NoError(t, err) 139 148 140 149 req, err := http.NewRequest("GET", "/did:plc:test", nil) ··· 149 158 150 159 t.Run("Test Resolve DID Internal Error", func(t *testing.T) { 151 160 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "internal"} 152 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 161 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 153 162 require.NoError(t, err) 154 163 155 164 req, err := http.NewRequest("GET", "/did:plc:test", nil) ··· 163 172 }) 164 173 165 174 t.Run("Test Create PLC Operation", func(t *testing.T) { 166 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 175 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 167 176 require.NoError(t, err) 168 177 169 178 op := map[string]interface{}{ ··· 187 196 }) 188 197 189 198 t.Run("Test Get PLC Log", func(t *testing.T) { 190 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 199 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 191 200 require.NoError(t, err) 192 201 193 202 req, err := http.NewRequest("GET", "/did:plc:test/log", nil) ··· 201 210 202 211 t.Run("Test Get PLC Log Not Found", func(t *testing.T) { 203 212 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "notfound"} 204 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 213 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 205 214 require.NoError(t, err) 206 215 207 216 req, err := http.NewRequest("GET", "/did:plc:test/log", nil) ··· 215 224 }) 216 225 217 226 t.Run("Test Get PLC Audit Log", func(t *testing.T) { 218 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 227 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 219 228 require.NoError(t, err) 220 229 221 230 req, err := http.NewRequest("GET", "/did:plc:test/log/audit", nil) ··· 228 237 }) 229 238 230 239 t.Run("Test Get Last Operation", func(t *testing.T) { 231 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 240 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 232 241 require.NoError(t, err) 233 242 234 243 req, err := http.NewRequest("GET", "/did:plc:test/log/last", nil) ··· 242 251 243 252 t.Run("Test Get Last Operation Internal Error", func(t *testing.T) { 244 253 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "internal"} 245 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 254 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 246 255 require.NoError(t, err) 247 256 248 257 req, err := http.NewRequest("GET", "/did:plc:test/log/last", nil) ··· 256 265 }) 257 266 258 267 t.Run("Test Get PLC Data", func(t *testing.T) { 259 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 268 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 260 269 require.NoError(t, err) 261 270 262 271 req, err := http.NewRequest("GET", "/did:plc:test/data", nil) ··· 270 279 271 280 t.Run("Test Get PLC Data Not Found", func(t *testing.T) { 272 281 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "notfound"} 273 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 282 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 274 283 require.NoError(t, err) 275 284 276 285 req, err := http.NewRequest("GET", "/did:plc:test/data", nil) ··· 284 293 }) 285 294 286 295 t.Run("Test Export", func(t *testing.T) { 287 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 296 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 288 297 require.NoError(t, err) 289 298 290 299 req, err := http.NewRequest("GET", "/export?count=10", nil) ··· 298 307 299 308 t.Run("Test Export Internal Error", func(t *testing.T) { 300 309 mockPLC := &MockReadPLC{shouldReturnError: true, errorType: "internal"} 301 - server, err := NewServer(mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 310 + server, err := NewServer(txFactory, mockPLC, nil, "tcp://127.0.0.1:8080", 15*time.Second) 302 311 require.NoError(t, err) 303 312 304 313 req, err := http.NewRequest("GET", "/export?count=10", nil)
+46 -34
main.go
··· 6 6 "log" 7 7 "os" 8 8 "os/signal" 9 + "path" 9 10 "path/filepath" 10 11 "sync" 11 12 "syscall" 12 13 "time" 13 14 15 + db "github.com/cometbft/cometbft-db" 14 16 "github.com/cometbft/cometbft/p2p" 15 17 "github.com/cometbft/cometbft/privval" 16 18 "github.com/cometbft/cometbft/proxy" ··· 22 24 cmtflags "github.com/cometbft/cometbft/libs/cli/flags" 23 25 cmtlog "github.com/cometbft/cometbft/libs/log" 24 26 nm "github.com/cometbft/cometbft/node" 25 - "github.com/dgraph-io/badger/v4" 26 - "github.com/dgraph-io/badger/v4/options" 27 27 "github.com/spf13/viper" 28 28 ) 29 29 ··· 52 52 if err := config.ValidateBasic(); err != nil { 53 53 log.Fatalf("Invalid configuration data: %v", err) 54 54 } 55 - badgerDBPath := filepath.Join(homeDir, "badger") 56 - badgerDB, err := badger.Open(badger. 57 - DefaultOptions(badgerDBPath). 58 - WithBlockSize(8 * 1024). 59 - WithNumMemtables(3). 60 - WithNumLevelZeroTables(3). 61 - WithCompression(options.ZSTD)) 55 + 56 + var wg sync.WaitGroup 57 + closeGoroutinesCh := make(chan struct{}) 58 + 59 + treeDBContext := &bftconfig.DBContext{ID: "apptree", Config: config.Config} 60 + treeDB, err := bftconfig.DefaultDBProvider(treeDBContext) 62 61 if err != nil { 63 - log.Fatalf("Opening badger database: %v", err) 62 + log.Fatalf("failed to create application database: %v", err) 64 63 } 65 64 66 - for err == nil { 67 - err = badgerDB.RunValueLogGC(0.5) 65 + indexDBContext := &bftconfig.DBContext{ID: "appindex", Config: config.Config} 66 + indexDB, err := bftconfig.DefaultDBProvider(indexDBContext) 67 + if err != nil { 68 + log.Fatalf("failed to create application database: %v", err) 68 69 } 69 70 70 - var wg sync.WaitGroup 71 - closeGoroutinesCh := make(chan struct{}) 72 - wg.Go(func() { 73 - ticker := time.NewTicker(5 * time.Minute) 74 - defer ticker.Stop() 75 - for { 76 - select { 77 - case <-ticker.C: 78 - var err error 79 - for err == nil { 80 - err = badgerDB.RunValueLogGC(0.5) 81 - } 82 - case <-closeGoroutinesCh: 83 - return 84 - } 71 + defer func() { 72 + if err := treeDB.Close(); err != nil { 73 + log.Printf("Closing application tree database: %v", err) 85 74 } 86 - }) 75 + if err := indexDB.Close(); err != nil { 76 + log.Printf("Closing application index database: %v", err) 77 + } 78 + }() 87 79 88 - defer func() { 89 - if err := badgerDB.Close(); err != nil { 90 - log.Printf("Closing badger database: %v", err) 80 + recreateDatabases := func() (db.DB, db.DB) { 81 + if err := treeDB.Close(); err != nil { 82 + log.Printf("Closing application tree database for clearing: %v", err) 83 + } 84 + if err := indexDB.Close(); err != nil { 85 + log.Printf("Closing application index database for clearing: %v", err) 91 86 } 92 - }() 93 87 94 - app, plc, cleanup, err := abciapp.NewDIDPLCApplication(badgerDB, filepath.Join(homeDir, "snapshots")) 88 + // we're depending on an implementation detail of cometbft, but I'm yet to find a more elegant way to do this 89 + _ = os.RemoveAll(path.Join(homeDir, "data/appindex.db")) 90 + _ = os.RemoveAll(path.Join(homeDir, "data/apptree.db")) 91 + 92 + var err error 93 + treeDB, err = bftconfig.DefaultDBProvider(treeDBContext) 94 + if err != nil { 95 + log.Fatalf("failed to create application database: %v", err) 96 + } 97 + 98 + indexDB, err = bftconfig.DefaultDBProvider(indexDBContext) 99 + if err != nil { 100 + log.Fatalf("failed to create application database: %v", err) 101 + } 102 + 103 + return treeDB, indexDB 104 + } 105 + 106 + app, txFactory, plc, cleanup, err := abciapp.NewDIDPLCApplication(treeDB, indexDB, recreateDatabases, filepath.Join(homeDir, "snapshots")) 95 107 if err != nil { 96 108 log.Fatalf("failed to create DIDPLC application: %v", err) 97 109 } ··· 139 151 }() 140 152 141 153 if config.PLC.ListenAddress != "" { 142 - plcAPIServer, err := httpapi.NewServer(plc, node, config.PLC.ListenAddress, 30*time.Second) 154 + plcAPIServer, err := httpapi.NewServer(txFactory, plc, node, config.PLC.ListenAddress, 30*time.Second) 143 155 if err != nil { 144 156 log.Fatalf("Creating PLC API server: %v", err) 145 157 }
+46 -134
plc/impl.go
··· 2 2 3 3 import ( 4 4 "context" 5 - "iter" 6 5 "sync" 7 - "time" 8 6 9 7 "github.com/bluesky-social/indigo/atproto/syntax" 10 - "github.com/cosmos/iavl" 11 8 "github.com/did-method-plc/go-didplc" 12 9 "github.com/palantir/stacktrace" 13 10 "github.com/samber/lo" 14 11 "github.com/samber/mo" 15 12 "tangled.org/gbl08ma.com/didplcbft/store" 13 + "tangled.org/gbl08ma.com/didplcbft/transaction" 16 14 "tangled.org/gbl08ma.com/didplcbft/types" 17 15 ) 18 16 19 - type TreeProvider interface { 20 - MutableTree() (*iavl.MutableTree, error) 21 - ImmutableTree(version TreeVersion) (store.ReadOnlyTree, error) 22 - } 17 + type plcImpl struct { 18 + mu sync.Mutex // probably redundant, but let's keep for now 19 + validator OperationValidator 23 20 24 - type plcImpl struct { 25 - mu sync.Mutex // probably redundant, but let's keep for now 26 - treeProvider TreeProvider 27 - validator OperationValidator 21 + nextSeq uint64 28 22 } 29 23 30 24 var _ PLC = (*plcImpl)(nil) 31 25 32 - func NewPLC(treeProvider TreeProvider) *plcImpl { 33 - p := &plcImpl{ 34 - treeProvider: treeProvider, 35 - } 26 + func NewPLC() *plcImpl { 27 + p := &plcImpl{} 36 28 37 - p.validator = NewV0OperationValidator(&inMemoryAuditLogFetcher{ 38 - plc: p, 39 - }) 29 + p.validator = NewV0OperationValidator() 40 30 return p 41 31 } 42 32 43 - func (plc *plcImpl) ValidateOperation(ctx context.Context, atHeight TreeVersion, at time.Time, did string, opBytes []byte) error { 33 + func (plc *plcImpl) ValidateOperation(ctx context.Context, readTx transaction.Read, did string, opBytes []byte) error { 44 34 plc.mu.Lock() 45 35 defer plc.mu.Unlock() 46 36 47 - timestamp := syntax.Datetime(at.Format(types.ActualAtprotoDatetimeLayout)) 37 + timestamp := syntax.Datetime(readTx.Timestamp().Format(types.ActualAtprotoDatetimeLayout)) 48 38 49 39 // TODO set true to false only while importing old ops 50 - _, err := plc.validator.Validate(ctx, atHeight, timestamp, did, opBytes, true) 40 + _, err := plc.validator.Validate(ctx, readTx, timestamp, did, opBytes, true) 51 41 if err != nil { 52 42 return stacktrace.Propagate(err, "operation failed validation") 53 43 } ··· 55 45 return nil 56 46 } 57 47 58 - func (plc *plcImpl) ExecuteOperation(ctx context.Context, t time.Time, did string, opBytes []byte) error { 48 + func (plc *plcImpl) ExecuteOperation(ctx context.Context, tx transaction.Write, did string, opBytes []byte) error { 59 49 plc.mu.Lock() 60 50 defer plc.mu.Unlock() 61 51 62 - timestamp := syntax.Datetime(t.Format(types.ActualAtprotoDatetimeLayout)) 52 + timestamp := syntax.Datetime(tx.Timestamp().Format(types.ActualAtprotoDatetimeLayout)) 63 53 64 54 // TODO set true to false only while importing old ops 65 - effects, err := plc.validator.Validate(ctx, WorkingTreeVersion, timestamp, did, opBytes, true) 55 + effects, err := plc.validator.Validate(ctx, tx.Downgrade(), timestamp, did, opBytes, true) 66 56 if err != nil { 67 57 return stacktrace.Propagate(err, "operation failed validation") 68 58 } 69 59 70 - tree, err := plc.treeProvider.MutableTree() 71 - if err != nil { 72 - return stacktrace.Propagate(err, "failed to obtain mutable tree") 73 - } 74 - 75 - err = store.Tree.StoreOperation(tree, effects.NewLogEntry, effects.NullifiedEntriesStartingIndex) 60 + err = store.Tree.StoreOperation(ctx, tx, effects.NewLogEntry, effects.NullifiedEntriesStartingSeq) 76 61 if err != nil { 77 62 return stacktrace.Propagate(err, "failed to commit operation") 78 63 } 79 64 80 65 return nil 81 66 } 82 - 83 - func (plc *plcImpl) ImportOperationsFromAuthoritativeSource(ctx context.Context, newEntries []didplc.LogEntry) error { 67 + func (plc *plcImpl) ImportOperationFromAuthoritativeSource(ctx context.Context, tx transaction.Write, newEntry didplc.LogEntry) error { 84 68 plc.mu.Lock() 85 69 defer plc.mu.Unlock() 86 70 87 - tree, err := plc.treeProvider.MutableTree() 88 - if err != nil { 89 - return stacktrace.Propagate(err, "failed to obtain mutable tree") 90 - } 91 - 92 - for _, entry := range newEntries { 93 - err := plc.importOp(ctx, tree, entry) 94 - if err != nil { 95 - return stacktrace.Propagate(err, "") 96 - } 97 - } 98 - 99 - return nil 100 - } 101 - 102 - func (plc *plcImpl) ImportOperationFromAuthoritativeSource(ctx context.Context, newEntry didplc.LogEntry) error { 103 - plc.mu.Lock() 104 - defer plc.mu.Unlock() 105 - 106 - tree, err := plc.treeProvider.MutableTree() 107 - if err != nil { 108 - return stacktrace.Propagate(err, "failed to obtain mutable tree") 109 - } 110 - 111 - return stacktrace.Propagate(plc.importOp(ctx, tree, newEntry), "") 112 - } 113 - 114 - func (plc *plcImpl) importOp(ctx context.Context, tree *iavl.MutableTree, newEntry didplc.LogEntry) error { 115 71 newCID := newEntry.CID 116 72 newPrev := newEntry.Operation.AsOperation().PrevCIDStr() 117 73 118 - mostRecentOpIndex := -1 119 - indexOfPrev := -1 74 + hasExistingOps := false 75 + var seqOfPrev mo.Option[uint64] 76 + nullifiedEntriesStartingSeq := mo.None[uint64]() 77 + 120 78 var iteratorErr error 121 - for entryIdx, entry := range store.Tree.AuditLogReverseIterator(ctx, tree, newEntry.DID, &iteratorErr) { 79 + for entry := range store.Tree.AuditLogReverseIterator(ctx, tx.Downgrade(), newEntry.DID, &iteratorErr) { 122 80 entryCID := entry.CID.String() 123 - if mostRecentOpIndex == -1 { 124 - mostRecentOpIndex = entryIdx 81 + if !hasExistingOps { 82 + hasExistingOps = true 125 83 126 84 if newPrev == "" && entryCID != newCID { 127 85 // this should never happen unless the authoritative source doesn't compute DIDs from genesis ops the way we do ··· 139 97 } 140 98 141 99 return stacktrace.Propagate( 142 - store.Tree.SetOperationCreatedAt(tree, entry.Seq, newCreatedAtDT.Time()), 100 + store.Tree.SetOperationCreatedAt(tx, entry.Seq, newCreatedAtDT.Time()), 143 101 "") 144 102 } 145 103 146 104 if entryCID == newPrev { 147 - indexOfPrev = entryIdx 105 + seqOfPrev = mo.Some(entry.Seq) 148 106 break 107 + } else { 108 + // we only get here if there's an operation between the new latest and prev 109 + // this will keep decreasing until we find prev, at which point it'll have the lowest seq for this DID before that of prev 110 + nullifiedEntriesStartingSeq = mo.Some(entry.Seq) 149 111 } 150 112 } 151 113 if iteratorErr != nil { 152 114 return stacktrace.Propagate(iteratorErr, "") 153 115 } 154 116 155 - nullifiedEntriesStartingIndex := mo.None[int]() 156 - 157 - if mostRecentOpIndex < 0 { 117 + if !hasExistingOps { 158 118 // we have nothing for this DID - this should be a creation op, if not, then we're not importing things in order 159 119 if newPrev != "" { 160 120 return stacktrace.NewError("invalid internal state reached") ··· 163 123 // there's nothing to do but store the operation, no nullification involved 164 124 newEntry.Nullified = false 165 125 166 - err := store.Tree.StoreOperation(tree, newEntry, nullifiedEntriesStartingIndex) 126 + err := store.Tree.StoreOperation(ctx, tx, newEntry, nullifiedEntriesStartingSeq) 167 127 return stacktrace.Propagate(err, "failed to commit operation") 168 128 } 169 129 170 - if indexOfPrev < 0 { 130 + if !seqOfPrev.IsPresent() { 171 131 // there are entries in the audit log but none of them has a CID matching prev 172 132 // if this isn't a creation op, then this shouldn't happen 173 133 // (even when history forks between us and the authoritative source, at least the initial op should be the same, otherwise the DIDs wouldn't match) 174 134 // if this is a creation op, then this case should have been caught above 175 - return stacktrace.NewError("invalid internal state reached") 176 - } 177 - 178 - if indexOfPrev+1 <= mostRecentOpIndex { 179 - nullifiedEntriesStartingIndex = mo.Some(indexOfPrev + 1) 135 + return stacktrace.NewError("invalid internal state reached, %+v", newEntry) 180 136 } 181 137 182 138 newEntry.Nullified = false 183 - err := store.Tree.StoreOperation(tree, newEntry, nullifiedEntriesStartingIndex) 139 + err := store.Tree.StoreOperation(ctx, tx, newEntry, nullifiedEntriesStartingSeq) 184 140 return stacktrace.Propagate(err, "failed to commit operation") 185 141 } 186 142 187 - func (plc *plcImpl) Resolve(ctx context.Context, atHeight TreeVersion, did string) (didplc.Doc, error) { 143 + func (plc *plcImpl) Resolve(ctx context.Context, tx transaction.Read, did string) (didplc.Doc, error) { 188 144 plc.mu.Lock() 189 145 defer plc.mu.Unlock() 190 146 191 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 192 - if err != nil { 193 - return didplc.Doc{}, stacktrace.Propagate(err, "failed to obtain immutable tree") 194 - } 195 - 196 - l, _, err := store.Tree.AuditLog(ctx, tree, did, false) 147 + l, _, err := store.Tree.AuditLog(ctx, tx, did, false) 197 148 if err != nil { 198 149 return didplc.Doc{}, stacktrace.Propagate(err, "") 199 150 } ··· 209 160 return opEnum.AsOperation().Doc(did) 210 161 } 211 162 212 - func (plc *plcImpl) OperationLog(ctx context.Context, atHeight TreeVersion, did string) ([]didplc.OpEnum, error) { 163 + func (plc *plcImpl) OperationLog(ctx context.Context, tx transaction.Read, did string) ([]didplc.OpEnum, error) { 213 164 // GetPlcOpLog - /:did/log - same data as audit log but excludes nullified. just the inner operations 214 165 // if missing -> returns ErrDIDNotFound 215 166 // if tombstone -> returns log as normal ··· 217 168 plc.mu.Lock() 218 169 defer plc.mu.Unlock() 219 170 220 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 221 - if err != nil { 222 - return nil, stacktrace.Propagate(err, "failed to obtain immutable tree") 223 - } 224 - 225 - l, _, err := store.Tree.AuditLog(ctx, tree, did, false) 171 + l, _, err := store.Tree.AuditLog(ctx, tx, did, false) 226 172 if err != nil { 227 173 return nil, stacktrace.Propagate(err, "") 228 174 } ··· 240 186 }), nil 241 187 } 242 188 243 - func (plc *plcImpl) AuditLog(ctx context.Context, atHeight TreeVersion, did string) ([]didplc.LogEntry, error) { 189 + func (plc *plcImpl) AuditLog(ctx context.Context, tx transaction.Read, did string) ([]didplc.LogEntry, error) { 244 190 // GetPlcAuditLog - /:did/log/audit - full audit log, with nullified 245 191 // if missing -> returns ErrDIDNotFound 246 192 // if tombstone -> returns log as normal 247 193 plc.mu.Lock() 248 194 defer plc.mu.Unlock() 249 195 250 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 251 - if err != nil { 252 - return nil, stacktrace.Propagate(err, "failed to obtain immutable tree") 253 - } 254 - 255 - l, _, err := store.Tree.AuditLog(ctx, tree, did, false) 196 + l, _, err := store.Tree.AuditLog(ctx, tx, did, false) 256 197 if err != nil { 257 198 return nil, stacktrace.Propagate(err, "") 258 199 } ··· 266 207 }), nil 267 208 } 268 209 269 - func (plc *plcImpl) LastOperation(ctx context.Context, atHeight TreeVersion, did string) (didplc.OpEnum, error) { 210 + func (plc *plcImpl) LastOperation(ctx context.Context, tx transaction.Read, did string) (didplc.OpEnum, error) { 270 211 // GetLastOp - /:did/log/last - latest op from audit log which isn't nullified (the latest op is guaranteed not to be nullified) 271 212 // if missing -> returns ErrDIDNotFound 272 213 // if tombstone -> returns tombstone op 273 214 plc.mu.Lock() 274 215 defer plc.mu.Unlock() 275 216 276 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 277 - if err != nil { 278 - return didplc.OpEnum{}, stacktrace.Propagate(err, "failed to obtain immutable tree") 279 - } 280 - 281 - l, _, err := store.Tree.AuditLog(ctx, tree, did, false) 217 + l, _, err := store.Tree.AuditLog(ctx, tx, did, false) 282 218 if err != nil { 283 219 return didplc.OpEnum{}, stacktrace.Propagate(err, "") 284 220 } ··· 290 226 return l[len(l)-1].Operation, nil 291 227 } 292 228 293 - func (plc *plcImpl) Data(ctx context.Context, atHeight TreeVersion, did string) (didplc.RegularOp, error) { 229 + func (plc *plcImpl) Data(ctx context.Context, tx transaction.Read, did string) (didplc.RegularOp, error) { 294 230 // GetPlcData - /:did/data - similar to GetLastOp but applies a transformation on the op which normalizes it into a modern op 295 231 // if missing -> returns ErrDIDNotFound 296 232 // if tombstone -> returns ErrDIDGone 297 233 plc.mu.Lock() 298 234 defer plc.mu.Unlock() 299 235 300 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 301 - if err != nil { 302 - return didplc.RegularOp{}, stacktrace.Propagate(err, "failed to obtain immutable tree") 303 - } 304 - 305 - l, _, err := store.Tree.AuditLog(ctx, tree, did, false) 236 + l, _, err := store.Tree.AuditLog(ctx, tx, did, false) 306 237 if err != nil { 307 238 return didplc.RegularOp{}, stacktrace.Propagate(err, "") 308 239 } ··· 322 253 323 254 } 324 255 325 - func (plc *plcImpl) Export(ctx context.Context, atHeight TreeVersion, after uint64, count int) ([]types.SequencedLogEntry, error) { 256 + func (plc *plcImpl) Export(ctx context.Context, tx transaction.Read, after uint64, count int) ([]types.SequencedLogEntry, error) { 326 257 plc.mu.Lock() 327 258 defer plc.mu.Unlock() 328 259 329 - tree, err := plc.treeProvider.ImmutableTree(atHeight) 330 - if err != nil { 331 - return nil, stacktrace.Propagate(err, "failed to obtain immutable tree") 332 - } 333 - 334 - entries, err := store.Tree.ExportOperations(ctx, tree, after, count) 260 + entries, err := store.Tree.ExportOperations(ctx, tx, after, count) 335 261 return entries, stacktrace.Propagate(err, "") 336 262 } 337 - 338 - type inMemoryAuditLogFetcher struct { 339 - plc *plcImpl 340 - } 341 - 342 - func (a *inMemoryAuditLogFetcher) AuditLogReverseIterator(ctx context.Context, atHeight TreeVersion, did string, retErr *error) iter.Seq2[int, types.SequencedLogEntry] { 343 - tree, err := a.plc.treeProvider.ImmutableTree(atHeight) 344 - if err != nil { 345 - *retErr = stacktrace.Propagate(err, "") 346 - return func(yield func(int, types.SequencedLogEntry) bool) {} 347 - } 348 - 349 - return store.Tree.AuditLogReverseIterator(ctx, tree, did, retErr) 350 - }
+42 -88
plc/operation_validator.go
··· 3 3 import ( 4 4 "context" 5 5 "errors" 6 - "iter" 6 + "slices" 7 7 "strings" 8 8 "time" 9 9 ··· 11 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 12 "github.com/did-method-plc/go-didplc" 13 13 "github.com/palantir/stacktrace" 14 + "github.com/samber/lo" 14 15 "github.com/samber/mo" 16 + "tangled.org/gbl08ma.com/didplcbft/store" 17 + "tangled.org/gbl08ma.com/didplcbft/transaction" 15 18 "tangled.org/gbl08ma.com/didplcbft/types" 16 19 ) 17 20 18 - type AuditLogFetcher interface { 19 - // AuditLogReverseIterator should return an iterator over the list of log entries for the specified DID, in reverse 20 - AuditLogReverseIterator(ctx context.Context, atHeight TreeVersion, did string, err *error) iter.Seq2[int, types.SequencedLogEntry] 21 - } 22 - 23 - type V0OperationValidator struct { 24 - auditLogFetcher AuditLogFetcher 25 - } 21 + type V0OperationValidator struct{} 26 22 27 - func NewV0OperationValidator(logFetcher AuditLogFetcher) *V0OperationValidator { 28 - return &V0OperationValidator{ 29 - auditLogFetcher: logFetcher, 30 - } 23 + func NewV0OperationValidator() *V0OperationValidator { 24 + return &V0OperationValidator{} 31 25 } 32 26 33 27 type OperationEffects struct { 34 - NullifiedEntriesStartingIndex mo.Option[int] 35 - NewLogEntry didplc.LogEntry 28 + NullifiedEntriesStartingSeq mo.Option[uint64] 29 + NewLogEntry didplc.LogEntry 36 30 } 37 31 38 32 // Validate returns the new complete AuditLog that the DID history would assume if validation passes, and an error if it doesn't pass 39 - func (v *V0OperationValidator) Validate(ctx context.Context, atHeight TreeVersion, timestamp syntax.Datetime, expectedDid string, opBytes []byte, laxChecking bool) (OperationEffects, error) { 33 + func (v *V0OperationValidator) Validate(ctx context.Context, readTx transaction.Read, timestamp syntax.Datetime, expectedDid string, opBytes []byte, laxChecking bool) (OperationEffects, error) { 40 34 opEnum, op, err := unmarshalOp(opBytes) 41 35 if err != nil { 42 36 return OperationEffects{}, stacktrace.Propagate(errors.Join(ErrMalformedOperation, err), "") ··· 76 70 77 71 proposedPrev := op.PrevCIDStr() 78 72 79 - partialLog := make(map[int]types.SequencedLogEntry) 80 - mostRecentOpIndex := -1 81 - indexOfPrev := -1 73 + hasExistingOps := false 74 + var proposedPrevOp mo.Option[types.SequencedLogEntry] 75 + 76 + relevantExistingEntries := []types.SequencedLogEntry{} 77 + nullifiedEntries := []types.SequencedLogEntry{} 78 + nullifiedEntriesStartingSeq := mo.None[uint64]() 79 + 82 80 var iteratorErr error 83 - for entryIdx, entry := range v.auditLogFetcher.AuditLogReverseIterator(ctx, atHeight, expectedDid, &iteratorErr) { 84 - partialLog[entryIdx] = entry 85 - if mostRecentOpIndex == -1 { 86 - mostRecentOpIndex = entryIdx 87 - 81 + for entry := range store.Tree.AuditLogReverseIterator(ctx, readTx, expectedDid, &iteratorErr) { 82 + relevantExistingEntries = append(relevantExistingEntries, entry) 83 + if !hasExistingOps { 84 + hasExistingOps = true 88 85 if proposedPrev == "" { 89 86 return OperationEffects{}, stacktrace.Propagate(ErrInvalidOperationSequence, "creation operation not allowed as DID already exists") 90 87 } 91 88 } 92 89 93 90 if entry.CID.String() == proposedPrev { 94 - indexOfPrev = entryIdx 91 + // TODO confirm what should happen if proposedPrev points to a nullified operation. we should probably be ignoring nullified ops here, but confirm with the reference impl 92 + proposedPrevOp = mo.Some(entry) 95 93 break 94 + } else { 95 + // we only get here if there's an operation between the new latest and prev 96 + // this will keep decreasing until we find prev, at which point it'll have the lowest seq for this DID before that of prev 97 + nullifiedEntries = append(nullifiedEntries, entry) 98 + nullifiedEntriesStartingSeq = mo.Some(entry.Seq) 96 99 } 97 100 } 98 101 ··· 100 103 return OperationEffects{}, stacktrace.Propagate(iteratorErr, "") 101 104 } 102 105 103 - nullifiedEntries := []types.SequencedLogEntry{} 104 - nullifiedEntriesStartingIndex := mo.None[int]() 106 + // reverse entries to be in the order they were actually added 107 + slices.Reverse(relevantExistingEntries) 108 + slices.Reverse(nullifiedEntries) 105 109 106 - if mostRecentOpIndex < 0 { 110 + newOperationCID := op.CID() 111 + 112 + if !hasExistingOps { 107 113 // we are expecting a creation op, validate it like so 108 - newOperationCID := op.CID() 109 114 newEntry := didplc.LogEntry{ 110 115 DID: expectedDid, 111 116 Operation: opEnum, ··· 119 124 err = NewInvalidOperationError(4023, err) 120 125 return OperationEffects{}, stacktrace.Propagate(err, "invalid operation") 121 126 } 122 - } else if indexOfPrev < 0 { 127 + } else if !proposedPrevOp.IsPresent() { 123 128 // there are entries in the audit log but none of them has a CID matching prev 124 129 return OperationEffects{}, stacktrace.Propagate(ErrInvalidPrev, "") 125 130 } else { 126 131 // we've found the targeted prev operation 127 132 128 133 // timestamps must increase monotonically 129 - mostRecentOp := partialLog[mostRecentOpIndex] 134 + mostRecentOp := relevantExistingEntries[len(relevantExistingEntries)-1] 130 135 if !timestamp.Time().After(mostRecentOp.CreatedAt) { 131 136 return OperationEffects{}, stacktrace.Propagate(ErrInvalidOperationSequence, "") 132 137 } 133 138 134 139 // if we are forking history, these are the ops still in the proposed canonical history 135 - 136 - lastOpEntry := partialLog[indexOfPrev] 137 - lastOp := lastOpEntry.Operation.AsOperation() 140 + lastOp := lo.ToPtr(proposedPrevOp.MustGet().Operation).AsOperation() 138 141 lastOpRotationKeys := lastOp.EquivalentRotationKeys() 139 - for i := indexOfPrev + 1; i <= mostRecentOpIndex; i++ { 140 - nullifiedEntries = append(nullifiedEntries, partialLog[i]) 141 - } 142 142 if len(nullifiedEntries) > 0 { 143 - nullifiedEntriesStartingIndex = mo.Some(indexOfPrev + 1) 144 - 145 143 disputedSignerIdx, err := didplc.VerifySignatureAny(nullifiedEntries[0].Operation.AsOperation(), lastOpRotationKeys) 146 144 if err != nil { 147 145 return OperationEffects{}, stacktrace.Propagate(err, "reached invalid internal state") ··· 166 164 } 167 165 } 168 166 169 - newOperationCID := op.CID() 170 167 newEntry := didplc.LogEntry{ 171 168 DID: expectedDid, 172 169 Operation: opEnum, ··· 179 176 if len(nullifiedEntries) == 0 { 180 177 // (see prior note on september27Of2023) 181 178 if !laxChecking && timestamp.Time().After(september29Of2023) { 182 - err = v.EnforceOpsRateLimit(ctx, atHeight, expectedDid, timestamp.Time()) 179 + err = v.EnforceOpsRateLimit(ctx, readTx, expectedDid, timestamp.Time()) 183 180 if err != nil { 184 181 return OperationEffects{}, stacktrace.Propagate(err, "") 185 182 } ··· 187 184 } 188 185 189 186 return OperationEffects{ 190 - NullifiedEntriesStartingIndex: nullifiedEntriesStartingIndex, 191 - NewLogEntry: newEntry, 187 + NullifiedEntriesStartingSeq: nullifiedEntriesStartingSeq, 188 + NewLogEntry: newEntry, 192 189 }, nil 193 190 } 194 191 ··· 214 211 ) 215 212 216 213 // EnforceOpsRateLimit is ported from the TypeScript enforceOpsRateLimit function, adapted to not require fetching the entire log 217 - func (v *V0OperationValidator) EnforceOpsRateLimit(ctx context.Context, atHeight TreeVersion, did string, newOperationTimestamp time.Time) error { 214 + func (v *V0OperationValidator) EnforceOpsRateLimit(ctx context.Context, tx transaction.Read, did string, newOperationTimestamp time.Time) error { 218 215 hourAgo := newOperationTimestamp.Add(-time.Hour) 219 216 dayAgo := newOperationTimestamp.Add(-24 * time.Hour) 220 217 weekAgo := newOperationTimestamp.Add(-7 * 24 * time.Hour) 221 218 222 219 var withinHour, withinDay, withinWeek int 223 220 var err error 224 - for _, entry := range v.auditLogFetcher.AuditLogReverseIterator(ctx, atHeight, did, &err) { 221 + for entry := range store.Tree.AuditLogReverseIterator(ctx, tx, did, &err) { 225 222 if entry.Nullified { 226 223 // The typescript implementation operates over a `ops` array which doesn't include nullified ops 227 224 // (With recovery ops also skipping rate limits, doesn't this leave the PLC vulnerable to the spam of constant recovery operations? TODO investigate) ··· 256 253 } 257 254 } 258 255 return stacktrace.Propagate(err, "") 259 - } 260 - 261 - // EnforceOpsRateLimit checks whether a slice of log entries exceeds rate limits 262 - // This method is ported from the TypeScript enforceOpsRateLimit function 263 - func EnforceOpsRateLimit(ops []didplc.LogEntry) error { 264 - now := time.Now() 265 - hourAgo := now.Add(-time.Hour) 266 - dayAgo := now.Add(-24 * time.Hour) 267 - weekAgo := now.Add(-7 * 24 * time.Hour) 268 - 269 - var withinHour, withinDay, withinWeek int 270 - 271 - for _, op := range ops { 272 - // Parse the CreatedAt timestamp string 273 - // The CreatedAt field is stored as a string in ISO 8601 format 274 - opDatetime, err := syntax.ParseDatetime(op.CreatedAt) 275 - if err != nil { 276 - // If parsing fails, skip this operation for rate limiting 277 - continue 278 - } 279 - opTime := opDatetime.Time() 280 - 281 - if opTime.After(weekAgo) { 282 - withinWeek++ 283 - if withinWeek >= WeekLimit { 284 - return stacktrace.Propagate(ErrRateLimitExceeded, "too many operations within last week (max %d)", WeekLimit) 285 - } 286 - } 287 - if opTime.After(dayAgo) { 288 - withinDay++ 289 - if withinDay >= DayLimit { 290 - return stacktrace.Propagate(ErrRateLimitExceeded, "too many operations within last day (max %d)", DayLimit) 291 - } 292 - } 293 - if opTime.After(hourAgo) { 294 - withinHour++ 295 - if withinHour >= HourLimit { 296 - return stacktrace.Propagate(ErrRateLimitExceeded, "too many operations within last hour (max %d)", HourLimit) 297 - } 298 - } 299 - } 300 - 301 - return nil 302 256 } 303 257 304 258 func (v *V0OperationValidator) validateOperationConstraints(createdAt time.Time, op didplc.Operation) error {
+11 -44
plc/plc.go
··· 3 3 import ( 4 4 "context" 5 5 "errors" 6 - "time" 7 6 8 7 "github.com/bluesky-social/indigo/atproto/syntax" 9 8 "github.com/did-method-plc/go-didplc" 9 + "tangled.org/gbl08ma.com/didplcbft/transaction" 10 10 "tangled.org/gbl08ma.com/didplcbft/types" 11 11 ) 12 12 13 13 var ErrDIDNotFound = errors.New("DID not found") 14 14 var ErrDIDGone = errors.New("DID deactivated") 15 15 16 - type TreeVersion struct { 17 - workingHeight bool 18 - committedHeight bool 19 - specificHeight int64 20 - } 21 - 22 - func (v TreeVersion) IsMutable() bool { 23 - return v.workingHeight 24 - } 25 - 26 - func (v TreeVersion) IsCommitted() bool { 27 - return v.committedHeight 28 - } 29 - 30 - func (v TreeVersion) SpecificVersion() (int64, bool) { 31 - return v.specificHeight, !v.workingHeight && !v.committedHeight 32 - } 33 - 34 - var WorkingTreeVersion = TreeVersion{ 35 - workingHeight: true, 36 - } 37 - 38 - var CommittedTreeVersion = TreeVersion{ 39 - committedHeight: true, 40 - } 41 - 42 - func SpecificTreeVersion(height int64) TreeVersion { 43 - return TreeVersion{ 44 - specificHeight: height, 45 - } 46 - } 47 - 48 16 type OperationValidator interface { 49 - Validate(ctx context.Context, atHeight TreeVersion, timestamp syntax.Datetime, expectedDid string, opBytes []byte, allowLegacy bool) (OperationEffects, error) 17 + Validate(ctx context.Context, tx transaction.Read, timestamp syntax.Datetime, expectedDid string, opBytes []byte, allowLegacy bool) (OperationEffects, error) 50 18 } 51 19 52 20 type PLC interface { ··· 55 23 } 56 24 57 25 type ReadPLC interface { 58 - ValidateOperation(ctx context.Context, atHeight TreeVersion, at time.Time, did string, opBytes []byte) error 59 - Resolve(ctx context.Context, atHeight TreeVersion, did string) (didplc.Doc, error) 60 - OperationLog(ctx context.Context, atHeight TreeVersion, did string) ([]didplc.OpEnum, error) 61 - AuditLog(ctx context.Context, atHeight TreeVersion, did string) ([]didplc.LogEntry, error) 62 - LastOperation(ctx context.Context, atHeight TreeVersion, did string) (didplc.OpEnum, error) 63 - Data(ctx context.Context, atHeight TreeVersion, did string) (didplc.RegularOp, error) 64 - Export(ctx context.Context, atHeight TreeVersion, after uint64, count int) ([]types.SequencedLogEntry, error) 26 + ValidateOperation(ctx context.Context, tx transaction.Read, did string, opBytes []byte) error 27 + Resolve(ctx context.Context, tx transaction.Read, did string) (didplc.Doc, error) 28 + OperationLog(ctx context.Context, tx transaction.Read, did string) ([]didplc.OpEnum, error) 29 + AuditLog(ctx context.Context, tx transaction.Read, did string) ([]didplc.LogEntry, error) 30 + LastOperation(ctx context.Context, tx transaction.Read, did string) (didplc.OpEnum, error) 31 + Data(ctx context.Context, tx transaction.Read, did string) (didplc.RegularOp, error) 32 + Export(ctx context.Context, tx transaction.Read, after uint64, count int) ([]types.SequencedLogEntry, error) 65 33 } 66 34 67 35 type WritePLC interface { 68 - ExecuteOperation(ctx context.Context, timestamp time.Time, did string, opBytes []byte) error 69 - ImportOperationFromAuthoritativeSource(ctx context.Context, entry didplc.LogEntry) error 70 - ImportOperationsFromAuthoritativeSource(ctx context.Context, entries []didplc.LogEntry) error 36 + ExecuteOperation(ctx context.Context, tx transaction.Write, did string, opBytes []byte) error 37 + ImportOperationFromAuthoritativeSource(ctx context.Context, tx transaction.Write, entry didplc.LogEntry) error 71 38 }
+68 -139
plc/plc_test.go
··· 12 12 13 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 14 "github.com/did-method-plc/go-didplc" 15 - "github.com/samber/lo" 16 15 "github.com/stretchr/testify/require" 17 16 "tangled.org/gbl08ma.com/didplcbft/plc" 18 17 "tangled.org/gbl08ma.com/didplcbft/types" ··· 141 140 142 141 ctx := t.Context() 143 142 144 - treeProvider := NewTestTreeProvider() 145 - testPLC := plc.NewPLC(treeProvider) 143 + txFactory, tree, _ := NewTestTxFactory() 144 + testPLC := plc.NewPLC() 146 145 147 - tree, err := treeProvider.MutableTree() 148 - require.NoError(t, err) 149 - _, origVersion, err := tree.SaveVersion() 150 - require.NoError(t, err) 146 + origVersion := tree.Version() 151 147 152 148 // resolving a unknown DID should return an error 153 - _, err = testPLC.Resolve(ctx, plc.WorkingTreeVersion, "did:plc:y5gazb6lrsk3j4riiro62zjn") 149 + _, err := testPLC.Resolve(ctx, txFactory.ReadCommitted(), "did:plc:y5gazb6lrsk3j4riiro62zjn") 154 150 require.ErrorIs(t, err, plc.ErrDIDNotFound) 155 151 156 152 for _, c := range operations { 157 - err := testPLC.ExecuteOperation(ctx, c.ApplyAt.Time(), c.DID, []byte(c.Operation)) 153 + tx, err := txFactory.ReadWorking(c.ApplyAt.Time()).Upgrade() 154 + require.NoError(t, err) 155 + 156 + err = testPLC.ExecuteOperation(ctx, tx, c.DID, []byte(c.Operation)) 158 157 if c.ExpectFailure { 159 158 require.Error(t, err) 160 159 } else { 161 160 require.NoError(t, err) 162 161 } 163 - _, _, err = tree.SaveVersion() 162 + err = tx.Commit() 164 163 require.NoError(t, err) 165 164 } 165 + 166 + readTx := txFactory.ReadCommitted() 166 167 167 168 // now try resolving the DID, should return the document with the latest state 168 - doc, err := testPLC.Resolve(ctx, plc.WorkingTreeVersion, testDID) 169 + doc, err := testPLC.Resolve(ctx, readTx, testDID) 169 170 require.NoError(t, err) 170 171 require.Equal(t, testDID, doc.ID) 171 172 require.Len(t, doc.Service, 2) 172 173 require.Equal(t, []string{"at://pds.labeler.tny.im"}, doc.AlsoKnownAs) 173 174 174 - log, err := testPLC.OperationLog(ctx, plc.WorkingTreeVersion, testDID) 175 + log, err := testPLC.OperationLog(ctx, readTx, testDID) 175 176 require.NoError(t, err) 176 177 require.Len(t, log, 3) 177 178 require.Equal(t, "bafyreifgafcel2okxszhgbugieyvtmfig2gtf3dgqoh5fvdh3nlh6ncv6q", log[0].AsOperation().CID().String()) 178 179 require.Equal(t, "bafyreia6ewwkwjgly6dijfepaq2ey6zximodbtqqi5f6fyugli3cxohn5m", log[1].AsOperation().CID().String()) 179 180 require.Equal(t, "bafyreigyzl2esgnk7nvav5myvgywbshdmatzthc73iiar7tyeq3xjt47m4", log[2].AsOperation().CID().String()) 180 181 181 - log, err = testPLC.OperationLog(ctx, plc.SpecificTreeVersion(origVersion+2), testDID) 182 + readTx, err = txFactory.ReadHeight(time.Now(), origVersion+2) 183 + require.NoError(t, err) 184 + log, err = testPLC.OperationLog(ctx, readTx, testDID) 182 185 require.NoError(t, err) 183 186 require.Len(t, log, 1) 184 187 require.Equal(t, "bafyreifgafcel2okxszhgbugieyvtmfig2gtf3dgqoh5fvdh3nlh6ncv6q", log[0].AsOperation().CID().String()) 185 188 186 189 // the DID should still be not found in older versions of the tree 187 - _, err = testPLC.Resolve(ctx, plc.SpecificTreeVersion(origVersion), testDID) 190 + readTx, err = txFactory.ReadHeight(time.Now(), origVersion) 191 + require.NoError(t, err) 192 + _, err = testPLC.Resolve(ctx, readTx, testDID) 188 193 require.ErrorIs(t, err, plc.ErrDIDNotFound) 189 194 190 - doc, err = testPLC.Resolve(ctx, plc.SpecificTreeVersion(origVersion+4), testDID) 195 + readTx, err = txFactory.ReadHeight(time.Now(), origVersion+4) 196 + require.NoError(t, err) 197 + doc, err = testPLC.Resolve(ctx, readTx, testDID) 191 198 require.NoError(t, err) 192 199 193 - export, err := testPLC.Export(ctx, plc.CommittedTreeVersion, 0, 1000) 200 + export, err := testPLC.Export(ctx, txFactory.ReadCommitted(), 0, 1000) 194 201 require.NoError(t, err) 195 202 require.Len(t, export, 3) 196 203 ··· 202 209 require.Equal(t, "bafyreigyzl2esgnk7nvav5myvgywbshdmatzthc73iiar7tyeq3xjt47m4", export[2].CID.String()) 203 210 204 211 // the after parameter is exclusive, with a limit of 1, we should just get the second successful operation 205 - export, err = testPLC.Export(ctx, plc.CommittedTreeVersion, export[0].Seq, 1) 212 + export, err = testPLC.Export(ctx, txFactory.ReadCommitted(), export[0].Seq, 1) 206 213 require.NoError(t, err) 207 214 require.Len(t, export, 1) 208 215 require.Equal(t, "bafyreia6ewwkwjgly6dijfepaq2ey6zximodbtqqi5f6fyugli3cxohn5m", export[0].CID.String()) ··· 234 241 235 242 ctx := t.Context() 236 243 237 - treeProvider := NewTestTreeProvider() 238 - testPLC := plc.NewPLC(treeProvider) 244 + txFactory, _, _ := NewTestTxFactory() 245 + testPLC := plc.NewPLC() 239 246 240 247 for _, auditLog := range remoteLogs { 241 248 for _, logEntry := range auditLog { ··· 244 251 245 252 at := syntax.Datetime(logEntry.CreatedAt).Time() 246 253 247 - err = testPLC.ValidateOperation(ctx, plc.WorkingTreeVersion, at, logEntry.DID, b) 254 + readTx := txFactory.ReadWorking(at) 255 + writeTx, err := readTx.Upgrade() 248 256 require.NoError(t, err) 249 257 250 - err = testPLC.ExecuteOperation(ctx, at, logEntry.DID, b) 258 + err = testPLC.ValidateOperation(ctx, readTx, logEntry.DID, b) 251 259 require.NoError(t, err) 252 260 253 - err = testPLC.ExecuteOperation(ctx, at, logEntry.DID, b) 261 + err = testPLC.ExecuteOperation(ctx, writeTx, logEntry.DID, b) 262 + require.NoError(t, err) 263 + 264 + err = testPLC.ExecuteOperation(ctx, writeTx, logEntry.DID, b) 254 265 // committing the same operation twice should never work, 255 266 // as though even in non-genesis ops the referenced prev will exist, 256 267 // (and thus could seem like a recovery operation at first glance) 257 268 // valid recovery operations must be signed by a key with a lower index in the rotationKeys array 258 269 // than the one which signed the operation to be invalidated 259 270 require.Error(t, err) 271 + 272 + err = writeTx.Commit() 273 + require.NoError(t, err) 260 274 } 261 275 } 262 276 263 - _, newVersion, err := lo.Must(treeProvider.MutableTree()).SaveVersion() 264 - require.NoError(t, err) 265 - 266 277 for i, testDID := range testDIDs { 267 - doc, err := testPLC.Resolve(ctx, plc.SpecificTreeVersion(newVersion), testDID) 278 + doc, err := testPLC.Resolve(ctx, txFactory.ReadCommitted(), testDID) 268 279 if testDID == "did:plc:pkmfz5soq2swsvbhvjekb36g" { 269 280 require.ErrorContains(t, err, "deactivated") 270 281 } else { ··· 272 283 require.Equal(t, testDID, doc.ID) 273 284 } 274 285 275 - actualLog, err := testPLC.AuditLog(ctx, plc.SpecificTreeVersion(newVersion), testDID) 286 + actualLog, err := testPLC.AuditLog(ctx, txFactory.ReadCommitted(), testDID) 276 287 require.NoError(t, err) 277 288 require.Len(t, actualLog, len(remoteLogs[i])) 289 + 290 + err = didplc.VerifyOpLog(actualLog) 291 + require.NoError(t, err) 292 + 278 293 for j := range actualLog { 279 294 require.Equal(t, remoteLogs[i][j].DID, actualLog[j].DID) 280 295 require.Equal(t, remoteLogs[i][j].CID, actualLog[j].CID) ··· 291 306 } 292 307 } 293 308 294 - export, err := testPLC.Export(ctx, plc.CommittedTreeVersion, 0, 0) 309 + export, err := testPLC.Export(ctx, txFactory.ReadCommitted(), 0, 0) 295 310 require.NoError(t, err) 296 311 require.Len(t, export, 100) 297 312 ··· 303 318 } 304 319 } 305 320 306 - func TestEnforceOpsRateLimit(t *testing.T) { 307 - // Test case 1: Operations within rate limits should pass 308 - t.Run("WithinLimits", func(t *testing.T) { 309 - now := time.Now() 310 - ops := []didplc.LogEntry{ 311 - { 312 - DID: "did:plc:test1", 313 - CreatedAt: now.Add(-30 * time.Minute).Format(time.RFC3339), 314 - }, 315 - { 316 - DID: "did:plc:test1", 317 - CreatedAt: now.Add(-45 * time.Minute).Format(time.RFC3339), 318 - }, 319 - } 320 - 321 - err := plc.EnforceOpsRateLimit(ops) 322 - require.NoError(t, err) 323 - }) 324 - 325 - // Test case 2: Exceeding hourly limit should fail 326 - t.Run("ExceedHourlyLimit", func(t *testing.T) { 327 - now := time.Now() 328 - ops := make([]didplc.LogEntry, plc.HourLimit+1) 329 - for i := 0; i < len(ops); i++ { 330 - ops[i] = didplc.LogEntry{ 331 - DID: "did:plc:test2", 332 - CreatedAt: now.Add(-time.Duration(i) * time.Minute).Format(time.RFC3339), 333 - } 334 - } 335 - 336 - err := plc.EnforceOpsRateLimit(ops) 337 - require.ErrorContains(t, err, "too many operations within last hour") 338 - }) 339 - 340 - // Test case 3: Exceeding daily limit should fail 341 - t.Run("ExceedDailyLimit", func(t *testing.T) { 342 - now := time.Now() 343 - ops := make([]didplc.LogEntry, plc.DayLimit+1) 344 - for i := 0; i < len(ops); i++ { 345 - // Create operations within the last day but over the daily limit 346 - ops[i] = didplc.LogEntry{ 347 - DID: "did:plc:test3", 348 - CreatedAt: now.Add(-time.Duration(i%24) * time.Hour).Format(time.RFC3339), 349 - } 350 - } 351 - 352 - err := plc.EnforceOpsRateLimit(ops) 353 - require.ErrorContains(t, err, "too many operations within last day") 354 - }) 355 - 356 - // Test case 4: Exceeding weekly limit should fail 357 - t.Run("ExceedWeeklyLimit", func(t *testing.T) { 358 - now := time.Now() 359 - ops := make([]didplc.LogEntry, plc.WeekLimit+1) 360 - for i := 0; i < len(ops); i++ { 361 - // Create operations within the last week but over the weekly limit 362 - ops[i] = didplc.LogEntry{ 363 - DID: "did:plc:test4", 364 - CreatedAt: now.Add(-time.Duration(i%168) * time.Hour).Format(time.RFC3339), 365 - } 366 - } 367 - 368 - err := plc.EnforceOpsRateLimit(ops) 369 - require.ErrorContains(t, err, "too many operations within last week") 370 - }) 371 - 372 - // Test case 5: Many operations within day but spread across hours should not exceed hourly limit 373 - t.Run("OperationsSpreadAcrossHours", func(t *testing.T) { 374 - now := time.Now() 375 - // Create 15 operations within the last day (within daily limit of 30) 376 - // but spread across 2+ hours so no single hour exceeds the limit of 10 377 - ops := make([]didplc.LogEntry, 15) 378 - for i := 0; i < 15; i++ { 379 - // First 8 operations in the last 30 minutes 380 - if i < 8 { 381 - ops[i] = didplc.LogEntry{ 382 - DID: "did:plc:test6", 383 - CreatedAt: now.Add(-time.Duration(i) * time.Minute).Format(time.RFC3339), 384 - } 385 - } else { 386 - // Next 7 operations spread across 2+ hours ago 387 - ops[i] = didplc.LogEntry{ 388 - DID: "did:plc:test6", 389 - CreatedAt: now.Add(-(2*time.Hour + time.Duration(i-8)*time.Minute)).Format(time.RFC3339), 390 - } 391 - } 392 - } 393 - 394 - err := plc.EnforceOpsRateLimit(ops) 395 - require.NoError(t, err) 396 - }) 397 - } 398 - 399 321 func TestImportOperationFromAuthoritativeSource(t *testing.T) { 400 322 ctx := t.Context() 401 323 402 - treeProvider := NewTestTreeProvider() 403 - testPLC := plc.NewPLC(treeProvider) 324 + txFactory, _, _ := NewTestTxFactory() 325 + testPLC := plc.NewPLC() 404 326 405 - tree, err := treeProvider.MutableTree() 406 - require.NoError(t, err) 407 - _, _, err = tree.SaveVersion() 327 + readTx := txFactory.ReadWorking(time.Now()) 328 + writeTx, err := readTx.Upgrade() 408 329 require.NoError(t, err) 409 330 410 331 seenCIDs := map[string]struct{}{} 411 332 seenDIDs := map[string]struct{}{} 412 333 for entry := range iterateOverExport(ctx, 0) { 413 - err := testPLC.ImportOperationFromAuthoritativeSource(ctx, entry) 334 + err := testPLC.ImportOperationFromAuthoritativeSource(ctx, writeTx, entry) 414 335 require.NoError(t, err) 415 336 337 + if len(seenCIDs)%1000 == 0 { 338 + err = writeTx.Commit() 339 + require.NoError(t, err) 340 + 341 + readTx = txFactory.ReadWorking(time.Now()) 342 + writeTx, err = readTx.Upgrade() 343 + require.NoError(t, err) 344 + } 345 + 416 346 seenCIDs[entry.CID] = struct{}{} 417 347 seenDIDs[entry.DID] = struct{}{} 418 348 if len(seenCIDs) == 10000 { ··· 420 350 } 421 351 } 422 352 423 - _, _, err = tree.SaveVersion() 353 + err = writeTx.Commit() 424 354 require.NoError(t, err) 425 355 426 - exportedEntries, err := testPLC.Export(ctx, plc.CommittedTreeVersion, 0, len(seenCIDs)+1) 356 + exportedEntries, err := testPLC.Export(ctx, txFactory.ReadCommitted(), 0, len(seenCIDs)+1) 427 357 require.NoError(t, err) 428 358 429 359 require.Len(t, exportedEntries, len(seenCIDs)) ··· 434 364 require.Empty(t, seenCIDs) 435 365 436 366 for did := range seenDIDs { 437 - auditLog, err := testPLC.AuditLog(ctx, plc.CommittedTreeVersion, did) 367 + auditLog, err := testPLC.AuditLog(ctx, txFactory.ReadCommitted(), did) 438 368 require.NoError(t, err) 439 369 440 370 err = didplc.VerifyOpLog(auditLog) ··· 446 376 ctx := t.Context() 447 377 448 378 testFn := func(toImport []didplc.LogEntry, mutate func(didplc.LogEntry) didplc.LogEntry) ([]types.SequencedLogEntry, []didplc.LogEntry) { 449 - treeProvider := NewTestTreeProvider() 450 - testPLC := plc.NewPLC(treeProvider) 379 + txFactory, _, _ := NewTestTxFactory() 380 + testPLC := plc.NewPLC() 451 381 452 - tree, err := treeProvider.MutableTree() 453 - require.NoError(t, err) 454 - _, _, err = tree.SaveVersion() 382 + readTx := txFactory.ReadWorking(time.Now()) 383 + writeTx, err := readTx.Upgrade() 455 384 require.NoError(t, err) 456 385 457 386 for _, entry := range toImport { 458 387 entry = mutate(entry) 459 - err := testPLC.ImportOperationFromAuthoritativeSource(ctx, entry) 388 + err := testPLC.ImportOperationFromAuthoritativeSource(ctx, writeTx, entry) 460 389 require.NoError(t, err) 461 390 } 462 391 463 - _, _, err = tree.SaveVersion() 392 + err = writeTx.Commit() 464 393 require.NoError(t, err) 465 394 466 - exportedEntries, err := testPLC.Export(ctx, plc.CommittedTreeVersion, 0, len(toImport)+1) 395 + exportedEntries, err := testPLC.Export(ctx, txFactory.ReadCommitted(), 0, len(toImport)+1) 467 396 require.NoError(t, err) 468 397 469 398 require.Len(t, exportedEntries, len(toImport)) 470 399 471 - auditLog, err := testPLC.AuditLog(ctx, plc.CommittedTreeVersion, "did:plc:pkmfz5soq2swsvbhvjekb36g") 400 + auditLog, err := testPLC.AuditLog(ctx, txFactory.ReadCommitted(), "did:plc:pkmfz5soq2swsvbhvjekb36g") 472 401 require.NoError(t, err) 473 402 474 403 return exportedEntries, auditLog
+10 -35
plc/testutil_test.go
··· 1 1 package plc_test 2 2 3 3 import ( 4 + dbm "github.com/cometbft/cometbft-db" 4 5 "github.com/cosmos/iavl" 5 - dbm "github.com/cosmos/iavl/db" 6 + iavldb "github.com/cosmos/iavl/db" 6 7 "github.com/palantir/stacktrace" 7 - "tangled.org/gbl08ma.com/didplcbft/plc" 8 8 "tangled.org/gbl08ma.com/didplcbft/store" 9 + "tangled.org/gbl08ma.com/didplcbft/transaction" 9 10 ) 10 11 11 - type testTreeProvider struct { 12 - tree *iavl.MutableTree 13 - } 14 - 15 - func NewTestTreeProvider() *testTreeProvider { 16 - return &testTreeProvider{ 17 - tree: iavl.NewMutableTree(dbm.NewMemDB(), 128, false, iavl.NewNopLogger()), 12 + func NewTestTxFactory() (*transaction.Factory, *iavl.MutableTree, dbm.DB) { 13 + tree := iavl.NewMutableTree(iavldb.NewMemDB(), 128, false, iavl.NewNopLogger()) 14 + _, _, err := tree.SaveVersion() 15 + if err != nil { 16 + panic(stacktrace.Propagate(err, "")) 18 17 } 19 - } 20 18 21 - func (t *testTreeProvider) ImmutableTree(version plc.TreeVersion) (store.ReadOnlyTree, error) { 22 - if version.IsMutable() { 23 - return store.AdaptMutableTree(t.tree), nil 24 - } 25 - var v int64 26 - if version.IsCommitted() { 27 - var err error 28 - v, err = t.tree.GetLatestVersion() 29 - if err != nil { 30 - return nil, stacktrace.Propagate(err, "") 31 - } 32 - } else { 33 - var ok bool 34 - v, ok = version.SpecificVersion() 35 - if !ok { 36 - return nil, stacktrace.NewError("unsupported TreeVersion") 37 - } 38 - } 39 - 40 - it, err := t.tree.GetImmutable(v) 41 - return store.AdaptImmutableTree(it), stacktrace.Propagate(err, "") 42 - } 43 - 44 - func (t *testTreeProvider) MutableTree() (*iavl.MutableTree, error) { 45 - return t.tree, nil 19 + indexDB := dbm.NewMemDB() 20 + return transaction.NewFactory(tree, indexDB, store.Tree.NextOperationSequence), tree, indexDB 46 21 }
-93
store/iavl_adapter.go
··· 1 - package store 2 - 3 - import ( 4 - "github.com/cosmos/iavl" 5 - ics23 "github.com/cosmos/ics23/go" 6 - "github.com/palantir/stacktrace" 7 - ) 8 - 9 - type ReadOnlyTree interface { 10 - Has(key []byte) (bool, error) 11 - Get(key []byte) ([]byte, error) 12 - GetProof(key []byte) (*ics23.CommitmentProof, error) // won't actually work on mutable trees, but we don't need it to 13 - IterateRange(start, end []byte, ascending bool, fn func(key []byte, value []byte) bool) (stopped bool) 14 - } 15 - 16 - type mutableToUnifiedTree struct { 17 - tree *iavl.MutableTree 18 - } 19 - 20 - var _ ReadOnlyTree = (*mutableToUnifiedTree)(nil) 21 - 22 - func AdaptMutableTree(tree *iavl.MutableTree) ReadOnlyTree { 23 - return &mutableToUnifiedTree{ 24 - tree: tree, 25 - } 26 - } 27 - 28 - // Has implements [ReadOnlyTree]. 29 - func (m *mutableToUnifiedTree) Has(key []byte) (bool, error) { 30 - return m.tree.Has(key) 31 - } 32 - 33 - // Get implements [ReadOnlyTree]. 34 - func (m *mutableToUnifiedTree) Get(key []byte) ([]byte, error) { 35 - return m.tree.Get(key) 36 - } 37 - 38 - // GetProof implements [ReadOnlyTree]. 39 - func (m *mutableToUnifiedTree) GetProof(key []byte) (*ics23.CommitmentProof, error) { 40 - return nil, stacktrace.NewError("proof calculation not possible over mutable tree") 41 - } 42 - 43 - // IterateRange implements [ReadOnlyTree]. 44 - func (m *mutableToUnifiedTree) IterateRange(start []byte, end []byte, ascending bool, fn func(key []byte, value []byte) bool) (stopped bool) { 45 - // it might look like MutableTree implements IterateRange but it doesn't, 46 - // most iteration methods actually come from the embedded ImmutableTree we're not meant to use 47 - // (terrible API) 48 - itr, err := m.tree.Iterator(start, end, ascending) 49 - if err != nil { 50 - return false 51 - } 52 - 53 - defer itr.Close() 54 - 55 - for ; itr.Valid(); itr.Next() { 56 - if fn(itr.Key(), itr.Value()) { 57 - return true 58 - } 59 - } 60 - return false 61 - } 62 - 63 - type immutableToUnifiedTree struct { 64 - tree *iavl.ImmutableTree 65 - } 66 - 67 - var _ ReadOnlyTree = (*immutableToUnifiedTree)(nil) 68 - 69 - func AdaptImmutableTree(tree *iavl.ImmutableTree) ReadOnlyTree { 70 - return &immutableToUnifiedTree{ 71 - tree: tree, 72 - } 73 - } 74 - 75 - // Has implements [ReadOnlyTree]. 76 - func (i *immutableToUnifiedTree) Has(key []byte) (bool, error) { 77 - return i.tree.Has(key) 78 - } 79 - 80 - // Get implements [ReadOnlyTree]. 81 - func (i *immutableToUnifiedTree) Get(key []byte) ([]byte, error) { 82 - return i.tree.Get(key) 83 - } 84 - 85 - // GetProof implements [ReadOnlyTree]. 86 - func (i *immutableToUnifiedTree) GetProof(key []byte) (*ics23.CommitmentProof, error) { 87 - return i.tree.GetProof(key) 88 - } 89 - 90 - // IterateRange implements [ReadOnlyTree]. 91 - func (i *immutableToUnifiedTree) IterateRange(start []byte, end []byte, ascending bool, fn func(key []byte, value []byte) bool) (stopped bool) { 92 - return i.tree.IterateRange(start, end, ascending, fn) 93 - }
+181 -105
store/tree.go
··· 11 11 "time" 12 12 13 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 - "github.com/cosmos/iavl" 15 14 ics23 "github.com/cosmos/ics23/go" 16 15 "github.com/did-method-plc/go-didplc" 17 16 cbornode "github.com/ipfs/go-ipld-cbor" ··· 19 18 "github.com/polydawn/refmt/obj/atlas" 20 19 "github.com/samber/lo" 21 20 "github.com/samber/mo" 21 + "tangled.org/gbl08ma.com/didplcbft/transaction" 22 22 "tangled.org/gbl08ma.com/didplcbft/types" 23 23 ) 24 24 25 + // TODO rename to something more appropriate, now that this touches both the tree and the index 25 26 var Tree PLCTreeStore = &TreeStore{} 26 27 27 28 type PLCTreeStore interface { 28 - AuditLog(ctx context.Context, tree ReadOnlyTree, did string, withProof bool) ([]types.SequencedLogEntry, *ics23.CommitmentProof, error) 29 - AuditLogReverseIterator(ctx context.Context, tree ReadOnlyTree, did string, err *error) iter.Seq2[int, types.SequencedLogEntry] 30 - ExportOperations(ctx context.Context, tree ReadOnlyTree, after uint64, count int) ([]types.SequencedLogEntry, error) // passing a count of zero means unlimited 31 - StoreOperation(tree *iavl.MutableTree, entry didplc.LogEntry, nullifyWithIndexEqualOrGreaterThan mo.Option[int]) error 32 - SetOperationCreatedAt(tree *iavl.MutableTree, seqID uint64, createdAt time.Time) error 29 + AuditLog(ctx context.Context, tx transaction.Read, did string, withProof bool) ([]types.SequencedLogEntry, *ics23.CommitmentProof, error) 30 + AuditLogReverseIterator(ctx context.Context, tx transaction.Read, did string, err *error) iter.Seq[types.SequencedLogEntry] 31 + ExportOperations(ctx context.Context, tx transaction.Read, after uint64, count int) ([]types.SequencedLogEntry, error) // passing a count of zero means unlimited 32 + StoreOperation(ctx context.Context, tx transaction.Write, entry didplc.LogEntry, nullifyWithSequenceEqualOrGreaterThan mo.Option[uint64]) error 33 + SetOperationCreatedAt(tx transaction.Write, seqID uint64, createdAt time.Time) error 33 34 34 - AuthoritativePLC(tree ReadOnlyTree) (string, error) 35 - SetAuthoritativePLC(tree *iavl.MutableTree, url string) error 35 + NextOperationSequence(tx transaction.Read) (uint64, error) 36 36 37 - AuthoritativeImportProgress(tree ReadOnlyTree) (uint64, error) 38 - SetAuthoritativeImportProgress(tree *iavl.MutableTree, nextCursor uint64) error 37 + AuthoritativePLC(tx transaction.Read) (string, error) 38 + SetAuthoritativePLC(tx transaction.Write, url string) error 39 + 40 + AuthoritativeImportProgress(tx transaction.Read) (uint64, error) 41 + SetAuthoritativeImportProgress(tx transaction.Write, nextCursor uint64) error 39 42 } 40 43 41 44 var _ PLCTreeStore = (*TreeStore)(nil) ··· 43 46 // TreeStore exists just to groups methods nicely 44 47 type TreeStore struct{} 45 48 46 - func (t *TreeStore) AuditLog(ctx context.Context, tree ReadOnlyTree, did string, withProof bool) ([]types.SequencedLogEntry, *ics23.CommitmentProof, error) { 47 - proofs := []*ics23.CommitmentProof{} 48 - 49 + func (t *TreeStore) AuditLog(ctx context.Context, tx transaction.Read, did string, withProof bool) ([]types.SequencedLogEntry, *ics23.CommitmentProof, error) { 49 50 didBytes, err := DIDToBytes(did) 50 51 if err != nil { 51 52 return nil, nil, stacktrace.Propagate(err, "") 52 53 } 53 54 54 - logKey := marshalDIDLogKey(didBytes) 55 + didRangeStart := marshalDIDLogKey(didBytes, 0) 56 + didRangeEnd := marshalDIDLogKey(didBytes, math.MaxUint64) 55 57 56 - logOperations, err := tree.Get(logKey) 58 + didLogIterator, err := tx.IndexDB().Iterator(didRangeStart, didRangeEnd) 57 59 if err != nil { 58 60 return nil, nil, stacktrace.Propagate(err, "") 59 - } 60 - operationKeys := make([][]byte, 0, len(logOperations)/8) 61 - for seqBytes := range slices.Chunk(logOperations, 8) { 62 - operationKeys = append(operationKeys, sequenceBytesToOperationKey(seqBytes)) 63 61 } 64 62 65 - if withProof { 66 - proof, err := tree.GetProof(logKey) 67 - if err != nil { 68 - return nil, nil, stacktrace.Propagate(err, "") 69 - } 70 - proofs = append(proofs, proof) 71 - } 63 + defer didLogIterator.Close() 72 64 73 - logEntries := make([]types.SequencedLogEntry, 0, len(operationKeys)) 74 - for _, opKey := range operationKeys { 65 + logEntries := make([]types.SequencedLogEntry, 0, 1) 66 + proofs := []*ics23.CommitmentProof{} 67 + txHeight := uint64(tx.Height()) 68 + for didLogIterator.Valid() { 75 69 select { 76 70 case <-ctx.Done(): 77 71 return nil, nil, stacktrace.Propagate(ctx.Err(), "") 78 72 default: 79 73 } 80 74 81 - operationValue, err := tree.Get(opKey) 82 - if err != nil { 83 - return nil, nil, stacktrace.Propagate(err, "") 84 - } 75 + sequence := unmarshalDIDLogSequence(didLogIterator.Key()) 76 + validFromHeight, validToHeight := unmarshalDIDLogValue(didLogIterator.Value()) 77 + 78 + opKey := marshalOperationKey(sequence) 79 + 80 + if txHeight >= validFromHeight && txHeight <= validToHeight { 81 + operationValue, err := tx.Tree().Get(opKey) 82 + if err != nil { 83 + return nil, nil, stacktrace.Propagate(err, "") 84 + } 85 + 86 + if withProof { 87 + proof, err := tx.Tree().GetProof(opKey) 88 + if err != nil { 89 + return nil, nil, stacktrace.Propagate(err, "") 90 + } 91 + proofs = append(proofs, proof) 92 + } 85 93 86 - if withProof { 87 - proof, err := tree.GetProof(opKey) 94 + logEntry, err := unmarshalLogEntry(opKey, operationValue) 88 95 if err != nil { 89 96 return nil, nil, stacktrace.Propagate(err, "") 90 97 } 91 - proofs = append(proofs, proof) 92 - } 93 98 94 - logEntry, err := unmarshalLogEntry(opKey, operationValue) 95 - if err != nil { 96 - return nil, nil, stacktrace.Propagate(err, "") 99 + logEntries = append(logEntries, logEntry) 97 100 } 101 + didLogIterator.Next() 102 + } 98 103 99 - logEntries = append(logEntries, logEntry) 104 + err = didLogIterator.Error() 105 + if err != nil { 106 + return nil, nil, stacktrace.Propagate(err, "") 100 107 } 101 108 102 109 var combinedProof *ics23.CommitmentProof ··· 109 116 return logEntries, combinedProof, nil 110 117 } 111 118 112 - func (t *TreeStore) AuditLogReverseIterator(ctx context.Context, tree ReadOnlyTree, did string, retErr *error) iter.Seq2[int, types.SequencedLogEntry] { 113 - return func(yield func(int, types.SequencedLogEntry) bool) { 119 + func (t *TreeStore) AuditLogReverseIterator(ctx context.Context, tx transaction.Read, did string, retErr *error) iter.Seq[types.SequencedLogEntry] { 120 + return func(yield func(types.SequencedLogEntry) bool) { 114 121 didBytes, err := DIDToBytes(did) 115 122 if err != nil { 116 123 *retErr = stacktrace.Propagate(err, "") 117 124 return 118 125 } 119 126 120 - logKey := marshalDIDLogKey(didBytes) 127 + didRangeStart := marshalDIDLogKey(didBytes, 0) 128 + didRangeEnd := marshalDIDLogKey(didBytes, math.MaxUint64) 121 129 122 - logOperations, err := tree.Get(logKey) 130 + didLogIterator, err := tx.IndexDB().ReverseIterator(didRangeStart, didRangeEnd) 123 131 if err != nil { 124 132 *retErr = stacktrace.Propagate(err, "") 125 133 return 126 134 } 127 - operationKeys := make([][]byte, 0, len(logOperations)/8) 128 - for seqBytes := range slices.Chunk(logOperations, 8) { 129 - operationKeys = append(operationKeys, sequenceBytesToOperationKey(seqBytes)) 130 - } 135 + 136 + defer didLogIterator.Close() 137 + 138 + txHeight := uint64(tx.Height()) 131 139 132 - for i := len(operationKeys) - 1; i >= 0; i-- { 140 + for didLogIterator.Valid() { 133 141 select { 134 142 case <-ctx.Done(): 135 143 *retErr = stacktrace.Propagate(ctx.Err(), "") ··· 137 145 default: 138 146 } 139 147 140 - opKey := operationKeys[i] 141 - operationValue, err := tree.Get(opKey) 142 - if err != nil { 143 - *retErr = stacktrace.Propagate(err, "") 144 - return 145 - } 148 + sequence := unmarshalDIDLogSequence(didLogIterator.Key()) 149 + validFromHeight, validToHeight := unmarshalDIDLogValue(didLogIterator.Value()) 150 + 151 + opKey := marshalOperationKey(sequence) 146 152 147 - logEntry, err := unmarshalLogEntry(opKey, operationValue) 148 - if err != nil { 149 - *retErr = stacktrace.Propagate(err, "") 150 - return 151 - } 153 + if txHeight >= validFromHeight && txHeight <= validToHeight { 154 + operationValue, err := tx.Tree().Get(opKey) 155 + if err != nil { 156 + *retErr = stacktrace.Propagate(err, "") 157 + return 158 + } 159 + 160 + logEntry, err := unmarshalLogEntry(opKey, operationValue) 161 + if err != nil { 162 + *retErr = stacktrace.Propagate(err, "") 163 + return 164 + } 152 165 153 - if !yield(i, logEntry) { 154 - return 166 + if !yield(logEntry) { 167 + return 168 + } 155 169 } 170 + didLogIterator.Next() 171 + } 172 + 173 + err = didLogIterator.Error() 174 + if err != nil { 175 + *retErr = stacktrace.Propagate(err, "") 156 176 } 157 177 } 158 178 } 159 179 160 - func (t *TreeStore) ExportOperations(ctx context.Context, tree ReadOnlyTree, after uint64, count int) ([]types.SequencedLogEntry, error) { 180 + func (t *TreeStore) ExportOperations(ctx context.Context, tx transaction.Read, after uint64, count int) ([]types.SequencedLogEntry, error) { 161 181 // as the name suggests, after is an exclusive lower bound, but our iterators use inclusive lower bounds 162 182 start := after + 1 163 183 startKey := marshalOperationKey(start) ··· 165 185 166 186 entries := make([]types.SequencedLogEntry, 0, count) 167 187 var iterErr error 168 - tree.IterateRange(startKey, endKey, true, func(operationKey, operationValue []byte) bool { 188 + tx.Tree().IterateRange(startKey, endKey, true, func(operationKey, operationValue []byte) bool { 169 189 select { 170 190 case <-ctx.Done(): 171 191 iterErr = stacktrace.Propagate(ctx.Err(), "") ··· 188 208 return entries, nil 189 209 } 190 210 191 - func (t *TreeStore) StoreOperation(tree *iavl.MutableTree, entry didplc.LogEntry, nullifyWithIndexEqualOrGreaterThan mo.Option[int]) error { 211 + // StoreOperation stores an operation in the tree, nullifying existing operations whose index within the DID's history 212 + // is lower or equal to the specified optional integer. 213 + // 214 + // Even though this function is not meant to overwrite operations (it will error) we ask the caller to provide the sequence 215 + // The caller is responsible for managing the sequence and invalidating it when needed (e.g. after a rollback) using 216 + // [[TreeStore.NextOperationSequence]]. 217 + // Pushing the responsibility to the caller is preferable in terms of performance, even if it leads to less safe code, 218 + // because getting the sequence from the tree within this function every time has a significant performance hit 219 + func (t *TreeStore) StoreOperation(ctx context.Context, tx transaction.Write, entry didplc.LogEntry, nullifyWithSequenceEqualOrGreaterThan mo.Option[uint64]) error { 192 220 didBytes, err := DIDToBytes(entry.DID) 193 221 if err != nil { 194 222 return stacktrace.Propagate(err, "") 195 223 } 196 224 197 - logKey := marshalDIDLogKey(didBytes) 225 + txHeight := uint64(tx.Height()) 198 226 199 - logOperations, err := tree.Get(logKey) 200 - logOperations = slices.Clone(logOperations) 227 + if nullifyEGt, ok := nullifyWithSequenceEqualOrGreaterThan.Get(); ok { 228 + didRangeStart := marshalDIDLogKey(didBytes, nullifyEGt) 229 + didRangeEnd := marshalDIDLogKey(didBytes, math.MaxUint64) 201 230 202 - if nullifyEGt, ok := nullifyWithIndexEqualOrGreaterThan.Get(); ok { 203 - var operationKeys [][]byte 231 + didLogIterator, err := tx.IndexDB().ReverseIterator(didRangeStart, didRangeEnd) 204 232 if err != nil { 205 - operationKeys = [][]byte{} 206 - } else { 207 - operationKeys = make([][]byte, 0, len(logOperations)/8) 208 - for seqBytes := range slices.Chunk(logOperations, 8) { 209 - operationKeys = append(operationKeys, sequenceBytesToOperationKey(seqBytes)) 210 - } 233 + return stacktrace.Propagate(err, "") 211 234 } 212 235 213 - for _, opKey := range operationKeys[nullifyEGt:] { 214 - operationValue, err := tree.Get(opKey) 236 + defer didLogIterator.Close() 237 + 238 + txHeight := uint64(tx.Height()) 239 + 240 + for didLogIterator.Valid() { 241 + select { 242 + case <-ctx.Done(): 243 + return stacktrace.Propagate(ctx.Err(), "") 244 + default: 245 + } 246 + 247 + sequence := unmarshalDIDLogSequence(didLogIterator.Key()) 248 + validFromHeight, validToHeight := unmarshalDIDLogValue(didLogIterator.Value()) 249 + 250 + opKey := marshalOperationKey(sequence) 251 + 252 + if txHeight < validFromHeight || txHeight > validToHeight { 253 + // ignore ops that are invisible at this height 254 + didLogIterator.Next() 255 + continue 256 + } 257 + 258 + operationValue, err := tx.Tree().Get(opKey) 215 259 if err != nil { 216 260 return stacktrace.Propagate(err, "") 217 261 } 218 262 operationValue = slices.Clone(operationValue) 219 263 operationValue[0] = 1 220 264 221 - _, err = tree.Set(opKey, operationValue) 265 + updated, err := tx.Tree().Set(opKey, operationValue) 222 266 if err != nil { 223 267 return stacktrace.Propagate(err, "") 224 268 } 269 + if !updated { 270 + // if we get to this point we have a mistake in our program, and the data is now inconsistent 271 + // we are not supposed to be able to recover from this error without rolling back the tree 272 + return stacktrace.NewError("expected to be updating an existing operation key but wrote new one instead") 273 + } 274 + 275 + didLogIterator.Next() 276 + } 277 + err = didLogIterator.Error() 278 + if err != nil { 279 + return stacktrace.Propagate(err, "") 225 280 } 226 281 } 227 282 ··· 230 285 return stacktrace.Propagate(err, "invalid CreatedAt") 231 286 } 232 287 233 - seq, err := getNextSeqID(tree) 288 + sequence, err := tx.NextSequence() 234 289 if err != nil { 235 290 return stacktrace.Propagate(err, "") 236 291 } 237 292 238 293 operation := entry.Operation.AsOperation() 239 - opKey := marshalOperationKey(seq) 294 + opKey := marshalOperationKey(sequence) 240 295 opValue := marshalOperationValue(entry.Nullified, didBytes, opDatetime.Time(), operation) 241 296 242 - updated, err := tree.Set(opKey, opValue) 297 + updated, err := tx.Tree().Set(opKey, opValue) 243 298 if err != nil { 244 299 return stacktrace.Propagate(err, "") 245 300 } 246 301 if updated { 302 + // if we get to this point we have a mistake in our program, and the data is now inconsistent 303 + // we are not supposed to be able to recover from this error without rolling back the tree 247 304 return stacktrace.NewError("expected to be writing to a new operation key but updated instead") 248 305 } 249 306 250 - logOperations = append(logOperations, opKey[1:9]...) 251 - _, err = tree.Set(logKey, logOperations) 307 + logKey := marshalDIDLogKey(didBytes, sequence) 308 + logValue := marshalDIDLogValue(txHeight, math.MaxUint64) 309 + 310 + err = tx.IndexDB().Set(logKey, logValue) 252 311 if err != nil { 253 312 return stacktrace.Propagate(err, "") 254 313 } ··· 256 315 return nil 257 316 } 258 317 259 - func (t *TreeStore) SetOperationCreatedAt(tree *iavl.MutableTree, seqID uint64, createdAt time.Time) error { 318 + func (t *TreeStore) SetOperationCreatedAt(tx transaction.Write, seqID uint64, createdAt time.Time) error { 260 319 opKey := marshalOperationKey(seqID) 261 320 262 - opValue, err := tree.Get(opKey) 321 + opValue, err := tx.Tree().Get(opKey) 263 322 if err != nil { 264 323 return stacktrace.Propagate(err, "") 265 324 } ··· 272 331 ts := uint64(createdAt.Truncate(1 * time.Millisecond).UTC().UnixNano()) 273 332 binary.BigEndian.PutUint64(opValue[16:24], ts) 274 333 275 - _, err = tree.Set(opKey, opValue) 334 + updated, err := tx.Tree().Set(opKey, opValue) 335 + if !updated { 336 + // if we get to this point we have a mistake in our program, and the data is now inconsistent 337 + // we are not supposed to be able to recover from this error without rolling back the tree 338 + return stacktrace.NewError("expected to be updating an existing operation key but wrote new one instead") 339 + } 276 340 return stacktrace.Propagate(err, "") 277 341 } 278 342 279 343 var minOperationKey = marshalOperationKey(0) 280 344 var maxOperationKey = marshalOperationKey(math.MaxInt64) 281 345 282 - func getNextSeqID(tree *iavl.MutableTree) (uint64, error) { 346 + func (t *TreeStore) NextOperationSequence(tx transaction.Read) (uint64, error) { 283 347 seq := uint64(0) 284 348 285 - itr, err := tree.Iterator(minOperationKey, maxOperationKey, false) 349 + itr, err := tx.Tree().Iterator(minOperationKey, maxOperationKey, false) 286 350 if err != nil { 287 351 return 0, stacktrace.Propagate(err, "") 288 352 } ··· 326 390 return did, nil 327 391 } 328 392 329 - func marshalDIDLogKey(didBytes []byte) []byte { 330 - key := make([]byte, 1+15) 393 + func marshalDIDLogKey(didBytes []byte, sequence uint64) []byte { 394 + key := make([]byte, 1+15+8) 331 395 key[0] = 'l' 332 - copy(key[1:], didBytes) 396 + copy(key[1:16], didBytes) 397 + binary.BigEndian.PutUint64(key[16:], sequence) 333 398 return key 334 399 } 335 400 336 - func sequenceBytesToOperationKey(sequenceBytes []byte) []byte { 337 - key := make([]byte, 1+8) 338 - key[0] = 'o' 339 - copy(key[1:9], sequenceBytes) 340 - return key 401 + func unmarshalDIDLogSequence(logKey []byte) uint64 { 402 + return binary.BigEndian.Uint64(logKey[16:24]) 403 + } 404 + 405 + // validFromHeight, validToHeight are inclusive (i.e. if the former is 5 and the latter is 10, the value was valid at height 5 and 10, but not at 4 or 11) 406 + func marshalDIDLogValue(validFromHeight, validToHeight uint64) []byte { 407 + value := make([]byte, 8+8) 408 + binary.BigEndian.PutUint64(value, validFromHeight) 409 + binary.BigEndian.PutUint64(value[8:], validToHeight) 410 + return value 411 + } 412 + 413 + func unmarshalDIDLogValue(value []byte) (validFromHeight, validToHeight uint64) { 414 + validFromHeight = binary.BigEndian.Uint64(value[0:8]) 415 + validToHeight = binary.BigEndian.Uint64(value[8:16]) 416 + return 341 417 } 342 418 343 419 func marshalOperationKey(sequence uint64) []byte { ··· 461 537 Complete()) 462 538 } 463 539 464 - func (t *TreeStore) AuthoritativePLC(tree ReadOnlyTree) (string, error) { 465 - url, err := tree.Get([]byte("aPLCURL")) 540 + func (t *TreeStore) AuthoritativePLC(tx transaction.Read) (string, error) { 541 + url, err := tx.Tree().Get([]byte("aPLCURL")) 466 542 if err != nil { 467 543 return "", stacktrace.Propagate(err, "") 468 544 } ··· 472 548 return string(url), nil 473 549 } 474 550 475 - func (t *TreeStore) SetAuthoritativePLC(tree *iavl.MutableTree, url string) error { 476 - _, err := tree.Set([]byte("aPLCURL"), []byte(url)) 551 + func (t *TreeStore) SetAuthoritativePLC(tx transaction.Write, url string) error { 552 + _, err := tx.Tree().Set([]byte("aPLCURL"), []byte(url)) 477 553 return stacktrace.Propagate(err, "") 478 554 } 479 555 480 - func (t *TreeStore) AuthoritativeImportProgress(tree ReadOnlyTree) (uint64, error) { 481 - progBytes, err := tree.Get([]byte("aImportProgress")) 556 + func (t *TreeStore) AuthoritativeImportProgress(tx transaction.Read) (uint64, error) { 557 + progBytes, err := tx.Tree().Get([]byte("aImportProgress")) 482 558 if err != nil { 483 559 return 0, stacktrace.Propagate(err, "") 484 560 } ··· 488 564 return binary.BigEndian.Uint64(progBytes), nil 489 565 } 490 566 491 - func (t *TreeStore) SetAuthoritativeImportProgress(tree *iavl.MutableTree, nextCursor uint64) error { 567 + func (t *TreeStore) SetAuthoritativeImportProgress(tx transaction.Write, nextCursor uint64) error { 492 568 value := make([]byte, 8) 493 569 binary.BigEndian.PutUint64(value, nextCursor) 494 570 495 - _, err := tree.Set([]byte("aImportProgress"), value) 571 + _, err := tx.Tree().Set([]byte("aImportProgress"), value) 496 572 return stacktrace.Propagate(err, "") 497 573 }
+33
transaction/height.go
··· 1 + package transaction 2 + 3 + type Height struct { 4 + workingHeight bool 5 + committedHeight bool 6 + specificHeight int64 7 + } 8 + 9 + func (v Height) IsMutable() bool { 10 + return v.workingHeight 11 + } 12 + 13 + func (v Height) IsCommitted() bool { 14 + return v.committedHeight 15 + } 16 + 17 + func (v Height) SpecificVersion() (int64, bool) { 18 + return v.specificHeight, !v.workingHeight && !v.committedHeight 19 + } 20 + 21 + var WorkingHeight = Height{ 22 + workingHeight: true, 23 + } 24 + 25 + var CommittedHeight = Height{ 26 + committedHeight: true, 27 + } 28 + 29 + func SpecificHeight(height int64) Height { 30 + return Height{ 31 + specificHeight: height, 32 + } 33 + }
+57
transaction/iavl_adapter.go
··· 1 + package transaction 2 + 3 + import ( 4 + corestore "cosmossdk.io/core/store" 5 + "github.com/cosmos/iavl" 6 + ics23 "github.com/cosmos/ics23/go" 7 + ) 8 + 9 + type ReadTree interface { 10 + Has(key []byte) (bool, error) 11 + Get(key []byte) ([]byte, error) 12 + GetProof(key []byte) (*ics23.CommitmentProof, error) // won't actually work on mutable trees, but we don't need it to 13 + Iterator(start, end []byte, ascending bool) (corestore.Iterator, error) 14 + IterateRange(start, end []byte, ascending bool, fn func(key []byte, value []byte) bool) (stopped bool) 15 + } 16 + 17 + type UnifiedTree interface { 18 + ReadTree 19 + Set(key, value []byte) (bool, error) 20 + } 21 + 22 + type immutableToReadOnlyTree struct { 23 + tree *iavl.ImmutableTree 24 + } 25 + 26 + var _ ReadTree = (*immutableToReadOnlyTree)(nil) 27 + 28 + func AdaptImmutableTree(tree *iavl.ImmutableTree) ReadTree { 29 + return &immutableToReadOnlyTree{ 30 + tree: tree, 31 + } 32 + } 33 + 34 + // Has implements [ReadTree]. 35 + func (i *immutableToReadOnlyTree) Has(key []byte) (bool, error) { 36 + return i.tree.Has(key) 37 + } 38 + 39 + // Get implements [ReadTree]. 40 + func (i *immutableToReadOnlyTree) Get(key []byte) ([]byte, error) { 41 + return i.tree.Get(key) 42 + } 43 + 44 + // GetProof implements [ReadTree]. 45 + func (i *immutableToReadOnlyTree) GetProof(key []byte) (*ics23.CommitmentProof, error) { 46 + return i.tree.GetProof(key) 47 + } 48 + 49 + // IterateRange implements [ReadTree]. 50 + func (i *immutableToReadOnlyTree) IterateRange(start []byte, end []byte, ascending bool, fn func(key []byte, value []byte) bool) (stopped bool) { 51 + return i.tree.IterateRange(start, end, ascending, fn) 52 + } 53 + 54 + // Iterator implements [ReadTree]. 55 + func (m *immutableToReadOnlyTree) Iterator(start, end []byte, ascending bool) (corestore.Iterator, error) { 56 + return m.tree.Iterator(start, end, ascending) 57 + }
+46
transaction/interface.go
··· 1 + package transaction 2 + 3 + import ( 4 + "time" 5 + 6 + dbm "github.com/cometbft/cometbft-db" 7 + ) 8 + 9 + type Read interface { 10 + Height() int64 11 + Timestamp() time.Time 12 + 13 + Tree() ReadTree 14 + IndexDB() ReadIndex 15 + 16 + Upgrade() (Write, error) 17 + } 18 + 19 + type Write interface { 20 + Height() int64 21 + Timestamp() time.Time 22 + 23 + NextSequence() (uint64, error) 24 + Tree() UnifiedTree 25 + IndexDB() WriteIndex 26 + 27 + Commit() error 28 + Rollback() error 29 + 30 + Downgrade() Read 31 + } 32 + 33 + type ReadIndex interface { 34 + Get([]byte) ([]byte, error) 35 + Has(key []byte) (bool, error) 36 + 37 + Iterator(start, end []byte) (dbm.Iterator, error) 38 + ReverseIterator(start, end []byte) (dbm.Iterator, error) 39 + } 40 + 41 + type WriteIndex interface { 42 + ReadIndex 43 + 44 + Delete([]byte) error 45 + Set([]byte, []byte) error 46 + }
+39
transaction/read_on_write_tx.go
··· 1 + package transaction 2 + 3 + import ( 4 + "time" 5 + ) 6 + 7 + // readOnWriteTx is created from a write tx to allow a write tx to be passed to functions that accept a read-only transaction 8 + // we can't outright use the original read transaction because we want to read uncommitted values from IndexDB 9 + type readOnWriteTx struct { 10 + w *writeTx 11 + } 12 + 13 + // Height implements [Read]. 14 + func (d *readOnWriteTx) Height() int64 { 15 + return d.w.readTx.height 16 + } 17 + 18 + // IndexDB implements [Read]. 19 + func (d *readOnWriteTx) IndexDB() ReadIndex { 20 + // by returning the write index we get the read uncommitted behavior we want 21 + d.w.createWriteIndexIfNeeded() 22 + return d.w.writeIndex 23 + } 24 + 25 + // Timestamp implements [Read]. 26 + func (d *readOnWriteTx) Timestamp() time.Time { 27 + return d.w.readTx.ts 28 + } 29 + 30 + // Tree implements [Read]. 31 + func (d *readOnWriteTx) Tree() ReadTree { 32 + // we can return the mutable tree as-is because it already presents read unsaved behavior 33 + return d.w.readTx.mutableTree 34 + } 35 + 36 + // Upgrade implements [Read]. 37 + func (d *readOnWriteTx) Upgrade() (Write, error) { 38 + return d.w, nil 39 + }
+101
transaction/read_tx.go
··· 1 + package transaction 2 + 3 + import ( 4 + "time" 5 + 6 + dbm "github.com/cometbft/cometbft-db" 7 + "github.com/cosmos/iavl" 8 + "github.com/palantir/stacktrace" 9 + ) 10 + 11 + type Factory struct { 12 + db dbm.DB 13 + tree *iavl.MutableTree 14 + sequenceGetter func(tx Read) (uint64, error) 15 + } 16 + 17 + func NewFactory(tree *iavl.MutableTree, indexDB dbm.DB, sequenceGetter func(tx Read) (uint64, error)) *Factory { 18 + return &Factory{ 19 + db: indexDB, 20 + tree: tree, 21 + sequenceGetter: sequenceGetter, 22 + } 23 + } 24 + 25 + type readTx struct { 26 + ts time.Time 27 + height int64 28 + 29 + mutableTree *iavl.MutableTree // only present if upgradable 30 + tree ReadTree 31 + db dbm.DB 32 + 33 + sequenceGetter func(tx Read) (uint64, error) 34 + } 35 + 36 + func (f *Factory) ReadWorking(ts time.Time) Read { 37 + return &readTx{ 38 + ts: ts, 39 + height: f.tree.WorkingVersion(), 40 + mutableTree: f.tree, 41 + db: f.db, 42 + sequenceGetter: f.sequenceGetter, 43 + } 44 + } 45 + 46 + func (f *Factory) ReadCommitted() Read { 47 + tx, err := f.ReadHeight(time.Now(), f.tree.Version()) 48 + if err != nil { 49 + // this should never happen, it's not worth making the signature of this function more 50 + // complex for an error we'll never return unless the ABCI application is yet to be initialized 51 + panic(stacktrace.Propagate(err, "")) 52 + } 53 + return tx 54 + } 55 + 56 + func (f *Factory) ReadHeight(ts time.Time, height int64) (Read, error) { 57 + immutable, err := f.tree.GetImmutable(height) 58 + if err != nil { 59 + return nil, stacktrace.Propagate(err, "") 60 + } 61 + return &readTx{ 62 + ts: ts, 63 + height: height, 64 + tree: AdaptImmutableTree(immutable), 65 + db: f.db, 66 + sequenceGetter: f.sequenceGetter, 67 + }, nil 68 + } 69 + 70 + // Height implements [Read]. 71 + func (t *readTx) Height() int64 { 72 + return t.height 73 + } 74 + 75 + // Timestamp implements [Read]. 76 + func (t *readTx) Timestamp() time.Time { 77 + return t.ts 78 + } 79 + 80 + // Tree implements [Read]. 81 + func (t *readTx) Tree() ReadTree { 82 + if t.mutableTree != nil { 83 + return t.mutableTree 84 + } 85 + return t.tree 86 + } 87 + 88 + // IndexDB implements [Read]. 89 + func (t *readTx) IndexDB() ReadIndex { 90 + return t.db 91 + } 92 + 93 + // Upgrade implements [Read]. 94 + func (t *readTx) Upgrade() (Write, error) { 95 + if t.mutableTree == nil { 96 + return nil, stacktrace.NewError("historical transaction is not upgradable to a write transaction") 97 + } 98 + return &writeTx{ 99 + readTx: t, 100 + }, nil 101 + }
+315
transaction/write_index.go
··· 1 + package transaction 2 + 3 + import ( 4 + "bytes" 5 + "slices" 6 + "sort" 7 + "unsafe" 8 + 9 + "cosmossdk.io/core/store" 10 + dbm "github.com/cometbft/cometbft-db" 11 + "github.com/palantir/stacktrace" 12 + ) 13 + 14 + // writeIndex provides write transactions with read uncommitted behavior 15 + type writeIndex struct { 16 + batch dbm.Batch 17 + db dbm.DB 18 + 19 + unsavedAdditions map[string][]byte 20 + unsavedRemovals map[string]struct{} 21 + } 22 + 23 + // Delete implements [WriteIndex]. 24 + func (w *writeIndex) Delete(key []byte) error { 25 + err := w.batch.Delete(key) 26 + if err != nil { 27 + return stacktrace.Propagate(err, "") 28 + } 29 + 30 + kstr := unsafeBytesToStr(key) 31 + w.unsavedRemovals[kstr] = struct{}{} 32 + delete(w.unsavedAdditions, kstr) 33 + 34 + return nil 35 + } 36 + 37 + // Set implements [WriteIndex]. 38 + func (w *writeIndex) Set(key []byte, value []byte) error { 39 + err := w.batch.Set(key, value) 40 + if err != nil { 41 + return stacktrace.Propagate(err, "") 42 + } 43 + 44 + kstr := unsafeBytesToStr(key) 45 + w.unsavedAdditions[kstr] = value 46 + delete(w.unsavedRemovals, kstr) 47 + 48 + return nil 49 + } 50 + 51 + // Get implements [WriteIndex]. 52 + func (w *writeIndex) Get(key []byte) ([]byte, error) { 53 + kstr := unsafeBytesToStr(key) 54 + 55 + if _, ok := w.unsavedRemovals[kstr]; ok { 56 + return nil, nil 57 + } 58 + if v, ok := w.unsavedAdditions[kstr]; ok { 59 + return v, nil 60 + } 61 + 62 + v, err := w.db.Get(key) 63 + return v, stacktrace.Propagate(err, "") 64 + } 65 + 66 + // Has implements [WriteIndex]. 67 + func (w *writeIndex) Has(key []byte) (bool, error) { 68 + kstr := unsafeBytesToStr(key) 69 + 70 + if _, ok := w.unsavedRemovals[kstr]; ok { 71 + return false, nil 72 + } 73 + if _, ok := w.unsavedAdditions[kstr]; ok { 74 + return true, nil 75 + } 76 + 77 + v, err := w.db.Has(key) 78 + return v, stacktrace.Propagate(err, "") 79 + } 80 + 81 + // Iterator implements [WriteIndex]. 82 + func (w *writeIndex) Iterator(start []byte, end []byte) (dbm.Iterator, error) { 83 + v, err := newUnsavedIterator(start, end, true, w.db, w.unsavedAdditions, w.unsavedRemovals) 84 + return v, stacktrace.Propagate(err, "") 85 + } 86 + 87 + // ReverseIterator implements [WriteIndex]. 88 + func (w *writeIndex) ReverseIterator(start []byte, end []byte) (dbm.Iterator, error) { 89 + v, err := newUnsavedIterator(start, end, false, w.db, w.unsavedAdditions, w.unsavedRemovals) 90 + return v, stacktrace.Propagate(err, "") 91 + } 92 + 93 + func (w *writeIndex) Commit() error { 94 + err := w.batch.Write() 95 + if err != nil { 96 + return stacktrace.Propagate(err, "") 97 + } 98 + 99 + err = w.batch.Close() 100 + if err != nil { 101 + return stacktrace.Propagate(err, "") 102 + } 103 + return nil 104 + } 105 + 106 + func (w *writeIndex) Rollback() error { 107 + err := w.batch.Close() 108 + if err != nil { 109 + return stacktrace.Propagate(err, "") 110 + } 111 + return nil 112 + } 113 + 114 + type unsavedIterator struct { 115 + start, end []byte 116 + ascending bool 117 + err error 118 + nextKey []byte 119 + nextVal []byte 120 + underlyingIterator dbm.Iterator 121 + firstUnderlingNextDone bool 122 + 123 + nextUnsavedNodeIdx int 124 + unsavedKeyAdditions map[string][]byte 125 + unsavedKeyRemovals map[string]struct{} 126 + unsavedKeysToSort []string 127 + } 128 + 129 + var _ store.Iterator = (*unsavedIterator)(nil) 130 + 131 + func newUnsavedIterator(start, end []byte, ascending bool, db dbm.DB, unsavedNodeAdditions map[string][]byte, unsavedNodeRemovals map[string]struct{}) (*unsavedIterator, error) { 132 + iter := &unsavedIterator{ 133 + start: start, 134 + end: end, 135 + ascending: ascending, 136 + unsavedKeyAdditions: unsavedNodeAdditions, 137 + unsavedKeyRemovals: unsavedNodeRemovals, 138 + nextKey: nil, 139 + nextVal: nil, 140 + nextUnsavedNodeIdx: 0, 141 + } 142 + 143 + var err error 144 + if ascending { 145 + iter.underlyingIterator, err = db.Iterator(start, end) 146 + } else { 147 + iter.underlyingIterator, err = db.ReverseIterator(start, end) 148 + } 149 + if err != nil { 150 + return nil, stacktrace.Propagate(err, "") 151 + } 152 + 153 + // We need to ensure that we iterate over saved and unsaved state in order. 154 + // The strategy is to sort unsaved keys, the keys on dbm.DB are already sorted. 155 + // Then, we keep a pointer to both the unsaved and saved keys, and iterate over them in order efficiently. 156 + for k := range unsavedNodeAdditions { 157 + kbytes := unsafeStrToBytes(k) 158 + if start != nil && bytes.Compare(kbytes, start) < 0 { 159 + continue 160 + } 161 + 162 + if end != nil && bytes.Compare(kbytes, end) >= 0 { 163 + continue 164 + } 165 + 166 + iter.unsavedKeysToSort = append(iter.unsavedKeysToSort, k) 167 + } 168 + 169 + sort.Slice(iter.unsavedKeysToSort, func(i, j int) bool { 170 + if ascending { 171 + return iter.unsavedKeysToSort[i] < iter.unsavedKeysToSort[j] 172 + } 173 + return iter.unsavedKeysToSort[i] > iter.unsavedKeysToSort[j] 174 + }) 175 + 176 + // Move to the first element 177 + iter.Next() 178 + 179 + return iter, nil 180 + } 181 + 182 + // Domain implements [[dbm.Iterator]]. 183 + func (iter *unsavedIterator) Domain() ([]byte, []byte) { 184 + return iter.start, iter.end 185 + } 186 + 187 + // Valid implements [[dbm.Iterator]]. 188 + func (iter *unsavedIterator) Valid() bool { 189 + if iter.start != nil && iter.end != nil { 190 + if bytes.Compare(iter.end, iter.start) != 1 { 191 + return false 192 + } 193 + } 194 + 195 + return iter.underlyingIterator.Valid() || iter.nextUnsavedNodeIdx < len(iter.unsavedKeysToSort) || (iter.nextKey != nil && iter.nextVal != nil) 196 + } 197 + 198 + // Key implements [[dbm.Iterator]] 199 + func (iter *unsavedIterator) Key() []byte { 200 + return iter.nextKey 201 + } 202 + 203 + // Value implements [[dbm.Iterator]] 204 + func (iter *unsavedIterator) Value() []byte { 205 + return iter.nextVal 206 + } 207 + 208 + // Next implements [[dbm.Iterator]] 209 + // It's effectively running the constant space overhead algorithm for streaming through sorted lists: 210 + // the sorted lists being underlying keys & unsavedKeyAdditions / unsavedKeyRemovals 211 + func (iter *unsavedIterator) Next() { 212 + if iter.underlyingIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedKeysToSort) { 213 + diskKey := iter.underlyingIterator.Key() 214 + diskKeyStr := unsafeBytesToStr(diskKey) 215 + if _, ok := iter.unsavedKeyRemovals[diskKeyStr]; ok { 216 + // If next underlying key is to be removed, skip it. 217 + iter.underlyingIterator.Next() 218 + iter.Next() 219 + return 220 + } 221 + 222 + nextUnsavedKey := iter.unsavedKeysToSort[iter.nextUnsavedNodeIdx] 223 + nextUnsavedVal, _ := iter.unsavedKeyAdditions[nextUnsavedKey] 224 + 225 + var isUnsavedNext bool 226 + if iter.ascending { 227 + isUnsavedNext = diskKeyStr >= nextUnsavedKey 228 + } else { 229 + isUnsavedNext = diskKeyStr <= nextUnsavedKey 230 + } 231 + 232 + if isUnsavedNext { 233 + // Unsaved key is next 234 + if diskKeyStr == nextUnsavedKey { 235 + // Unsaved update prevails over saved copy so we skip the copy from the underlying iterator 236 + iter.underlyingIterator.Next() 237 + } 238 + 239 + iter.nextKey = unsafeStrToBytes(nextUnsavedKey) 240 + iter.nextVal = nextUnsavedVal 241 + 242 + iter.nextUnsavedNodeIdx++ 243 + return 244 + } 245 + // Underlying key is next 246 + iter.nextKey = slices.Clone(diskKey) 247 + iter.nextVal = slices.Clone(iter.underlyingIterator.Value()) 248 + 249 + iter.underlyingIterator.Next() 250 + return 251 + } 252 + 253 + // if only nodes on disk are left, we return them 254 + if iter.underlyingIterator.Valid() { 255 + diskKey := iter.underlyingIterator.Key() 256 + diskKeyStr := unsafeBytesToStr(diskKey) 257 + if _, ok := iter.unsavedKeyRemovals[diskKeyStr]; ok { 258 + // If next underlying key is to be removed, skip it. 259 + iter.underlyingIterator.Next() 260 + iter.Next() 261 + return 262 + } 263 + 264 + iter.nextKey = slices.Clone(diskKey) 265 + iter.nextVal = slices.Clone(iter.underlyingIterator.Value()) 266 + 267 + iter.underlyingIterator.Next() 268 + return 269 + } 270 + 271 + // if only unsaved nodes are left, we can just iterate 272 + if iter.nextUnsavedNodeIdx < len(iter.unsavedKeysToSort) { 273 + nextUnsavedKey := iter.unsavedKeysToSort[iter.nextUnsavedNodeIdx] 274 + nextUnsavedNodeVal, _ := iter.unsavedKeyAdditions[nextUnsavedKey] 275 + 276 + iter.nextKey = unsafeStrToBytes(nextUnsavedKey) 277 + iter.nextVal = nextUnsavedNodeVal 278 + 279 + iter.nextUnsavedNodeIdx++ 280 + return 281 + } 282 + 283 + iter.nextKey = nil 284 + iter.nextVal = nil 285 + } 286 + 287 + // Close implements [[dbm.Iterator]] 288 + func (iter *unsavedIterator) Close() error { 289 + return stacktrace.Propagate(iter.underlyingIterator.Close(), "") 290 + } 291 + 292 + // Error implements [[dbm.Iterator]] 293 + func (iter *unsavedIterator) Error() error { 294 + return iter.err 295 + } 296 + 297 + // unsafeStrToBytes uses unsafe to convert string into byte array. Returned bytes 298 + // must not be altered after this function is called as it will cause a segmentation fault. 299 + func unsafeStrToBytes(s string) []byte { 300 + if len(s) == 0 { 301 + return nil 302 + } 303 + return unsafe.Slice(unsafe.StringData(s), len(s)) 304 + } 305 + 306 + // unsafeBytesToStr is meant to make a zero allocation conversion 307 + // from []byte -> string to speed up operations, it is not meant 308 + // to be used generally, but for a specific pattern to delete keys 309 + // from a map. 310 + func unsafeBytesToStr(b []byte) string { 311 + if len(b) == 0 { 312 + return "" 313 + } 314 + return unsafe.String(&b[0], len(b)) 315 + }
+99
transaction/write_tx.go
··· 1 + package transaction 2 + 3 + import ( 4 + "time" 5 + 6 + "github.com/palantir/stacktrace" 7 + ) 8 + 9 + type writeTx struct { 10 + readTx *readTx 11 + 12 + writeIndex *writeIndex 13 + 14 + hasSeq bool 15 + seq uint64 16 + } 17 + 18 + // Downgrade implements [Write]. 19 + func (w *writeTx) Downgrade() Read { 20 + return &readOnWriteTx{w: w} 21 + } 22 + 23 + // Commit implements [Write]. 24 + func (w *writeTx) Commit() error { 25 + _, _, err := w.readTx.mutableTree.SaveVersion() 26 + if err != nil { 27 + return stacktrace.Propagate(err, "") 28 + } 29 + 30 + if w.writeIndex != nil { 31 + err := w.writeIndex.Commit() 32 + if err != nil { 33 + return stacktrace.Propagate(err, "") 34 + } 35 + } 36 + return nil 37 + } 38 + 39 + // Rollback implements [Write]. 40 + func (w *writeTx) Rollback() error { 41 + w.readTx.mutableTree.Rollback() 42 + 43 + if w.writeIndex != nil { 44 + err := w.writeIndex.Rollback() 45 + if err != nil { 46 + return stacktrace.Propagate(err, "") 47 + } 48 + } 49 + return nil 50 + } 51 + 52 + // Height implements [Write]. 53 + func (w *writeTx) Height() int64 { 54 + return w.readTx.height 55 + } 56 + 57 + // IndexDB implements [Write]. 58 + func (w *writeTx) IndexDB() WriteIndex { 59 + w.createWriteIndexIfNeeded() 60 + 61 + return w.writeIndex 62 + } 63 + 64 + // Tree implements [Write]. 65 + func (w *writeTx) Tree() UnifiedTree { 66 + return w.readTx.mutableTree 67 + } 68 + 69 + // NextSequence implements [Write]. 70 + func (w *writeTx) NextSequence() (uint64, error) { 71 + if !w.hasSeq { 72 + var err error 73 + w.seq, err = w.readTx.sequenceGetter(w.readTx) 74 + if err != nil { 75 + return 0, stacktrace.Propagate(err, "") 76 + } 77 + w.hasSeq = true 78 + return w.seq, nil 79 + } 80 + 81 + w.seq++ 82 + return w.seq, nil 83 + } 84 + 85 + // Timestamp implements [Write]. 86 + func (w *writeTx) Timestamp() time.Time { 87 + return w.readTx.ts 88 + } 89 + 90 + func (w *writeTx) createWriteIndexIfNeeded() { 91 + if w.writeIndex == nil { 92 + w.writeIndex = &writeIndex{ 93 + batch: w.readTx.db.NewBatch(), 94 + db: w.readTx.db, 95 + unsavedAdditions: make(map[string][]byte), 96 + unsavedRemovals: make(map[string]struct{}), 97 + } 98 + } 99 + }