Live video on the AT Protocol
1package oproxy
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "net/url"
9 "time"
10
11 oauth "github.com/haileyok/atproto-oauth-golang"
12 "github.com/haileyok/atproto-oauth-golang/helpers"
13 "github.com/labstack/echo/v4"
14 "go.opentelemetry.io/otel"
15)
16
17func (o *OProxy) HandleOAuthAuthorize(c echo.Context) error {
18 ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleOAuthAuthorize")
19 defer span.End()
20 c.Response().Header().Set("Access-Control-Allow-Origin", "*")
21 requestURI := c.QueryParam("request_uri")
22 if requestURI == "" {
23 return echo.NewHTTPError(http.StatusBadRequest, "request_uri is required")
24 }
25 clientID := c.QueryParam("client_id")
26 if clientID == "" {
27 return echo.NewHTTPError(http.StatusBadRequest, "client_id is required")
28 }
29 redirectURL, err := o.Authorize(ctx, requestURI, clientID)
30 if err != nil {
31 // we're a redirect; if we fail we need to send the user back
32 jkt, _, err := parseURN(requestURI)
33 if err != nil {
34 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse URN: %s", err))
35 }
36
37 session, err := o.loadOAuthSession(jkt)
38 if err != nil {
39 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to load OAuth session jkt=%s: %s", jkt, err))
40 }
41
42 u, err := url.Parse(session.DownstreamRedirectURI)
43 if err != nil {
44 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse downstream redirect URI: %s", err))
45 }
46 q := u.Query()
47 q.Set("error", "authorize_failed")
48 q.Set("error_description", err.Error())
49 u.RawQuery = q.Encode()
50 return c.Redirect(http.StatusTemporaryRedirect, u.String())
51 }
52 return c.Redirect(http.StatusTemporaryRedirect, redirectURL)
53}
54
55// downstream --> upstream transition; attempt to send user to the upstream auth server
56func (o *OProxy) Authorize(ctx context.Context, requestURI, clientID string) (string, *echo.HTTPError) {
57 downstreamMeta, err := o.GetDownstreamMetadata("")
58 if err != nil {
59 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to get downstream metadata: %s", err))
60 }
61 if downstreamMeta.ClientID != clientID {
62 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("client ID mismatch: %s != %s", downstreamMeta.ClientID, clientID))
63 }
64
65 jkt, _, err := parseURN(requestURI)
66 if err != nil {
67 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse URN: %s", err))
68 }
69
70 session, err := o.loadOAuthSession(jkt)
71 if err != nil {
72 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to load OAuth session jkt=%s: %s", jkt, err))
73 }
74
75 if session == nil {
76 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("no session found for jkt=%s", jkt))
77 }
78
79 if session.Status() != OAuthSessionStatePARCreated {
80 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("session is not in par-created state: %s", session.Status()))
81 }
82
83 if session.DownstreamPARRequestURI != requestURI {
84 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("request URI mismatch: %s != %s", session.DownstreamPARRequestURI, requestURI))
85 }
86
87 now := time.Now()
88 session.DownstreamPARUsedAt = &now
89 err = o.updateOAuthSession(jkt, session)
90 if err != nil {
91 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update OAuth session: %s", err))
92 }
93
94 upstreamMeta := o.GetUpstreamMetadata()
95 oclient, err := oauth.NewClient(oauth.ClientArgs{
96 ClientJwk: o.upstreamJWK,
97 ClientId: upstreamMeta.ClientID,
98 RedirectUri: upstreamMeta.RedirectURIs[0],
99 })
100 if err != nil {
101 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create OAuth client: %s", err))
102 }
103
104 did, err := ResolveHandle(ctx, session.Handle)
105 if err != nil {
106 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve handle '%s': %s", session.DID, err))
107 }
108
109 service, err := ResolveService(ctx, did)
110 if err != nil {
111 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve service for DID '%s': %s", did, err))
112 }
113
114 authserver, err := oclient.ResolvePdsAuthServer(ctx, service)
115 if err != nil {
116 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to resolve PDS auth server for service '%s': %s", service, err))
117 }
118
119 authmeta, err := oclient.FetchAuthServerMetadata(ctx, authserver)
120 if err != nil {
121 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to fetch auth server metadata from '%s': %s", authserver, err))
122 }
123
124 k, err := helpers.GenerateKey(nil)
125 if err != nil {
126 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate DPoP key: %s", err))
127 }
128
129 state := makeState(jkt)
130
131 opts := oauth.ParAuthRequestOpts{
132 State: state,
133 }
134 parResp, err := oclient.SendParAuthRequest(ctx, authserver, authmeta, session.Handle, upstreamMeta.Scope, k, opts)
135 if err != nil {
136 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to send PAR auth request to '%s': %s", authserver, err))
137 }
138
139 jwkJSON, err := json.Marshal(k)
140 if err != nil {
141 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to marshal DPoP key to JSON: %s", err))
142 }
143
144 u, err := url.Parse(authmeta.AuthorizationEndpoint)
145 if err != nil {
146 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse auth server metadata: %s", err))
147 }
148 u.RawQuery = fmt.Sprintf("client_id=%s&request_uri=%s", url.QueryEscape(upstreamMeta.ClientID), parResp.RequestUri)
149 str := u.String()
150
151 session.DID = did
152 session.PDSUrl = service
153 session.UpstreamState = parResp.State
154 session.UpstreamAuthServerIssuer = authserver
155 session.UpstreamPKCEVerifier = parResp.PkceVerifier
156 session.UpstreamDPoPNonce = parResp.DpopAuthserverNonce
157 session.UpstreamDPoPPrivateJWK = string(jwkJSON)
158
159 err = o.updateOAuthSession(jkt, session)
160 if err != nil {
161 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update OAuth session: %s", err))
162 }
163
164 return str, nil
165}