Vibe-guided bskyoauth and custom repo example code in Golang 馃 probably not safe to use in prod
at main 14 kB view raw
1package bskyoauth 2 3import ( 4 "context" 5 "encoding/json" 6 "os" 7 "path/filepath" 8 "strings" 9 "sync" 10 "testing" 11 "time" 12) 13 14// TestNoOpAuditLogger verifies the default no-op logger doesn't cause errors. 15func TestNoOpAuditLogger(t *testing.T) { 16 logger := &NoOpAuditLogger{} 17 ctx := context.Background() 18 19 event := AuditEvent{ 20 Timestamp: time.Now().UTC(), 21 EventType: AuditEventAuthStart, 22 Action: "test_action", 23 Result: AuditResultSuccess, 24 } 25 26 err := logger.Log(ctx, event) 27 if err != nil { 28 t.Errorf("NoOpAuditLogger.Log() returned error: %v", err) 29 } 30} 31 32// TestSetAuditLogger verifies setting and getting the global audit logger. 33func TestSetAuditLogger(t *testing.T) { 34 // Save original logger 35 originalLogger := GetAuditLogger() 36 defer SetAuditLogger(originalLogger) 37 38 // Test setting a custom logger 39 customLogger := &NoOpAuditLogger{} 40 SetAuditLogger(customLogger) 41 42 retrieved := GetAuditLogger() 43 if retrieved != customLogger { 44 t.Error("GetAuditLogger() did not return the set logger") 45 } 46 47 // Test setting nil (should revert to no-op) 48 SetAuditLogger(nil) 49 retrieved = GetAuditLogger() 50 if _, ok := retrieved.(*NoOpAuditLogger); !ok { 51 t.Error("Setting nil should revert to NoOpAuditLogger") 52 } 53} 54 55// TestLogAuditEvent verifies the convenience function enriches events. 56func TestLogAuditEvent(t *testing.T) { 57 // Save original logger 58 originalLogger := GetAuditLogger() 59 defer SetAuditLogger(originalLogger) 60 61 // Create mock logger 62 var capturedEvent AuditEvent 63 mockLogger := &mockAuditLogger{ 64 logFunc: func(ctx context.Context, event AuditEvent) error { 65 capturedEvent = event 66 return nil 67 }, 68 } 69 SetAuditLogger(mockLogger) 70 71 // Create context with request ID and session ID 72 ctx := context.Background() 73 ctx = context.WithValue(ctx, ContextKeyRequestID, "test-request-123") 74 ctx = context.WithValue(ctx, ContextKeySessionID, "test-session-456") 75 76 // Log event without timestamp 77 event := AuditEvent{ 78 EventType: AuditEventPostCreate, 79 Actor: "did:plc:test123", 80 Action: "create_post", 81 Result: AuditResultSuccess, 82 } 83 84 err := LogAuditEvent(ctx, event) 85 if err != nil { 86 t.Fatalf("LogAuditEvent() returned error: %v", err) 87 } 88 89 // Verify enrichment 90 if capturedEvent.Timestamp.IsZero() { 91 t.Error("LogAuditEvent should set Timestamp if not provided") 92 } 93 if capturedEvent.RequestID != "test-request-123" { 94 t.Errorf("Expected RequestID 'test-request-123', got '%s'", capturedEvent.RequestID) 95 } 96 if capturedEvent.SessionID != "test-session-456" { 97 t.Errorf("Expected SessionID 'test-session-456', got '%s'", capturedEvent.SessionID) 98 } 99 if capturedEvent.EventType != AuditEventPostCreate { 100 t.Errorf("Event type should be preserved") 101 } 102} 103 104// TestFileAuditLogger tests basic file logging functionality. 105func TestFileAuditLogger(t *testing.T) { 106 // Create temp directory 107 tmpDir := t.TempDir() 108 logFile := filepath.Join(tmpDir, "audit.log") 109 110 // Create logger 111 logger, err := NewFileAuditLogger(logFile) 112 if err != nil { 113 t.Fatalf("NewFileAuditLogger() failed: %v", err) 114 } 115 defer logger.Close() 116 117 ctx := context.Background() 118 119 // Log some events 120 events := []AuditEvent{ 121 { 122 Timestamp: time.Now().UTC(), 123 EventType: AuditEventAuthStart, 124 Action: "start_oauth_flow", 125 Resource: "user.bsky.social", 126 Result: AuditResultSuccess, 127 }, 128 { 129 Timestamp: time.Now().UTC(), 130 EventType: AuditEventAuthSuccess, 131 Actor: "did:plc:test123", 132 Action: "complete_oauth_flow", 133 Result: AuditResultSuccess, 134 }, 135 { 136 Timestamp: time.Now().UTC(), 137 EventType: AuditEventPostCreate, 138 Actor: "did:plc:test123", 139 Action: "create_post", 140 Resource: "at://did:plc:test123/app.bsky.feed.post/abc123", 141 Result: AuditResultSuccess, 142 }, 143 } 144 145 for _, event := range events { 146 if err := logger.Log(ctx, event); err != nil { 147 t.Errorf("Log() failed: %v", err) 148 } 149 } 150 151 // Close to flush 152 if err := logger.Close(); err != nil { 153 t.Errorf("Close() failed: %v", err) 154 } 155 156 // Read and verify log file 157 data, err := os.ReadFile(logFile) 158 if err != nil { 159 t.Fatalf("Failed to read log file: %v", err) 160 } 161 162 lines := strings.Split(strings.TrimSpace(string(data)), "\n") 163 if len(lines) != len(events) { 164 t.Errorf("Expected %d log lines, got %d", len(events), len(lines)) 165 } 166 167 // Verify each line is valid JSON and contains expected data 168 for i, line := range lines { 169 var loggedEvent AuditEvent 170 if err := json.Unmarshal([]byte(line), &loggedEvent); err != nil { 171 t.Errorf("Line %d is not valid JSON: %v", i+1, err) 172 continue 173 } 174 175 if loggedEvent.EventType != events[i].EventType { 176 t.Errorf("Event %d: expected type %s, got %s", i, events[i].EventType, loggedEvent.EventType) 177 } 178 if loggedEvent.Action != events[i].Action { 179 t.Errorf("Event %d: expected action %s, got %s", i, events[i].Action, loggedEvent.Action) 180 } 181 } 182} 183 184// TestFileAuditLoggerConcurrent tests thread-safety of FileAuditLogger. 185func TestFileAuditLoggerConcurrent(t *testing.T) { 186 tmpDir := t.TempDir() 187 logFile := filepath.Join(tmpDir, "audit.log") 188 189 logger, err := NewFileAuditLogger(logFile) 190 if err != nil { 191 t.Fatalf("NewFileAuditLogger() failed: %v", err) 192 } 193 defer logger.Close() 194 195 ctx := context.Background() 196 const numGoroutines = 10 197 const eventsPerGoroutine = 10 198 199 var wg sync.WaitGroup 200 wg.Add(numGoroutines) 201 202 for i := 0; i < numGoroutines; i++ { 203 go func(id int) { 204 defer wg.Done() 205 for j := 0; j < eventsPerGoroutine; j++ { 206 event := AuditEvent{ 207 Timestamp: time.Now().UTC(), 208 EventType: AuditEventPostCreate, 209 Actor: "did:plc:test", 210 Action: "concurrent_test", 211 Metadata: map[string]interface{}{ 212 "goroutine": id, 213 "event": j, 214 }, 215 Result: AuditResultSuccess, 216 } 217 if err := logger.Log(ctx, event); err != nil { 218 t.Errorf("Log() failed: %v", err) 219 } 220 } 221 }(i) 222 } 223 224 wg.Wait() 225 226 // Close to flush 227 if err := logger.Close(); err != nil { 228 t.Errorf("Close() failed: %v", err) 229 } 230 231 // Verify we got all events 232 data, err := os.ReadFile(logFile) 233 if err != nil { 234 t.Fatalf("Failed to read log file: %v", err) 235 } 236 237 lines := strings.Split(strings.TrimSpace(string(data)), "\n") 238 expectedLines := numGoroutines * eventsPerGoroutine 239 if len(lines) != expectedLines { 240 t.Errorf("Expected %d log lines, got %d", expectedLines, len(lines)) 241 } 242 243 // Verify all lines are valid JSON 244 for i, line := range lines { 245 var event AuditEvent 246 if err := json.Unmarshal([]byte(line), &event); err != nil { 247 t.Errorf("Line %d is not valid JSON: %v", i+1, err) 248 } 249 } 250} 251 252// TestFileAuditLoggerDirectoryCreation verifies directory creation. 253func TestFileAuditLoggerDirectoryCreation(t *testing.T) { 254 tmpDir := t.TempDir() 255 logFile := filepath.Join(tmpDir, "nested", "dir", "audit.log") 256 257 logger, err := NewFileAuditLogger(logFile) 258 if err != nil { 259 t.Fatalf("NewFileAuditLogger() failed: %v", err) 260 } 261 defer logger.Close() 262 263 // Verify directory was created 264 dir := filepath.Dir(logFile) 265 if _, err := os.Stat(dir); os.IsNotExist(err) { 266 t.Error("Directory should have been created") 267 } 268 269 // Verify file was created 270 if _, err := os.Stat(logFile); os.IsNotExist(err) { 271 t.Error("Log file should have been created") 272 } 273} 274 275// TestRotatingFileAuditLogger tests daily log rotation. 276func TestRotatingFileAuditLogger(t *testing.T) { 277 tmpDir := t.TempDir() 278 279 logger, err := NewRotatingFileAuditLogger(tmpDir) 280 if err != nil { 281 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err) 282 } 283 defer logger.Close() 284 285 ctx := context.Background() 286 287 // Log an event 288 event := AuditEvent{ 289 Timestamp: time.Now().UTC(), 290 EventType: AuditEventAuthStart, 291 Action: "test_action", 292 Result: AuditResultSuccess, 293 } 294 295 if err := logger.Log(ctx, event); err != nil { 296 t.Errorf("Log() failed: %v", err) 297 } 298 299 // Verify file was created with today's date 300 expectedDate := time.Now().UTC().Format("2006-01-02") 301 expectedFile := filepath.Join(tmpDir, "audit-"+expectedDate+".log") 302 303 if _, err := os.Stat(expectedFile); os.IsNotExist(err) { 304 t.Errorf("Expected log file %s does not exist", expectedFile) 305 } 306 307 // Read and verify content 308 data, err := os.ReadFile(expectedFile) 309 if err != nil { 310 t.Fatalf("Failed to read log file: %v", err) 311 } 312 313 var loggedEvent AuditEvent 314 if err := json.Unmarshal(data, &loggedEvent); err != nil { 315 t.Errorf("Log file does not contain valid JSON: %v", err) 316 } 317 318 if loggedEvent.EventType != AuditEventAuthStart { 319 t.Errorf("Expected event type %s, got %s", AuditEventAuthStart, loggedEvent.EventType) 320 } 321} 322 323// TestRotatingFileAuditLoggerRotation verifies rotation behavior. 324func TestRotatingFileAuditLoggerRotation(t *testing.T) { 325 tmpDir := t.TempDir() 326 327 logger, err := NewRotatingFileAuditLogger(tmpDir) 328 if err != nil { 329 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err) 330 } 331 defer logger.Close() 332 333 ctx := context.Background() 334 335 // Get initial file name 336 initialDate := time.Now().UTC().Format("2006-01-02") 337 initialFile := filepath.Join(tmpDir, "audit-"+initialDate+".log") 338 339 // Log event with current date 340 event1 := AuditEvent{ 341 Timestamp: time.Now().UTC(), 342 EventType: AuditEventAuthStart, 343 Action: "before_rotation", 344 Result: AuditResultSuccess, 345 } 346 if err := logger.Log(ctx, event1); err != nil { 347 t.Errorf("Log() failed: %v", err) 348 } 349 350 // Manually change the current date to trigger rotation on next log 351 logger.mu.Lock() 352 logger.currentDate = "2023-01-01" // Old date to force rotation 353 logger.mu.Unlock() 354 355 // Log event (should trigger rotation to current date) 356 event2 := AuditEvent{ 357 Timestamp: time.Now().UTC(), 358 EventType: AuditEventAuthSuccess, 359 Action: "after_rotation", 360 Result: AuditResultSuccess, 361 } 362 if err := logger.Log(ctx, event2); err != nil { 363 t.Errorf("Log() after rotation failed: %v", err) 364 } 365 366 // Close to ensure all writes are flushed 367 if err := logger.Close(); err != nil { 368 t.Errorf("Close() failed: %v", err) 369 } 370 371 // Verify initial file still exists with first event 372 data, err := os.ReadFile(initialFile) 373 if err != nil { 374 t.Fatalf("Failed to read initial file %s: %v", initialFile, err) 375 } 376 377 if !strings.Contains(string(data), "before_rotation") { 378 t.Error("Initial file should contain first event") 379 } 380 if !strings.Contains(string(data), "after_rotation") { 381 t.Error("Initial file should contain second event after rotation back to same date") 382 } 383} 384 385// TestRotatingFileAuditLoggerConcurrent tests thread-safety during rotation. 386func TestRotatingFileAuditLoggerConcurrent(t *testing.T) { 387 tmpDir := t.TempDir() 388 389 logger, err := NewRotatingFileAuditLogger(tmpDir) 390 if err != nil { 391 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err) 392 } 393 defer logger.Close() 394 395 ctx := context.Background() 396 const numGoroutines = 10 397 const eventsPerGoroutine = 10 398 399 var wg sync.WaitGroup 400 wg.Add(numGoroutines) 401 402 for i := 0; i < numGoroutines; i++ { 403 go func(id int) { 404 defer wg.Done() 405 for j := 0; j < eventsPerGoroutine; j++ { 406 event := AuditEvent{ 407 Timestamp: time.Now().UTC(), 408 EventType: AuditEventPostCreate, 409 Actor: "did:plc:test", 410 Action: "concurrent_rotation_test", 411 Metadata: map[string]interface{}{ 412 "goroutine": id, 413 "event": j, 414 }, 415 Result: AuditResultSuccess, 416 } 417 if err := logger.Log(ctx, event); err != nil { 418 t.Errorf("Log() failed: %v", err) 419 } 420 } 421 }(i) 422 } 423 424 wg.Wait() 425 426 // Close to flush 427 if err := logger.Close(); err != nil { 428 t.Errorf("Close() failed: %v", err) 429 } 430 431 // Verify log file exists and has all events 432 expectedDate := time.Now().UTC().Format("2006-01-02") 433 expectedFile := filepath.Join(tmpDir, "audit-"+expectedDate+".log") 434 435 data, err := os.ReadFile(expectedFile) 436 if err != nil { 437 t.Fatalf("Failed to read log file: %v", err) 438 } 439 440 lines := strings.Split(strings.TrimSpace(string(data)), "\n") 441 expectedLines := numGoroutines * eventsPerGoroutine 442 if len(lines) != expectedLines { 443 t.Errorf("Expected %d log lines, got %d", expectedLines, len(lines)) 444 } 445} 446 447// TestAuditEventStructure verifies AuditEvent fields are properly serialized. 448func TestAuditEventStructure(t *testing.T) { 449 event := AuditEvent{ 450 Timestamp: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), 451 EventType: AuditEventPostCreate, 452 Actor: "did:plc:test123", 453 Action: "create_post", 454 Resource: "at://did:plc:test123/app.bsky.feed.post/abc123", 455 Result: AuditResultSuccess, 456 Error: "", 457 Metadata: map[string]interface{}{ 458 "ip_address": "192.168.1.1", 459 "user_agent": "MyApp/1.0", 460 }, 461 RequestID: "req-123", 462 SessionID: "sess-456", 463 } 464 465 // Marshal to JSON 466 data, err := json.Marshal(event) 467 if err != nil { 468 t.Fatalf("Failed to marshal event: %v", err) 469 } 470 471 // Unmarshal and verify 472 var decoded AuditEvent 473 if err := json.Unmarshal(data, &decoded); err != nil { 474 t.Fatalf("Failed to unmarshal event: %v", err) 475 } 476 477 if decoded.EventType != event.EventType { 478 t.Errorf("EventType mismatch: expected %s, got %s", event.EventType, decoded.EventType) 479 } 480 if decoded.Actor != event.Actor { 481 t.Errorf("Actor mismatch: expected %s, got %s", event.Actor, decoded.Actor) 482 } 483 if decoded.Action != event.Action { 484 t.Errorf("Action mismatch: expected %s, got %s", event.Action, decoded.Action) 485 } 486 if decoded.Resource != event.Resource { 487 t.Errorf("Resource mismatch: expected %s, got %s", event.Resource, decoded.Resource) 488 } 489 if decoded.Result != event.Result { 490 t.Errorf("Result mismatch: expected %s, got %s", event.Result, decoded.Result) 491 } 492 if decoded.RequestID != event.RequestID { 493 t.Errorf("RequestID mismatch: expected %s, got %s", event.RequestID, decoded.RequestID) 494 } 495 if decoded.SessionID != event.SessionID { 496 t.Errorf("SessionID mismatch: expected %s, got %s", event.SessionID, decoded.SessionID) 497 } 498 499 // Verify metadata 500 if decoded.Metadata["ip_address"] != "192.168.1.1" { 501 t.Error("Metadata ip_address mismatch") 502 } 503 if decoded.Metadata["user_agent"] != "MyApp/1.0" { 504 t.Error("Metadata user_agent mismatch") 505 } 506} 507 508// mockAuditLogger is a test helper for capturing logged events. 509type mockAuditLogger struct { 510 logFunc func(ctx context.Context, event AuditEvent) error 511} 512 513func (m *mockAuditLogger) Log(ctx context.Context, event AuditEvent) error { 514 if m.logFunc != nil { 515 return m.logFunc(ctx, event) 516 } 517 return nil 518}