A community based topic aggregation platform built on atproto
at main 158 lines 5.1 kB view raw
1package imageproxy 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "net/url" 10 "strings" 11 "time" 12) 13 14// Fetcher defines the interface for fetching blobs from a PDS. 15type Fetcher interface { 16 // Fetch retrieves a blob from the specified PDS. 17 // Returns the blob bytes or an error if the fetch fails. 18 Fetch(ctx context.Context, pdsURL, did, cid string) ([]byte, error) 19} 20 21// PDSFetcher implements the Fetcher interface for fetching blobs from atproto PDS servers. 22type PDSFetcher struct { 23 client *http.Client 24 timeout time.Duration 25 maxSizeBytes int64 26} 27 28// DefaultMaxSourceSizeMB is the default maximum source image size if not configured. 29const DefaultMaxSourceSizeMB = 10 30 31// NewPDSFetcher creates a new PDSFetcher with the specified timeout. 32// maxSizeMB specifies the maximum allowed image size in megabytes (0 uses default of 10MB). 33func NewPDSFetcher(timeout time.Duration, maxSizeMB int) *PDSFetcher { 34 if maxSizeMB <= 0 { 35 maxSizeMB = DefaultMaxSourceSizeMB 36 } 37 return &PDSFetcher{ 38 client: &http.Client{ 39 Timeout: timeout, 40 }, 41 timeout: timeout, 42 maxSizeBytes: int64(maxSizeMB) * 1024 * 1024, 43 } 44} 45 46// Fetch retrieves a blob from the specified PDS using the com.atproto.sync.getBlob endpoint. 47// Returns: 48// - ErrPDSNotFound if the blob does not exist (404 response) 49// - ErrPDSTimeout if the request times out or context is cancelled 50// - ErrPDSFetchFailed for any other error 51func (f *PDSFetcher) Fetch(ctx context.Context, pdsURL, did, cid string) ([]byte, error) { 52 // Construct the request URL 53 endpoint, err := url.Parse(pdsURL) 54 if err != nil { 55 return nil, fmt.Errorf("%w: invalid PDS URL: %v", ErrPDSFetchFailed, err) 56 } 57 endpoint.Path = "/xrpc/com.atproto.sync.getBlob" 58 59 query := url.Values{} 60 query.Set("did", did) 61 query.Set("cid", cid) 62 endpoint.RawQuery = query.Encode() 63 64 // Create the request with context 65 req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) 66 if err != nil { 67 return nil, fmt.Errorf("%w: failed to create request: %v", ErrPDSFetchFailed, err) 68 } 69 70 // Set User-Agent header for identification 71 req.Header.Set("User-Agent", "Coves-ImageProxy/1.0") 72 73 // Execute the request 74 resp, err := f.client.Do(req) 75 if err != nil { 76 // Check if the error is due to context cancellation or timeout 77 if ctx.Err() != nil { 78 return nil, fmt.Errorf("%w: %v", ErrPDSTimeout, ctx.Err()) 79 } 80 // Check if it's a timeout error from the http client 81 if isTimeoutError(err) { 82 return nil, fmt.Errorf("%w: request timed out", ErrPDSTimeout) 83 } 84 return nil, fmt.Errorf("%w: %v", ErrPDSFetchFailed, err) 85 } 86 defer resp.Body.Close() 87 88 // Handle response status codes 89 switch resp.StatusCode { 90 case http.StatusOK: 91 // Check Content-Length header if available 92 if resp.ContentLength > 0 && resp.ContentLength > f.maxSizeBytes { 93 return nil, fmt.Errorf("%w: content length %d exceeds maximum %d bytes", 94 ErrImageTooLarge, resp.ContentLength, f.maxSizeBytes) 95 } 96 97 // Use a limited reader to prevent memory exhaustion even if Content-Length is missing or wrong. 98 // We read maxSizeBytes + 1 to detect if the response exceeds the limit. 99 limitedReader := io.LimitReader(resp.Body, f.maxSizeBytes+1) 100 data, err := io.ReadAll(limitedReader) 101 if err != nil { 102 return nil, fmt.Errorf("%w: failed to read response body: %v", ErrPDSFetchFailed, err) 103 } 104 105 // Check if we hit the limit (meaning there was more data) 106 if int64(len(data)) > f.maxSizeBytes { 107 return nil, fmt.Errorf("%w: response body exceeds maximum %d bytes", 108 ErrImageTooLarge, f.maxSizeBytes) 109 } 110 111 return data, nil 112 113 case http.StatusNotFound: 114 return nil, ErrPDSNotFound 115 116 case http.StatusBadRequest: 117 // AT Protocol PDS may return 400 with "Blob not found" for missing blobs 118 // We need to check the error message to distinguish from actual bad requests 119 body, readErr := io.ReadAll(io.LimitReader(resp.Body, 1024)) 120 if readErr == nil && isBlobNotFoundError(body) { 121 return nil, ErrPDSNotFound 122 } 123 return nil, fmt.Errorf("%w: bad request (status 400)", ErrPDSFetchFailed) 124 125 default: 126 return nil, fmt.Errorf("%w: unexpected status code %d", ErrPDSFetchFailed, resp.StatusCode) 127 } 128} 129 130// pdsErrorResponse represents the error response structure from AT Protocol PDS 131type pdsErrorResponse struct { 132 Error string `json:"error"` 133 Message string `json:"message"` 134} 135 136// isBlobNotFoundError checks if the error response indicates a blob was not found. 137// AT Protocol PDS returns 400 with {"error":"InvalidRequest","message":"Blob not found"} 138// for missing blobs instead of a proper 404. 139func isBlobNotFoundError(body []byte) bool { 140 var errResp pdsErrorResponse 141 if err := json.Unmarshal(body, &errResp); err != nil { 142 return false 143 } 144 // Check for "Blob not found" message (case-insensitive) 145 return strings.Contains(strings.ToLower(errResp.Message), "blob not found") 146} 147 148// isTimeoutError checks if the error is a timeout-related error. 149func isTimeoutError(err error) bool { 150 if err == nil { 151 return false 152 } 153 // Check for timeout interface 154 if te, ok := err.(interface{ Timeout() bool }); ok { 155 return te.Timeout() 156 } 157 return false 158}