The server for Open Course World
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "net/url"
10 "smm2_gameserver/config"
11 "smm2_gameserver/orm"
12 "strings"
13
14 "github.com/gorilla/mux"
15 "golang.org/x/oauth2"
16)
17
18func originOf(urlStr string) string {
19 u, err := url.Parse(urlStr)
20 if err != nil {
21 panic(err)
22 }
23 return u.Scheme + "://" + hostOf(urlStr)
24}
25
26func hostOf(urlStr string) string {
27 u, err := url.Parse(urlStr)
28 if err != nil {
29 panic(err)
30 }
31
32 host := u.Hostname()
33 port := u.Port()
34
35 if (port == "" || port == "80" || port == "443") {
36 return host
37 }
38
39 return fmt.Sprintf("%s:%s", host, port)
40}
41
42func getCallbackUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) {
43 if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") {
44 parsed, err := url.Parse(urlOrHost)
45 if err != nil {
46 return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err)
47 }
48 return getCallbackUrl(parsed.Host, cfg)
49 }
50 for _, appUrl := range cfg.AppUrls {
51 parsed, err := url.Parse(appUrl)
52 if err != nil {
53 return nil, fmt.Errorf("invalid URL: %s (%s)", appUrl, err)
54 }
55 if parsed.Host == urlOrHost {
56 return parsed, nil
57 }
58 }
59 return nil, fmt.Errorf("no callback URL found for %s", urlOrHost)
60}
61
62func getRedirectUrl(urlOrHost string, cfg *config.Config) (*url.URL, error) {
63 if strings.HasPrefix(urlOrHost, "http://") || strings.HasPrefix(urlOrHost, "https://") {
64 parsed, err := url.Parse(urlOrHost)
65 if err != nil {
66 return nil, fmt.Errorf("invalid URL: %s (%s)", urlOrHost, err)
67 }
68 return getRedirectUrl(parsed.Host, cfg)
69 }
70 for _, apiUrl := range cfg.ApiUrls {
71 parsed, err := url.Parse(apiUrl)
72 if err != nil {
73 return nil, fmt.Errorf("invalid URL: %s (%s)", apiUrl, err)
74 }
75 if parsed.Host == urlOrHost {
76 return parsed, nil
77 }
78 }
79 return nil, fmt.Errorf("no redirect URL found for %s", urlOrHost)
80}
81
82func initAuth() {
83 hostProviders := map[string]map[string]*oauth2.Config{}
84
85 // Loop through cfg.RedirectUrls. For each url, get the origin and add a Twitch and Discord provider to the map.
86 for _, redirectUrl := range cfg.ApiUrls {
87 host := hostOf(redirectUrl)
88 hostProviders[host] = map[string]*oauth2.Config{
89 "twitch": {
90 ClientID: cfg.TwitchClientId,
91 ClientSecret: cfg.TwitchClientSecret,
92 RedirectURL: originOf(redirectUrl) + "/api/connect/twitch/callback",
93 Scopes: []string{},
94 Endpoint: oauth2.Endpoint{
95 AuthURL: "https://id.twitch.tv/oauth2/authorize",
96 TokenURL: "https://id.twitch.tv/oauth2/token",
97 },
98 },
99 "discord": {
100 ClientID: cfg.DiscordClientId,
101 ClientSecret: cfg.DiscordClientSecret,
102 RedirectURL: originOf(redirectUrl) + "/api/connect/discord/callback",
103 Scopes: []string{"identify"},
104 Endpoint: oauth2.Endpoint{
105 AuthURL: "https://discord.com/api/oauth2/authorize",
106 TokenURL: "https://discord.com/api/oauth2/token",
107 },
108 },
109 }
110 }
111
112 fmt.Println("[hostProviders]:")
113
114 // Log the URL and method of each request.
115
116 router.Use(func(next http.Handler) http.Handler {
117 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118 fmt.Printf("%s %s\n", r.Method, r.URL)
119 next.ServeHTTP(w, r)
120 })
121 })
122
123 Insecure("/api/hello-world", func(w http.ResponseWriter, r *http.Request) {
124 fmt.Fprintf(w, "Hello World!")
125 }).Methods("GET")
126
127 Insecure("/api/connect/{provider}", func(w http.ResponseWriter, r *http.Request) {
128 _, err := getCallbackUrl(r.Header.Get("Referer"), cfg)
129 if err != nil {
130 reportError(w, r, fmt.Errorf("invalid referrer URL: %s", err))
131 }
132 provider := mux.Vars(r)["provider"]
133 providerCfg, ok := hostProviders[r.Host][provider]
134 if !ok {
135 fmt.Printf("no provider %s for host %s\n", provider, r.Host)
136 w.WriteHeader(http.StatusNotFound)
137 return
138 }
139 // Choose a callback URL from cfg that matches the origin of the referer.
140 // If there is no match, fail.
141 // The state will be the referer URL.
142 // Redirect to the provider's login URL.
143
144 http.Redirect(w, r, providerCfg.AuthCodeURL(r.Header.Get("Referer")), http.StatusFound)
145 }).Methods("GET")
146
147 Insecure("/api/connect/{provider}/callback", func(w http.ResponseWriter, r *http.Request) {
148 provider := mux.Vars(r)["provider"]
149 providerCfg, ok := hostProviders[r.Host][provider]
150 if !ok {
151 fmt.Printf("no provider %s for host %s\n", provider, r.Host)
152 w.WriteHeader(http.StatusNotFound)
153 return
154 }
155
156 if !ok {
157 fmt.Printf("no provider %s for host %s\n", provider, r.Host)
158 w.WriteHeader(http.StatusNotFound)
159 return
160 }
161 // Get the state from the query params.
162 // If the state's origin does not match a callback URL in cfg, fail.
163
164 state := r.URL.Query().Get("state")
165 stateOrigin := originOf(state)
166 var redirectUrl *url.URL
167 for _, cbUrl := range cfg.AppUrls {
168 if originOf(cbUrl) == stateOrigin {
169 parsed, err := url.Parse(cbUrl)
170 if err != nil {
171 reportError(w, r, fmt.Errorf("failed to parse callback url %s: %w", cbUrl, err))
172 return
173 }
174 redirectUrl = parsed
175 break
176 }
177 }
178 if redirectUrl == nil {
179 reportError(w, r, fmt.Errorf("no callback url for origin %s", stateOrigin))
180 return
181 }
182
183 // Verify redirectUrl is in cfg.AppUrls.
184 // If not, fail.
185 callbackUrl, err := getCallbackUrl(state, cfg)
186 if err != nil {
187 reportError(w, r, fmt.Errorf("invalid redirect URL: %s", err))
188 return
189 }
190
191 // Get the code from the query params.
192 // Exchange the code for an access token.
193 // Redirect to the state URL with the access token as a query param.
194
195 code := r.URL.Query().Get("code")
196
197 token, err := providerCfg.Exchange(context.Background(), code)
198 if err != nil {
199 reportError(w, r, fmt.Errorf("failed to exchange code for token: %w", err))
200 return
201 }
202
203 callbackUrl.Path = fmt.Sprintf("/auth/%s/callback", provider)
204 query := callbackUrl.Query()
205 query.Set("access_token", token.AccessToken)
206 query.Set("refresh_token", token.RefreshToken)
207 query.Set("redirect_uri", redirectUrl.RawPath)
208 callbackUrl.RawQuery = query.Encode()
209
210 http.Redirect(w, r, callbackUrl.String(), http.StatusFound)
211 }).Methods("GET")
212
213 Insecure("/api/auth/{provider}/callback", func(w http.ResponseWriter, r *http.Request) {
214 provider := mux.Vars(r)["provider"]
215 // We will use the access token to get the user's ID and profile
216 // information from the provider.
217 // We will return the user's ID and profile information to the client.
218
219 accessToken := r.URL.Query().Get("access_token")
220 if accessToken == "" {
221 // Return `400 Bad Request: Required parameter "access_token" is missing.`
222 w.WriteHeader(http.StatusBadRequest)
223 fmt.Fprintf(w, "Required parameter \"access_token\" is missing.")
224 return
225 }
226
227 // Get the user's ID and profile information from the provider.
228 // Return the user's ID and profile information to the client.
229
230 var user *orm.User
231
232 switch provider {
233 case "twitch":
234 // Get the user's ID and profile information from Twitch.
235 // Return the user's ID and profile information to the client.
236 httpClient := &http.Client{}
237 req, err := http.NewRequest("GET", "https://api.twitch.tv/helix/users", nil)
238 if err != nil {
239 reportError(w, r, fmt.Errorf("failed to create request: %w", err))
240 return
241 }
242 req.Header.Set("Authorization", "Bearer "+accessToken)
243 req.Header.Set("Client-ID", cfg.TwitchClientId)
244 resp, err := httpClient.Do(req)
245 if err != nil {
246 reportError(w, r, fmt.Errorf("failed to send request: %w", err))
247 return
248 }
249 defer resp.Body.Close()
250 body, err := ioutil.ReadAll(resp.Body)
251 if err != nil {
252 reportError(w, r, fmt.Errorf("failed to read response body: %w", err))
253 return
254 }
255 var twitchUser struct {
256 Data []struct {
257 ID string `json:"id"`
258 } `json:"data"`
259 }
260 if err := json.Unmarshal(body, &twitchUser); err != nil {
261 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err))
262 return
263 }
264 if len(twitchUser.Data) == 0 {
265 reportError(w, r, fmt.Errorf("no user found"))
266 return
267 }
268 userId := twitchUser.Data[0].ID
269 // Get the user's profile information from Twitch.
270 req, err = http.NewRequest("GET", "https://api.twitch.tv/helix/users?id="+userId, nil)
271 if err != nil {
272 reportError(w, r, fmt.Errorf("failed to create request: %w", err))
273 return
274 }
275 req.Header.Set("Authorization", "Bearer "+accessToken)
276 req.Header.Set("Client-ID", cfg.TwitchClientId)
277 resp, err = httpClient.Do(req)
278 if err != nil {
279 reportError(w, r, fmt.Errorf("failed to send request: %w", err))
280 return
281 }
282 defer resp.Body.Close()
283 body, err = ioutil.ReadAll(resp.Body)
284 if err != nil {
285 reportError(w, r, fmt.Errorf("failed to read response body: %w", err))
286 return
287 }
288 var twitchProfile struct {
289 Data []struct {
290 DisplayName string `json:"display_name"`
291 ProfileImageUrl string `json:"profile_image_url"`
292 } `json:"data"`
293 }
294 if err := json.Unmarshal(body, &twitchProfile); err != nil {
295 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err))
296 return
297 }
298 if len(twitchProfile.Data) == 0 {
299 reportError(w, r, fmt.Errorf("no user found"))
300 return
301 }
302 profile := twitchProfile.Data[0]
303 // Return the user's ID and profile information to the client.
304 w.Header().Set("Content-Type", "application/json")
305
306 user, err = CreateOrUpdateUser("twitch", userId, profile.DisplayName, profile.ProfileImageUrl)
307 if err != nil {
308 reportError(w, r, fmt.Errorf("failed to create user: %w", err))
309 return
310 }
311
312 case "discord":
313 // Get the user's ID and profile information from Discord.
314 // Return the user's ID and profile information to the client.
315 httpClient := &http.Client{}
316 req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
317 if err != nil {
318 reportError(w, r, fmt.Errorf("failed to create request: %w", err))
319 return
320 }
321 req.Header.Set("Authorization", "Bearer "+accessToken)
322 resp, err := httpClient.Do(req)
323 if err != nil {
324 reportError(w, r, fmt.Errorf("failed to send request: %w", err))
325 return
326 }
327 defer resp.Body.Close()
328 body, err := ioutil.ReadAll(resp.Body)
329 if err != nil {
330 reportError(w, r, fmt.Errorf("failed to read response body: %w", err))
331 return
332 }
333 var discordUser struct {
334 ID string `json:"id"`
335 Username string `json:"username"`
336 Discriminator string `json:"discriminator"`
337 Avatar string `json:"avatar"`
338 }
339 if err := json.Unmarshal(body, &discordUser); err != nil {
340 reportError(w, r, fmt.Errorf("failed to unmarshal response body: %w", err))
341 return
342 }
343 // Return the user's ID and profile information to the client.
344 w.Header().Set("Content-Type", "application/json")
345
346 user, err = CreateOrUpdateUser("discord",
347 discordUser.ID,
348 discordUser.Username+"#"+discordUser.Discriminator,
349 "https://cdn.discordapp.com/avatars/"+discordUser.ID+"/"+discordUser.Avatar+".png",
350 )
351 if err != nil {
352 reportError(w, r, fmt.Errorf("failed to create user: %w", err))
353 return
354 }
355
356 default:
357 fmt.Printf("Unknown provider: %s\n", provider)
358 w.WriteHeader(http.StatusNotFound)
359 return
360 }
361
362 // The user has been created or updated. Create a JWT for them.
363 token, err := createToken(user.ID)
364
365 if err != nil {
366 reportError(w, r, fmt.Errorf("failed to create token: %w", err))
367 return
368 }
369
370 // Return { token, user } to the client.
371 w.Header().Set("Content-Type", "application/json")
372 json.NewEncoder(w).Encode(map[string]interface{}{
373 "jwt": token,
374 "user": user.View(),
375 })
376 }).Methods("GET")
377
378 Secure("/api/auth/refresh", func(w http.ResponseWriter, r *http.Request, user orm.User) {
379 // The user has been created or updated. Create a JWT for them.
380 token, err := createToken(user.ID)
381
382 if err != nil {
383 reportError(w, r, fmt.Errorf("failed to create token: %w", err))
384 return
385 }
386
387 // Return { token, user } to the client.
388 w.Header().Set("Content-Type", "application/json")
389 json.NewEncoder(w).Encode(map[string]interface{}{
390 "jwt": token,
391 "user": user.View(),
392 })
393 }).Methods("GET")
394}