summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core.rs95
-rw-r--r--src/typecheck.rs128
2 files changed, 122 insertions, 101 deletions
diff --git a/src/core.rs b/src/core.rs
index e9f185c..0faffe5 100644
--- a/src/core.rs
+++ b/src/core.rs
@@ -365,7 +365,7 @@ fn add_ui(u: usize, i: isize) -> usize {
pub fn shift<'i, S, T, A>(d: isize, v: V, e: &Expr<'i, S, A>) -> Expr<'i, T, A>
where S: ::std::fmt::Debug,
T: ::std::fmt::Debug,
- A: ::std::fmt::Debug,
+ A: Clone + ::std::fmt::Debug,
{
use Expr::*;
let V(x, n) = v;
@@ -388,20 +388,21 @@ pub fn shift<'i, S, T, A>(d: isize, v: V, e: &Expr<'i, S, A>) -> Expr<'i, T, A>
pi(x2, tA2, tB2)
}
&App(ref f, ref a) => app(shift(d, v, f), shift(d, v, a)),
+ &Let(f, ref mt, ref r, ref e) => {
+ let n2 = if x == f { n + 1 } else { n };
+ let e2 = shift(d, V(x, n2), e);
+ let mt2 = mt.as_ref().map(|t| bx(shift(d, V(x, n), t)));
+ let r2 = shift(d, V(x, n), r);
+ Let(f, mt2, bx(r2), bx(e2))
+ }
/*
-shift d (V x n) (Let f mt r e) = Let f mt' r' e'
- where
- e' = shift d (V x n') e
- where
- n' = if x == f then n + 1 else n
-
- mt' = fmap (shift d (V x n)) mt
- r' = shift d (V x n) r
shift d v (Annot a b) = Annot a' b'
where
a' = shift d v a
b' = shift d v b
*/
+ &BuiltinType(t) => BuiltinType(t),
+ &BuiltinValue(v) => BuiltinValue(v),
&BoolLit(a) => BoolLit(a),
&BoolAnd(ref a, ref b) => BoolAnd(bx(shift(d, v, a)), bx(shift(d, v, b))),
/*
@@ -422,26 +423,16 @@ shift d v (BoolIf a b c) = BoolIf a' b' c'
a' = shift d v a
b' = shift d v b
c' = shift d v c
-shift _ _ Natural = Natural
-shift _ _ (NaturalLit a) = NaturalLit a
-shift _ _ NaturalFold = NaturalFold
-shift _ _ NaturalBuild = NaturalBuild
-shift _ _ NaturalIsZero = NaturalIsZero
-shift _ _ NaturalEven = NaturalEven
-shift _ _ NaturalOdd = NaturalOdd
-shift d v (NaturalPlus a b) = NaturalPlus a' b'
- where
- a' = shift d v a
- b' = shift d v b
+*/
+ &NaturalLit(a) => NaturalLit(a),
+ &NaturalPlus(ref a, ref b) => NaturalPlus(bx(shift(d, v, a)), bx(shift(d, v, b))),
+/*
shift d v (NaturalTimes a b) = NaturalTimes a' b'
where
a' = shift d v a
b' = shift d v b
-shift _ _ Integer = Integer
shift _ _ (IntegerLit a) = IntegerLit a
-shift _ _ Double = Double
shift _ _ (DoubleLit a) = DoubleLit a
-shift _ _ Text = Text
shift _ _ (TextLit a) = TextLit a
shift d v (TextAppend a b) = TextAppend a' b'
where
@@ -451,28 +442,18 @@ shift d v (ListLit a b) = ListLit a' b'
where
a' = shift d v a
b' = fmap (shift d v) b
-shift _ _ ListBuild = ListBuild
-shift _ _ ListFold = ListFold
-shift _ _ ListLength = ListLength
-shift _ _ ListHead = ListHead
-shift _ _ ListLast = ListLast
-shift _ _ ListIndexed = ListIndexed
-shift _ _ ListReverse = ListReverse
-shift _ _ Optional = Optional
shift d v (OptionalLit a b) = OptionalLit a' b'
where
a' = shift d v a
b' = fmap (shift d v) b
-shift _ _ OptionalFold = OptionalFold
-shift d v (Record a) = Record a'
- where
- a' = fmap (shift d v) a
-shift d v (RecordLit a) = RecordLit a'
- where
- a' = fmap (shift d v) a
-shift d v (Union a) = Union a'
- where
- a' = fmap (shift d v) a
+*/
+ &Record(ref a) =>
+ Record(a.iter().map(|(&k, val)| (k, shift(d, v, val))).collect()),
+ &RecordLit(ref a) =>
+ RecordLit(a.iter().map(|(&k, val)| (k, shift(d, v, val))).collect()),
+ &Union(ref a) =>
+ Union(a.iter().map(|(&k, val)| (k, shift(d, v, val))).collect()),
+ /*
shift d v (UnionLit a b c) = UnionLit a b' c'
where
b' = shift d v b
@@ -486,18 +467,12 @@ shift d v (Merge a b c) = Merge a' b' c'
a' = shift d v a
b' = shift d v b
c' = shift d v c
-shift d v (Field a b) = Field a' b
- where
- a' = shift d v a
-shift d v (Note _ b) = b'
- where
- b' = shift d v b
--- The Dhall compiler enforces that all embedded values are closed expressions
--- and `shift` does nothing to a closed expression
-shift _ _ (Embed p) = Embed p
-*/
- &BuiltinType(t) => BuiltinType(t),
- &BuiltinValue(v) => BuiltinValue(v),
+ */
+ &Field(ref a, b) => Field(bx(shift(d, v, a)), b),
+ &Note(_, ref b) => shift(d, v, b),
+ // The Dhall compiler enforces that all embedded values are closed expressions
+ // and `shift` does nothing to a closed expression
+ &Embed(ref p) => Embed(p.clone()),
e => panic!("Unimplemented shift case: {:?}", (d, v, e)),
}
}
@@ -535,13 +510,22 @@ pub fn subst<'i, S, T, A>(v: V<'i>, e: &Expr<'i, S, A>, b: &Expr<'i, T, A>) -> E
app(f2, a2)
}
&Var(v2) => if v == v2 { e.clone() } else { Var(v2) },
+ &Let(f, ref mt, ref r, ref b) => {
+ let n2 = if x == f { n + 1 } else { n };
+ let b2 = subst(V(x, n2), &shift(1, V(f, 0), e), b);
+ let mt2 = mt.as_ref().map(|t| bx(subst(V(x, n), e, t)));
+ let r2 = subst(V(x, n), e, r);
+ Let(f, mt2, bx(r2), bx(b2))
+ }
+ &BuiltinType(t) => BuiltinType(t),
+ &BuiltinValue(v) => BuiltinValue(v),
&ListLit(ref a, ref b) => {
let a2 = subst(v, e, a);
let b2 = b.iter().map(|be| subst(v, e, be)).collect();
ListLit(bx(a2), b2)
}
- &BuiltinType(t) => BuiltinType(t),
- &BuiltinValue(v) => BuiltinValue(v),
+ &Record(ref kts) => Record(kts.iter().map(|(&k, t)| (k, subst(v, e, t))).collect()),
+ &RecordLit(ref kvs) => Record(kvs.iter().map(|(&k, val)| (k, subst(v, e, val))).collect()),
b => panic!("Unimplemented subst case: {:?}", (v, e, b)),
}
}
@@ -680,6 +664,7 @@ pub fn normalize<S, T, A>(e: Expr<S, A>) -> Expr<T, A>
let es2 = es.into_iter().map(normalize).collect();
ListLit(bx(t2), es2)
}
+ Record(kts) => Record(kts.into_iter().map(|(k, t)| (k, normalize(t))).collect()),
BuiltinType(t) => BuiltinType(t),
BuiltinValue(v) => BuiltinValue(v),
_ => panic!("Unimplemented normalize case: {:?}", e),
diff --git a/src/typecheck.rs b/src/typecheck.rs
index f6e6c7f..6fa9747 100644
--- a/src/typecheck.rs
+++ b/src/typecheck.rs
@@ -70,19 +70,40 @@ fn prop_equal<S, T>(eL0: &Expr<S, X>, eR0: &Expr<T, X>) -> bool
(&App(ref fL, ref aL), &App(ref fR, ref aR)) =>
if go(ctx, fL, fR) { go(ctx, aL, aR) } else { false },
(&BuiltinType(a), &BuiltinType(b)) => a == b,
- (&Record(ref _ktsL0), &Record(ref _ktsR0)) => unimplemented!(),
- /*
- let loop ((kL, tL):ktsL) ((kR, tR):ktsR)
+ (&Record(ref ktsL0), &Record(ref ktsR0)) => {
+ if ktsL0.len() != ktsR0.len() {
+ return false;
+ }
+ /*
+ let go ((kL, tL):ktsL) ((kR, tR):ktsR)
| kL == kR = do
b <- go tL tR
if b
- then loop ktsL ktsR
+ then go ktsL ktsR
else return False
- loop [] [] = return True
- loop _ _ = return False
- loop (Data.Map.toList ktsL0) (Data.Map.toList ktsR0)
- */
- (&Union(ref _ktsL0), &Union(ref _ktsR0)) => unimplemented!(),
+ go [] [] = return True
+ go _ _ = return False
+ */
+ /*
+ for ((kL, tL), (kR, tR)) in ktsL0.iter().zip(ktsR0.iter()) {
+ if kL == kR {
+ if !go(ctx, tL, tR) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ true
+ */
+ !ktsL0.iter().zip(ktsR0.iter()).any(|((kL, tL), (kR, tR))| {
+ kL != kR || !go(ctx, tL, tR)
+ })
+ }
+ (&Union(ref ktsL0), &Union(ref ktsR0)) => {
+ if ktsL0.len() != ktsR0.len() {
+ return false;
+ }
/*
let loop ((kL, tL):ktsL) ((kR, tR):ktsR)
| kL == kR = do
@@ -94,6 +115,10 @@ fn prop_equal<S, T>(eL0: &Expr<S, X>, eR0: &Expr<T, X>) -> bool
loop _ _ = return False
loop (Data.Map.toList ktsL0) (Data.Map.toList ktsR0)
*/
+ !ktsL0.iter().zip(ktsR0.iter()).any(|((kL, tL), (kR, tR))| {
+ kL != kR || !go(ctx, tL, tR)
+ })
+ }
(_, _) => false,
}
}
@@ -313,18 +338,21 @@ type_with _ NaturalBuild = do
&BuiltinValue(NaturalIsZero) => Ok(pi("_", Natural, Bool)),
&BuiltinValue(NaturalEven) => Ok(pi("_", Natural, Bool)),
&BuiltinValue(NaturalOdd) => Ok(pi("_", Natural, Bool)),
- /*
-type_with ctx e@(NaturalPlus l r) = do
- tl <- fmap Dhall.Core.normalize (type_with ctx l)
- case tl of
- Natural -> return ()
- _ -> Left (TypeError ctx e (CantAdd l tl))
+ &NaturalPlus(ref l, ref r) => {
+ let tl = normalize(type_with(ctx, l)?);
+ match tl {
+ BuiltinType(Natural) => {}
+ _ => return Err(TypeError::new(ctx, e, CantAdd((**l).clone(), tl))),
+ }
- tr <- fmap Dhall.Core.normalize (type_with ctx r)
- case tr of
- Natural -> return ()
- _ -> Left (TypeError ctx e (CantAdd r tr))
- return Natural
+ let tr = normalize(type_with(ctx, r)?);
+ match tr {
+ BuiltinType(Natural) => {}
+ _ => return Err(TypeError::new(ctx, e, CantAdd((**r).clone(), tr))),
+ }
+ Ok(BuiltinType(Natural))
+ }
+ /*
type_with ctx e@(NaturalTimes l r) = do
tl <- fmap Dhall.Core.normalize (type_with ctx l)
case tl of
@@ -424,24 +452,30 @@ type_with _ OptionalFold = do
(Pi "optional" (Const Type)
(Pi "just" (Pi "_" "a" "optional")
(Pi "nothing" "optional" "optional") ) ) ) )
-type_with ctx e@(Record kts ) = do
- let process (k, t) = do
- s <- fmap Dhall.Core.normalize (type_with ctx t)
- case s of
- Const Type -> return ()
- _ -> Left (TypeError ctx e (InvalidFieldType k t))
- mapM_ process (Data.Map.toList kts)
- return (Const Type)
-type_with ctx e@(RecordLit kvs ) = do
- let process (k, v) = do
- t <- type_with ctx v
- s <- fmap Dhall.Core.normalize (type_with ctx t)
- case s of
- Const Type -> return ()
- _ -> Left (TypeError ctx e (InvalidField k v))
- return (k, t)
- kts <- mapM process (Data.Map.toAscList kvs)
- return (Record (Data.Map.fromAscList kts))
+*/
+ &Record(ref kts) => {
+ for (k, t) in kts {
+ let s = normalize::<S, S, X>(type_with(ctx, t)?);
+ match s {
+ Const(Type) => {}
+ _ => return Err(TypeError::new(ctx, e, InvalidFieldType((*k).to_owned(), (*t).clone()))),
+ }
+ }
+ Ok(Const(Type))
+ }
+ &RecordLit(ref kvs) => {
+ let kts = kvs.iter().map(|(&k, v)| {
+ let t = type_with(ctx, v)?;
+ let s = normalize::<S, S, X>(type_with(ctx, &t)?);
+ match s {
+ Const(Type) => {}
+ _ => return Err(TypeError::new(ctx, e, InvalidField((*k).to_owned(), (*v).clone()))),
+ }
+ Ok((k, t))
+ }).collect::<Result<_, _>>()?;
+ Ok(Record(kts))
+ }
+/*
type_with ctx e@(Union kts ) = do
let process (k, t) = do
s <- fmap Dhall.Core.normalize (type_with ctx t)
@@ -521,14 +555,16 @@ type_with ctx e@(Merge kvsX kvsY t) = do
_ -> Left (TypeError ctx e (HandlerNotAFunction kY tX))
mapM_ process (Data.Map.toList ktsY)
return t
-type_with ctx e@(Field r x ) = do
- t <- fmap Dhall.Core.normalize (type_with ctx r)
- case t of
- Record kts ->
- case Data.Map.lookup x kts of
- Just t' -> return t'
- Nothing -> Left (TypeError ctx e (MissingField x t))
- _ -> Left (TypeError ctx e (NotARecord x r t))
+ */
+ &Field(ref r, x) => {
+ let t = normalize(type_with(ctx, r)?);
+ match &t {
+ &Record(ref kts) =>
+ kts.get(x).cloned().ok_or_else(|| TypeError::new(ctx, e, MissingField(x.to_owned(), t.clone()))),
+ _ => Err(TypeError::new(ctx, e, NotARecord(x.to_owned(), (**r).clone(), t.clone()))),
+ }
+ }
+ /*
type_with ctx (Note s e' ) = case type_with ctx e' of
Left (TypeError ctx2 (Note s' e'') m) -> Left (TypeError ctx2 (Note s' e'') m)
Left (TypeError ctx2 e'' m) -> Left (TypeError ctx2 (Note s e'') m)