Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "crypto/sha256"
6 "encoding/binary"
7 "errors"
8 "fmt"
9 "sync"
10
11 "github.com/cenkalti/backoff"
12 "gorm.io/gorm"
13 "stream.place/streamplace/pkg/log"
14)
15
16func (state *StatefulDB) GetNamedLock(name string) (func(), error) {
17 switch state.Type {
18 case DBTypeSQLite:
19 return state.getNamedLockSQLite(name)
20 case DBTypePostgres:
21 return state.getNamedLockPostgres(name)
22 }
23 panic("unsupported database type")
24}
25
26var ErrNoLock = fmt.Errorf("pg_try_advisory_lock returned false")
27
28func (state *StatefulDB) getNamedLockPostgres(name string) (func(), error) {
29 // we also use a local lock here - whoever is locking wants exclusive access even within the node
30 lock := state.locks.GetLock(name)
31 lock.Lock()
32 // Convert string to sha256 hash and use decimal value for advisory lock
33 h := sha256.Sum256([]byte(name))
34 nameInt := int64(binary.BigEndian.Uint64(h[:8]))
35
36 log.Debug(context.Background(), fmt.Sprintf("starting SELECT pg_advisory_lock(%d)", nameInt))
37 err := state.pgLockBackoff(nameInt)
38 if err != nil {
39 lock.Unlock()
40 return nil, err
41 }
42 return func() {
43 log.Debug(context.Background(), fmt.Sprintf("starting SELECT pg_advisory_unlock(%d)", nameInt))
44 err := state.pgUnlock(nameInt)
45 if err != nil {
46 // unfortunate, but the risk is that we're holding on to the lock forever,
47 // so it's responsible to crash in this case
48 panic(fmt.Errorf("error unlocking named lock: %w", err))
49 }
50 lock.Unlock()
51 }, nil
52}
53
54func (state *StatefulDB) pgLockBackoff(key int64) error {
55 ticker := backoff.NewTicker(backoff.NewExponentialBackOff())
56 defer ticker.Stop()
57 var err error
58 for i := 0; i < 10; i++ {
59 err = state.pgLock(key)
60 if err == nil {
61 return nil
62 }
63 if !errors.Is(err, ErrNoLock) {
64 return err
65 }
66 if i < 9 {
67 <-ticker.C
68 }
69 }
70 return fmt.Errorf("failed to lock after 10 attempts: %w", err)
71}
72
73func (state *StatefulDB) pgLock(key int64) error {
74 state.pgLockConnMu.Lock()
75 defer state.pgLockConnMu.Unlock()
76 var locked bool
77 err := state.pgLockConn.Raw("SELECT pg_try_advisory_lock($1)", key).Scan(&locked).Error
78 if err == nil && !locked {
79 log.Error(context.Background(), fmt.Sprintf("pg_try_advisory_lock returned false for key %d", key))
80 err = ErrNoLock
81 }
82 return err
83}
84
85func (state *StatefulDB) pgUnlock(key int64) error {
86 state.pgLockConnMu.Lock()
87 defer state.pgLockConnMu.Unlock()
88 var unlocked bool
89 err := state.pgLockConn.Raw("SELECT pg_advisory_unlock($1)", key).Scan(&unlocked).Error
90 if err == nil && !unlocked {
91 err = fmt.Errorf("pg_advisory_unlock returned false")
92 }
93 return err
94}
95
96// startLockerConn starts a dedicated connection to the database for locking
97func (state *StatefulDB) startPostgresLockerConn(ctx context.Context) error {
98 done := make(chan struct{})
99 var err error
100 go func() {
101 err = state.DB.Connection(func(tx *gorm.DB) error {
102 state.pgLockConn = tx
103 close(done)
104 // hold this open until the context is done
105 <-ctx.Done()
106 return nil
107 })
108 if err != nil {
109 close(done)
110 }
111 }()
112 <-done
113 return err
114}
115
116func (state *StatefulDB) getNamedLockSQLite(name string) (func(), error) {
117 lock := state.locks.GetLock(name)
118 lock.Lock()
119 return func() {
120 lock.Unlock()
121 }, nil
122}
123
124// Local mutex implementation for sqlite
125type NamedLocks struct {
126 mu sync.Mutex
127 locks map[string]*sync.Mutex
128}
129
130// NewNamedLocks creates a new NamedLocks instance
131func NewNamedLocks() *NamedLocks {
132 return &NamedLocks{
133 locks: make(map[string]*sync.Mutex),
134 }
135}
136
137// GetLock returns the mutex for the given name, creating it if it doesn't exist
138func (n *NamedLocks) GetLock(name string) *sync.Mutex {
139 n.mu.Lock()
140 defer n.mu.Unlock()
141
142 lock, exists := n.locks[name]
143 if !exists {
144 lock = &sync.Mutex{}
145 n.locks[name] = lock
146 }
147 return lock
148}