Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "os"
8 "os/exec"
9 "strings"
10 "sync/atomic"
11 "testing"
12 "time"
13
14 "github.com/google/uuid"
15 "github.com/stretchr/testify/require"
16 "golang.org/x/sync/errgroup"
17 "gorm.io/driver/postgres"
18 "stream.place/streamplace/pkg/config"
19 "stream.place/streamplace/pkg/model"
20)
21
22var postgresURL string
23
24func TestMain(m *testing.M) {
25 postgresCommand := os.Getenv("STREAMPLACE_TEST_POSTGRES_COMMAND")
26 postgresURL = os.Getenv("STREAMPLACE_TEST_POSTGRES_URL")
27 if postgresCommand != "" {
28 // Start postgres process
29 fmt.Printf("Starting postgres process with command: %s\n", postgresCommand)
30 cmd := exec.Command("bash", "-c", postgresCommand)
31 err := cmd.Start()
32 if err != nil {
33 fmt.Printf("Failed to start postgres: %v\n", err)
34 os.Exit(1)
35 }
36
37 // Give postgres time to start up
38 time.Sleep(2 * time.Second)
39
40 // Run tests
41 exitCode := m.Run()
42
43 // Clean up postgres process
44 if cmd.Process != nil {
45 cmd2 := exec.Command("pkill", "postgres")
46 err := cmd2.Run()
47 if err != nil {
48 fmt.Printf("Failed to kill postgres: %v\n", err)
49 }
50 }
51
52 os.Exit(exitCode)
53 return
54 }
55 os.Exit(m.Run())
56}
57
58func makePostgresURL(t *testing.T) string {
59 u, err := url.Parse(postgresURL)
60 if err != nil {
61 panic(err)
62 }
63 uu, err := uuid.NewV7()
64 if err != nil {
65 panic(err)
66 }
67 dbName := fmt.Sprintf("test_%s", strings.ReplaceAll(uu.String(), "-", "_"))
68 u.Path = fmt.Sprintf("/%s", dbName)
69 t.Cleanup(func() {
70 u, err := url.Parse(postgresURL)
71 if err != nil {
72 panic(err)
73 }
74 u.Path = "/postgres"
75 rootDial := postgres.Open(u.String())
76
77 db, err := openDB(rootDial)
78 if err != nil {
79 t.Logf("Failed to open database: %v", err)
80 return
81 }
82
83 // Drop the test database
84 err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
85 if err != nil {
86 t.Logf("Failed to drop test database: %v", err)
87 }
88 })
89 return u.String()
90}
91
92var lockRuns = 100
93var nodeCount = 25
94
95func TestPostgresLocks(t *testing.T) {
96 if postgresURL == "" {
97 t.Skip("no postgres url, skipping postgres tests")
98 return
99 }
100 dburl := makePostgresURL(t)
101 cli := config.CLI{
102 DBURL: dburl,
103 }
104 ctx, cancel := context.WithCancel(context.Background())
105 defer cancel()
106 var g errgroup.Group
107 var count atomic.Uint64
108 start := make(chan struct{})
109 for i := 0; i < nodeCount; i++ {
110 mod, err := model.MakeDB(":memory:")
111 require.NoError(t, err)
112 state, err := MakeDB(ctx, &cli, nil, mod)
113 require.NoError(t, err)
114
115 defer func() {
116 sqlDB, err := state.DB.DB()
117 require.NoError(t, err)
118 err = sqlDB.Close()
119 require.NoError(t, err)
120 }()
121
122 doLock := func() error {
123 <-start
124 unlock, err := state.GetNamedLock("test")
125 require.NoError(t, err)
126 defer unlock()
127 count.Add(1)
128 return nil
129 }
130
131 for i := 0; i < lockRuns; i++ {
132 g.Go(doLock)
133 }
134 }
135 close(start)
136
137 err := g.Wait()
138 require.NoError(t, err)
139 require.Equal(t, int(count.Load()), int(uint64(lockRuns*nodeCount)))
140
141}