1package main
2
3import (
4 "encoding/json"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "os"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/miekg/dns"
16)
17
18// DID Document structure
19type DIDDocument struct {
20 Context []interface{} `json:"@context"`
21 ID string `json:"id"`
22 AlsoKnownAs []string `json:"alsoKnownAs"`
23 VerificationMethod []VerificationMethod `json:"verificationMethod"`
24 Service []Service `json:"service"`
25}
26
27type VerificationMethod struct {
28 ID string `json:"id"`
29 Type string `json:"type"`
30 Controller string `json:"controller"`
31 PublicKeyMultibase string `json:"publicKeyMultibase"`
32}
33
34type Service struct {
35 ID string `json:"id"`
36 Type string `json:"type"`
37 ServiceEndpoint string `json:"serviceEndpoint"`
38}
39
40type PLCHandler struct {
41 plcDirectory string
42 cache map[string]*CachedDID
43}
44
45type CachedDID struct {
46 Document *DIDDocument
47 Timestamp time.Time
48}
49
50type QueryType int
51
52const (
53 QueryHandle QueryType = iota
54 QueryPDS
55 QueryPubKey
56 QueryLabeler
57 QueryInvalid
58)
59
60func NewPLCHandler(plcDirectory string) *PLCHandler {
61 return &PLCHandler{
62 plcDirectory: plcDirectory,
63 cache: make(map[string]*CachedDID),
64 }
65}
66
67func (h *PLCHandler) fetchDIDDocument(did string) (*DIDDocument, error) {
68 // Check cache first (5 minute TTL)
69 if cached, exists := h.cache[did]; exists {
70 if time.Since(cached.Timestamp) < 5*time.Minute {
71 log.Printf("Cache hit for %s", did)
72 return cached.Document, nil
73 }
74 }
75
76 url := fmt.Sprintf("%s/%s", h.plcDirectory, did)
77 log.Printf("Fetching DID document from: %s", url)
78
79 client := &http.Client{Timeout: 10 * time.Second}
80 resp, err := client.Get(url)
81 if err != nil {
82 return nil, fmt.Errorf("failed to fetch DID document: %w", err)
83 }
84 defer resp.Body.Close()
85
86 if resp.StatusCode != http.StatusOK {
87 return nil, fmt.Errorf("DID not found: status %d", resp.StatusCode)
88 }
89
90 body, err := io.ReadAll(resp.Body)
91 if err != nil {
92 return nil, fmt.Errorf("failed to read response: %w", err)
93 }
94
95 var doc DIDDocument
96 if err := json.Unmarshal(body, &doc); err != nil {
97 return nil, fmt.Errorf("failed to parse DID document: %w", err)
98 }
99
100 // Cache the result
101 h.cache[did] = &CachedDID{
102 Document: &doc,
103 Timestamp: time.Now(),
104 }
105
106 return &doc, nil
107}
108
109// Parse domain name to extract DID and query type
110// Expected formats:
111// _handle.<did>.plc.atscan.net
112// _pds.<did>.plc.atscan.net
113// _pubkey.<did>.plc.atscan.net
114// _labeler.<did>.plc.atscan.net
115func (h *PLCHandler) parseDomain(domain string) (string, QueryType, bool) {
116 domain = strings.TrimSuffix(domain, ".")
117
118 // DNS is case-insensitive, normalize to lowercase
119 domain = strings.ToLower(domain)
120
121 parts := strings.Split(domain, ".")
122
123 // Should be at least: [_prefix, <did>, plc, atscan, net]
124 if len(parts) < 5 {
125 return "", QueryInvalid, false
126 }
127
128 // Determine query type based on prefix
129 var queryType QueryType
130 switch parts[0] {
131 case "_handle":
132 queryType = QueryHandle
133 case "_pds":
134 queryType = QueryPDS
135 case "_pubkey":
136 queryType = QueryPubKey
137 case "_labeler":
138 queryType = QueryLabeler
139 default:
140 return "", QueryInvalid, false
141 }
142
143 // Extract DID identifier (already lowercase due to normalization above)
144 didIdentifier := parts[1]
145
146 // Construct full DID
147 did := fmt.Sprintf("did:plc:%s", didIdentifier)
148
149 return did, queryType, true
150}
151
152// Get handle from DID document
153func (h *PLCHandler) getHandle(doc *DIDDocument) string {
154 for _, aka := range doc.AlsoKnownAs {
155 if strings.HasPrefix(aka, "at://") {
156 return strings.TrimPrefix(aka, "at://")
157 }
158 }
159 return ""
160}
161
162// Get PDS endpoint from DID document
163func (h *PLCHandler) getPDS(doc *DIDDocument) string {
164 for _, service := range doc.Service {
165 if service.Type == "AtprotoPersonalDataServer" {
166 return service.ServiceEndpoint
167 }
168 }
169 return ""
170}
171
172// Get labeler endpoint from DID document
173func (h *PLCHandler) getLabeler(doc *DIDDocument) string {
174 for _, service := range doc.Service {
175 if service.ID == "#atproto_labeler" {
176 return service.ServiceEndpoint
177 }
178 }
179 return ""
180}
181
182// Get public key from DID document
183func (h *PLCHandler) getPubKey(doc *DIDDocument) string {
184 if len(doc.VerificationMethod) > 0 {
185 return doc.VerificationMethod[0].PublicKeyMultibase
186 }
187 return ""
188}
189
190// Create TXT record based on query type
191func (h *PLCHandler) createTXTRecord(doc *DIDDocument, qname string, queryType QueryType) []dns.RR {
192 var records []dns.RR
193 ttl := uint32(300) // 5 minutes
194
195 var value string
196 switch queryType {
197 case QueryHandle:
198 value = h.getHandle(doc)
199 if value == "" {
200 return records
201 }
202 case QueryPDS:
203 value = h.getPDS(doc)
204 if value == "" {
205 return records
206 }
207 case QueryLabeler:
208 value = h.getLabeler(doc)
209 if value == "" {
210 return records
211 }
212 case QueryPubKey:
213 value = h.getPubKey(doc)
214 if value == "" {
215 return records
216 }
217 default:
218 return records
219 }
220
221 records = append(records, &dns.TXT{
222 Hdr: dns.RR_Header{
223 Name: qname,
224 Rrtype: dns.TypeTXT,
225 Class: dns.ClassINET,
226 Ttl: ttl,
227 },
228 Txt: []string{value},
229 })
230
231 return records
232}
233
234func (h *PLCHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
235 msg := dns.Msg{}
236 msg.SetReply(r)
237 msg.Authoritative = true
238
239 for _, question := range r.Question {
240 log.Printf("Query: %s %s", question.Name, dns.TypeToString[question.Qtype])
241
242 // Only handle TXT queries
243 if question.Qtype != dns.TypeTXT {
244 continue
245 }
246
247 // Parse the domain to extract DID and query type
248 did, queryType, ok := h.parseDomain(question.Name)
249 if !ok {
250 log.Printf("Invalid domain format: %s", question.Name)
251 msg.SetRcode(r, dns.RcodeNameError)
252 w.WriteMsg(&msg)
253 return
254 }
255
256 log.Printf("Extracted DID: %s, Query Type: %v", did, queryType)
257
258 // Fetch DID document
259 doc, err := h.fetchDIDDocument(did)
260 if err != nil {
261 log.Printf("Error fetching DID document: %s", err)
262 msg.SetRcode(r, dns.RcodeServerFailure)
263 w.WriteMsg(&msg)
264 return
265 }
266
267 // Create TXT record based on query type
268 records := h.createTXTRecord(doc, question.Name, queryType)
269 if len(records) == 0 {
270 log.Printf("No data found for query type")
271 msg.SetRcode(r, dns.RcodeNameError)
272 w.WriteMsg(&msg)
273 return
274 }
275
276 msg.Answer = append(msg.Answer, records...)
277 }
278
279 if len(msg.Answer) == 0 {
280 msg.SetRcode(r, dns.RcodeNameError)
281 }
282
283 w.WriteMsg(&msg)
284}
285
286func main() {
287 // Command line flags
288 port := flag.String("port", "", "DNS server port (default: 8053 or DNS_PORT env var)")
289 plcDir := flag.String("plc", "https://plc.directory", "PLC directory URL")
290 flag.Parse()
291
292 // Determine port: flag -> env var -> default
293 finalPort := *port
294 if finalPort == "" {
295 if envPort := os.Getenv("DNS_PORT"); envPort != "" {
296 finalPort = envPort
297 } else {
298 finalPort = "8053"
299 }
300 }
301
302 // Validate port number
303 if portNum, err := strconv.Atoi(finalPort); err != nil || portNum < 1 || portNum > 65535 {
304 log.Fatalf("Invalid port number: %s", finalPort)
305 }
306
307 addr := ":" + finalPort
308 handler := NewPLCHandler(*plcDir)
309
310 log.Printf("PLC Directory: %s", *plcDir)
311 log.Printf("Starting DNS servers on port %s", finalPort)
312 log.Printf("Supported query types:")
313 log.Printf(" _handle.<did>.plc.atscan.net - Returns handle")
314 log.Printf(" _pds.<did>.plc.atscan.net - Returns PDS endpoint")
315 log.Printf(" _labeler.<did>.plc.atscan.net - Returns labeler endpoint")
316 log.Printf(" _pubkey.<did>.plc.atscan.net - Returns public key")
317
318 // UDP server
319 udpServer := &dns.Server{
320 Addr: addr,
321 Net: "udp",
322 Handler: handler,
323 }
324
325 // TCP server
326 tcpServer := &dns.Server{
327 Addr: addr,
328 Net: "tcp",
329 Handler: handler,
330 }
331
332 // Start UDP server in goroutine
333 go func() {
334 log.Printf("UDP DNS server listening on %s", addr)
335 if err := udpServer.ListenAndServe(); err != nil {
336 log.Fatalf("Failed to start UDP server: %s", err)
337 }
338 }()
339
340 // Start TCP server
341 log.Printf("TCP DNS server listening on %s", addr)
342 if err := tcpServer.ListenAndServe(); err != nil {
343 log.Fatalf("Failed to start TCP server: %s", err)
344 }
345}