diff options
Diffstat (limited to 'dhall/src')
-rw-r--r-- | dhall/src/lib.rs | 2 | ||||
-rw-r--r-- | dhall/src/typecheck.rs | 406 |
2 files changed, 203 insertions, 205 deletions
diff --git a/dhall/src/lib.rs b/dhall/src/lib.rs index 0270103..103fd29 100644 --- a/dhall/src/lib.rs +++ b/dhall/src/lib.rs @@ -14,6 +14,8 @@ mod dhall_type; pub mod imports; pub mod typecheck; pub use crate::dhall_type::*; +pub use dhall_generator::expr; +pub use dhall_generator::subexpr; pub use dhall_generator::Type; pub use crate::imports::*; diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 145de63..e63b032 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -6,7 +6,7 @@ use crate::normalize::normalize; use dhall_core; use dhall_core::context::Context; use dhall_core::*; -use dhall_generator::dhall_expr; +use dhall_generator as dhall; use self::TypeMessage::*; @@ -112,31 +112,31 @@ where go::<S, T>(&mut ctx, eL0, eR0) } -fn type_of_builtin<S>(b: Builtin) -> SubExpr<S, X> { +fn type_of_builtin<S>(b: Builtin) -> Expr<S, X> { use dhall_core::Builtin::*; match b { - Bool | Natural | Integer | Double | Text => dhall_expr!(Type), - List | Optional => dhall_expr!( + Bool | Natural | Integer | Double | Text => dhall::expr!(Type), + List | Optional => dhall::expr!( Type -> Type ), - NaturalFold => dhall_expr!( + NaturalFold => dhall::expr!( Natural -> forall (natural: Type) -> forall (succ: natural -> natural) -> forall (zero: natural) -> natural ), - NaturalBuild => dhall_expr!( + NaturalBuild => dhall::expr!( (forall (natural: Type) -> forall (succ: natural -> natural) -> forall (zero: natural) -> natural) -> Natural ), - NaturalIsZero | NaturalEven | NaturalOdd => dhall_expr!( + NaturalIsZero | NaturalEven | NaturalOdd => dhall::expr!( Natural -> Bool ), - ListBuild => dhall_expr!( + ListBuild => dhall::expr!( forall (a: Type) -> (forall (list: Type) -> forall (cons: a -> list -> list) -> @@ -144,7 +144,7 @@ fn type_of_builtin<S>(b: Builtin) -> SubExpr<S, X> { list) -> List a ), - ListFold => dhall_expr!( + ListFold => dhall::expr!( forall (a: Type) -> List a -> forall (list: Type) -> @@ -152,19 +152,19 @@ fn type_of_builtin<S>(b: Builtin) -> SubExpr<S, X> { forall (nil: list) -> list ), - ListLength => dhall_expr!(forall (a: Type) -> List a -> Natural), + ListLength => dhall::expr!(forall (a: Type) -> List a -> Natural), ListHead | ListLast => { - dhall_expr!(forall (a: Type) -> List a -> Optional a) + dhall::expr!(forall (a: Type) -> List a -> Optional a) } - ListIndexed => dhall_expr!( + ListIndexed => dhall::expr!( forall (a: Type) -> List a -> List { index: Natural, value: a } ), - ListReverse => dhall_expr!( + ListReverse => dhall::expr!( forall (a: Type) -> List a -> List a ), - OptionalFold => dhall_expr!( + OptionalFold => dhall::expr!( forall (a: Type) -> Optional a -> forall (optional: Type) -> @@ -176,12 +176,29 @@ fn type_of_builtin<S>(b: Builtin) -> SubExpr<S, X> { } } -/// Type-check an expression and return the expression'i type if type-checking -/// suceeds or an error if type-checking fails +fn ensure_equal<'a, S, F1, F2>( + x: &'a Expr<S, X>, + y: &'a Expr<S, X>, + mkerr: F1, + mkmsg: F2, +) -> Result<(), TypeError<S>> +where + S: std::fmt::Debug, + F1: FnOnce(TypeMessage<S>) -> TypeError<S>, + F2: FnOnce() -> TypeMessage<S>, +{ + if prop_equal(x, y) { + Ok(()) + } else { + Err(mkerr(mkmsg())) + } +} + +/// Type-check an expression and return the expression's type if type-checking +/// succeeds or an error if type-checking fails /// -/// `type_with` does not necessarily normalize the type since full normalization -/// is not necessary for just type-checking. If you actually care about the -/// returned type then you may want to `normalize` it afterwards. +/// `type_with` normalizes the type since while type-checking. It expects the +/// context to contain only normalized expressions. pub fn type_with<S>( ctx: &Context<Label, SubExpr<S, X>>, e: SubExpr<S, X>, @@ -205,31 +222,25 @@ where _ => Err(mkerr(msg)), }; - match e.as_ref() { - Const(c) => axiom(*c).map(Const), - Var(V(x, n)) => { - return ctx - .lookup(x, *n) - .cloned() - .ok_or_else(|| mkerr(UnboundVariable)) - } - Lam(x, tA, b) => { + let ret = match e.as_ref() { + Lam(x, t, b) => { + let t2 = normalize(SubExpr::clone(t)); let ctx2 = ctx - .insert(x.clone(), tA.clone()) + .insert(x.clone(), t2.clone()) .map(|e| shift(1, &V(x.clone(), 0), e)); let tB = type_with(&ctx2, b.clone())?; - let p = rc(Pi(x.clone(), tA.clone(), tB)); - let _ = type_with(ctx, p.clone())?; - return Ok(p); + let _ = type_with(ctx, rc(Pi(x.clone(), t.clone(), tB.clone())))?; + Ok(Pi(x.clone(), t2, tB)) } Pi(x, tA, tB) => { - let tA2 = normalized_type_with(ctx, tA.clone())?; - let kA = ensure_const(&tA2, InvalidInputType(tA.clone()))?; + let ttA = type_with(ctx, tA.clone())?; + let tA = normalize(SubExpr::clone(tA)); + let kA = ensure_const(&ttA, InvalidInputType(tA.clone()))?; let ctx2 = ctx .insert(x.clone(), tA.clone()) .map(|e| shift(1, &V(x.clone(), 0), e)); - let tB = normalized_type_with(&ctx2, tB.clone())?; + let tB = type_with(&ctx2, tB.clone())?; let kB = match tB.as_ref() { Const(k) => *k, _ => { @@ -246,33 +257,6 @@ where Ok(_) => Ok(Const(kB)), } } - App(f, args) => { - // Recurse on args - let (a, tf) = match args.split_last() { - None => return type_with(ctx, f.clone()), - Some((a, args)) => ( - a, - normalized_type_with( - ctx, - rc(App(f.clone(), args.to_vec())), - )?, - ), - }; - let (x, tA, tB) = match tf.as_ref() { - Pi(x, tA, tB) => (x, tA, tB), - _ => { - return Err(mkerr(NotAFunction(f.clone(), tf))); - } - }; - let tA = normalize(SubExpr::clone(tA)); - let tA2 = normalized_type_with(ctx, a.clone())?; - if prop_equal(tA.as_ref(), tA2.as_ref()) { - let vx0 = &V(x.clone(), 0); - return Ok(subst_shift(vx0, &a, &tB)); - } else { - Err(mkerr(TypeMismatch(f.clone(), tA, a.clone(), tA2))) - } - } Let(f, mt, r, b) => { let r = if let Some(t) = mt { rc(Annot(SubExpr::clone(r), SubExpr::clone(t))) @@ -281,14 +265,14 @@ where }; let tR = type_with(ctx, r)?; - let ttR = normalized_type_with(ctx, tR.clone())?; + let ttR = type_with(ctx, tR.clone())?; // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway let kR = ensure_const(&ttR, InvalidInputType(tR.clone()))?; let ctx2 = ctx.insert(f.clone(), tR.clone()); let tB = type_with(&ctx2, b.clone())?; - let ttB = normalized_type_with(ctx, tB.clone())?; + let ttB = type_with(ctx, tB.clone())?; // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway let kB = ensure_const(&ttB, InvalidOutputType(tB.clone()))?; @@ -297,162 +281,174 @@ where return Err(mkerr(NoDependentLet(tR, tB))); } - return Ok(tB); - } - Annot(x, t) => { - // This is mainly just to check that `t` is not `Kind` - let _ = type_with(ctx, t.clone())?; - - let t2 = normalized_type_with(ctx, x.clone())?; - let t = normalize(t.clone()); - if prop_equal(t.as_ref(), t2.as_ref()) { - return Ok(t.clone()); - } else { - Err(mkerr(AnnotMismatch(x.clone(), t, t2))) - } + Ok(tB.unroll()) } - BoolIf(x, y, z) => { - let tx = normalized_type_with(ctx, x.clone())?; - match tx.as_ref() { - Builtin(Bool) => {} - _ => { - return Err(mkerr(InvalidPredicate(x.clone(), tx))); + _ => match e + .as_ref() + .traverse_ref_simple(|e| Ok((e, type_with(ctx, e.clone())?)))? + { + Lam(_, _, _) => unreachable!(), + Pi(_, _, _) => unreachable!(), + Let(_, _, _, _) => unreachable!(), + Const(c) => axiom(c).map(Const), + Var(V(x, n)) => match ctx.lookup(&x, n) { + Some(e) => Ok(e.unroll()), + None => Err(mkerr(UnboundVariable)), + }, + App((f, mut tf), args) => { + let mut iter = args.into_iter(); + let mut seen_args: Vec<SubExpr<_, _>> = vec![]; + while let Some((a, ta)) = iter.next() { + seen_args.push(a.clone()); + let (x, tx, tb) = match tf.as_ref() { + Pi(x, tx, tb) => (x, tx, tb), + _ => { + return Err(mkerr(NotAFunction( + rc(App(f.clone(), seen_args)), + tf, + ))); + } + }; + ensure_equal(tx.as_ref(), ta.as_ref(), mkerr, || { + TypeMismatch( + rc(App(f.clone(), seen_args.clone())), + tx.clone(), + a.clone(), + ta.clone(), + ) + })?; + tf = normalize(subst_shift(&V(x.clone(), 0), &a, &tb)); } + Ok(tf.unroll()) } - let ty = normalized_type_with(ctx, y.clone())?; - let tty = normalized_type_with(ctx, ty.clone())?; - ensure_is_type( - tty.clone(), - IfBranchMustBeTerm(true, y.clone(), ty.clone(), tty.clone()), - )?; + Annot((x, tx), (t, _)) => { + let t = normalize(t.clone()); + ensure_equal(t.as_ref(), tx.as_ref(), mkerr, || { + AnnotMismatch(x.clone(), t.clone(), tx.clone()) + })?; + Ok(t.unroll()) + } + BoolIf((x, tx), (y, ty), (z, tz)) => { + ensure_equal(tx.as_ref(), &Builtin(Bool), mkerr, || { + InvalidPredicate(x.clone(), tx.clone()) + })?; + let tty = type_with(ctx, ty.clone())?; + ensure_is_type( + tty.clone(), + IfBranchMustBeTerm( + true, + y.clone(), + ty.clone(), + tty.clone(), + ), + )?; - let tz = normalized_type_with(ctx, z.clone())?; - let ttz = normalized_type_with(ctx, tz.clone())?; - ensure_is_type( - ttz.clone(), - IfBranchMustBeTerm(false, z.clone(), tz.clone(), ttz.clone()), - )?; + let ttz = type_with(ctx, tz.clone())?; + ensure_is_type( + ttz.clone(), + IfBranchMustBeTerm( + false, + z.clone(), + tz.clone(), + ttz.clone(), + ), + )?; - if !prop_equal(ty.as_ref(), tz.as_ref()) { - return Err(mkerr(IfBranchMismatch( - y.clone(), - z.clone(), - ty, - tz, - ))); + ensure_equal(ty.as_ref(), tz.as_ref(), mkerr, || { + IfBranchMismatch( + y.clone(), + z.clone(), + ty.clone(), + tz.clone(), + ) + })?; + Ok(ty.unroll()) } - return Ok(ty); - } - EmptyListLit(t) => { - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidListType(t.clone()))?; - let t = normalize(SubExpr::clone(t)); - return Ok(dhall_expr!(List t)); - } - NEListLit(xs) => { - let mut iter = xs.iter().enumerate(); - let (_, first_x) = iter.next().unwrap(); - let t = type_with(ctx, first_x.clone())?; - - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidListType(t.clone()))?; - let t = normalize(t); - for (i, x) in iter { - let t2 = normalized_type_with(ctx, x.clone())?; - if !prop_equal(t.as_ref(), t2.as_ref()) { - return Err(mkerr(InvalidListElement(i, t, x.clone(), t2))); + EmptyListLit((t, tt)) => { + ensure_is_type(tt, InvalidListType(t.clone()))?; + let t = normalize(SubExpr::clone(t)); + Ok(dhall::expr!(List t)) + } + NEListLit(xs) => { + let mut iter = xs.into_iter().enumerate(); + let (_, (_, t)) = iter.next().unwrap(); + let s = type_with(ctx, t.clone())?; + ensure_is_type(s, InvalidListType(t.clone()))?; + for (i, (y, ty)) in iter { + ensure_equal(t.as_ref(), ty.as_ref(), mkerr, || { + InvalidListElement(i, t.clone(), y.clone(), ty.clone()) + })?; } + Ok(dhall::expr!(List t)) } - return Ok(dhall_expr!(List t)); - } - EmptyOptionalLit(t) => { - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidOptionalType(t.clone()))?; - let t = normalize(t.clone()); - return Ok(dhall_expr!(Optional t)); - } - NEOptionalLit(x) => { - let t: SubExpr<_, _> = type_with(ctx, x.clone())?; - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidOptionalType(t.clone()))?; - let t = normalize(t); - return Ok(dhall_expr!(Optional t)); - } - RecordType(kts) => { - for (k, t) in kts { - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidFieldType(k.clone(), t.clone()))?; + EmptyOptionalLit((t, tt)) => { + ensure_is_type(tt, InvalidOptionalType(t.clone()))?; + let t = normalize(t.clone()); + Ok(dhall::expr!(Optional t)) } - Ok(Const(Type)) - } - RecordLit(kvs) => { - let kts = kvs - .iter() - .map(|(k, v)| { - let t = type_with(ctx, v.clone())?; - let s = normalized_type_with(ctx, t.clone())?; - ensure_is_type(s, InvalidField(k.clone(), v.clone()))?; - Ok((k.clone(), t)) - }) - .collect::<Result<_, _>>()?; - Ok(RecordType(kts)) - } - Field(r, x) => { - let t = normalized_type_with(ctx, r.clone())?; - match t.as_ref() { - RecordType(kts) => { - return kts.get(x).cloned().ok_or_else(|| { - mkerr(MissingField(x.clone(), t.clone())) - }) + NEOptionalLit((_, t)) => { + let s = type_with(ctx, t.clone())?; + ensure_is_type(s, InvalidOptionalType(t.clone()))?; + Ok(dhall::expr!(Optional t)) + } + RecordType(kts) => { + for (k, (t, tt)) in kts { + ensure_is_type(tt, InvalidFieldType(k.clone(), t.clone()))?; } - _ => Err(mkerr(NotARecord(x.clone(), r.clone(), t.clone()))), + Ok(Const(Type)) } - } - Builtin(b) => return Ok(type_of_builtin(*b)), - BoolLit(_) => Ok(Builtin(Bool)), - NaturalLit(_) => Ok(Builtin(Natural)), - IntegerLit(_) => Ok(Builtin(Integer)), - DoubleLit(_) => Ok(Builtin(Double)), - TextLit(_) => Ok(Builtin(Text)), - BinOp(o, l, r) => { - let t = match o { - BoolAnd => Bool, - BoolOr => Bool, - BoolEQ => Bool, - BoolNE => Bool, - NaturalPlus => Natural, - NaturalTimes => Natural, - TextAppend => Text, - _ => panic!("Unimplemented typecheck case: {:?}", e), - }; - let tl = normalized_type_with(ctx, l.clone())?; - match tl.as_ref() { - Builtin(lt) if *lt == t => {} - _ => return Err(mkerr(BinOpTypeMismatch(*o, l.clone(), tl))), + RecordLit(kvs) => { + let kts = kvs + .into_iter() + .map(|(k, (v, t))| { + let s = type_with(ctx, t.clone())?; + ensure_is_type(s, InvalidField(k.clone(), v.clone()))?; + Ok((k.clone(), t)) + }) + .collect::<Result<_, _>>()?; + Ok(RecordType(kts)) } + Field((r, tr), x) => match tr.as_ref() { + RecordType(kts) => match kts.get(&x) { + Some(e) => Ok(e.unroll()), + None => Err(mkerr(MissingField(x.clone(), tr.clone()))), + }, + _ => Err(mkerr(NotARecord(x.clone(), r.clone(), tr.clone()))), + }, + Builtin(b) => Ok(type_of_builtin(b)), + BoolLit(_) => Ok(Builtin(Bool)), + NaturalLit(_) => Ok(Builtin(Natural)), + IntegerLit(_) => Ok(Builtin(Integer)), + DoubleLit(_) => Ok(Builtin(Double)), + // TODO: check type of interpolations + TextLit(_) => Ok(Builtin(Text)), + BinOp(o, (l, tl), (r, tr)) => { + let t = Builtin(match o { + BoolAnd => Bool, + BoolOr => Bool, + BoolEQ => Bool, + BoolNE => Bool, + NaturalPlus => Natural, + NaturalTimes => Natural, + TextAppend => Text, + _ => panic!("Unimplemented typecheck case: {:?}", e), + }); - let tr = normalized_type_with(ctx, r.clone())?; - match tr.as_ref() { - Builtin(rt) if *rt == t => {} - _ => return Err(mkerr(BinOpTypeMismatch(*o, r.clone(), tr))), - } + ensure_equal(tl.as_ref(), &t, mkerr, || { + BinOpTypeMismatch(o, l.clone(), tl.clone()) + })?; - Ok(Builtin(t)) - } - Embed(p) => match *p {}, - _ => panic!("Unimplemented typecheck case: {:?}", e), - } - .map(rc) -} + ensure_equal(tr.as_ref(), &t, mkerr, || { + BinOpTypeMismatch(o, r.clone(), tr.clone()) + })?; -pub fn normalized_type_with<S>( - ctx: &Context<Label, SubExpr<S, X>>, - e: SubExpr<S, X>, -) -> Result<SubExpr<S, X>, TypeError<S>> -where - S: ::std::fmt::Debug + Clone, -{ - Ok(normalize(type_with(ctx, e)?)) + Ok(t) + } + Embed(p) => match p {}, + _ => panic!("Unimplemented typecheck case: {:?}", e), + }, + }?; + Ok(rc(ret)) } /// `typeOf` is the same as `type_with` with an empty context, meaning that the |