1package server
2
3import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "slices"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/golang-jwt/jwt/v4"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/oauth"
16 "github.com/haileyok/cocoon/oauth/constants"
17 "github.com/haileyok/cocoon/oauth/dpop"
18 "github.com/haileyok/cocoon/oauth/provider"
19 "github.com/labstack/echo/v4"
20)
21
22type OauthTokenRequest struct {
23 provider.AuthenticateClientRequestBase
24 GrantType string `form:"grant_type" json:"grant_type"`
25 Code *string `form:"code" json:"code,omitempty"`
26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"`
27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"`
28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"`
29}
30
31type OauthTokenResponse struct {
32 AccessToken string `json:"access_token"`
33 TokenType string `json:"token_type"`
34 RefreshToken string `json:"refresh_token"`
35 Scope string `json:"scope"`
36 ExpiresIn int64 `json:"expires_in"`
37 Sub string `json:"sub"`
38}
39
40func (s *Server) handleOauthToken(e echo.Context) error {
41 ctx := e.Request().Context()
42 logger := s.logger.With("name", "handleOauthToken")
43
44 var req OauthTokenRequest
45 if err := e.Bind(&req); err != nil {
46 logger.Error("error binding token request", "error", err)
47 return helpers.ServerError(e, nil)
48 }
49
50 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
51 if err != nil {
52 if errors.Is(err, dpop.ErrUseDpopNonce) {
53 nonce := s.oauthProvider.NextNonce()
54 if nonce != "" {
55 e.Response().Header().Set("DPoP-Nonce", nonce)
56 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
57 }
58 return e.JSON(400, map[string]string{
59 "error": "use_dpop_nonce",
60 })
61 }
62 logger.Error("error getting dpop proof", "error", err)
63 return helpers.InputError(e, nil)
64 }
65
66 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
67 AllowMissingDpopProof: true,
68 })
69 if err != nil {
70 logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
71 return helpers.InputError(e, to.StringPtr(err.Error()))
72 }
73
74 // TODO: this should come from an oauth provier config
75 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
76 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
77 }
78
79 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
80 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
81 }
82
83 if req.GrantType == "authorization_code" {
84 if req.Code == nil {
85 return helpers.InputError(e, to.StringPtr(`"code" is required"`))
86 }
87
88 var authReq provider.OauthAuthorizationRequest
89 // get the lil guy and delete him
90 if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
91 logger.Error("error finding authorization request", "error", err)
92 return helpers.ServerError(e, nil)
93 }
94
95 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
96 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
97 }
98
99 if authReq.Parameters.CodeChallenge != nil {
100 if req.CodeVerifier == nil {
101 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
102 }
103
104 if len(*req.CodeVerifier) < 43 {
105 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
106 }
107
108 switch *&authReq.Parameters.CodeChallengeMethod {
109 case "", "plain":
110 if authReq.Parameters.CodeChallenge != req.CodeVerifier {
111 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
112 }
113 case "S256":
114 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
115 if err != nil {
116 logger.Error("error decoding code challenge", "error", err)
117 return helpers.ServerError(e, nil)
118 }
119
120 h := sha256.New()
121 h.Write([]byte(*req.CodeVerifier))
122 compdChal := h.Sum(nil)
123
124 if !bytes.Equal(inputChal, compdChal) {
125 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
126 }
127 default:
128 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
129 }
130 } else if req.CodeVerifier != nil {
131 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
132 }
133
134 repo, err := s.getRepoActorByDid(ctx, *authReq.Sub)
135 if err != nil {
136 helpers.InputError(e, to.StringPtr("unable to find actor"))
137 }
138
139 now := time.Now()
140 eat := now.Add(constants.TokenMaxAge)
141 id := oauth.GenerateTokenId()
142
143 refreshToken := oauth.GenerateRefreshToken()
144
145 accessClaims := jwt.MapClaims{
146 "scope": authReq.Parameters.Scope,
147 "aud": s.config.Did,
148 "sub": repo.Repo.Did,
149 "iat": now.Unix(),
150 "exp": eat.Unix(),
151 "jti": id,
152 "client_id": authReq.ClientId,
153 }
154
155 if authReq.Parameters.DpopJkt != nil {
156 accessClaims["cnf"] = *authReq.Parameters.DpopJkt
157 }
158
159 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
160 accessString, err := accessToken.SignedString(s.privateKey)
161 if err != nil {
162 return err
163 }
164
165 if err := s.db.Create(ctx, &provider.OauthToken{
166 ClientId: authReq.ClientId,
167 ClientAuth: *clientAuth,
168 Parameters: authReq.Parameters,
169 ExpiresAt: eat,
170 DeviceId: "",
171 Sub: repo.Repo.Did,
172 Code: *authReq.Code,
173 Token: accessString,
174 RefreshToken: refreshToken,
175 Ip: authReq.Ip,
176 }, nil).Error; err != nil {
177 logger.Error("error creating token in db", "error", err)
178 return helpers.ServerError(e, nil)
179 }
180
181 // prob not needed
182 tokenType := "Bearer"
183 if authReq.Parameters.DpopJkt != nil {
184 tokenType = "DPoP"
185 }
186
187 e.Response().Header().Set("content-type", "application/json")
188
189 return e.JSON(200, OauthTokenResponse{
190 AccessToken: accessString,
191 RefreshToken: refreshToken,
192 TokenType: tokenType,
193 Scope: authReq.Parameters.Scope,
194 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
195 Sub: repo.Repo.Did,
196 })
197 }
198
199 if req.GrantType == "refresh_token" {
200 if req.RefreshToken == nil {
201 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
202 }
203
204 var oauthToken provider.OauthToken
205 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
206 logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
207 return helpers.ServerError(e, nil)
208 }
209
210 if client.Metadata.ClientID != oauthToken.ClientId {
211 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
212 }
213
214 if clientAuth.Method != oauthToken.ClientAuth.Method {
215 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
216 }
217
218 if *oauthToken.Parameters.DpopJkt != proof.JKT {
219 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
220 }
221
222 ageRes := oauth.GetSessionAgeFromToken(oauthToken)
223
224 if ageRes.SessionExpired {
225 return helpers.InputError(e, to.StringPtr("Session expired"))
226 }
227
228 if ageRes.RefreshExpired {
229 return helpers.InputError(e, to.StringPtr("Refresh token expired"))
230 }
231
232 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
233 // why? ref impl
234 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
235 }
236
237 nextTokenId := oauth.GenerateTokenId()
238 nextRefreshToken := oauth.GenerateRefreshToken()
239
240 now := time.Now()
241 eat := now.Add(constants.TokenMaxAge)
242
243 accessClaims := jwt.MapClaims{
244 "scope": oauthToken.Parameters.Scope,
245 "aud": s.config.Did,
246 "sub": oauthToken.Sub,
247 "iat": now.Unix(),
248 "exp": eat.Unix(),
249 "jti": nextTokenId,
250 "client_id": oauthToken.ClientId,
251 }
252
253 if oauthToken.Parameters.DpopJkt != nil {
254 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
255 }
256
257 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
258 accessString, err := accessToken.SignedString(s.privateKey)
259 if err != nil {
260 return err
261 }
262
263 if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
264 logger.Error("error updating token", "error", err)
265 return helpers.ServerError(e, nil)
266 }
267
268 // prob not needed
269 tokenType := "Bearer"
270 if oauthToken.Parameters.DpopJkt != nil {
271 tokenType = "DPoP"
272 }
273
274 return e.JSON(200, OauthTokenResponse{
275 AccessToken: accessString,
276 RefreshToken: nextRefreshToken,
277 TokenType: tokenType,
278 Scope: oauthToken.Parameters.Scope,
279 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
280 Sub: oauthToken.Sub,
281 })
282 }
283
284 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
285}