Vibe-guided bskyoauth and custom repo example code in Golang 馃 probably not safe to use in prod
1package bskyoauth
2
3import (
4 "context"
5 "encoding/json"
6 "os"
7 "path/filepath"
8 "strings"
9 "sync"
10 "testing"
11 "time"
12)
13
14// TestNoOpAuditLogger verifies the default no-op logger doesn't cause errors.
15func TestNoOpAuditLogger(t *testing.T) {
16 logger := &NoOpAuditLogger{}
17 ctx := context.Background()
18
19 event := AuditEvent{
20 Timestamp: time.Now().UTC(),
21 EventType: AuditEventAuthStart,
22 Action: "test_action",
23 Result: AuditResultSuccess,
24 }
25
26 err := logger.Log(ctx, event)
27 if err != nil {
28 t.Errorf("NoOpAuditLogger.Log() returned error: %v", err)
29 }
30}
31
32// TestSetAuditLogger verifies setting and getting the global audit logger.
33func TestSetAuditLogger(t *testing.T) {
34 // Save original logger
35 originalLogger := GetAuditLogger()
36 defer SetAuditLogger(originalLogger)
37
38 // Test setting a custom logger
39 customLogger := &NoOpAuditLogger{}
40 SetAuditLogger(customLogger)
41
42 retrieved := GetAuditLogger()
43 if retrieved != customLogger {
44 t.Error("GetAuditLogger() did not return the set logger")
45 }
46
47 // Test setting nil (should revert to no-op)
48 SetAuditLogger(nil)
49 retrieved = GetAuditLogger()
50 if _, ok := retrieved.(*NoOpAuditLogger); !ok {
51 t.Error("Setting nil should revert to NoOpAuditLogger")
52 }
53}
54
55// TestLogAuditEvent verifies the convenience function enriches events.
56func TestLogAuditEvent(t *testing.T) {
57 // Save original logger
58 originalLogger := GetAuditLogger()
59 defer SetAuditLogger(originalLogger)
60
61 // Create mock logger
62 var capturedEvent AuditEvent
63 mockLogger := &mockAuditLogger{
64 logFunc: func(ctx context.Context, event AuditEvent) error {
65 capturedEvent = event
66 return nil
67 },
68 }
69 SetAuditLogger(mockLogger)
70
71 // Create context with request ID and session ID
72 ctx := context.Background()
73 ctx = context.WithValue(ctx, ContextKeyRequestID, "test-request-123")
74 ctx = context.WithValue(ctx, ContextKeySessionID, "test-session-456")
75
76 // Log event without timestamp
77 event := AuditEvent{
78 EventType: AuditEventPostCreate,
79 Actor: "did:plc:test123",
80 Action: "create_post",
81 Result: AuditResultSuccess,
82 }
83
84 err := LogAuditEvent(ctx, event)
85 if err != nil {
86 t.Fatalf("LogAuditEvent() returned error: %v", err)
87 }
88
89 // Verify enrichment
90 if capturedEvent.Timestamp.IsZero() {
91 t.Error("LogAuditEvent should set Timestamp if not provided")
92 }
93 if capturedEvent.RequestID != "test-request-123" {
94 t.Errorf("Expected RequestID 'test-request-123', got '%s'", capturedEvent.RequestID)
95 }
96 if capturedEvent.SessionID != "test-session-456" {
97 t.Errorf("Expected SessionID 'test-session-456', got '%s'", capturedEvent.SessionID)
98 }
99 if capturedEvent.EventType != AuditEventPostCreate {
100 t.Errorf("Event type should be preserved")
101 }
102}
103
104// TestFileAuditLogger tests basic file logging functionality.
105func TestFileAuditLogger(t *testing.T) {
106 // Create temp directory
107 tmpDir := t.TempDir()
108 logFile := filepath.Join(tmpDir, "audit.log")
109
110 // Create logger
111 logger, err := NewFileAuditLogger(logFile)
112 if err != nil {
113 t.Fatalf("NewFileAuditLogger() failed: %v", err)
114 }
115 defer logger.Close()
116
117 ctx := context.Background()
118
119 // Log some events
120 events := []AuditEvent{
121 {
122 Timestamp: time.Now().UTC(),
123 EventType: AuditEventAuthStart,
124 Action: "start_oauth_flow",
125 Resource: "user.bsky.social",
126 Result: AuditResultSuccess,
127 },
128 {
129 Timestamp: time.Now().UTC(),
130 EventType: AuditEventAuthSuccess,
131 Actor: "did:plc:test123",
132 Action: "complete_oauth_flow",
133 Result: AuditResultSuccess,
134 },
135 {
136 Timestamp: time.Now().UTC(),
137 EventType: AuditEventPostCreate,
138 Actor: "did:plc:test123",
139 Action: "create_post",
140 Resource: "at://did:plc:test123/app.bsky.feed.post/abc123",
141 Result: AuditResultSuccess,
142 },
143 }
144
145 for _, event := range events {
146 if err := logger.Log(ctx, event); err != nil {
147 t.Errorf("Log() failed: %v", err)
148 }
149 }
150
151 // Close to flush
152 if err := logger.Close(); err != nil {
153 t.Errorf("Close() failed: %v", err)
154 }
155
156 // Read and verify log file
157 data, err := os.ReadFile(logFile)
158 if err != nil {
159 t.Fatalf("Failed to read log file: %v", err)
160 }
161
162 lines := strings.Split(strings.TrimSpace(string(data)), "\n")
163 if len(lines) != len(events) {
164 t.Errorf("Expected %d log lines, got %d", len(events), len(lines))
165 }
166
167 // Verify each line is valid JSON and contains expected data
168 for i, line := range lines {
169 var loggedEvent AuditEvent
170 if err := json.Unmarshal([]byte(line), &loggedEvent); err != nil {
171 t.Errorf("Line %d is not valid JSON: %v", i+1, err)
172 continue
173 }
174
175 if loggedEvent.EventType != events[i].EventType {
176 t.Errorf("Event %d: expected type %s, got %s", i, events[i].EventType, loggedEvent.EventType)
177 }
178 if loggedEvent.Action != events[i].Action {
179 t.Errorf("Event %d: expected action %s, got %s", i, events[i].Action, loggedEvent.Action)
180 }
181 }
182}
183
184// TestFileAuditLoggerConcurrent tests thread-safety of FileAuditLogger.
185func TestFileAuditLoggerConcurrent(t *testing.T) {
186 tmpDir := t.TempDir()
187 logFile := filepath.Join(tmpDir, "audit.log")
188
189 logger, err := NewFileAuditLogger(logFile)
190 if err != nil {
191 t.Fatalf("NewFileAuditLogger() failed: %v", err)
192 }
193 defer logger.Close()
194
195 ctx := context.Background()
196 const numGoroutines = 10
197 const eventsPerGoroutine = 10
198
199 var wg sync.WaitGroup
200 wg.Add(numGoroutines)
201
202 for i := 0; i < numGoroutines; i++ {
203 go func(id int) {
204 defer wg.Done()
205 for j := 0; j < eventsPerGoroutine; j++ {
206 event := AuditEvent{
207 Timestamp: time.Now().UTC(),
208 EventType: AuditEventPostCreate,
209 Actor: "did:plc:test",
210 Action: "concurrent_test",
211 Metadata: map[string]interface{}{
212 "goroutine": id,
213 "event": j,
214 },
215 Result: AuditResultSuccess,
216 }
217 if err := logger.Log(ctx, event); err != nil {
218 t.Errorf("Log() failed: %v", err)
219 }
220 }
221 }(i)
222 }
223
224 wg.Wait()
225
226 // Close to flush
227 if err := logger.Close(); err != nil {
228 t.Errorf("Close() failed: %v", err)
229 }
230
231 // Verify we got all events
232 data, err := os.ReadFile(logFile)
233 if err != nil {
234 t.Fatalf("Failed to read log file: %v", err)
235 }
236
237 lines := strings.Split(strings.TrimSpace(string(data)), "\n")
238 expectedLines := numGoroutines * eventsPerGoroutine
239 if len(lines) != expectedLines {
240 t.Errorf("Expected %d log lines, got %d", expectedLines, len(lines))
241 }
242
243 // Verify all lines are valid JSON
244 for i, line := range lines {
245 var event AuditEvent
246 if err := json.Unmarshal([]byte(line), &event); err != nil {
247 t.Errorf("Line %d is not valid JSON: %v", i+1, err)
248 }
249 }
250}
251
252// TestFileAuditLoggerDirectoryCreation verifies directory creation.
253func TestFileAuditLoggerDirectoryCreation(t *testing.T) {
254 tmpDir := t.TempDir()
255 logFile := filepath.Join(tmpDir, "nested", "dir", "audit.log")
256
257 logger, err := NewFileAuditLogger(logFile)
258 if err != nil {
259 t.Fatalf("NewFileAuditLogger() failed: %v", err)
260 }
261 defer logger.Close()
262
263 // Verify directory was created
264 dir := filepath.Dir(logFile)
265 if _, err := os.Stat(dir); os.IsNotExist(err) {
266 t.Error("Directory should have been created")
267 }
268
269 // Verify file was created
270 if _, err := os.Stat(logFile); os.IsNotExist(err) {
271 t.Error("Log file should have been created")
272 }
273}
274
275// TestRotatingFileAuditLogger tests daily log rotation.
276func TestRotatingFileAuditLogger(t *testing.T) {
277 tmpDir := t.TempDir()
278
279 logger, err := NewRotatingFileAuditLogger(tmpDir)
280 if err != nil {
281 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err)
282 }
283 defer logger.Close()
284
285 ctx := context.Background()
286
287 // Log an event
288 event := AuditEvent{
289 Timestamp: time.Now().UTC(),
290 EventType: AuditEventAuthStart,
291 Action: "test_action",
292 Result: AuditResultSuccess,
293 }
294
295 if err := logger.Log(ctx, event); err != nil {
296 t.Errorf("Log() failed: %v", err)
297 }
298
299 // Verify file was created with today's date
300 expectedDate := time.Now().UTC().Format("2006-01-02")
301 expectedFile := filepath.Join(tmpDir, "audit-"+expectedDate+".log")
302
303 if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
304 t.Errorf("Expected log file %s does not exist", expectedFile)
305 }
306
307 // Read and verify content
308 data, err := os.ReadFile(expectedFile)
309 if err != nil {
310 t.Fatalf("Failed to read log file: %v", err)
311 }
312
313 var loggedEvent AuditEvent
314 if err := json.Unmarshal(data, &loggedEvent); err != nil {
315 t.Errorf("Log file does not contain valid JSON: %v", err)
316 }
317
318 if loggedEvent.EventType != AuditEventAuthStart {
319 t.Errorf("Expected event type %s, got %s", AuditEventAuthStart, loggedEvent.EventType)
320 }
321}
322
323// TestRotatingFileAuditLoggerRotation verifies rotation behavior.
324func TestRotatingFileAuditLoggerRotation(t *testing.T) {
325 tmpDir := t.TempDir()
326
327 logger, err := NewRotatingFileAuditLogger(tmpDir)
328 if err != nil {
329 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err)
330 }
331 defer logger.Close()
332
333 ctx := context.Background()
334
335 // Get initial file name
336 initialDate := time.Now().UTC().Format("2006-01-02")
337 initialFile := filepath.Join(tmpDir, "audit-"+initialDate+".log")
338
339 // Log event with current date
340 event1 := AuditEvent{
341 Timestamp: time.Now().UTC(),
342 EventType: AuditEventAuthStart,
343 Action: "before_rotation",
344 Result: AuditResultSuccess,
345 }
346 if err := logger.Log(ctx, event1); err != nil {
347 t.Errorf("Log() failed: %v", err)
348 }
349
350 // Manually change the current date to trigger rotation on next log
351 logger.mu.Lock()
352 logger.currentDate = "2023-01-01" // Old date to force rotation
353 logger.mu.Unlock()
354
355 // Log event (should trigger rotation to current date)
356 event2 := AuditEvent{
357 Timestamp: time.Now().UTC(),
358 EventType: AuditEventAuthSuccess,
359 Action: "after_rotation",
360 Result: AuditResultSuccess,
361 }
362 if err := logger.Log(ctx, event2); err != nil {
363 t.Errorf("Log() after rotation failed: %v", err)
364 }
365
366 // Close to ensure all writes are flushed
367 if err := logger.Close(); err != nil {
368 t.Errorf("Close() failed: %v", err)
369 }
370
371 // Verify initial file still exists with first event
372 data, err := os.ReadFile(initialFile)
373 if err != nil {
374 t.Fatalf("Failed to read initial file %s: %v", initialFile, err)
375 }
376
377 if !strings.Contains(string(data), "before_rotation") {
378 t.Error("Initial file should contain first event")
379 }
380 if !strings.Contains(string(data), "after_rotation") {
381 t.Error("Initial file should contain second event after rotation back to same date")
382 }
383}
384
385// TestRotatingFileAuditLoggerConcurrent tests thread-safety during rotation.
386func TestRotatingFileAuditLoggerConcurrent(t *testing.T) {
387 tmpDir := t.TempDir()
388
389 logger, err := NewRotatingFileAuditLogger(tmpDir)
390 if err != nil {
391 t.Fatalf("NewRotatingFileAuditLogger() failed: %v", err)
392 }
393 defer logger.Close()
394
395 ctx := context.Background()
396 const numGoroutines = 10
397 const eventsPerGoroutine = 10
398
399 var wg sync.WaitGroup
400 wg.Add(numGoroutines)
401
402 for i := 0; i < numGoroutines; i++ {
403 go func(id int) {
404 defer wg.Done()
405 for j := 0; j < eventsPerGoroutine; j++ {
406 event := AuditEvent{
407 Timestamp: time.Now().UTC(),
408 EventType: AuditEventPostCreate,
409 Actor: "did:plc:test",
410 Action: "concurrent_rotation_test",
411 Metadata: map[string]interface{}{
412 "goroutine": id,
413 "event": j,
414 },
415 Result: AuditResultSuccess,
416 }
417 if err := logger.Log(ctx, event); err != nil {
418 t.Errorf("Log() failed: %v", err)
419 }
420 }
421 }(i)
422 }
423
424 wg.Wait()
425
426 // Close to flush
427 if err := logger.Close(); err != nil {
428 t.Errorf("Close() failed: %v", err)
429 }
430
431 // Verify log file exists and has all events
432 expectedDate := time.Now().UTC().Format("2006-01-02")
433 expectedFile := filepath.Join(tmpDir, "audit-"+expectedDate+".log")
434
435 data, err := os.ReadFile(expectedFile)
436 if err != nil {
437 t.Fatalf("Failed to read log file: %v", err)
438 }
439
440 lines := strings.Split(strings.TrimSpace(string(data)), "\n")
441 expectedLines := numGoroutines * eventsPerGoroutine
442 if len(lines) != expectedLines {
443 t.Errorf("Expected %d log lines, got %d", expectedLines, len(lines))
444 }
445}
446
447// TestAuditEventStructure verifies AuditEvent fields are properly serialized.
448func TestAuditEventStructure(t *testing.T) {
449 event := AuditEvent{
450 Timestamp: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
451 EventType: AuditEventPostCreate,
452 Actor: "did:plc:test123",
453 Action: "create_post",
454 Resource: "at://did:plc:test123/app.bsky.feed.post/abc123",
455 Result: AuditResultSuccess,
456 Error: "",
457 Metadata: map[string]interface{}{
458 "ip_address": "192.168.1.1",
459 "user_agent": "MyApp/1.0",
460 },
461 RequestID: "req-123",
462 SessionID: "sess-456",
463 }
464
465 // Marshal to JSON
466 data, err := json.Marshal(event)
467 if err != nil {
468 t.Fatalf("Failed to marshal event: %v", err)
469 }
470
471 // Unmarshal and verify
472 var decoded AuditEvent
473 if err := json.Unmarshal(data, &decoded); err != nil {
474 t.Fatalf("Failed to unmarshal event: %v", err)
475 }
476
477 if decoded.EventType != event.EventType {
478 t.Errorf("EventType mismatch: expected %s, got %s", event.EventType, decoded.EventType)
479 }
480 if decoded.Actor != event.Actor {
481 t.Errorf("Actor mismatch: expected %s, got %s", event.Actor, decoded.Actor)
482 }
483 if decoded.Action != event.Action {
484 t.Errorf("Action mismatch: expected %s, got %s", event.Action, decoded.Action)
485 }
486 if decoded.Resource != event.Resource {
487 t.Errorf("Resource mismatch: expected %s, got %s", event.Resource, decoded.Resource)
488 }
489 if decoded.Result != event.Result {
490 t.Errorf("Result mismatch: expected %s, got %s", event.Result, decoded.Result)
491 }
492 if decoded.RequestID != event.RequestID {
493 t.Errorf("RequestID mismatch: expected %s, got %s", event.RequestID, decoded.RequestID)
494 }
495 if decoded.SessionID != event.SessionID {
496 t.Errorf("SessionID mismatch: expected %s, got %s", event.SessionID, decoded.SessionID)
497 }
498
499 // Verify metadata
500 if decoded.Metadata["ip_address"] != "192.168.1.1" {
501 t.Error("Metadata ip_address mismatch")
502 }
503 if decoded.Metadata["user_agent"] != "MyApp/1.0" {
504 t.Error("Metadata user_agent mismatch")
505 }
506}
507
508// mockAuditLogger is a test helper for capturing logged events.
509type mockAuditLogger struct {
510 logFunc func(ctx context.Context, event AuditEvent) error
511}
512
513func (m *mockAuditLogger) Log(ctx context.Context, event AuditEvent) error {
514 if m.logFunc != nil {
515 return m.logFunc(ctx, event)
516 }
517 return nil
518}