Live video on the AT Protocol
1package oproxy
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "time"
9
10 "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/xrpc"
12 oauth "github.com/haileyok/atproto-oauth-golang"
13 "github.com/labstack/echo/v4"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "go.opentelemetry.io/otel"
16)
17
18func (o *OProxy) HandleOAuthReturn(c echo.Context) error {
19 ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleOAuthReturn")
20 defer span.End()
21 code := c.QueryParam("code")
22 iss := c.QueryParam("iss")
23 state := c.QueryParam("state")
24 errorCode := c.QueryParam("error")
25 errorDescription := c.QueryParam("error_description")
26 var httpError *echo.HTTPError
27 var redirectURL string
28 if errorCode != "" {
29 httpError = echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("%s (%s)", errorDescription, errorCode))
30 } else {
31 redirectURL, httpError = o.Return(ctx, code, iss, state)
32 }
33 if httpError != nil {
34 // we're a redirect; if we fail we need to send the user back
35 jkt, _, err := parseState(state)
36 if err != nil {
37 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse URN: %s", err))
38 }
39
40 session, err := o.loadOAuthSession(jkt)
41 if err != nil {
42 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to load OAuth session jkt=%s: %s", jkt, err))
43 }
44
45 u, err := url.Parse(session.DownstreamRedirectURI)
46 if err != nil {
47 return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse downstream redirect URI: %s", err))
48 }
49 q := u.Query()
50 q.Set("error", "return_failed")
51 q.Set("error_description", httpError.Error())
52 u.RawQuery = q.Encode()
53 return c.Redirect(http.StatusTemporaryRedirect, u.String())
54 }
55 return c.Redirect(http.StatusTemporaryRedirect, redirectURL)
56}
57
58func (o *OProxy) Return(ctx context.Context, code string, iss string, state string) (string, *echo.HTTPError) {
59 upstreamMeta := o.GetUpstreamMetadata()
60 oclient, err := oauth.NewClient(oauth.ClientArgs{
61 ClientJwk: o.upstreamJWK,
62 ClientId: upstreamMeta.ClientID,
63 RedirectUri: upstreamMeta.RedirectURIs[0],
64 })
65
66 jkt, _, err := parseState(state)
67 if err != nil {
68 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse state: %s", err))
69 }
70
71 session, err := o.loadOAuthSession(jkt)
72 if err != nil {
73 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to get OAuth session: %s", err))
74 }
75 if session == nil {
76 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("no OAuth session found for state: %s", state))
77 }
78
79 if session.Status() != OAuthSessionStateUpstream {
80 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("session is not in upstream state: %s", session.Status()))
81 }
82
83 if session.UpstreamState != state {
84 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("state mismatch: %s != %s", session.UpstreamState, state))
85 }
86
87 if iss != session.UpstreamAuthServerIssuer {
88 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("issuer mismatch: %s != %s", iss, session.UpstreamAuthServerIssuer))
89 }
90
91 key, err := jwk.ParseKey([]byte(session.UpstreamDPoPPrivateJWK))
92 if err != nil {
93 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to parse DPoP private JWK: %s", err))
94 }
95
96 itResp, err := oclient.InitialTokenRequest(ctx, code, iss, session.UpstreamPKCEVerifier, session.UpstreamDPoPNonce, key)
97 if err != nil {
98 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to request initial token: %s", err))
99 }
100 now := time.Now()
101
102 if itResp.Sub != session.DID {
103 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("sub mismatch: %s != %s", itResp.Sub, session.DID))
104 }
105
106 if itResp.Scope != upstreamMeta.Scope {
107 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("scope mismatch: %s != %s", itResp.Scope, upstreamMeta.Scope))
108 }
109
110 downstreamCode, err := generateAuthorizationCode()
111 if err != nil {
112 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate downstream code: %s", err))
113 }
114
115 expiry := now.Add(time.Second * time.Duration(itResp.ExpiresIn)).UTC()
116 session.UpstreamAccessToken = itResp.AccessToken
117 session.UpstreamAccessTokenExp = &expiry
118 session.UpstreamRefreshToken = itResp.RefreshToken
119 session.DownstreamAuthorizationCode = downstreamCode
120
121 authArgs := &oauth.XrpcAuthedRequestArgs{
122 Did: session.DID,
123 AccessToken: session.UpstreamAccessToken,
124 PdsUrl: session.PDSUrl,
125 Issuer: session.UpstreamAuthServerIssuer,
126 DpopPdsNonce: session.UpstreamDPoPNonce,
127 DpopPrivateJwk: key,
128 }
129
130 xrpcClient := &oauth.XrpcClient{
131 OnDpopPdsNonceChanged: func(did, newNonce string) {},
132 }
133
134 // brief check to make sure we can actually do stuff
135 var out atproto.ServerCheckAccountStatus_Output
136 if err := xrpcClient.Do(ctx, authArgs, xrpc.Query, "application/json", "com.atproto.server.checkAccountStatus", nil, nil, &out); err != nil {
137 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to check account status: %s", err))
138 }
139
140 err = o.updateOAuthSession(session.DownstreamDPoPJKT, session)
141 if err != nil {
142 return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update OAuth session: %s", err))
143 }
144
145 u, err := url.Parse(session.DownstreamRedirectURI)
146 if err != nil {
147 return "", echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed to parse downstream redirect URI: %s", err))
148 }
149 q := u.Query()
150 q.Set("iss", fmt.Sprintf("https://%s", o.host))
151 q.Set("state", session.DownstreamState)
152 q.Set("code", session.DownstreamAuthorizationCode)
153 u.RawQuery = q.Encode()
154
155 return u.String(), nil
156}