diff options
Diffstat (limited to 'dhall')
-rw-r--r-- | dhall/src/typecheck.rs | 247 |
1 files changed, 114 insertions, 133 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index d0e1d44..0ebc67e 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -21,6 +21,9 @@ impl Resolved { } } impl Typed { + fn as_expr(&self) -> &SubExpr<X, X> { + &self.0 + } pub fn get_type(&self) -> &Type { &self.1 } @@ -29,15 +32,18 @@ impl Normalized { pub fn get_type(&self) -> &Type { &self.1 } + fn into_type(self) -> Type { + crate::expr::Type(TypeInternal::Expr(Box::new(self))) + } } impl Type { // pub fn as_expr(&self) -> &Normalized { // &*self.0 // } - pub fn as_expr(&self) -> &SubExpr<X, X> { + fn as_expr(&self) -> &SubExpr<X, X> { &self.as_normalized().0 } - pub fn as_normalized(&self) -> &Normalized { + fn as_normalized(&self) -> &Normalized { use TypeInternal::*; match &self.0 { Expr(e) => &e, @@ -49,6 +55,9 @@ impl Type { } } +const TYPE_OF_SORT: crate::expr::Type = + crate::expr::Type(TypeInternal::Universe(4)); + fn rule(a: Const, b: Const) -> Result<Const, ()> { use dhall_core::Const::*; match (a, b) { @@ -209,8 +218,8 @@ fn type_of_builtin<S>(b: Builtin) -> Expr<S, X> { } fn ensure_equal<'a, S, F1, F2>( - x: &'a Expr<S, X>, - y: &'a Expr<S, X>, + x: &'a Type, + y: &'a Type, mkerr: F1, mkmsg: F2, ) -> Result<(), TypeError<S>> @@ -219,7 +228,7 @@ where F1: FnOnce(TypeMessage<S>) -> TypeError<S>, F2: FnOnce() -> TypeMessage<S>, { - if prop_equal(x, y) { + if prop_equal(x.as_expr().as_ref(), y.as_expr().as_ref()) { Ok(()) } else { Err(mkerr(mkmsg())) @@ -253,33 +262,40 @@ pub fn type_with( _ => Err(mkerr(msg)), }; + let mktype = + |ctx, x: SubExpr<X, X>| Ok(type_with(ctx, x)?.normalize().into_type()); + enum Ret { - ErrRet(TypeError<X>), - OkNormalized(Normalized), - OkType(crate::expr::Type), - OkRet(Expr<X, X>), + RetType(crate::expr::Type), + RetExpr(Expr<X, X>), } use Ret::*; let ret = match e.as_ref() { Lam(x, t, b) => { - let t2 = type_with(ctx, t.clone())?.normalize(); + let t2 = mktype(ctx, t.clone())?; let ctx2 = ctx - .insert(x.clone(), t2.0.clone()) + .insert(x.clone(), t2.as_expr().clone()) .map(|e| shift(1, &V(x.clone(), 0), e)); let b = type_with(&ctx2, b.clone())?; let _ = type_with( ctx, rc(Pi(x.clone(), t.clone(), b.get_type().as_expr().clone())), )?; - OkRet(Pi(x.clone(), t2.0, b.get_type().as_expr().clone())) + Ok(RetExpr(Pi( + x.clone(), + t2.as_expr().clone(), + b.get_type().as_expr().clone(), + ))) } Pi(x, tA, tB) => { - let tA = type_with(ctx, tA.clone())?.normalize(); - let kA = - ensure_const(tA.get_type(), InvalidInputType(tA.0.clone()))?; + let tA = mktype(ctx, tA.clone())?; + let kA = ensure_const( + tA.get_type(), + InvalidInputType(tA.as_expr().clone()), + )?; let ctx2 = ctx - .insert(x.clone(), tA.0.clone()) + .insert(x.clone(), tA.as_expr().clone()) .map(|e| shift(1, &V(x.clone(), 0), e)); let tB = type_with(&ctx2, tB.clone())?; let kB = match tB.get_type().as_expr().as_ref() { @@ -294,11 +310,11 @@ pub fn type_with( }; match rule(kA, kB) { - Err(()) => ErrRet(mkerr(NoDependentTypes( - tA.0.clone(), + Err(()) => Err(mkerr(NoDependentTypes( + tA.as_expr().clone(), tB.get_type().clone(), ))), - Ok(_) => OkRet(Const(kB)), + Ok(k) => Ok(RetExpr(Const(k))), } } Let(f, mt, r, b) => { @@ -332,7 +348,7 @@ pub fn type_with( ))); } - OkType(b.get_type().clone()) + Ok(RetType(b.get_type().clone())) } _ => match e .as_ref() @@ -341,20 +357,20 @@ pub fn type_with( Lam(_, _, _) => unreachable!(), Pi(_, _, _) => unreachable!(), Let(_, _, _, _) => unreachable!(), - Const(Type) => OkRet(Const(Kind)), - Const(Kind) => OkRet(Const(Sort)), - Const(Sort) => ErrRet(mkerr(Untyped)), + Const(Type) => Ok(RetExpr(Const(Kind))), + Const(Kind) => Ok(RetExpr(Const(Sort))), + Const(Sort) => Ok(RetType(TYPE_OF_SORT)), Var(V(x, n)) => match ctx.lookup(&x, n) { - Some(e) => OkRet(e.unroll()), - None => ErrRet(mkerr(UnboundVariable)), + Some(e) => Ok(RetExpr(e.unroll())), + None => Err(mkerr(UnboundVariable)), }, App(f, args) => { let mut iter = args.into_iter(); let mut seen_args: Vec<SubExpr<_, _>> = vec![]; - let mut tf = f.get_type().as_normalized().clone(); + let mut tf = f.get_type().clone(); while let Some(a) = iter.next() { seen_args.push(a.0.clone()); - let (x, tx, tb) = match tf.0.as_ref() { + let (x, tx, tb) = match tf.as_expr().as_ref() { Pi(x, tx, tb) => (x, tx, tb), _ => { return Err(mkerr(NotAFunction( @@ -363,40 +379,29 @@ pub fn type_with( ))); } }; - ensure_equal( - tx.as_ref(), - a.get_type().as_expr().as_ref(), - mkerr, - || { - TypeMismatch( - rc(App(f.0.clone(), seen_args.clone())), - tx.clone(), - a.clone(), - ) - }, - )?; - tf = type_with( - ctx, - subst_shift(&V(x.clone(), 0), &a.0, &tb), - )? - .normalize(); + let tx = mktype(ctx, tx.clone())?; + ensure_equal(&tx, a.get_type(), mkerr, || { + TypeMismatch( + rc(App(f.0.clone(), seen_args.clone())), + tx.clone(), + a.clone(), + ) + })?; + tf = mktype(ctx, subst_shift(&V(x.clone(), 0), &a.0, &tb))?; } - OkNormalized(tf) + Ok(RetType(tf)) } Annot(x, t) => { - let t = t.normalize(); - ensure_equal( - t.0.as_ref(), - x.get_type().as_expr().as_ref(), - mkerr, - || AnnotMismatch(x.clone(), t.clone()), - )?; - OkType(x.get_type().clone()) + let t = t.normalize().into_type(); + ensure_equal(&t, x.get_type(), mkerr, || { + AnnotMismatch(x.clone(), t.clone()) + })?; + Ok(RetType(x.get_type().clone())) } BoolIf(x, y, z) => { ensure_equal( - x.get_type().as_expr().as_ref(), - &Builtin(Bool), + x.get_type(), + &mktype(ctx, rc(Builtin(Bool)))?, mkerr, || InvalidPredicate(x.clone()), )?; @@ -410,19 +415,16 @@ pub fn type_with( IfBranchMustBeTerm(false, z.clone()), )?; - ensure_equal( - y.get_type().as_expr().as_ref(), - z.get_type().as_expr().as_ref(), - mkerr, - || IfBranchMismatch(y.clone(), z.clone()), - )?; + ensure_equal(y.get_type(), z.get_type(), mkerr, || { + IfBranchMismatch(y.clone(), z.clone()) + })?; - OkType(y.get_type().clone()) + Ok(RetType(y.get_type().clone())) } EmptyListLit(t) => { ensure_is_type(t.get_type(), InvalidListType(t.0.clone()))?; let t = t.normalize().0; - OkRet(dhall::expr!(List t)) + Ok(RetExpr(dhall::expr!(List t))) } NEListLit(xs) => { let mut iter = xs.into_iter().enumerate(); @@ -432,26 +434,21 @@ pub fn type_with( InvalidListType(x.get_type().as_expr().clone()), )?; for (i, y) in iter { - ensure_equal( - x.get_type().as_expr().as_ref(), - y.get_type().as_expr().as_ref(), - mkerr, - || { - InvalidListElement( - i, - x.get_type().as_expr().clone(), - y.clone(), - ) - }, - )?; + ensure_equal(x.get_type(), y.get_type(), mkerr, || { + InvalidListElement( + i, + x.get_type().as_expr().clone(), + y.clone(), + ) + })?; } let t = x.get_type().as_expr().clone(); - OkRet(dhall::expr!(List t)) + Ok(RetExpr(dhall::expr!(List t))) } EmptyOptionalLit(t) => { ensure_is_type(t.get_type(), InvalidOptionalType(t.0.clone()))?; let t = t.normalize().0; - OkRet(dhall::expr!(Optional t)) + Ok(RetExpr(dhall::expr!(Optional t))) } NEOptionalLit(x) => { ensure_is_type( @@ -459,7 +456,7 @@ pub fn type_with( InvalidOptionalType(x.get_type().as_expr().clone()), )?; let t = x.get_type().as_expr().clone(); - OkRet(dhall::expr!(Optional t)) + Ok(RetExpr(dhall::expr!(Optional t))) } RecordType(kts) => { for (k, t) in kts { @@ -468,7 +465,7 @@ pub fn type_with( InvalidFieldType(k.clone(), t.clone()), )?; } - OkRet(Const(Type)) + Ok(RetExpr(Const(Type))) } RecordLit(kvs) => { let kts = kvs @@ -481,70 +478,54 @@ pub fn type_with( Ok((k.clone(), v.get_type().as_expr().clone())) }) .collect::<Result<_, _>>()?; - OkRet(RecordType(kts)) + Ok(RetExpr(RecordType(kts))) } Field(r, x) => match r.get_type().as_expr().as_ref() { RecordType(kts) => match kts.get(&x) { - Some(e) => OkRet(e.unroll()), - None => ErrRet(mkerr(MissingField(x.clone(), r.clone()))), + Some(e) => Ok(RetExpr(e.unroll())), + None => Err(mkerr(MissingField(x.clone(), r.clone()))), }, - _ => ErrRet(mkerr(NotARecord(x.clone(), r.clone()))), + _ => Err(mkerr(NotARecord(x.clone(), r.clone()))), }, - Builtin(b) => OkRet(type_of_builtin(b)), - BoolLit(_) => OkRet(Builtin(Bool)), - NaturalLit(_) => OkRet(Builtin(Natural)), - IntegerLit(_) => OkRet(Builtin(Integer)), - DoubleLit(_) => OkRet(Builtin(Double)), + Builtin(b) => Ok(RetExpr(type_of_builtin(b))), + BoolLit(_) => Ok(RetExpr(Builtin(Bool))), + NaturalLit(_) => Ok(RetExpr(Builtin(Natural))), + IntegerLit(_) => Ok(RetExpr(Builtin(Integer))), + DoubleLit(_) => Ok(RetExpr(Builtin(Double))), // TODO: check type of interpolations - TextLit(_) => OkRet(Builtin(Text)), + TextLit(_) => Ok(RetExpr(Builtin(Text))), BinOp(o, l, r) => { - let t = Builtin(match o { - BoolAnd => Bool, - BoolOr => Bool, - BoolEQ => Bool, - BoolNE => Bool, - NaturalPlus => Natural, - NaturalTimes => Natural, - TextAppend => Text, - _ => panic!("Unimplemented typecheck case: {:?}", e), - }); - - ensure_equal( - l.get_type().as_expr().as_ref(), - &t, - mkerr, - || BinOpTypeMismatch(o, l.clone()), + let t = mktype( + ctx, + rc(Builtin(match o { + BoolAnd => Bool, + BoolOr => Bool, + BoolEQ => Bool, + BoolNE => Bool, + NaturalPlus => Natural, + NaturalTimes => Natural, + TextAppend => Text, + _ => panic!("Unimplemented typecheck case: {:?}", e), + })), )?; - ensure_equal( - r.get_type().as_expr().as_ref(), - &t, - mkerr, - || BinOpTypeMismatch(o, r.clone()), - )?; + ensure_equal(l.get_type(), &t, mkerr, || { + BinOpTypeMismatch(o, l.clone()) + })?; - OkRet(t) + ensure_equal(r.get_type(), &t, mkerr, || { + BinOpTypeMismatch(o, r.clone()) + })?; + + Ok(RetType(t)) } Embed(p) => match p {}, _ => panic!("Unimplemented typecheck case: {:?}", e), }, - }; + }?; match ret { - OkRet(Const(Sort)) => { - Ok(Typed(e, crate::expr::Type(TypeInternal::Universe(3)))) - } - OkRet(ret) => Ok(Typed( - e, - crate::expr::Type(TypeInternal::Expr(Box::new( - type_with(ctx, rc(ret))?.normalize(), - ))), - )), - OkNormalized(ret) => Ok(Typed( - e, - crate::expr::Type(TypeInternal::Expr(Box::new(ret))), - )), - OkType(ret) => Ok(Typed(e, ret)), - ErrRet(e) => Err(e), + RetExpr(ret) => Ok(Typed(e, mktype(ctx, rc(ret))?)), + RetType(typ) => Ok(Typed(e, typ)), } } @@ -567,9 +548,9 @@ pub enum TypeMessage<S> { UnboundVariable, InvalidInputType(SubExpr<S, X>), InvalidOutputType(crate::expr::Type), - NotAFunction(SubExpr<S, X>, Normalized), - TypeMismatch(SubExpr<S, X>, SubExpr<S, X>, Typed), - AnnotMismatch(Typed, Normalized), + NotAFunction(SubExpr<S, X>, crate::expr::Type), + TypeMismatch(SubExpr<S, X>, crate::expr::Type, Typed), + AnnotMismatch(Typed, crate::expr::Type), Untyped, InvalidListElement(usize, SubExpr<S, X>, Typed), InvalidListType(SubExpr<S, X>), @@ -645,8 +626,8 @@ impl<S> fmt::Display for TypeMessage<S> { let template = include_str!("errors/TypeMismatch.txt"); let s = template .replace("$txt0", &format!("{}", e0)) - .replace("$txt1", &format!("{}", e1)) - .replace("$txt2", &format!("{}", e2.0)) + .replace("$txt1", &format!("{}", e1.as_expr())) + .replace("$txt2", &format!("{}", e2.as_expr())) .replace("$txt3", &format!("{}", e2.get_type().as_expr())); f.write_str(&s) } |