1// Copyright 2022 The Gitea Authors. All rights reserved.
2// Copyright 2024 The Forgejo Authors. All rights reserved.
3// SPDX-License-Identifier: MIT
4
5// TODO: Think about whether this should be moved to services/activitypub (compare to exosy/services/activitypub/client.go)
6package activitypub
7
8import (
9 "bytes"
10 "context"
11 "crypto/rsa"
12 "crypto/x509"
13 "encoding/pem"
14 "fmt"
15 "io"
16 "net/http"
17 "strings"
18 "time"
19
20 user_model "forgejo.org/models/user"
21 "forgejo.org/modules/log"
22 "forgejo.org/modules/proxy"
23 "forgejo.org/modules/setting"
24
25 "github.com/42wim/httpsig"
26)
27
28const (
29 // ActivityStreamsContentType const
30 ActivityStreamsContentType = `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`
31 httpsigExpirationTime = 60
32)
33
34func CurrentTime() string {
35 return time.Now().UTC().Format(http.TimeFormat)
36}
37
38func containsRequiredHTTPHeaders(method string, headers []string) error {
39 var hasRequestTarget, hasDate, hasDigest, hasHost bool
40 for _, header := range headers {
41 hasRequestTarget = hasRequestTarget || header == httpsig.RequestTarget
42 hasDate = hasDate || header == "Date"
43 hasDigest = hasDigest || header == "Digest"
44 hasHost = hasHost || header == "Host"
45 }
46 if !hasRequestTarget {
47 return fmt.Errorf("missing http header for %s: %s", method, httpsig.RequestTarget)
48 } else if !hasDate {
49 return fmt.Errorf("missing http header for %s: Date", method)
50 } else if !hasHost {
51 return fmt.Errorf("missing http header for %s: Host", method)
52 } else if !hasDigest && method != http.MethodGet {
53 return fmt.Errorf("missing http header for %s: Digest", method)
54 }
55 return nil
56}
57
58// Client struct
59type ClientFactory struct {
60 client *http.Client
61 algs []httpsig.Algorithm
62 digestAlg httpsig.DigestAlgorithm
63 getHeaders []string
64 postHeaders []string
65}
66
67// NewClient function
68func NewClientFactory() (c *ClientFactory, err error) {
69 if err = containsRequiredHTTPHeaders(http.MethodGet, setting.Federation.GetHeaders); err != nil {
70 return nil, err
71 } else if err = containsRequiredHTTPHeaders(http.MethodPost, setting.Federation.PostHeaders); err != nil {
72 return nil, err
73 }
74
75 c = &ClientFactory{
76 client: &http.Client{
77 Transport: &http.Transport{
78 Proxy: proxy.Proxy(),
79 },
80 Timeout: 5 * time.Second,
81 },
82 algs: setting.HttpsigAlgs,
83 digestAlg: httpsig.DigestAlgorithm(setting.Federation.DigestAlgorithm),
84 getHeaders: setting.Federation.GetHeaders,
85 postHeaders: setting.Federation.PostHeaders,
86 }
87 return c, err
88}
89
90type APClientFactory interface {
91 WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error)
92}
93
94// Client struct
95type Client struct {
96 client *http.Client
97 algs []httpsig.Algorithm
98 digestAlg httpsig.DigestAlgorithm
99 getHeaders []string
100 postHeaders []string
101 priv *rsa.PrivateKey
102 pubID string
103}
104
105// NewRequest function
106func (cf *ClientFactory) WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error) {
107 priv, err := GetPrivateKey(ctx, user)
108 if err != nil {
109 return nil, err
110 }
111 privPem, _ := pem.Decode([]byte(priv))
112 privParsed, err := x509.ParsePKCS1PrivateKey(privPem.Bytes)
113 if err != nil {
114 return nil, err
115 }
116
117 c := Client{
118 client: cf.client,
119 algs: cf.algs,
120 digestAlg: cf.digestAlg,
121 getHeaders: cf.getHeaders,
122 postHeaders: cf.postHeaders,
123 priv: privParsed,
124 pubID: pubID,
125 }
126 return &c, nil
127}
128
129// NewRequest function
130func (c *Client) newRequest(method string, b []byte, to string) (req *http.Request, err error) {
131 buf := bytes.NewBuffer(b)
132 req, err = http.NewRequest(method, to, buf)
133 if err != nil {
134 return nil, err
135 }
136 req.Header.Add("Accept", "application/json, "+ActivityStreamsContentType)
137 req.Header.Add("Date", CurrentTime())
138 req.Header.Add("Host", req.URL.Host)
139 req.Header.Add("User-Agent", "Gitea/"+setting.AppVer)
140 req.Header.Add("Content-Type", ActivityStreamsContentType)
141
142 return req, err
143}
144
145// Post function
146func (c *Client) Post(b []byte, to string) (resp *http.Response, err error) {
147 var req *http.Request
148 if req, err = c.newRequest(http.MethodPost, b, to); err != nil {
149 return nil, err
150 }
151
152 signer, _, err := httpsig.NewSigner(c.algs, c.digestAlg, c.postHeaders, httpsig.Signature, httpsigExpirationTime)
153 if err != nil {
154 return nil, err
155 }
156 if err := signer.SignRequest(c.priv, c.pubID, req, b); err != nil {
157 return nil, err
158 }
159
160 resp, err = c.client.Do(req)
161 return resp, err
162}
163
164// Create an http GET request with forgejo/gitea specific headers
165func (c *Client) Get(to string) (resp *http.Response, err error) {
166 var req *http.Request
167 if req, err = c.newRequest(http.MethodGet, nil, to); err != nil {
168 return nil, err
169 }
170 signer, _, err := httpsig.NewSigner(c.algs, c.digestAlg, c.getHeaders, httpsig.Signature, httpsigExpirationTime)
171 if err != nil {
172 return nil, err
173 }
174 if err := signer.SignRequest(c.priv, c.pubID, req, nil); err != nil {
175 return nil, err
176 }
177
178 resp, err = c.client.Do(req)
179 return resp, err
180}
181
182// Create an http GET request with forgejo/gitea specific headers
183func (c *Client) GetBody(uri string) ([]byte, error) {
184 response, err := c.Get(uri)
185 if err != nil {
186 return nil, err
187 }
188 log.Debug("Client: got status: %v", response.Status)
189 if response.StatusCode != 200 {
190 err = fmt.Errorf("got non 200 status code for id: %v", uri)
191 return nil, err
192 }
193 defer response.Body.Close()
194 if response.ContentLength > setting.Federation.MaxSize {
195 return nil, fmt.Errorf("Request returned %d bytes (max allowed incomming size: %d bytes)", response.ContentLength, setting.Federation.MaxSize)
196 } else if response.ContentLength == -1 {
197 log.Warn("Request to %v returned an unknown content length, response may be truncated to %d bytes", uri, setting.Federation.MaxSize)
198 }
199
200 body, err := io.ReadAll(io.LimitReader(response.Body, setting.Federation.MaxSize))
201 if err != nil {
202 return nil, err
203 }
204
205 log.Debug("Client: got body: %v", charLimiter(string(body), 120))
206 return body, nil
207}
208
209// Limit number of characters in a string (useful to prevent log injection attacks and overly long log outputs)
210// Thanks to https://www.socketloop.com/tutorials/golang-characters-limiter-example
211func charLimiter(s string, limit int) string {
212 reader := strings.NewReader(s)
213 buff := make([]byte, limit)
214 n, _ := io.ReadAtLeast(reader, buff, limit)
215 if n != 0 {
216 return fmt.Sprint(string(buff), "...")
217 }
218 return s
219}
220
221type APClient interface {
222 newRequest(method string, b []byte, to string) (req *http.Request, err error)
223 Post(b []byte, to string) (resp *http.Response, err error)
224 Get(to string) (resp *http.Response, err error)
225 GetBody(uri string) ([]byte, error)
226}
227
228// contextKey is a value for use with context.WithValue.
229type contextKey struct {
230 name string
231}
232
233// clientFactoryContextKey is a context key. It is used with context.Value() to get the current Food for the context
234var (
235 clientFactoryContextKey = &contextKey{"clientFactory"}
236 _ APClientFactory = &ClientFactory{}
237)
238
239// Context represents an activitypub client factory context
240type Context struct {
241 context.Context
242 e APClientFactory
243}
244
245func NewContext(ctx context.Context, e APClientFactory) *Context {
246 return &Context{
247 Context: ctx,
248 e: e,
249 }
250}
251
252// APClientFactory represents an activitypub client factory
253func (ctx *Context) APClientFactory() APClientFactory {
254 return ctx.e
255}
256
257// provides APClientFactory
258type GetAPClient interface {
259 GetClientFactory() APClientFactory
260}
261
262// GetClientFactory will get an APClientFactory from this context or returns the default implementation
263func GetClientFactory(ctx context.Context) (APClientFactory, error) {
264 if e := getClientFactory(ctx); e != nil {
265 return e, nil
266 }
267 return NewClientFactory()
268}
269
270// getClientFactory will get an APClientFactory from this context or return nil
271func getClientFactory(ctx context.Context) APClientFactory {
272 if clientFactory, ok := ctx.(APClientFactory); ok {
273 return clientFactory
274 }
275 clientFactoryInterface := ctx.Value(clientFactoryContextKey)
276 if clientFactoryInterface != nil {
277 return clientFactoryInterface.(GetAPClient).GetClientFactory()
278 }
279 return nil
280}