The server for Open Course World
at main 394 lines 12 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "net/url" 10 "smm2_gameserver/config" 11 "smm2_gameserver/orm" 12 "strings" 13 14 "github.com/gorilla/mux" 15 "golang.org/x/oauth2" 16) 17 18func originOf(urlStr string) string { 19 u, err := url.Parse(urlStr) 20 if err != nil { 21 panic(err) 22 } 23 return u.Scheme + "://" + hostOf(urlStr) 24} 25 26func hostOf(urlStr string) string { 27 u, err := url.Parse(urlStr) 28 if err != nil { 29 panic(err) 30 } 31 32 host := u.Hostname() 33 port := u.Port() 34 35 if (port == "" || port == "80" || port == "443") { 36 return host 37 } 38 39 return fmt.Sprintf("%s:%s", host, port) 40} 41 42func getCallbackUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) { 43 if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") { 44 parsed, err := url.Parse(urlOrHost) 45 if err != nil { 46 return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err) 47 } 48 return getCallbackUrl(parsed.Host, cfg) 49 } 50 for _, appUrl := range cfg.AppUrls { 51 parsed, err := url.Parse(appUrl) 52 if err != nil { 53 return nil, fmt.Errorf("invalid URL: %s (%s)", appUrl, err) 54 } 55 if parsed.Host == urlOrHost { 56 return parsed, nil 57 } 58 } 59 return nil, fmt.Errorf("no callback URL found for %s", urlOrHost) 60} 61 62func getRedirectUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) { 63 if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") { 64 parsed, err := url.Parse(urlOrHost) 65 if err != nil { 66 return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err) 67 } 68 return getRedirectUrl(parsed.Host, cfg) 69 } 70 for _, apiUrl := range cfg.ApiUrls { 71 parsed, err := url.Parse(apiUrl) 72 if err != nil { 73 return nil, fmt.Errorf("invalid URL: %s (%s)", apiUrl, err) 74 } 75 if parsed.Host == urlOrHost { 76 return parsed, nil 77 } 78 } 79 return nil, fmt.Errorf("no redirect URL found for %s", urlOrHost) 80} 81 82func initAuth() { 83 hostProviders := map[string]map[string]*oauth2.Config{} 84 85 // Loop through cfg.RedirectUrls. For each url, get the origin and add a Twitch and Discord provider to the map. 86 for _, redirectUrl := range cfg.ApiUrls { 87 host := hostOf(redirectUrl) 88 hostProviders[host] = map[string]*oauth2.Config{ 89 "twitch": { 90 ClientID: cfg.TwitchClientId, 91 ClientSecret: cfg.TwitchClientSecret, 92 RedirectURL: originOf(redirectUrl) + "/api/connect/twitch/callback", 93 Scopes: []string{}, 94 Endpoint: oauth2.Endpoint{ 95 AuthURL: "https://id.twitch.tv/oauth2/authorize", 96 TokenURL: "https://id.twitch.tv/oauth2/token", 97 }, 98 }, 99 "discord": { 100 ClientID: cfg.DiscordClientId, 101 ClientSecret: cfg.DiscordClientSecret, 102 RedirectURL: originOf(redirectUrl) + "/api/connect/discord/callback", 103 Scopes: []string{"identify"}, 104 Endpoint: oauth2.Endpoint{ 105 AuthURL: "https://discord.com/api/oauth2/authorize", 106 TokenURL: "https://discord.com/api/oauth2/token", 107 }, 108 }, 109 } 110 } 111 112 fmt.Println("[hostProviders]:") 113 114 // Log the URL and method of each request. 115 116 router.Use(func(next http.Handler) http.Handler { 117 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 118 fmt.Printf("%s %s\n", r.Method, r.URL) 119 next.ServeHTTP(w, r) 120 }) 121 }) 122 123 Insecure("/api/hello-world", func(w http.ResponseWriter, r *http.Request) { 124 fmt.Fprintf(w, "Hello World!") 125 }).Methods("GET") 126 127 Insecure("/api/connect/{provider}", func(w http.ResponseWriter, r *http.Request) { 128 _, err := getCallbackUrl(r.Header.Get("Referer"), cfg) 129 if err != nil { 130 reportError(w, r, fmt.Errorf("invalid referrer URL: %s", err)) 131 } 132 provider := mux.Vars(r)["provider"] 133 providerCfg, ok := hostProviders[r.Host][provider] 134 if !ok { 135 fmt.Printf("no provider %s for host %s\n", provider, r.Host) 136 w.WriteHeader(http.StatusNotFound) 137 return 138 } 139 // Choose a callback URL from cfg that matches the origin of the referer. 140 // If there is no match, fail. 141 // The state will be the referer URL. 142 // Redirect to the provider's login URL. 143 144 http.Redirect(w, r, providerCfg.AuthCodeURL(r.Header.Get("Referer")), http.StatusFound) 145 }).Methods("GET") 146 147 Insecure("/api/connect/{provider}/callback", func(w http.ResponseWriter, r *http.Request) { 148 provider := mux.Vars(r)["provider"] 149 providerCfg, ok := hostProviders[r.Host][provider] 150 if !ok { 151 fmt.Printf("no provider %s for host %s\n", provider, r.Host) 152 w.WriteHeader(http.StatusNotFound) 153 return 154 } 155 156 if !ok { 157 fmt.Printf("no provider %s for host %s\n", provider, r.Host) 158 w.WriteHeader(http.StatusNotFound) 159 return 160 } 161 // Get the state from the query params. 162 // If the state's origin does not match a callback URL in cfg, fail. 163 164 state := r.URL.Query().Get("state") 165 stateOrigin := originOf(state) 166 var redirectUrl *url.URL 167 for _, cbUrl := range cfg.AppUrls { 168 if originOf(cbUrl) == stateOrigin { 169 parsed, err := url.Parse(cbUrl) 170 if err != nil { 171 reportError(w, r, fmt.Errorf("failed to parse callback url %s: %w", cbUrl, err)) 172 return 173 } 174 redirectUrl = parsed 175 break 176 } 177 } 178 if redirectUrl == nil { 179 reportError(w, r, fmt.Errorf("no callback url for origin %s", stateOrigin)) 180 return 181 } 182 183 // Verify redirectUrl is in cfg.AppUrls. 184 // If not, fail. 185 callbackUrl, err := getCallbackUrl(state, cfg) 186 if err != nil { 187 reportError(w, r, fmt.Errorf("invalid redirect URL: %s", err)) 188 return 189 } 190 191 // Get the code from the query params. 192 // Exchange the code for an access token. 193 // Redirect to the state URL with the access token as a query param. 194 195 code := r.URL.Query().Get("code") 196 197 token, err := providerCfg.Exchange(context.Background(), code) 198 if err != nil { 199 reportError(w, r, fmt.Errorf("failed to exchange code for token: %w", err)) 200 return 201 } 202 203 callbackUrl.Path = fmt.Sprintf("/auth/%s/callback", provider) 204 query := callbackUrl.Query() 205 query.Set("access_token", token.AccessToken) 206 query.Set("refresh_token", token.RefreshToken) 207 query.Set("redirect_uri", redirectUrl.RawPath) 208 callbackUrl.RawQuery = query.Encode() 209 210 http.Redirect(w, r, callbackUrl.String(), http.StatusFound) 211 }).Methods("GET") 212 213 Insecure("/api/auth/{provider}/callback", func(w http.ResponseWriter, r *http.Request) { 214 provider := mux.Vars(r)["provider"] 215 // We will use the access token to get the user's ID and profile 216 // information from the provider. 217 // We will return the user's ID and profile information to the client. 218 219 accessToken := r.URL.Query().Get("access_token") 220 if accessToken == "" { 221 // Return `400 Bad Request: Required parameter "access_token" is missing.` 222 w.WriteHeader(http.StatusBadRequest) 223 fmt.Fprintf(w, "Required parameter \"access_token\" is missing.") 224 return 225 } 226 227 // Get the user's ID and profile information from the provider. 228 // Return the user's ID and profile information to the client. 229 230 var user *orm.User 231 232 switch provider { 233 case "twitch": 234 // Get the user's ID and profile information from Twitch. 235 // Return the user's ID and profile information to the client. 236 httpClient := &http.Client{} 237 req, err := http.NewRequest("GET", "https://api.twitch.tv/helix/users", nil) 238 if err != nil { 239 reportError(w, r, fmt.Errorf("failed to create request: %w", err)) 240 return 241 } 242 req.Header.Set("Authorization", "Bearer "+accessToken) 243 req.Header.Set("Client-ID", cfg.TwitchClientId) 244 resp, err := httpClient.Do(req) 245 if err != nil { 246 reportError(w, r, fmt.Errorf("failed to send request: %w", err)) 247 return 248 } 249 defer resp.Body.Close() 250 body, err := ioutil.ReadAll(resp.Body) 251 if err != nil { 252 reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) 253 return 254 } 255 var twitchUser struct { 256 Data []struct { 257 ID string `json:"id"` 258 } `json:"data"` 259 } 260 if err := json.Unmarshal(body, &twitchUser); err != nil { 261 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) 262 return 263 } 264 if len(twitchUser.Data) == 0 { 265 reportError(w, r, fmt.Errorf("no user found")) 266 return 267 } 268 userId := twitchUser.Data[0].ID 269 // Get the user's profile information from Twitch. 270 req, err = http.NewRequest("GET", "https://api.twitch.tv/helix/users?id="+userId, nil) 271 if err != nil { 272 reportError(w, r, fmt.Errorf("failed to create request: %w", err)) 273 return 274 } 275 req.Header.Set("Authorization", "Bearer "+accessToken) 276 req.Header.Set("Client-ID", cfg.TwitchClientId) 277 resp, err = httpClient.Do(req) 278 if err != nil { 279 reportError(w, r, fmt.Errorf("failed to send request: %w", err)) 280 return 281 } 282 defer resp.Body.Close() 283 body, err = ioutil.ReadAll(resp.Body) 284 if err != nil { 285 reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) 286 return 287 } 288 var twitchProfile struct { 289 Data []struct { 290 DisplayName string `json:"display_name"` 291 ProfileImageUrl string `json:"profile_image_url"` 292 } `json:"data"` 293 } 294 if err := json.Unmarshal(body, &twitchProfile); err != nil { 295 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) 296 return 297 } 298 if len(twitchProfile.Data) == 0 { 299 reportError(w, r, fmt.Errorf("no user found")) 300 return 301 } 302 profile := twitchProfile.Data[0] 303 // Return the user's ID and profile information to the client. 304 w.Header().Set("Content-Type", "application/json") 305 306 user, err = CreateOrUpdateUser("twitch", userId, profile.DisplayName, profile.ProfileImageUrl) 307 if err != nil { 308 reportError(w, r, fmt.Errorf("failed to create user: %w", err)) 309 return 310 } 311 312 case "discord": 313 // Get the user's ID and profile information from Discord. 314 // Return the user's ID and profile information to the client. 315 httpClient := &http.Client{} 316 req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil) 317 if err != nil { 318 reportError(w, r, fmt.Errorf("failed to create request: %w", err)) 319 return 320 } 321 req.Header.Set("Authorization", "Bearer "+accessToken) 322 resp, err := httpClient.Do(req) 323 if err != nil { 324 reportError(w, r, fmt.Errorf("failed to send request: %w", err)) 325 return 326 } 327 defer resp.Body.Close() 328 body, err := ioutil.ReadAll(resp.Body) 329 if err != nil { 330 reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) 331 return 332 } 333 var discordUser struct { 334 ID string `json:"id"` 335 Username string `json:"username"` 336 Discriminator string `json:"discriminator"` 337 Avatar string `json:"avatar"` 338 } 339 if err := json.Unmarshal(body, &discordUser); err != nil { 340 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) 341 return 342 } 343 // Return the user's ID and profile information to the client. 344 w.Header().Set("Content-Type", "application/json") 345 346 user, err = CreateOrUpdateUser("discord", 347 discordUser.ID, 348 discordUser.Username+"#"+discordUser.Discriminator, 349 "https://cdn.discordapp.com/avatars/"+discordUser.ID+"/"+discordUser.Avatar+".png", 350 ) 351 if err != nil { 352 reportError(w, r, fmt.Errorf("failed to create user: %w", err)) 353 return 354 } 355 356 default: 357 fmt.Printf("Unknown provider: %s\n", provider) 358 w.WriteHeader(http.StatusNotFound) 359 return 360 } 361 362 // The user has been created or updated. Create a JWT for them. 363 token, err := createToken(user.ID) 364 365 if err != nil { 366 reportError(w, r, fmt.Errorf("failed to create token: %w", err)) 367 return 368 } 369 370 // Return { token, user } to the client. 371 w.Header().Set("Content-Type", "application/json") 372 json.NewEncoder(w).Encode(map[string]interface{}{ 373 "jwt": token, 374 "user": user.View(), 375 }) 376 }).Methods("GET") 377 378 Secure("/api/auth/refresh", func(w http.ResponseWriter, r *http.Request, user orm.User) { 379 // The user has been created or updated. Create a JWT for them. 380 token, err := createToken(user.ID) 381 382 if err != nil { 383 reportError(w, r, fmt.Errorf("failed to create token: %w", err)) 384 return 385 } 386 387 // Return { token, user } to the client. 388 w.Header().Set("Content-Type", "application/json") 389 json.NewEncoder(w).Encode(map[string]interface{}{ 390 "jwt": token, 391 "user": user.View(), 392 }) 393 }).Methods("GET") 394}