forked from hailey.at/cocoon
An atproto PDS written in Go
at main 9.3 kB view raw
1package server 2 3import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "slices" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/golang-jwt/jwt/v4" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/oauth" 16 "github.com/haileyok/cocoon/oauth/constants" 17 "github.com/haileyok/cocoon/oauth/dpop" 18 "github.com/haileyok/cocoon/oauth/provider" 19 "github.com/labstack/echo/v4" 20) 21 22type OauthTokenRequest struct { 23 provider.AuthenticateClientRequestBase 24 GrantType string `form:"grant_type" json:"grant_type"` 25 Code *string `form:"code" json:"code,omitempty"` 26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"` 27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"` 28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"` 29} 30 31type OauthTokenResponse struct { 32 AccessToken string `json:"access_token"` 33 TokenType string `json:"token_type"` 34 RefreshToken string `json:"refresh_token"` 35 Scope string `json:"scope"` 36 ExpiresIn int64 `json:"expires_in"` 37 Sub string `json:"sub"` 38} 39 40func (s *Server) handleOauthToken(e echo.Context) error { 41 ctx := e.Request().Context() 42 logger := s.logger.With("name", "handleOauthToken") 43 44 var req OauthTokenRequest 45 if err := e.Bind(&req); err != nil { 46 logger.Error("error binding token request", "error", err) 47 return helpers.ServerError(e, nil) 48 } 49 50 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil) 51 if err != nil { 52 if errors.Is(err, dpop.ErrUseDpopNonce) { 53 nonce := s.oauthProvider.NextNonce() 54 if nonce != "" { 55 e.Response().Header().Set("DPoP-Nonce", nonce) 56 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 57 } 58 return e.JSON(400, map[string]string{ 59 "error": "use_dpop_nonce", 60 }) 61 } 62 logger.Error("error getting dpop proof", "error", err) 63 return helpers.InputError(e, nil) 64 } 65 66 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 67 AllowMissingDpopProof: true, 68 }) 69 if err != nil { 70 logger.Error("error authenticating client", "client_id", req.ClientID, "error", err) 71 return helpers.InputError(e, to.StringPtr(err.Error())) 72 } 73 74 // TODO: this should come from an oauth provier config 75 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) { 76 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType))) 77 } 78 79 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) { 80 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType))) 81 } 82 83 if req.GrantType == "authorization_code" { 84 if req.Code == nil { 85 return helpers.InputError(e, to.StringPtr(`"code" is required"`)) 86 } 87 88 var authReq provider.OauthAuthorizationRequest 89 // get the lil guy and delete him 90 if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 91 logger.Error("error finding authorization request", "error", err) 92 return helpers.ServerError(e, nil) 93 } 94 95 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI { 96 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`)) 97 } 98 99 if authReq.Parameters.CodeChallenge != nil { 100 if req.CodeVerifier == nil { 101 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`)) 102 } 103 104 if len(*req.CodeVerifier) < 43 { 105 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`)) 106 } 107 108 switch *&authReq.Parameters.CodeChallengeMethod { 109 case "", "plain": 110 if authReq.Parameters.CodeChallenge != req.CodeVerifier { 111 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 112 } 113 case "S256": 114 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge) 115 if err != nil { 116 logger.Error("error decoding code challenge", "error", err) 117 return helpers.ServerError(e, nil) 118 } 119 120 h := sha256.New() 121 h.Write([]byte(*req.CodeVerifier)) 122 compdChal := h.Sum(nil) 123 124 if !bytes.Equal(inputChal, compdChal) { 125 return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 126 } 127 default: 128 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod)) 129 } 130 } else if req.CodeVerifier != nil { 131 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 132 } 133 134 repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) 135 if err != nil { 136 helpers.InputError(e, to.StringPtr("unable to find actor")) 137 } 138 139 now := time.Now() 140 eat := now.Add(constants.TokenMaxAge) 141 id := oauth.GenerateTokenId() 142 143 refreshToken := oauth.GenerateRefreshToken() 144 145 accessClaims := jwt.MapClaims{ 146 "scope": authReq.Parameters.Scope, 147 "aud": s.config.Did, 148 "sub": repo.Repo.Did, 149 "iat": now.Unix(), 150 "exp": eat.Unix(), 151 "jti": id, 152 "client_id": authReq.ClientId, 153 } 154 155 if authReq.Parameters.DpopJkt != nil { 156 accessClaims["cnf"] = *authReq.Parameters.DpopJkt 157 } 158 159 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 160 accessString, err := accessToken.SignedString(s.privateKey) 161 if err != nil { 162 return err 163 } 164 165 if err := s.db.Create(ctx, &provider.OauthToken{ 166 ClientId: authReq.ClientId, 167 ClientAuth: *clientAuth, 168 Parameters: authReq.Parameters, 169 ExpiresAt: eat, 170 DeviceId: "", 171 Sub: repo.Repo.Did, 172 Code: *authReq.Code, 173 Token: accessString, 174 RefreshToken: refreshToken, 175 Ip: authReq.Ip, 176 }, nil).Error; err != nil { 177 logger.Error("error creating token in db", "error", err) 178 return helpers.ServerError(e, nil) 179 } 180 181 // prob not needed 182 tokenType := "Bearer" 183 if authReq.Parameters.DpopJkt != nil { 184 tokenType = "DPoP" 185 } 186 187 e.Response().Header().Set("content-type", "application/json") 188 189 return e.JSON(200, OauthTokenResponse{ 190 AccessToken: accessString, 191 RefreshToken: refreshToken, 192 TokenType: tokenType, 193 Scope: authReq.Parameters.Scope, 194 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 195 Sub: repo.Repo.Did, 196 }) 197 } 198 199 if req.GrantType == "refresh_token" { 200 if req.RefreshToken == nil { 201 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`)) 202 } 203 204 var oauthToken provider.OauthToken 205 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 206 logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 207 return helpers.ServerError(e, nil) 208 } 209 210 if client.Metadata.ClientID != oauthToken.ClientId { 211 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`)) 212 } 213 214 if clientAuth.Method != oauthToken.ClientAuth.Method { 215 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`)) 216 } 217 218 if *oauthToken.Parameters.DpopJkt != proof.JKT { 219 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 220 } 221 222 ageRes := oauth.GetSessionAgeFromToken(oauthToken) 223 224 if ageRes.SessionExpired { 225 return helpers.InputError(e, to.StringPtr("Session expired")) 226 } 227 228 if ageRes.RefreshExpired { 229 return helpers.InputError(e, to.StringPtr("Refresh token expired")) 230 } 231 232 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil { 233 // why? ref impl 234 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens")) 235 } 236 237 nextTokenId := oauth.GenerateTokenId() 238 nextRefreshToken := oauth.GenerateRefreshToken() 239 240 now := time.Now() 241 eat := now.Add(constants.TokenMaxAge) 242 243 accessClaims := jwt.MapClaims{ 244 "scope": oauthToken.Parameters.Scope, 245 "aud": s.config.Did, 246 "sub": oauthToken.Sub, 247 "iat": now.Unix(), 248 "exp": eat.Unix(), 249 "jti": nextTokenId, 250 "client_id": oauthToken.ClientId, 251 } 252 253 if oauthToken.Parameters.DpopJkt != nil { 254 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt 255 } 256 257 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 258 accessString, err := accessToken.SignedString(s.privateKey) 259 if err != nil { 260 return err 261 } 262 263 if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 264 logger.Error("error updating token", "error", err) 265 return helpers.ServerError(e, nil) 266 } 267 268 // prob not needed 269 tokenType := "Bearer" 270 if oauthToken.Parameters.DpopJkt != nil { 271 tokenType = "DPoP" 272 } 273 274 return e.JSON(200, OauthTokenResponse{ 275 AccessToken: accessString, 276 RefreshToken: nextRefreshToken, 277 TokenType: tokenType, 278 Scope: oauthToken.Parameters.Scope, 279 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 280 Sub: oauthToken.Sub, 281 }) 282 } 283 284 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType))) 285}