extern crate proc_macro; use proc_macro::TokenStream; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; use syn::Error; use syn::{parse_quote, DeriveInput}; pub fn derive_static_type(input: TokenStream) -> TokenStream { TokenStream::from(match derive_static_type_inner(input) { Ok(tokens) => tokens, Err(err) => err.to_compile_error(), }) } fn static_type(ty: T) -> proc_macro2::TokenStream where T: quote::ToTokens, { quote!( <#ty as ::serde_dhall::StaticType>::static_type() ) } fn derive_for_struct( data: &syn::DataStruct, constraints: &mut Vec, ) -> Result { let fields = match &data.fields { syn::Fields::Named(fields) => fields .named .iter() .map(|f| { let name = f.ident.as_ref().unwrap().to_string(); let ty = &f.ty; (name, ty) }) .collect(), syn::Fields::Unnamed(fields) => fields .unnamed .iter() .enumerate() .map(|(i, f)| { let name = format!("_{}", i + 1); let ty = &f.ty; (name, ty) }) .collect(), syn::Fields::Unit => vec![], }; let entries = fields.into_iter().map(|(name, ty)| { constraints.push(ty.clone()); let ty = static_type(ty); quote!( (#name.to_owned(), #ty) ) }); Ok(quote! { ::serde_dhall::SimpleType::Record( vec![ #(#entries),* ].into_iter().collect() ) }) } fn derive_for_enum( data: &syn::DataEnum, constraints: &mut Vec, ) -> Result { let entries: Vec<_> = data .variants .iter() .map(|v| { let name = v.ident.to_string(); match &v.fields { syn::Fields::Unit => Ok(quote!( (#name.to_owned(), None) )), syn::Fields::Unnamed(fields) if fields.unnamed.is_empty() => { Ok(quote!( (#name.to_owned(), None) )) } syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { let ty = &fields.unnamed.iter().next().unwrap().ty; constraints.push(ty.clone()); let ty = static_type(ty); Ok(quote!( (#name.to_owned(), Some(#ty)) )) }, syn::Fields::Unnamed(_) => Err(Error::new( v.span(), "Derive StaticType: Variants with more than one field are not supported", )), syn::Fields::Named(fields) => { let entries = fields .named .iter() .map(|field| { constraints.push(field.ty.clone()); let ty = static_type(&field.ty); let name = field.ident.as_ref().unwrap().to_string(); quote!( (#name.to_owned(), #ty) ) }); let record = quote! {::serde_dhall::SimpleType::Record( vec![ #(#entries),* ].into_iter().collect() )}; Ok(quote!( (#name.to_owned(), Some(#record)) )) } } }) .collect::>()?; Ok(quote! { ::serde_dhall::SimpleType::Union( vec![ #(#entries),* ].into_iter().collect() ) }) } pub fn derive_static_type_inner( input: TokenStream, ) -> Result { let input: DeriveInput = syn::parse_macro_input::parse(input)?; // List of types that must impl StaticType let mut constraints = vec![]; let get_type = match &input.data { syn::Data::Struct(data) => derive_for_struct(data, &mut constraints)?, syn::Data::Enum(data) if data.variants.is_empty() => { return Err(Error::new( input.span(), "Derive StaticType: Empty enums are not supported", )) } syn::Data::Enum(data) => derive_for_enum(data, &mut constraints)?, syn::Data::Union(x) => { return Err(Error::new( x.union_token.span(), "Derive StaticType: Unions are not supported", )) } }; let mut generics = input.generics.clone(); generics.make_where_clause(); let (impl_generics, ty_generics, orig_where_clause) = generics.split_for_impl(); let orig_where_clause = orig_where_clause.unwrap(); // Hygienic errors let assertions = constraints.iter().enumerate().map(|(i, ty)| { // Ensure that ty: StaticType, with an appropriate span let assert_name = syn::Ident::new(&format!("_AssertType{}", i), ty.span()); let mut local_where_clause = orig_where_clause.clone(); local_where_clause .predicates .push(parse_quote!(#ty: ::serde_dhall::StaticType)); let phantoms = generics.params.iter().map(|param| match param { syn::GenericParam::Type(syn::TypeParam { ident, .. }) => { quote!(#ident) } syn::GenericParam::Lifetime(syn::LifetimeDef { lifetime, .. }) => quote!(&#lifetime ()), _ => unimplemented!(), }); quote_spanned! {ty.span()=> struct #assert_name #impl_generics #local_where_clause { _phantom: std::marker::PhantomData<(#(#phantoms),*)> } } }); // Ensure that all the fields have a StaticType impl let mut where_clause = orig_where_clause.clone(); for ty in constraints.iter() { where_clause .predicates .push(parse_quote!(#ty: ::serde_dhall::StaticType)); } let ident = &input.ident; let tokens = quote! { #[allow(unused_parens)] impl #impl_generics ::serde_dhall::StaticType for #ident #ty_generics #where_clause { fn static_type() -> ::serde_dhall::SimpleType { #(#assertions)* #get_type } } }; Ok(tokens) }