porting all github actions from bluesky-social/indigo to tangled CI
at main 6.7 kB view raw
1package handles 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "log/slog" 9 "net" 10 "net/http" 11 "net/url" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/bluesky-social/indigo/did" 17 arc "github.com/hashicorp/golang-lru/arc/v2" 18 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 19 otel "go.opentelemetry.io/otel" 20) 21 22func ResolveDidToHandle(ctx context.Context, res did.Resolver, hr HandleResolver, udid string) (string, string, error) { 23 ctx, span := otel.Tracer("gosky").Start(ctx, "resolveDidToHandle") 24 defer span.End() 25 26 doc, err := res.GetDocument(ctx, udid) 27 if err != nil { 28 return "", "", err 29 } 30 31 if len(doc.AlsoKnownAs) == 0 { 32 return "", "", fmt.Errorf("users did document does not specify a handle") 33 } 34 35 aka := doc.AlsoKnownAs[0] 36 37 u, err := url.Parse(aka) 38 if err != nil { 39 return "", "", fmt.Errorf("aka field in doc was not a valid url: %w", err) 40 } 41 42 handle := u.Host 43 44 var svc *did.Service 45 for _, s := range doc.Service { 46 if s.ID.String() == "#atproto_pds" && s.Type == "AtprotoPersonalDataServer" { 47 svc = &s 48 break 49 } 50 } 51 52 if svc == nil { 53 return "", "", fmt.Errorf("users did document has no pds service set") 54 } 55 56 verdid, err := hr.ResolveHandleToDid(ctx, handle) 57 if err != nil { 58 return "", "", err 59 } 60 61 if verdid != udid { 62 return "", "", fmt.Errorf("pds server reported different did for claimed handle") 63 } 64 65 return handle, svc.ServiceEndpoint, nil 66} 67 68type HandleResolver interface { 69 ResolveHandleToDid(ctx context.Context, handle string) (string, error) 70} 71 72type failCacheItem struct { 73 err error 74 count int 75 expiresAt time.Time 76} 77 78type ProdHandleResolver struct { 79 client *http.Client 80 resolver *net.Resolver 81 ReqMod func(*http.Request, string) error 82 FailCache *arc.ARCCache[string, *failCacheItem] 83} 84 85func NewProdHandleResolver(failureCacheSize int, resolveAddr string, forceUDP bool) (*ProdHandleResolver, error) { 86 failureCache, err := arc.NewARC[string, *failCacheItem](failureCacheSize) 87 if err != nil { 88 return nil, err 89 } 90 91 if resolveAddr == "" { 92 resolveAddr = "1.1.1.1:53" 93 } 94 95 c := http.Client{ 96 Transport: otelhttp.NewTransport(http.DefaultTransport), 97 Timeout: time.Second * 10, 98 } 99 100 r := &net.Resolver{ 101 PreferGo: true, 102 Dial: func(ctx context.Context, network, address string) (net.Conn, error) { 103 d := net.Dialer{ 104 Timeout: time.Second * 10, 105 } 106 if forceUDP { 107 network = "udp" 108 } 109 return d.DialContext(ctx, network, resolveAddr) 110 }, 111 } 112 113 return &ProdHandleResolver{ 114 FailCache: failureCache, 115 client: &c, 116 resolver: r, 117 }, nil 118} 119 120func (dr *ProdHandleResolver) ResolveHandleToDid(ctx context.Context, handle string) (string, error) { 121 ctx, cancel := context.WithTimeout(ctx, time.Second*20) 122 defer cancel() 123 124 ctx, span := otel.Tracer("resolver").Start(ctx, "ResolveHandleToDid") 125 defer span.End() 126 127 var cachedFailureCount int 128 129 if dr.FailCache != nil { 130 if item, ok := dr.FailCache.Get(handle); ok { 131 cachedFailureCount = item.count 132 if item.expiresAt.After(time.Now()) { 133 return "", item.err 134 } 135 dr.FailCache.Remove(handle) 136 } 137 } 138 139 var wkres, dnsres string 140 var wkerr, dnserr error 141 142 var wg sync.WaitGroup 143 wg.Add(2) 144 145 go func() { 146 defer wg.Done() 147 wkres, wkerr = dr.resolveWellKnown(ctx, handle) 148 if wkerr == nil { 149 cancel() 150 } 151 }() 152 go func() { 153 defer wg.Done() 154 dnsres, dnserr = dr.resolveDNS(ctx, handle) 155 if dnserr == nil { 156 cancel() 157 } 158 }() 159 160 wg.Wait() 161 162 if dnserr == nil { 163 return dnsres, nil 164 } 165 166 if wkerr == nil { 167 return wkres, nil 168 } 169 170 err := errors.Join(fmt.Errorf("no did record found for handle %q", handle), dnserr, wkerr) 171 172 if dr.FailCache != nil { 173 cachedFailureCount++ 174 expireAt := time.Now().Add(time.Millisecond * 100) 175 if cachedFailureCount > 1 { 176 // exponential backoff 177 expireAt = time.Now().Add(time.Millisecond * 100 * time.Duration(cachedFailureCount*cachedFailureCount)) 178 // Clamp to one hour 179 if expireAt.After(time.Now().Add(time.Hour)) { 180 expireAt = time.Now().Add(time.Hour) 181 } 182 } 183 184 dr.FailCache.Add(handle, &failCacheItem{ 185 err: err, 186 expiresAt: expireAt, 187 count: cachedFailureCount, 188 }) 189 } 190 191 return "", err 192} 193 194func (dr *ProdHandleResolver) resolveWellKnown(ctx context.Context, handle string) (string, error) { 195 req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/.well-known/atproto-did", handle), nil) 196 if err != nil { 197 return "", err 198 } 199 200 if dr.ReqMod != nil { 201 if err := dr.ReqMod(req, handle); err != nil { 202 return "", err 203 } 204 } 205 206 req = req.WithContext(ctx) 207 208 resp, err := dr.client.Do(req) 209 if err != nil { 210 return "", fmt.Errorf("failed to resolve handle (%s) through HTTP well-known route: %s", handle, err) 211 } 212 if resp.StatusCode != 200 { 213 return "", fmt.Errorf("failed to resolve handle (%s) through HTTP well-known route: status=%d", handle, resp.StatusCode) 214 } 215 216 if resp.ContentLength > 2048 { 217 return "", fmt.Errorf("http well-known route returned too much data") 218 } 219 220 b, err := io.ReadAll(io.LimitReader(resp.Body, 2048)) 221 if err != nil { 222 return "", fmt.Errorf("failed to read resolved did: %w", err) 223 } 224 225 parsed, err := did.ParseDID(string(b)) 226 if err != nil { 227 return "", err 228 } 229 230 return parsed.String(), nil 231} 232 233func (dr *ProdHandleResolver) resolveDNS(ctx context.Context, handle string) (string, error) { 234 res, err := dr.resolver.LookupTXT(ctx, "_atproto."+handle) 235 if err != nil { 236 return "", fmt.Errorf("handle lookup failed: %w", err) 237 } 238 239 for _, s := range res { 240 if strings.HasPrefix(s, "did=") { 241 parts := strings.Split(s, "=") 242 pdid, err := did.ParseDID(parts[1]) 243 if err != nil { 244 return "", fmt.Errorf("invalid did in record: %w", err) 245 } 246 247 return pdid.String(), nil 248 } 249 } 250 251 return "", fmt.Errorf("no did record found") 252} 253 254type TestHandleResolver struct { 255 TrialHosts []string 256} 257 258func (tr *TestHandleResolver) ResolveHandleToDid(ctx context.Context, handle string) (string, error) { 259 c := http.DefaultClient 260 261 for _, h := range tr.TrialHosts { 262 req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/.well-known/atproto-did", h), nil) 263 if err != nil { 264 return "", err 265 } 266 267 req.Host = handle 268 269 resp, err := c.Do(req) 270 if err != nil { 271 slog.Warn("failed to resolve handle to DID", "handle", handle, "err", err) 272 continue 273 } 274 275 if resp.StatusCode != 200 { 276 slog.Warn("got non-200 status code while resolving handle", "handle", handle, "statusCode", resp.StatusCode) 277 continue 278 } 279 280 b, err := io.ReadAll(resp.Body) 281 if err != nil { 282 return "", fmt.Errorf("failed to read resolved did: %w", err) 283 } 284 285 parsed, err := did.ParseDID(string(b)) 286 if err != nil { 287 return "", err 288 } 289 290 return parsed.String(), nil 291 } 292 293 return "", fmt.Errorf("no did record found for handle %q", handle) 294}