Live video on the AT Protocol
at next 141 lines 2.9 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "os" 8 "os/exec" 9 "strings" 10 "sync/atomic" 11 "testing" 12 "time" 13 14 "github.com/google/uuid" 15 "github.com/stretchr/testify/require" 16 "golang.org/x/sync/errgroup" 17 "gorm.io/driver/postgres" 18 "stream.place/streamplace/pkg/config" 19 "stream.place/streamplace/pkg/model" 20) 21 22var postgresURL string 23 24func TestMain(m *testing.M) { 25 postgresCommand := os.Getenv("STREAMPLACE_TEST_POSTGRES_COMMAND") 26 postgresURL = os.Getenv("STREAMPLACE_TEST_POSTGRES_URL") 27 if postgresCommand != "" { 28 // Start postgres process 29 fmt.Printf("Starting postgres process with command: %s\n", postgresCommand) 30 cmd := exec.Command("bash", "-c", postgresCommand) 31 err := cmd.Start() 32 if err != nil { 33 fmt.Printf("Failed to start postgres: %v\n", err) 34 os.Exit(1) 35 } 36 37 // Give postgres time to start up 38 time.Sleep(2 * time.Second) 39 40 // Run tests 41 exitCode := m.Run() 42 43 // Clean up postgres process 44 if cmd.Process != nil { 45 cmd2 := exec.Command("pkill", "postgres") 46 err := cmd2.Run() 47 if err != nil { 48 fmt.Printf("Failed to kill postgres: %v\n", err) 49 } 50 } 51 52 os.Exit(exitCode) 53 return 54 } 55 os.Exit(m.Run()) 56} 57 58func makePostgresURL(t *testing.T) string { 59 u, err := url.Parse(postgresURL) 60 if err != nil { 61 panic(err) 62 } 63 uu, err := uuid.NewV7() 64 if err != nil { 65 panic(err) 66 } 67 dbName := fmt.Sprintf("test_%s", strings.ReplaceAll(uu.String(), "-", "_")) 68 u.Path = fmt.Sprintf("/%s", dbName) 69 t.Cleanup(func() { 70 u, err := url.Parse(postgresURL) 71 if err != nil { 72 panic(err) 73 } 74 u.Path = "/postgres" 75 rootDial := postgres.Open(u.String()) 76 77 db, err := openDB(rootDial) 78 if err != nil { 79 t.Logf("Failed to open database: %v", err) 80 return 81 } 82 83 // Drop the test database 84 err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error 85 if err != nil { 86 t.Logf("Failed to drop test database: %v", err) 87 } 88 }) 89 return u.String() 90} 91 92var lockRuns = 100 93var nodeCount = 25 94 95func TestPostgresLocks(t *testing.T) { 96 if postgresURL == "" { 97 t.Skip("no postgres url, skipping postgres tests") 98 return 99 } 100 dburl := makePostgresURL(t) 101 cli := config.CLI{ 102 DBURL: dburl, 103 } 104 ctx, cancel := context.WithCancel(context.Background()) 105 defer cancel() 106 var g errgroup.Group 107 var count atomic.Uint64 108 start := make(chan struct{}) 109 for i := 0; i < nodeCount; i++ { 110 mod, err := model.MakeDB(":memory:") 111 require.NoError(t, err) 112 state, err := MakeDB(ctx, &cli, nil, mod) 113 require.NoError(t, err) 114 115 defer func() { 116 sqlDB, err := state.DB.DB() 117 require.NoError(t, err) 118 err = sqlDB.Close() 119 require.NoError(t, err) 120 }() 121 122 doLock := func() error { 123 <-start 124 unlock, err := state.GetNamedLock("test") 125 require.NoError(t, err) 126 defer unlock() 127 count.Add(1) 128 return nil 129 } 130 131 for i := 0; i < lockRuns; i++ { 132 g.Go(doLock) 133 } 134 } 135 close(start) 136 137 err := g.Wait() 138 require.NoError(t, err) 139 require.Equal(t, int(count.Load()), int(uint64(lockRuns*nodeCount))) 140 141}