forked from
parakeet.at/parakeet
Rust AppView - highly experimental!
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Type};
4
5#[proc_macro_derive(PostgresComposite, attributes(composite))]
6pub fn derive_postgres_composite(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8
9 let name = &input.ident;
10 let (sql_type, field_count) = extract_metadata(&input);
11 let fields = extract_fields(&input);
12
13 let from_reader_impl = generate_from_reader(&fields);
14 let write_to_impl = generate_write_to(&fields);
15 let write_pg_impl = generate_write_pg(&fields);
16
17 let field_count_i32 = field_count as i32;
18
19 let expanded = quote! {
20 impl FromSql<#sql_type, Pg> for #name {
21 fn from_sql(bytes: <Pg as diesel::backend::Backend>::RawValue<'_>) -> deserialize::Result<Self> {
22 let bytes = bytes.as_bytes();
23 let mut reader = CompositeReader::new(bytes);
24
25 let field_count = reader.read_field_count()?;
26 if field_count != #field_count_i32 {
27 return Err(format!("{}: expected {} fields, got {}",
28 stringify!(#name), #field_count_i32, field_count).into());
29 }
30
31 Ok(Self {
32 #from_reader_impl
33 })
34 }
35 }
36
37 impl ToSql<#sql_type, Pg> for #name {
38 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
39 CompositeWriter::write_header(out, #field_count_i32)?;
40 let mut writer = CompositeWriter::with_output(out);
41 #write_to_impl
42 Ok(serialize::IsNull::No)
43 }
44 }
45
46 impl PgToSql for #name {
47 fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
48 out.extend_from_slice(&(#field_count_i32).to_be_bytes());
49 #write_pg_impl
50 Ok(IsNull::No)
51 }
52
53 fn accepts(_ty: &Type) -> bool {
54 true
55 }
56
57 postgres_types::to_sql_checked!();
58 }
59 };
60
61 TokenStream::from(expanded)
62}
63
64fn extract_metadata(input: &DeriveInput) -> (Ident, usize) {
65 // Extract SQL type from #[composite(sql_type = "...")] attribute
66 let sql_type = input
67 .attrs
68 .iter()
69 .find(|attr| attr.path().is_ident("composite"))
70 .and_then(|attr| {
71 attr.parse_args::<syn::Meta>().ok().and_then(|meta| {
72 if let syn::Meta::NameValue(nv) = meta {
73 if nv.path.is_ident("sql_type") {
74 if let syn::Expr::Lit(lit) = nv.value {
75 if let syn::Lit::Str(s) = lit.lit {
76 return Some(syn::Ident::new(&s.value(), s.span()));
77 }
78 }
79 }
80 }
81 None
82 })
83 })
84 .expect("Missing #[composite(sql_type = \"...\")] attribute");
85
86 // Count fields
87 let field_count = match &input.data {
88 Data::Struct(data) => match &data.fields {
89 Fields::Named(fields) => fields.named.len(),
90 _ => panic!("Only named fields are supported"),
91 },
92 _ => panic!("Only structs are supported"),
93 };
94
95 (sql_type, field_count)
96}
97
98fn extract_fields(input: &DeriveInput) -> Vec<(Ident, Type)> {
99 match &input.data {
100 Data::Struct(data) => match &data.fields {
101 Fields::Named(fields) => fields
102 .named
103 .iter()
104 .map(|f| (f.ident.clone().unwrap(), f.ty.clone()))
105 .collect(),
106 _ => panic!("Only named fields are supported"),
107 },
108 _ => panic!("Only structs are supported"),
109 }
110}
111
112fn generate_from_reader(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream {
113 let field_reads = fields.iter().map(|(name, ty)| {
114 let read_expr = generate_read_expr(name, ty);
115 quote! { #name: #read_expr, }
116 });
117
118 quote! { #(#field_reads)* }
119}
120
121fn generate_write_to(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream {
122 let field_writes = fields.iter().map(|(name, ty)| {
123 generate_write_expr(name, ty)
124 });
125
126 quote! { #(#field_writes)* }
127}
128
129fn generate_write_pg(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream {
130 let field_writes = fields.iter().map(|(name, ty)| {
131 generate_pg_write_expr(name, ty)
132 });
133
134 quote! { #(#field_writes)* }
135}
136
137fn generate_read_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream {
138 let type_str = quote!(#ty).to_string();
139
140 // Handle different types
141 match type_str.as_str() {
142 "String" => quote! {
143 reader.read_text_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
144 },
145 "Option < String >" | "Option<String>" => quote! {
146 reader.read_text_field()?
147 },
148 "Vec < u8 >" | "Vec<u8>" => quote! {
149 reader.read_bytea_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
150 },
151 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! {
152 reader.read_bytea_field()?
153 },
154 "i32" => quote! {
155 reader.read_int_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
156 },
157 "Option < i32 >" | "Option<i32>" => quote! {
158 reader.read_int_field()?
159 },
160 "i64" => quote! {
161 reader.read_bigint_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
162 },
163 "Option < i64 >" | "Option<i64>" => quote! {
164 reader.read_bigint_field()?
165 },
166 "bool" => quote! {
167 reader.read_bool_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
168 },
169 "Option < bool >" | "Option<bool>" => quote! {
170 reader.read_bool_field()?
171 },
172 "DateTime < Utc >" | "DateTime<Utc>" => quote! {
173 reader.read_timestamptz_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
174 },
175 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! {
176 reader.read_timestamptz_field()?
177 },
178 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! {
179 reader.read_jsonb_field()?
180 },
181 _ => {
182 // Check if it's an Option type
183 if type_str.starts_with("Option <") || type_str.starts_with("Option<") {
184 // Extract inner type
185 if let Type::Path(path) = ty {
186 if let Some(segment) = path.path.segments.last() {
187 if segment.ident == "Option" {
188 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
189 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
190 // Check if inner type is VideoCaption (composite type)
191 let inner_type_str = quote!(#inner_ty).to_string();
192 if inner_type_str == "VideoCaption" {
193 return quote! { reader.read_composite_field::<#inner_ty, PostVideoCaption>()? };
194 } else {
195 return quote! { reader.read_enum_field::<#inner_ty>()? };
196 }
197 }
198 }
199 }
200 }
201 }
202 quote! { reader.read_enum_field::<#ty>()? }
203 } else {
204 // Check if it's VideoCaption or another composite type
205 if type_str == "VideoCaption" {
206 quote! {
207 reader.read_composite_field::<#ty, PostVideoCaption>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
208 }
209 } else {
210 // Assume it's a required enum type
211 quote! {
212 reader.read_enum_field::<#ty>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))?
213 }
214 }
215 }
216 }
217 }
218}
219
220fn generate_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream {
221 let type_str = quote!(#ty).to_string();
222
223 match type_str.as_str() {
224 "String" => quote! { writer.write_text_field(Some(&self.#name))?; },
225 "Option < String >" | "Option<String>" => quote! { writer.write_text_field(self.#name.as_deref())?; },
226 "Vec < u8 >" | "Vec<u8>" => quote! { writer.write_bytea_field(Some(&self.#name))?; },
227 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! { writer.write_bytea_field(self.#name.as_deref())?; },
228 "i32" => quote! { writer.write_int_field(Some(self.#name))?; },
229 "Option < i32 >" | "Option<i32>" => quote! { writer.write_int_field(self.#name)?; },
230 "i64" => quote! { writer.write_bigint_field(Some(self.#name))?; },
231 "Option < i64 >" | "Option<i64>" => quote! { writer.write_bigint_field(self.#name)?; },
232 "bool" => quote! { writer.write_bool_field(Some(self.#name))?; },
233 "Option < bool >" | "Option<bool>" => quote! { writer.write_bool_field(self.#name)?; },
234 "DateTime < Utc >" | "DateTime<Utc>" => quote! { writer.write_timestamptz_field(Some(&self.#name))?; },
235 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! { writer.write_timestamptz_field(self.#name.as_ref())?; },
236 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! { writer.write_jsonb_field(self.#name.as_ref())?; },
237 _ => {
238 // Check if it's an Option type
239 if type_str.starts_with("Option <") || type_str.starts_with("Option<") {
240 // Check if it's Option<VideoCaption>
241 if type_str.contains("VideoCaption") {
242 quote! { writer.write_composite_field::<VideoCaption, PostVideoCaption>(self.#name.as_ref())?; }
243 } else {
244 quote! { writer.write_enum_field(self.#name.as_ref())?; }
245 }
246 } else if type_str == "VideoCaption" {
247 // Required VideoCaption
248 quote! { writer.write_composite_field::<VideoCaption, PostVideoCaption>(Some(&self.#name))?; }
249 } else {
250 // Required enum
251 quote! { writer.write_enum_field(Some(&self.#name))?; }
252 }
253 }
254 }
255}
256
257fn generate_pg_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream {
258 let type_str = quote!(#ty).to_string();
259
260 match type_str.as_str() {
261 "String" => quote! {
262 out.extend_from_slice(&25i32.to_be_bytes());
263 let bytes = self.#name.as_bytes();
264 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
265 out.extend_from_slice(bytes);
266 },
267 "Option < String >" | "Option<String>" => quote! {
268 out.extend_from_slice(&25i32.to_be_bytes());
269 if let Some(ref val) = self.#name {
270 let bytes = val.as_bytes();
271 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
272 out.extend_from_slice(bytes);
273 } else {
274 out.extend_from_slice(&(-1i32).to_be_bytes());
275 }
276 },
277 "Vec < u8 >" | "Vec<u8>" => quote! {
278 out.extend_from_slice(&17i32.to_be_bytes());
279 out.extend_from_slice(&(self.#name.len() as i32).to_be_bytes());
280 out.extend_from_slice(&self.#name);
281 },
282 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! {
283 out.extend_from_slice(&17i32.to_be_bytes());
284 if let Some(ref val) = self.#name {
285 out.extend_from_slice(&(val.len() as i32).to_be_bytes());
286 out.extend_from_slice(val);
287 } else {
288 out.extend_from_slice(&(-1i32).to_be_bytes());
289 }
290 },
291 "i32" => quote! {
292 out.extend_from_slice(&23i32.to_be_bytes());
293 out.extend_from_slice(&4i32.to_be_bytes());
294 out.extend_from_slice(&self.#name.to_be_bytes());
295 },
296 "Option < i32 >" | "Option<i32>" => quote! {
297 out.extend_from_slice(&23i32.to_be_bytes());
298 if let Some(val) = self.#name {
299 out.extend_from_slice(&4i32.to_be_bytes());
300 out.extend_from_slice(&val.to_be_bytes());
301 } else {
302 out.extend_from_slice(&(-1i32).to_be_bytes());
303 }
304 },
305 "i64" => quote! {
306 out.extend_from_slice(&20i32.to_be_bytes());
307 out.extend_from_slice(&8i32.to_be_bytes());
308 out.extend_from_slice(&self.#name.to_be_bytes());
309 },
310 "Option < i64 >" | "Option<i64>" => quote! {
311 out.extend_from_slice(&20i32.to_be_bytes());
312 if let Some(val) = self.#name {
313 out.extend_from_slice(&8i32.to_be_bytes());
314 out.extend_from_slice(&val.to_be_bytes());
315 } else {
316 out.extend_from_slice(&(-1i32).to_be_bytes());
317 }
318 },
319 "bool" => quote! {
320 out.extend_from_slice(&16i32.to_be_bytes());
321 out.extend_from_slice(&1i32.to_be_bytes());
322 out.extend_from_slice(&[if self.#name { 1 } else { 0 }]);
323 },
324 "Option < bool >" | "Option<bool>" => quote! {
325 out.extend_from_slice(&16i32.to_be_bytes());
326 if let Some(val) = self.#name {
327 out.extend_from_slice(&1i32.to_be_bytes());
328 out.extend_from_slice(&[if val { 1 } else { 0 }]);
329 } else {
330 out.extend_from_slice(&(-1i32).to_be_bytes());
331 }
332 },
333 "DateTime < Utc >" | "DateTime<Utc>" => quote! {
334 const PG_EPOCH_SECS: i64 = 946684800;
335 let timestamp_secs = self.#name.timestamp();
336 let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (self.#name.timestamp_subsec_micros() as i64);
337 out.extend_from_slice(&1184i32.to_be_bytes());
338 out.extend_from_slice(&8i32.to_be_bytes());
339 out.extend_from_slice(×tamp_micros.to_be_bytes());
340 },
341 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! {
342 out.extend_from_slice(&1184i32.to_be_bytes());
343 if let Some(ref dt) = self.#name {
344 const PG_EPOCH_SECS: i64 = 946684800;
345 let timestamp_secs = dt.timestamp();
346 let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (dt.timestamp_subsec_micros() as i64);
347 out.extend_from_slice(&8i32.to_be_bytes());
348 out.extend_from_slice(×tamp_micros.to_be_bytes());
349 } else {
350 out.extend_from_slice(&(-1i32).to_be_bytes());
351 }
352 },
353 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! {
354 out.extend_from_slice(&3802i32.to_be_bytes());
355 if let Some(ref val) = self.#name {
356 let json_bytes = serde_json::to_vec(val)?;
357 let total_len = 1 + json_bytes.len();
358 out.extend_from_slice(&(total_len as i32).to_be_bytes());
359 out.extend_from_slice(&[1u8]);
360 out.extend_from_slice(&json_bytes);
361 } else {
362 out.extend_from_slice(&(-1i32).to_be_bytes());
363 }
364 },
365 _ => {
366 // Check for VideoCaption composite type
367 if type_str.contains("VideoCaption") {
368 // Handle nested composite types
369 if type_str.starts_with("Option <") || type_str.starts_with("Option<") {
370 // Optional nested composite
371 quote! {
372 out.extend_from_slice(&0i32.to_be_bytes()); // Composite OID (0 for embedded composites)
373 if let Some(ref caption) = self.#name {
374 // Serialize nested composite: field count + fields
375 let mut nested_buf = BytesMut::new();
376 nested_buf.extend_from_slice(&3i32.to_be_bytes()); // VideoCaption has 3 fields
377
378 // lang field (enum as text)
379 nested_buf.extend_from_slice(&25i32.to_be_bytes());
380 let lang_bytes = caption.lang.to_string().into_bytes();
381 nested_buf.extend_from_slice(&(lang_bytes.len() as i32).to_be_bytes());
382 nested_buf.extend_from_slice(&lang_bytes);
383
384 // mime_type field (enum as text)
385 nested_buf.extend_from_slice(&25i32.to_be_bytes());
386 let mime_bytes = caption.mime_type.to_string().into_bytes();
387 nested_buf.extend_from_slice(&(mime_bytes.len() as i32).to_be_bytes());
388 nested_buf.extend_from_slice(&mime_bytes);
389
390 // cid field (bytea)
391 nested_buf.extend_from_slice(&17i32.to_be_bytes());
392 nested_buf.extend_from_slice(&(caption.cid.len() as i32).to_be_bytes());
393 nested_buf.extend_from_slice(&caption.cid);
394
395 out.extend_from_slice(&(nested_buf.len() as i32).to_be_bytes());
396 out.extend_from_slice(&nested_buf);
397 } else {
398 out.extend_from_slice(&(-1i32).to_be_bytes()); // NULL
399 }
400 }
401 } else {
402 // Required nested composite - shouldn't happen for VideoCaption in our schema
403 quote! {
404 compile_error!("Required nested composite types not yet supported in PgToSql")
405 }
406 }
407 } else if type_str.starts_with("Option <") || type_str.starts_with("Option<") {
408 // For optional enums, serialize as text
409 quote! {
410 out.extend_from_slice(&25i32.to_be_bytes());
411 if let Some(ref val) = self.#name {
412 let bytes = val.to_string().into_bytes();
413 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
414 out.extend_from_slice(&bytes);
415 } else {
416 out.extend_from_slice(&(-1i32).to_be_bytes());
417 }
418 }
419 } else {
420 quote! {
421 out.extend_from_slice(&25i32.to_be_bytes());
422 let bytes = self.#name.to_string().into_bytes();
423 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
424 out.extend_from_slice(&bytes);
425 }
426 }
427 }
428 }
429}