Live video on the AT Protocol
at eli/handle-changes 148 lines 3.7 kB view raw
1package statedb 2 3import ( 4 "context" 5 "crypto/sha256" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "sync" 10 11 "github.com/cenkalti/backoff" 12 "gorm.io/gorm" 13 "stream.place/streamplace/pkg/log" 14) 15 16func (state *StatefulDB) GetNamedLock(name string) (func(), error) { 17 switch state.Type { 18 case DBTypeSQLite: 19 return state.getNamedLockSQLite(name) 20 case DBTypePostgres: 21 return state.getNamedLockPostgres(name) 22 } 23 panic("unsupported database type") 24} 25 26var ErrNoLock = fmt.Errorf("pg_try_advisory_lock returned false") 27 28func (state *StatefulDB) getNamedLockPostgres(name string) (func(), error) { 29 // we also use a local lock here - whoever is locking wants exclusive access even within the node 30 lock := state.locks.GetLock(name) 31 lock.Lock() 32 // Convert string to sha256 hash and use decimal value for advisory lock 33 h := sha256.Sum256([]byte(name)) 34 nameInt := int64(binary.BigEndian.Uint64(h[:8])) 35 36 log.Debug(context.Background(), fmt.Sprintf("starting SELECT pg_advisory_lock(%d)", nameInt)) 37 err := state.pgLockBackoff(nameInt) 38 if err != nil { 39 lock.Unlock() 40 return nil, err 41 } 42 return func() { 43 log.Debug(context.Background(), fmt.Sprintf("starting SELECT pg_advisory_unlock(%d)", nameInt)) 44 err := state.pgUnlock(nameInt) 45 if err != nil { 46 // unfortunate, but the risk is that we're holding on to the lock forever, 47 // so it's responsible to crash in this case 48 panic(fmt.Errorf("error unlocking named lock: %w", err)) 49 } 50 lock.Unlock() 51 }, nil 52} 53 54func (state *StatefulDB) pgLockBackoff(key int64) error { 55 ticker := backoff.NewTicker(backoff.NewExponentialBackOff()) 56 defer ticker.Stop() 57 var err error 58 for i := 0; i < 10; i++ { 59 err = state.pgLock(key) 60 if err == nil { 61 return nil 62 } 63 if !errors.Is(err, ErrNoLock) { 64 return err 65 } 66 if i < 9 { 67 <-ticker.C 68 } 69 } 70 return fmt.Errorf("failed to lock after 10 attempts: %w", err) 71} 72 73func (state *StatefulDB) pgLock(key int64) error { 74 state.pgLockConnMu.Lock() 75 defer state.pgLockConnMu.Unlock() 76 var locked bool 77 err := state.pgLockConn.Raw("SELECT pg_try_advisory_lock($1)", key).Scan(&locked).Error 78 if err == nil && !locked { 79 log.Error(context.Background(), fmt.Sprintf("pg_try_advisory_lock returned false for key %d", key)) 80 err = ErrNoLock 81 } 82 return err 83} 84 85func (state *StatefulDB) pgUnlock(key int64) error { 86 state.pgLockConnMu.Lock() 87 defer state.pgLockConnMu.Unlock() 88 var unlocked bool 89 err := state.pgLockConn.Raw("SELECT pg_advisory_unlock($1)", key).Scan(&unlocked).Error 90 if err == nil && !unlocked { 91 err = fmt.Errorf("pg_advisory_unlock returned false") 92 } 93 return err 94} 95 96// startLockerConn starts a dedicated connection to the database for locking 97func (state *StatefulDB) startPostgresLockerConn(ctx context.Context) error { 98 done := make(chan struct{}) 99 var err error 100 go func() { 101 err = state.DB.Connection(func(tx *gorm.DB) error { 102 state.pgLockConn = tx 103 close(done) 104 // hold this open until the context is done 105 <-ctx.Done() 106 return nil 107 }) 108 if err != nil { 109 close(done) 110 } 111 }() 112 <-done 113 return err 114} 115 116func (state *StatefulDB) getNamedLockSQLite(name string) (func(), error) { 117 lock := state.locks.GetLock(name) 118 lock.Lock() 119 return func() { 120 lock.Unlock() 121 }, nil 122} 123 124// Local mutex implementation for sqlite 125type NamedLocks struct { 126 mu sync.Mutex 127 locks map[string]*sync.Mutex 128} 129 130// NewNamedLocks creates a new NamedLocks instance 131func NewNamedLocks() *NamedLocks { 132 return &NamedLocks{ 133 locks: make(map[string]*sync.Mutex), 134 } 135} 136 137// GetLock returns the mutex for the given name, creating it if it doesn't exist 138func (n *NamedLocks) GetLock(name string) *sync.Mutex { 139 n.mu.Lock() 140 defer n.mu.Unlock() 141 142 lock, exists := n.locks[name] 143 if !exists { 144 lock = &sync.Mutex{} 145 n.locks[name] = lock 146 } 147 return lock 148}