this repo has no description
at main 11 kB view raw
1use std::collections::HashSet; 2use std::fmt; 3use std::marker::PhantomData; 4use std::str::FromStr; 5 6use anyhow::{anyhow, Result}; 7use chrono::Duration; 8use serde::de::{self, MapAccess, Visitor}; 9use serde::{Deserialize, Deserializer}; 10 11#[derive(Clone, Deserialize)] 12pub struct Feeds { 13 pub feeds: Vec<Feed>, 14} 15 16#[derive(Clone, Debug, Deserialize)] 17pub struct FeedQueryLimit(pub u32); 18 19impl Default for FeedQueryLimit { 20 fn default() -> Self { 21 FeedQueryLimit(500) 22 } 23} 24 25#[derive(Clone, Debug, Deserialize)] 26#[serde(tag = "type")] 27pub enum FeedQuery { 28 #[serde(rename = "simple")] 29 Simple { 30 #[serde(default)] 31 limit: FeedQueryLimit, 32 }, 33 34 #[serde(rename = "popular")] 35 Popular { 36 #[serde(default)] 37 gravity: f64, 38 39 #[serde(default)] 40 limit: FeedQueryLimit, 41 }, 42} 43 44#[derive(Clone, Deserialize)] 45pub struct Feed { 46 pub uri: String, 47 pub name: String, 48 pub description: String, 49 50 #[serde(default)] 51 pub aturi: Option<String>, 52 53 #[serde(default)] 54 pub allow: HashSet<String>, 55 56 #[serde(default)] 57 pub deny: Option<String>, 58 59 #[serde(default, deserialize_with = "string_or_struct")] 60 pub query: FeedQuery, 61 62 pub matchers: Vec<Matcher>, 63} 64 65#[derive(Clone, Deserialize)] 66#[serde(tag = "type")] 67pub enum Matcher { 68 #[serde(rename = "equal")] 69 Equal { 70 path: String, 71 value: String, 72 aturi: Option<String>, 73 }, 74 75 #[serde(rename = "prefix")] 76 Prefix { 77 path: String, 78 value: String, 79 aturi: Option<String>, 80 }, 81 82 #[serde(rename = "sequence")] 83 Sequence { 84 path: String, 85 values: Vec<String>, 86 aturi: Option<String>, 87 }, 88 89 #[serde(rename = "rhai")] 90 Rhai { script: String }, 91} 92 93#[derive(Clone)] 94pub struct HttpPort(u16); 95 96#[derive(Clone)] 97pub struct CertificateBundles(Vec<String>); 98 99#[derive(Clone)] 100pub struct TaskEnable(bool); 101 102#[derive(Clone)] 103pub struct TaskInterval(Duration); 104 105#[derive(Clone)] 106pub struct Compression(bool); 107 108#[derive(Clone)] 109pub struct Collections(Vec<String>); 110 111#[derive(Clone)] 112pub struct Config { 113 pub version: String, 114 pub http_port: HttpPort, 115 pub external_base: String, 116 pub database_url: String, 117 pub certificate_bundles: CertificateBundles, 118 pub consumer_task_enable: TaskEnable, 119 pub cache_task_enable: TaskEnable, 120 pub cache_task_interval: TaskInterval, 121 pub cleanup_task_enable: TaskEnable, 122 pub cleanup_task_interval: TaskInterval, 123 pub cleanup_task_max_age: TaskInterval, 124 pub vmc_task_enable: TaskEnable, 125 pub plc_hostname: String, 126 pub user_agent: String, 127 pub zstd_dictionary: String, 128 pub jetstream_hostname: String, 129 pub feeds: Feeds, 130 pub compression: Compression, 131 pub collections: Collections, 132 pub feed_cache_dir: String, 133} 134 135impl Config { 136 pub fn new() -> Result<Self> { 137 let http_port: HttpPort = default_env("HTTP_PORT", "4050").try_into()?; 138 let external_base = require_env("EXTERNAL_BASE")?; 139 140 let database_url = default_env("DATABASE_URL", "sqlite://development.db"); 141 142 let certificate_bundles: CertificateBundles = 143 optional_env("CERTIFICATE_BUNDLES").try_into()?; 144 145 let jetstream_hostname = require_env("JETSTREAM_HOSTNAME")?; 146 147 let compression: Compression = default_env("COMPRESSION", "false").try_into()?; 148 149 let zstd_dictionary = if compression.0 { 150 require_env("ZSTD_DICTIONARY")? 151 } else { 152 "".to_string() 153 }; 154 155 let consumer_task_enable: TaskEnable = 156 default_env("CONSUMER_TASK_ENABLE", "true").try_into()?; 157 158 let cache_task_enable: TaskEnable = default_env("CACHE_TASK_ENABLE", "true").try_into()?; 159 160 let cache_task_interval: TaskInterval = 161 default_env("CACHE_TASK_INTERVAL", "3m").try_into()?; 162 163 let cleanup_task_enable: TaskEnable = 164 default_env("CLEANUP_TASK_ENABLE", "true").try_into()?; 165 166 let cleanup_task_interval: TaskInterval = 167 default_env("CLEANUP_TASK_INTERVAL", "1h").try_into()?; 168 169 let cleanup_task_max_age: TaskInterval = 170 default_env("CLEANUP_TASK_MAX_AGE", "48h").try_into()?; 171 172 let vmc_task_enable: TaskEnable = default_env("VMC_TASK_ENABLE", "true").try_into()?; 173 174 let plc_hostname = default_env("PLC_HOSTNAME", "plc.directory"); 175 176 let default_user_agent = format!( 177 "supercell ({}; +https://github.com/astrenoxcoop/supercell)", 178 version()? 179 ); 180 181 let user_agent = default_env("USER_AGENT", &default_user_agent); 182 183 let feeds: Feeds = require_env("FEEDS")?.try_into()?; 184 185 let collections: Collections = 186 default_env("COLLECTIONS", "app.bsky.feed.post").try_into()?; 187 188 let feed_cache_dir = optional_env("FEED_CACHE_DIR"); 189 190 Ok(Self { 191 version: version()?, 192 http_port, 193 external_base, 194 database_url, 195 certificate_bundles, 196 consumer_task_enable, 197 cache_task_enable, 198 cache_task_interval, 199 cleanup_task_enable, 200 cleanup_task_interval, 201 cleanup_task_max_age, 202 vmc_task_enable, 203 plc_hostname, 204 user_agent, 205 jetstream_hostname, 206 zstd_dictionary, 207 feeds, 208 compression, 209 collections, 210 feed_cache_dir, 211 }) 212 } 213} 214 215fn require_env(name: &str) -> Result<String> { 216 std::env::var(name) 217 .map_err(|err| anyhow::Error::new(err).context(anyhow!("{} must be set", name))) 218} 219 220fn optional_env(name: &str) -> String { 221 std::env::var(name).unwrap_or("".to_string()) 222} 223 224fn default_env(name: &str, default_value: &str) -> String { 225 std::env::var(name).unwrap_or(default_value.to_string()) 226} 227 228pub fn version() -> Result<String> { 229 option_env!("GIT_HASH") 230 .or(option_env!("CARGO_PKG_VERSION")) 231 .map(|val| val.to_string()) 232 .ok_or(anyhow!("one of GIT_HASH or CARGO_PKG_VERSION must be set")) 233} 234 235impl TryFrom<String> for HttpPort { 236 type Error = anyhow::Error; 237 fn try_from(value: String) -> Result<Self, Self::Error> { 238 if value.is_empty() { 239 Ok(Self(80)) 240 } else { 241 value.parse::<u16>().map(Self).map_err(|err| { 242 anyhow::Error::new(err).context(anyhow!("parsing PORT into u16 failed")) 243 }) 244 } 245 } 246} 247 248impl AsRef<u16> for HttpPort { 249 fn as_ref(&self) -> &u16 { 250 &self.0 251 } 252} 253 254impl TryFrom<String> for CertificateBundles { 255 type Error = anyhow::Error; 256 fn try_from(value: String) -> Result<Self, Self::Error> { 257 Ok(Self( 258 value 259 .split(';') 260 .filter_map(|s| { 261 if s.is_empty() { 262 None 263 } else { 264 Some(s.to_string()) 265 } 266 }) 267 .collect::<Vec<String>>(), 268 )) 269 } 270} 271 272impl AsRef<Vec<String>> for CertificateBundles { 273 fn as_ref(&self) -> &Vec<String> { 274 &self.0 275 } 276} 277 278impl AsRef<bool> for TaskEnable { 279 fn as_ref(&self) -> &bool { 280 &self.0 281 } 282} 283 284impl TryFrom<String> for TaskEnable { 285 type Error = anyhow::Error; 286 fn try_from(value: String) -> Result<Self, Self::Error> { 287 let value = value.parse::<bool>().map_err(|err| { 288 anyhow::Error::new(err).context(anyhow!("parsing task enable into bool failed")) 289 })?; 290 Ok(Self(value)) 291 } 292} 293 294impl AsRef<Duration> for TaskInterval { 295 fn as_ref(&self) -> &Duration { 296 &self.0 297 } 298} 299 300impl TryFrom<String> for TaskInterval { 301 type Error = anyhow::Error; 302 fn try_from(value: String) -> Result<Self, Self::Error> { 303 let duration = duration_str::parse_chrono(&value) 304 .map_err(|err| anyhow!(err).context("parsing task interval into duration failed"))?; 305 Ok(Self(duration)) 306 } 307} 308 309impl AsRef<bool> for Compression { 310 fn as_ref(&self) -> &bool { 311 &self.0 312 } 313} 314 315impl TryFrom<String> for Compression { 316 type Error = anyhow::Error; 317 fn try_from(value: String) -> Result<Self, Self::Error> { 318 let value = value.parse::<bool>().map_err(|err| { 319 anyhow::Error::new(err).context(anyhow!("parsing compression into bool failed")) 320 })?; 321 Ok(Self(value)) 322 } 323} 324 325impl TryFrom<String> for Feeds { 326 type Error = anyhow::Error; 327 fn try_from(value: String) -> Result<Self, Self::Error> { 328 let content = std::fs::read(value).map_err(|err| { 329 anyhow::Error::new(err).context(anyhow!("reading feed config file failed")) 330 })?; 331 332 serde_yaml::from_slice(&content).map_err(|err| { 333 anyhow::Error::new(err).context(anyhow!("parsing feeds into Feeds failed")) 334 }) 335 } 336} 337 338impl TryFrom<String> for Collections { 339 type Error = anyhow::Error; 340 fn try_from(value: String) -> Result<Self, Self::Error> { 341 Ok(Self( 342 value 343 .split(',') 344 .filter_map(|s| { 345 if s.is_empty() { 346 None 347 } else { 348 Some(s.to_string()) 349 } 350 }) 351 .collect::<Vec<String>>(), 352 )) 353 } 354} 355 356impl AsRef<Vec<String>> for Collections { 357 fn as_ref(&self) -> &Vec<String> { 358 &self.0 359 } 360} 361 362impl Default for FeedQuery { 363 fn default() -> Self { 364 FeedQuery::Simple { 365 limit: FeedQueryLimit::default(), 366 } 367 } 368} 369 370impl FromStr for FeedQuery { 371 type Err = anyhow::Error; 372 373 fn from_str(value: &str) -> Result<Self, Self::Err> { 374 match value { 375 "simple" => Ok(FeedQuery::Simple { 376 limit: FeedQueryLimit::default(), 377 }), 378 "popular" => Ok(FeedQuery::Popular { 379 gravity: 1.8, 380 limit: FeedQueryLimit::default(), 381 }), 382 _ => Err(anyhow!("unsupported query")), 383 } 384 } 385} 386 387fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error> 388where 389 T: Deserialize<'de> + FromStr<Err = anyhow::Error>, 390 D: Deserializer<'de>, 391{ 392 struct StringOrStruct<T>(PhantomData<fn() -> T>); 393 394 impl<'de, T> Visitor<'de> for StringOrStruct<T> 395 where 396 T: Deserialize<'de> + FromStr<Err = anyhow::Error>, 397 { 398 type Value = T; 399 400 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 401 formatter.write_str("string or FeedQuery") 402 } 403 404 fn visit_str<E>(self, value: &str) -> Result<T, E> 405 where 406 E: de::Error, 407 { 408 FromStr::from_str(value).map_err(|_| de::Error::custom("cannot deserialize field")) 409 } 410 411 fn visit_map<M>(self, map: M) -> Result<T, M::Error> 412 where 413 M: MapAccess<'de>, 414 { 415 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map)) 416 } 417 } 418 419 deserializer.deserialize_any(StringOrStruct(PhantomData)) 420} 421 422impl AsRef<u32> for FeedQueryLimit { 423 fn as_ref(&self) -> &u32 { 424 &self.0 425 } 426}