1package server
2
3import (
4 "fmt"
5 "net/url"
6 "strings"
7 "time"
8
9 "github.com/Azure/go-autorest/autorest/to"
10 "github.com/haileyok/cocoon/internal/helpers"
11 "github.com/haileyok/cocoon/oauth"
12 "github.com/haileyok/cocoon/oauth/constants"
13 "github.com/haileyok/cocoon/oauth/provider"
14 "github.com/labstack/echo/v4"
15)
16
17type HandleOauthAuthorizeGetInput struct {
18 RequestUri string `query:"request_uri"`
19}
20
21func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
22 ctx := e.Request().Context()
23
24 logger := s.logger.With("name", "handleOauthAuthorizeGet")
25
26 var input HandleOauthAuthorizeGetInput
27 if err := e.Bind(&input); err != nil {
28 logger.Error("error binding request", "err", err)
29 return fmt.Errorf("error binding request")
30 }
31
32 var reqId string
33 if input.RequestUri != "" {
34 id, err := oauth.DecodeRequestUri(input.RequestUri)
35 if err != nil {
36 logger.Error("no request uri found in input", "url", e.Request().URL.String())
37 return helpers.InputError(e, to.StringPtr("no request uri"))
38 }
39 reqId = id
40 } else {
41 var parRequest provider.ParRequest
42 if err := e.Bind(&parRequest); err != nil {
43 s.logger.Error("error binding for standard auth request", "error", err)
44 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
45 }
46
47 if err := e.Validate(parRequest); err != nil {
48 // render page for logged out dev
49 if s.config.Version == "dev" && parRequest.ClientID == "" {
50 return e.Render(200, "authorize.html", map[string]any{
51 "Scopes": []string{"atproto", "transition:generic"},
52 "AppName": "DEV MODE AUTHORIZATION PAGE",
53 "Handle": "paula.cocoon.social",
54 "RequestUri": "",
55 })
56 }
57 return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters"))
58 }
59
60 client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{
61 AllowMissingDpopProof: true,
62 })
63 if err != nil {
64 s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err)
65 return helpers.ServerError(e, to.StringPtr(err.Error()))
66 }
67
68 if parRequest.DpopJkt == nil {
69 if client.Metadata.DpopBoundAccessTokens {
70 }
71 } else {
72 if !client.Metadata.DpopBoundAccessTokens {
73 msg := "dpop bound access tokens are not enabled for this client"
74 return helpers.InputError(e, &msg)
75 }
76 }
77
78 eat := time.Now().Add(constants.ParExpiresIn)
79 id := oauth.GenerateRequestId()
80
81 authRequest := &provider.OauthAuthorizationRequest{
82 RequestId: id,
83 ClientId: client.Metadata.ClientID,
84 ClientAuth: *clientAuth,
85 Parameters: parRequest,
86 ExpiresAt: eat,
87 }
88
89 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
90 s.logger.Error("error creating auth request in db", "error", err)
91 return helpers.ServerError(e, nil)
92 }
93
94 input.RequestUri = oauth.EncodeRequestUri(id)
95 reqId = id
96
97 }
98
99 repo, _, err := s.getSessionRepoOrErr(e)
100 if err != nil {
101 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
102 }
103
104 var req provider.OauthAuthorizationRequest
105 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
106 return helpers.ServerError(e, to.StringPtr(err.Error()))
107 }
108
109 clientId := e.QueryParam("client_id")
110 if clientId != req.ClientId {
111 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request"))
112 }
113
114 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId)
115 if err != nil {
116 return helpers.ServerError(e, to.StringPtr(err.Error()))
117 }
118
119 scopes := strings.Split(req.Parameters.Scope, " ")
120 appName := client.Metadata.ClientName
121
122 data := map[string]any{
123 "Scopes": scopes,
124 "AppName": appName,
125 "RequestUri": input.RequestUri,
126 "QueryParams": e.QueryParams().Encode(),
127 "Handle": repo.Actor.Handle,
128 }
129
130 return e.Render(200, "authorize.html", data)
131}
132
133type OauthAuthorizePostRequest struct {
134 RequestUri string `form:"request_uri"`
135 AcceptOrRejct string `form:"accept_or_reject"`
136}
137
138func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
139 ctx := e.Request().Context()
140 logger := s.logger.With("name", "handleOauthAuthorizePost")
141
142 repo, _, err := s.getSessionRepoOrErr(e)
143 if err != nil {
144 return e.Redirect(303, "/account/signin")
145 }
146
147 var req OauthAuthorizePostRequest
148 if err := e.Bind(&req); err != nil {
149 logger.Error("error binding authorize post request", "error", err)
150 return helpers.InputError(e, nil)
151 }
152
153 reqId, err := oauth.DecodeRequestUri(req.RequestUri)
154 if err != nil {
155 return helpers.InputError(e, to.StringPtr(err.Error()))
156 }
157
158 var authReq provider.OauthAuthorizationRequest
159 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
160 return helpers.ServerError(e, to.StringPtr(err.Error()))
161 }
162
163 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId)
164 if err != nil {
165 return helpers.ServerError(e, to.StringPtr(err.Error()))
166 }
167
168 // TODO: figure out how im supposed to actually redirect
169 if req.AcceptOrRejct == "reject" {
170 return e.Redirect(303, client.Metadata.ClientURI)
171 }
172
173 if time.Now().After(authReq.ExpiresAt) {
174 return helpers.InputError(e, to.StringPtr("the request has expired"))
175 }
176
177 if authReq.Sub != nil || authReq.Code != nil {
178 return helpers.InputError(e, to.StringPtr("this request was already authorized"))
179 }
180
181 code := oauth.GenerateCode()
182
183 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
184 logger.Error("error updating authorization request", "error", err)
185 return helpers.ServerError(e, nil)
186 }
187
188 q := url.Values{}
189 q.Set("state", authReq.Parameters.State)
190 q.Set("iss", "https://"+s.config.Hostname)
191 q.Set("code", code)
192
193 hashOrQuestion := "?"
194 if authReq.Parameters.ResponseMode != nil {
195 switch *authReq.Parameters.ResponseMode {
196 case "fragment":
197 hashOrQuestion = "#"
198 case "query":
199 // do nothing
200 break
201 default:
202 if authReq.Parameters.ResponseType != "code" {
203 hashOrQuestion = "#"
204 }
205 }
206 } else {
207 if authReq.Parameters.ResponseType != "code" {
208 hashOrQuestion = "#"
209 }
210 }
211
212 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
213}