1// Modify piper/oauth/oauth_manager.go
2package oauth
3
4import (
5 "fmt"
6 "log"
7 "net/http"
8 "sync"
9)
10
11// manages multiple oauth client services
12type OAuthServiceManager struct {
13 services map[string]AuthService
14 mu sync.RWMutex
15 logger *log.Logger
16}
17
18func NewOAuthServiceManager() *OAuthServiceManager {
19 return &OAuthServiceManager{
20 services: make(map[string]AuthService),
21 logger: log.New(log.Writer(), "oauth: ", log.LstdFlags|log.Lmsgprefix),
22 }
23}
24
25// registers any service that impls AuthService
26func (m *OAuthServiceManager) RegisterService(name string, service AuthService) {
27 m.mu.Lock()
28 defer m.mu.Unlock()
29 m.services[name] = service
30 m.logger.Printf("Registered auth service: %s", name)
31}
32
33// get an AuthService by registered name
34func (m *OAuthServiceManager) GetService(name string) (AuthService, bool) {
35 m.mu.RLock()
36 defer m.mu.RUnlock()
37 service, exists := m.services[name]
38 return service, exists
39}
40
41func (m *OAuthServiceManager) HandleLogin(serviceName string) http.HandlerFunc {
42 return func(w http.ResponseWriter, r *http.Request) {
43 m.mu.RLock()
44 service, exists := m.services[serviceName]
45 m.mu.RUnlock()
46
47 if exists {
48 service.HandleLogin(w, r)
49 return
50 }
51
52 m.logger.Printf("Auth service '%s' not found for login request", serviceName)
53 http.Error(w, fmt.Sprintf("Auth service '%s' not found", serviceName), http.StatusNotFound)
54 }
55}
56
57func (m *OAuthServiceManager) HandleLogout(serviceName string) http.HandlerFunc {
58 return func(w http.ResponseWriter, r *http.Request) {
59 m.mu.RLock()
60 service, exists := m.services[serviceName]
61 m.mu.RUnlock()
62
63 if exists {
64 service.HandleLogout(w, r)
65 return
66 }
67
68 m.logger.Printf("Auth service '%s' not found for login request", serviceName)
69 http.Error(w, fmt.Sprintf("Auth service '%s' not found", serviceName), http.StatusNotFound)
70 }
71}
72
73func (m *OAuthServiceManager) HandleCallback(serviceName string) http.HandlerFunc {
74 return func(w http.ResponseWriter, r *http.Request) {
75 m.mu.RLock()
76 service, exists := m.services[serviceName]
77 m.mu.RUnlock()
78
79 m.logger.Printf("Logging in with service %s", serviceName)
80
81 if !exists {
82 m.logger.Printf("Auth service '%s' not found for callback request", serviceName)
83 http.Error(w, fmt.Sprintf("OAuth service '%s' not found", serviceName), http.StatusNotFound)
84 return
85 }
86
87 userID, err := service.HandleCallback(w, r)
88
89 if err != nil {
90 m.logger.Printf("Error handling callback for service '%s': %v", serviceName, err)
91 http.Error(w, fmt.Sprintf("Error handling callback for service '%s'", serviceName), http.StatusInternalServerError)
92 return
93 }
94
95 if userID > 0 {
96
97 http.Redirect(w, r, "/", http.StatusSeeOther)
98 } else {
99 m.logger.Printf("Callback for service '%s' did not result in a valid user ID.", serviceName)
100 // todo: redirect to an error page
101 // right now this just redirects home but we don't want this behaviour ideally
102 http.Redirect(w, r, "/", http.StatusSeeOther)
103 }
104 }
105}