Live video on the AT Protocol
at natb/certmagic 325 lines 7.6 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io/fs" 8 "os" 9 "path/filepath" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/caddyserver/certmagic" 15) 16 17type StreamplaceCertStorage struct { 18 Path string 19 20 // locks keeps track of current locks for this storage instance 21 locks map[string]*fileLock 22 locksMu sync.Mutex 23} 24 25type fileLock struct { 26 name string 27 path string 28 lockFile *os.File 29 cancel context.CancelFunc 30 done chan struct{} 31} 32 33func NewStreamplaceCertStorage(storagePath string) *StreamplaceCertStorage { 34 return &StreamplaceCertStorage{ 35 Path: storagePath, 36 locks: make(map[string]*fileLock), 37 } 38} 39 40func (s *StreamplaceCertStorage) Store(ctx context.Context, key string, value []byte) error { 41 filePath := s.keyToPath(key) 42 43 // Ensure the directory exists 44 if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { 45 return fmt.Errorf("failed to create directory: %w", err) 46 } 47 48 // Write the file atomically by writing to a temp file first 49 tempPath := filePath + ".tmp" 50 if err := os.WriteFile(tempPath, value, 0644); err != nil { 51 return fmt.Errorf("failed to write temp file: %w", err) 52 } 53 54 // Move temp file to final location 55 if err := os.Rename(tempPath, filePath); err != nil { 56 os.Remove(tempPath) // Clean up temp file on error 57 return fmt.Errorf("failed to move temp file to final location: %w", err) 58 } 59 60 return nil 61} 62 63func (s *StreamplaceCertStorage) Load(ctx context.Context, key string) ([]byte, error) { 64 filePath := s.keyToPath(key) 65 66 data, err := os.ReadFile(filePath) 67 if err != nil { 68 if os.IsNotExist(err) { 69 return nil, fs.ErrNotExist 70 } 71 return nil, fmt.Errorf("failed to read file: %w", err) 72 } 73 74 return data, nil 75} 76 77func (s *StreamplaceCertStorage) Delete(ctx context.Context, key string) error { 78 filePath := s.keyToPath(key) 79 80 // Check if it's a directory (prefix of other keys) 81 info, err := os.Stat(filePath) 82 if err != nil { 83 if os.IsNotExist(err) { 84 return fs.ErrNotExist 85 } 86 return fmt.Errorf("failed to stat file: %w", err) 87 } 88 89 if info.IsDir() { 90 err = os.RemoveAll(filePath) 91 } else { 92 err = os.Remove(filePath) 93 } 94 95 if err != nil && !os.IsNotExist(err) { 96 return fmt.Errorf("failed to delete: %w", err) 97 } 98 99 return nil 100} 101 102func (s *StreamplaceCertStorage) Exists(ctx context.Context, key string) bool { 103 filePath := s.keyToPath(key) 104 _, err := os.Stat(filePath) 105 return err == nil 106} 107 108func (s *StreamplaceCertStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { 109 dirPath := s.keyToPath(prefix) 110 111 var keys []string 112 113 if recursive { 114 err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { 115 if err != nil { 116 return err 117 } 118 119 // Convert back to key format 120 if relPath, err := filepath.Rel(s.Path, path); err == nil { 121 key := s.pathToKey(relPath) 122 if key != prefix && strings.HasPrefix(key, prefix) { 123 keys = append(keys, key) 124 } 125 } 126 127 return nil 128 }) 129 130 if err != nil && !os.IsNotExist(err) { 131 return nil, fmt.Errorf("failed to walk directory: %w", err) 132 } 133 } else { 134 entries, err := os.ReadDir(dirPath) 135 if err != nil { 136 if os.IsNotExist(err) { 137 return keys, nil 138 } 139 return nil, fmt.Errorf("failed to read directory: %w", err) 140 } 141 142 for _, entry := range entries { 143 key := prefix 144 if key != "" && !strings.HasSuffix(key, "/") { 145 key += "/" 146 } 147 key += entry.Name() 148 keys = append(keys, key) 149 } 150 } 151 152 return keys, nil 153} 154 155func (s *StreamplaceCertStorage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { 156 filePath := s.keyToPath(key) 157 158 info, err := os.Stat(filePath) 159 if err != nil { 160 if os.IsNotExist(err) { 161 return certmagic.KeyInfo{}, fs.ErrNotExist 162 } 163 return certmagic.KeyInfo{}, fmt.Errorf("failed to stat file: %w", err) 164 } 165 166 return certmagic.KeyInfo{ 167 Key: key, 168 Modified: info.ModTime(), 169 Size: info.Size(), 170 IsTerminal: !info.IsDir(), 171 }, nil 172} 173 174func (s *StreamplaceCertStorage) Lock(ctx context.Context, name string) error { 175 s.locksMu.Lock() 176 defer s.locksMu.Unlock() 177 178 // Check if we already have this lock 179 if _, exists := s.locks[name]; exists { 180 return fmt.Errorf("lock %s already held by this process", name) 181 } 182 183 lockPath := s.keyToPath("locks/" + name + ".lock") 184 185 // Ensure lock directory exists 186 if err := os.MkdirAll(filepath.Dir(lockPath), 0755); err != nil { 187 return fmt.Errorf("failed to create lock directory: %w", err) 188 } 189 190 // Try to create the lock file exclusively 191 lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) 192 if err != nil { 193 if os.IsExist(err) { 194 // Lock file exists, check if it's stale 195 if s.isLockStale(lockPath) { 196 // Remove stale lock and try again 197 os.Remove(lockPath) 198 lockFile, err = os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) 199 if err != nil { 200 return fmt.Errorf("failed to acquire lock after removing stale lock: %w", err) 201 } 202 } else { 203 return fmt.Errorf("lock is held by another process") 204 } 205 } else { 206 return fmt.Errorf("failed to create lock file: %w", err) 207 } 208 } 209 210 // Write lock info with timestamp 211 lockInfo := map[string]any{ 212 "pid": os.Getpid(), 213 "timestamp": time.Now().Unix(), 214 } 215 216 lockData, err := json.Marshal(lockInfo) 217 if err != nil { 218 lockFile.Close() 219 os.Remove(lockPath) 220 return fmt.Errorf("failed to marshal lock info: %w", err) 221 } 222 _, err = lockFile.Write(lockData) 223 if err != nil { 224 lockFile.Close() 225 os.Remove(lockPath) 226 return fmt.Errorf("failed to write lock info: %w", err) 227 } 228 229 lockFile.Close() 230 231 // Create a file lock struct and start the renewal goroutine 232 ctx, cancel := context.WithCancel(ctx) 233 lock := &fileLock{ 234 name: name, 235 path: lockPath, 236 cancel: cancel, 237 done: make(chan struct{}), 238 } 239 240 s.locks[name] = lock 241 242 // Start renewal goroutine to update timestamp periodically 243 go s.renewLock(ctx, lock) 244 245 return nil 246} 247 248func (s *StreamplaceCertStorage) Unlock(ctx context.Context, name string) error { 249 s.locksMu.Lock() 250 defer s.locksMu.Unlock() 251 252 lock, exists := s.locks[name] 253 if !exists { 254 return fmt.Errorf("lock %s not held by this process", name) 255 } 256 257 lock.cancel() 258 259 <-lock.done 260 261 err := os.Remove(lock.path) 262 if err != nil && !os.IsNotExist(err) { 263 return fmt.Errorf("failed to remove lock file: %w", err) 264 } 265 266 delete(s.locks, name) 267 268 return nil 269} 270 271func (s *StreamplaceCertStorage) keyToPath(key string) string { 272 return filepath.Join(s.Path, filepath.FromSlash(key)) 273} 274func (s *StreamplaceCertStorage) pathToKey(path string) string { 275 return filepath.ToSlash(path) 276} 277func (s *StreamplaceCertStorage) isLockStale(lockPath string) bool { 278 data, err := os.ReadFile(lockPath) 279 if err != nil { 280 return true 281 } 282 283 var lockInfo map[string]any 284 if err := json.Unmarshal(data, &lockInfo); err != nil { 285 return true 286 } 287 288 timestamp, ok := lockInfo["timestamp"].(float64) 289 if !ok { 290 return true 291 } 292 293 lockTime := time.Unix(int64(timestamp), 0) 294 return time.Since(lockTime) > 30*time.Second 295} 296func (s *StreamplaceCertStorage) renewLock(ctx context.Context, lock *fileLock) { 297 defer close(lock.done) 298 299 ticker := time.NewTicker(5 * time.Second) 300 defer ticker.Stop() 301 302 for { 303 select { 304 case <-ctx.Done(): 305 return 306 case <-ticker.C: 307 // update the lock file timestamp 308 lockInfo := map[string]any{ 309 "pid": os.Getpid(), 310 "timestamp": time.Now().Unix(), 311 } 312 313 if lockData, err := json.Marshal(lockInfo); err == nil { 314 err = os.WriteFile(lock.path, lockData, 0644) 315 if err != nil { 316 // lock is probably stale, remove it 317 s.locksMu.Lock() 318 delete(s.locks, lock.name) 319 s.locksMu.Unlock() 320 return 321 } 322 } 323 } 324 } 325}