1package oauth
2
3import (
4 "context"
5 "crypto/rand"
6 "crypto/sha256"
7 "encoding/base64"
8 "errors"
9 "fmt"
10 "log"
11 "net/http"
12 "strings"
13
14 "github.com/teal-fm/piper/session"
15 "golang.org/x/oauth2"
16 "golang.org/x/oauth2/spotify"
17)
18
19type OAuth2Service struct {
20 config oauth2.Config
21 state string
22 codeVerifier string
23 codeChallenge string
24 tokenReceiver TokenReceiver
25}
26
27func GenerateRandomState() string {
28 b := make([]byte, 16)
29 rand.Read(b)
30 return base64.URLEncoding.EncodeToString(b)
31}
32
33func NewOAuth2Service(clientID, clientSecret, redirectURI string, scopes []string, provider string, tokenReceiver TokenReceiver) *OAuth2Service {
34 var endpoint oauth2.Endpoint
35
36 switch strings.ToLower(provider) {
37 case "spotify":
38 endpoint = spotify.Endpoint
39 default:
40 // placeholder
41 log.Printf("Warning: OAuth2 provider '%s' not explicitly configured. Using placeholder endpoints.", provider)
42 endpoint = oauth2.Endpoint{
43 AuthURL: "https://example.com/auth",
44 TokenURL: "https://example.com/token",
45 }
46 }
47
48 codeVerifier := GenerateCodeVerifier()
49 codeChallenge := GenerateCodeChallenge(codeVerifier)
50
51 return &OAuth2Service{
52 config: oauth2.Config{
53 ClientID: clientID,
54 ClientSecret: clientSecret,
55 RedirectURL: redirectURI,
56 Scopes: scopes,
57 Endpoint: endpoint,
58 },
59 state: GenerateRandomState(),
60 codeVerifier: codeVerifier,
61 codeChallenge: codeChallenge,
62 tokenReceiver: tokenReceiver,
63 }
64}
65
66// generate a random code verifier, for PKCE
67func GenerateCodeVerifier() string {
68 b := make([]byte, 64)
69 rand.Read(b)
70 return base64.RawURLEncoding.EncodeToString(b)
71}
72
73// generate a code challenge for verification later
74func GenerateCodeChallenge(verifier string) string {
75 h := sha256.New()
76 h.Write([]byte(verifier))
77 return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
78}
79
80func (o *OAuth2Service) HandleLogin(w http.ResponseWriter, r *http.Request) {
81 opts := []oauth2.AuthCodeOption{
82 oauth2.SetAuthURLParam("code_challenge", o.codeChallenge),
83 oauth2.SetAuthURLParam("code_challenge_method", "S256"),
84 }
85 authURL := o.config.AuthCodeURL(o.state, opts...)
86 http.Redirect(w, r, authURL, http.StatusSeeOther)
87}
88
89func (o *OAuth2Service) HandleLogout(w http.ResponseWriter, r *http.Request) {
90 //TODO not implemented yet. not sure what the api call is for this package
91 http.Redirect(w, r, "/", http.StatusSeeOther)
92}
93
94func (o *OAuth2Service) HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error) {
95 state := r.URL.Query().Get("state")
96 if state != o.state {
97 log.Printf("OAuth2 Callback Error: State mismatch. Expected '%s', got '%s'", o.state, state)
98 http.Error(w, "State mismatch", http.StatusBadRequest)
99 return 0, errors.New("state mismatch")
100 }
101
102 code := r.URL.Query().Get("code")
103 if code == "" {
104 errMsg := r.URL.Query().Get("error")
105 errDesc := r.URL.Query().Get("error_description")
106 log.Printf("OAuth2 Callback Error: No code provided. Error: '%s', Description: '%s'", errMsg, errDesc)
107 http.Error(w, fmt.Sprintf("Authorization failed: %s (%s)", errMsg, errDesc), http.StatusBadRequest)
108 return 0, errors.New("no code provided")
109 }
110
111 if o.tokenReceiver == nil {
112 log.Printf("OAuth2 Callback Error: TokenReceiver is not configured for this service.")
113 http.Error(w, "Internal server configuration error", http.StatusInternalServerError)
114 return 0, errors.New("token receiver not configured")
115 }
116
117 opts := []oauth2.AuthCodeOption{
118 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier),
119 }
120
121 log.Println(code)
122
123 token, err := o.config.Exchange(context.Background(), code, opts...)
124 if err != nil {
125 log.Printf("OAuth2 Callback Error: Failed to exchange code for token: %v", err)
126 http.Error(w, fmt.Sprintf("Error exchanging code for token: %v", err), http.StatusInternalServerError)
127 return 0, errors.New("failed to exchange code for token")
128 }
129
130 userId, hasSession := session.GetUserID(r.Context())
131 // store token and get uid
132 userID, err := o.tokenReceiver.SetAccessToken(token.AccessToken, token.RefreshToken, userId, hasSession)
133 if err != nil {
134 log.Printf("OAuth2 Callback Info: TokenReceiver did not return a valid user ID for token: %s...", token.AccessToken[:min(10, len(token.AccessToken))])
135 }
136
137 log.Printf("OAuth2 Callback Success: Exchanged code for token, UserID: %d", userID)
138 return userID, nil
139}
140
141func (o *OAuth2Service) GetToken(code string) (*oauth2.Token, error) {
142 opts := []oauth2.AuthCodeOption{
143 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier),
144 }
145 return o.config.Exchange(context.Background(), code, opts...)
146}
147
148func (o *OAuth2Service) GetClient(token *oauth2.Token) *http.Client {
149 return o.config.Client(context.Background(), token)
150}
151
152func (o *OAuth2Service) RefreshToken(token *oauth2.Token) (*oauth2.Token, error) {
153 source := o.config.TokenSource(context.Background(), token)
154 return oauth2.ReuseTokenSource(token, source).Token()
155}
156
157func min(a, b int) int {
158 if a < b {
159 return a
160 }
161 return b
162}