1package xrpc
2
3import (
4 "bytes"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "io"
10 "net/http"
11 "net/url"
12 "strconv"
13 "strings"
14 "time"
15
16 "github.com/bluesky-social/indigo/util"
17 "github.com/carlmjohnson/versioninfo"
18)
19
20type Client struct {
21 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient().
22 Client *http.Client
23 Auth *AuthInfo
24 AdminToken *string
25 Host string
26 UserAgent *string
27 Headers map[string]string
28}
29
30func (c *Client) getClient() *http.Client {
31 if c.Client == nil {
32 return util.RobustHTTPClient()
33 }
34 return c.Client
35}
36
37var (
38 Query = http.MethodGet
39 Procedure = http.MethodPost
40)
41
42type AuthInfo struct {
43 AccessJwt string `json:"accessJwt"`
44 RefreshJwt string `json:"refreshJwt"`
45 Handle string `json:"handle"`
46 Did string `json:"did"`
47}
48
49type XRPCError struct {
50 ErrStr string `json:"error"`
51 Message string `json:"message"`
52}
53
54func (xe *XRPCError) Error() string {
55 return fmt.Sprintf("%s: %s", xe.ErrStr, xe.Message)
56}
57
58type Error struct {
59 StatusCode int
60 Wrapped error
61 Ratelimit *RatelimitInfo
62}
63
64func (e *Error) Error() string {
65 // Preserving "XRPC ERROR %d" prefix for compatibility - previously matching this string was the only way
66 // to obtain the status code.
67 if e.Wrapped == nil {
68 return fmt.Sprintf("XRPC ERROR %d", e.StatusCode)
69 }
70 if e.StatusCode == http.StatusTooManyRequests && e.Ratelimit != nil {
71 return fmt.Sprintf("XRPC ERROR %d: %s (throttled until %s)", e.StatusCode, e.Wrapped, e.Ratelimit.Reset.Local())
72 }
73 return fmt.Sprintf("XRPC ERROR %d: %s", e.StatusCode, e.Wrapped)
74}
75
76func (e *Error) Unwrap() error {
77 if e.Wrapped == nil {
78 return nil
79 }
80 return e.Wrapped
81}
82
83func (e *Error) IsThrottled() bool {
84 return e.StatusCode == http.StatusTooManyRequests
85}
86
87func errorFromHTTPResponse(resp *http.Response, err error) error {
88 r := &Error{
89 StatusCode: resp.StatusCode,
90 Wrapped: err,
91 }
92 if resp.Header.Get("ratelimit-limit") != "" {
93 r.Ratelimit = &RatelimitInfo{
94 Policy: resp.Header.Get("ratelimit-policy"),
95 }
96 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil {
97 r.Ratelimit.Reset = time.Unix(n, 0)
98 }
99 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil {
100 r.Ratelimit.Limit = int(n)
101 }
102 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil {
103 r.Ratelimit.Remaining = int(n)
104 }
105 }
106 return r
107}
108
109type RatelimitInfo struct {
110 Limit int
111 Remaining int
112 Policy string
113 Reset time.Time
114}
115
116// makeParams converts a map of string keys and any values into a URL-encoded string.
117// If a value is a slice of strings, it will be joined with commas.
118// Generally the values will be strings, numbers, booleans, or slices of strings
119func makeParams(p map[string]any) string {
120 params := url.Values{}
121 for k, v := range p {
122 if s, ok := v.([]string); ok {
123 for _, v := range s {
124 params.Add(k, v)
125 }
126 } else {
127 params.Add(k, fmt.Sprint(v))
128 }
129 }
130
131 return params.Encode()
132}
133
134func (c *Client) Do(ctx context.Context, kind string, inpenc string, method string, params map[string]interface{}, bodyobj interface{}, out interface{}) error {
135 var body io.Reader
136 if bodyobj != nil {
137 if rr, ok := bodyobj.(io.Reader); ok {
138 body = rr
139 } else {
140 b, err := json.Marshal(bodyobj)
141 if err != nil {
142 return err
143 }
144
145 body = bytes.NewReader(b)
146 }
147 }
148
149 var m string
150 switch kind {
151 case Query:
152 m = "GET"
153 case Procedure:
154 m = "POST"
155 default:
156 return fmt.Errorf("unsupported request kind: %s", kind)
157 }
158
159 var paramStr string
160 if len(params) > 0 {
161 paramStr = "?" + makeParams(params)
162 }
163
164 req, err := http.NewRequest(m, c.Host+"/xrpc/"+method+paramStr, body)
165 if err != nil {
166 return err
167 }
168
169 if bodyobj != nil && inpenc != "" {
170 req.Header.Set("Content-Type", inpenc)
171 }
172 if c.UserAgent != nil {
173 req.Header.Set("User-Agent", *c.UserAgent)
174 } else {
175 req.Header.Set("User-Agent", "indigo/"+versioninfo.Short())
176 }
177
178 if c.Headers != nil {
179 for k, v := range c.Headers {
180 req.Header.Set(k, v)
181 }
182 }
183
184 // use admin auth if we have it configured and are doing a request that requires it
185 if c.AdminToken != nil && (strings.HasPrefix(method, "com.atproto.admin.") || strings.HasPrefix(method, "tools.ozone.") || method == "com.atproto.server.createInviteCode" || method == "com.atproto.server.createInviteCodes") {
186 req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:"+*c.AdminToken)))
187 } else if c.Auth != nil {
188 req.Header.Set("Authorization", "Bearer "+c.Auth.AccessJwt)
189 }
190
191 resp, err := c.getClient().Do(req.WithContext(ctx))
192 if err != nil {
193 return fmt.Errorf("request failed: %w", err)
194 }
195
196 defer resp.Body.Close()
197
198 if resp.StatusCode != 200 {
199 var xe XRPCError
200 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil {
201 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err))
202 }
203 return errorFromHTTPResponse(resp, &xe)
204 }
205
206 if out != nil {
207 if buf, ok := out.(*bytes.Buffer); ok {
208 if resp.ContentLength < 0 {
209 _, err := io.Copy(buf, resp.Body)
210 if err != nil {
211 return fmt.Errorf("reading response body: %w", err)
212 }
213 } else {
214 n, err := io.CopyN(buf, resp.Body, resp.ContentLength)
215 if err != nil {
216 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err)
217 }
218 }
219 } else {
220 if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
221 return fmt.Errorf("decoding xrpc response: %w", err)
222 }
223 }
224 }
225
226 return nil
227}
228
229func (c *Client) LexDo(ctx context.Context, method string, inputEncoding string, endpoint string, params map[string]any, bodyData any, out any) error {
230 return c.Do(ctx, method, inputEncoding, endpoint, params, bodyData, out)
231}