summaryrefslogtreecommitdiff
path: root/dhall_proc_macros/src
diff options
context:
space:
mode:
Diffstat (limited to 'dhall_proc_macros/src')
-rw-r--r--dhall_proc_macros/src/derive.rs179
-rw-r--r--dhall_proc_macros/src/lib.rs26
-rw-r--r--dhall_proc_macros/src/quote.rs227
3 files changed, 432 insertions, 0 deletions
diff --git a/dhall_proc_macros/src/derive.rs b/dhall_proc_macros/src/derive.rs
new file mode 100644
index 0000000..bcefb17
--- /dev/null
+++ b/dhall_proc_macros/src/derive.rs
@@ -0,0 +1,179 @@
+extern crate proc_macro;
+// use dhall_syntax::*;
+use proc_macro::TokenStream;
+use quote::{quote, quote_spanned};
+use syn::spanned::Spanned;
+use syn::Error;
+use syn::{parse_quote, DeriveInput};
+
+pub fn derive_simple_static_type(input: TokenStream) -> TokenStream {
+ TokenStream::from(match derive_simple_static_type_inner(input) {
+ Ok(tokens) => tokens,
+ Err(err) => err.to_compile_error(),
+ })
+}
+
+fn get_simple_static_type<T>(ty: T) -> proc_macro2::TokenStream
+where
+ T: quote::ToTokens,
+{
+ quote!(
+ <#ty as ::dhall::de::SimpleStaticType>::get_simple_static_type()
+ )
+}
+
+fn derive_for_struct(
+ data: &syn::DataStruct,
+ constraints: &mut Vec<syn::Type>,
+) -> Result<proc_macro2::TokenStream, Error> {
+ let fields = match &data.fields {
+ syn::Fields::Named(fields) => fields
+ .named
+ .iter()
+ .map(|f| {
+ let name = f.ident.as_ref().unwrap().to_string();
+ let ty = &f.ty;
+ (name, ty)
+ })
+ .collect(),
+ syn::Fields::Unnamed(fields) => fields
+ .unnamed
+ .iter()
+ .enumerate()
+ .map(|(i, f)| {
+ let name = format!("_{}", i + 1);
+ let ty = &f.ty;
+ (name, ty)
+ })
+ .collect(),
+ syn::Fields::Unit => vec![],
+ };
+ let fields = fields
+ .into_iter()
+ .map(|(name, ty)| {
+ let name = dhall_syntax::Label::from(name);
+ constraints.push(ty.clone());
+ let ty = get_simple_static_type(ty);
+ (name, quote!(#ty.into()))
+ })
+ .collect();
+ let record =
+ crate::quote::quote_exprf(dhall_syntax::ExprF::RecordType(fields));
+ Ok(quote! { dhall_syntax::rc(#record) })
+}
+
+fn derive_for_enum(
+ data: &syn::DataEnum,
+ constraints: &mut Vec<syn::Type>,
+) -> Result<proc_macro2::TokenStream, Error> {
+ let variants = data
+ .variants
+ .iter()
+ .map(|v| {
+ let name = dhall_syntax::Label::from(v.ident.to_string());
+ match &v.fields {
+ syn::Fields::Unit => Ok((name, None)),
+ syn::Fields::Unnamed(fields) if fields.unnamed.is_empty() => {
+ Ok((name, None))
+ }
+ syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
+ let ty = &fields.unnamed.iter().next().unwrap().ty;
+ constraints.push(ty.clone());
+ let ty = get_simple_static_type(ty);
+ Ok((name, Some(quote!(#ty.into()))))
+ }
+ syn::Fields::Unnamed(_) => Err(Error::new(
+ v.span(),
+ "Variants with more than one field are not supported",
+ )),
+ syn::Fields::Named(_) => Err(Error::new(
+ v.span(),
+ "Named variants are not supported",
+ )),
+ }
+ })
+ .collect::<Result<_, Error>>()?;
+
+ let union =
+ crate::quote::quote_exprf(dhall_syntax::ExprF::UnionType(variants));
+ Ok(quote! { dhall_syntax::rc(#union) })
+}
+
+pub fn derive_simple_static_type_inner(
+ input: TokenStream,
+) -> Result<proc_macro2::TokenStream, Error> {
+ let input: DeriveInput = syn::parse_macro_input::parse(input)?;
+
+ // List of types that must impl Type
+ let mut constraints = vec![];
+
+ let get_type = match &input.data {
+ syn::Data::Struct(data) => derive_for_struct(data, &mut constraints)?,
+ syn::Data::Enum(data) if data.variants.is_empty() => {
+ return Err(Error::new(
+ input.span(),
+ "Empty enums are not supported",
+ ))
+ }
+ syn::Data::Enum(data) => derive_for_enum(data, &mut constraints)?,
+ syn::Data::Union(x) => {
+ return Err(Error::new(
+ x.union_token.span(),
+ "Unions are not supported",
+ ))
+ }
+ };
+
+ let mut generics = input.generics.clone();
+ generics.make_where_clause();
+ let (impl_generics, ty_generics, orig_where_clause) =
+ generics.split_for_impl();
+ let orig_where_clause = orig_where_clause.unwrap();
+
+ // Hygienic errors
+ let assertions = constraints.iter().enumerate().map(|(i, ty)| {
+ // Ensure that ty: Type, with an appropriate span
+ let assert_name =
+ syn::Ident::new(&format!("_AssertType{}", i), ty.span());
+ let mut local_where_clause = orig_where_clause.clone();
+ local_where_clause
+ .predicates
+ .push(parse_quote!(#ty: ::dhall::de::SimpleStaticType));
+ let phantoms = generics.params.iter().map(|param| match param {
+ syn::GenericParam::Type(syn::TypeParam { ident, .. }) => {
+ quote!(#ident)
+ }
+ syn::GenericParam::Lifetime(syn::LifetimeDef {
+ lifetime, ..
+ }) => quote!(&#lifetime ()),
+ _ => unimplemented!(),
+ });
+ quote_spanned! {ty.span()=>
+ struct #assert_name #impl_generics #local_where_clause {
+ _phantom: std::marker::PhantomData<(#(#phantoms),*)>
+ };
+ }
+ });
+
+ // Ensure that all the fields have a Type impl
+ let mut where_clause = orig_where_clause.clone();
+ for ty in constraints.iter() {
+ where_clause
+ .predicates
+ .push(parse_quote!(#ty: ::dhall::de::SimpleStaticType));
+ }
+
+ let ident = &input.ident;
+ let tokens = quote! {
+ impl #impl_generics ::dhall::de::SimpleStaticType
+ for #ident #ty_generics
+ #where_clause {
+ fn get_simple_static_type<'get_simple_static_type>() ->
+ ::dhall::expr::SimpleType<'get_simple_static_type> {
+ #(#assertions)*
+ ::dhall::expr::SimpleType::from(#get_type)
+ }
+ }
+ };
+ Ok(tokens)
+}
diff --git a/dhall_proc_macros/src/lib.rs b/dhall_proc_macros/src/lib.rs
new file mode 100644
index 0000000..1124968
--- /dev/null
+++ b/dhall_proc_macros/src/lib.rs
@@ -0,0 +1,26 @@
+//! This crate contains the code-generation primitives for the [dhall-rust][dhall-rust] crate.
+//! This is highly unstable and breaks regularly; use at your own risk.
+//!
+//! [dhall-rust]: https://github.com/Nadrieril/dhall-rust
+
+extern crate proc_macro;
+
+mod derive;
+mod quote;
+
+use proc_macro::TokenStream;
+
+#[proc_macro]
+pub fn expr(input: TokenStream) -> TokenStream {
+ quote::expr(input)
+}
+
+#[proc_macro]
+pub fn subexpr(input: TokenStream) -> TokenStream {
+ quote::subexpr(input)
+}
+
+#[proc_macro_derive(SimpleStaticType)]
+pub fn derive_simple_static_type(input: TokenStream) -> TokenStream {
+ derive::derive_simple_static_type(input)
+}
diff --git a/dhall_proc_macros/src/quote.rs b/dhall_proc_macros/src/quote.rs
new file mode 100644
index 0000000..c2323fa
--- /dev/null
+++ b/dhall_proc_macros/src/quote.rs
@@ -0,0 +1,227 @@
+extern crate proc_macro;
+use dhall_syntax::context::Context;
+use dhall_syntax::*;
+use proc_macro2::TokenStream;
+use quote::quote;
+use std::collections::BTreeMap;
+
+pub fn expr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input_str = input.to_string();
+ let expr: SubExpr<_, Import> = parse_expr(&input_str).unwrap().unnote();
+ let no_import =
+ |_: &Import| -> X { panic!("Don't use import in dhall::expr!()") };
+ let expr = expr.map_embed(no_import);
+ let output = quote_expr(&expr.unroll(), &Context::new());
+ output.into()
+}
+
+pub fn subexpr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input_str = input.to_string();
+ let expr: SubExpr<_, Import> = parse_expr(&input_str).unwrap().unnote();
+ let no_import =
+ |_: &Import| -> X { panic!("Don't use import in dhall::subexpr!()") };
+ let expr = expr.map_embed(no_import);
+ let output = quote_subexpr(&expr, &Context::new());
+ output.into()
+}
+
+// Returns an expression of type ExprF<T, _, _>, where T is the
+// type of the subexpressions after interpolation.
+pub fn quote_exprf<TS>(expr: ExprF<TS, Label, X, X>) -> TokenStream
+where
+ TS: quote::ToTokens + std::fmt::Debug,
+{
+ use dhall_syntax::ExprF::*;
+ match expr {
+ Var(_) => unreachable!(),
+ Pi(x, t, b) => {
+ let x = quote_label(&x);
+ quote! { dhall_syntax::ExprF::Pi(#x, #t, #b) }
+ }
+ Lam(x, t, b) => {
+ let x = quote_label(&x);
+ quote! { dhall_syntax::ExprF::Lam(#x, #t, #b) }
+ }
+ App(f, a) => {
+ quote! { dhall_syntax::ExprF::App(#f, #a) }
+ }
+ Annot(x, t) => {
+ quote! { dhall_syntax::ExprF::Annot(#x, #t) }
+ }
+ Const(c) => {
+ let c = quote_const(c);
+ quote! { dhall_syntax::ExprF::Const(#c) }
+ }
+ Builtin(b) => {
+ let b = quote_builtin(b);
+ quote! { dhall_syntax::ExprF::Builtin(#b) }
+ }
+ BinOp(o, a, b) => {
+ let o = quote_binop(o);
+ quote! { dhall_syntax::ExprF::BinOp(#o, #a, #b) }
+ }
+ NaturalLit(n) => {
+ quote! { dhall_syntax::ExprF::NaturalLit(#n) }
+ }
+ BoolLit(b) => {
+ quote! { dhall_syntax::ExprF::BoolLit(#b) }
+ }
+ SomeLit(x) => {
+ quote! { dhall_syntax::ExprF::SomeLit(#x) }
+ }
+ EmptyListLit(t) => {
+ quote! { dhall_syntax::ExprF::EmptyListLit(#t) }
+ }
+ NEListLit(es) => {
+ let es = quote_vec(es);
+ quote! { dhall_syntax::ExprF::NEListLit(#es) }
+ }
+ RecordType(m) => {
+ let m = quote_map(m);
+ quote! { dhall_syntax::ExprF::RecordType(#m) }
+ }
+ RecordLit(m) => {
+ let m = quote_map(m);
+ quote! { dhall_syntax::ExprF::RecordLit(#m) }
+ }
+ UnionType(m) => {
+ let m = quote_opt_map(m);
+ quote! { dhall_syntax::ExprF::UnionType(#m) }
+ }
+ e => unimplemented!("{:?}", e),
+ }
+}
+
+// Returns an expression of type SubExpr<_, _>. Expects interpolated variables
+// to be of type SubExpr<_, _>.
+fn quote_subexpr(
+ expr: &SubExpr<X, X>,
+ ctx: &Context<Label, ()>,
+) -> TokenStream {
+ use dhall_syntax::ExprF::*;
+ match expr.as_ref().map_ref_with_special_handling_of_binders(
+ |e| quote_subexpr(e, ctx),
+ |l, e| quote_subexpr(e, &ctx.insert(l.clone(), ())),
+ |_| unreachable!(),
+ |_| unreachable!(),
+ Label::clone,
+ ) {
+ Var(V(ref s, n)) => {
+ match ctx.lookup(s, n) {
+ // Non-free variable; interpolates as itself
+ Some(()) => {
+ let s: String = s.into();
+ let var = quote! { dhall_syntax::V(#s.into(), #n) };
+ rc(quote! { dhall_syntax::ExprF::Var(#var) })
+ }
+ // Free variable; interpolates as a rust variable
+ None => {
+ let s: String = s.into();
+ // TODO: insert appropriate shifts ?
+ let v: TokenStream = s.parse().unwrap();
+ quote! { {
+ let x: dhall_syntax::SubExpr<_, _> = #v.clone();
+ x
+ } }
+ }
+ }
+ }
+ e => rc(quote_exprf(e)),
+ }
+}
+
+// Returns an expression of type Expr<_, _>. Expects interpolated variables
+// to be of type SubExpr<_, _>.
+fn quote_expr(expr: &Expr<X, X>, ctx: &Context<Label, ()>) -> TokenStream {
+ use dhall_syntax::ExprF::*;
+ match expr.map_ref_with_special_handling_of_binders(
+ |e| quote_subexpr(e, ctx),
+ |l, e| quote_subexpr(e, &ctx.insert(l.clone(), ())),
+ |_| unreachable!(),
+ |_| unreachable!(),
+ Label::clone,
+ ) {
+ Var(V(ref s, n)) => {
+ match ctx.lookup(s, n) {
+ // Non-free variable; interpolates as itself
+ Some(()) => {
+ let s: String = s.into();
+ let var = quote! { dhall_syntax::V(#s.into(), #n) };
+ quote! { dhall_syntax::ExprF::Var(#var) }
+ }
+ // Free variable; interpolates as a rust variable
+ None => {
+ let s: String = s.into();
+ // TODO: insert appropriate shifts ?
+ let v: TokenStream = s.parse().unwrap();
+ quote! { {
+ let x: dhall_syntax::SubExpr<_, _> = #v.clone();
+ x.unroll()
+ } }
+ }
+ }
+ }
+ e => quote_exprf(e),
+ }
+}
+
+fn quote_builtin(b: Builtin) -> TokenStream {
+ format!("dhall_syntax::Builtin::{:?}", b).parse().unwrap()
+}
+
+fn quote_const(c: Const) -> TokenStream {
+ format!("dhall_syntax::Const::{:?}", c).parse().unwrap()
+}
+
+fn quote_binop(b: BinOp) -> TokenStream {
+ format!("dhall_syntax::BinOp::{:?}", b).parse().unwrap()
+}
+
+fn quote_label(l: &Label) -> TokenStream {
+ let l = String::from(l);
+ quote! { dhall_syntax::Label::from(#l) }
+}
+
+fn rc(x: TokenStream) -> TokenStream {
+ quote! { dhall_syntax::rc(#x) }
+}
+
+fn quote_opt<TS>(x: Option<TS>) -> TokenStream
+where
+ TS: quote::ToTokens + std::fmt::Debug,
+{
+ match x {
+ Some(x) => quote!(Some(#x)),
+ None => quote!(None),
+ }
+}
+
+fn quote_vec<TS>(e: Vec<TS>) -> TokenStream
+where
+ TS: quote::ToTokens + std::fmt::Debug,
+{
+ quote! { vec![ #(#e),* ] }
+}
+
+fn quote_map<TS>(m: BTreeMap<Label, TS>) -> TokenStream
+where
+ TS: quote::ToTokens + std::fmt::Debug,
+{
+ let entries = m.into_iter().map(|(k, v)| {
+ let k = quote_label(&k);
+ quote!(m.insert(#k, #v);)
+ });
+ quote! { {
+ use std::collections::BTreeMap;
+ let mut m = BTreeMap::new();
+ #( #entries )*
+ m
+ } }
+}
+
+fn quote_opt_map<TS>(m: BTreeMap<Label, Option<TS>>) -> TokenStream
+where
+ TS: quote::ToTokens + std::fmt::Debug,
+{
+ quote_map(m.into_iter().map(|(k, v)| (k, quote_opt(v))).collect())
+}