at main 208 lines 6.6 kB view raw
1use core::fmt; 2use std::str::FromStr; 3 4use jacquard::{ 5 CowStr, IntoStatic, 6 oauth::{ 7 atproto::{AtprotoClientMetadata, GrantType}, 8 scopes::Scope, 9 }, 10 smol_str::{SmolStr, ToSmolStr}, 11 url::Url, 12}; 13 14use crate::env; 15 16#[derive(Debug, Clone)] 17pub struct Config { 18 pub oauth: AtprotoClientMetadata<'static>, 19} 20 21#[derive(Debug, Clone)] 22pub struct OAuthConfig { 23 pub client_id: jacquard::url::Url, 24 pub redirect_uri: jacquard::url::Url, 25 pub scopes: Vec<Scope<'static>>, 26 pub client_name: SmolStr, 27 pub client_uri: Option<jacquard::url::Url>, 28 pub logo_uri: Option<jacquard::url::Url>, 29 pub tos_uri: Option<jacquard::url::Url>, 30 pub privacy_policy_uri: Option<jacquard::url::Url>, 31} 32 33impl OAuthConfig { 34 /// This will panic if something is incorrect. You kind of can't proceed if these aren't a certain way, so... 35 pub fn new( 36 client_id: jacquard::url::Url, 37 redirect_uri: jacquard::url::Url, 38 scopes: Vec<Scope<'static>>, 39 client_name: SmolStr, 40 client_uri: Option<jacquard::url::Url>, 41 logo_uri: Option<jacquard::url::Url>, 42 tos_uri: Option<jacquard::url::Url>, 43 privacy_policy_uri: Option<jacquard::url::Url>, 44 ) -> Self { 45 let scopes = if scopes.is_empty() { 46 vec![ 47 Scope::Atproto, 48 Scope::Transition(jacquard::oauth::scopes::TransitionScope::Generic), 49 ] 50 } else { 51 scopes 52 }; 53 if let Some(client_uri) = &client_uri { 54 if let Some(client_uri_host) = client_uri.host_str() { 55 if client_uri_host != client_id.host_str().expect("client_id must have a host") { 56 panic!("client_uri host must match client_id host"); 57 } 58 } 59 } 60 if let Some(logo_uri) = &logo_uri { 61 if logo_uri.scheme() != "https" { 62 panic!("logo_uri scheme must be https"); 63 } 64 } 65 if let Some(tos_uri) = &tos_uri { 66 if tos_uri.scheme() != "https" { 67 panic!("tos_uri scheme must be https"); 68 } 69 } 70 if let Some(privacy_policy_uri) = &privacy_policy_uri { 71 if privacy_policy_uri.scheme() != "https" { 72 panic!("privacy_policy_uri scheme must be https"); 73 } 74 } 75 Self { 76 client_id, 77 redirect_uri, 78 scopes, 79 client_name, 80 client_uri, 81 logo_uri, 82 tos_uri, 83 privacy_policy_uri, 84 } 85 } 86 87 pub fn new_dev(port: u32, scopes: Vec<Scope<'static>>, client_name: SmolStr) -> Self { 88 // determine client_id 89 #[derive(serde::Serialize)] 90 struct Parameters<'a> { 91 #[serde(skip_serializing_if = "Option::is_none")] 92 redirect_uri: Option<Vec<Url>>, 93 #[serde(skip_serializing_if = "Option::is_none")] 94 scope: Option<CowStr<'a>>, 95 } 96 let redirect_uri: Url = format!("http://127.0.0.1:{port}/callback").parse().unwrap(); 97 let query = serde_html_form::to_string(Parameters { 98 redirect_uri: Some(vec![redirect_uri.clone()]), 99 scope: Some(Scope::serialize_multiple(scopes.as_slice())), 100 }) 101 .ok(); 102 let mut client_id = String::from("http://localhost"); 103 if let Some(query) = query 104 && !query.is_empty() 105 { 106 client_id.push_str(&format!("?{query}")); 107 }; 108 Self::new( 109 client_id.parse().unwrap(), 110 redirect_uri, 111 scopes, 112 client_name, 113 None, 114 None, 115 None, 116 None, 117 ) 118 } 119 120 pub fn from_env() -> Self { 121 let app_env = AppEnv::from_str(env::WEAVER_APP_ENV).unwrap_or(AppEnv::Dev); 122 123 if app_env == AppEnv::Dev { 124 Self::new_dev( 125 env::WEAVER_PORT.parse().unwrap_or(8080), 126 Scope::parse_multiple(env::WEAVER_APP_SCOPES) 127 .unwrap_or(vec![]) 128 .into_static(), 129 env::WEAVER_CLIENT_NAME.to_smolstr(), 130 ) 131 } else { 132 let host = env::WEAVER_APP_HOST; 133 let client_id = format!("{host}/oauth-client-metadata.json"); 134 let redirect_uri = format!("{host}/callback"); 135 let logo_uri = if env::WEAVER_LOGO_URI.is_empty() { 136 None 137 } else { 138 Url::parse(env::WEAVER_LOGO_URI).ok() 139 }; 140 let tos_uri = if env::WEAVER_TOS_URI.is_empty() { 141 None 142 } else { 143 Url::parse(env::WEAVER_TOS_URI).ok() 144 }; 145 let privacy_policy_uri = if env::WEAVER_PRIVACY_POLICY_URI.is_empty() { 146 None 147 } else { 148 Url::parse(env::WEAVER_PRIVACY_POLICY_URI).ok() 149 }; 150 Self::new( 151 Url::parse(&client_id).expect("Failed to parse client ID as valid URL"), 152 Url::parse(&redirect_uri).expect("Failed to parse redirect URI as valid URL"), 153 Scope::parse_multiple(env::WEAVER_APP_SCOPES) 154 .unwrap_or(vec![]) 155 .into_static(), 156 env::WEAVER_CLIENT_NAME.to_smolstr(), 157 Some(Url::parse(&host).expect("Failed to parse host as valid URL")), 158 logo_uri, 159 tos_uri, 160 privacy_policy_uri, 161 ) 162 } 163 } 164 165 pub fn as_metadata(self) -> AtprotoClientMetadata<'static> { 166 AtprotoClientMetadata::new( 167 self.client_id, 168 self.client_uri, 169 vec![self.redirect_uri], 170 vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 171 self.scopes, 172 None, 173 ) 174 .with_prod_info( 175 self.client_name.as_str(), 176 self.logo_uri, 177 self.tos_uri, 178 self.privacy_policy_uri, 179 ) 180 } 181} 182 183#[derive(PartialEq)] 184enum AppEnv { 185 Dev, 186 Prod, 187} 188 189impl std::str::FromStr for AppEnv { 190 type Err = String; 191 192 fn from_str(s: &str) -> Result<Self, Self::Err> { 193 match s { 194 "dev" => Ok(Self::Dev), 195 "prod" => Ok(Self::Prod), 196 s => Err(format!("Invalid AppEnv: {s}")), 197 } 198 } 199} 200 201impl fmt::Display for AppEnv { 202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 203 match self { 204 AppEnv::Dev => write!(f, "dev"), 205 AppEnv::Prod => write!(f, "prod"), 206 } 207 } 208}