summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--dhall/src/typecheck.rs96
1 files changed, 50 insertions, 46 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs
index 08e5928..e63b032 100644
--- a/dhall/src/typecheck.rs
+++ b/dhall/src/typecheck.rs
@@ -176,6 +176,24 @@ fn type_of_builtin<S>(b: Builtin) -> Expr<S, X> {
}
}
+fn ensure_equal<'a, S, F1, F2>(
+ x: &'a Expr<S, X>,
+ y: &'a Expr<S, X>,
+ mkerr: F1,
+ mkmsg: F2,
+) -> Result<(), TypeError<S>>
+where
+ S: std::fmt::Debug,
+ F1: FnOnce(TypeMessage<S>) -> TypeError<S>,
+ F2: FnOnce() -> TypeMessage<S>,
+{
+ if prop_equal(x, y) {
+ Ok(())
+ } else {
+ Err(mkerr(mkmsg()))
+ }
+}
+
/// Type-check an expression and return the expression's type if type-checking
/// succeeds or an error if type-checking fails
///
@@ -291,33 +309,29 @@ where
)));
}
};
- if !prop_equal(tx.as_ref(), ta.as_ref()) {
- return Err(mkerr(TypeMismatch(
- rc(App(f.clone(), seen_args)),
+ ensure_equal(tx.as_ref(), ta.as_ref(), mkerr, || {
+ TypeMismatch(
+ rc(App(f.clone(), seen_args.clone())),
tx.clone(),
a.clone(),
- ta,
- )));
- }
+ ta.clone(),
+ )
+ })?;
tf = normalize(subst_shift(&V(x.clone(), 0), &a, &tb));
}
Ok(tf.unroll())
}
Annot((x, tx), (t, _)) => {
let t = normalize(t.clone());
- if prop_equal(t.as_ref(), tx.as_ref()) {
- Ok(t.unroll())
- } else {
- Err(mkerr(AnnotMismatch(x.clone(), t, tx)))
- }
+ ensure_equal(t.as_ref(), tx.as_ref(), mkerr, || {
+ AnnotMismatch(x.clone(), t.clone(), tx.clone())
+ })?;
+ Ok(t.unroll())
}
BoolIf((x, tx), (y, ty), (z, tz)) => {
- match tx.as_ref() {
- Builtin(Bool) => {}
- _ => {
- return Err(mkerr(InvalidPredicate(x.clone(), tx)));
- }
- }
+ ensure_equal(tx.as_ref(), &Builtin(Bool), mkerr, || {
+ InvalidPredicate(x.clone(), tx.clone())
+ })?;
let tty = type_with(ctx, ty.clone())?;
ensure_is_type(
tty.clone(),
@@ -340,14 +354,14 @@ where
),
)?;
- if !prop_equal(ty.as_ref(), tz.as_ref()) {
- return Err(mkerr(IfBranchMismatch(
+ ensure_equal(ty.as_ref(), tz.as_ref(), mkerr, || {
+ IfBranchMismatch(
y.clone(),
z.clone(),
- ty,
- tz,
- )));
- }
+ ty.clone(),
+ tz.clone(),
+ )
+ })?;
Ok(ty.unroll())
}
EmptyListLit((t, tt)) => {
@@ -361,14 +375,9 @@ where
let s = type_with(ctx, t.clone())?;
ensure_is_type(s, InvalidListType(t.clone()))?;
for (i, (y, ty)) in iter {
- if !prop_equal(t.as_ref(), ty.as_ref()) {
- return Err(mkerr(InvalidListElement(
- i,
- t,
- y.clone(),
- ty,
- )));
- }
+ ensure_equal(t.as_ref(), ty.as_ref(), mkerr, || {
+ InvalidListElement(i, t.clone(), y.clone(), ty.clone())
+ })?;
}
Ok(dhall::expr!(List t))
}
@@ -414,7 +423,7 @@ where
// TODO: check type of interpolations
TextLit(_) => Ok(Builtin(Text)),
BinOp(o, (l, tl), (r, tr)) => {
- let t = match o {
+ let t = Builtin(match o {
BoolAnd => Bool,
BoolOr => Bool,
BoolEQ => Bool,
@@ -423,22 +432,17 @@ where
NaturalTimes => Natural,
TextAppend => Text,
_ => panic!("Unimplemented typecheck case: {:?}", e),
- };
- match tl.as_ref() {
- Builtin(lt) if *lt == t => {}
- _ => {
- return Err(mkerr(BinOpTypeMismatch(o, l.clone(), tl)))
- }
- }
+ });
- match tr.as_ref() {
- Builtin(rt) if *rt == t => {}
- _ => {
- return Err(mkerr(BinOpTypeMismatch(o, r.clone(), tr)))
- }
- }
+ ensure_equal(tl.as_ref(), &t, mkerr, || {
+ BinOpTypeMismatch(o, l.clone(), tl.clone())
+ })?;
+
+ ensure_equal(tr.as_ref(), &t, mkerr, || {
+ BinOpTypeMismatch(o, r.clone(), tr.clone())
+ })?;
- Ok(Builtin(t))
+ Ok(t)
}
Embed(p) => match p {},
_ => panic!("Unimplemented typecheck case: {:?}", e),