summaryrefslogtreecommitdiff
path: root/dhall/src/semantics/tck/typecheck.rs
diff options
context:
space:
mode:
Diffstat (limited to 'dhall/src/semantics/tck/typecheck.rs')
-rw-r--r--dhall/src/semantics/tck/typecheck.rs611
1 files changed, 611 insertions, 0 deletions
diff --git a/dhall/src/semantics/tck/typecheck.rs b/dhall/src/semantics/tck/typecheck.rs
new file mode 100644
index 0000000..20076cd
--- /dev/null
+++ b/dhall/src/semantics/tck/typecheck.rs
@@ -0,0 +1,611 @@
+use std::borrow::Cow;
+use std::cmp::max;
+use std::collections::HashMap;
+
+use crate::error::{TypeError, TypeMessage};
+use crate::semantics::merge_maps;
+use crate::semantics::{
+ type_of_builtin, Binder, BuiltinClosure, Closure, TyEnv, TyExpr,
+ TyExprKind, Type, Value, ValueKind,
+};
+use crate::syntax::{
+ BinOp, Builtin, Const, Expr, ExprKind, InterpolatedTextContents, Span,
+};
+use crate::Normalized;
+
+fn type_of_recordtype<'a>(
+ tys: impl Iterator<Item = Cow<'a, TyExpr>>,
+) -> Result<Type, TypeError> {
+ // An empty record type has type Type
+ let mut k = Const::Type;
+ for t in tys {
+ match t.get_type()?.as_const() {
+ Some(c) => k = max(k, c),
+ None => return mkerr("InvalidFieldType"),
+ }
+ }
+ Ok(Value::from_const(k))
+}
+
+fn function_check(a: Const, b: Const) -> Const {
+ if b == Const::Type {
+ Const::Type
+ } else {
+ max(a, b)
+ }
+}
+
+fn type_of_function(src: Type, tgt: Type) -> Result<Type, TypeError> {
+ let ks = match src.as_const() {
+ Some(k) => k,
+ _ => return Err(TypeError::new(TypeMessage::InvalidInputType(src))),
+ };
+ let kt = match tgt.as_const() {
+ Some(k) => k,
+ _ => return Err(TypeError::new(TypeMessage::InvalidOutputType(tgt))),
+ };
+
+ Ok(Value::from_const(function_check(ks, kt)))
+}
+
+fn mkerr<T, S: ToString>(x: S) -> Result<T, TypeError> {
+ Err(TypeError::new(TypeMessage::Custom(x.to_string())))
+}
+
+/// When all sub-expressions have been typed, check the remaining toplevel
+/// layer.
+fn type_one_layer(
+ env: &TyEnv,
+ kind: &ExprKind<TyExpr, Normalized>,
+) -> Result<Type, TypeError> {
+ Ok(match kind {
+ ExprKind::Import(..) => unreachable!(
+ "There should remain no imports in a resolved expression"
+ ),
+ ExprKind::Var(..)
+ | ExprKind::Lam(..)
+ | ExprKind::Pi(..)
+ | ExprKind::Let(..)
+ | ExprKind::Const(Const::Sort)
+ | ExprKind::Embed(..) => unreachable!(), // Handled in type_with
+
+ ExprKind::Const(Const::Type) => Value::from_const(Const::Kind),
+ ExprKind::Const(Const::Kind) => Value::from_const(Const::Sort),
+ ExprKind::Builtin(b) => {
+ let t_expr = type_of_builtin(*b);
+ let t_tyexpr = type_with(env, &t_expr)?;
+ t_tyexpr.eval(env.as_nzenv())
+ }
+ ExprKind::BoolLit(_) => Value::from_builtin(Builtin::Bool),
+ ExprKind::NaturalLit(_) => Value::from_builtin(Builtin::Natural),
+ ExprKind::IntegerLit(_) => Value::from_builtin(Builtin::Integer),
+ ExprKind::DoubleLit(_) => Value::from_builtin(Builtin::Double),
+ ExprKind::TextLit(interpolated) => {
+ let text_type = Value::from_builtin(Builtin::Text);
+ for contents in interpolated.iter() {
+ use InterpolatedTextContents::Expr;
+ if let Expr(x) = contents {
+ if x.get_type()? != text_type {
+ return mkerr("InvalidTextInterpolation");
+ }
+ }
+ }
+ text_type
+ }
+ ExprKind::EmptyListLit(t) => {
+ let t = t.eval(env.as_nzenv());
+ match &*t.kind() {
+ ValueKind::AppliedBuiltin(BuiltinClosure {
+ b: Builtin::List,
+ args,
+ ..
+ }) if args.len() == 1 => {}
+ _ => return mkerr("InvalidListType"),
+ };
+ t
+ }
+ ExprKind::NEListLit(xs) => {
+ let mut iter = xs.iter();
+ let x = iter.next().unwrap();
+ for y in iter {
+ if x.get_type()? != y.get_type()? {
+ return mkerr("InvalidListElement");
+ }
+ }
+ let t = x.get_type()?;
+ if t.get_type()?.as_const() != Some(Const::Type) {
+ return mkerr("InvalidListType");
+ }
+
+ Value::from_builtin(Builtin::List).app(t)
+ }
+ ExprKind::SomeLit(x) => {
+ let t = x.get_type()?;
+ if t.get_type()?.as_const() != Some(Const::Type) {
+ return mkerr("InvalidOptionalType");
+ }
+
+ Value::from_builtin(Builtin::Optional).app(t)
+ }
+ ExprKind::RecordLit(kvs) => {
+ use std::collections::hash_map::Entry;
+ let mut kts = HashMap::new();
+ for (x, v) in kvs {
+ // Check for duplicated entries
+ match kts.entry(x.clone()) {
+ Entry::Occupied(_) => {
+ return mkerr("RecordTypeDuplicateField")
+ }
+ Entry::Vacant(e) => e.insert(v.get_type()?),
+ };
+ }
+
+ let ty = type_of_recordtype(
+ kts.iter()
+ .map(|(_, t)| Cow::Owned(t.to_tyexpr(env.as_varenv()))),
+ )?;
+ Value::from_kind_and_type(ValueKind::RecordType(kts), ty)
+ }
+ ExprKind::RecordType(kts) => {
+ use std::collections::hash_map::Entry;
+ let mut seen_fields = HashMap::new();
+ for (x, _) in kts {
+ // Check for duplicated entries
+ match seen_fields.entry(x.clone()) {
+ Entry::Occupied(_) => {
+ return mkerr("RecordTypeDuplicateField")
+ }
+ Entry::Vacant(e) => e.insert(()),
+ };
+ }
+
+ type_of_recordtype(kts.iter().map(|(_, t)| Cow::Borrowed(t)))?
+ }
+ ExprKind::UnionType(kts) => {
+ use std::collections::hash_map::Entry;
+ let mut seen_fields = HashMap::new();
+ // Check that all types are the same const
+ let mut k = None;
+ for (x, t) in kts {
+ if let Some(t) = t {
+ match (k, t.get_type()?.as_const()) {
+ (None, Some(k2)) => k = Some(k2),
+ (Some(k1), Some(k2)) if k1 == k2 => {}
+ _ => return mkerr("InvalidFieldType"),
+ }
+ }
+ match seen_fields.entry(x) {
+ Entry::Occupied(_) => {
+ return mkerr("UnionTypeDuplicateField")
+ }
+ Entry::Vacant(e) => e.insert(()),
+ };
+ }
+
+ // An empty union type has type Type;
+ // an union type with only unary variants also has type Type
+ let k = k.unwrap_or(Const::Type);
+
+ Value::from_const(k)
+ }
+ ExprKind::Field(scrut, x) => {
+ match &*scrut.get_type()?.kind() {
+ ValueKind::RecordType(kts) => match kts.get(&x) {
+ Some(tth) => tth.clone(),
+ None => return mkerr("MissingRecordField"),
+ },
+ // TODO: branch here only when scrut.get_type() is a Const
+ _ => {
+ let scrut_nf = scrut.eval(env.as_nzenv());
+ match scrut_nf.kind() {
+ ValueKind::UnionType(kts) => match kts.get(x) {
+ // Constructor has type T -> < x: T, ... >
+ Some(Some(ty)) => {
+ let pi_ty = type_of_function(
+ ty.get_type()?,
+ scrut.get_type()?,
+ )?;
+ Value::from_kind_and_type(
+ ValueKind::PiClosure {
+ binder: Binder::new(x.clone()),
+ annot: ty.clone(),
+ closure: Closure::new_constant(
+ scrut_nf,
+ ),
+ },
+ pi_ty,
+ )
+ }
+ Some(None) => scrut_nf,
+ None => return mkerr("MissingUnionField"),
+ },
+ _ => return mkerr("NotARecord"),
+ }
+ } // _ => mkerr("NotARecord"),
+ }
+ }
+ ExprKind::Annot(x, t) => {
+ let t = t.eval(env.as_nzenv());
+ let x_ty = x.get_type()?;
+ if x_ty != t {
+ return mkerr(format!(
+ "annot mismatch: ({} : {}) : {}",
+ x.to_expr_tyenv(env),
+ x_ty.to_tyexpr(env.as_varenv()).to_expr_tyenv(env),
+ t.to_tyexpr(env.as_varenv()).to_expr_tyenv(env)
+ ));
+ // return mkerr(format!(
+ // "annot mismatch: {} != {}",
+ // x_ty.to_tyexpr(env.as_varenv()).to_expr_tyenv(env),
+ // t.to_tyexpr(env.as_varenv()).to_expr_tyenv(env)
+ // ));
+ // return mkerr(format!("annot mismatch: {:#?} : {:#?}", x, t,));
+ }
+ x_ty
+ }
+ ExprKind::Assert(t) => {
+ let t = t.eval(env.as_nzenv());
+ match &*t.kind() {
+ ValueKind::Equivalence(x, y) if x == y => {}
+ ValueKind::Equivalence(..) => return mkerr("AssertMismatch"),
+ _ => return mkerr("AssertMustTakeEquivalence"),
+ }
+ t
+ }
+ ExprKind::App(f, arg) => {
+ match f.get_type()?.kind() {
+ ValueKind::PiClosure { annot, closure, .. } => {
+ if arg.get_type()? != *annot {
+ // return mkerr(format!("function annot mismatch"));
+ return mkerr(format!(
+ "function annot mismatch: ({} : {}) : {}",
+ arg.to_expr_tyenv(env),
+ arg.get_type()?
+ .to_tyexpr(env.as_varenv())
+ .to_expr_tyenv(env),
+ annot.to_tyexpr(env.as_varenv()).to_expr_tyenv(env),
+ ));
+ }
+
+ let arg_nf = arg.eval(env.as_nzenv());
+ closure.apply(arg_nf)
+ }
+ _ => return mkerr(format!("apply to not Pi")),
+ }
+ }
+ ExprKind::BoolIf(x, y, z) => {
+ if *x.get_type()?.kind() != ValueKind::from_builtin(Builtin::Bool) {
+ return mkerr("InvalidPredicate");
+ }
+ if y.get_type()?.get_type()?.as_const() != Some(Const::Type) {
+ return mkerr("IfBranchMustBeTerm");
+ }
+ if z.get_type()?.get_type()?.as_const() != Some(Const::Type) {
+ return mkerr("IfBranchMustBeTerm");
+ }
+ if y.get_type()? != z.get_type()? {
+ return mkerr("IfBranchMismatch");
+ }
+
+ y.get_type()?
+ }
+ ExprKind::BinOp(BinOp::RightBiasedRecordMerge, x, y) => {
+ let x_type = x.get_type()?;
+ let y_type = y.get_type()?;
+
+ // Extract the LHS record type
+ let kts_x = match x_type.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("MustCombineRecord"),
+ };
+ // Extract the RHS record type
+ let kts_y = match y_type.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("MustCombineRecord"),
+ };
+
+ // Union the two records, prefering
+ // the values found in the RHS.
+ let kts = merge_maps::<_, _, _, !>(kts_x, kts_y, |_, _, r_t| {
+ Ok(r_t.clone())
+ })?;
+
+ // Construct the final record type
+ let ty = type_of_recordtype(
+ kts.iter()
+ .map(|(_, t)| Cow::Owned(t.to_tyexpr(env.as_varenv()))),
+ )?;
+ Value::from_kind_and_type(ValueKind::RecordType(kts), ty)
+ }
+ ExprKind::BinOp(BinOp::RecursiveRecordMerge, x, y) => {
+ let ekind = ExprKind::BinOp(
+ BinOp::RecursiveRecordTypeMerge,
+ x.get_type()?.to_tyexpr(env.as_varenv()),
+ y.get_type()?.to_tyexpr(env.as_varenv()),
+ );
+ let ty = type_one_layer(env, &ekind)?;
+ TyExpr::new(TyExprKind::Expr(ekind), Some(ty), Span::Artificial)
+ .eval(env.as_nzenv())
+ }
+ ExprKind::BinOp(BinOp::RecursiveRecordTypeMerge, x, y) => {
+ let x_val = x.eval(env.as_nzenv());
+ let y_val = y.eval(env.as_nzenv());
+ let kts_x = match x_val.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("RecordTypeMergeRequiresRecordType"),
+ };
+ let kts_y = match y_val.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("RecordTypeMergeRequiresRecordType"),
+ };
+ for (k, tx) in kts_x {
+ if let Some(ty) = kts_y.get(k) {
+ type_one_layer(
+ env,
+ &ExprKind::BinOp(
+ BinOp::RecursiveRecordTypeMerge,
+ tx.to_tyexpr(env.as_varenv()),
+ ty.to_tyexpr(env.as_varenv()),
+ ),
+ )?;
+ }
+ }
+
+ // A RecordType's type is always a const
+ let xk = x.get_type()?.as_const().unwrap();
+ let yk = y.get_type()?.as_const().unwrap();
+ Value::from_const(max(xk, yk))
+ }
+ ExprKind::BinOp(BinOp::ListAppend, l, r) => {
+ let l_ty = l.get_type()?;
+ match &*l_ty.kind() {
+ ValueKind::AppliedBuiltin(BuiltinClosure {
+ b: Builtin::List,
+ ..
+ }) => {}
+ _ => return mkerr("BinOpTypeMismatch"),
+ }
+
+ if l_ty != r.get_type()? {
+ return mkerr("BinOpTypeMismatch");
+ }
+
+ l_ty
+ }
+ ExprKind::BinOp(BinOp::Equivalence, l, r) => {
+ if l.get_type()? != r.get_type()? {
+ return mkerr("EquivalenceTypeMismatch");
+ }
+ if l.get_type()?.get_type()?.as_const() != Some(Const::Type) {
+ return mkerr("EquivalenceArgumentsMustBeTerms");
+ }
+
+ Value::from_const(Const::Type)
+ }
+ ExprKind::BinOp(o, l, r) => {
+ let t = Value::from_builtin(match o {
+ BinOp::BoolAnd
+ | BinOp::BoolOr
+ | BinOp::BoolEQ
+ | BinOp::BoolNE => Builtin::Bool,
+ BinOp::NaturalPlus | BinOp::NaturalTimes => Builtin::Natural,
+ BinOp::TextAppend => Builtin::Text,
+ BinOp::ListAppend
+ | BinOp::RightBiasedRecordMerge
+ | BinOp::RecursiveRecordMerge
+ | BinOp::RecursiveRecordTypeMerge
+ | BinOp::Equivalence => unreachable!(),
+ BinOp::ImportAlt => unreachable!("ImportAlt leftover in tck"),
+ });
+
+ if l.get_type()? != t {
+ return mkerr("BinOpTypeMismatch");
+ }
+
+ if r.get_type()? != t {
+ return mkerr("BinOpTypeMismatch");
+ }
+
+ t
+ }
+ ExprKind::Merge(record, union, type_annot) => {
+ let record_type = record.get_type()?;
+ let handlers = match record_type.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("Merge1ArgMustBeRecord"),
+ };
+
+ let union_type = union.get_type()?;
+ let variants = match union_type.kind() {
+ ValueKind::UnionType(kts) => Cow::Borrowed(kts),
+ ValueKind::AppliedBuiltin(BuiltinClosure {
+ b: Builtin::Optional,
+ args,
+ ..
+ }) if args.len() == 1 => {
+ let ty = &args[0];
+ let mut kts = HashMap::new();
+ kts.insert("None".into(), None);
+ kts.insert("Some".into(), Some(ty.clone()));
+ Cow::Owned(kts)
+ }
+ _ => return mkerr("Merge2ArgMustBeUnionOrOptional"),
+ };
+
+ let mut inferred_type = None;
+ for (x, handler_type) in handlers {
+ let handler_return_type = match variants.get(x) {
+ // Union alternative with type
+ Some(Some(variant_type)) => match handler_type.kind() {
+ ValueKind::PiClosure { closure, annot, .. } => {
+ if variant_type != annot {
+ return mkerr("MergeHandlerTypeMismatch");
+ }
+
+ closure.remove_binder().or_else(|()| {
+ mkerr("MergeReturnTypeIsDependent")
+ })?
+ }
+ _ => return mkerr("NotAFunction"),
+ },
+ // Union alternative without type
+ Some(None) => handler_type.clone(),
+ None => return mkerr("MergeHandlerMissingVariant"),
+ };
+ match &inferred_type {
+ None => inferred_type = Some(handler_return_type),
+ Some(t) => {
+ if t != &handler_return_type {
+ return mkerr("MergeHandlerTypeMismatch");
+ }
+ }
+ }
+ }
+ for x in variants.keys() {
+ if !handlers.contains_key(x) {
+ return mkerr("MergeVariantMissingHandler");
+ }
+ }
+
+ let type_annot =
+ type_annot.as_ref().map(|t| t.eval(env.as_nzenv()));
+ match (inferred_type, type_annot) {
+ (Some(t1), Some(t2)) => {
+ if t1 != t2 {
+ return mkerr("MergeAnnotMismatch");
+ }
+ t1
+ }
+ (Some(t), None) => t,
+ (None, Some(t)) => t,
+ (None, None) => return mkerr("MergeEmptyNeedsAnnotation"),
+ }
+ }
+ ExprKind::ToMap(_, _) => unimplemented!("toMap"),
+ ExprKind::Projection(record, labels) => {
+ let record_type = record.get_type()?;
+ let kts = match record_type.kind() {
+ ValueKind::RecordType(kts) => kts,
+ _ => return mkerr("ProjectionMustBeRecord"),
+ };
+
+ let mut new_kts = HashMap::new();
+ for l in labels {
+ match kts.get(l) {
+ None => return mkerr("ProjectionMissingEntry"),
+ Some(t) => {
+ use std::collections::hash_map::Entry;
+ match new_kts.entry(l.clone()) {
+ Entry::Occupied(_) => {
+ return mkerr("ProjectionDuplicateField")
+ }
+ Entry::Vacant(e) => e.insert(t.clone()),
+ }
+ }
+ };
+ }
+
+ Value::from_kind_and_type(
+ ValueKind::RecordType(new_kts),
+ record_type.get_type()?,
+ )
+ }
+ ExprKind::ProjectionByExpr(_, _) => {
+ unimplemented!("selection by expression")
+ }
+ ExprKind::Completion(_, _) => unimplemented!("record completion"),
+ })
+}
+
+/// `type_with` typechecks an expressio in the provided environment.
+pub(crate) fn type_with(
+ env: &TyEnv,
+ expr: &Expr<Normalized>,
+) -> Result<TyExpr, TypeError> {
+ let (tyekind, ty) = match expr.as_ref() {
+ ExprKind::Var(var) => match env.lookup(&var) {
+ Some((k, ty)) => (k, Some(ty)),
+ None => return mkerr("unbound variable"),
+ },
+ ExprKind::Lam(binder, annot, body) => {
+ let annot = type_with(env, annot)?;
+ let annot_nf = annot.eval(env.as_nzenv());
+ let body_env = env.insert_type(&binder, annot_nf.clone());
+ let body = type_with(&body_env, body)?;
+ let body_ty = body.get_type()?;
+ let ty = TyExpr::new(
+ TyExprKind::Expr(ExprKind::Pi(
+ binder.clone(),
+ annot.clone(),
+ body_ty.to_tyexpr(body_env.as_varenv()),
+ )),
+ Some(type_of_function(annot.get_type()?, body_ty.get_type()?)?),
+ Span::Artificial,
+ );
+ let ty = ty.eval(env.as_nzenv());
+ (
+ TyExprKind::Expr(ExprKind::Lam(binder.clone(), annot, body)),
+ Some(ty),
+ )
+ }
+ ExprKind::Pi(binder, annot, body) => {
+ let annot = type_with(env, annot)?;
+ let annot_nf = annot.eval(env.as_nzenv());
+ let body =
+ type_with(&env.insert_type(binder, annot_nf.clone()), body)?;
+ let ty = type_of_function(annot.get_type()?, body.get_type()?)?;
+ (
+ TyExprKind::Expr(ExprKind::Pi(binder.clone(), annot, body)),
+ Some(ty),
+ )
+ }
+ ExprKind::Let(binder, annot, val, body) => {
+ let val = if let Some(t) = annot {
+ t.rewrap(ExprKind::Annot(val.clone(), t.clone()))
+ } else {
+ val.clone()
+ };
+
+ let val = type_with(env, &val)?;
+ let val_nf = val.eval(&env.as_nzenv());
+ let body = type_with(&env.insert_value(&binder, val_nf), body)?;
+ let body_ty = body.get_type().ok();
+ (
+ TyExprKind::Expr(ExprKind::Let(
+ binder.clone(),
+ None,
+ val,
+ body,
+ )),
+ body_ty,
+ )
+ }
+ ExprKind::Const(Const::Sort) => {
+ (TyExprKind::Expr(ExprKind::Const(Const::Sort)), None)
+ }
+ ExprKind::Embed(p) => {
+ return Ok(p.clone().into_value().to_tyexpr_noenv())
+ }
+ ekind => {
+ let ekind = ekind.traverse_ref(|e| type_with(env, e))?;
+ let ty = type_one_layer(env, &ekind)?;
+ (TyExprKind::Expr(ekind), Some(ty))
+ }
+ };
+
+ Ok(TyExpr::new(tyekind, ty, expr.span()))
+}
+
+/// Typecheck an expression and return the expression annotated with types if type-checking
+/// succeeded, or an error if type-checking failed.
+pub(crate) fn typecheck(e: &Expr<Normalized>) -> Result<TyExpr, TypeError> {
+ type_with(&TyEnv::new(), e)
+}
+
+/// Like `typecheck`, but additionally checks that the expression's type matches the provided type.
+pub(crate) fn typecheck_with(
+ expr: &Expr<Normalized>,
+ ty: Expr<Normalized>,
+) -> Result<TyExpr, TypeError> {
+ typecheck(&expr.rewrap(ExprKind::Annot(expr.clone(), ty)))
+}