summaryrefslogtreecommitdiff
path: root/dhall
diff options
context:
space:
mode:
authorNadrieril2019-04-13 18:21:21 +0200
committerNadrieril2019-04-13 18:21:21 +0200
commit99f379f6bc319f1055a72521493caa554d515e65 (patch)
treef8bd0f8dcf0e39d5ddd01876ccf92f01724e48ff /dhall
parent57e2c14460a9080ede622127e57fe80ab9e0c879 (diff)
Various typecheck improvements
Diffstat (limited to 'dhall')
-rw-r--r--dhall/src/typecheck.rs176
1 files changed, 93 insertions, 83 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs
index 891c906..d241d8d 100644
--- a/dhall/src/typecheck.rs
+++ b/dhall/src/typecheck.rs
@@ -110,7 +110,7 @@ impl<'a> Type<'a> {
}
}
-fn rule(a: Const, b: Const) -> Result<Const, ()> {
+fn function_check(a: Const, b: Const) -> Result<Const, ()> {
use dhall_core::Const::*;
match (a, b) {
(_, Type) => Ok(Type),
@@ -121,23 +121,22 @@ fn rule(a: Const, b: Const) -> Result<Const, ()> {
}
}
-fn match_vars(vl: &V<Label>, vr: &V<Label>, ctx: &[(Label, Label)]) -> bool {
- let mut vl = vl.clone();
- let mut vr = vr.clone();
- let mut ctx = ctx.to_vec();
- ctx.reverse();
- while let Some((xL2, xR2)) = &ctx.pop() {
- match (&vl, &vr) {
- (V(xL, 0), V(xR, 0)) if xL == xL2 && xR == xR2 => return true,
- (V(xL, nL), V(xR, nR)) => {
- let nL2 = if xL == xL2 { nL - 1 } else { *nL };
- let nR2 = if xR == xR2 { nR - 1 } else { *nR };
- vl = V(xL.clone(), nL2);
- vr = V(xR.clone(), nR2);
+fn match_vars(vl: &V<Label>, vr: &V<Label>, ctx: &[(&Label, &Label)]) -> bool {
+ let (V(xL, mut nL), V(xR, mut nR)) = (vl, vr);
+ for &(xL2, xR2) in ctx {
+ match (nL, nR) {
+ (0, 0) if xL == xL2 && xR == xR2 => return true,
+ (_, _) => {
+ if xL == xL2 {
+ nL = nL - 1;
+ }
+ if xR == xR2 {
+ nR = nR - 1;
+ }
}
}
}
- vl == vr
+ xL == xR && nL == nR
}
// Equality up to alpha-equivalence (renaming of bound variables)
@@ -147,56 +146,45 @@ where
U: Borrow<Type<'static>>,
{
use dhall_core::ExprF::*;
- fn go<S, T>(
- ctx: &mut Vec<(Label, Label)>,
- el: &Expr<S, X>,
- er: &Expr<T, X>,
+ fn go<'a, S, T>(
+ ctx: &mut Vec<(&'a Label, &'a Label)>,
+ el: &'a SubExpr<S, X>,
+ er: &'a SubExpr<T, X>,
) -> bool
where
S: ::std::fmt::Debug,
T: ::std::fmt::Debug,
{
- match (el, er) {
- (&Const(a), &Const(b)) => a == b,
- (&Builtin(a), &Builtin(b)) => a == b,
- (&Var(ref vL), &Var(ref vR)) => match_vars(vL, vR, ctx),
- (&Pi(ref xL, ref tL, ref bL), &Pi(ref xR, ref tR, ref bR)) => {
- //ctx <- State.get
- let eq1 = go(ctx, tL.as_ref(), tR.as_ref());
- if eq1 {
- //State.put ((xL, xR):ctx)
- ctx.push((xL.clone(), xR.clone()));
- let eq2 = go(ctx, bL.as_ref(), bR.as_ref());
- //State.put ctx
- let _ = ctx.pop();
+ match (el.as_ref(), er.as_ref()) {
+ (Const(a), Const(b)) => a == b,
+ (Builtin(a), Builtin(b)) => a == b,
+ (Var(vL), Var(vR)) => match_vars(vL, vR, ctx),
+ (Pi(xL, tL, bL), Pi(xR, tR, bR)) => {
+ go(ctx, tL, tR) && {
+ ctx.push((xL, xR));
+ let eq2 = go(ctx, bL, bR);
+ ctx.pop();
eq2
- } else {
- false
}
}
- (&App(ref fL, ref aL), &App(ref fR, ref aR)) => {
- go(ctx, fL.as_ref(), fR.as_ref())
+ (App(fL, aL), App(fR, aR)) => {
+ go(ctx, fL, fR)
&& aL.len() == aR.len()
- && aL
- .iter()
- .zip(aR.iter())
- .all(|(aL, aR)| go(ctx, aL.as_ref(), aR.as_ref()))
+ && aL.iter().zip(aR.iter()).all(|(aL, aR)| go(ctx, aL, aR))
}
- (&RecordType(ref ktsL0), &RecordType(ref ktsR0)) => {
+ (RecordType(ktsL0), RecordType(ktsR0)) => {
ktsL0.len() == ktsR0.len()
- && ktsL0.iter().zip(ktsR0.iter()).all(
- |((kL, tL), (kR, tR))| {
- kL == kR && go(ctx, tL.as_ref(), tR.as_ref())
- },
- )
+ && ktsL0
+ .iter()
+ .zip(ktsR0.iter())
+ .all(|((kL, tL), (kR, tR))| kL == kR && go(ctx, tL, tR))
}
- (&UnionType(ref ktsL0), &UnionType(ref ktsR0)) => {
+ (UnionType(ktsL0), UnionType(ktsR0)) => {
ktsL0.len() == ktsR0.len()
- && ktsL0.iter().zip(ktsR0.iter()).all(
- |((kL, tL), (kR, tR))| {
- kL == kR && go(ctx, tL.as_ref(), tR.as_ref())
- },
- )
+ && ktsL0
+ .iter()
+ .zip(ktsR0.iter())
+ .all(|((kL, tL), (kR, tR))| kL == kR && go(ctx, tL, tR))
}
(_, _) => false,
}
@@ -205,12 +193,20 @@ where
(TypeInternal::SuperType, TypeInternal::SuperType) => true,
(TypeInternal::Expr(l), TypeInternal::Expr(r)) => {
let mut ctx = vec![];
- go(&mut ctx, l.unroll_ref(), r.unroll_ref())
+ go(&mut ctx, l.as_expr(), r.as_expr())
}
_ => false,
}
}
+fn type_of_const<'a>(c: Const) -> Type<'a> {
+ match c {
+ Const::Type => Type::const_kind(),
+ Const::Kind => Type::const_sort(),
+ Const::Sort => Type(TypeInternal::SuperType),
+ }
+}
+
fn type_of_builtin<S>(b: Builtin) -> Expr<S, Normalized<'static>> {
use dhall_core::Builtin::*;
match b {
@@ -275,6 +271,15 @@ fn type_of_builtin<S>(b: Builtin) -> Expr<S, Normalized<'static>> {
}
}
+macro_rules! function_check {
+ ($x:expr, $y:expr, $err:expr $(,)*) => {
+ match function_check($x, $y) {
+ Ok(k) => k,
+ Err(()) => return Err($err),
+ }
+ };
+}
+
macro_rules! ensure_equal {
($x:expr, $y:expr, $err:expr $(,)*) => {
if !prop_equal($x, $y) {
@@ -348,14 +353,11 @@ fn type_with(
.insert(x.clone(), t.clone())
.map(|e| e.shift(1, &V(x.clone(), 0)));
let b = type_with(&ctx2, b.clone())?;
- Ok(RetType(mktype(
- ctx,
- rc(Pi(
- x.clone(),
- t.into_normalized()?.into_expr(),
- b.get_type_move()?.into_normalized()?.into_expr(),
- )),
- )?))
+ Ok(RetExpr(Pi(
+ x.clone(),
+ t.into_normalized()?.into_expr(),
+ b.get_type_move()?.into_normalized()?.into_expr(),
+ )))
}
Pi(x, tA, tB) => {
let tA = mktype(ctx, tA.clone())?;
@@ -377,13 +379,14 @@ fn type_with(
),
);
- match rule(kA, kB) {
- Err(()) => Err(mkerr(NoDependentTypes(
+ let k = function_check!(kA, kB, {
+ mkerr(NoDependentTypes(
tA.clone().into_normalized()?,
tB.get_type_move()?.into_normalized()?,
- ))),
- Ok(k) => Ok(RetExpr(Const(k))),
- }
+ ))
+ });
+
+ Ok(RetExpr(Const(k)))
}
Let(f, mt, r, b) => {
let r = if let Some(t) = mt {
@@ -411,12 +414,12 @@ fn type_with(
mkerr(InvalidOutputType(b.get_type_move()?.into_normalized()?)),
);
- if let Err(()) = rule(kR, kB) {
- return Err(mkerr(NoDependentLet(
+ function_check!(kR, kB, {
+ mkerr(NoDependentLet(
r.get_type_move()?.into_normalized()?,
b.get_type_move()?.into_normalized()?,
- )));
- }
+ ))
+ });
Ok(RetType(b.get_type_move()?))
}
@@ -456,40 +459,46 @@ fn type_last_layer(
Pi(_, _, _) => unreachable!(),
Let(_, _, _, _) => unreachable!(),
Embed(_) => unreachable!(),
- Const(Type) => Ok(RetType(crate::expr::Type::const_kind())),
- Const(Kind) => Ok(RetType(crate::expr::Type::const_sort())),
- Const(Sort) => Ok(RetType(crate::expr::Type(TypeInternal::SuperType))),
Var(V(x, n)) => match ctx.lookup(&x, n) {
Some(e) => Ok(RetType(e.clone())),
None => Err(mkerr(UnboundVariable)),
},
App(f, args) => {
- let mut seen_args: Vec<SubExpr<_, _>> = vec![];
let mut tf = f.get_type()?.into_owned();
- for a in args {
- seen_args.push(a.as_expr().clone());
+ for (i, a) in args.iter().enumerate() {
let (x, tx, tb) = ensure_matches!(tf,
Pi(x, tx, tb) => (x, tx, tb),
mkerr(NotAFunction(Typed(
- rc(App(f.into_expr(), seen_args)),
+ rc(App(
+ f.into_expr(),
+ args.into_iter()
+ .take(i)
+ .map(|e| e.into_expr())
+ .collect()
+ )),
Some(tf),
PhantomData
)))
);
let tx = mktype(ctx, tx.absurd())?;
- ensure_equal!(
- &tx,
- a.get_type()?,
+ ensure_equal!(&tx, a.get_type()?, {
+ let a = a.clone();
mkerr(TypeMismatch(
Typed(
- rc(App(f.into_expr(), seen_args)),
+ rc(App(
+ f.into_expr(),
+ args.into_iter()
+ .take(i + 1)
+ .map(|e| e.into_expr())
+ .collect(),
+ )),
Some(tf),
- PhantomData
+ PhantomData,
),
tx.into_normalized()?,
a,
))
- );
+ });
tf = mktype(
ctx,
subst_shift(&V(x.clone(), 0), a.as_expr(), &tb.absurd()),
@@ -605,6 +614,7 @@ fn type_last_layer(
},
mkerr(NotARecord(x, r))
),
+ Const(c) => Ok(RetType(type_of_const(c))),
Builtin(b) => Ok(RetExpr(type_of_builtin(b))),
BoolLit(_) => Ok(RetType(simple_type_from_builtin(Bool))),
NaturalLit(_) => Ok(RetType(simple_type_from_builtin(Natural))),