package api import ( "context" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" "smm2_gameserver/config" "smm2_gameserver/orm" "strings" "github.com/gorilla/mux" "golang.org/x/oauth2" ) func originOf(urlStr string) string { u, err := url.Parse(urlStr) if err != nil { panic(err) } return u.Scheme + "://" + hostOf(urlStr) } func hostOf(urlStr string) string { u, err := url.Parse(urlStr) if err != nil { panic(err) } host := u.Hostname() port := u.Port() if (port == "" || port == "80" || port == "443") { return host } return fmt.Sprintf("%s:%s", host, port) } func getCallbackUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) { if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") { parsed, err := url.Parse(urlOrHost) if err != nil { return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err) } return getCallbackUrl(parsed.Host, cfg) } for _, appUrl := range cfg.AppUrls { parsed, err := url.Parse(appUrl) if err != nil { return nil, fmt.Errorf("invalid URL: %s (%s)", appUrl, err) } if parsed.Host == urlOrHost { return parsed, nil } } return nil, fmt.Errorf("no callback URL found for %s", urlOrHost) } func getRedirectUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) { if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") { parsed, err := url.Parse(urlOrHost) if err != nil { return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err) } return getRedirectUrl(parsed.Host, cfg) } for _, apiUrl := range cfg.ApiUrls { parsed, err := url.Parse(apiUrl) if err != nil { return nil, fmt.Errorf("invalid URL: %s (%s)", apiUrl, err) } if parsed.Host == urlOrHost { return parsed, nil } } return nil, fmt.Errorf("no redirect URL found for %s", urlOrHost) } func initAuth() { hostProviders := map[string]map[string]*oauth2.Config{} // Loop through cfg.RedirectUrls. For each url, get the origin and add a Twitch and Discord provider to the map. for _, redirectUrl := range cfg.ApiUrls { host := hostOf(redirectUrl) hostProviders[host] = map[string]*oauth2.Config{ "twitch": { ClientID: cfg.TwitchClientId, ClientSecret: cfg.TwitchClientSecret, RedirectURL: originOf(redirectUrl) + "/api/connect/twitch/callback", Scopes: []string{}, Endpoint: oauth2.Endpoint{ AuthURL: "https://id.twitch.tv/oauth2/authorize", TokenURL: "https://id.twitch.tv/oauth2/token", }, }, "discord": { ClientID: cfg.DiscordClientId, ClientSecret: cfg.DiscordClientSecret, RedirectURL: originOf(redirectUrl) + "/api/connect/discord/callback", Scopes: []string{"identify"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", }, }, } } fmt.Println("[hostProviders]:") // Log the URL and method of each request. router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Printf("%s %s\n", r.Method, r.URL) next.ServeHTTP(w, r) }) }) Insecure("/api/hello-world", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Hello World!") }).Methods("GET") Insecure("/api/connect/{provider}", func(w http.ResponseWriter, r *http.Request) { _, err := getCallbackUrl(r.Header.Get("Referer"), cfg) if err != nil { reportError(w, r, fmt.Errorf("invalid referrer URL: %s", err)) } provider := mux.Vars(r)["provider"] providerCfg, ok := hostProviders[r.Host][provider] if !ok { fmt.Printf("no provider %s for host %s\n", provider, r.Host) w.WriteHeader(http.StatusNotFound) return } // Choose a callback URL from cfg that matches the origin of the referer. // If there is no match, fail. // The state will be the referer URL. // Redirect to the provider's login URL. http.Redirect(w, r, providerCfg.AuthCodeURL(r.Header.Get("Referer")), http.StatusFound) }).Methods("GET") Insecure("/api/connect/{provider}/callback", func(w http.ResponseWriter, r *http.Request) { provider := mux.Vars(r)["provider"] providerCfg, ok := hostProviders[r.Host][provider] if !ok { fmt.Printf("no provider %s for host %s\n", provider, r.Host) w.WriteHeader(http.StatusNotFound) return } if !ok { fmt.Printf("no provider %s for host %s\n", provider, r.Host) w.WriteHeader(http.StatusNotFound) return } // Get the state from the query params. // If the state's origin does not match a callback URL in cfg, fail. state := r.URL.Query().Get("state") stateOrigin := originOf(state) var redirectUrl *url.URL for _, cbUrl := range cfg.AppUrls { if originOf(cbUrl) == stateOrigin { parsed, err := url.Parse(cbUrl) if err != nil { reportError(w, r, fmt.Errorf("failed to parse callback url %s: %w", cbUrl, err)) return } redirectUrl = parsed break } } if redirectUrl == nil { reportError(w, r, fmt.Errorf("no callback url for origin %s", stateOrigin)) return } // Verify redirectUrl is in cfg.AppUrls. // If not, fail. callbackUrl, err := getCallbackUrl(state, cfg) if err != nil { reportError(w, r, fmt.Errorf("invalid redirect URL: %s", err)) return } // Get the code from the query params. // Exchange the code for an access token. // Redirect to the state URL with the access token as a query param. code := r.URL.Query().Get("code") token, err := providerCfg.Exchange(context.Background(), code) if err != nil { reportError(w, r, fmt.Errorf("failed to exchange code for token: %w", err)) return } callbackUrl.Path = fmt.Sprintf("/auth/%s/callback", provider) query := callbackUrl.Query() query.Set("access_token", token.AccessToken) query.Set("refresh_token", token.RefreshToken) query.Set("redirect_uri", redirectUrl.RawPath) callbackUrl.RawQuery = query.Encode() http.Redirect(w, r, callbackUrl.String(), http.StatusFound) }).Methods("GET") Insecure("/api/auth/{provider}/callback", func(w http.ResponseWriter, r *http.Request) { provider := mux.Vars(r)["provider"] // We will use the access token to get the user's ID and profile // information from the provider. // We will return the user's ID and profile information to the client. accessToken := r.URL.Query().Get("access_token") if accessToken == "" { // Return `400 Bad Request: Required parameter "access_token" is missing.` w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "Required parameter \"access_token\" is missing.") return } // Get the user's ID and profile information from the provider. // Return the user's ID and profile information to the client. var user *orm.User switch provider { case "twitch": // Get the user's ID and profile information from Twitch. // Return the user's ID and profile information to the client. httpClient := &http.Client{} req, err := http.NewRequest("GET", "https://api.twitch.tv/helix/users", nil) if err != nil { reportError(w, r, fmt.Errorf("failed to create request: %w", err)) return } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Client-ID", cfg.TwitchClientId) resp, err := httpClient.Do(req) if err != nil { reportError(w, r, fmt.Errorf("failed to send request: %w", err)) return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) return } var twitchUser struct { Data []struct { ID string `json:"id"` } `json:"data"` } if err := json.Unmarshal(body, &twitchUser); err != nil { reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) return } if len(twitchUser.Data) == 0 { reportError(w, r, fmt.Errorf("no user found")) return } userId := twitchUser.Data[0].ID // Get the user's profile information from Twitch. req, err = http.NewRequest("GET", "https://api.twitch.tv/helix/users?id="+userId, nil) if err != nil { reportError(w, r, fmt.Errorf("failed to create request: %w", err)) return } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Client-ID", cfg.TwitchClientId) resp, err = httpClient.Do(req) if err != nil { reportError(w, r, fmt.Errorf("failed to send request: %w", err)) return } defer resp.Body.Close() body, err = ioutil.ReadAll(resp.Body) if err != nil { reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) return } var twitchProfile struct { Data []struct { DisplayName string `json:"display_name"` ProfileImageUrl string `json:"profile_image_url"` } `json:"data"` } if err := json.Unmarshal(body, &twitchProfile); err != nil { reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) return } if len(twitchProfile.Data) == 0 { reportError(w, r, fmt.Errorf("no user found")) return } profile := twitchProfile.Data[0] // Return the user's ID and profile information to the client. w.Header().Set("Content-Type", "application/json") user, err = CreateOrUpdateUser("twitch", userId, profile.DisplayName, profile.ProfileImageUrl) if err != nil { reportError(w, r, fmt.Errorf("failed to create user: %w", err)) return } case "discord": // Get the user's ID and profile information from Discord. // Return the user's ID and profile information to the client. httpClient := &http.Client{} req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil) if err != nil { reportError(w, r, fmt.Errorf("failed to create request: %w", err)) return } req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := httpClient.Do(req) if err != nil { reportError(w, r, fmt.Errorf("failed to send request: %w", err)) return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { reportError(w, r, fmt.Errorf("failed to read response body: %w", err)) return } var discordUser struct { ID string `json:"id"` Username string `json:"username"` Discriminator string `json:"discriminator"` Avatar string `json:"avatar"` } if err := json.Unmarshal(body, &discordUser); err != nil { reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err)) return } // Return the user's ID and profile information to the client. w.Header().Set("Content-Type", "application/json") user, err = CreateOrUpdateUser("discord", discordUser.ID, discordUser.Username+"#"+discordUser.Discriminator, "https://cdn.discordapp.com/avatars/"+discordUser.ID+"/"+discordUser.Avatar+".png", ) if err != nil { reportError(w, r, fmt.Errorf("failed to create user: %w", err)) return } default: fmt.Printf("Unknown provider: %s\n", provider) w.WriteHeader(http.StatusNotFound) return } // The user has been created or updated. Create a JWT for them. token, err := createToken(user.ID) if err != nil { reportError(w, r, fmt.Errorf("failed to create token: %w", err)) return } // Return { token, user } to the client. w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "jwt": token, "user": user.View(), }) }).Methods("GET") Secure("/api/auth/refresh", func(w http.ResponseWriter, r *http.Request, user orm.User) { // The user has been created or updated. Create a JWT for them. token, err := createToken(user.ID) if err != nil { reportError(w, r, fmt.Errorf("failed to create token: %w", err)) return } // Return { token, user } to the client. w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "jwt": token, "user": user.View(), }) }).Methods("GET") }