diff options
Diffstat (limited to '')
-rw-r--r-- | dhall/src/typecheck.rs | 176 |
1 files changed, 93 insertions, 83 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 891c906..d241d8d 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -110,7 +110,7 @@ impl<'a> Type<'a> { } } -fn rule(a: Const, b: Const) -> Result<Const, ()> { +fn function_check(a: Const, b: Const) -> Result<Const, ()> { use dhall_core::Const::*; match (a, b) { (_, Type) => Ok(Type), @@ -121,23 +121,22 @@ fn rule(a: Const, b: Const) -> Result<Const, ()> { } } -fn match_vars(vl: &V<Label>, vr: &V<Label>, ctx: &[(Label, Label)]) -> bool { - let mut vl = vl.clone(); - let mut vr = vr.clone(); - let mut ctx = ctx.to_vec(); - ctx.reverse(); - while let Some((xL2, xR2)) = &ctx.pop() { - match (&vl, &vr) { - (V(xL, 0), V(xR, 0)) if xL == xL2 && xR == xR2 => return true, - (V(xL, nL), V(xR, nR)) => { - let nL2 = if xL == xL2 { nL - 1 } else { *nL }; - let nR2 = if xR == xR2 { nR - 1 } else { *nR }; - vl = V(xL.clone(), nL2); - vr = V(xR.clone(), nR2); +fn match_vars(vl: &V<Label>, vr: &V<Label>, ctx: &[(&Label, &Label)]) -> bool { + let (V(xL, mut nL), V(xR, mut nR)) = (vl, vr); + for &(xL2, xR2) in ctx { + match (nL, nR) { + (0, 0) if xL == xL2 && xR == xR2 => return true, + (_, _) => { + if xL == xL2 { + nL = nL - 1; + } + if xR == xR2 { + nR = nR - 1; + } } } } - vl == vr + xL == xR && nL == nR } // Equality up to alpha-equivalence (renaming of bound variables) @@ -147,56 +146,45 @@ where U: Borrow<Type<'static>>, { use dhall_core::ExprF::*; - fn go<S, T>( - ctx: &mut Vec<(Label, Label)>, - el: &Expr<S, X>, - er: &Expr<T, X>, + fn go<'a, S, T>( + ctx: &mut Vec<(&'a Label, &'a Label)>, + el: &'a SubExpr<S, X>, + er: &'a SubExpr<T, X>, ) -> bool where S: ::std::fmt::Debug, T: ::std::fmt::Debug, { - match (el, er) { - (&Const(a), &Const(b)) => a == b, - (&Builtin(a), &Builtin(b)) => a == b, - (&Var(ref vL), &Var(ref vR)) => match_vars(vL, vR, ctx), - (&Pi(ref xL, ref tL, ref bL), &Pi(ref xR, ref tR, ref bR)) => { - //ctx <- State.get - let eq1 = go(ctx, tL.as_ref(), tR.as_ref()); - if eq1 { - //State.put ((xL, xR):ctx) - ctx.push((xL.clone(), xR.clone())); - let eq2 = go(ctx, bL.as_ref(), bR.as_ref()); - //State.put ctx - let _ = ctx.pop(); + match (el.as_ref(), er.as_ref()) { + (Const(a), Const(b)) => a == b, + (Builtin(a), Builtin(b)) => a == b, + (Var(vL), Var(vR)) => match_vars(vL, vR, ctx), + (Pi(xL, tL, bL), Pi(xR, tR, bR)) => { + go(ctx, tL, tR) && { + ctx.push((xL, xR)); + let eq2 = go(ctx, bL, bR); + ctx.pop(); eq2 - } else { - false } } - (&App(ref fL, ref aL), &App(ref fR, ref aR)) => { - go(ctx, fL.as_ref(), fR.as_ref()) + (App(fL, aL), App(fR, aR)) => { + go(ctx, fL, fR) && aL.len() == aR.len() - && aL - .iter() - .zip(aR.iter()) - .all(|(aL, aR)| go(ctx, aL.as_ref(), aR.as_ref())) + && aL.iter().zip(aR.iter()).all(|(aL, aR)| go(ctx, aL, aR)) } - (&RecordType(ref ktsL0), &RecordType(ref ktsR0)) => { + (RecordType(ktsL0), RecordType(ktsR0)) => { ktsL0.len() == ktsR0.len() - && ktsL0.iter().zip(ktsR0.iter()).all( - |((kL, tL), (kR, tR))| { - kL == kR && go(ctx, tL.as_ref(), tR.as_ref()) - }, - ) + && ktsL0 + .iter() + .zip(ktsR0.iter()) + .all(|((kL, tL), (kR, tR))| kL == kR && go(ctx, tL, tR)) } - (&UnionType(ref ktsL0), &UnionType(ref ktsR0)) => { + (UnionType(ktsL0), UnionType(ktsR0)) => { ktsL0.len() == ktsR0.len() - && ktsL0.iter().zip(ktsR0.iter()).all( - |((kL, tL), (kR, tR))| { - kL == kR && go(ctx, tL.as_ref(), tR.as_ref()) - }, - ) + && ktsL0 + .iter() + .zip(ktsR0.iter()) + .all(|((kL, tL), (kR, tR))| kL == kR && go(ctx, tL, tR)) } (_, _) => false, } @@ -205,12 +193,20 @@ where (TypeInternal::SuperType, TypeInternal::SuperType) => true, (TypeInternal::Expr(l), TypeInternal::Expr(r)) => { let mut ctx = vec![]; - go(&mut ctx, l.unroll_ref(), r.unroll_ref()) + go(&mut ctx, l.as_expr(), r.as_expr()) } _ => false, } } +fn type_of_const<'a>(c: Const) -> Type<'a> { + match c { + Const::Type => Type::const_kind(), + Const::Kind => Type::const_sort(), + Const::Sort => Type(TypeInternal::SuperType), + } +} + fn type_of_builtin<S>(b: Builtin) -> Expr<S, Normalized<'static>> { use dhall_core::Builtin::*; match b { @@ -275,6 +271,15 @@ fn type_of_builtin<S>(b: Builtin) -> Expr<S, Normalized<'static>> { } } +macro_rules! function_check { + ($x:expr, $y:expr, $err:expr $(,)*) => { + match function_check($x, $y) { + Ok(k) => k, + Err(()) => return Err($err), + } + }; +} + macro_rules! ensure_equal { ($x:expr, $y:expr, $err:expr $(,)*) => { if !prop_equal($x, $y) { @@ -348,14 +353,11 @@ fn type_with( .insert(x.clone(), t.clone()) .map(|e| e.shift(1, &V(x.clone(), 0))); let b = type_with(&ctx2, b.clone())?; - Ok(RetType(mktype( - ctx, - rc(Pi( - x.clone(), - t.into_normalized()?.into_expr(), - b.get_type_move()?.into_normalized()?.into_expr(), - )), - )?)) + Ok(RetExpr(Pi( + x.clone(), + t.into_normalized()?.into_expr(), + b.get_type_move()?.into_normalized()?.into_expr(), + ))) } Pi(x, tA, tB) => { let tA = mktype(ctx, tA.clone())?; @@ -377,13 +379,14 @@ fn type_with( ), ); - match rule(kA, kB) { - Err(()) => Err(mkerr(NoDependentTypes( + let k = function_check!(kA, kB, { + mkerr(NoDependentTypes( tA.clone().into_normalized()?, tB.get_type_move()?.into_normalized()?, - ))), - Ok(k) => Ok(RetExpr(Const(k))), - } + )) + }); + + Ok(RetExpr(Const(k))) } Let(f, mt, r, b) => { let r = if let Some(t) = mt { @@ -411,12 +414,12 @@ fn type_with( mkerr(InvalidOutputType(b.get_type_move()?.into_normalized()?)), ); - if let Err(()) = rule(kR, kB) { - return Err(mkerr(NoDependentLet( + function_check!(kR, kB, { + mkerr(NoDependentLet( r.get_type_move()?.into_normalized()?, b.get_type_move()?.into_normalized()?, - ))); - } + )) + }); Ok(RetType(b.get_type_move()?)) } @@ -456,40 +459,46 @@ fn type_last_layer( Pi(_, _, _) => unreachable!(), Let(_, _, _, _) => unreachable!(), Embed(_) => unreachable!(), - Const(Type) => Ok(RetType(crate::expr::Type::const_kind())), - Const(Kind) => Ok(RetType(crate::expr::Type::const_sort())), - Const(Sort) => Ok(RetType(crate::expr::Type(TypeInternal::SuperType))), Var(V(x, n)) => match ctx.lookup(&x, n) { Some(e) => Ok(RetType(e.clone())), None => Err(mkerr(UnboundVariable)), }, App(f, args) => { - let mut seen_args: Vec<SubExpr<_, _>> = vec![]; let mut tf = f.get_type()?.into_owned(); - for a in args { - seen_args.push(a.as_expr().clone()); + for (i, a) in args.iter().enumerate() { let (x, tx, tb) = ensure_matches!(tf, Pi(x, tx, tb) => (x, tx, tb), mkerr(NotAFunction(Typed( - rc(App(f.into_expr(), seen_args)), + rc(App( + f.into_expr(), + args.into_iter() + .take(i) + .map(|e| e.into_expr()) + .collect() + )), Some(tf), PhantomData ))) ); let tx = mktype(ctx, tx.absurd())?; - ensure_equal!( - &tx, - a.get_type()?, + ensure_equal!(&tx, a.get_type()?, { + let a = a.clone(); mkerr(TypeMismatch( Typed( - rc(App(f.into_expr(), seen_args)), + rc(App( + f.into_expr(), + args.into_iter() + .take(i + 1) + .map(|e| e.into_expr()) + .collect(), + )), Some(tf), - PhantomData + PhantomData, ), tx.into_normalized()?, a, )) - ); + }); tf = mktype( ctx, subst_shift(&V(x.clone(), 0), a.as_expr(), &tb.absurd()), @@ -605,6 +614,7 @@ fn type_last_layer( }, mkerr(NotARecord(x, r)) ), + Const(c) => Ok(RetType(type_of_const(c))), Builtin(b) => Ok(RetExpr(type_of_builtin(b))), BoolLit(_) => Ok(RetType(simple_type_from_builtin(Bool))), NaturalLit(_) => Ok(RetType(simple_type_from_builtin(Natural))), |