Rust AppView - highly experimental!
at experiments 429 lines 19 kB view raw
1use proc_macro::TokenStream; 2use quote::quote; 3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Type}; 4 5#[proc_macro_derive(PostgresComposite, attributes(composite))] 6pub fn derive_postgres_composite(input: TokenStream) -> TokenStream { 7 let input = parse_macro_input!(input as DeriveInput); 8 9 let name = &input.ident; 10 let (sql_type, field_count) = extract_metadata(&input); 11 let fields = extract_fields(&input); 12 13 let from_reader_impl = generate_from_reader(&fields); 14 let write_to_impl = generate_write_to(&fields); 15 let write_pg_impl = generate_write_pg(&fields); 16 17 let field_count_i32 = field_count as i32; 18 19 let expanded = quote! { 20 impl FromSql<#sql_type, Pg> for #name { 21 fn from_sql(bytes: <Pg as diesel::backend::Backend>::RawValue<'_>) -> deserialize::Result<Self> { 22 let bytes = bytes.as_bytes(); 23 let mut reader = CompositeReader::new(bytes); 24 25 let field_count = reader.read_field_count()?; 26 if field_count != #field_count_i32 { 27 return Err(format!("{}: expected {} fields, got {}", 28 stringify!(#name), #field_count_i32, field_count).into()); 29 } 30 31 Ok(Self { 32 #from_reader_impl 33 }) 34 } 35 } 36 37 impl ToSql<#sql_type, Pg> for #name { 38 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { 39 CompositeWriter::write_header(out, #field_count_i32)?; 40 let mut writer = CompositeWriter::with_output(out); 41 #write_to_impl 42 Ok(serialize::IsNull::No) 43 } 44 } 45 46 impl PgToSql for #name { 47 fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> { 48 out.extend_from_slice(&(#field_count_i32).to_be_bytes()); 49 #write_pg_impl 50 Ok(IsNull::No) 51 } 52 53 fn accepts(_ty: &Type) -> bool { 54 true 55 } 56 57 postgres_types::to_sql_checked!(); 58 } 59 }; 60 61 TokenStream::from(expanded) 62} 63 64fn extract_metadata(input: &DeriveInput) -> (Ident, usize) { 65 // Extract SQL type from #[composite(sql_type = "...")] attribute 66 let sql_type = input 67 .attrs 68 .iter() 69 .find(|attr| attr.path().is_ident("composite")) 70 .and_then(|attr| { 71 attr.parse_args::<syn::Meta>().ok().and_then(|meta| { 72 if let syn::Meta::NameValue(nv) = meta { 73 if nv.path.is_ident("sql_type") { 74 if let syn::Expr::Lit(lit) = nv.value { 75 if let syn::Lit::Str(s) = lit.lit { 76 return Some(syn::Ident::new(&s.value(), s.span())); 77 } 78 } 79 } 80 } 81 None 82 }) 83 }) 84 .expect("Missing #[composite(sql_type = \"...\")] attribute"); 85 86 // Count fields 87 let field_count = match &input.data { 88 Data::Struct(data) => match &data.fields { 89 Fields::Named(fields) => fields.named.len(), 90 _ => panic!("Only named fields are supported"), 91 }, 92 _ => panic!("Only structs are supported"), 93 }; 94 95 (sql_type, field_count) 96} 97 98fn extract_fields(input: &DeriveInput) -> Vec<(Ident, Type)> { 99 match &input.data { 100 Data::Struct(data) => match &data.fields { 101 Fields::Named(fields) => fields 102 .named 103 .iter() 104 .map(|f| (f.ident.clone().unwrap(), f.ty.clone())) 105 .collect(), 106 _ => panic!("Only named fields are supported"), 107 }, 108 _ => panic!("Only structs are supported"), 109 } 110} 111 112fn generate_from_reader(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { 113 let field_reads = fields.iter().map(|(name, ty)| { 114 let read_expr = generate_read_expr(name, ty); 115 quote! { #name: #read_expr, } 116 }); 117 118 quote! { #(#field_reads)* } 119} 120 121fn generate_write_to(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { 122 let field_writes = fields.iter().map(|(name, ty)| { 123 generate_write_expr(name, ty) 124 }); 125 126 quote! { #(#field_writes)* } 127} 128 129fn generate_write_pg(fields: &[(Ident, Type)]) -> proc_macro2::TokenStream { 130 let field_writes = fields.iter().map(|(name, ty)| { 131 generate_pg_write_expr(name, ty) 132 }); 133 134 quote! { #(#field_writes)* } 135} 136 137fn generate_read_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { 138 let type_str = quote!(#ty).to_string(); 139 140 // Handle different types 141 match type_str.as_str() { 142 "String" => quote! { 143 reader.read_text_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 144 }, 145 "Option < String >" | "Option<String>" => quote! { 146 reader.read_text_field()? 147 }, 148 "Vec < u8 >" | "Vec<u8>" => quote! { 149 reader.read_bytea_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 150 }, 151 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! { 152 reader.read_bytea_field()? 153 }, 154 "i32" => quote! { 155 reader.read_int_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 156 }, 157 "Option < i32 >" | "Option<i32>" => quote! { 158 reader.read_int_field()? 159 }, 160 "i64" => quote! { 161 reader.read_bigint_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 162 }, 163 "Option < i64 >" | "Option<i64>" => quote! { 164 reader.read_bigint_field()? 165 }, 166 "bool" => quote! { 167 reader.read_bool_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 168 }, 169 "Option < bool >" | "Option<bool>" => quote! { 170 reader.read_bool_field()? 171 }, 172 "DateTime < Utc >" | "DateTime<Utc>" => quote! { 173 reader.read_timestamptz_field()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 174 }, 175 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! { 176 reader.read_timestamptz_field()? 177 }, 178 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! { 179 reader.read_jsonb_field()? 180 }, 181 _ => { 182 // Check if it's an Option type 183 if type_str.starts_with("Option <") || type_str.starts_with("Option<") { 184 // Extract inner type 185 if let Type::Path(path) = ty { 186 if let Some(segment) = path.path.segments.last() { 187 if segment.ident == "Option" { 188 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { 189 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { 190 // Check if inner type is VideoCaption (composite type) 191 let inner_type_str = quote!(#inner_ty).to_string(); 192 if inner_type_str == "VideoCaption" { 193 return quote! { reader.read_composite_field::<#inner_ty, PostVideoCaption>()? }; 194 } else { 195 return quote! { reader.read_enum_field::<#inner_ty>()? }; 196 } 197 } 198 } 199 } 200 } 201 } 202 quote! { reader.read_enum_field::<#ty>()? } 203 } else { 204 // Check if it's VideoCaption or another composite type 205 if type_str == "VideoCaption" { 206 quote! { 207 reader.read_composite_field::<#ty, PostVideoCaption>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 208 } 209 } else { 210 // Assume it's a required enum type 211 quote! { 212 reader.read_enum_field::<#ty>()?.ok_or(concat!(stringify!(#name), " cannot be NULL"))? 213 } 214 } 215 } 216 } 217 } 218} 219 220fn generate_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { 221 let type_str = quote!(#ty).to_string(); 222 223 match type_str.as_str() { 224 "String" => quote! { writer.write_text_field(Some(&self.#name))?; }, 225 "Option < String >" | "Option<String>" => quote! { writer.write_text_field(self.#name.as_deref())?; }, 226 "Vec < u8 >" | "Vec<u8>" => quote! { writer.write_bytea_field(Some(&self.#name))?; }, 227 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! { writer.write_bytea_field(self.#name.as_deref())?; }, 228 "i32" => quote! { writer.write_int_field(Some(self.#name))?; }, 229 "Option < i32 >" | "Option<i32>" => quote! { writer.write_int_field(self.#name)?; }, 230 "i64" => quote! { writer.write_bigint_field(Some(self.#name))?; }, 231 "Option < i64 >" | "Option<i64>" => quote! { writer.write_bigint_field(self.#name)?; }, 232 "bool" => quote! { writer.write_bool_field(Some(self.#name))?; }, 233 "Option < bool >" | "Option<bool>" => quote! { writer.write_bool_field(self.#name)?; }, 234 "DateTime < Utc >" | "DateTime<Utc>" => quote! { writer.write_timestamptz_field(Some(&self.#name))?; }, 235 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! { writer.write_timestamptz_field(self.#name.as_ref())?; }, 236 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! { writer.write_jsonb_field(self.#name.as_ref())?; }, 237 _ => { 238 // Check if it's an Option type 239 if type_str.starts_with("Option <") || type_str.starts_with("Option<") { 240 // Check if it's Option<VideoCaption> 241 if type_str.contains("VideoCaption") { 242 quote! { writer.write_composite_field::<VideoCaption, PostVideoCaption>(self.#name.as_ref())?; } 243 } else { 244 quote! { writer.write_enum_field(self.#name.as_ref())?; } 245 } 246 } else if type_str == "VideoCaption" { 247 // Required VideoCaption 248 quote! { writer.write_composite_field::<VideoCaption, PostVideoCaption>(Some(&self.#name))?; } 249 } else { 250 // Required enum 251 quote! { writer.write_enum_field(Some(&self.#name))?; } 252 } 253 } 254 } 255} 256 257fn generate_pg_write_expr(name: &Ident, ty: &Type) -> proc_macro2::TokenStream { 258 let type_str = quote!(#ty).to_string(); 259 260 match type_str.as_str() { 261 "String" => quote! { 262 out.extend_from_slice(&25i32.to_be_bytes()); 263 let bytes = self.#name.as_bytes(); 264 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); 265 out.extend_from_slice(bytes); 266 }, 267 "Option < String >" | "Option<String>" => quote! { 268 out.extend_from_slice(&25i32.to_be_bytes()); 269 if let Some(ref val) = self.#name { 270 let bytes = val.as_bytes(); 271 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); 272 out.extend_from_slice(bytes); 273 } else { 274 out.extend_from_slice(&(-1i32).to_be_bytes()); 275 } 276 }, 277 "Vec < u8 >" | "Vec<u8>" => quote! { 278 out.extend_from_slice(&17i32.to_be_bytes()); 279 out.extend_from_slice(&(self.#name.len() as i32).to_be_bytes()); 280 out.extend_from_slice(&self.#name); 281 }, 282 "Option < Vec < u8 > >" | "Option<Vec<u8>>" => quote! { 283 out.extend_from_slice(&17i32.to_be_bytes()); 284 if let Some(ref val) = self.#name { 285 out.extend_from_slice(&(val.len() as i32).to_be_bytes()); 286 out.extend_from_slice(val); 287 } else { 288 out.extend_from_slice(&(-1i32).to_be_bytes()); 289 } 290 }, 291 "i32" => quote! { 292 out.extend_from_slice(&23i32.to_be_bytes()); 293 out.extend_from_slice(&4i32.to_be_bytes()); 294 out.extend_from_slice(&self.#name.to_be_bytes()); 295 }, 296 "Option < i32 >" | "Option<i32>" => quote! { 297 out.extend_from_slice(&23i32.to_be_bytes()); 298 if let Some(val) = self.#name { 299 out.extend_from_slice(&4i32.to_be_bytes()); 300 out.extend_from_slice(&val.to_be_bytes()); 301 } else { 302 out.extend_from_slice(&(-1i32).to_be_bytes()); 303 } 304 }, 305 "i64" => quote! { 306 out.extend_from_slice(&20i32.to_be_bytes()); 307 out.extend_from_slice(&8i32.to_be_bytes()); 308 out.extend_from_slice(&self.#name.to_be_bytes()); 309 }, 310 "Option < i64 >" | "Option<i64>" => quote! { 311 out.extend_from_slice(&20i32.to_be_bytes()); 312 if let Some(val) = self.#name { 313 out.extend_from_slice(&8i32.to_be_bytes()); 314 out.extend_from_slice(&val.to_be_bytes()); 315 } else { 316 out.extend_from_slice(&(-1i32).to_be_bytes()); 317 } 318 }, 319 "bool" => quote! { 320 out.extend_from_slice(&16i32.to_be_bytes()); 321 out.extend_from_slice(&1i32.to_be_bytes()); 322 out.extend_from_slice(&[if self.#name { 1 } else { 0 }]); 323 }, 324 "Option < bool >" | "Option<bool>" => quote! { 325 out.extend_from_slice(&16i32.to_be_bytes()); 326 if let Some(val) = self.#name { 327 out.extend_from_slice(&1i32.to_be_bytes()); 328 out.extend_from_slice(&[if val { 1 } else { 0 }]); 329 } else { 330 out.extend_from_slice(&(-1i32).to_be_bytes()); 331 } 332 }, 333 "DateTime < Utc >" | "DateTime<Utc>" => quote! { 334 const PG_EPOCH_SECS: i64 = 946684800; 335 let timestamp_secs = self.#name.timestamp(); 336 let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (self.#name.timestamp_subsec_micros() as i64); 337 out.extend_from_slice(&1184i32.to_be_bytes()); 338 out.extend_from_slice(&8i32.to_be_bytes()); 339 out.extend_from_slice(&timestamp_micros.to_be_bytes()); 340 }, 341 "Option < DateTime < Utc > >" | "Option<DateTime<Utc>>" => quote! { 342 out.extend_from_slice(&1184i32.to_be_bytes()); 343 if let Some(ref dt) = self.#name { 344 const PG_EPOCH_SECS: i64 = 946684800; 345 let timestamp_secs = dt.timestamp(); 346 let timestamp_micros = (timestamp_secs - PG_EPOCH_SECS) * 1_000_000 + (dt.timestamp_subsec_micros() as i64); 347 out.extend_from_slice(&8i32.to_be_bytes()); 348 out.extend_from_slice(&timestamp_micros.to_be_bytes()); 349 } else { 350 out.extend_from_slice(&(-1i32).to_be_bytes()); 351 } 352 }, 353 "Option < serde_json :: Value >" | "Option<serde_json::Value>" => quote! { 354 out.extend_from_slice(&3802i32.to_be_bytes()); 355 if let Some(ref val) = self.#name { 356 let json_bytes = serde_json::to_vec(val)?; 357 let total_len = 1 + json_bytes.len(); 358 out.extend_from_slice(&(total_len as i32).to_be_bytes()); 359 out.extend_from_slice(&[1u8]); 360 out.extend_from_slice(&json_bytes); 361 } else { 362 out.extend_from_slice(&(-1i32).to_be_bytes()); 363 } 364 }, 365 _ => { 366 // Check for VideoCaption composite type 367 if type_str.contains("VideoCaption") { 368 // Handle nested composite types 369 if type_str.starts_with("Option <") || type_str.starts_with("Option<") { 370 // Optional nested composite 371 quote! { 372 out.extend_from_slice(&0i32.to_be_bytes()); // Composite OID (0 for embedded composites) 373 if let Some(ref caption) = self.#name { 374 // Serialize nested composite: field count + fields 375 let mut nested_buf = BytesMut::new(); 376 nested_buf.extend_from_slice(&3i32.to_be_bytes()); // VideoCaption has 3 fields 377 378 // lang field (enum as text) 379 nested_buf.extend_from_slice(&25i32.to_be_bytes()); 380 let lang_bytes = caption.lang.to_string().into_bytes(); 381 nested_buf.extend_from_slice(&(lang_bytes.len() as i32).to_be_bytes()); 382 nested_buf.extend_from_slice(&lang_bytes); 383 384 // mime_type field (enum as text) 385 nested_buf.extend_from_slice(&25i32.to_be_bytes()); 386 let mime_bytes = caption.mime_type.to_string().into_bytes(); 387 nested_buf.extend_from_slice(&(mime_bytes.len() as i32).to_be_bytes()); 388 nested_buf.extend_from_slice(&mime_bytes); 389 390 // cid field (bytea) 391 nested_buf.extend_from_slice(&17i32.to_be_bytes()); 392 nested_buf.extend_from_slice(&(caption.cid.len() as i32).to_be_bytes()); 393 nested_buf.extend_from_slice(&caption.cid); 394 395 out.extend_from_slice(&(nested_buf.len() as i32).to_be_bytes()); 396 out.extend_from_slice(&nested_buf); 397 } else { 398 out.extend_from_slice(&(-1i32).to_be_bytes()); // NULL 399 } 400 } 401 } else { 402 // Required nested composite - shouldn't happen for VideoCaption in our schema 403 quote! { 404 compile_error!("Required nested composite types not yet supported in PgToSql") 405 } 406 } 407 } else if type_str.starts_with("Option <") || type_str.starts_with("Option<") { 408 // For optional enums, serialize as text 409 quote! { 410 out.extend_from_slice(&25i32.to_be_bytes()); 411 if let Some(ref val) = self.#name { 412 let bytes = val.to_string().into_bytes(); 413 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); 414 out.extend_from_slice(&bytes); 415 } else { 416 out.extend_from_slice(&(-1i32).to_be_bytes()); 417 } 418 } 419 } else { 420 quote! { 421 out.extend_from_slice(&25i32.to_be_bytes()); 422 let bytes = self.#name.to_string().into_bytes(); 423 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes()); 424 out.extend_from_slice(&bytes); 425 } 426 } 427 } 428 } 429}