From 0f5b69ac599a4b6979de465cf520ff5c283cfdbd Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Sun, 7 Apr 2019 14:17:38 +0200 Subject: Use macros in typecheck; much cleaner --- dhall/src/typecheck.rs | 257 ++++++++++++++++++++++++------------------------- 1 file changed, 125 insertions(+), 132 deletions(-) (limited to 'dhall') diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 9eead93..a0782f8 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -272,22 +272,33 @@ fn type_of_builtin(b: Builtin) -> Expr { } } -fn ensure_equal( - x: &Type, - y: &Type, - mkerr: F1, - mkmsg: F2, -) -> Result<(), TypeError> -where - S: std::fmt::Debug, - F1: FnOnce(TypeMessage) -> TypeError, - F2: FnOnce() -> TypeMessage, -{ - if prop_equal(x, y) { - Ok(()) - } else { - Err(mkerr(mkmsg())) - } +macro_rules! ensure_equal { + ($x:expr, $y:expr, $err:expr $(,)*) => { + if !prop_equal(&$x, &$y) { + return Err($err); + } + }; +} + +macro_rules! ensure_matches { + ($x:expr, $pat:pat => $branch:expr, $err:expr $(,)*) => { + match $x.unroll_ref()? { + $pat => $branch, + _ => return Err($err), + } + }; +} + +macro_rules! ensure_is_type { + ($x:expr, $err:expr $(,)*) => { + ensure_matches!($x, Const(Type) => {}, $err) + }; +} + +macro_rules! ensure_is_const { + ($x:expr, $err:expr $(,)*) => { + ensure_matches!($x, Const(k) => *k, $err) + }; } /// Type-check an expression and return the expression alongside its type @@ -301,16 +312,6 @@ pub fn type_with( use dhall_core::Const::*; use dhall_core::ExprF::*; let mkerr = |msg: TypeMessage<_>| TypeError::new(ctx, e.clone(), msg); - let ensure_const = - |x: &crate::expr::Type, msg: TypeMessage<_>| match x.unroll_ref()? { - Const(k) => Ok(*k), - _ => Err(mkerr(msg)), - }; - let ensure_is_type = - |x: &crate::expr::Type, msg: TypeMessage<_>| match x.unroll_ref()? { - Const(Type) => Ok(()), - _ => Err(mkerr(msg)), - }; let mktype = |ctx, x: SubExpr| Ok(type_with(ctx, x)?.normalize().into_type()); @@ -338,27 +339,23 @@ pub fn type_with( } Pi(x, tA, tB) => { let tA = mktype(ctx, tA.clone())?; - let kA = ensure_const( + let kA = ensure_is_const!( &tA.get_type(), - InvalidInputType(tA.clone().into_normalized()?), - )?; + mkerr(InvalidInputType(tA.into_normalized()?)), + ); let ctx2 = ctx .insert(x.clone(), tA.clone()) .map(|e| e.shift(1, &V(x.clone(), 0))); let tB = type_with(&ctx2, tB.clone())?; - let kB = match tB.get_type().unroll_ref()? { - Const(k) => *k, - _ => { - return Err(TypeError::new( - &ctx2, - e.clone(), - InvalidOutputType( - tB.get_type_move().into_normalized()?, - ), - )); - } - }; + let kB = ensure_is_const!( + &tB.get_type(), + TypeError::new( + &ctx2, + e.clone(), + InvalidOutputType(tB.get_type_move().into_normalized()?), + ), + ); match rule(kA, kB) { Err(()) => Err(mkerr(NoDependentTypes( @@ -380,19 +377,19 @@ pub fn type_with( let r = type_with(ctx, r)?; // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway - let kR = ensure_const( + let kR = ensure_is_const!( &r.get_type().get_type(), - InvalidInputType(r.get_type().clone().into_normalized()?), - )?; + mkerr(InvalidInputType(r.get_type_move().into_normalized()?)), + ); let ctx2 = ctx.insert(f.clone(), r.get_type().clone()); let b = type_with(&ctx2, b.clone())?; // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway - let kB = ensure_const( + let kB = ensure_is_const!( &b.get_type().get_type(), - InvalidOutputType(b.get_type().clone().into_normalized()?), - )?; + mkerr(InvalidOutputType(b.get_type_move().into_normalized()?)), + ); if let Err(()) = rule(kR, kB) { return Err(mkerr(NoDependentLet( @@ -423,26 +420,23 @@ pub fn type_with( let mut tf = f.get_type().clone(); while let Some(a) = iter.next() { seen_args.push(a.as_expr().clone()); - let (x, tx, tb) = match tf.unroll_ref()? { + let (x, tx, tb) = ensure_matches!(tf, Pi(x, tx, tb) => (x, tx, tb), - _ => { - return Err(mkerr(NotAFunction(Typed( - rc(App(f.into_expr(), seen_args)), - tf, - )))); - } - }; + mkerr(NotAFunction(Typed( + rc(App(f.into_expr(), seen_args)), + tf, + ))) + ); let tx = mktype(ctx, tx.clone())?; - ensure_equal(&tx, &a.get_type(), mkerr, || { - TypeMismatch( - Typed( - rc(App(f.as_expr().clone(), seen_args.clone())), - tf.clone(), - ), - tx.clone().into_normalized().unwrap(), - a.clone(), - ) - })?; + ensure_equal!( + tx, + a.get_type(), + mkerr(TypeMismatch( + Typed(rc(App(f.into_expr(), seen_args)), tf), + tx.into_normalized()?, + a, + )) + ); tf = mktype( ctx, subst_shift(&V(x.clone(), 0), &a.as_expr(), &tb), @@ -452,90 +446,93 @@ pub fn type_with( } Annot(x, t) => { let t = t.normalize().into_type(); - ensure_equal(&t, &x.get_type(), mkerr, || { - AnnotMismatch( - x.clone(), - t.clone().into_normalized().unwrap(), - ) - })?; + ensure_equal!( + t, + x.get_type(), + mkerr(AnnotMismatch(x, t.into_normalized()?)) + ); Ok(RetType(x.get_type_move())) } BoolIf(x, y, z) => { - ensure_equal( - &x.get_type(), - &mktype(ctx, rc(Builtin(Bool)))?, - mkerr, - || InvalidPredicate(x.clone()), - )?; - ensure_is_type( - &y.get_type().get_type(), - IfBranchMustBeTerm(true, y.clone()), - )?; + ensure_equal!( + x.get_type(), + mktype(ctx, rc(Builtin(Bool)))?, + mkerr(InvalidPredicate(x)), + ); + ensure_is_type!( + y.get_type().get_type(), + mkerr(IfBranchMustBeTerm(true, y)), + ); - ensure_is_type( - &z.get_type().get_type(), - IfBranchMustBeTerm(false, z.clone()), - )?; + ensure_is_type!( + z.get_type().get_type(), + mkerr(IfBranchMustBeTerm(false, z)), + ); - ensure_equal(&y.get_type(), &z.get_type(), mkerr, || { - IfBranchMismatch(y.clone(), z.clone()) - })?; + ensure_equal!( + y.get_type(), + z.get_type(), + mkerr(IfBranchMismatch(y, z)) + ); Ok(RetType(y.get_type_move())) } EmptyListLit(t) => { let t = t.normalize().into_type(); - ensure_is_type( - &t.get_type(), - InvalidListType(t.clone().into_normalized()?), - )?; + ensure_is_type!( + t.get_type(), + mkerr(InvalidListType(t.into_normalized()?)), + ); let t = t.into_normalized()?.into_expr(); Ok(RetExpr(dhall::expr!(List t))) } NEListLit(xs) => { let mut iter = xs.into_iter().enumerate(); let (_, x) = iter.next().unwrap(); - ensure_is_type( - &x.get_type().get_type(), - InvalidListType(x.get_type().clone().into_normalized()?), - )?; + ensure_is_type!( + x.get_type().get_type(), + mkerr(InvalidListType( + x.get_type_move().into_normalized()? + )), + ); for (i, y) in iter { - ensure_equal(&x.get_type(), &y.get_type(), mkerr, || { - InvalidListElement( + ensure_equal!( + x.get_type(), + y.get_type(), + mkerr(InvalidListElement( i, - x.get_type().clone().into_normalized().unwrap(), - y.clone(), - ) - })?; + x.get_type_move().into_normalized()?, + y + )) + ); } let t = x.get_type_move().into_normalized()?.into_expr(); Ok(RetExpr(dhall::expr!(List t))) } EmptyOptionalLit(t) => { let t = t.normalize().into_type(); - ensure_is_type( - &t.get_type(), - InvalidOptionalType(t.clone().into_normalized()?), - )?; + ensure_is_type!( + t.get_type(), + mkerr(InvalidOptionalType(t.into_normalized()?)), + ); let t = t.into_normalized()?.into_expr(); Ok(RetExpr(dhall::expr!(Optional t))) } NEOptionalLit(x) => { - ensure_is_type( - &x.get_type().get_type(), - InvalidOptionalType( - x.get_type().clone().into_normalized()?, - ), - )?; - let t = x.get_type_move().into_normalized()?.into_expr(); + let tx = x.get_type_move(); + ensure_is_type!( + tx.get_type(), + mkerr(InvalidOptionalType(tx.into_normalized()?,)), + ); + let t = tx.into_normalized()?.into_expr(); Ok(RetExpr(dhall::expr!(Optional t))) } RecordType(kts) => { for (k, t) in kts { - ensure_is_type( - &t.get_type(), - InvalidFieldType(k.clone(), t.clone()), - )?; + ensure_is_type!( + t.get_type(), + mkerr(InvalidFieldType(k, t)), + ); } Ok(RetExpr(dhall::expr!(Type))) } @@ -543,25 +540,25 @@ pub fn type_with( let kts = kvs .into_iter() .map(|(k, v)| { - ensure_is_type( - &v.get_type().get_type(), - InvalidField(k.clone(), v.clone()), - )?; + ensure_is_type!( + v.get_type().get_type(), + mkerr(InvalidField(k, v)), + ); Ok(( - k.clone(), + k, v.get_type_move().into_normalized()?.into_expr(), )) }) .collect::>()?; Ok(RetExpr(RecordType(kts))) } - Field(r, x) => match r.get_type().unroll_ref()? { + Field(r, x) => ensure_matches!(r.get_type(), RecordType(kts) => match kts.get(&x) { Some(e) => Ok(RetExpr(e.unroll())), - None => Err(mkerr(MissingField(x.clone(), r.clone()))), + None => Err(mkerr(MissingField(x, r))), }, - _ => Err(mkerr(NotARecord(x.clone(), r.clone()))), - }, + mkerr(NotARecord(x, r)) + ), Builtin(b) => Ok(RetExpr(type_of_builtin(b))), BoolLit(_) => Ok(RetExpr(dhall::expr!(Bool))), NaturalLit(_) => Ok(RetExpr(dhall::expr!(Natural))), @@ -584,13 +581,9 @@ pub fn type_with( }, )?; - ensure_equal(&l.get_type(), &t, mkerr, || { - BinOpTypeMismatch(o, l.clone()) - })?; + ensure_equal!(l.get_type(), t, mkerr(BinOpTypeMismatch(o, l))); - ensure_equal(&r.get_type(), &t, mkerr, || { - BinOpTypeMismatch(o, r.clone()) - })?; + ensure_equal!(r.get_type(), t, mkerr(BinOpTypeMismatch(o, r))); Ok(RetType(t)) } -- cgit v1.2.3