+1
-9
config/config.go
+1
-9
config/config.go
···
10
10
11
11
// Load initializes the configuration with viper
12
12
func Load() {
13
-
// Load .env file if it exists
14
13
if err := godotenv.Load(); err != nil {
15
14
log.Println("No .env file found or error loading it. Using default values and environment variables.")
16
15
}
17
16
18
-
// Set default configurations
19
17
viper.SetDefault("server.port", "8080")
20
18
viper.SetDefault("server.host", "localhost")
21
19
viper.SetDefault("callback.spotify", "http://localhost:8080/callback/spotify")
···
30
28
viper.SetDefault("atproto.metadata_url", "http://localhost:8080/metadata")
31
29
viper.SetDefault("atproto.callback_url", "/metadata")
32
30
33
-
// Configure Viper to read environment variables
34
31
viper.AutomaticEnv()
35
32
36
-
// Replace dots with underscores for environment variables
37
33
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
38
34
39
-
// Set the config name and paths
40
35
viper.SetConfigName("config")
41
36
viper.SetConfigType("yaml")
42
37
viper.AddConfigPath("./config")
43
38
viper.AddConfigPath(".")
44
39
45
-
// Try to read the config file
46
40
if err := viper.ReadInConfig(); err != nil {
47
41
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
48
-
// It's not a "file not found" error, so it's a real error
49
42
log.Fatalf("Error reading config file: %v", err)
50
43
}
51
-
// Config file not found, using defaults and environment variables
52
44
log.Println("Config file not found, using default values and environment variables")
53
45
} else {
54
46
log.Println("Using config file:", viper.ConfigFileUsed())
55
47
}
56
48
57
-
// Check if required values are present
49
+
// check for required settings
58
50
requiredVars := []string{"spotify.client_id", "spotify.client_secret"}
59
51
missingVars := []string{}
60
52
+6
-10
db/atproto.go
+6
-10
db/atproto.go
···
39
39
40
40
func (db *DB) GetATprotoAuthData(state string) (*models.ATprotoAuthData, error) {
41
41
var data models.ATprotoAuthData
42
-
var dpopPrivateJWKString string // Temporary variable to hold the JSON string
42
+
var dpopPrivateJWKString string
43
43
44
44
err := db.QueryRow(`
45
45
SELECT state, did, pds_url, authserver_issuer, pkce_verifier, dpop_authserver_nonce, dpop_private_jwk
···
52
52
&data.AuthServerIssuer,
53
53
&data.PKCEVerifier,
54
54
&data.DPoPAuthServerNonce,
55
-
&dpopPrivateJWKString, // Scan into the temporary string
55
+
&dpopPrivateJWKString,
56
56
)
57
57
if err != nil {
58
-
// Return the original scan error if it occurred
59
58
if err == sql.ErrNoRows {
60
59
return nil, fmt.Errorf("no auth data found for state %s: %w", state, err)
61
60
}
···
64
63
65
64
key, err := helpers.ParseJWKFromBytes([]byte(dpopPrivateJWKString))
66
65
if err != nil {
67
-
// Return an error if parsing fails
68
66
return nil, fmt.Errorf("failed to parse DPoPPrivateJWK for state %s: %w", state, err)
69
67
}
70
68
data.DPoPPrivateJWK = key
71
69
72
-
return &data, nil // Return nil error on success
70
+
return &data, nil
73
71
}
74
72
75
73
func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) {
···
97
95
if idErr != nil {
98
96
return nil, fmt.Errorf("failed to get last insert id: %w", idErr)
99
97
}
100
-
// Populate the user struct with the newly created user's data
101
98
user.ID = lastID
102
99
user.ATProtoDID = &did
103
100
user.CreatedAt = now
104
101
user.UpdatedAt = now
105
-
return &user, nil // Return the created user and nil error
102
+
return &user, nil
106
103
} else if err != nil {
107
-
// Handle other potential errors from QueryRow
108
104
return nil, fmt.Errorf("failed to find user by DID: %w", err)
109
105
}
110
106
111
107
return &user, err
112
108
}
113
109
114
-
// Create or update the current user's ATproto session data.
110
+
// create or update the current user's ATproto session data.
115
111
func (db *DB) SaveATprotoSession(tokenResp *oauth.TokenResponse) error {
116
112
117
113
expiryTime := time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn))
···
141
137
142
138
rowsAffected, err := result.RowsAffected()
143
139
if err != nil {
144
-
// Error checking RowsAffected, but the update might have succeeded
140
+
// it's possible the update succeeded here?
145
141
return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %w", tokenResp.Sub, err)
146
142
}
147
143
+5
-11
db/db.go
+5
-11
db/db.go
···
11
11
"github.com/teal-fm/piper/models"
12
12
)
13
13
14
-
// DB is a wrapper around sql.DB
15
14
type DB struct {
16
15
*sql.DB
17
16
}
···
36
35
}
37
36
38
37
func (db *DB) Initialize() error {
39
-
// Create users table
40
38
_, err := db.Exec(`
41
39
CREATE TABLE IF NOT EXISTS users (
42
40
id INTEGER PRIMARY KEY AUTOINCREMENT,
···
59
57
return err
60
58
}
61
59
62
-
// Create tracks table
63
60
_, err = db.Exec(`
64
61
CREATE TABLE IF NOT EXISTS tracks (
65
62
id INTEGER PRIMARY KEY AUTOINCREMENT,
···
115
112
return result.LastInsertId()
116
113
}
117
114
118
-
// Add spotify session to user, returning the updated user
115
+
// add spotify session to user, returning the updated user
119
116
func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) {
120
117
now := time.Now()
121
118
···
191
188
}
192
189
193
190
func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) {
194
-
// Convert the Artist array to a string for storage
191
+
// marshal artist json
195
192
artistString := ""
196
193
if len(track.Artist) > 0 {
197
194
bytes, err := json.Marshal(track.Artist)
···
215
212
}
216
213
217
214
func (db *DB) UpdateTrack(trackID int64, track *models.Track) error {
218
-
// Convert the Artist array to a string for storage
219
-
// In a production environment, you'd want to use proper JSON serialization
215
+
// marshal artist json
220
216
artistString := ""
221
217
if len(track.Artist) > 0 {
222
218
bytes, err := json.Marshal(track.Artist)
···
248
244
}
249
245
250
246
func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) {
251
-
// convert previous-format artist strings to current-format
252
-
253
247
rows, err := db.Query(`
254
248
SELECT id, name, artist, album, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped
255
249
FROM tracks
···
270
264
err := rows.Scan(
271
265
&track.PlayID,
272
266
&track.Name,
273
-
&artistString, // Scan into a string first
267
+
&artistString, // scan to be unmarshaled later
274
268
&track.Album,
275
269
&track.URL,
276
270
&track.Timestamp,
···
285
279
return nil, err
286
280
}
287
281
288
-
// Convert the artist string to the Artist array structure
282
+
// unmarshal artist json
289
283
var artists []models.Artist
290
284
err = json.Unmarshal([]byte(artistString), &artists)
291
285
if err != nil {
+3
-7
main.go
+3
-7
main.go
···
21
21
func home(w http.ResponseWriter, r *http.Request) {
22
22
w.Header().Set("Content-Type", "text/html")
23
23
24
-
// Check if user has an active session cookie
24
+
// check if user has an active session cookie
25
25
cookie, err := r.Cookie("session")
26
26
isLoggedIn := err == nil && cookie != nil
27
-
// TODO: Add logic here to fetch user details from DB using session ID
28
-
// to check if Spotify is already connected, if desired for finer control.
29
-
// For now, we'll just check if *any* session exists.
27
+
// TODO: add logic here to fetch user details from DB using session ID
28
+
// to check if Spotify is already connected
30
29
31
30
html := `
32
31
<html>
···
106
105
107
106
// JSON API handlers
108
107
109
-
// jsonResponse returns a JSON response
110
108
func jsonResponse(w http.ResponseWriter, statusCode int, data any) {
111
109
w.Header().Set("Content-Type", "application/json")
112
110
w.WriteHeader(statusCode)
···
115
113
}
116
114
}
117
115
118
-
// API endpoint for current track
119
116
func apiCurrentTrack(spotifyService *spotify.SpotifyService) http.HandlerFunc {
120
117
return func(w http.ResponseWriter, r *http.Request) {
121
118
userID, ok := session.GetUserID(r.Context())
···
134
131
}
135
132
}
136
133
137
-
// API endpoint for history
138
134
func apiTrackHistory(spotifyService *spotify.SpotifyService) http.HandlerFunc {
139
135
return func(w http.ResponseWriter, r *http.Request) {
140
136
userID, ok := session.GetUserID(r.Context())
-1
models/atproto.go
-1
models/atproto.go
+19
-14
models/user.go
+19
-14
models/user.go
···
2
2
3
3
import "time"
4
4
5
-
// User represents a user of the application
5
+
// an end user of piper
6
6
type User struct {
7
-
ID int64
8
-
Username string
9
-
Email *string // Use pointer for nullable fields
10
-
SpotifyID *string // Use pointer for nullable fields
11
-
AccessToken *string // Spotify Access Token
12
-
RefreshToken *string // Spotify Refresh Token
13
-
TokenExpiry *time.Time // Spotify Token Expiry
14
-
CreatedAt time.Time
15
-
UpdatedAt time.Time
16
-
ATProtoDID *string // ATProto DID
17
-
ATProtoAccessToken *string // ATProto Access Token
18
-
ATProtoRefreshToken *string // ATProto Refresh Token
19
-
ATProtoTokenExpiry *time.Time // ATProto Token Expiry
7
+
ID int64
8
+
Username string
9
+
Email *string
10
+
11
+
// spotify information
12
+
SpotifyID *string
13
+
AccessToken *string
14
+
RefreshToken *string
15
+
TokenExpiry *time.Time
16
+
17
+
// atp info
18
+
ATProtoDID *string
19
+
ATProtoAccessToken *string
20
+
ATProtoRefreshToken *string
21
+
ATProtoTokenExpiry *time.Time
22
+
23
+
CreatedAt time.Time
24
+
UpdatedAt time.Time
20
25
}
+2
-3
oauth/atproto/atproto.go
+2
-3
oauth/atproto/atproto.go
···
1
-
// Modify piper/oauth/atproto/atproto.go
2
1
package atproto
3
2
4
3
import (
···
88
87
return nil, fmt.Errorf("failed PAR request to %s: %w", ui.AuthServer, err)
89
88
}
90
89
91
-
// Save state including generated PKCE verifier and DPoP key
90
+
// Save state
92
91
data := &models.ATprotoAuthData{
93
92
State: parResp.State,
94
93
DID: ui.DID,
···
171
170
}
172
171
173
172
log.Printf("ATProto Callback Success: User %d (DID: %s) authenticated.", userID.ID, data.DID)
174
-
return userID.ID, nil // Return the piper user ID
173
+
return userID.ID, nil
175
174
}
+6
-14
oauth/oauth2.go
+6
-14
oauth/oauth2.go
···
1
-
// Modify piper/oauth/oauth2.go
2
1
package oauth
3
2
4
3
import (
···
22
21
state string
23
22
codeVerifier string
24
23
codeChallenge string
25
-
// Added TokenReceiver field to handle user lookup/creation based on token
26
24
tokenReceiver TokenReceiver
27
25
}
28
26
···
38
36
switch strings.ToLower(provider) {
39
37
case "spotify":
40
38
endpoint = spotify.Endpoint
41
-
// Add other providers like Last.fm here
42
39
default:
43
-
// Placeholder for unconfigured providers
40
+
// placeholder
44
41
log.Printf("Warning: OAuth2 provider '%s' not explicitly configured. Using placeholder endpoints.", provider)
45
42
endpoint = oauth2.Endpoint{
46
-
AuthURL: "https://example.com/auth", // Replace with actual endpoints if needed
43
+
AuthURL: "https://example.com/auth",
47
44
TokenURL: "https://example.com/token",
48
45
}
49
46
}
···
62
59
state: GenerateRandomState(),
63
60
codeVerifier: codeVerifier,
64
61
codeChallenge: codeChallenge,
65
-
tokenReceiver: tokenReceiver, // Store the token receiver
62
+
tokenReceiver: tokenReceiver,
66
63
}
67
64
}
68
65
69
-
// generateCodeVerifier creates a random code verifier for PKCE
66
+
// generate a random code verifier, for PKCE
70
67
func GenerateCodeVerifier() string {
71
68
b := make([]byte, 64)
72
69
rand.Read(b)
73
70
return base64.RawURLEncoding.EncodeToString(b)
74
71
}
75
72
76
-
// generateCodeChallenge creates a code challenge from the code verifier using S256 method
73
+
// generate a code challenge for verification later
77
74
func GenerateCodeChallenge(verifier string) string {
78
75
h := sha256.New()
79
76
h.Write([]byte(verifier))
80
77
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
81
78
}
82
79
83
-
// HandleLogin implements the AuthService interface method.
84
80
func (o *OAuth2Service) HandleLogin(w http.ResponseWriter, r *http.Request) {
85
81
opts := []oauth2.AuthCodeOption{
86
82
oauth2.SetAuthURLParam("code_challenge", o.codeChallenge),
···
128
124
129
125
userId, hasSession := session.GetUserID(r.Context())
130
126
131
-
// Use the token receiver to store the token and get the user ID
127
+
// store token and get uid
132
128
userID, err := o.tokenReceiver.SetAccessToken(token.AccessToken, userId, hasSession)
133
129
if err != nil {
134
130
log.Printf("OAuth2 Callback Info: TokenReceiver did not return a valid user ID for token: %s...", token.AccessToken[:min(10, len(token.AccessToken))])
···
138
134
return userID, nil
139
135
}
140
136
141
-
// GetToken remains unchanged
142
137
func (o *OAuth2Service) GetToken(code string) (*oauth2.Token, error) {
143
138
opts := []oauth2.AuthCodeOption{
144
139
oauth2.SetAuthURLParam("code_verifier", o.codeVerifier),
···
146
141
return o.config.Exchange(context.Background(), code, opts...)
147
142
}
148
143
149
-
// GetClient remains unchanged
150
144
func (o *OAuth2Service) GetClient(token *oauth2.Token) *http.Client {
151
145
return o.config.Client(context.Background(), token)
152
146
}
153
147
154
-
// RefreshToken remains unchanged
155
148
func (o *OAuth2Service) RefreshToken(token *oauth2.Token) (*oauth2.Token, error) {
156
149
source := o.config.TokenSource(context.Background(), token)
157
150
return oauth2.ReuseTokenSource(token, source).Token()
158
151
}
159
152
160
-
// Helper function
161
153
func min(a, b int) int {
162
154
if a < b {
163
155
return a
+9
-14
oauth/oauth_manager.go
+9
-14
oauth/oauth_manager.go
···
12
12
13
13
// manages multiple oauth client services
14
14
type OAuthServiceManager struct {
15
-
services map[string]AuthService // Changed from *OAuth2Service to AuthService interface
15
+
services map[string]AuthService
16
16
sessionManager *session.SessionManager
17
17
mu sync.RWMutex
18
18
}
19
19
20
20
func NewOAuthServiceManager() *OAuthServiceManager {
21
21
return &OAuthServiceManager{
22
-
services: make(map[string]AuthService), // Initialize the new map
22
+
services: make(map[string]AuthService),
23
23
sessionManager: session.NewSessionManager(),
24
24
}
25
25
}
26
26
27
-
// RegisterService registers any service that implements the AuthService interface.
27
+
// registers any service that impls AuthService
28
28
func (m *OAuthServiceManager) RegisterService(name string, service AuthService) {
29
29
m.mu.Lock()
30
30
defer m.mu.Unlock()
···
32
32
log.Printf("Registered auth service: %s", name)
33
33
}
34
34
35
-
// GetService retrieves a registered AuthService by name.
35
+
// get an AuthService by registered name
36
36
func (m *OAuthServiceManager) GetService(name string) (AuthService, bool) {
37
37
m.mu.RLock()
38
38
defer m.mu.RUnlock()
···
47
47
m.mu.RUnlock()
48
48
49
49
if exists {
50
-
service.HandleLogin(w, r) // Call interface method
50
+
service.HandleLogin(w, r)
51
51
return
52
52
}
53
53
···
70
70
return
71
71
}
72
72
73
-
// Call the service's HandleCallback, which now returns the user ID
74
-
userID, err := service.HandleCallback(w, r) // Call interface method
73
+
userID, err := service.HandleCallback(w, r)
75
74
76
75
if err != nil {
77
76
log.Printf("Error handling callback for service '%s': %v", serviceName, err)
···
80
79
}
81
80
82
81
if userID > 0 {
83
-
// Create session for the user
84
82
session := m.sessionManager.CreateSession(userID)
85
83
86
-
// Set session cookie
87
84
m.sessionManager.SetSessionCookie(w, session)
88
85
89
86
log.Printf("Created session for user %d via service %s", userID, serviceName)
90
87
91
-
// Redirect to homepage after successful login and session creation
92
88
http.Redirect(w, r, "/", http.StatusSeeOther)
93
89
} else {
94
90
log.Printf("Callback for service '%s' did not result in a valid user ID.", serviceName)
95
-
// Optionally redirect to an error page or show an error message
96
-
// For now, just redirecting home, but this might hide errors.
97
-
// Consider adding error handling based on why userID might be 0.
98
-
http.Redirect(w, r, "/", http.StatusSeeOther) // Or redirect to a login/error page
91
+
// todo: redirect to an error page
92
+
// right now this just redirects home but we don't want this behaviour ideally
93
+
http.Redirect(w, r, "/", http.StatusSeeOther)
99
94
}
100
95
}
101
96
}
+6
-10
oauth/service.go
+6
-10
oauth/service.go
···
1
-
// Create piper/oauth/auth_service.go
2
1
package oauth
3
2
4
3
import (
5
4
"net/http"
6
5
)
7
6
8
-
// AuthService defines the interface for different authentication services
9
-
// that can be managed by the OAuthServiceManager.
10
7
type AuthService interface {
11
-
// HandleLogin initiates the login flow for the specific service.
8
+
// inits the login flow for the service
12
9
HandleLogin(w http.ResponseWriter, r *http.Request)
13
-
// HandleCallback handles the callback from the authentication provider,
14
-
// processes the response (e.g., exchanges code for token), finds or creates
15
-
// the user in the local system, and returns the user ID.
16
-
// Returns 0 if authentication failed or user could not be determined.
10
+
// handles the callback for the provider. is responsible for inserting
11
+
// sessions in the db
17
12
HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error)
18
13
}
19
14
15
+
// optional but recommended
20
16
type TokenReceiver interface {
21
-
// SetAccessToken stores the access token for the user and returns the user ID.
22
-
// If the user is already logged in, the current ID is provided.
17
+
// stores the access token in the db
18
+
// if there is a session, will associate the token with the session
23
19
SetAccessToken(token string, currentId int64, hasSession bool) (int64, error)
24
20
}
+24
-39
service/spotify/spotify.go
+24
-39
service/spotify/spotify.go
···
2
2
3
3
import (
4
4
"encoding/json"
5
+
"errors"
5
6
"fmt"
6
7
"io"
7
8
"log"
···
31
32
}
32
33
33
34
func (s *SpotifyService) SetAccessToken(token string, userId int64, hasSession bool) (int64, error) {
34
-
// Identify the user synchronously instead of in a goroutine
35
35
userID, err := s.identifyAndStoreUser(token, userId, hasSession)
36
36
if err != nil {
37
37
log.Printf("Error identifying and storing user: %v", err)
···
41
41
}
42
42
43
43
func (s *SpotifyService) identifyAndStoreUser(token string, userId int64, hasSession bool) (int64, error) {
44
-
// Get Spotify user profile
45
44
userProfile, err := s.fetchSpotifyProfile(token)
46
45
if err != nil {
47
46
log.Printf("Error fetching Spotify profile: %v", err)
···
50
49
51
50
fmt.Printf("uid: %d hasSession: %t", userId, hasSession)
52
51
53
-
// Check if user exists
54
52
user, err := s.DB.GetUserBySpotifyID(userProfile.ID)
55
53
if err != nil {
56
54
// This error might mean DB connection issue, not just user not found.
···
74
72
}
75
73
}
76
74
} else {
77
-
// Update existing user's token and expiry
78
75
err = s.DB.UpdateUserToken(user.ID, token, "", tokenExpiryTime)
79
76
if err != nil {
77
+
// for now log and continue
80
78
log.Printf("Error updating user token for user ID %d: %v", user.ID, err)
81
-
// Consider if we should return 0 or the user ID even if update fails
82
-
// Sticking to original behavior: log and continue
83
79
} else {
84
80
log.Printf("Updated token for existing user: %s (ID: %d)", user.Username, user.ID)
85
81
}
86
82
}
87
-
// Keep the local 'user' object consistent (optional but good practice)
88
83
user.AccessToken = &token
89
84
user.TokenExpiry = &tokenExpiryTime
90
85
91
-
// Store token in memory cache regardless of new/existing user
92
86
s.mu.Lock()
93
87
s.userTokens[user.ID] = token
94
88
s.mu.Unlock()
···
103
97
Email string `json:"email"`
104
98
}
105
99
106
-
// LoadAllUsers loads all active users from the database into memory
107
100
func (s *SpotifyService) LoadAllUsers() error {
108
101
users, err := s.DB.GetAllActiveUsers()
109
102
if err != nil {
···
115
108
116
109
count := 0
117
110
for _, user := range users {
118
-
// Only load users with valid tokens
111
+
// load users with valid tokens
119
112
if user.AccessToken != nil && user.TokenExpiry.After(time.Now()) {
120
113
s.userTokens[user.ID] = *user.AccessToken
121
114
count++
···
126
119
return nil
127
120
}
128
121
122
+
func (s *SpotifyService) refreshTokenInner(user models.User) error {
123
+
// implement token refresh logic here using Spotify's token refresh endpoint
124
+
// this would make a request to Spotify's token endpoint with grant_type=refresh_token
125
+
return errors.New("Not implemented yet")
126
+
// if successful, update the database and in-memory cache
127
+
}
128
+
129
129
func (s *SpotifyService) RefreshToken(userID string) error {
130
130
s.mu.Lock()
131
131
defer s.mu.Unlock()
···
139
139
return fmt.Errorf("no refresh token for user %s", userID)
140
140
}
141
141
142
-
// Implement token refresh logic here using Spotify's token refresh endpoint
143
-
// This would make a request to Spotify's token endpoint with grant_type=refresh_token
144
-
145
-
// If successful, update the database and in-memory cache
146
-
// we won't be now so just error out
147
-
return fmt.Errorf("token refresh not implemented")
148
-
//
149
-
//s.userTokens[user.ID] = newToken
150
-
//return nil
142
+
return s.refreshTokenInner(*user)
151
143
}
152
144
153
-
// RefreshExpiredTokens attempts to refresh expired tokens
145
+
// attempt to refresh expired tokens
154
146
func (s *SpotifyService) RefreshExpiredTokens() {
155
147
users, err := s.DB.GetUsersWithExpiredTokens()
156
148
if err != nil {
···
160
152
161
153
refreshed := 0
162
154
for _, user := range users {
163
-
// Skip users without refresh tokens
155
+
// skip users without refresh tokens
164
156
if user.RefreshToken == nil {
165
157
continue
166
158
}
167
159
168
-
// Implement token refresh logic here using Spotify's token refresh endpoint
169
-
// This would make a request to Spotify's token endpoint with grant_type=refresh_token
160
+
err := s.refreshTokenInner(*user)
161
+
162
+
if err != nil {
163
+
// just print out errors here for now
164
+
log.Printf("Error from service/spotify/spotify.go when refreshing tokens: %s", err.Error())
165
+
}
170
166
171
-
// If successful, update the database and in-memory cache
172
167
refreshed++
173
168
}
174
169
···
231
226
return
232
227
}
233
228
234
-
// Get recent tracks from database
235
229
tracks, err := s.DB.GetRecentTracks(userID, 20)
236
230
if err != nil {
237
231
http.Error(w, "Error retrieving track history", http.StatusInternalServerError)
···
252
246
return nil, fmt.Errorf("no access token for user %d", userID)
253
247
}
254
248
255
-
// Call Spotify API to get currently playing track
256
249
req, err := http.NewRequest("GET", "https://api.spotify.com/v1/me/player/currently-playing", nil)
257
250
if err != nil {
258
251
return nil, err
···
266
259
}
267
260
defer resp.Body.Close()
268
261
269
-
// No track playing
262
+
// nothing playing
270
263
if resp.StatusCode == 204 {
271
264
return nil, nil
272
265
}
273
266
274
-
// Token expired
267
+
// oops, token expired
275
268
if resp.StatusCode == 401 {
276
269
// attempt to refresh token
277
270
if err := s.RefreshToken(strconv.FormatInt(userID, 10)); err != nil {
···
282
275
}
283
276
}
284
277
285
-
// Error response
286
278
if resp.StatusCode != 200 {
287
279
body, _ := io.ReadAll(resp.Body)
288
280
return nil, fmt.Errorf("spotify API error: %s", body)
289
281
}
290
282
291
-
// Parse response
292
283
var response struct {
293
284
Item struct {
294
285
Name string `json:"name"`
···
320
311
return nil, err
321
312
}
322
313
323
-
// Extract artist names/ids
324
314
var artists []models.Artist
325
315
for _, artist := range response.Item.Artists {
326
316
artists = append(artists, models.Artist{
···
329
319
})
330
320
}
331
321
332
-
// Create Track model
322
+
// assemble Track
333
323
track := &models.Track{
334
324
Name: response.Item.Name,
335
325
Artist: artists,
···
351
341
defer ticker.Stop()
352
342
353
343
for range ticker.C {
354
-
// Copy userIDs to avoid holding the lock too long
344
+
// copy userIDs to avoid holding the lock too long
355
345
s.mu.RLock()
356
346
userIDs := make([]int64, 0, len(s.userTokens))
357
347
for userID := range s.userTokens {
···
359
349
}
360
350
s.mu.RUnlock()
361
351
362
-
// Check each user's currently playing track
363
352
for _, userID := range userIDs {
364
353
track, err := s.FetchCurrentTrack(userID)
365
354
if err != nil {
···
367
356
continue
368
357
}
369
358
370
-
// No change if no track is playing
371
359
if track == nil {
372
360
continue
373
361
}
374
362
375
-
// Check if this is a new track
376
363
s.mu.RLock()
377
364
currentTrack := s.userTracks[userID]
378
365
s.mu.RUnlock()
···
384
371
}
385
372
}
386
373
387
-
// If track is different or we've played more than either half of the track or 30 seconds since the start
388
-
// whichever is greater
374
+
// if flagged true, we have a new track
389
375
isNewTrack := currentTrack == nil ||
390
376
currentTrack.Name != track.Name ||
391
377
// just check the first one for now
···
426
412
}
427
413
428
414
if isNewTrack {
429
-
// Save to database
430
415
id, err := s.DB.SaveTrack(userID, track)
431
416
if err != nil {
432
417
log.Printf("Error saving track for user %d: %v", userID, err)
+10
-33
session/session.go
+10
-33
session/session.go
···
31
31
mu sync.RWMutex
32
32
}
33
33
34
-
// NewSessionManager creates a new session manager
35
34
func NewSessionManager() *SessionManager {
36
-
// Initialize session table if it doesn't exist
37
35
database, err := db.New("./data/piper.db")
38
36
if err != nil {
39
37
log.Printf("Error connecting to database for sessions, falling back to in memory only: %v", err)
···
56
54
log.Printf("Error creating sessions table: %v", err)
57
55
}
58
56
59
-
// Create API key manager
60
57
apiKeyMgr := apikey.NewApiKeyManager(database)
61
58
62
59
return &SessionManager{
···
120
117
return session, true
121
118
}
122
119
123
-
// If not in memory and we have a database, check there
120
+
// if not in memory and we have a database, check there
124
121
if sm.db != nil {
125
122
session = &Session{ID: sessionID}
126
123
···
189
186
http.SetCookie(w, cookie)
190
187
}
191
188
192
-
// HandleLogout handles user logout
193
189
func (sm *SessionManager) HandleLogout(w http.ResponseWriter, r *http.Request) {
194
190
cookie, err := r.Cookie("session")
195
191
if err == nil {
···
201
197
http.Redirect(w, r, "/", http.StatusSeeOther)
202
198
}
203
199
204
-
// GetAPIKeyManager returns the API key manager
205
200
func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager {
206
201
return sm.apiKeyMgr
207
202
}
208
203
209
-
// CreateAPIKey creates a new API key for a user
210
204
func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) {
211
205
return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays)
212
206
}
213
207
214
-
// WithAuth is a middleware that checks if a user is authenticated via cookies or API key
208
+
// middleware that checks if a user is authenticated via cookies or API key
215
209
func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
216
210
return func(w http.ResponseWriter, r *http.Request) {
217
-
// First try API key authentication (for API requests)
211
+
// first: check API keys
218
212
apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
219
213
if apiKeyErr == nil && apiKeyStr != "" {
220
-
// Validate API key
221
214
apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
222
215
if valid {
223
-
// Add user ID to context
224
216
ctx := WithUserID(r.Context(), apiKey.UserID)
225
217
r = r.WithContext(ctx)
226
218
227
-
// Set a flag in the context that this is an API request
219
+
// set a flag for api requests
228
220
ctx = WithAPIRequest(r.Context(), true)
229
221
r = r.WithContext(ctx)
230
222
···
233
225
}
234
226
}
235
227
236
-
// Fall back to cookie authentication (for browser requests)
228
+
// if not found, check cookies for session value
237
229
cookie, err := r.Cookie("session")
238
230
if err != nil {
239
231
http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
240
232
return
241
233
}
242
234
243
-
// Verify cookie session
244
235
session, exists := sm.GetSession(cookie.Value)
245
236
if !exists {
246
237
http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
247
238
return
248
239
}
249
240
250
-
// Add session information to request context
251
241
ctx := WithUserID(r.Context(), session.UserID)
252
242
r = r.WithContext(ctx)
253
243
···
255
245
}
256
246
}
257
247
248
+
// middleware that checks if a user is authenticated but doesn't error out if not
258
249
func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
259
250
return func(w http.ResponseWriter, r *http.Request) {
260
251
ctx := r.Context()
261
-
authenticated := false // Default to not authenticated
252
+
authenticated := false
262
253
263
-
// 1. Try API key authentication
264
254
apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
265
255
if apiKeyErr == nil && apiKeyStr != "" {
266
256
apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
267
257
if valid {
268
-
// API Key valid: Add UserID, API flag, and set auth status
269
258
ctx = WithUserID(ctx, apiKey.UserID)
270
259
ctx = WithAPIRequest(ctx, true)
271
260
authenticated = true
272
-
// Update request context and call handler
273
261
r = r.WithContext(WithAuthStatus(ctx, authenticated))
274
262
handler(w, r)
275
263
return
276
264
}
277
-
// If API key was provided but invalid, we still proceed without auth
278
265
}
279
266
280
-
// 2. If no valid API key, try cookie authentication
281
-
if !authenticated { // Only check cookies if API key didn't authenticate
267
+
if !authenticated {
282
268
cookie, err := r.Cookie("session")
283
-
if err == nil { // Cookie exists
269
+
if err == nil {
284
270
session, exists := sm.GetSession(cookie.Value)
285
271
if exists {
286
-
// Session valid: Add UserID and set auth status
287
272
ctx = WithUserID(ctx, session.UserID)
288
-
// ctx = WithAPIRequest(ctx, false) // Not strictly needed, default is false
289
273
authenticated = true
290
274
}
291
-
// If session cookie exists but is invalid/expired, we proceed without auth
292
275
}
293
276
}
294
277
295
-
// 3. Set final auth status (could be true or false) and call handler
296
278
r = r.WithContext(WithAuthStatus(ctx, authenticated))
297
279
handler(w, r)
298
280
}
299
281
}
300
282
301
-
// WithAPIAuth is a middleware specifically for API-only endpoints (no cookie fallback, returns 401 instead of redirect)
283
+
// middleware that only accepts API keys
302
284
func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
303
285
return func(w http.ResponseWriter, r *http.Request) {
304
-
// Try API key authentication
305
286
apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
306
287
if apiKeyErr != nil || apiKeyStr == "" {
307
288
w.Header().Set("Content-Type", "application/json")
···
310
291
return
311
292
}
312
293
313
-
// Validate API key
314
294
apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
315
295
if !valid {
316
296
w.Header().Set("Content-Type", "application/json")
···
319
299
return
320
300
}
321
301
322
-
// Add user ID to context
323
302
ctx := WithUserID(r.Context(), apiKey.UserID)
324
-
// Mark as API request
325
303
ctx = WithAPIRequest(ctx, true)
326
304
r = r.WithContext(ctx)
327
305
···
329
307
}
330
308
}
331
309
332
-
// Context keys
333
310
type contextKey int
334
311
335
312
const (