1package server
2
3import (
4 "net/url"
5 "strings"
6 "time"
7
8 "github.com/Azure/go-autorest/autorest/to"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/oauth"
11 "github.com/haileyok/cocoon/oauth/provider"
12 "github.com/labstack/echo/v4"
13)
14
15func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
16 ctx := e.Request().Context()
17
18 reqUri := e.QueryParam("request_uri")
19 if reqUri == "" {
20 // render page for logged out dev
21 if s.config.Version == "dev" {
22 return e.Render(200, "authorize.html", map[string]any{
23 "Scopes": []string{"atproto", "transition:generic"},
24 "AppName": "DEV MODE AUTHORIZATION PAGE",
25 "Handle": "paula.cocoon.social",
26 "RequestUri": "",
27 })
28 }
29 return helpers.InputError(e, to.StringPtr("no request uri"))
30 }
31
32 repo, _, err := s.getSessionRepoOrErr(e)
33 if err != nil {
34 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
35 }
36
37 reqId, err := oauth.DecodeRequestUri(reqUri)
38 if err != nil {
39 return helpers.InputError(e, to.StringPtr(err.Error()))
40 }
41
42 var req provider.OauthAuthorizationRequest
43 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
44 return helpers.ServerError(e, to.StringPtr(err.Error()))
45 }
46
47 clientId := e.QueryParam("client_id")
48 if clientId != req.ClientId {
49 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request"))
50 }
51
52 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId)
53 if err != nil {
54 return helpers.ServerError(e, to.StringPtr(err.Error()))
55 }
56
57 scopes := strings.Split(req.Parameters.Scope, " ")
58 appName := client.Metadata.ClientName
59
60 data := map[string]any{
61 "Scopes": scopes,
62 "AppName": appName,
63 "RequestUri": reqUri,
64 "QueryParams": e.QueryParams().Encode(),
65 "Handle": repo.Actor.Handle,
66 }
67
68 return e.Render(200, "authorize.html", data)
69}
70
71type OauthAuthorizePostRequest struct {
72 RequestUri string `form:"request_uri"`
73 AcceptOrRejct string `form:"accept_or_reject"`
74}
75
76func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
77 ctx := e.Request().Context()
78 logger := s.logger.With("name", "handleOauthAuthorizePost")
79
80 repo, _, err := s.getSessionRepoOrErr(e)
81 if err != nil {
82 return e.Redirect(303, "/account/signin")
83 }
84
85 var req OauthAuthorizePostRequest
86 if err := e.Bind(&req); err != nil {
87 logger.Error("error binding authorize post request", "error", err)
88 return helpers.InputError(e, nil)
89 }
90
91 reqId, err := oauth.DecodeRequestUri(req.RequestUri)
92 if err != nil {
93 return helpers.InputError(e, to.StringPtr(err.Error()))
94 }
95
96 var authReq provider.OauthAuthorizationRequest
97 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
98 return helpers.ServerError(e, to.StringPtr(err.Error()))
99 }
100
101 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId)
102 if err != nil {
103 return helpers.ServerError(e, to.StringPtr(err.Error()))
104 }
105
106 // TODO: figure out how im supposed to actually redirect
107 if req.AcceptOrRejct == "reject" {
108 return e.Redirect(303, client.Metadata.ClientURI)
109 }
110
111 if time.Now().After(authReq.ExpiresAt) {
112 return helpers.InputError(e, to.StringPtr("the request has expired"))
113 }
114
115 if authReq.Sub != nil || authReq.Code != nil {
116 return helpers.InputError(e, to.StringPtr("this request was already authorized"))
117 }
118
119 code := oauth.GenerateCode()
120
121 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
122 logger.Error("error updating authorization request", "error", err)
123 return helpers.ServerError(e, nil)
124 }
125
126 q := url.Values{}
127 q.Set("state", authReq.Parameters.State)
128 q.Set("iss", "https://"+s.config.Hostname)
129 q.Set("code", code)
130
131 hashOrQuestion := "?"
132 if authReq.ClientAuth.Method != "private_key_jwt" {
133 hashOrQuestion = "#"
134 }
135
136 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
137}