this repo has no description
1mod config;
2mod database;
3use std::{env, net::SocketAddr};
4
5use axum::{
6 Json, Router,
7 http::HeaderValue,
8 response::IntoResponse,
9 routing::{get, post},
10};
11use hyper::{Method, StatusCode, header};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use tower_http::cors::CorsLayer;
15use tower_sessions::{
16 Expiry, Session, SessionManagerLayer,
17 cookie::{SameSite, time::Duration},
18};
19use tower_sessions_redis_store::{
20 RedisStore,
21 fred::prelude::{ClientLike, Config, Pool},
22};
23use url::Url;
24
25#[tokio::main]
26async fn main() -> Result<(), Box<dyn std::error::Error>> {
27 dotenv::dotenv().ok();
28 database::init().unwrap();
29 let config = Config::from_url_centralized(env::var("REDIS_URL").unwrap().as_str()).unwrap();
30 let pool = Pool::new(config, None, None, None, 6)?;
31 let redis_conn = pool.connect();
32 pool.wait_for_connect().await?;
33
34 let session_store = RedisStore::new(pool);
35 let session_layer = SessionManagerLayer::new(session_store)
36 .with_secure(true)
37 .with_same_site(SameSite::None)
38 .with_expiry(Expiry::OnInactivity(Duration::minutes(30)));
39
40 let cors = CorsLayer::new()
41 .allow_origin("http://localhost:3000".parse::<HeaderValue>().unwrap())
42 .allow_methods(vec![Method::GET, Method::POST])
43 .allow_credentials(true);
44 let app = Router::new()
45 .route("/auth", get(auth_handler))
46 .route("/auth_url", get(auth_url_handler))
47 .route("/callback", post(auth_callback))
48 .route("/channel", get(user_handler))
49 .layer(session_layer)
50 .layer(cors);
51 let addr = SocketAddr::from(([127, 0, 0, 1], 5173));
52 println!("Server listening on {}", addr);
53
54 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
55 axum::serve(listener, app.into_make_service()).await?;
56
57 redis_conn.await??;
58 Ok(())
59}
60
61async fn auth_url_handler(session: Session) -> impl IntoResponse {
62 let state = nanoid::nanoid!();
63 match session.insert("state", &state).await {
64 Ok(_) => {
65 let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap();
66 let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap();
67 let scopes = vec!["user:read:email"];
68
69 let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap();
70 auth_url
71 .query_pairs_mut()
72 .append_pair("client_id", &twitch_client_id)
73 .append_pair("redirect_uri", &twitch_redirect_uri)
74 .append_pair("response_type", "code")
75 .append_pair("scope", &scopes.join(" "))
76 .append_pair("state", &state);
77
78 let mut headers = axum::http::HeaderMap::new();
79 headers.insert(
80 header::LOCATION,
81 HeaderValue::from_str(auth_url.as_str()).unwrap(),
82 );
83
84 (StatusCode::FOUND, headers, ()).into_response()
85 }
86 Err(_) => (
87 StatusCode::INTERNAL_SERVER_ERROR,
88 "Unable to create session".to_string(),
89 )
90 .into_response(),
91 }
92}
93
94#[derive(Deserialize, Serialize)]
95struct UserSession {
96 access_token: String,
97 refresh_token: String,
98 state: String,
99 channel: String,
100}
101
102async fn user_handler(session: Session) -> impl IntoResponse {
103 match session.get_value("user_session").await {
104 Ok(Some(user_session_json)) => {
105 let Ok(user_session) = serde_json::from_value::<UserSession>(user_session_json) else {
106 return (
107 StatusCode::INTERNAL_SERVER_ERROR,
108 "Error deserializing user session".to_string(),
109 );
110 };
111
112 (StatusCode::OK, user_session.channel)
113 }
114 Ok(None) => (
115 StatusCode::UNAUTHORIZED,
116 "User session not found".to_string(),
117 ), // TODO: maybe have the frontend redirect to /auth_url upon receiving this to re-authenticate
118 Err(e) => (
119 StatusCode::INTERNAL_SERVER_ERROR,
120 format!("Error retrieving session: {}", e),
121 ),
122 }
123}
124
125async fn auth_handler() -> impl IntoResponse {
126 let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap();
127 let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap();
128 let scopes = vec!["user:read:email"];
129 let state = nanoid::nanoid!();
130
131 let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap();
132 auth_url
133 .query_pairs_mut()
134 .append_pair("client_id", &twitch_client_id)
135 .append_pair("redirect_uri", &twitch_redirect_uri)
136 .append_pair("response_type", "code")
137 .append_pair("scope", &scopes.join(" "))
138 .append_pair("state", &state);
139
140 let mut headers = axum::http::HeaderMap::new();
141 headers.insert(header::LOCATION, auth_url.as_str().parse().unwrap());
142 headers.insert(
143 header::SET_COOKIE,
144 format!("state={}; HttpOnly; Secure", state)
145 .parse()
146 .unwrap(),
147 );
148
149 (StatusCode::FOUND, headers, ())
150}
151
152#[derive(Deserialize, Serialize)]
153struct AuthTokenRequestResponse {
154 access_token: String,
155 refresh_token: String,
156 expires_in: u64,
157 scope: Vec<String>,
158 token_type: String,
159}
160
161#[derive(Deserialize, Serialize)]
162struct AuthCallbackRequest {
163 code: String,
164 state: String,
165 scope: String,
166}
167async fn auth_callback(
168 session: Session,
169 Json(payload): Json<AuthCallbackRequest>,
170) -> impl IntoResponse {
171 let code = payload.code;
172 let request_state = payload.state;
173 let session_state: String = session
174 .get("state")
175 .await
176 .unwrap()
177 .unwrap_or("".to_string());
178 if request_state != session_state {
179 println!("State mismatch");
180 return (StatusCode::INTERNAL_SERVER_ERROR, "State mismatch").into_response();
181 }
182
183 let request_client = reqwest::Client::builder().build().unwrap();
184 match request_client
185 .post("https://id.twitch.tv/oauth2/token")
186 .form(&[
187 ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()),
188 ("client_secret", env::var("TWITCH_CLIENT_SECRET").unwrap()),
189 ("code", code.to_string()),
190 ("grant_type", "authorization_code".to_string()),
191 ("redirect_uri", env::var("TWITCH_REDIRECT_URI").unwrap()),
192 ("state", session_state.to_string()),
193 ])
194 .send()
195 .await
196 {
197 Ok(response) => {
198 let token_response: AuthTokenRequestResponse =
199 serde_json::from_str(&response.text().await.unwrap()).unwrap();
200 match validate_token(&token_response.access_token).await {
201 Some(validated_token) => {
202 session
203 .insert_value(
204 "user_session",
205 json!({
206 "access_token": token_response.access_token,
207 "refresh_token": token_response.refresh_token,
208 "state": session_state,
209 "channel": validated_token.login,
210 }),
211 )
212 .await
213 .unwrap();
214 println!("Session created successfully.");
215 (StatusCode::OK, ()).into_response()
216 }
217 None => (
218 StatusCode::INTERNAL_SERVER_ERROR,
219 "Token is invalid".to_string(),
220 )
221 .into_response(),
222 }
223 }
224 Err(e) => {
225 println!("Error: {}", e);
226 (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)).into_response()
227 }
228 }
229}
230
231// TODO: implement this for invalid tokens
232async fn logout(token: &str, session: Session) {
233 let request_client = reqwest::Client::builder().build().unwrap();
234 match request_client
235 .post("https://id.twitch.tv/oauth2/revoke")
236 .form(&[
237 ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()),
238 ("token", token.to_string()),
239 ])
240 .send()
241 .await
242 {
243 Ok(response) => {
244 let status = response.status();
245 if status.is_success() {
246 println!("Token revoked successfully.");
247 let _: Option<serde_json::Value> =
248 session.remove_value("user_session").await.unwrap();
249 println!("Session removed successfully.");
250 } else {
251 let json_response: serde_json::Value =
252 serde_json::from_str(&response.text().await.unwrap()).unwrap();
253 println!(
254 "Failed to revoke token. Status: {} for error: {}",
255 status,
256 json_response["message"].as_str().unwrap()
257 );
258 }
259 }
260 Err(e) => {
261 println!("Error revoking token: {}", e);
262 }
263 }
264}
265
266#[derive(Deserialize, Serialize)]
267struct ValidatedTokenResponse {
268 client_id: String,
269 login: String,
270 user_id: String,
271 expires_in: u64,
272 scopes: Vec<String>,
273}
274async fn validate_token(token: &str) -> Option<ValidatedTokenResponse> {
275 let request_client = reqwest::Client::builder().build().unwrap();
276 match request_client
277 .get("https://id.twitch.tv/oauth2/validate")
278 .header("Authorization", format!("OAuth {}", token))
279 .send()
280 .await
281 {
282 Ok(response) => Some(serde_json::from_str(&response.text().await.unwrap()).unwrap()),
283 Err(_) => None,
284 }
285}