diff options
-rw-r--r-- | dhall/src/typecheck.rs | 96 |
1 files changed, 50 insertions, 46 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 08e5928..e63b032 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -176,6 +176,24 @@ fn type_of_builtin<S>(b: Builtin) -> Expr<S, X> { } } +fn ensure_equal<'a, S, F1, F2>( + x: &'a Expr<S, X>, + y: &'a Expr<S, X>, + mkerr: F1, + mkmsg: F2, +) -> Result<(), TypeError<S>> +where + S: std::fmt::Debug, + F1: FnOnce(TypeMessage<S>) -> TypeError<S>, + F2: FnOnce() -> TypeMessage<S>, +{ + if prop_equal(x, y) { + Ok(()) + } else { + Err(mkerr(mkmsg())) + } +} + /// Type-check an expression and return the expression's type if type-checking /// succeeds or an error if type-checking fails /// @@ -291,33 +309,29 @@ where ))); } }; - if !prop_equal(tx.as_ref(), ta.as_ref()) { - return Err(mkerr(TypeMismatch( - rc(App(f.clone(), seen_args)), + ensure_equal(tx.as_ref(), ta.as_ref(), mkerr, || { + TypeMismatch( + rc(App(f.clone(), seen_args.clone())), tx.clone(), a.clone(), - ta, - ))); - } + ta.clone(), + ) + })?; tf = normalize(subst_shift(&V(x.clone(), 0), &a, &tb)); } Ok(tf.unroll()) } Annot((x, tx), (t, _)) => { let t = normalize(t.clone()); - if prop_equal(t.as_ref(), tx.as_ref()) { - Ok(t.unroll()) - } else { - Err(mkerr(AnnotMismatch(x.clone(), t, tx))) - } + ensure_equal(t.as_ref(), tx.as_ref(), mkerr, || { + AnnotMismatch(x.clone(), t.clone(), tx.clone()) + })?; + Ok(t.unroll()) } BoolIf((x, tx), (y, ty), (z, tz)) => { - match tx.as_ref() { - Builtin(Bool) => {} - _ => { - return Err(mkerr(InvalidPredicate(x.clone(), tx))); - } - } + ensure_equal(tx.as_ref(), &Builtin(Bool), mkerr, || { + InvalidPredicate(x.clone(), tx.clone()) + })?; let tty = type_with(ctx, ty.clone())?; ensure_is_type( tty.clone(), @@ -340,14 +354,14 @@ where ), )?; - if !prop_equal(ty.as_ref(), tz.as_ref()) { - return Err(mkerr(IfBranchMismatch( + ensure_equal(ty.as_ref(), tz.as_ref(), mkerr, || { + IfBranchMismatch( y.clone(), z.clone(), - ty, - tz, - ))); - } + ty.clone(), + tz.clone(), + ) + })?; Ok(ty.unroll()) } EmptyListLit((t, tt)) => { @@ -361,14 +375,9 @@ where let s = type_with(ctx, t.clone())?; ensure_is_type(s, InvalidListType(t.clone()))?; for (i, (y, ty)) in iter { - if !prop_equal(t.as_ref(), ty.as_ref()) { - return Err(mkerr(InvalidListElement( - i, - t, - y.clone(), - ty, - ))); - } + ensure_equal(t.as_ref(), ty.as_ref(), mkerr, || { + InvalidListElement(i, t.clone(), y.clone(), ty.clone()) + })?; } Ok(dhall::expr!(List t)) } @@ -414,7 +423,7 @@ where // TODO: check type of interpolations TextLit(_) => Ok(Builtin(Text)), BinOp(o, (l, tl), (r, tr)) => { - let t = match o { + let t = Builtin(match o { BoolAnd => Bool, BoolOr => Bool, BoolEQ => Bool, @@ -423,22 +432,17 @@ where NaturalTimes => Natural, TextAppend => Text, _ => panic!("Unimplemented typecheck case: {:?}", e), - }; - match tl.as_ref() { - Builtin(lt) if *lt == t => {} - _ => { - return Err(mkerr(BinOpTypeMismatch(o, l.clone(), tl))) - } - } + }); - match tr.as_ref() { - Builtin(rt) if *rt == t => {} - _ => { - return Err(mkerr(BinOpTypeMismatch(o, r.clone(), tr))) - } - } + ensure_equal(tl.as_ref(), &t, mkerr, || { + BinOpTypeMismatch(o, l.clone(), tl.clone()) + })?; + + ensure_equal(tr.as_ref(), &t, mkerr, || { + BinOpTypeMismatch(o, r.clone(), tr.clone()) + })?; - Ok(Builtin(t)) + Ok(t) } Embed(p) => match p {}, _ => panic!("Unimplemented typecheck case: {:?}", e), |