A better Rust ATProto crate
at main 189 lines 6.6 kB view raw
1use crate::codegen::nsid_utils::{NsidPath, RefPath}; 2use crate::corpus::LexiconCorpus; 3use crate::error::Result; 4use heck::ToPascalCase; 5use jacquard_common::CowStr; 6use proc_macro2::TokenStream; 7use quote::quote; 8use std::cell::RefCell; 9use std::collections::{HashMap, HashSet}; 10 11/// Information about a union variant 12#[derive(Debug, Clone)] 13pub struct UnionVariant { 14 /// The original ref string (normalized) 15 pub ref_str: String, 16 /// The variant name (may be disambiguated) 17 pub variant_name: String, 18 /// The Rust type for this variant 19 pub rust_type: TokenStream, 20} 21 22/// Context for tracking namespace dependencies during union generation 23pub struct UnionGenContext<'a> { 24 pub corpus: &'a LexiconCorpus, 25 pub namespace_deps: &'a RefCell<HashMap<String, HashSet<String>>>, 26 pub current_nsid: &'a str, 27} 28 29impl<'a> UnionGenContext<'a> { 30 /// Build variants for a union with collision detection and disambiguation 31 pub fn build_union_variants( 32 &self, 33 refs: &[CowStr<'static>], 34 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>, 35 ) -> Result<Vec<UnionVariant>> { 36 let current_nsid_path = NsidPath::parse(self.current_nsid); 37 let current_namespace = current_nsid_path.namespace(); 38 39 // First pass: collect all variant names and detect collisions 40 #[derive(Debug)] 41 struct VariantInfo { 42 ref_str: String, 43 ref_nsid: String, 44 simple_name: String, 45 is_current_namespace: bool, 46 } 47 48 let mut seen_refs = HashSet::new(); 49 let mut variant_infos = Vec::new(); 50 for ref_str in refs { 51 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid); 52 if !seen_refs.insert(normalized_ref.clone()) { 53 continue; 54 } 55 let ref_path = RefPath::parse(&normalized_ref, None); 56 let ref_nsid_str = ref_path.nsid(); 57 let ref_def = ref_path.def(); 58 59 // Skip unknown refs 60 if !self.corpus.ref_exists(&normalized_ref) { 61 continue; 62 } 63 64 let is_current_namespace = ref_nsid_str.starts_with(&current_namespace); 65 let is_same_module = ref_nsid_str == self.current_nsid; 66 67 // Generate simple variant name 68 let last_segment = ref_nsid_str.split('.').last().unwrap(); 69 let simple_name = if ref_def == "main" { 70 last_segment.to_pascal_case() 71 } else if last_segment == "defs" { 72 ref_def.to_pascal_case() 73 } else if is_same_module { 74 ref_def.to_pascal_case() 75 } else { 76 format!( 77 "{}{}", 78 last_segment.to_pascal_case(), 79 ref_def.to_pascal_case() 80 ) 81 }; 82 83 variant_infos.push(VariantInfo { 84 ref_str: normalized_ref.clone(), 85 ref_nsid: ref_nsid_str.to_string(), 86 simple_name, 87 is_current_namespace, 88 }); 89 } 90 91 // Second pass: detect collisions and disambiguate 92 let mut name_counts: HashMap<String, usize> = HashMap::new(); 93 for info in &variant_infos { 94 *name_counts.entry(info.simple_name.clone()).or_insert(0) += 1; 95 } 96 97 let mut variants = Vec::new(); 98 for info in variant_infos { 99 let has_collision = name_counts.get(&info.simple_name).copied().unwrap_or(0) > 1; 100 101 // Track namespace dependency for foreign refs 102 if !info.is_current_namespace { 103 let ref_nsid_path = NsidPath::parse(&info.ref_nsid); 104 let foreign_namespace = ref_nsid_path.namespace(); 105 self.namespace_deps 106 .borrow_mut() 107 .entry(current_namespace.clone()) 108 .or_default() 109 .insert(foreign_namespace); 110 } 111 112 // Disambiguate: add second NSID segment prefix only to foreign refs when there's a collision 113 let variant_name = if has_collision && !info.is_current_namespace { 114 let ref_nsid_path = NsidPath::parse(&info.ref_nsid); 115 let segments = ref_nsid_path.segments(); 116 let prefix = if segments.len() >= 2 { 117 segments[1].to_pascal_case() 118 } else { 119 segments[0].to_pascal_case() 120 }; 121 format!("{}{}", prefix, info.simple_name) 122 } else { 123 info.simple_name.clone() 124 }; 125 126 let rust_type = ref_to_rust_type(&info.ref_str)?; 127 128 variants.push(UnionVariant { 129 ref_str: info.ref_str, 130 variant_name, 131 rust_type, 132 }); 133 } 134 135 Ok(variants) 136 } 137 138 /// Build variants for a union without collision detection (simple mode) 139 pub fn build_simple_union_variants( 140 &self, 141 refs: &[CowStr<'static>], 142 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>, 143 ) -> Result<Vec<UnionVariant>> { 144 let mut variants = Vec::new(); 145 146 for ref_str in refs { 147 let ref_str_s = ref_str.as_ref(); 148 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid); 149 let ref_path = RefPath::parse(&normalized_ref, None); 150 let ref_nsid = ref_path.nsid(); 151 let ref_def = ref_path.def(); 152 153 let variant_name = if ref_def == "main" { 154 let ref_nsid_path = NsidPath::parse(ref_nsid); 155 ref_nsid_path.last_segment().to_pascal_case() 156 } else { 157 ref_def.to_pascal_case() 158 }; 159 160 let rust_type = ref_to_rust_type(&normalized_ref)?; 161 162 variants.push(UnionVariant { 163 ref_str: ref_str_s.to_string(), 164 variant_name, 165 rust_type, 166 }); 167 } 168 169 Ok(variants) 170 } 171} 172 173/// Generate variant tokens for a union enum 174pub fn generate_variant_tokens(variants: &[UnionVariant]) -> Vec<TokenStream> { 175 variants 176 .iter() 177 .map(|variant| { 178 let variant_ident = 179 syn::Ident::new(&variant.variant_name, proc_macro2::Span::call_site()); 180 let ref_str_literal = &variant.ref_str; 181 let rust_type = &variant.rust_type; 182 183 quote! { 184 #[serde(rename = #ref_str_literal)] 185 #variant_ident(Box<#rust_type>) 186 } 187 }) 188 .collect() 189}