1package main
2
3import (
4 "encoding/json"
5 "net"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10
11 "github.com/miekg/dns"
12)
13
14// Mock DID document for testing
15var mockDIDDocument = DIDDocument{
16 ID: "did:plc:test123",
17 AlsoKnownAs: []string{
18 "at://test.bsky.social",
19 },
20 VerificationMethod: []VerificationMethod{
21 {
22 ID: "did:plc:test123#atproto",
23 Type: "Multikey",
24 Controller: "did:plc:test123",
25 PublicKeyMultibase: "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9przSMS",
26 },
27 },
28 Service: []Service{
29 {
30 ID: "#atproto_pds",
31 Type: "AtprotoPersonalDataServer",
32 ServiceEndpoint: "https://bsky.social",
33 },
34 {
35 ID: "#atproto_labeler",
36 Type: "AtprotoLabeler",
37 ServiceEndpoint: "https://mod.bsky.app",
38 },
39 },
40}
41
42// Test domain parsing
43func TestParseDomain(t *testing.T) {
44 handler := NewPLCHandler("https://plc.directory")
45
46 tests := []struct {
47 name string
48 domain string
49 expectedDID string
50 expectedType QueryType
51 expectedValid bool
52 }{
53 {
54 name: "Valid handle query",
55 domain: "_handle.test123.plc.atscan.net",
56 expectedDID: "did:plc:test123",
57 expectedType: QueryHandle,
58 expectedValid: true,
59 },
60 {
61 name: "Valid PDS query",
62 domain: "_pds.test123.plc.atscan.net",
63 expectedDID: "did:plc:test123",
64 expectedType: QueryPDS,
65 expectedValid: true,
66 },
67 {
68 name: "Valid labeler query",
69 domain: "_labeler.test123.plc.atscan.net",
70 expectedDID: "did:plc:test123",
71 expectedType: QueryLabeler,
72 expectedValid: true,
73 },
74 {
75 name: "Valid pubkey query",
76 domain: "_pubkey.test123.plc.atscan.net",
77 expectedDID: "did:plc:test123",
78 expectedType: QueryPubKey,
79 expectedValid: true,
80 },
81 {
82 name: "Invalid prefix",
83 domain: "_invalid.test123.plc.atscan.net",
84 expectedDID: "",
85 expectedType: QueryInvalid,
86 expectedValid: false,
87 },
88 {
89 name: "Too few parts",
90 domain: "_handle.test123.plc",
91 expectedDID: "",
92 expectedType: QueryInvalid,
93 expectedValid: false,
94 },
95 {
96 name: "Domain with trailing dot",
97 domain: "_handle.test123.plc.atscan.net.",
98 expectedDID: "did:plc:test123",
99 expectedType: QueryHandle,
100 expectedValid: true,
101 },
102 }
103
104 for _, tt := range tests {
105 t.Run(tt.name, func(t *testing.T) {
106 did, queryType, valid := handler.parseDomain(tt.domain)
107
108 if valid != tt.expectedValid {
109 t.Errorf("expected valid=%v, got %v", tt.expectedValid, valid)
110 }
111
112 if did != tt.expectedDID {
113 t.Errorf("expected DID=%s, got %s", tt.expectedDID, did)
114 }
115
116 if queryType != tt.expectedType {
117 t.Errorf("expected queryType=%v, got %v", tt.expectedType, queryType)
118 }
119 })
120 }
121}
122
123// Test getting handle from DID document
124func TestGetHandle(t *testing.T) {
125 handler := NewPLCHandler("https://plc.directory")
126
127 tests := []struct {
128 name string
129 doc *DIDDocument
130 expected string
131 }{
132 {
133 name: "Valid handle",
134 doc: &mockDIDDocument,
135 expected: "test.bsky.social",
136 },
137 {
138 name: "No handle",
139 doc: &DIDDocument{
140 AlsoKnownAs: []string{},
141 },
142 expected: "",
143 },
144 {
145 name: "Non-AT protocol handle",
146 doc: &DIDDocument{
147 AlsoKnownAs: []string{"https://example.com"},
148 },
149 expected: "",
150 },
151 }
152
153 for _, tt := range tests {
154 t.Run(tt.name, func(t *testing.T) {
155 result := handler.getHandle(tt.doc)
156 if result != tt.expected {
157 t.Errorf("expected %s, got %s", tt.expected, result)
158 }
159 })
160 }
161}
162
163// Test getting PDS from DID document
164func TestGetPDS(t *testing.T) {
165 handler := NewPLCHandler("https://plc.directory")
166
167 tests := []struct {
168 name string
169 doc *DIDDocument
170 expected string
171 }{
172 {
173 name: "Valid PDS",
174 doc: &mockDIDDocument,
175 expected: "https://bsky.social",
176 },
177 {
178 name: "No PDS",
179 doc: &DIDDocument{
180 Service: []Service{},
181 },
182 expected: "",
183 },
184 }
185
186 for _, tt := range tests {
187 t.Run(tt.name, func(t *testing.T) {
188 result := handler.getPDS(tt.doc)
189 if result != tt.expected {
190 t.Errorf("expected %s, got %s", tt.expected, result)
191 }
192 })
193 }
194}
195
196// Test getting labeler from DID document
197func TestGetLabeler(t *testing.T) {
198 handler := NewPLCHandler("https://plc.directory")
199
200 tests := []struct {
201 name string
202 doc *DIDDocument
203 expected string
204 }{
205 {
206 name: "Valid labeler",
207 doc: &mockDIDDocument,
208 expected: "https://mod.bsky.app",
209 },
210 {
211 name: "No labeler",
212 doc: &DIDDocument{
213 Service: []Service{},
214 },
215 expected: "",
216 },
217 }
218
219 for _, tt := range tests {
220 t.Run(tt.name, func(t *testing.T) {
221 result := handler.getLabeler(tt.doc)
222 if result != tt.expected {
223 t.Errorf("expected %s, got %s", tt.expected, result)
224 }
225 })
226 }
227}
228
229// Test getting pubkey from DID document
230func TestGetPubKey(t *testing.T) {
231 handler := NewPLCHandler("https://plc.directory")
232
233 tests := []struct {
234 name string
235 doc *DIDDocument
236 expected string
237 }{
238 {
239 name: "Valid pubkey",
240 doc: &mockDIDDocument,
241 expected: "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9przSMS",
242 },
243 {
244 name: "No pubkey",
245 doc: &DIDDocument{
246 VerificationMethod: []VerificationMethod{},
247 },
248 expected: "",
249 },
250 }
251
252 for _, tt := range tests {
253 t.Run(tt.name, func(t *testing.T) {
254 result := handler.getPubKey(tt.doc)
255 if result != tt.expected {
256 t.Errorf("expected %s, got %s", tt.expected, result)
257 }
258 })
259 }
260}
261
262// Test creating TXT records
263func TestCreateTXTRecord(t *testing.T) {
264 handler := NewPLCHandler("https://plc.directory")
265
266 tests := []struct {
267 name string
268 doc *DIDDocument
269 qname string
270 queryType QueryType
271 expectedValue string
272 shouldBeEmpty bool
273 }{
274 {
275 name: "Handle record",
276 doc: &mockDIDDocument,
277 qname: "_handle.test123.plc.atscan.net.",
278 queryType: QueryHandle,
279 expectedValue: "test.bsky.social",
280 shouldBeEmpty: false,
281 },
282 {
283 name: "PDS record",
284 doc: &mockDIDDocument,
285 qname: "_pds.test123.plc.atscan.net.",
286 queryType: QueryPDS,
287 expectedValue: "https://bsky.social",
288 shouldBeEmpty: false,
289 },
290 {
291 name: "Labeler record",
292 doc: &mockDIDDocument,
293 qname: "_labeler.test123.plc.atscan.net.",
294 queryType: QueryLabeler,
295 expectedValue: "https://mod.bsky.app",
296 shouldBeEmpty: false,
297 },
298 {
299 name: "Pubkey record",
300 doc: &mockDIDDocument,
301 qname: "_pubkey.test123.plc.atscan.net.",
302 queryType: QueryPubKey,
303 expectedValue: "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9przSMS",
304 shouldBeEmpty: false,
305 },
306 {
307 name: "Empty handle",
308 doc: &DIDDocument{
309 AlsoKnownAs: []string{},
310 },
311 qname: "_handle.test123.plc.atscan.net.",
312 queryType: QueryHandle,
313 expectedValue: "",
314 shouldBeEmpty: true,
315 },
316 }
317
318 for _, tt := range tests {
319 t.Run(tt.name, func(t *testing.T) {
320 records := handler.createTXTRecord(tt.doc, tt.qname, tt.queryType)
321
322 if tt.shouldBeEmpty {
323 if len(records) != 0 {
324 t.Errorf("expected empty records, got %d records", len(records))
325 }
326 return
327 }
328
329 if len(records) != 1 {
330 t.Fatalf("expected 1 record, got %d", len(records))
331 }
332
333 txtRecord, ok := records[0].(*dns.TXT)
334 if !ok {
335 t.Fatal("record is not TXT type")
336 }
337
338 if txtRecord.Hdr.Name != tt.qname {
339 t.Errorf("expected name %s, got %s", tt.qname, txtRecord.Hdr.Name)
340 }
341
342 if len(txtRecord.Txt) != 1 {
343 t.Fatalf("expected 1 TXT value, got %d", len(txtRecord.Txt))
344 }
345
346 if txtRecord.Txt[0] != tt.expectedValue {
347 t.Errorf("expected value %s, got %s", tt.expectedValue, txtRecord.Txt[0])
348 }
349
350 if txtRecord.Hdr.Ttl != 300 {
351 t.Errorf("expected TTL 300, got %d", txtRecord.Hdr.Ttl)
352 }
353 })
354 }
355}
356
357// Test fetching DID document with mock server
358func TestFetchDIDDocument(t *testing.T) {
359 // Create mock HTTP server
360 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361 if r.URL.Path == "/did:plc:test123" {
362 w.WriteHeader(http.StatusOK)
363 json.NewEncoder(w).Encode(mockDIDDocument)
364 } else if r.URL.Path == "/did:plc:notfound" {
365 w.WriteHeader(http.StatusNotFound)
366 } else {
367 w.WriteHeader(http.StatusInternalServerError)
368 }
369 }))
370 defer server.Close()
371
372 handler := NewPLCHandler(server.URL)
373
374 tests := []struct {
375 name string
376 did string
377 expectError bool
378 }{
379 {
380 name: "Valid DID",
381 did: "did:plc:test123",
382 expectError: false,
383 },
384 {
385 name: "Not found DID",
386 did: "did:plc:notfound",
387 expectError: true,
388 },
389 }
390
391 for _, tt := range tests {
392 t.Run(tt.name, func(t *testing.T) {
393 doc, err := handler.fetchDIDDocument(tt.did)
394
395 if tt.expectError {
396 if err == nil {
397 t.Error("expected error, got nil")
398 }
399 return
400 }
401
402 if err != nil {
403 t.Fatalf("unexpected error: %v", err)
404 }
405
406 if doc.ID != mockDIDDocument.ID {
407 t.Errorf("expected ID %s, got %s", mockDIDDocument.ID, doc.ID)
408 }
409 })
410 }
411}
412
413// Test caching
414func TestCaching(t *testing.T) {
415 callCount := 0
416
417 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
418 callCount++
419 w.WriteHeader(http.StatusOK)
420 json.NewEncoder(w).Encode(mockDIDDocument)
421 }))
422 defer server.Close()
423
424 handler := NewPLCHandler(server.URL)
425
426 // First call
427 _, err := handler.fetchDIDDocument("did:plc:test123")
428 if err != nil {
429 t.Fatalf("unexpected error: %v", err)
430 }
431
432 if callCount != 1 {
433 t.Errorf("expected 1 HTTP call, got %d", callCount)
434 }
435
436 // Second call (should use cache)
437 _, err = handler.fetchDIDDocument("did:plc:test123")
438 if err != nil {
439 t.Fatalf("unexpected error: %v", err)
440 }
441
442 if callCount != 1 {
443 t.Errorf("expected 1 HTTP call (cached), got %d", callCount)
444 }
445
446 // Invalidate cache by modifying timestamp
447 handler.cache["did:plc:test123"].Timestamp = time.Now().Add(-10 * time.Minute)
448
449 // Third call (cache expired, should fetch again)
450 _, err = handler.fetchDIDDocument("did:plc:test123")
451 if err != nil {
452 t.Fatalf("unexpected error: %v", err)
453 }
454
455 if callCount != 2 {
456 t.Errorf("expected 2 HTTP calls (cache expired), got %d", callCount)
457 }
458}
459
460// Integration test for DNS server
461func TestDNSServerIntegration(t *testing.T) {
462 // Create mock HTTP server
463 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
464 w.WriteHeader(http.StatusOK)
465 json.NewEncoder(w).Encode(mockDIDDocument)
466 }))
467 defer server.Close()
468
469 handler := NewPLCHandler(server.URL)
470
471 tests := []struct {
472 name string
473 qname string
474 qtype uint16
475 expectedValue string
476 expectError bool
477 }{
478 {
479 name: "Query handle",
480 qname: "_handle.test123.plc.atscan.net.",
481 qtype: dns.TypeTXT,
482 expectedValue: "test.bsky.social",
483 expectError: false,
484 },
485 {
486 name: "Query PDS",
487 qname: "_pds.test123.plc.atscan.net.",
488 qtype: dns.TypeTXT,
489 expectedValue: "https://bsky.social",
490 expectError: false,
491 },
492 {
493 name: "Query labeler",
494 qname: "_labeler.test123.plc.atscan.net.",
495 qtype: dns.TypeTXT,
496 expectedValue: "https://mod.bsky.app",
497 expectError: false,
498 },
499 {
500 name: "Query pubkey",
501 qname: "_pubkey.test123.plc.atscan.net.",
502 qtype: dns.TypeTXT,
503 expectedValue: "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9przSMS",
504 expectError: false,
505 },
506 {
507 name: "Invalid query type",
508 qname: "_handle.test123.plc.atscan.net.",
509 qtype: dns.TypeA,
510 expectError: true,
511 },
512 {
513 name: "Invalid domain",
514 qname: "invalid.domain.com.",
515 qtype: dns.TypeTXT,
516 expectError: true,
517 },
518 }
519
520 for _, tt := range tests {
521 t.Run(tt.name, func(t *testing.T) {
522 // Create DNS request
523 m := new(dns.Msg)
524 m.SetQuestion(tt.qname, tt.qtype)
525
526 // Create response writer
527 rw := &testResponseWriter{msg: new(dns.Msg)}
528
529 // Handle request
530 handler.ServeDNS(rw, m)
531
532 if tt.expectError {
533 if rw.msg.Rcode == dns.RcodeSuccess && len(rw.msg.Answer) > 0 {
534 t.Error("expected error response, got success")
535 }
536 return
537 }
538
539 if rw.msg.Rcode != dns.RcodeSuccess {
540 t.Errorf("expected success, got rcode %d", rw.msg.Rcode)
541 }
542
543 if len(rw.msg.Answer) != 1 {
544 t.Fatalf("expected 1 answer, got %d", len(rw.msg.Answer))
545 }
546
547 txtRecord, ok := rw.msg.Answer[0].(*dns.TXT)
548 if !ok {
549 t.Fatal("answer is not TXT record")
550 }
551
552 if len(txtRecord.Txt) != 1 {
553 t.Fatalf("expected 1 TXT value, got %d", len(txtRecord.Txt))
554 }
555
556 if txtRecord.Txt[0] != tt.expectedValue {
557 t.Errorf("expected value %s, got %s", tt.expectedValue, txtRecord.Txt[0])
558 }
559 })
560 }
561}
562
563// Test response writer for DNS testing
564type testResponseWriter struct {
565 msg *dns.Msg
566}
567
568func (w *testResponseWriter) LocalAddr() net.Addr {
569 return nil
570}
571
572func (w *testResponseWriter) RemoteAddr() net.Addr {
573 return nil
574}
575
576func (w *testResponseWriter) WriteMsg(m *dns.Msg) error {
577 w.msg = m
578 return nil
579}
580
581func (w *testResponseWriter) Write([]byte) (int, error) {
582 return 0, nil
583}
584
585func (w *testResponseWriter) Close() error {
586 return nil
587}
588
589func (w *testResponseWriter) TsigStatus() error {
590 return nil
591}
592
593func (w *testResponseWriter) TsigTimersOnly(bool) {}
594
595func (w *testResponseWriter) Hijack() {}