diff options
-rw-r--r-- | src/context.rs | 52 | ||||
-rw-r--r-- | src/core.rs | 487 | ||||
-rw-r--r-- | src/grammar.lalrpop | 1 | ||||
-rw-r--r-- | src/grammar_util.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 20 | ||||
-rw-r--r-- | src/typecheck.rs | 621 |
6 files changed, 1176 insertions, 13 deletions
diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..4d6abf2 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,52 @@ +use std::borrow::Cow; +use std::collections::HashMap; + +/// A `(Context a)` associates `Text` labels with values of type `a` +/// +/// The `Context` is used for type-checking when `(a = Expr X)` +/// +/// * You create a `Context` using `empty` and `insert` +/// * You transform a `Context` using `fmap` +/// * You consume a `Context` using `lookup` and `toList` +/// +/// The difference between a `Context` and a `Map` is that a `Context` lets you +/// have multiple ordered occurrences of the same key and you can query for the +/// `n`th occurrence of a given key. +/// +#[derive(Debug, Clone)] +pub struct Context<'i, T>(HashMap<Cow<'i, str>, Vec<T>>); + +impl<'i, T> Context<'i, T> { + /// An empty context with no key-value pairs + pub fn new() -> Self { + Context(HashMap::new()) + } + + /// Look up a key by name and index + /// + /// ```c + /// lookup _ _ empty = Nothing + /// lookup k 0 (insert k v c) = Just v + /// lookup k n (insert k v c) = lookup k (n - 1) c -- 1 <= n + /// lookup k n (insert j v c) = lookup k n c -- k /= j + /// ``` + pub fn lookup<'a>(&'a self, k: &str, n: usize) -> Option<&'a T> { + self.0.get(k).and_then(|v| v.get(n)) + } + + pub fn map<U, F: Fn(&T) -> U>(&self, f: F) -> Context<'i, U> { + Context(self.0.iter().map(|(k, v)| (k.clone(), v.iter().map(&f).collect())).collect()) + } +} + +impl<'i, T: Clone> Context<'i, T> { + /// Add a key-value pair to the `Context` + pub fn insert(&self, k: Cow<'i, str>, v: T) -> Self { + let mut ctx = (*self).clone(); + { + let m = ctx.0.entry(k).or_insert(vec![]); + m.push(v); + } + ctx + } +} diff --git a/src/core.rs b/src/core.rs index 85ebf1b..5bf90db 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,3 +1,4 @@ +#![allow(non_snake_case)] use std::borrow::Cow; use std::collections::HashMap; use std::path::PathBuf; @@ -216,8 +217,494 @@ pub enum Expr<'i, S, A> { Embed(A), } +impl<'i, S, A> Expr<'i, S, A> { + /// Clones the expression if it is a unit constructor + fn clone_unit<T, B>(&self) -> Option<Expr<'static, T, B>> { + use Expr::*; + match self { + &Bool => Some(Bool), + &Natural => Some(Natural), + &NaturalFold => Some(NaturalFold), + &NaturalBuild => Some(NaturalBuild), + &NaturalIsZero => Some(NaturalIsZero), + &NaturalEven => Some(NaturalEven), + &NaturalOdd => Some(NaturalOdd), + &Integer => Some(Integer), + &Double => Some(Double), + &Text => Some(Text), + &List => Some(List), + &ListBuild => Some(ListBuild), + &ListFold => Some(ListFold), + &ListLength => Some(ListLength), + &ListHead => Some(ListHead), + &ListLast => Some(ListLast), + &ListIndexed => Some(ListIndexed), + &ListReverse => Some(ListReverse), + &Optional => Some(Optional), + &OptionalFold => Some(OptionalFold), + _ => None, + } + } + + /// Returns true if the expression is a unit constructor + pub fn is_unit(&self) -> bool { + self.clone_unit::<S, A>().is_some() + } +} + +impl<'i> From<&'i str> for V<'i> { + fn from(s: &'i str) -> Self { + V(Cow::Borrowed(s), 0) + } +} + +impl<'i, S, A> From<&'i str> for Expr<'i, S, A> { + fn from(s: &'i str) -> Self { + Expr::Var(V(Cow::Borrowed(s), 0)) + } +} + +pub fn pi<'i, S, A, Name, Et, Ev>(var: Name, ty: Et, value: Ev) -> Expr<'i, S, A> + where Name: Into<Cow<'i, str>>, + Et: Into<Expr<'i, S, A>>, + Ev: Into<Expr<'i, S, A>> +{ + Expr::Pi(var.into(), bx(ty.into()), bx(value.into())) +} + +pub fn app<'i, S, A, Ef, Ex>(f: Ef, x: Ex) -> Expr<'i, S, A> + where Ef: Into<Expr<'i, S, A>>, + Ex: Into<Expr<'i, S, A>> +{ + Expr::App(bx(f.into()), bx(x.into())) +} + pub type Builder = String; pub type Double = f64; pub type Int = isize; pub type Integer = isize; pub type Natural = usize; + +/// A void type +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum X {} + +pub fn bx<T>(x: T) -> Box<T> { + Box::new(x) +} + +fn add_ui(u: usize, i: isize) -> usize { + if i < 0 { + u.checked_sub((i.checked_neg().unwrap() as usize)).unwrap() + } else { + u.checked_add(i as usize).unwrap() + } +} + +/// `shift` is used by both normalization and type-checking to avoid variable +/// capture by shifting variable indices +/// +/// For example, suppose that you were to normalize the following expression: +/// +/// ```c +/// λ(a : Type) → λ(x : a) → (λ(y : a) → λ(x : a) → y) x +/// ``` +/// +/// If you were to substitute `y` with `x` without shifting any variable +/// indices, then you would get the following incorrect result: +/// +/// ```c +/// λ(a : Type) → λ(x : a) → λ(x : a) → x -- Incorrect normalized form +/// ``` +/// +/// In order to substitute `x` in place of `y` we need to `shift` `x` by `1` in +/// order to avoid being misinterpreted as the `x` bound by the innermost +/// lambda. If we perform that `shift` then we get the correct result: +/// +/// ```c +/// λ(a : Type) → λ(x : a) → λ(x : a) → x@1 +/// ``` +/// +/// As a more worked example, suppose that you were to normalize the following +/// expression: +/// +/// ```c +/// λ(a : Type) +/// → λ(f : a → a → a) +/// → λ(x : a) +/// → λ(x : a) +/// → (λ(x : a) → f x x@1) x@1 +/// ``` +/// +/// The correct normalized result would be: +/// +/// ```c +/// λ(a : Type) +/// → λ(f : a → a → a) +/// → λ(x : a) +/// → λ(x : a) +/// → f x@1 x +/// ``` +/// +/// The above example illustrates how we need to both increase and decrease +/// variable indices as part of substitution: +/// +/// * We need to increase the index of the outer `x\@1` to `x\@2` before we +/// substitute it into the body of the innermost lambda expression in order +/// to avoid variable capture. This substitution changes the body of the +/// lambda expression to `(f x\@2 x\@1)` +/// +/// * We then remove the innermost lambda and therefore decrease the indices of +/// both `x`s in `(f x\@2 x\@1)` to `(f x\@1 x)` in order to reflect that one +/// less `x` variable is now bound within that scope +/// +/// Formally, `(shift d (V x n) e)` modifies the expression `e` by adding `d` to +/// the indices of all variables named `x` whose indices are greater than +/// `(n + m)`, where `m` is the number of bound variables of the same name +/// within that scope +/// +/// In practice, `d` is always `1` or `-1` because we either: +/// +/// * increment variables by `1` to avoid variable capture during substitution +/// * decrement variables by `1` when deleting lambdas after substitution +/// +/// `n` starts off at `0` when substitution begins and increments every time we +/// descend into a lambda or let expression that binds a variable of the same +/// name in order to avoid shifting the bound variables by mistake. +/// +pub fn shift<'i, S, T, A>(d: isize, v: V, e: Expr<'i, S, A>) -> Expr<'i, T, A> + where S: ::std::fmt::Debug, + T: ::std::fmt::Debug, + A: ::std::fmt::Debug, +{ + use Expr::*; + match e { + Const(a) => Const(a), + Var(V(x2, n2)) => { + let V(x, n) = v; + let n3 = if x == x2 && n <= n2 { add_ui(n2, d) } else { n2 }; + Var(V(x2, n3)) + } + Lam(x2, tA, b) => { + let V(x, n) = v; + let n2 = if x == x2 { n + 1 } else { n }; + let tA2 = shift(d, V(x.clone(), n ), *tA); + let b2 = shift(d, V(x, n2), *b); + Lam(x2, bx(tA2), bx(b2)) + } + Pi(x2, tA, tB) => { + let V(x, n) = v; + let n2 = if x == x2 { n + 1 } else { n }; + let tA2 = shift(d, V(x.clone(), n ), *tA); + let tB2 = shift(d, V(x, n2), *tB); + pi(x2, tA2, tB2) + } + App(f, a) => { + let f2 = shift(d, v.clone(), *f); + let a2 = shift(d, v, *a); + App(bx(f2), bx(a2)) + } +/* +shift d (V x n) (Let f mt r e) = Let f mt' r' e' + where + e' = shift d (V x n') e + where + n' = if x == f then n + 1 else n + + mt' = fmap (shift d (V x n)) mt + r' = shift d (V x n) r +shift d v (Annot a b) = Annot a' b' + where + a' = shift d v a + b' = shift d v b + */ + BoolLit(a) => BoolLit(a), + BoolAnd(a, b) => BoolAnd(bx(shift(d, v.clone(), *a)), bx(shift(d, v, *b))), +/* +shift d v (BoolOr a b) = BoolOr a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (BoolEQ a b) = BoolEQ a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (BoolNE a b) = BoolNE a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (BoolIf a b c) = BoolIf a' b' c' + where + a' = shift d v a + b' = shift d v b + c' = shift d v c +shift _ _ Natural = Natural +shift _ _ (NaturalLit a) = NaturalLit a +shift _ _ NaturalFold = NaturalFold +shift _ _ NaturalBuild = NaturalBuild +shift _ _ NaturalIsZero = NaturalIsZero +shift _ _ NaturalEven = NaturalEven +shift _ _ NaturalOdd = NaturalOdd +shift d v (NaturalPlus a b) = NaturalPlus a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (NaturalTimes a b) = NaturalTimes a' b' + where + a' = shift d v a + b' = shift d v b +shift _ _ Integer = Integer +shift _ _ (IntegerLit a) = IntegerLit a +shift _ _ Double = Double +shift _ _ (DoubleLit a) = DoubleLit a +shift _ _ Text = Text +shift _ _ (TextLit a) = TextLit a +shift d v (TextAppend a b) = TextAppend a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (ListLit a b) = ListLit a' b' + where + a' = shift d v a + b' = fmap (shift d v) b +shift _ _ ListBuild = ListBuild +shift _ _ ListFold = ListFold +shift _ _ ListLength = ListLength +shift _ _ ListHead = ListHead +shift _ _ ListLast = ListLast +shift _ _ ListIndexed = ListIndexed +shift _ _ ListReverse = ListReverse +shift _ _ Optional = Optional +shift d v (OptionalLit a b) = OptionalLit a' b' + where + a' = shift d v a + b' = fmap (shift d v) b +shift _ _ OptionalFold = OptionalFold +shift d v (Record a) = Record a' + where + a' = fmap (shift d v) a +shift d v (RecordLit a) = RecordLit a' + where + a' = fmap (shift d v) a +shift d v (Union a) = Union a' + where + a' = fmap (shift d v) a +shift d v (UnionLit a b c) = UnionLit a b' c' + where + b' = shift d v b + c' = fmap (shift d v) c +shift d v (Combine a b) = Combine a' b' + where + a' = shift d v a + b' = shift d v b +shift d v (Merge a b c) = Merge a' b' c' + where + a' = shift d v a + b' = shift d v b + c' = shift d v c +shift d v (Field a b) = Field a' b + where + a' = shift d v a +shift d v (Note _ b) = b' + where + b' = shift d v b +-- The Dhall compiler enforces that all embedded values are closed expressions +-- and `shift` does nothing to a closed expression +shift _ _ (Embed p) = Embed p +*/ + e => if let Some(e2) = e.clone_unit() { + e2 + } else { + panic!("Unimplemented shift case: {:?}", (d, v, e)) + }, + } +} + + +/// Substitute all occurrences of a variable with an expression +/// +/// ```c +/// subst x C B ~ B[x := C] +/// ``` +/// +pub fn subst<'i, S, T, A>(v: V<'i>, a: Expr<'i, S, A>, b: Expr<'i, T, A>) -> Expr<'i, S, A> + where S: Clone + ::std::fmt::Debug, + T: Clone + ::std::fmt::Debug, + A: Clone + ::std::fmt::Debug, +{ + use Expr::*; + match (a, b) { + (_, Const(a)) => Const(a), + (e, Lam(y, tA, b)) => { + let V(x, n) = v; + let n2 = if x == y { n + 1 } else { n }; + let tA2 = subst(V(x.clone(), n), e.clone(), *tA); + let b2 = subst(V(x, n2), shift(1, V(y.clone(), 0), e), *b); + Lam(y, bx(tA2), bx(b2)) + } + (e, Pi(y, tA, tB)) => { + let V(x, n) = v; + let n2 = if x == y { n + 1 } else { n }; + let tA2 = subst(V(x.clone(), n), e.clone(), *tA); + let tB2 = subst(V(x, n2), shift(1, V(y.clone(), 0), e), *tB); + pi(y, tA2, tB2) + } + (e, App(f, a)) => { + let f2 = subst(v.clone(), e.clone(), *f); + let a2 = subst(v, e, *a); + app(f2, a2) + } + (e, Var(v2)) => if v == v2 { e } else { Var(v2) }, + (e, ListLit(a, b)) => { + let b2 = b.into_iter().map(|be| subst(v.clone(), e.clone(), be)).collect(); + let a2 = subst(v, e, *a); + ListLit(bx(a2), b2) + } + (a, b) => if let Some(e2) = b.clone_unit() { + e2 + } else { + panic!("Unimplemented subst case: {:?}", (v, a, b)) + } + } +} + +/// Reduce an expression to its normal form, performing beta reduction +/// +/// `normalize` does not type-check the expression. You may want to type-check +/// expressions before normalizing them since normalization can convert an +/// ill-typed expression into a well-typed expression. +/// +/// However, `normalize` will not fail if the expression is ill-typed and will +/// leave ill-typed sub-expressions unevaluated. +/// +pub fn normalize<S, T, A>(e: Expr<S, A>) -> Expr<T, A> + where S: Clone + ::std::fmt::Debug, + T: Clone + ::std::fmt::Debug, + A: Clone + ::std::fmt::Debug, +{ + use Expr::*; + match e { + Const(k) => Const(k), + Var(v) => Var(v), + Lam(x, tA, b) => { + let tA2 = normalize(*tA); + let b2 = normalize(*b); + Lam(x, bx(tA2), bx(b2)) + } + Pi(x, tA, tB) => { + let tA2 = normalize(*tA); + let tB2 = normalize(*tB); + pi(x, tA2, tB2) + } + App(f, a) => match normalize::<S, T, A>(*f) { + Lam(x, _A, b) => { // Beta reduce + let vx0 = V(x, 0); + let a2 = shift::<S, S, A>( 1, vx0.clone(), *a); + let b2 = subst::<S, T, A>(vx0.clone(), a2, *b); + let b3 = shift::<S, T, A>(-1, vx0, b2); + normalize(b3) + } + f2 => match (f2, normalize::<S, T, A>(*a)) { + /* + -- fold/build fusion for `List` + App (App ListBuild _) (App (App ListFold _) e') -> normalize e' + App (App ListFold _) (App (App ListBuild _) e') -> normalize e' + + -- fold/build fusion for `Natural` + App NaturalBuild (App NaturalFold e') -> normalize e' + App NaturalFold (App NaturalBuild e') -> normalize e' + + App (App (App (App NaturalFold (NaturalLit n0)) _) succ') zero -> + normalize (go n0) + where + go !0 = zero + go !n = App succ' (go (n - 1)) + App NaturalBuild k + | check -> NaturalLit n + | otherwise -> App f' a' + where + labeled = + normalize (App (App (App k Natural) "Succ") "Zero") + + n = go 0 labeled + where + go !m (App (Var "Succ") e') = go (m + 1) e' + go !m (Var "Zero") = m + go !_ _ = internalError text + check = go labeled + where + go (App (Var "Succ") e') = go e' + go (Var "Zero") = True + go _ = False + */ + (NaturalIsZero, NaturalLit(n)) => BoolLit(n == 0), + (NaturalEven, NaturalLit(n)) => BoolLit(n % 2 == 0), + (NaturalOdd, NaturalLit(n)) => BoolLit(n % 2 != 0), + /* + App (App ListBuild t) k + | check -> ListLit t (buildVector k') + | otherwise -> App f' a' + where + labeled = + normalize (App (App (App k (App List t)) "Cons") "Nil") + + k' cons nil = go labeled + where + go (App (App (Var "Cons") x) e') = cons x (go e') + go (Var "Nil") = nil + go _ = internalError text + check = go labeled + where + go (App (App (Var "Cons") _) e') = go e' + go (Var "Nil") = True + go _ = False + App (App (App (App (App ListFold _) (ListLit _ xs)) _) cons) nil -> + normalize (Data.Vector.foldr cons' nil xs) + where + cons' y ys = App (App cons y) ys + App (App ListLength _) (ListLit _ ys) -> + NaturalLit (fromIntegral (Data.Vector.length ys)) + App (App ListHead _) (ListLit t ys) -> + normalize (OptionalLit t (Data.Vector.take 1 ys)) + App (App ListLast _) (ListLit t ys) -> + normalize (OptionalLit t y) + where + y = if Data.Vector.null ys + then Data.Vector.empty + else Data.Vector.singleton (Data.Vector.last ys) + App (App ListIndexed _) (ListLit t xs) -> + normalize (ListLit t' (fmap adapt (Data.Vector.indexed xs))) + where + t' = Record (Data.Map.fromList kts) + where + kts = [ ("index", Natural) + , ("value", t) + ] + adapt (n, x) = RecordLit (Data.Map.fromList kvs) + where + kvs = [ ("index", NaturalLit (fromIntegral n)) + , ("value", x) + ] + App (App ListReverse _) (ListLit t xs) -> + normalize (ListLit t (Data.Vector.reverse xs)) + App (App (App (App (App OptionalFold _) (OptionalLit _ xs)) _) just) nothing -> + normalize (maybe nothing just' (toMaybe xs)) + where + just' y = App just y + toMaybe = Data.Maybe.listToMaybe . Data.Vector.toList + */ + (f2, a2) => app(f2, a2), + } + }, + ListLit(t, es) => { + let t2 = normalize(*t); + let es2 = es.into_iter().map(normalize).collect(); + ListLit(bx(t2), es2) + } + _ => if let Some(e2) = e.clone_unit() { + e2 + } else { + panic!("Unimplemented normalize case: {:?}", e) + } + } +} diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 3e216ac..91f8b8d 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -1,4 +1,5 @@ use core; +use core::bx; use core::Expr::*; use grammar_util::*; use lexer::*; diff --git a/src/grammar_util.rs b/src/grammar_util.rs index cf0ee59..6927d33 100644 --- a/src/grammar_util.rs +++ b/src/grammar_util.rs @@ -1,15 +1,11 @@ -use core::Expr; +use core::{Expr, X}; use lexer::Builtin; -pub type ParsedExpr<'i> = Expr<'i, (), ()>; +pub type ParsedExpr<'i> = Expr<'i, X, X>; // FIXME Parse paths and replace the second X with Path pub type BoxExpr<'i> = Box<ParsedExpr<'i>>; pub type ExprOpFn<'i> = fn(BoxExpr<'i>, BoxExpr<'i>) -> ParsedExpr<'i>; pub type ExprListFn<'i> = fn(BoxExpr<'i>, Vec<ParsedExpr<'i>>) -> ParsedExpr<'i>; -pub fn bx<T>(x: T) -> Box<T> { - Box::new(x) -} - pub fn builtin_expr<'i, S, A>(b: Builtin) -> Expr<'i, S, A> { match b { Builtin::Natural => Expr::Natural, diff --git a/src/main.rs b/src/main.rs index deb6ac3..17c990d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,14 @@ extern crate lalrpop_util; extern crate nom; extern crate term_painter; +pub mod context; mod core; pub use core::*; pub mod grammar; mod grammar_util; pub mod lexer; pub mod parser; +pub mod typecheck; use std::io::{self, Read}; @@ -62,29 +64,33 @@ fn print_error(message: &str, source: &str, start: usize, end: usize) { fn main() { let mut buffer = String::new(); io::stdin().read_to_string(&mut buffer).unwrap(); - match parser::parse_expr(&buffer) { - Ok(e) => println!("{:?}", e), + let expr = match parser::parse_expr(&buffer) { + Ok(e) => e, Err(lalrpop_util::ParseError::User { error: lexer::LexicalError::Error(pos, e) }) => { print_error(&format!("Unexpected token {:?}", e), &buffer, pos, pos); + return; } Err(lalrpop_util::ParseError::UnrecognizedToken { token: Some((start, t, end)), expected: e }) => { print_error(&format!("Unrecognized token {:?}", t), &buffer, start, end); if e.len() > 0 { println!("Expected {:?}", e); } + return; } Err(e) => { print_error(&format!("Parser error {:?}", e), &buffer, 0, 0); + return; } - } + }; - /* - expr <- case exprFromText (Directed "(stdin)" 0 0 0 0) inText of - Left err -> Control.Exception.throwIO err - Right expr -> return expr + println!("{:?}", expr); + /* expr' <- load expr + */ + println!("{:?}", typecheck::type_of(&expr)); + /* typeExpr <- case Dhall.TypeCheck.typeOf expr' of Left err -> Control.Exception.throwIO err Right typeExpr -> return typeExpr diff --git a/src/typecheck.rs b/src/typecheck.rs new file mode 100644 index 0000000..e552de5 --- /dev/null +++ b/src/typecheck.rs @@ -0,0 +1,621 @@ +#![allow(non_snake_case)] +use std::collections::HashSet; + +use context::Context; +use core; +use core::{Expr, V, X, bx, normalize, shift, subst}; +use core::{pi, app}; +use core::Expr::*; +use core::Const::*; + +use self::TypeMessage::*; + +fn axiom<'i, S: Clone>(c: core::Const) -> Result<core::Const, TypeError<'i, S>> { + match c { + Type => Ok(Kind), + Kind => Err(TypeError::new(&Context::new(), &Const(Kind), Untyped)), + } +} + +fn rule(a: core::Const, b: core::Const) -> Result<core::Const, ()> { + match (a, b) { + (Type, Kind) => Err(()), + (Type, Type) => Ok(Type), + (Kind, Kind) => Ok(Kind), + (Kind, Type) => Ok(Type), + } +} + +fn match_vars(vl: &V, vr: &V, ctx: &[(&str, &str)]) -> bool { + let xxs = ctx.get(0).map(|x| (x, ctx.split_at(1).1)); + match (vl, vr, xxs) { + (&V(ref xL, nL), &V(ref xR, nR), None) => xL == xR && nL == nR, + (&V(ref xL, 0), &V(ref xR, 0), Some((&(ref xL2, ref xR2), _))) if xL == xL2 && xR == xR2 => true, + (&V(ref xL, nL), &V(ref xR, nR), Some((&(ref xL2, ref xR2), xs))) => { + let nL2 = if *xL == xL2.as_ref() { nL - 1 } else { nL }; + let nR2 = if *xR == xR2.as_ref() { nR - 1 } else { nR }; + match_vars(&V(xL.clone(), nL2), &V(xR.clone(), nR2), xs) + } + } +} + +fn prop_equal<S, T>(eL0: &Expr<S, X>, eR0: &Expr<T, X>) -> bool + where S: Clone + ::std::fmt::Debug, + T: Clone + ::std::fmt::Debug +{ + fn go<'i, S, T>(ctx: &mut Vec<(&'i str, &'i str)>, el: &'i Expr<'i, S, X>, er: &'i Expr<'i, T, X>) -> bool + where S: Clone + ::std::fmt::Debug, + T: Clone + ::std::fmt::Debug + { + match (el, er) { + (&Const(Type), &Const(Type)) => true, + (&Const(Kind), &Const(Kind)) => true, + (&Var(ref vL), &Var(ref vR)) => match_vars(vL, vR, &*ctx), + (&Pi(ref xL, ref tL, ref bL), &Pi(ref xR, ref tR, ref bR)) => { + //ctx <- State.get + let eq1 = go(ctx, tL, tR); + if eq1 { + //State.put ((xL, xR):ctx) + ctx.push((xL, xR)); + let eq2 = go(ctx, bL, bR); + //State.put ctx + let _ = ctx.pop(); + eq2 + } else { + false + } + } + (&App(ref fL, ref aL), &App(ref fR, ref aR)) => + if go(ctx, fL, fR) { go(ctx, aL, aR) } else { false }, + (&Bool, &Bool) => true, + (&Natural, &Natural) => true, + (&Integer, &Integer) => true, + (&Double, &Double) => true, + (&Text, &Text) => true, + (&List, &List) => true, + (&Optional, &Optional) => true, + (&Record(ref _ktsL0), &Record(ref _ktsR0)) => unimplemented!(), + /* + let loop ((kL, tL):ktsL) ((kR, tR):ktsR) + | kL == kR = do + b <- go tL tR + if b + then loop ktsL ktsR + else return False + loop [] [] = return True + loop _ _ = return False + loop (Data.Map.toList ktsL0) (Data.Map.toList ktsR0) + */ + (&Union(ref _ktsL0), &Union(ref _ktsR0)) => unimplemented!(), + /* + let loop ((kL, tL):ktsL) ((kR, tR):ktsR) + | kL == kR = do + b <- go tL tR + if b + then loop ktsL ktsR + else return False + loop [] [] = return True + loop _ _ = return False + loop (Data.Map.toList ktsL0) (Data.Map.toList ktsR0) + */ + (_, _) => false, + } + } + let mut ctx = vec![]; + go::<S, T>(&mut ctx, &normalize(eL0.clone()), &normalize(eR0.clone())) +} + + +/// Type-check an expression and return the expression'i type if type-checking +/// suceeds or an error if type-checking fails +/// +/// `type_with` does not necessarily normalize the type since full normalization +/// is not necessary for just type-checking. If you actually care about the +/// returned type then you may want to `normalize` it afterwards. +pub fn type_with<'i, S>(ctx: &Context<'i, Expr<'i, S, X>>, + e: &Expr<'i, S, X>) + -> Result<Expr<'i, S, X>, TypeError<'i, S>> + where S: Clone + ::std::fmt::Debug + 'i +{ + match e { + &Const(c) => axiom(c).map(Const), //.map(Cow::Owned), + &Var(V(ref x, n)) => { + ctx.lookup(x, n) + .cloned() + //.map(Cow::Borrowed) + .ok_or_else(|| TypeError::new(ctx, &e, UnboundVariable)) + } + &Lam(ref x, ref tA, ref b) => { + let ctx2 = ctx.insert(x.clone(), (**tA).clone()).map(|e| core::shift(1, V(x.clone(), 0), e.clone())); + let tB = type_with(&ctx2, b)?; + let p = Pi(x.clone(), tA.clone(), bx(tB)); + let _ = type_with(ctx, &p)?; + //Ok(Cow::Owned(p)) + Ok(p) + } + &Pi(ref x, ref tA, ref tB) => { + let tA2 = normalize::<S, S, X>(type_with(ctx, tA)?); + let kA = match tA2 { + Const(k) => k, + _ => return Err(TypeError::new(ctx, e, InvalidInputType((**tA).clone()))), + }; + + let ctx2 = ctx.insert(x.clone(), (**tA).clone()).map(|e| core::shift(1, V(x.clone(), 0), e.clone())); + let tB = normalize(type_with(&ctx2, tB)?); + let kB = match tB { + Const(k) => k, + _ => return Err(TypeError::new(&ctx2, e, InvalidOutputType(tB))), + }; + + match rule(kA, kB) { + Err(()) => Err(TypeError::new(ctx, e, NoDependentTypes((**tA).clone(), tB))), + Ok(k) => Ok(Const(k)), + } + } + &App(ref f, ref a) => { + let tf = normalize(type_with(ctx, f)?); + let (x, tA, tB) = match tf { + Pi(x, tA, tB) => (x, tA, tB), + _ => return Err(TypeError::new(ctx, e, NotAFunction((**f).clone(), tf))), + }; + let tA2 = type_with(ctx, a)?; + if prop_equal(&tA, &tA2) { + let vx0 = V(x, 0); + let a2 = shift::<S, S, X>( 1, vx0.clone(), (**a).clone()); + let tB2 = subst(vx0.clone(), a2, (*tB).clone()); + let tB3 = shift::<S, S, X>(-1, vx0, tB2); + Ok(tB3) + } else { + let nf_A = normalize(*tA); + let nf_A2 = normalize(tA2); + Err(TypeError::new(ctx, e, TypeMismatch((**f).clone(), nf_A, (**a).clone(), nf_A2))) + } + } + &Let(ref f, ref mt, ref r, ref b) => { + let tR = type_with(ctx, r)?; + let ttR = normalize::<S, S, X>(type_with(ctx, &tR)?); + let kR = match ttR { + Const(k) => k, + // Don't bother to provide a `let`-specific version of this error + // message because this should never happen anyway + _ => return Err(TypeError::new(ctx, &e, InvalidInputType(tR))), + }; + + let ctx2 = ctx.insert(f.clone(), tR.clone()); + let tB = type_with(&ctx2, b)?; + let ttB = normalize::<S, S, X>(type_with(ctx, &tB)?); + let kB = match ttB { + Const(k) => k, + // Don't bother to provide a `let`-specific version of this error + // message because this should never happen anyway + _ => return Err(TypeError::new(ctx, &e, InvalidOutputType(tB))), + }; + + if let Err(()) = rule(kR, kB) { + return Err(TypeError::new(ctx, &e, NoDependentLet(tR, tB))); + } + + if let &Some(ref t) = mt { + let nf_t = normalize((**t).clone()); + let nf_tR = normalize(tR.clone()); + if !prop_equal(&nf_tR, &nf_t) { + return Err(TypeError::new(ctx, &e, AnnotMismatch((**r).clone(), nf_t, nf_tR))); + } + } + + Ok(tB) + } +/* +type_with ctx e@(Annot x t ) = do + -- This is mainly just to check that `t` is not `Kind` + _ <- type_with ctx t + + t' <- type_with ctx x + if prop_equal t t' + then do + return t + else do + let nf_t = Dhall.Core.normalize t + let nf_t' = Dhall.Core.normalize t' + Left (TypeError ctx e (AnnotMismatch x nf_t nf_t')) +*/ + &Bool => Ok(Const(Type)), + &BoolLit(_) => Ok(Bool), + &BoolAnd(ref l, ref r) => { + let tl = normalize(type_with(ctx, l)?); + match tl { + Bool => {} + _ => return Err(TypeError::new(ctx, e, CantAnd((**l).clone(), tl))), + } + + let tr = normalize(type_with(ctx, r)?); + match tr { + Bool => {} + _ => return Err(TypeError::new(ctx, e, CantAnd((**r).clone(), tr))), + } + + Ok(Bool) + } + /* +type_with ctx e@(BoolOr l r ) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Bool -> return () + _ -> Left (TypeError ctx e (CantOr l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Bool -> return () + _ -> Left (TypeError ctx e (CantOr r tr)) + + return Bool +type_with ctx e@(BoolEQ l r ) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Bool -> return () + _ -> Left (TypeError ctx e (CantEQ l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Bool -> return () + _ -> Left (TypeError ctx e (CantEQ r tr)) + + return Bool +type_with ctx e@(BoolNE l r ) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Bool -> return () + _ -> Left (TypeError ctx e (CantNE l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Bool -> return () + _ -> Left (TypeError ctx e (CantNE r tr)) + + return Bool +type_with ctx e@(BoolIf x y z ) = do + tx <- fmap Dhall.Core.normalize (type_with ctx x) + case tx of + Bool -> return () + _ -> Left (TypeError ctx e (InvalidPredicate x tx)) + ty <- fmap Dhall.Core.normalize (type_with ctx y ) + tty <- fmap Dhall.Core.normalize (type_with ctx ty) + case tty of + Const Type -> return () + _ -> Left (TypeError ctx e (IfBranchMustBeTerm True y ty tty)) + + tz <- fmap Dhall.Core.normalize (type_with ctx z) + ttz <- fmap Dhall.Core.normalize (type_with ctx tz) + case ttz of + Const Type -> return () + _ -> Left (TypeError ctx e (IfBranchMustBeTerm False z tz ttz)) + + if prop_equal ty tz + then return () + else Left (TypeError ctx e (IfBranchMismatch y z ty tz)) + return ty + */ + &Natural => Ok(Const(Type)), + &NaturalLit(_) => Ok(Natural), + /* +type_with _ NaturalFold = do + return + (Pi "_" Natural + (Pi "natural" (Const Type) + (Pi "succ" (Pi "_" "natural" "natural") + (Pi "zero" "natural" "natural") ) ) ) +type_with _ NaturalBuild = do + return + (Pi "_" + (Pi "natural" (Const Type) + (Pi "succ" (Pi "_" "natural" "natural") + (Pi "zero" "natural" "natural") ) ) + Natural ) + */ + &NaturalIsZero => Ok(pi("_", Natural, Bool)), + &NaturalEven => Ok(pi("_", Natural, Bool)), + &NaturalOdd => Ok(pi("_", Natural, Bool)), + /* +type_with ctx e@(NaturalPlus l r) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Natural -> return () + _ -> Left (TypeError ctx e (CantAdd l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Natural -> return () + _ -> Left (TypeError ctx e (CantAdd r tr)) + return Natural +type_with ctx e@(NaturalTimes l r) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Natural -> return () + _ -> Left (TypeError ctx e (CantMultiply l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Natural -> return () + _ -> Left (TypeError ctx e (CantMultiply r tr)) + return Natural + */ + &Integer => Ok(Const(Type)), + &IntegerLit(_) => Ok(Integer), + &Double => Ok(Const(Type)), + &DoubleLit(_) => Ok(Double), + &Text => Ok(Const(Type)), + &TextLit(_) => Ok(Text), + /* +type_with ctx e@(TextAppend l r ) = do + tl <- fmap Dhall.Core.normalize (type_with ctx l) + case tl of + Text -> return () + _ -> Left (TypeError ctx e (CantTextAppend l tl)) + + tr <- fmap Dhall.Core.normalize (type_with ctx r) + case tr of + Text -> return () + _ -> Left (TypeError ctx e (CantTextAppend r tr)) + return Text + */ + &List => Ok(pi("_", Const(Type), Const(Type))), + /* +type_with ctx e@(ListLit t xs ) = do + s <- fmap Dhall.Core.normalize (type_with ctx t) + case s of + Const Type -> return () + _ -> Left (TypeError ctx e (InvalidListType t)) + flip Data.Vector.imapM_ xs (\i x -> do + t' <- type_with ctx x + if prop_equal t t' + then return () + else do + let nf_t = Dhall.Core.normalize t + let nf_t' = Dhall.Core.normalize t' + Left (TypeError ctx e (InvalidListElement i nf_t x nf_t')) ) + return (App List t) + */ + &ListBuild => + Ok(pi("a", Const(Type), + pi("_", + pi("list", Const(Type), + pi("cons", pi("_", "a", pi("_", "list", "list")), + pi("nil", "list", "list"))), + app(List, "a")))), + &ListFold => + Ok(pi("a", Const(Type), + pi("_", app(List, "a"), + pi("list", Const(Type), + pi("cons", pi("_", "a", pi("_", "list", "list")), + pi("nil", "list", "list")))))), + &ListLength => + Ok(pi("a", Const(Type), pi("_", app(List, "a"), Natural))), + &ListHead => + Ok(pi("a", Const(Type), pi("_", app(List, "a"), app(Optional, "a")))), + &ListLast => + Ok(pi("a", Const(Type), pi("_", app(List, "a"), app(Optional, "a")))), + /* +type_with _ ListIndexed = do + let kts = [("index", Natural), ("value", "a")] + return + (Pi "a" (Const Type) + (Pi "_" (App List "a") + (App List (Record (Data.Map.fromList kts))) ) ) +type_with _ ListReverse = do + return (Pi "a" (Const Type) (Pi "_" (App List "a") (App List "a"))) +type_with _ Optional = do + return (Pi "_" (Const Type) (Const Type)) +type_with ctx e@(OptionalLit t xs) = do + s <- fmap Dhall.Core.normalize (type_with ctx t) + case s of + Const Type -> return () + _ -> Left (TypeError ctx e (InvalidOptionalType t)) + let n = Data.Vector.length xs + if 2 <= n + then Left (TypeError ctx e (InvalidOptionalLiteral n)) + else return () + forM_ xs (\x -> do + t' <- type_with ctx x + if prop_equal t t' + then return () + else do + let nf_t = Dhall.Core.normalize t + let nf_t' = Dhall.Core.normalize t' + Left (TypeError ctx e (InvalidOptionalElement nf_t x nf_t')) ) + return (App Optional t) +type_with _ OptionalFold = do + return + (Pi "a" (Const Type) + (Pi "_" (App Optional "a") + (Pi "optional" (Const Type) + (Pi "just" (Pi "_" "a" "optional") + (Pi "nothing" "optional" "optional") ) ) ) ) +type_with ctx e@(Record kts ) = do + let process (k, t) = do + s <- fmap Dhall.Core.normalize (type_with ctx t) + case s of + Const Type -> return () + _ -> Left (TypeError ctx e (InvalidFieldType k t)) + mapM_ process (Data.Map.toList kts) + return (Const Type) +type_with ctx e@(RecordLit kvs ) = do + let process (k, v) = do + t <- type_with ctx v + s <- fmap Dhall.Core.normalize (type_with ctx t) + case s of + Const Type -> return () + _ -> Left (TypeError ctx e (InvalidField k v)) + return (k, t) + kts <- mapM process (Data.Map.toAscList kvs) + return (Record (Data.Map.fromAscList kts)) +type_with ctx e@(Union kts ) = do + let process (k, t) = do + s <- fmap Dhall.Core.normalize (type_with ctx t) + case s of + Const Type -> return () + _ -> Left (TypeError ctx e (InvalidAlternativeType k t)) + mapM_ process (Data.Map.toList kts) + return (Const Type) +type_with ctx e@(UnionLit k v kts) = do + case Data.Map.lookup k kts of + Just _ -> Left (TypeError ctx e (DuplicateAlternative k)) + Nothing -> return () + t <- type_with ctx v + let union = Union (Data.Map.insert k t kts) + _ <- type_with ctx union + return union +type_with ctx e@(Combine kvsX kvsY) = do + tKvsX <- fmap Dhall.Core.normalize (type_with ctx kvsX) + ktsX <- case tKvsX of + Record kts -> return kts + _ -> Left (TypeError ctx e (MustCombineARecord kvsX tKvsX)) + + tKvsY <- fmap Dhall.Core.normalize (type_with ctx kvsY) + ktsY <- case tKvsY of + Record kts -> return kts + _ -> Left (TypeError ctx e (MustCombineARecord kvsY tKvsY)) + + let combineTypes ktsL ktsR = do + let ks = + Data.Set.union (Data.Map.keysSet ktsL) (Data.Map.keysSet ktsR) + kts <- forM (toList ks) (\k -> do + case (Data.Map.lookup k ktsL, Data.Map.lookup k ktsR) of + (Just (Record ktsL'), Just (Record ktsR')) -> do + t <- combineTypes ktsL' ktsR' + return (k, t) + (Nothing, Just t) -> do + return (k, t) + (Just t, Nothing) -> do + return (k, t) + _ -> do + Left (TypeError ctx e (FieldCollision k)) ) + return (Record (Data.Map.fromList kts)) + + combineTypes ktsX ktsY +type_with ctx e@(Merge kvsX kvsY t) = do + tKvsX <- fmap Dhall.Core.normalize (type_with ctx kvsX) + ktsX <- case tKvsX of + Record kts -> return kts + _ -> Left (TypeError ctx e (MustMergeARecord kvsX tKvsX)) + let ksX = Data.Map.keysSet ktsX + + tKvsY <- fmap Dhall.Core.normalize (type_with ctx kvsY) + ktsY <- case tKvsY of + Union kts -> return kts + _ -> Left (TypeError ctx e (MustMergeUnion kvsY tKvsY)) + let ksY = Data.Map.keysSet ktsY + + let diffX = Data.Set.difference ksX ksY + let diffY = Data.Set.difference ksY ksX + + if Data.Set.null diffX + then return () + else Left (TypeError ctx e (UnusedHandler diffX)) + + let process (kY, tY) = do + case Data.Map.lookup kY ktsX of + Nothing -> Left (TypeError ctx e (MissingHandler diffY)) + Just tX -> + case tX of + Pi _ tY' t' -> do + if prop_equal tY tY' + then return () + else Left (TypeError ctx e (HandlerInputTypeMismatch kY tY tY')) + if prop_equal t t' + then return () + else Left (TypeError ctx e (HandlerOutputTypeMismatch kY t t')) + _ -> Left (TypeError ctx e (HandlerNotAFunction kY tX)) + mapM_ process (Data.Map.toList ktsY) + return t +type_with ctx e@(Field r x ) = do + t <- fmap Dhall.Core.normalize (type_with ctx r) + case t of + Record kts -> + case Data.Map.lookup x kts of + Just t' -> return t' + Nothing -> Left (TypeError ctx e (MissingField x t)) + _ -> Left (TypeError ctx e (NotARecord x r t)) +type_with ctx (Note s e' ) = case type_with ctx e' of + Left (TypeError ctx2 (Note s' e'') m) -> Left (TypeError ctx2 (Note s' e'') m) + Left (TypeError ctx2 e'' m) -> Left (TypeError ctx2 (Note s e'') m) + Right r -> Right r +type_with _ (Embed p ) = do + absurd p +*/ + _ => panic!("Unimplemented typecheck case: {:?}", e), + } +} + +/// `typeOf` is the same as `type_with` with an empty context, meaning that the +/// expression must be closed (i.e. no free variables), otherwise type-checking +/// will fail. +pub fn type_of<'i, S: Clone + ::std::fmt::Debug + 'i>(e: &Expr<'i, S, X>) -> Result<Expr<'i, S, X>, TypeError<'i, S>> { + let ctx = Context::new(); + type_with(&ctx, e) //.map(|e| e.into_owned()) +} + +/// The specific type error +#[derive(Debug)] +pub enum TypeMessage<'i, S> { + UnboundVariable, + InvalidInputType(Expr<'i, S, X>), + InvalidOutputType(Expr<'i, S, X>), + NotAFunction(Expr<'i, S, X>, Expr<'i, S, X>), + TypeMismatch(Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + AnnotMismatch(Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + Untyped, + InvalidListElement(isize, Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + InvalidListType(Expr<'i, S, X>), + InvalidOptionalElement(Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + InvalidOptionalLiteral(isize), + InvalidOptionalType(Expr<'i, S, X>), + InvalidPredicate(Expr<'i, S, X>, Expr<'i, S, X>), + IfBranchMismatch(Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + IfBranchMustBeTerm(bool, Expr<'i, S, X>, Expr<'i, S, X>, Expr<'i, S, X>), + InvalidField(String, Expr<'i, S, X>), + InvalidFieldType(String, Expr<'i, S, X>), + InvalidAlternative(String, Expr<'i, S, X>), + InvalidAlternativeType(String, Expr<'i, S, X>), + DuplicateAlternative(String), + MustCombineARecord(Expr<'i, S, X>, Expr<'i, S, X>), + FieldCollision(String), + MustMergeARecord(Expr<'i, S, X>, Expr<'i, S, X>), + MustMergeUnion(Expr<'i, S, X>, Expr<'i, S, X>), + UnusedHandler(HashSet<String>), + MissingHandler(HashSet<String>), + HandlerInputTypeMismatch(String, Expr<'i, S, X>, Expr<'i, S, X>), + HandlerOutputTypeMismatch(String, Expr<'i, S, X>, Expr<'i, S, X>), + HandlerNotAFunction(String, Expr<'i, S, X>), + NotARecord(String, Expr<'i, S, X>, Expr<'i, S, X>), + MissingField(String, Expr<'i, S, X>), + CantAnd(Expr<'i, S, X>, Expr<'i, S, X>), + CantOr(Expr<'i, S, X>, Expr<'i, S, X>), + CantEQ(Expr<'i, S, X>, Expr<'i, S, X>), + CantNE(Expr<'i, S, X>, Expr<'i, S, X>), + CantStringAppend(Expr<'i, S, X>, Expr<'i, S, X>), + CantAdd(Expr<'i, S, X>, Expr<'i, S, X>), + CantMultiply(Expr<'i, S, X>, Expr<'i, S, X>), + NoDependentLet(Expr<'i, S, X>, Expr<'i, S, X>), + NoDependentTypes(Expr<'i, S, X>, Expr<'i, S, X>), +} + +/// A structured type error that includes context +#[derive(Debug)] +pub struct TypeError<'i, S> { + context: Context<'i, Expr<'i, S, X>>, + current: Expr<'i, S, X>, + type_message: TypeMessage<'i, S>, +} + +impl<'i, S: Clone> TypeError<'i, S> { + pub fn new(context: &Context<'i, Expr<'i, S, X>>, + current: &Expr<'i, S, X>, + type_message: TypeMessage<'i, S>) + -> Self { + TypeError { + context: context.clone(), + current: current.clone(), + type_message: type_message, + } + } +} |