this repo has no description
1use std::collections::HashSet;
2use std::fmt;
3use std::marker::PhantomData;
4use std::str::FromStr;
5
6use anyhow::{anyhow, Result};
7use chrono::Duration;
8use serde::de::{self, MapAccess, Visitor};
9use serde::{Deserialize, Deserializer};
10
11#[derive(Clone, Deserialize)]
12pub struct Feeds {
13 pub feeds: Vec<Feed>,
14}
15
16#[derive(Clone, Debug, Deserialize)]
17pub struct FeedQueryLimit(pub u32);
18
19impl Default for FeedQueryLimit {
20 fn default() -> Self {
21 FeedQueryLimit(500)
22 }
23}
24
25#[derive(Clone, Debug, Deserialize)]
26#[serde(tag = "type")]
27pub enum FeedQuery {
28 #[serde(rename = "simple")]
29 Simple {
30 #[serde(default)]
31 limit: FeedQueryLimit,
32 },
33
34 #[serde(rename = "popular")]
35 Popular {
36 #[serde(default)]
37 gravity: f64,
38
39 #[serde(default)]
40 limit: FeedQueryLimit,
41 },
42}
43
44#[derive(Clone, Deserialize)]
45pub struct Feed {
46 pub uri: String,
47 pub name: String,
48 pub description: String,
49
50 #[serde(default)]
51 pub aturi: Option<String>,
52
53 #[serde(default)]
54 pub allow: HashSet<String>,
55
56 #[serde(default)]
57 pub deny: Option<String>,
58
59 #[serde(default, deserialize_with = "string_or_struct")]
60 pub query: FeedQuery,
61
62 pub matchers: Vec<Matcher>,
63}
64
65#[derive(Clone, Deserialize)]
66#[serde(tag = "type")]
67pub enum Matcher {
68 #[serde(rename = "equal")]
69 Equal {
70 path: String,
71 value: String,
72 aturi: Option<String>,
73 },
74
75 #[serde(rename = "prefix")]
76 Prefix {
77 path: String,
78 value: String,
79 aturi: Option<String>,
80 },
81
82 #[serde(rename = "sequence")]
83 Sequence {
84 path: String,
85 values: Vec<String>,
86 aturi: Option<String>,
87 },
88
89 #[serde(rename = "rhai")]
90 Rhai { script: String },
91}
92
93#[derive(Clone)]
94pub struct HttpPort(u16);
95
96#[derive(Clone)]
97pub struct CertificateBundles(Vec<String>);
98
99#[derive(Clone)]
100pub struct TaskEnable(bool);
101
102#[derive(Clone)]
103pub struct TaskInterval(Duration);
104
105#[derive(Clone)]
106pub struct Compression(bool);
107
108#[derive(Clone)]
109pub struct Collections(Vec<String>);
110
111#[derive(Clone)]
112pub struct Config {
113 pub version: String,
114 pub http_port: HttpPort,
115 pub external_base: String,
116 pub database_url: String,
117 pub certificate_bundles: CertificateBundles,
118 pub consumer_task_enable: TaskEnable,
119 pub cache_task_enable: TaskEnable,
120 pub cache_task_interval: TaskInterval,
121 pub cleanup_task_enable: TaskEnable,
122 pub cleanup_task_interval: TaskInterval,
123 pub cleanup_task_max_age: TaskInterval,
124 pub vmc_task_enable: TaskEnable,
125 pub plc_hostname: String,
126 pub user_agent: String,
127 pub zstd_dictionary: String,
128 pub jetstream_hostname: String,
129 pub feeds: Feeds,
130 pub compression: Compression,
131 pub collections: Collections,
132 pub feed_cache_dir: String,
133}
134
135impl Config {
136 pub fn new() -> Result<Self> {
137 let http_port: HttpPort = default_env("HTTP_PORT", "4050").try_into()?;
138 let external_base = require_env("EXTERNAL_BASE")?;
139
140 let database_url = default_env("DATABASE_URL", "sqlite://development.db");
141
142 let certificate_bundles: CertificateBundles =
143 optional_env("CERTIFICATE_BUNDLES").try_into()?;
144
145 let jetstream_hostname = require_env("JETSTREAM_HOSTNAME")?;
146
147 let compression: Compression = default_env("COMPRESSION", "false").try_into()?;
148
149 let zstd_dictionary = if compression.0 {
150 require_env("ZSTD_DICTIONARY")?
151 } else {
152 "".to_string()
153 };
154
155 let consumer_task_enable: TaskEnable =
156 default_env("CONSUMER_TASK_ENABLE", "true").try_into()?;
157
158 let cache_task_enable: TaskEnable = default_env("CACHE_TASK_ENABLE", "true").try_into()?;
159
160 let cache_task_interval: TaskInterval =
161 default_env("CACHE_TASK_INTERVAL", "3m").try_into()?;
162
163 let cleanup_task_enable: TaskEnable =
164 default_env("CLEANUP_TASK_ENABLE", "true").try_into()?;
165
166 let cleanup_task_interval: TaskInterval =
167 default_env("CLEANUP_TASK_INTERVAL", "1h").try_into()?;
168
169 let cleanup_task_max_age: TaskInterval =
170 default_env("CLEANUP_TASK_MAX_AGE", "48h").try_into()?;
171
172 let vmc_task_enable: TaskEnable = default_env("VMC_TASK_ENABLE", "true").try_into()?;
173
174 let plc_hostname = default_env("PLC_HOSTNAME", "plc.directory");
175
176 let default_user_agent = format!(
177 "supercell ({}; +https://github.com/astrenoxcoop/supercell)",
178 version()?
179 );
180
181 let user_agent = default_env("USER_AGENT", &default_user_agent);
182
183 let feeds: Feeds = require_env("FEEDS")?.try_into()?;
184
185 let collections: Collections =
186 default_env("COLLECTIONS", "app.bsky.feed.post").try_into()?;
187
188 let feed_cache_dir = optional_env("FEED_CACHE_DIR");
189
190 Ok(Self {
191 version: version()?,
192 http_port,
193 external_base,
194 database_url,
195 certificate_bundles,
196 consumer_task_enable,
197 cache_task_enable,
198 cache_task_interval,
199 cleanup_task_enable,
200 cleanup_task_interval,
201 cleanup_task_max_age,
202 vmc_task_enable,
203 plc_hostname,
204 user_agent,
205 jetstream_hostname,
206 zstd_dictionary,
207 feeds,
208 compression,
209 collections,
210 feed_cache_dir,
211 })
212 }
213}
214
215fn require_env(name: &str) -> Result<String> {
216 std::env::var(name)
217 .map_err(|err| anyhow::Error::new(err).context(anyhow!("{} must be set", name)))
218}
219
220fn optional_env(name: &str) -> String {
221 std::env::var(name).unwrap_or("".to_string())
222}
223
224fn default_env(name: &str, default_value: &str) -> String {
225 std::env::var(name).unwrap_or(default_value.to_string())
226}
227
228pub fn version() -> Result<String> {
229 option_env!("GIT_HASH")
230 .or(option_env!("CARGO_PKG_VERSION"))
231 .map(|val| val.to_string())
232 .ok_or(anyhow!("one of GIT_HASH or CARGO_PKG_VERSION must be set"))
233}
234
235impl TryFrom<String> for HttpPort {
236 type Error = anyhow::Error;
237 fn try_from(value: String) -> Result<Self, Self::Error> {
238 if value.is_empty() {
239 Ok(Self(80))
240 } else {
241 value.parse::<u16>().map(Self).map_err(|err| {
242 anyhow::Error::new(err).context(anyhow!("parsing PORT into u16 failed"))
243 })
244 }
245 }
246}
247
248impl AsRef<u16> for HttpPort {
249 fn as_ref(&self) -> &u16 {
250 &self.0
251 }
252}
253
254impl TryFrom<String> for CertificateBundles {
255 type Error = anyhow::Error;
256 fn try_from(value: String) -> Result<Self, Self::Error> {
257 Ok(Self(
258 value
259 .split(';')
260 .filter_map(|s| {
261 if s.is_empty() {
262 None
263 } else {
264 Some(s.to_string())
265 }
266 })
267 .collect::<Vec<String>>(),
268 ))
269 }
270}
271
272impl AsRef<Vec<String>> for CertificateBundles {
273 fn as_ref(&self) -> &Vec<String> {
274 &self.0
275 }
276}
277
278impl AsRef<bool> for TaskEnable {
279 fn as_ref(&self) -> &bool {
280 &self.0
281 }
282}
283
284impl TryFrom<String> for TaskEnable {
285 type Error = anyhow::Error;
286 fn try_from(value: String) -> Result<Self, Self::Error> {
287 let value = value.parse::<bool>().map_err(|err| {
288 anyhow::Error::new(err).context(anyhow!("parsing task enable into bool failed"))
289 })?;
290 Ok(Self(value))
291 }
292}
293
294impl AsRef<Duration> for TaskInterval {
295 fn as_ref(&self) -> &Duration {
296 &self.0
297 }
298}
299
300impl TryFrom<String> for TaskInterval {
301 type Error = anyhow::Error;
302 fn try_from(value: String) -> Result<Self, Self::Error> {
303 let duration = duration_str::parse_chrono(&value)
304 .map_err(|err| anyhow!(err).context("parsing task interval into duration failed"))?;
305 Ok(Self(duration))
306 }
307}
308
309impl AsRef<bool> for Compression {
310 fn as_ref(&self) -> &bool {
311 &self.0
312 }
313}
314
315impl TryFrom<String> for Compression {
316 type Error = anyhow::Error;
317 fn try_from(value: String) -> Result<Self, Self::Error> {
318 let value = value.parse::<bool>().map_err(|err| {
319 anyhow::Error::new(err).context(anyhow!("parsing compression into bool failed"))
320 })?;
321 Ok(Self(value))
322 }
323}
324
325impl TryFrom<String> for Feeds {
326 type Error = anyhow::Error;
327 fn try_from(value: String) -> Result<Self, Self::Error> {
328 let content = std::fs::read(value).map_err(|err| {
329 anyhow::Error::new(err).context(anyhow!("reading feed config file failed"))
330 })?;
331
332 serde_yaml::from_slice(&content).map_err(|err| {
333 anyhow::Error::new(err).context(anyhow!("parsing feeds into Feeds failed"))
334 })
335 }
336}
337
338impl TryFrom<String> for Collections {
339 type Error = anyhow::Error;
340 fn try_from(value: String) -> Result<Self, Self::Error> {
341 Ok(Self(
342 value
343 .split(',')
344 .filter_map(|s| {
345 if s.is_empty() {
346 None
347 } else {
348 Some(s.to_string())
349 }
350 })
351 .collect::<Vec<String>>(),
352 ))
353 }
354}
355
356impl AsRef<Vec<String>> for Collections {
357 fn as_ref(&self) -> &Vec<String> {
358 &self.0
359 }
360}
361
362impl Default for FeedQuery {
363 fn default() -> Self {
364 FeedQuery::Simple {
365 limit: FeedQueryLimit::default(),
366 }
367 }
368}
369
370impl FromStr for FeedQuery {
371 type Err = anyhow::Error;
372
373 fn from_str(value: &str) -> Result<Self, Self::Err> {
374 match value {
375 "simple" => Ok(FeedQuery::Simple {
376 limit: FeedQueryLimit::default(),
377 }),
378 "popular" => Ok(FeedQuery::Popular {
379 gravity: 1.8,
380 limit: FeedQueryLimit::default(),
381 }),
382 _ => Err(anyhow!("unsupported query")),
383 }
384 }
385}
386
387fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error>
388where
389 T: Deserialize<'de> + FromStr<Err = anyhow::Error>,
390 D: Deserializer<'de>,
391{
392 struct StringOrStruct<T>(PhantomData<fn() -> T>);
393
394 impl<'de, T> Visitor<'de> for StringOrStruct<T>
395 where
396 T: Deserialize<'de> + FromStr<Err = anyhow::Error>,
397 {
398 type Value = T;
399
400 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
401 formatter.write_str("string or FeedQuery")
402 }
403
404 fn visit_str<E>(self, value: &str) -> Result<T, E>
405 where
406 E: de::Error,
407 {
408 FromStr::from_str(value).map_err(|_| de::Error::custom("cannot deserialize field"))
409 }
410
411 fn visit_map<M>(self, map: M) -> Result<T, M::Error>
412 where
413 M: MapAccess<'de>,
414 {
415 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
416 }
417 }
418
419 deserializer.deserialize_any(StringOrStruct(PhantomData))
420}
421
422impl AsRef<u32> for FeedQueryLimit {
423 fn as_ref(&self) -> &u32 {
424 &self.0
425 }
426}