1package mst
2
3import (
4 "context"
5 "encoding/hex"
6 "fmt"
7 "io"
8 "maps"
9 "math/rand"
10 "os"
11 "regexp"
12 "sort"
13 "testing"
14
15 "github.com/bluesky-social/indigo/util"
16
17 "github.com/ipfs/go-cid"
18 "github.com/ipfs/go-datastore"
19 blockstore "github.com/ipfs/go-ipfs-blockstore"
20 "github.com/ipld/go-car/v2"
21 "github.com/multiformats/go-multihash"
22 mh "github.com/multiformats/go-multihash"
23)
24
25func randCid() cid.Cid {
26 buf := make([]byte, 32)
27 rand.Read(buf)
28 c, err := cid.NewPrefixV1(cid.Raw, mh.SHA2_256).Sum(buf)
29 if err != nil {
30 panic(err)
31 }
32 return c
33}
34
35func TestBasicMst(t *testing.T) {
36
37 ctx := context.Background()
38 cst := util.CborStore(blockstore.NewBlockstore(datastore.NewMapDatastore()))
39 mst := createMST(cst, cid.Undef, []nodeEntry{}, -1)
40
41 // NOTE: these were previously generated randomly, but the random seed behavior changed
42 vals := map[string]cid.Cid{
43 "cats/cats": mustCid(t, "bafkreicwamkg77pijyudfbdmskelsnuztr6gp62lqfjv3e3urbs3gxnv2m"),
44 "dogs/dogs": mustCid(t, "bafkreihwoet2mghoduxannw3uqsq44a3if37i5omnlqmjcfuhcdegzpyn4"),
45 "cats/bears": mustCid(t, "bafkreiealwcgkpqxmr75a2vubzdertnbrip65nclv6gbsss4w7ef7fh6oy"),
46 }
47
48 for k, v := range vals {
49 nmst, err := mst.Add(ctx, k, v, -1)
50 if err != nil {
51 t.Fatal(err)
52 }
53 mst = nmst
54 }
55
56 ncid, err := mst.GetPointer(ctx)
57 if err != nil {
58 t.Fatal(err)
59 }
60
61 if ncid.String() != "bafyreidap7hdugsxisef7esd2eh26423j23r65mvlvpsdv7vbbsl5qfgxq" {
62 t.Fatal("mst generation changed", ncid.String())
63 }
64
65 // delete a key
66 nmst, err := mst.Delete(ctx, "dogs/dogs")
67 if err != nil {
68 t.Fatal(err)
69 }
70 delete(vals, "dogs/dogs")
71 assertValues(t, nmst, vals)
72
73 // update a key
74 newCid := randCid()
75 vals["cats/cats"] = newCid
76 nmst, err = nmst.Update(ctx, "cats/cats", newCid)
77 if err != nil {
78 t.Fatal(err)
79 }
80 assertValues(t, nmst, vals)
81
82 // update deleted key should fail
83 _, err = nmst.Delete(ctx, "dogs/dogs")
84 if err == nil {
85 t.Fatal("can't delete a removed key")
86 }
87 _, err = nmst.Update(ctx, "dogs/dogs", newCid)
88 if err == nil {
89 t.Fatal("can't update a removed key")
90 }
91}
92
93func assertValues(t *testing.T, mst *MerkleSearchTree, vals map[string]cid.Cid) {
94 out := make(map[string]cid.Cid)
95 if err := mst.WalkLeavesFrom(context.TODO(), "", func(key string, val cid.Cid) error {
96 out[key] = val
97 return nil
98 }); err != nil {
99 t.Fatal(err)
100 }
101
102 if len(vals) == len(out) {
103 for k, v := range vals {
104 ov, ok := out[k]
105 if !ok {
106 t.Fatalf("expected key %s to be present", k)
107 }
108 if ov != v {
109 t.Fatalf("value mismatch on %s", k)
110 }
111 }
112 } else {
113 t.Fatalf("different number of values than expected: %d != %d", len(vals), len(out))
114 }
115}
116
117func mustCid(t *testing.T, s string) cid.Cid {
118 t.Helper()
119 c, err := cid.Decode(s)
120 if err != nil {
121 t.Fatal(err)
122 }
123 return c
124
125}
126
127func loadCar(bs blockstore.Blockstore, fname string) error {
128 fi, err := os.Open(fname)
129 if err != nil {
130 return err
131 }
132 defer fi.Close()
133 br, err := car.NewBlockReader(fi)
134 if err != nil {
135 return err
136 }
137
138 for {
139 blk, err := br.Next()
140 if err != nil {
141 if err == io.EOF {
142 break
143 }
144 return err
145 }
146
147 if err := bs.Put(context.TODO(), blk); err != nil {
148 return err
149 }
150 }
151
152 return nil
153}
154
155/*
156func TestDiff(t *testing.T) {
157 to := mustCid(t, "bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454")
158 from := mustCid(t, "bafyreigv5er7vcxlbikkwedmtd7b3kp7wrcyffep5ogcuxosloxfox5reu")
159
160 bs := blockstore.NewBlockstore(datastore.NewMapDatastore())
161
162 if err := loadCar(bs, "paul.car"); err != nil {
163 t.Fatal(err)
164 }
165
166 ctx := context.TODO()
167 ops, err := DiffTrees(ctx, bs, from, to)
168 if err != nil {
169 t.Fatal(err)
170 }
171 _ = ops
172}
173*/
174
175func randStr(s int64) string {
176 buf := make([]byte, 6)
177 r := rand.New(rand.NewSource(s))
178 r.Read(buf)
179 return hex.EncodeToString(buf)
180}
181
182func TestDiffInsertionsBasic(t *testing.T) {
183 a := map[string]string{
184 "cats/asdf": randStr(1),
185 "cats/foosesdf": randStr(2),
186 }
187
188 b := maps.Clone(a)
189 b["cats/bawda"] = randStr(3)
190 b["cats/crosasd"] = randStr(4)
191
192 testMapDiffs(t, a, b)
193 testMapDiffs(t, b, a)
194}
195
196func randKey(s int64) string {
197 r := rand.New(rand.NewSource(s))
198
199 top := r.Int63n(6)
200 mid := r.Int63n(3)
201
202 end := randStr(r.Int63n(10000000))
203
204 return randStr(125125+top) + "." + randStr(858392+mid) + "/" + end
205}
206
207func TestDiffInsertionsLarge(t *testing.T) {
208 a := map[string]string{}
209 for i := int64(0); i < 1000; i++ {
210 a[randKey(i)] = randStr(72385739 - i)
211 }
212
213 b := maps.Clone(a)
214 for i := int64(0); i < 30; i++ {
215 b[randKey(5000+i)] = randStr(2293825 - i)
216 }
217
218 testMapDiffs(t, a, b)
219 testMapDiffs(t, b, a)
220}
221
222func TestDiffNoOverlap(t *testing.T) {
223 a := map[string]string{}
224 for i := int64(0); i < 10; i++ {
225 a[randKey(i)] = randStr(72385739 - i)
226 }
227
228 b := map[string]string{}
229 for i := int64(0); i < 10; i++ {
230 b[randKey(5000+i)] = randStr(2293825 - i)
231 }
232
233 testMapDiffs(t, a, b)
234 testMapDiffs(t, b, a)
235}
236
237func TestDiffSmallOverlap(t *testing.T) {
238 a := map[string]string{}
239 for i := int64(0); i < 10; i++ {
240 a[randKey(i)] = randStr(72385739 - i)
241 }
242
243 b := maps.Clone(a)
244
245 for i := int64(0); i < 1000; i++ {
246 a[randKey(i)] = randStr(682823 - i)
247 }
248
249 for i := int64(0); i < 1000; i++ {
250 b[randKey(5000+i)] = randStr(2293825 - i)
251 }
252
253 testMapDiffs(t, a, b)
254 //testMapDiffs(t, b, a)
255}
256
257func TestDiffSmallOverlapSmall(t *testing.T) {
258 a := map[string]string{}
259 for i := int64(0); i < 4; i++ {
260 a[randKey(i)] = randStr(72385739 - i)
261 }
262
263 b := maps.Clone(a)
264
265 for i := int64(0); i < 20; i++ {
266 a[randKey(i)] = randStr(682823 - i)
267 }
268
269 for i := int64(0); i < 20; i++ {
270 b[randKey(5000+i)] = randStr(2293825 - i)
271 }
272
273 testMapDiffs(t, a, b)
274 //testMapDiffs(t, b, a)
275}
276
277func TestDiffMutationsBasic(t *testing.T) {
278 a := map[string]string{
279 "cats/asdf": randStr(1),
280 "cats/foosesdf": randStr(2),
281 }
282
283 b := maps.Clone(a)
284 b["cats/asdf"] = randStr(3)
285
286 testMapDiffs(t, a, b)
287}
288
289func diffMaps(a, b map[string]cid.Cid) []*DiffOp {
290 var akeys, bkeys []string
291
292 for k := range a {
293 akeys = append(akeys, k)
294 }
295
296 for k := range b {
297 bkeys = append(bkeys, k)
298 }
299
300 sort.Strings(akeys)
301 sort.Strings(bkeys)
302
303 var out []*DiffOp
304 for _, k := range akeys {
305 av := a[k]
306 bv, ok := b[k]
307 if !ok {
308 out = append(out, &DiffOp{
309 Op: "del",
310 Rpath: k,
311 OldCid: av,
312 })
313 } else {
314 if av != bv {
315 out = append(out, &DiffOp{
316 Op: "mut",
317 Rpath: k,
318 OldCid: av,
319 NewCid: bv,
320 })
321 }
322 }
323 }
324
325 for _, k := range bkeys {
326 _, ok := a[k]
327 if !ok {
328 out = append(out, &DiffOp{
329 Op: "add",
330 Rpath: k,
331 NewCid: b[k],
332 })
333 }
334 }
335
336 sort.Slice(out, func(i, j int) bool {
337 return out[i].Rpath < out[j].Rpath
338 })
339
340 return out
341}
342
343// NOTE(bnewbold): this does *not* just call cid.Decode(), which is the simple
344// "parse a CID in string form into a cid.Cid object". This method (strToCid())
345// can sometimes result in "identity" CIDs
346func strToCid(s string) cid.Cid {
347 h, err := multihash.Sum([]byte(s), multihash.ID, -1)
348 if err != nil {
349 panic(err)
350 }
351
352 return cid.NewCidV1(cid.Raw, h)
353
354}
355
356func mapToCidMap(a map[string]string) map[string]cid.Cid {
357 out := make(map[string]cid.Cid)
358 for k, v := range a {
359 out[k] = strToCid(v)
360 }
361
362 return out
363}
364
365func cidMapToMst(t testing.TB, bs blockstore.Blockstore, m map[string]cid.Cid) *MerkleSearchTree {
366 cst := util.CborStore(bs)
367 mt := createMST(cst, cid.Undef, []nodeEntry{}, -1)
368
369 for k, v := range m {
370 nmst, err := mt.Add(context.TODO(), k, v, -1)
371 if err != nil {
372 t.Fatal(err)
373 }
374
375 mt = nmst
376 }
377
378 return mt
379}
380
381func mustCidTree(t testing.TB, tree *MerkleSearchTree) cid.Cid {
382 c, err := tree.GetPointer(context.TODO())
383 if err != nil {
384 t.Fatal(err)
385 }
386 return c
387}
388
389func memBs() blockstore.Blockstore {
390 return blockstore.NewBlockstore(datastore.NewMapDatastore())
391}
392
393func testMapDiffs(t testing.TB, a, b map[string]string) {
394 amc := mapToCidMap(a)
395 bmc := mapToCidMap(b)
396
397 exp := diffMaps(amc, bmc)
398
399 bs := memBs()
400
401 msta := cidMapToMst(t, bs, amc)
402 mstb := cidMapToMst(t, bs, bmc)
403
404 cida := mustCidTree(t, msta)
405 cidb := mustCidTree(t, mstb)
406
407 diffs, err := DiffTrees(context.TODO(), bs, cida, cidb)
408 if err != nil {
409 t.Fatal(err)
410 }
411
412 if !sort.SliceIsSorted(diffs, func(i, j int) bool {
413 return diffs[i].Rpath < diffs[j].Rpath
414 }) {
415 t.Log("diff algo did not produce properly sorted diff")
416 }
417 if !compareDiffs(diffs, exp) {
418 fmt.Println("Expected Diff:")
419 for _, do := range exp {
420 fmt.Println(do)
421 }
422 fmt.Println("Actual Diff:")
423 for _, do := range diffs {
424 fmt.Println(do)
425 }
426 t.Logf("diff lens: %d %d", len(diffs), len(exp))
427 diffDiff(diffs, exp)
428 t.Fatal("diffs not equal")
429 }
430}
431
432func diffDiff(a, b []*DiffOp) {
433 var i, j int
434
435 for i < len(a) || j < len(b) {
436 if i >= len(a) {
437 fmt.Println("+: ", b[j])
438 j++
439 continue
440 }
441
442 if j >= len(b) {
443 fmt.Println("-: ", a[i])
444 i++
445 continue
446 }
447
448 aa := a[i]
449 bb := b[j]
450
451 if diffOpEq(aa, bb) {
452 fmt.Println("eq: ", i, j, aa.Rpath)
453 i++
454 j++
455 continue
456 }
457
458 if aa.Rpath == bb.Rpath {
459 fmt.Println("~: ", aa, bb)
460 i++
461 j++
462 continue
463 }
464
465 if aa.Rpath < bb.Rpath {
466 fmt.Println("-: ", aa)
467 i++
468 continue
469 } else {
470 fmt.Println("+: ", bb)
471 j++
472 continue
473 }
474 }
475}
476
477func compareDiffs(a, b []*DiffOp) bool {
478 if len(a) != len(b) {
479 return false
480 }
481
482 for i := 0; i < len(a); i++ {
483 aa := a[i]
484 bb := b[i]
485
486 if aa.Op != bb.Op || aa.Rpath != bb.Rpath || aa.NewCid != bb.NewCid || aa.OldCid != bb.OldCid {
487 return false
488 }
489 }
490
491 return true
492}
493
494func diffOpEq(aa, bb *DiffOp) bool {
495 if aa.Op != bb.Op || aa.Rpath != bb.Rpath || aa.NewCid != bb.NewCid || aa.OldCid != bb.OldCid {
496 return false
497 }
498 return true
499}
500
501func BenchmarkIsValidMstKey(b *testing.B) {
502 b.ReportAllocs()
503 for i := 0; i < b.N; i++ {
504 if !isValidMstKey("foo/foo.bar123") {
505 b.Fatal()
506 }
507 }
508}
509
510func TestLeadingZerosOnHashAllocs(t *testing.T) {
511 var sink int
512 const in = "some.key.prefix/key.bar123456789012334556"
513 var inb = []byte(in)
514 if n := int(testing.AllocsPerRun(1000, func() {
515 sink = leadingZerosOnHash(in)
516 })); n != 0 {
517 t.Errorf("allocs (string) = %d; want 0", n)
518 }
519 if n := int(testing.AllocsPerRun(1000, func() {
520 sink = leadingZerosOnHashBytes(inb)
521 })); n != 0 {
522 t.Errorf("allocs (bytes) = %d; want 0", n)
523 }
524 _ = sink
525}
526
527// Verify that keyHasAllValidChars matches its documented regexp.
528func FuzzKeyHasAllValidChars(f *testing.F) {
529 for _, seed := range [][]byte{{}} {
530 f.Add(seed)
531 }
532 for i := 0; i < 256; i++ {
533 f.Add([]byte{byte(i)})
534 }
535 rx := regexp.MustCompile("^[a-zA-Z0-9_:.~-]+$")
536 f.Fuzz(func(t *testing.T, in []byte) {
537 s := string(in)
538 if a, b := rx.MatchString(s), keyHasAllValidChars(s); a != b {
539 t.Fatalf("for %q, rx=%v, keyHasAllValidChars=%v", s, a, b)
540 }
541 })
542}
543
544func BenchmarkLeadingZerosOnHash(b *testing.B) {
545 b.ReportAllocs()
546 for i := 0; i < b.N; i++ {
547 _ = leadingZerosOnHash("some.key.prefix/key.bar123456789012334556")
548 }
549}
550
551func BenchmarkDiffTrees(b *testing.B) {
552 b.ReportAllocs()
553 const size = 10000
554 ma := map[string]string{}
555 for i := 0; i < size; i++ {
556 ma[fmt.Sprintf("num/%02d", i)] = fmt.Sprint(i)
557 }
558 // And then mess with half of the items of the first half of it.
559 mb := maps.Clone(ma)
560 for i := 0; i < size/2; i++ {
561 switch i % 4 {
562 case 0, 1:
563 case 2:
564 delete(mb, fmt.Sprintf("num/%02d", i))
565 case 3:
566 ma[fmt.Sprintf("num/%02d", i)] = fmt.Sprint(i + 1)
567 }
568 }
569
570 amc := mapToCidMap(ma)
571 bmc := mapToCidMap(mb)
572
573 want := diffMaps(amc, bmc)
574
575 bs := memBs()
576
577 msta := cidMapToMst(b, bs, amc)
578 mstb := cidMapToMst(b, bs, bmc)
579
580 cida := mustCidTree(b, msta)
581 cidb := mustCidTree(b, mstb)
582
583 b.ResetTimer()
584
585 var diffs []*DiffOp
586 var err error
587 for i := 0; i < b.N; i++ {
588 diffs, err = DiffTrees(context.TODO(), bs, cida, cidb)
589 if err != nil {
590 b.Fatal(err)
591 }
592 }
593
594 if !sort.SliceIsSorted(diffs, func(i, j int) bool {
595 return diffs[i].Rpath < diffs[j].Rpath
596 }) {
597 b.Log("diff algo did not produce properly sorted diff")
598 }
599 if !compareDiffs(diffs, want) {
600 b.Fatal("diffs not equal")
601 }
602}
603
604var countPrefixLenTests = []struct {
605 a, b string
606 want int
607}{
608 {"", "", 0},
609 {"a", "", 0},
610 {"", "a", 0},
611 {"a", "b", 0},
612 {"a", "a", 1},
613 {"ab", "a", 1},
614 {"a", "ab", 1},
615 {"ab", "ab", 2},
616 {"abcdefghijklmnop", "abcdefghijklmnoq", 15},
617}
618
619func TestCountPrefixLen(t *testing.T) {
620 for _, tt := range countPrefixLenTests {
621 if got := countPrefixLen(tt.a, tt.b); got != tt.want {
622 t.Errorf("countPrefixLenTests(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
623 }
624 }
625}
626
627func BenchmarkCountPrefixLen(b *testing.B) {
628 b.ReportAllocs()
629 for i := 0; i < b.N; i++ {
630 for _, tt := range countPrefixLenTests {
631 if got := countPrefixLen(tt.a, tt.b); got != tt.want {
632 b.Fatalf("countPrefixLenTests(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
633 }
634 }
635 }
636}