forked from hailey.at/cocoon
An atproto PDS written in Go
1package server 2 3import ( 4 "net/url" 5 "strings" 6 "time" 7 8 "github.com/Azure/go-autorest/autorest/to" 9 "github.com/haileyok/cocoon/internal/helpers" 10 "github.com/haileyok/cocoon/oauth" 11 "github.com/haileyok/cocoon/oauth/provider" 12 "github.com/labstack/echo/v4" 13) 14 15func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 ctx := e.Request().Context() 17 18 reqUri := e.QueryParam("request_uri") 19 if reqUri == "" { 20 // render page for logged out dev 21 if s.config.Version == "dev" { 22 return e.Render(200, "authorize.html", map[string]any{ 23 "Scopes": []string{"atproto", "transition:generic"}, 24 "AppName": "DEV MODE AUTHORIZATION PAGE", 25 "Handle": "paula.cocoon.social", 26 "RequestUri": "", 27 }) 28 } 29 return helpers.InputError(e, to.StringPtr("no request uri")) 30 } 31 32 repo, _, err := s.getSessionRepoOrErr(e) 33 if err != nil { 34 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 35 } 36 37 reqId, err := oauth.DecodeRequestUri(reqUri) 38 if err != nil { 39 return helpers.InputError(e, to.StringPtr(err.Error())) 40 } 41 42 var req provider.OauthAuthorizationRequest 43 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 44 return helpers.ServerError(e, to.StringPtr(err.Error())) 45 } 46 47 clientId := e.QueryParam("client_id") 48 if clientId != req.ClientId { 49 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request")) 50 } 51 52 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId) 53 if err != nil { 54 return helpers.ServerError(e, to.StringPtr(err.Error())) 55 } 56 57 scopes := strings.Split(req.Parameters.Scope, " ") 58 appName := client.Metadata.ClientName 59 60 data := map[string]any{ 61 "Scopes": scopes, 62 "AppName": appName, 63 "RequestUri": reqUri, 64 "QueryParams": e.QueryParams().Encode(), 65 "Handle": repo.Actor.Handle, 66 } 67 68 return e.Render(200, "authorize.html", data) 69} 70 71type OauthAuthorizePostRequest struct { 72 RequestUri string `form:"request_uri"` 73 AcceptOrRejct string `form:"accept_or_reject"` 74} 75 76func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 77 ctx := e.Request().Context() 78 logger := s.logger.With("name", "handleOauthAuthorizePost") 79 80 repo, _, err := s.getSessionRepoOrErr(e) 81 if err != nil { 82 return e.Redirect(303, "/account/signin") 83 } 84 85 var req OauthAuthorizePostRequest 86 if err := e.Bind(&req); err != nil { 87 logger.Error("error binding authorize post request", "error", err) 88 return helpers.InputError(e, nil) 89 } 90 91 reqId, err := oauth.DecodeRequestUri(req.RequestUri) 92 if err != nil { 93 return helpers.InputError(e, to.StringPtr(err.Error())) 94 } 95 96 var authReq provider.OauthAuthorizationRequest 97 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 98 return helpers.ServerError(e, to.StringPtr(err.Error())) 99 } 100 101 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId) 102 if err != nil { 103 return helpers.ServerError(e, to.StringPtr(err.Error())) 104 } 105 106 // TODO: figure out how im supposed to actually redirect 107 if req.AcceptOrRejct == "reject" { 108 return e.Redirect(303, client.Metadata.ClientURI) 109 } 110 111 if time.Now().After(authReq.ExpiresAt) { 112 return helpers.InputError(e, to.StringPtr("the request has expired")) 113 } 114 115 if authReq.Sub != nil || authReq.Code != nil { 116 return helpers.InputError(e, to.StringPtr("this request was already authorized")) 117 } 118 119 code := oauth.GenerateCode() 120 121 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 122 logger.Error("error updating authorization request", "error", err) 123 return helpers.ServerError(e, nil) 124 } 125 126 q := url.Values{} 127 q.Set("state", authReq.Parameters.State) 128 q.Set("iss", "https://"+s.config.Hostname) 129 q.Set("code", code) 130 131 hashOrQuestion := "?" 132 if authReq.ClientAuth.Method != "private_key_jwt" { 133 hashOrQuestion = "#" 134 } 135 136 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode()) 137}