A better Rust ATProto crate
1use crate::codegen::nsid_utils::{NsidPath, RefPath};
2use crate::corpus::LexiconCorpus;
3use crate::error::Result;
4use heck::ToPascalCase;
5use jacquard_common::CowStr;
6use proc_macro2::TokenStream;
7use quote::quote;
8use std::cell::RefCell;
9use std::collections::{HashMap, HashSet};
10
11/// Information about a union variant
12#[derive(Debug, Clone)]
13pub struct UnionVariant {
14 /// The original ref string (normalized)
15 pub ref_str: String,
16 /// The variant name (may be disambiguated)
17 pub variant_name: String,
18 /// The Rust type for this variant
19 pub rust_type: TokenStream,
20}
21
22/// Context for tracking namespace dependencies during union generation
23pub struct UnionGenContext<'a> {
24 pub corpus: &'a LexiconCorpus,
25 pub namespace_deps: &'a RefCell<HashMap<String, HashSet<String>>>,
26 pub current_nsid: &'a str,
27}
28
29impl<'a> UnionGenContext<'a> {
30 /// Build variants for a union with collision detection and disambiguation
31 pub fn build_union_variants(
32 &self,
33 refs: &[CowStr<'static>],
34 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>,
35 ) -> Result<Vec<UnionVariant>> {
36 let current_nsid_path = NsidPath::parse(self.current_nsid);
37 let current_namespace = current_nsid_path.namespace();
38
39 // First pass: collect all variant names and detect collisions
40 #[derive(Debug)]
41 struct VariantInfo {
42 ref_str: String,
43 ref_nsid: String,
44 simple_name: String,
45 is_current_namespace: bool,
46 }
47
48 let mut seen_refs = HashSet::new();
49 let mut variant_infos = Vec::new();
50 for ref_str in refs {
51 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid);
52 if !seen_refs.insert(normalized_ref.clone()) {
53 continue;
54 }
55 let ref_path = RefPath::parse(&normalized_ref, None);
56 let ref_nsid_str = ref_path.nsid();
57 let ref_def = ref_path.def();
58
59 // Skip unknown refs
60 if !self.corpus.ref_exists(&normalized_ref) {
61 continue;
62 }
63
64 let is_current_namespace = ref_nsid_str.starts_with(¤t_namespace);
65 let is_same_module = ref_nsid_str == self.current_nsid;
66
67 // Generate simple variant name
68 let last_segment = ref_nsid_str.split('.').last().unwrap();
69 let simple_name = if ref_def == "main" {
70 last_segment.to_pascal_case()
71 } else if last_segment == "defs" {
72 ref_def.to_pascal_case()
73 } else if is_same_module {
74 ref_def.to_pascal_case()
75 } else {
76 format!(
77 "{}{}",
78 last_segment.to_pascal_case(),
79 ref_def.to_pascal_case()
80 )
81 };
82
83 variant_infos.push(VariantInfo {
84 ref_str: normalized_ref.clone(),
85 ref_nsid: ref_nsid_str.to_string(),
86 simple_name,
87 is_current_namespace,
88 });
89 }
90
91 // Second pass: detect collisions and disambiguate
92 let mut name_counts: HashMap<String, usize> = HashMap::new();
93 for info in &variant_infos {
94 *name_counts.entry(info.simple_name.clone()).or_insert(0) += 1;
95 }
96
97 let mut variants = Vec::new();
98 for info in variant_infos {
99 let has_collision = name_counts.get(&info.simple_name).copied().unwrap_or(0) > 1;
100
101 // Track namespace dependency for foreign refs
102 if !info.is_current_namespace {
103 let ref_nsid_path = NsidPath::parse(&info.ref_nsid);
104 let foreign_namespace = ref_nsid_path.namespace();
105 self.namespace_deps
106 .borrow_mut()
107 .entry(current_namespace.clone())
108 .or_default()
109 .insert(foreign_namespace);
110 }
111
112 // Disambiguate: add second NSID segment prefix only to foreign refs when there's a collision
113 let variant_name = if has_collision && !info.is_current_namespace {
114 let ref_nsid_path = NsidPath::parse(&info.ref_nsid);
115 let segments = ref_nsid_path.segments();
116 let prefix = if segments.len() >= 2 {
117 segments[1].to_pascal_case()
118 } else {
119 segments[0].to_pascal_case()
120 };
121 format!("{}{}", prefix, info.simple_name)
122 } else {
123 info.simple_name.clone()
124 };
125
126 let rust_type = ref_to_rust_type(&info.ref_str)?;
127
128 variants.push(UnionVariant {
129 ref_str: info.ref_str,
130 variant_name,
131 rust_type,
132 });
133 }
134
135 Ok(variants)
136 }
137
138 /// Build variants for a union without collision detection (simple mode)
139 pub fn build_simple_union_variants(
140 &self,
141 refs: &[CowStr<'static>],
142 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>,
143 ) -> Result<Vec<UnionVariant>> {
144 let mut variants = Vec::new();
145
146 for ref_str in refs {
147 let ref_str_s = ref_str.as_ref();
148 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid);
149 let ref_path = RefPath::parse(&normalized_ref, None);
150 let ref_nsid = ref_path.nsid();
151 let ref_def = ref_path.def();
152
153 let variant_name = if ref_def == "main" {
154 let ref_nsid_path = NsidPath::parse(ref_nsid);
155 ref_nsid_path.last_segment().to_pascal_case()
156 } else {
157 ref_def.to_pascal_case()
158 };
159
160 let rust_type = ref_to_rust_type(&normalized_ref)?;
161
162 variants.push(UnionVariant {
163 ref_str: ref_str_s.to_string(),
164 variant_name,
165 rust_type,
166 });
167 }
168
169 Ok(variants)
170 }
171}
172
173/// Generate variant tokens for a union enum
174pub fn generate_variant_tokens(variants: &[UnionVariant]) -> Vec<TokenStream> {
175 variants
176 .iter()
177 .map(|variant| {
178 let variant_ident =
179 syn::Ident::new(&variant.variant_name, proc_macro2::Span::call_site());
180 let ref_str_literal = &variant.ref_str;
181 let rust_type = &variant.rust_type;
182
183 quote! {
184 #[serde(rename = #ref_str_literal)]
185 #variant_ident(Box<#rust_type>)
186 }
187 })
188 .collect()
189}