1package main
2
3import (
4 "encoding/json"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/http"
10 "os"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/miekg/dns"
16)
17
18// DID Document structure
19type DIDDocument struct {
20 Context []interface{} `json:"@context"`
21 ID string `json:"id"`
22 AlsoKnownAs []string `json:"alsoKnownAs"`
23 VerificationMethod []VerificationMethod `json:"verificationMethod"`
24 Service []Service `json:"service"`
25}
26
27type VerificationMethod struct {
28 ID string `json:"id"`
29 Type string `json:"type"`
30 Controller string `json:"controller"`
31 PublicKeyMultibase string `json:"publicKeyMultibase"`
32}
33
34type Service struct {
35 ID string `json:"id"`
36 Type string `json:"type"`
37 ServiceEndpoint string `json:"serviceEndpoint"`
38}
39
40type PLCHandler struct {
41 plcDirectory string
42 cache map[string]*CachedDID
43}
44
45type CachedDID struct {
46 Document *DIDDocument
47 Timestamp time.Time
48}
49
50type QueryType int
51
52const (
53 QueryHandle QueryType = iota
54 QueryPDS
55 QueryPubKey
56 QueryLabeler
57 QueryInvalid
58)
59
60func NewPLCHandler(plcDirectory string) *PLCHandler {
61 return &PLCHandler{
62 plcDirectory: plcDirectory,
63 cache: make(map[string]*CachedDID),
64 }
65}
66
67func (h *PLCHandler) fetchDIDDocument(did string) (*DIDDocument, error) {
68 // Check cache first (5 minute TTL)
69 if cached, exists := h.cache[did]; exists {
70 if time.Since(cached.Timestamp) < 5*time.Minute {
71 log.Printf("Cache hit for %s", did)
72 return cached.Document, nil
73 }
74 }
75
76 url := fmt.Sprintf("%s/%s", h.plcDirectory, did)
77 log.Printf("Fetching DID document from: %s", url)
78
79 client := &http.Client{Timeout: 10 * time.Second}
80 resp, err := client.Get(url)
81 if err != nil {
82 return nil, fmt.Errorf("failed to fetch DID document: %w", err)
83 }
84 defer resp.Body.Close()
85
86 if resp.StatusCode != http.StatusOK {
87 return nil, fmt.Errorf("DID not found: status %d", resp.StatusCode)
88 }
89
90 body, err := io.ReadAll(resp.Body)
91 if err != nil {
92 return nil, fmt.Errorf("failed to read response: %w", err)
93 }
94
95 var doc DIDDocument
96 if err := json.Unmarshal(body, &doc); err != nil {
97 return nil, fmt.Errorf("failed to parse DID document: %w", err)
98 }
99
100 // Cache the result
101 h.cache[did] = &CachedDID{
102 Document: &doc,
103 Timestamp: time.Now(),
104 }
105
106 return &doc, nil
107}
108
109// Parse domain name to extract DID and query type
110// Expected formats:
111// _handle.<did>.plc.atscan.net
112// _pds.<did>.plc.atscan.net
113// _pubkey.<did>.plc.atscan.net
114// _labeler.<did>.plc.atscan.net
115func (h *PLCHandler) parseDomain(domain string) (string, QueryType, bool) {
116 domain = strings.TrimSuffix(domain, ".")
117 parts := strings.Split(domain, ".")
118
119 // Should be at least: [_prefix, <did>, plc, atscan, net]
120 if len(parts) < 5 {
121 return "", QueryInvalid, false
122 }
123
124 // Determine query type based on prefix
125 var queryType QueryType
126 switch parts[0] {
127 case "_handle":
128 queryType = QueryHandle
129 case "_pds":
130 queryType = QueryPDS
131 case "_pubkey":
132 queryType = QueryPubKey
133 case "_labeler":
134 queryType = QueryLabeler
135 default:
136 return "", QueryInvalid, false
137 }
138
139 // Extract DID identifier
140 didIdentifier := parts[1]
141
142 // Construct full DID
143 did := fmt.Sprintf("did:plc:%s", didIdentifier)
144
145 return did, queryType, true
146}
147
148// Get handle from DID document
149func (h *PLCHandler) getHandle(doc *DIDDocument) string {
150 for _, aka := range doc.AlsoKnownAs {
151 if strings.HasPrefix(aka, "at://") {
152 return strings.TrimPrefix(aka, "at://")
153 }
154 }
155 return ""
156}
157
158// Get PDS endpoint from DID document
159func (h *PLCHandler) getPDS(doc *DIDDocument) string {
160 for _, service := range doc.Service {
161 if service.Type == "AtprotoPersonalDataServer" {
162 return service.ServiceEndpoint
163 }
164 }
165 return ""
166}
167
168// Get labeler endpoint from DID document
169func (h *PLCHandler) getLabeler(doc *DIDDocument) string {
170 for _, service := range doc.Service {
171 if service.ID == "#atproto_labeler" {
172 return service.ServiceEndpoint
173 }
174 }
175 return ""
176}
177
178// Get public key from DID document
179func (h *PLCHandler) getPubKey(doc *DIDDocument) string {
180 if len(doc.VerificationMethod) > 0 {
181 return doc.VerificationMethod[0].PublicKeyMultibase
182 }
183 return ""
184}
185
186// Create TXT record based on query type
187func (h *PLCHandler) createTXTRecord(doc *DIDDocument, qname string, queryType QueryType) []dns.RR {
188 var records []dns.RR
189 ttl := uint32(300) // 5 minutes
190
191 var value string
192 switch queryType {
193 case QueryHandle:
194 value = h.getHandle(doc)
195 if value == "" {
196 return records
197 }
198 case QueryPDS:
199 value = h.getPDS(doc)
200 if value == "" {
201 return records
202 }
203 case QueryLabeler:
204 value = h.getLabeler(doc)
205 if value == "" {
206 return records
207 }
208 case QueryPubKey:
209 value = h.getPubKey(doc)
210 if value == "" {
211 return records
212 }
213 default:
214 return records
215 }
216
217 records = append(records, &dns.TXT{
218 Hdr: dns.RR_Header{
219 Name: qname,
220 Rrtype: dns.TypeTXT,
221 Class: dns.ClassINET,
222 Ttl: ttl,
223 },
224 Txt: []string{value},
225 })
226
227 return records
228}
229
230func (h *PLCHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
231 msg := dns.Msg{}
232 msg.SetReply(r)
233 msg.Authoritative = true
234
235 for _, question := range r.Question {
236 log.Printf("Query: %s %s", question.Name, dns.TypeToString[question.Qtype])
237
238 // Only handle TXT queries
239 if question.Qtype != dns.TypeTXT {
240 continue
241 }
242
243 // Parse the domain to extract DID and query type
244 did, queryType, ok := h.parseDomain(question.Name)
245 if !ok {
246 log.Printf("Invalid domain format: %s", question.Name)
247 msg.SetRcode(r, dns.RcodeNameError)
248 w.WriteMsg(&msg)
249 return
250 }
251
252 log.Printf("Extracted DID: %s, Query Type: %v", did, queryType)
253
254 // Fetch DID document
255 doc, err := h.fetchDIDDocument(did)
256 if err != nil {
257 log.Printf("Error fetching DID document: %s", err)
258 msg.SetRcode(r, dns.RcodeServerFailure)
259 w.WriteMsg(&msg)
260 return
261 }
262
263 // Create TXT record based on query type
264 records := h.createTXTRecord(doc, question.Name, queryType)
265 if len(records) == 0 {
266 log.Printf("No data found for query type")
267 msg.SetRcode(r, dns.RcodeNameError)
268 w.WriteMsg(&msg)
269 return
270 }
271
272 msg.Answer = append(msg.Answer, records...)
273 }
274
275 if len(msg.Answer) == 0 {
276 msg.SetRcode(r, dns.RcodeNameError)
277 }
278
279 w.WriteMsg(&msg)
280}
281
282func main() {
283 // Command line flags
284 port := flag.String("port", "", "DNS server port (default: 8053 or DNS_PORT env var)")
285 plcDir := flag.String("plc", "https://plc.directory", "PLC directory URL")
286 flag.Parse()
287
288 // Determine port: flag -> env var -> default
289 finalPort := *port
290 if finalPort == "" {
291 if envPort := os.Getenv("DNS_PORT"); envPort != "" {
292 finalPort = envPort
293 } else {
294 finalPort = "8053"
295 }
296 }
297
298 // Validate port number
299 if portNum, err := strconv.Atoi(finalPort); err != nil || portNum < 1 || portNum > 65535 {
300 log.Fatalf("Invalid port number: %s", finalPort)
301 }
302
303 addr := ":" + finalPort
304 handler := NewPLCHandler(*plcDir)
305
306 log.Printf("PLC Directory: %s", *plcDir)
307 log.Printf("Starting DNS servers on port %s", finalPort)
308 log.Printf("Supported query types:")
309 log.Printf(" _handle.<did>.plc.atscan.net - Returns handle")
310 log.Printf(" _pds.<did>.plc.atscan.net - Returns PDS endpoint")
311 log.Printf(" _labeler.<did>.plc.atscan.net - Returns labeler endpoint")
312 log.Printf(" _pubkey.<did>.plc.atscan.net - Returns public key")
313
314 // UDP server
315 udpServer := &dns.Server{
316 Addr: addr,
317 Net: "udp",
318 Handler: handler,
319 }
320
321 // TCP server
322 tcpServer := &dns.Server{
323 Addr: addr,
324 Net: "tcp",
325 Handler: handler,
326 }
327
328 // Start UDP server in goroutine
329 go func() {
330 log.Printf("UDP DNS server listening on %s", addr)
331 if err := udpServer.ListenAndServe(); err != nil {
332 log.Fatalf("Failed to start UDP server: %s", err)
333 }
334 }()
335
336 // Start TCP server
337 log.Printf("TCP DNS server listening on %s", addr)
338 if err := tcpServer.ListenAndServe(); err != nil {
339 log.Fatalf("Failed to start TCP server: %s", err)
340 }
341}