use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Type}; #[proc_macro_derive(PostgresComposite, attributes(composite))] pub fn derive_postgres_composite(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = &input.ident; let (sql_type, field_count) = extract_metadata(&input); let fields = extract_fields(&input); let from_reader_impl = generate_from_reader(&fields); let write_to_impl = generate_write_to(&fields); let write_pg_impl = generate_write_pg(&fields); let field_count_i32 = field_count as i32; let expanded = quote! { impl FromSql<#sql_type, Pg> for #name { fn from_sql(bytes: ::RawValue<'_>) -> deserialize::Result { let bytes = bytes.as_bytes(); let mut reader = CompositeReader::new(bytes); let field_count = reader.read_field_count()?; if field_count != #field_count_i32 { return Err(format!("{}: expected {} fields, got {}", stringify!(#name), #field_count_i32, field_count).into()); } Ok(Self { #from_reader_impl }) } } impl ToSql<#sql_type, Pg> for #name { fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { CompositeWriter::write_header(out, #field_count_i32)?; let mut writer = CompositeWriter::with_output(out); #write_to_impl Ok(serialize::IsNull::No) } } impl PgToSql for #name { fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result> { out.extend_from_slice(&(#field_count_i32).to_be_bytes()); #write_pg_impl Ok(IsNull::No) } fn accepts(_ty: &Type) -> bool { true } postgres_types::to_sql_checked!(); } }; TokenStream::from(expanded) } fn extract_metadata(input: &DeriveInput) -> (Ident, usize) { // Extract SQL type from #[composite(sql_type = "...")] attribute let sql_type = input .attrs .iter() .find(|attr| attr.path().is_ident("composite")) .and_then(|attr| { attr.parse_args::().ok().and_then(|meta| { if let syn::Meta::NameValue(nv) = meta { if nv.path.is_ident("sql_type") { if let syn::Expr::Lit(lit) = nv.value { if let syn::Lit::Str(s) = lit.lit { return Some(syn::Ident::new(&s.value(), s.span())); } } } } None }) }) .expect("Missing #[composite(sql_type = \"...\")] attribute"); // Count fields let field_count = match &input.data { Data::Struct(data) => match &data.fields { Fields::Named(fields) => fields.named.len(), _ => panic!("Only named fields are supported"), }, _ => panic!("Only structs are supported"), }; (sql_type, field_count) } fn extract_fields(input: &DeriveInput) -> Vec<(Ident, Type)> { match &input.data { Data::Struct(data) => match &data.fields { Fields::Named(fields) => fields .named .iter() .map(|f| (f.ident.clone().unwrap(), f.ty.clone())) .collect(), _ => panic!("Only named fields are supported"), }, _ => panic!("Only structs are supported"), } } fn generate_from_reader(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { let field_reads = fields.iter().map(|(name, ty)| { let read_expr = generate_read_expr(name, ty); quote! { #name: #read_expr, } }); quote! { #(#field_reads)* } } fn generate_write_to(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { let field_writes = fields.iter().map(|(name, ty)| { generate_write_expr(name, ty) }); quote! { #(#field_writes)* } } fn generate_write_pg(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { let field_writes = fields.iter().map(|(name, ty)| { generate_pg_write_expr(name, ty) }); quote! { #(#field_writes)* } } fn generate_read_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { let type_str = quote!(#ty).to_string(); // Handle different types match type_str.as_str() { "String" => quote! { reader.read_text_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < String >" | "Option" => quote! { reader.read_text_field()? }, "Vec < u8 >" | "Vec" => quote! { reader.read_bytea_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < Vec < u8 > >" | "Option>" => quote! { reader.read_bytea_field()? }, "i32" => quote! { reader.read_int_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < i32 >" | "Option" => quote! { reader.read_int_field()? }, "i64" => quote! { reader.read_bigint_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < i64 >" | "Option" => quote! { reader.read_bigint_field()? }, "bool" => quote! { reader.read_bool_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < bool >" | "Option" => quote! { reader.read_bool_field()? }, "DateTime < Utc >" | "DateTime" => quote! { reader.read_timestamptz_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? }, "Option < DateTime < Utc > >" | "Option>" => quote! { reader.read_timestamptz_field()? }, "Option < serde_json :: Value >" | "Option" => quote! { reader.read_jsonb_field()? }, _ => { // Check if it's an Option type if type_str.starts_with("Option <") || type_str.starts_with("Option<") { // Extract inner type if let Type::Path(path) = ty { if let Some(segment) = path.path.segments.last() { if segment.ident == "Option" { if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { // Check if inner type is VideoCaption (composite type) let inner_type_str = quote!(#inner_ty).to_string(); if inner_type_str == "VideoCaption" { return quote! { reader.read_composite_field::<#inner_ty, PostVideoCaption>()? }; } else { return quote! { reader.read_enum_field::<#inner_ty>()? }; } } } } } } quote! { reader.read_enum_field::<#ty>()? } } else { // Check if it's VideoCaption or another composite type if type_str == "VideoCaption" { quote! { reader.read_composite_field::<#ty, PostVideoCaption>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? } } else { // Assume it's a required enum type quote! { reader.read_enum_field::<#ty>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? } } } } } } fn generate_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { let type_str = quote!(#ty).to_string(); match type_str.as_str() { "String" => quote! { writer.write_text_field(Some(&self.#name))?; }, "Option < String >" | "Option" => quote! { writer.write_text_field(self.#name.as_deref())?; }, "Vec < u8 >" | "Vec" => quote! { writer.write_bytea_field(Some(&self.#name))?; }, "Option < Vec < u8 > >" | "Option>" => quote! { writer.write_bytea_field(self.#name.as_deref())?; }, "i32" => quote! { writer.write_int_field(Some(self.#name))?; }, "Option < i32 >" | "Option" => quote! { writer.write_int_field(self.#name)?; }, "i64" => quote! { writer.write_bigint_field(Some(self.#name))?; }, "Option < i64 >" | "Option" => quote! { writer.write_bigint_field(self.#name)?; }, "bool" => quote! { writer.write_bool_field(Some(self.#name))?; }, "Option < bool >" | "Option" => quote! { writer.write_bool_field(self.#name)?; }, "DateTime < Utc >" | "DateTime" => quote! { writer.write_timestamptz_field(Some(&self.#name))?; }, "Option < DateTime < Utc > >" | "Option>" => quote! { writer.write_timestamptz_field(self.#name.as_ref())?; }, "Option < serde_json :: Value >" | "Option" => quote! { writer.write_jsonb_field(self.#name.as_ref())?; }, _ => { // Check if it's an Option type if type_str.starts_with("Option <") || type_str.starts_with("Option<") { // Check if it's Option if type_str.contains("VideoCaption") { quote! { writer.write_composite_field::(self.#name.as_ref())?; } } else { quote! { writer.write_enum_field(self.#name.as_ref())?; } } } else if type_str == "VideoCaption" { // Required VideoCaption quote! { writer.write_composite_field::(Some(&self.#name))?; } } else { // Required enum quote! { writer.write_enum_field(Some(&self.#name))?; } } } } } fn generate_pg_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { let type_str = quote!(#ty).to_string(); match type_str.as_str() { "String" => quote! { out.extend_from_slice(&25i32.to_be_bytes()); let bytes = self.#name.as_bytes(); out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); out.extend_from_slice(bytes); }, "Option < String >" | "Option" => quote! { out.extend_from_slice(&25i32.to_be_bytes()); if let Some(ref val) = self.#name { let bytes = val.as_bytes(); out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); out.extend_from_slice(bytes); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "Vec < u8 >" | "Vec" => quote! { out.extend_from_slice(&17i32.to_be_bytes()); out.extend_from_slice(&(self.#name.len() as i32).to_be_bytes()); out.extend_from_slice(&self.#name); }, "Option < Vec < u8 > >" | "Option>" => quote! { out.extend_from_slice(&17i32.to_be_bytes()); if let Some(ref val) = self.#name { out.extend_from_slice(&(val.len() as i32).to_be_bytes()); out.extend_from_slice(val); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "i32" => quote! { out.extend_from_slice(&23i32.to_be_bytes()); out.extend_from_slice(&4i32.to_be_bytes()); out.extend_from_slice(&self.#name.to_be_bytes()); }, "Option < i32 >" | "Option" => quote! { out.extend_from_slice(&23i32.to_be_bytes()); if let Some(val) = self.#name { out.extend_from_slice(&4i32.to_be_bytes()); out.extend_from_slice(&val.to_be_bytes()); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "i64" => quote! { out.extend_from_slice(&20i32.to_be_bytes()); out.extend_from_slice(&8i32.to_be_bytes()); out.extend_from_slice(&self.#name.to_be_bytes()); }, "Option < i64 >" | "Option" => quote! { out.extend_from_slice(&20i32.to_be_bytes()); if let Some(val) = self.#name { out.extend_from_slice(&8i32.to_be_bytes()); out.extend_from_slice(&val.to_be_bytes()); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "bool" => quote! { out.extend_from_slice(&16i32.to_be_bytes()); out.extend_from_slice(&1i32.to_be_bytes()); out.extend_from_slice(&[if self.#name { 1 } else { 0 }]); }, "Option < bool >" | "Option" => quote! { out.extend_from_slice(&16i32.to_be_bytes()); if let Some(val) = self.#name { out.extend_from_slice(&1i32.to_be_bytes()); out.extend_from_slice(&[if val { 1 } else { 0 }]); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "DateTime < Utc >" | "DateTime" => quote! { const PG_EPOCH_SECS: i64 = 946684800; let timestamp_secs = self.#name.timestamp(); let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (self.#name.timestamp_subsec_micros() as i64); out.extend_from_slice(&1184i32.to_be_bytes()); out.extend_from_slice(&8i32.to_be_bytes()); out.extend_from_slice(×tamp_micros.to_be_bytes()); }, "Option < DateTime < Utc > >" | "Option>" => quote! { out.extend_from_slice(&1184i32.to_be_bytes()); if let Some(ref dt) = self.#name { const PG_EPOCH_SECS: i64 = 946684800; let timestamp_secs = dt.timestamp(); let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (dt.timestamp_subsec_micros() as i64); out.extend_from_slice(&8i32.to_be_bytes()); out.extend_from_slice(×tamp_micros.to_be_bytes()); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, "Option < serde_json :: Value >" | "Option" => quote! { out.extend_from_slice(&3802i32.to_be_bytes()); if let Some(ref val) = self.#name { let json_bytes = serde_json::to_vec(val)?; let total_len = 1 + json_bytes.len(); out.extend_from_slice(&(total_len as i32).to_be_bytes()); out.extend_from_slice(&[1u8]); out.extend_from_slice(&json_bytes); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } }, _ => { // Check for VideoCaption composite type if type_str.contains("VideoCaption") { // Handle nested composite types if type_str.starts_with("Option <") || type_str.starts_with("Option<") { // Optional nested composite quote! { out.extend_from_slice(&0i32.to_be_bytes()); // Composite OID (0 for embedded composites) if let Some(ref caption) = self.#name { // Serialize nested composite: field count + fields let mut nested_buf = BytesMut::new(); nested_buf.extend_from_slice(&3i32.to_be_bytes()); // VideoCaption has 3 fields // lang field (enum as text) nested_buf.extend_from_slice(&25i32.to_be_bytes()); let lang_bytes = caption.lang.to_string().into_bytes(); nested_buf.extend_from_slice(&(lang_bytes.len() as i32).to_be_bytes()); nested_buf.extend_from_slice(&lang_bytes); // mime_type field (enum as text) nested_buf.extend_from_slice(&25i32.to_be_bytes()); let mime_bytes = caption.mime_type.to_string().into_bytes(); nested_buf.extend_from_slice(&(mime_bytes.len() as i32).to_be_bytes()); nested_buf.extend_from_slice(&mime_bytes); // cid field (bytea) nested_buf.extend_from_slice(&17i32.to_be_bytes()); nested_buf.extend_from_slice(&(caption.cid.len() as i32).to_be_bytes()); nested_buf.extend_from_slice(&caption.cid); out.extend_from_slice(&(nested_buf.len() as i32).to_be_bytes()); out.extend_from_slice(&nested_buf); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); // NULL } } } else { // Required nested composite - shouldn't happen for VideoCaption in our schema quote! { compile_error!("Required nested composite types not yet supported in PgToSql") } } } else if type_str.starts_with("Option <") || type_str.starts_with("Option<") { // For optional enums, serialize as text quote! { out.extend_from_slice(&25i32.to_be_bytes()); if let Some(ref val) = self.#name { let bytes = val.to_string().into_bytes(); out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); out.extend_from_slice(&bytes); } else { out.extend_from_slice(&(-1i32).to_be_bytes()); } } } else { quote! { out.extend_from_slice(&25i32.to_be_bytes()); let bytes = self.#name.to_string().into_bytes(); out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); out.extend_from_slice(&bytes); } } } } }