pydantic model generator for atproto lexicons
1//! python code generation from lexicon documents
2
3use std::collections::{HashMap, HashSet};
4use std::fs;
5use std::io;
6use std::path::Path;
7
8use atrium_lex::lexicon::{LexObject, LexRecord, LexUserType};
9use atrium_lex::LexiconDoc;
10use heck::ToSnakeCase;
11
12use crate::builtin::builtin_lexicons;
13use crate::types::{collect_external_refs, property_to_python, to_class_name, RefContext};
14
15const HEADER: &str = r#"# auto-generated by pmgfal - do not edit
16
17from __future__ import annotations
18
19from typing import Any
20
21from pydantic import BaseModel, Field
22"#;
23
24/// python keywords that need escaping as field names
25const PYTHON_KEYWORDS: &[&str] = &[
26 "type", "class", "import", "from", "global", "lambda", "def", "return", "yield", "raise",
27 "try", "except", "finally", "with", "as", "if", "elif", "else", "for", "while", "break",
28 "continue", "pass", "and", "or", "not", "in", "is", "None", "True", "False", "async", "await",
29];
30
31/// generate pydantic models for all documents
32pub fn generate_models(
33 docs: &[LexiconDoc],
34 output_dir: &Path,
35 namespace_prefix: Option<&str>,
36) -> Result<Vec<String>, io::Error> {
37 let filtered: Vec<_> = docs
38 .iter()
39 .filter(|doc| {
40 namespace_prefix
41 .map(|p| doc.id.starts_with(p))
42 .unwrap_or(true)
43 })
44 .collect();
45
46 if filtered.is_empty() {
47 return Ok(vec![]);
48 }
49
50 // build lookup of all available lexicons (user + builtin)
51 let mut all_docs: HashMap<&str, &LexiconDoc> = HashMap::new();
52 for doc in docs {
53 all_docs.insert(&doc.id, doc);
54 }
55 for doc in builtin_lexicons() {
56 all_docs.entry(&doc.id).or_insert(doc);
57 }
58
59 // collect external refs from user documents
60 let mut external_refs: HashSet<String> = HashSet::new();
61 for doc in &filtered {
62 external_refs.extend(collect_external_refs(doc));
63 }
64
65 // find which external refs we can resolve from builtins
66 let mut resolved_externals: Vec<&LexiconDoc> = Vec::new();
67 for ref_nsid in &external_refs {
68 if let Some(doc) = all_docs.get(ref_nsid.as_str()) {
69 // only include if not already in user docs
70 if !filtered.iter().any(|d| d.id == *ref_nsid) {
71 resolved_externals.push(doc);
72 }
73 }
74 }
75 resolved_externals.sort_by(|a, b| a.id.cmp(&b.id));
76
77 fs::create_dir_all(output_dir)?;
78
79 let mut output = String::from(HEADER);
80 output.push('\n');
81
82 // generate external deps first (so they're defined before use)
83 for doc in &resolved_externals {
84 output.push_str(&format!("\n# {} (builtin)\n", doc.id));
85 output.push_str(&generate_document(doc));
86 }
87
88 // generate user documents
89 for doc in &filtered {
90 output.push_str(&format!("\n# {}\n", doc.id));
91 output.push_str(&generate_document(doc));
92 }
93
94 let output_file = match namespace_prefix {
95 Some(prefix) => output_dir.join(format!("{}.py", prefix.replace('.', "_"))),
96 None => output_dir.join("models.py"),
97 };
98
99 fs::write(&output_file, &output)?;
100
101 Ok(vec![output_file.to_string_lossy().to_string()])
102}
103
104/// generate python code for a single lexicon document
105fn generate_document(doc: &LexiconDoc) -> String {
106 let ctx = RefContext::new(&doc.id);
107 let mut output = String::new();
108
109 for (def_name, def) in &doc.defs {
110 let class_name = to_class_name(&doc.id, def_name);
111
112 match def {
113 LexUserType::Record(LexRecord {
114 record,
115 description,
116 ..
117 }) => {
118 let atrium_lex::lexicon::LexRecordRecord::Object(obj) = record;
119 let desc = description.as_deref().unwrap_or(&doc.id);
120 output.push_str(&generate_class(&class_name, obj, Some(desc), &ctx));
121 output.push_str("\n\n");
122 }
123 LexUserType::Object(obj) => {
124 output.push_str(&generate_class(
125 &class_name,
126 obj,
127 obj.description.as_deref(),
128 &ctx,
129 ));
130 output.push_str("\n\n");
131 }
132 LexUserType::Token(_) => {
133 output.push_str(&format!(
134 "# token: {}\n{} = \"{}#{}\"\n\n",
135 class_name,
136 class_name.to_uppercase(),
137 doc.id,
138 def_name
139 ));
140 }
141 _ => {}
142 }
143 }
144
145 output
146}
147
148/// generate a pydantic model class
149fn generate_class(
150 class_name: &str,
151 obj: &LexObject,
152 description: Option<&str>,
153 ctx: &RefContext,
154) -> String {
155 let mut lines = vec![format!("class {class_name}(BaseModel):")];
156
157 if let Some(desc) = description {
158 lines.push(format!(" \"\"\"{desc}\"\"\""));
159 }
160
161 if obj.properties.is_empty() {
162 lines.push(" pass".into());
163 return lines.join("\n");
164 }
165
166 let required: HashSet<_> = obj
167 .required
168 .as_ref()
169 .map(|r| r.iter().map(String::as_str).collect())
170 .unwrap_or_default();
171
172 // generate required fields first, then optional
173 let mut fields: Vec<_> = obj.properties.iter().collect();
174 fields.sort_by_key(|(name, _)| !required.contains(name.as_str()));
175
176 for (name, prop) in fields {
177 let field_name = to_field_name(name);
178 let is_required = required.contains(name.as_str());
179
180 let mut py_type = property_to_python(prop, ctx);
181 if !is_required {
182 py_type = format!("{py_type} | None");
183 }
184
185 let needs_alias = field_name != *name;
186 let needs_default = !is_required;
187
188 let field_def = match (needs_alias, needs_default) {
189 (false, false) => format!(" {field_name}: {py_type}"),
190 (true, false) => format!(" {field_name}: {py_type} = Field(alias=\"{name}\")"),
191 (false, true) => format!(" {field_name}: {py_type} = Field(default=None)"),
192 (true, true) => {
193 format!(" {field_name}: {py_type} = Field(default=None, alias=\"{name}\")")
194 }
195 };
196
197 lines.push(field_def);
198 }
199
200 lines.join("\n")
201}
202
203/// convert property name to valid python field name
204fn to_field_name(name: &str) -> String {
205 let snake = name.to_snake_case();
206 if PYTHON_KEYWORDS.contains(&snake.as_str()) {
207 format!("{snake}_")
208 } else {
209 snake
210 }
211}