1package handles
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "net"
10 "net/http"
11 "net/url"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/bluesky-social/indigo/did"
17 arc "github.com/hashicorp/golang-lru/arc/v2"
18 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
19 otel "go.opentelemetry.io/otel"
20)
21
22func ResolveDidToHandle(ctx context.Context, res did.Resolver, hr HandleResolver, udid string) (string, string, error) {
23 ctx, span := otel.Tracer("gosky").Start(ctx, "resolveDidToHandle")
24 defer span.End()
25
26 doc, err := res.GetDocument(ctx, udid)
27 if err != nil {
28 return "", "", err
29 }
30
31 if len(doc.AlsoKnownAs) == 0 {
32 return "", "", fmt.Errorf("users did document does not specify a handle")
33 }
34
35 aka := doc.AlsoKnownAs[0]
36
37 u, err := url.Parse(aka)
38 if err != nil {
39 return "", "", fmt.Errorf("aka field in doc was not a valid url: %w", err)
40 }
41
42 handle := u.Host
43
44 var svc *did.Service
45 for _, s := range doc.Service {
46 if s.ID.String() == "#atproto_pds" && s.Type == "AtprotoPersonalDataServer" {
47 svc = &s
48 break
49 }
50 }
51
52 if svc == nil {
53 return "", "", fmt.Errorf("users did document has no pds service set")
54 }
55
56 verdid, err := hr.ResolveHandleToDid(ctx, handle)
57 if err != nil {
58 return "", "", err
59 }
60
61 if verdid != udid {
62 return "", "", fmt.Errorf("pds server reported different did for claimed handle")
63 }
64
65 return handle, svc.ServiceEndpoint, nil
66}
67
68type HandleResolver interface {
69 ResolveHandleToDid(ctx context.Context, handle string) (string, error)
70}
71
72type failCacheItem struct {
73 err error
74 count int
75 expiresAt time.Time
76}
77
78type ProdHandleResolver struct {
79 client *http.Client
80 resolver *net.Resolver
81 ReqMod func(*http.Request, string) error
82 FailCache *arc.ARCCache[string, *failCacheItem]
83}
84
85func NewProdHandleResolver(failureCacheSize int, resolveAddr string, forceUDP bool) (*ProdHandleResolver, error) {
86 failureCache, err := arc.NewARC[string, *failCacheItem](failureCacheSize)
87 if err != nil {
88 return nil, err
89 }
90
91 if resolveAddr == "" {
92 resolveAddr = "1.1.1.1:53"
93 }
94
95 c := http.Client{
96 Transport: otelhttp.NewTransport(http.DefaultTransport),
97 Timeout: time.Second * 10,
98 }
99
100 r := &net.Resolver{
101 PreferGo: true,
102 Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
103 d := net.Dialer{
104 Timeout: time.Second * 10,
105 }
106 if forceUDP {
107 network = "udp"
108 }
109 return d.DialContext(ctx, network, resolveAddr)
110 },
111 }
112
113 return &ProdHandleResolver{
114 FailCache: failureCache,
115 client: &c,
116 resolver: r,
117 }, nil
118}
119
120func (dr *ProdHandleResolver) ResolveHandleToDid(ctx context.Context, handle string) (string, error) {
121 ctx, cancel := context.WithTimeout(ctx, time.Second*20)
122 defer cancel()
123
124 ctx, span := otel.Tracer("resolver").Start(ctx, "ResolveHandleToDid")
125 defer span.End()
126
127 var cachedFailureCount int
128
129 if dr.FailCache != nil {
130 if item, ok := dr.FailCache.Get(handle); ok {
131 cachedFailureCount = item.count
132 if item.expiresAt.After(time.Now()) {
133 return "", item.err
134 }
135 dr.FailCache.Remove(handle)
136 }
137 }
138
139 var wkres, dnsres string
140 var wkerr, dnserr error
141
142 var wg sync.WaitGroup
143 wg.Add(2)
144
145 go func() {
146 defer wg.Done()
147 wkres, wkerr = dr.resolveWellKnown(ctx, handle)
148 if wkerr == nil {
149 cancel()
150 }
151 }()
152 go func() {
153 defer wg.Done()
154 dnsres, dnserr = dr.resolveDNS(ctx, handle)
155 if dnserr == nil {
156 cancel()
157 }
158 }()
159
160 wg.Wait()
161
162 if dnserr == nil {
163 return dnsres, nil
164 }
165
166 if wkerr == nil {
167 return wkres, nil
168 }
169
170 err := errors.Join(fmt.Errorf("no did record found for handle %q", handle), dnserr, wkerr)
171
172 if dr.FailCache != nil {
173 cachedFailureCount++
174 expireAt := time.Now().Add(time.Millisecond * 100)
175 if cachedFailureCount > 1 {
176 // exponential backoff
177 expireAt = time.Now().Add(time.Millisecond * 100 * time.Duration(cachedFailureCount*cachedFailureCount))
178 // Clamp to one hour
179 if expireAt.After(time.Now().Add(time.Hour)) {
180 expireAt = time.Now().Add(time.Hour)
181 }
182 }
183
184 dr.FailCache.Add(handle, &failCacheItem{
185 err: err,
186 expiresAt: expireAt,
187 count: cachedFailureCount,
188 })
189 }
190
191 return "", err
192}
193
194func (dr *ProdHandleResolver) resolveWellKnown(ctx context.Context, handle string) (string, error) {
195 req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/.well-known/atproto-did", handle), nil)
196 if err != nil {
197 return "", err
198 }
199
200 if dr.ReqMod != nil {
201 if err := dr.ReqMod(req, handle); err != nil {
202 return "", err
203 }
204 }
205
206 req = req.WithContext(ctx)
207
208 resp, err := dr.client.Do(req)
209 if err != nil {
210 return "", fmt.Errorf("failed to resolve handle (%s) through HTTP well-known route: %s", handle, err)
211 }
212 if resp.StatusCode != 200 {
213 return "", fmt.Errorf("failed to resolve handle (%s) through HTTP well-known route: status=%d", handle, resp.StatusCode)
214 }
215
216 if resp.ContentLength > 2048 {
217 return "", fmt.Errorf("http well-known route returned too much data")
218 }
219
220 b, err := io.ReadAll(io.LimitReader(resp.Body, 2048))
221 if err != nil {
222 return "", fmt.Errorf("failed to read resolved did: %w", err)
223 }
224
225 parsed, err := did.ParseDID(string(b))
226 if err != nil {
227 return "", err
228 }
229
230 return parsed.String(), nil
231}
232
233func (dr *ProdHandleResolver) resolveDNS(ctx context.Context, handle string) (string, error) {
234 res, err := dr.resolver.LookupTXT(ctx, "_atproto."+handle)
235 if err != nil {
236 return "", fmt.Errorf("handle lookup failed: %w", err)
237 }
238
239 for _, s := range res {
240 if strings.HasPrefix(s, "did=") {
241 parts := strings.Split(s, "=")
242 pdid, err := did.ParseDID(parts[1])
243 if err != nil {
244 return "", fmt.Errorf("invalid did in record: %w", err)
245 }
246
247 return pdid.String(), nil
248 }
249 }
250
251 return "", fmt.Errorf("no did record found")
252}
253
254type TestHandleResolver struct {
255 TrialHosts []string
256}
257
258func (tr *TestHandleResolver) ResolveHandleToDid(ctx context.Context, handle string) (string, error) {
259 c := http.DefaultClient
260
261 for _, h := range tr.TrialHosts {
262 req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/.well-known/atproto-did", h), nil)
263 if err != nil {
264 return "", err
265 }
266
267 req.Host = handle
268
269 resp, err := c.Do(req)
270 if err != nil {
271 slog.Warn("failed to resolve handle to DID", "handle", handle, "err", err)
272 continue
273 }
274
275 if resp.StatusCode != 200 {
276 slog.Warn("got non-200 status code while resolving handle", "handle", handle, "statusCode", resp.StatusCode)
277 continue
278 }
279
280 b, err := io.ReadAll(resp.Body)
281 if err != nil {
282 return "", fmt.Errorf("failed to read resolved did: %w", err)
283 }
284
285 parsed, err := did.ParseDID(string(b))
286 if err != nil {
287 return "", err
288 }
289
290 return parsed.String(), nil
291 }
292
293 return "", fmt.Errorf("no did record found for handle %q", handle)
294}