An atproto PDS written in Go
at main 6.5 kB view raw
1package server 2 3import ( 4 "fmt" 5 "net/url" 6 "strings" 7 "time" 8 9 "github.com/Azure/go-autorest/autorest/to" 10 "github.com/haileyok/cocoon/internal/helpers" 11 "github.com/haileyok/cocoon/oauth" 12 "github.com/haileyok/cocoon/oauth/constants" 13 "github.com/haileyok/cocoon/oauth/provider" 14 "github.com/labstack/echo/v4" 15) 16 17type HandleOauthAuthorizeGetInput struct { 18 RequestUri string `query:"request_uri"` 19} 20 21func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 22 ctx := e.Request().Context() 23 24 logger := s.logger.With("name", "handleOauthAuthorizeGet") 25 26 var input HandleOauthAuthorizeGetInput 27 if err := e.Bind(&input); err != nil { 28 logger.Error("error binding request", "err", err) 29 return fmt.Errorf("error binding request") 30 } 31 32 var reqId string 33 if input.RequestUri != "" { 34 id, err := oauth.DecodeRequestUri(input.RequestUri) 35 if err != nil { 36 logger.Error("no request uri found in input", "url", e.Request().URL.String()) 37 return helpers.InputError(e, to.StringPtr("no request uri")) 38 } 39 reqId = id 40 } else { 41 var parRequest provider.ParRequest 42 if err := e.Bind(&parRequest); err != nil { 43 s.logger.Error("error binding for standard auth request", "error", err) 44 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 45 } 46 47 if err := e.Validate(parRequest); err != nil { 48 // render page for logged out dev 49 if s.config.Version == "dev" && parRequest.ClientID == "" { 50 return e.Render(200, "authorize.html", map[string]any{ 51 "Scopes": []string{"atproto", "transition:generic"}, 52 "AppName": "DEV MODE AUTHORIZATION PAGE", 53 "Handle": "paula.cocoon.social", 54 "RequestUri": "", 55 }) 56 } 57 return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters")) 58 } 59 60 client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{ 61 AllowMissingDpopProof: true, 62 }) 63 if err != nil { 64 s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err) 65 return helpers.ServerError(e, to.StringPtr(err.Error())) 66 } 67 68 if parRequest.DpopJkt == nil { 69 if client.Metadata.DpopBoundAccessTokens { 70 } 71 } else { 72 if !client.Metadata.DpopBoundAccessTokens { 73 msg := "dpop bound access tokens are not enabled for this client" 74 return helpers.InputError(e, &msg) 75 } 76 } 77 78 eat := time.Now().Add(constants.ParExpiresIn) 79 id := oauth.GenerateRequestId() 80 81 authRequest := &provider.OauthAuthorizationRequest{ 82 RequestId: id, 83 ClientId: client.Metadata.ClientID, 84 ClientAuth: *clientAuth, 85 Parameters: parRequest, 86 ExpiresAt: eat, 87 } 88 89 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 90 s.logger.Error("error creating auth request in db", "error", err) 91 return helpers.ServerError(e, nil) 92 } 93 94 input.RequestUri = oauth.EncodeRequestUri(id) 95 reqId = id 96 97 } 98 99 repo, _, err := s.getSessionRepoOrErr(e) 100 if err != nil { 101 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 102 } 103 104 var req provider.OauthAuthorizationRequest 105 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 106 return helpers.ServerError(e, to.StringPtr(err.Error())) 107 } 108 109 clientId := e.QueryParam("client_id") 110 if clientId != req.ClientId { 111 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request")) 112 } 113 114 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId) 115 if err != nil { 116 return helpers.ServerError(e, to.StringPtr(err.Error())) 117 } 118 119 scopes := strings.Split(req.Parameters.Scope, " ") 120 appName := client.Metadata.ClientName 121 122 data := map[string]any{ 123 "Scopes": scopes, 124 "AppName": appName, 125 "RequestUri": input.RequestUri, 126 "QueryParams": e.QueryParams().Encode(), 127 "Handle": repo.Actor.Handle, 128 } 129 130 return e.Render(200, "authorize.html", data) 131} 132 133type OauthAuthorizePostRequest struct { 134 RequestUri string `form:"request_uri"` 135 AcceptOrRejct string `form:"accept_or_reject"` 136} 137 138func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 139 ctx := e.Request().Context() 140 logger := s.logger.With("name", "handleOauthAuthorizePost") 141 142 repo, _, err := s.getSessionRepoOrErr(e) 143 if err != nil { 144 return e.Redirect(303, "/account/signin") 145 } 146 147 var req OauthAuthorizePostRequest 148 if err := e.Bind(&req); err != nil { 149 logger.Error("error binding authorize post request", "error", err) 150 return helpers.InputError(e, nil) 151 } 152 153 reqId, err := oauth.DecodeRequestUri(req.RequestUri) 154 if err != nil { 155 return helpers.InputError(e, to.StringPtr(err.Error())) 156 } 157 158 var authReq provider.OauthAuthorizationRequest 159 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 160 return helpers.ServerError(e, to.StringPtr(err.Error())) 161 } 162 163 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId) 164 if err != nil { 165 return helpers.ServerError(e, to.StringPtr(err.Error())) 166 } 167 168 // TODO: figure out how im supposed to actually redirect 169 if req.AcceptOrRejct == "reject" { 170 return e.Redirect(303, client.Metadata.ClientURI) 171 } 172 173 if time.Now().After(authReq.ExpiresAt) { 174 return helpers.InputError(e, to.StringPtr("the request has expired")) 175 } 176 177 if authReq.Sub != nil || authReq.Code != nil { 178 return helpers.InputError(e, to.StringPtr("this request was already authorized")) 179 } 180 181 code := oauth.GenerateCode() 182 183 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 184 logger.Error("error updating authorization request", "error", err) 185 return helpers.ServerError(e, nil) 186 } 187 188 q := url.Values{} 189 q.Set("state", authReq.Parameters.State) 190 q.Set("iss", "https://"+s.config.Hostname) 191 q.Set("code", code) 192 193 hashOrQuestion := "?" 194 if authReq.Parameters.ResponseMode != nil { 195 switch *authReq.Parameters.ResponseMode { 196 case "fragment": 197 hashOrQuestion = "#" 198 case "query": 199 // do nothing 200 break 201 default: 202 if authReq.Parameters.ResponseType != "code" { 203 hashOrQuestion = "#" 204 } 205 } 206 } else { 207 if authReq.Parameters.ResponseType != "code" { 208 hashOrQuestion = "#" 209 } 210 } 211 212 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode()) 213}