From f112145814ed8243904d97c92a15bbdb7053d1a0 Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Mon, 18 Mar 2019 00:21:41 +0100 Subject: Considerably simplify subst, shift and typechecking --- dhall/src/main.rs | 2 +- dhall/src/normalize.rs | 25 ++-- dhall/src/typecheck.rs | 297 +++++++++++---------------------------- dhall/tests/common/mod.rs | 5 +- dhall_core/src/core.rs | 350 +++++++++++++--------------------------------- 5 files changed, 188 insertions(+), 491 deletions(-) diff --git a/dhall/src/main.rs b/dhall/src/main.rs index 41953e3..810d789 100644 --- a/dhall/src/main.rs +++ b/dhall/src/main.rs @@ -90,5 +90,5 @@ fn main() { println!("{}", type_expr); println!(); - println!("{}", normalize::<_, X, _>(expr)); + println!("{}", normalize(expr)); } diff --git a/dhall/src/normalize.rs b/dhall/src/normalize.rs index af30e3b..8438670 100644 --- a/dhall/src/normalize.rs +++ b/dhall/src/normalize.rs @@ -163,10 +163,9 @@ where let just = Rc::clone(&args[3]); return normalize_whnf(&dhall_expr!(just x)); } - ( - OptionalFold, - [_, OptionalLit(_, None), _, _, _], - ) => return Rc::clone(&args[4]), + (OptionalFold, [_, OptionalLit(_, None), _, _, _]) => { + return Rc::clone(&args[4]) + } // // fold/build fusion // (OptionalFold, [_, App(box Builtin(OptionalBuild), [_, x, rest..]), rest..]) => { // normalize_whnf(&App(bx(x.clone()), rest.to_vec())) @@ -237,9 +236,7 @@ where NaturalLit(x * y) } // TODO: interpolation - (TextAppend, TextLit(x), TextLit(y)) => { - TextLit(x + y) - } + (TextAppend, TextLit(x), TextLit(y)) => TextLit(x + y), (ListAppend, ListLit(t1, xs), ListLit(t2, ys)) => { let t1: Option> = t1.as_ref().map(Rc::clone); let t2: Option> = t2.as_ref().map(Rc::clone); @@ -279,16 +276,10 @@ where /// However, `normalize` will not fail if the expression is ill-typed and will /// leave ill-typed sub-expressions unevaluated. /// -pub fn normalize(e: SubExpr) -> SubExpr +pub fn normalize(e: SubExpr) -> SubExpr where - S: Clone + fmt::Debug, - T: Clone + fmt::Debug, - A: Clone + fmt::Debug, + S: fmt::Debug, + A: fmt::Debug, { - rc(normalize_whnf(&e).map_shallow_rc( - |x| normalize(Rc::clone(x)), - |_| unreachable!(), - |x| x.clone(), - |x| x.clone(), - )) + map_subexpr_rc(&normalize_whnf(&e), |x| normalize(Rc::clone(x))) } diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index b507e52..15629a9 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -15,7 +15,7 @@ use dhall_generator::dhall_expr; use self::TypeMessage::*; -fn axiom(c: core::Const) -> Result> { +fn axiom(c: core::Const) -> Result> { match c { Type => Ok(Kind), Kind => Err(TypeError::new(&Context::new(), rc(Const(Kind)), Untyped)), @@ -49,8 +49,8 @@ fn match_vars(vl: &V, vr: &V, ctx: &[(Label, Label)]) -> bool { fn prop_equal(eL0: Rc>, eR0: Rc>) -> bool where - S: Clone + ::std::fmt::Debug, - T: Clone + ::std::fmt::Debug, + S: ::std::fmt::Debug, + T: ::std::fmt::Debug, { fn go( ctx: &mut Vec<(Label, Label)>, @@ -58,8 +58,8 @@ where er: Rc>, ) -> bool where - S: Clone + ::std::fmt::Debug, - T: Clone + ::std::fmt::Debug, + S: ::std::fmt::Debug, + T: ::std::fmt::Debug, { match (el.as_ref(), er.as_ref()) { (&Const(Type), &Const(Type)) | (&Const(Kind), &Const(Kind)) => true, @@ -153,7 +153,7 @@ fn op2_type( r: &Rc>, ) -> Result, TypeError> where - S: Clone + ::std::fmt::Debug, + S: ::std::fmt::Debug, EF: FnOnce(Rc>, Rc>) -> TypeMessage, { let tl = normalize(type_with(ctx, l.clone())?); @@ -182,17 +182,18 @@ pub fn type_with( e: Rc>, ) -> Result>, TypeError> where - S: Clone + ::std::fmt::Debug, + S: ::std::fmt::Debug, { use dhall_core::BinOp::*; use dhall_core::Expr; + let mkerr = |msg: TypeMessage<_>| TypeError::new(ctx, e.clone(), msg); match *e { Const(c) => axiom(c).map(Const), Var(V(ref x, n)) => { return ctx .lookup(x, n) .cloned() - .ok_or_else(|| TypeError::new(ctx, e.clone(), UnboundVariable)) + .ok_or_else(|| mkerr(UnboundVariable)) } Lam(ref x, ref tA, ref b) => { let ctx2 = ctx @@ -204,15 +205,11 @@ where return Ok(p); } Pi(ref x, ref tA, ref tB) => { - let tA2 = normalize::(type_with(ctx, tA.clone())?); + let tA2 = normalize(type_with(ctx, tA.clone())?); let kA = match tA2.as_ref() { Const(k) => k, _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidInputType(tA.clone()), - )); + return Err(mkerr(InvalidInputType(tA.clone()))); } }; @@ -232,29 +229,26 @@ where }; match rule(kA, kB) { - Err(()) => Err(TypeError::new( - ctx, - e.clone(), - NoDependentTypes(tA.clone(), tB), - )), + Err(()) => Err(mkerr(NoDependentTypes(tA.clone(), tB))), Ok(k) => Ok(Const(k)), } } App(ref f, ref args) => { - let (a, args) = match args.split_last() { + // Recurse on args + let (a, tf) = match args.split_last() { None => return type_with(ctx, f.clone()), - Some(x) => x, + Some((a, args)) => ( + a, + normalize(type_with( + ctx, + rc(App(f.clone(), args.to_vec())), + )?), + ), }; - let tf = - normalize(type_with(ctx, rc(App(f.clone(), args.to_vec())))?); let (x, tA, tB) = match tf.as_ref() { Pi(x, tA, tB) => (x, tA, tB), _ => { - return Err(TypeError::new( - ctx, - e.clone(), - NotAFunction(f.clone(), tf), - )); + return Err(mkerr(NotAFunction(f.clone(), tf))); } }; let tA2 = type_with(ctx, a.clone())?; @@ -267,62 +261,38 @@ where } else { let nf_A = normalize(tA.clone()); let nf_A2 = normalize(tA2); - Err(TypeError::new( - ctx, - e.clone(), - TypeMismatch(f.clone(), nf_A, a.clone(), nf_A2), - )) + Err(mkerr(TypeMismatch(f.clone(), nf_A, a.clone(), nf_A2))) } } Let(ref f, ref mt, ref r, ref b) => { let tR = type_with(ctx, r.clone())?; - let ttR = normalize::(type_with(ctx, tR.clone())?); + let ttR = normalize(type_with(ctx, tR.clone())?); let kR = match ttR.as_ref() { Const(k) => k, // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway - _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidInputType(tR), - )) - } + _ => return Err(mkerr(InvalidInputType(tR))), }; let ctx2 = ctx.insert(f.clone(), tR.clone()); let tB = type_with(&ctx2, b.clone())?; - let ttB = normalize::(type_with(ctx, tB.clone())?); + let ttB = normalize(type_with(ctx, tB.clone())?); let kB = match ttB.as_ref() { Const(k) => k, // Don't bother to provide a `let`-specific version of this error // message because this should never happen anyway - _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidOutputType(tB), - )) - } + _ => return Err(mkerr(InvalidOutputType(tB))), }; if let Err(()) = rule(kR, kB) { - return Err(TypeError::new( - ctx, - e.clone(), - NoDependentLet(tR, tB), - )); + return Err(mkerr(NoDependentLet(tR, tB))); } if let Some(ref t) = *mt { let nf_t = normalize(t.clone()); let nf_tR = normalize(tR); if !prop_equal(nf_tR.clone(), nf_t.clone()) { - return Err(TypeError::new( - ctx, - e.clone(), - AnnotMismatch(r.clone(), nf_t, nf_tR), - )); + return Err(mkerr(AnnotMismatch(r.clone(), nf_t, nf_tR))); } } @@ -338,11 +308,7 @@ where } else { let nf_t = normalize(t.clone()); let nf_t2 = normalize(t2); - Err(TypeError::new( - ctx, - e.clone(), - AnnotMismatch(x.clone(), nf_t, nf_t2), - )) + Err(mkerr(AnnotMismatch(x.clone(), nf_t, nf_t2))) } } BoolLit(_) => Ok(Builtin(Bool)), @@ -363,11 +329,7 @@ where match tx.as_ref() { Builtin(Bool) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidPredicate(x.clone(), tx), - )); + return Err(mkerr(InvalidPredicate(x.clone(), tx))); } } let ty = normalize(type_with(ctx, y.clone())?); @@ -375,11 +337,12 @@ where match tty.as_ref() { Const(Type) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - IfBranchMustBeTerm(true, y.clone(), ty, tty), - )); + return Err(mkerr(IfBranchMustBeTerm( + true, + y.clone(), + ty, + tty, + ))); } } @@ -388,20 +351,22 @@ where match ttz.as_ref() { Const(Type) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - IfBranchMustBeTerm(false, z.clone(), tz, ttz), - )); + return Err(mkerr(IfBranchMustBeTerm( + false, + z.clone(), + tz, + ttz, + ))); } } if !prop_equal(ty.clone(), tz.clone()) { - return Err(TypeError::new( - ctx, - e.clone(), - IfBranchMismatch(y.clone(), z.clone(), ty, tz), - )); + return Err(mkerr(IfBranchMismatch( + y.clone(), + z.clone(), + ty, + tz, + ))); } return Ok(ty); } @@ -451,27 +416,22 @@ where } }; - let s = normalize::<_, S, _>(type_with(ctx, t.clone())?); + let s = normalize(type_with(ctx, t.clone())?); match s.as_ref() { Const(Type) => {} - _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidListType(t), - )) - } + _ => return Err(mkerr(InvalidListType(t))), } for (i, x) in iter { let t2 = type_with(ctx, x.clone())?; if !prop_equal(t.clone(), t2.clone()) { let nf_t = normalize(t); let nf_t2 = normalize(t2); - return Err(TypeError::new( - ctx, - e.clone(), - InvalidListElement(i, nf_t, x.clone(), nf_t2), - )); + return Err(mkerr(InvalidListElement( + i, + nf_t, + x.clone(), + nf_t2, + ))); } } return Ok(dhall_expr!(List t)); @@ -524,15 +484,11 @@ where } }; - let s = normalize::<_, S, _>(type_with(ctx, t.clone())?); + let s = normalize(type_with(ctx, t.clone())?); match s.as_ref() { Const(Type) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidOptionalType(t), - )); + return Err(mkerr(InvalidOptionalType(t))); } } for x in iter { @@ -540,11 +496,11 @@ where if !prop_equal(t.clone(), t2.clone()) { let nf_t = normalize(t); let nf_t2 = normalize(t2); - return Err(TypeError::new( - ctx, - e.clone(), - InvalidOptionalElement(nf_t, x.clone(), nf_t2), - )); + return Err(mkerr(InvalidOptionalElement( + nf_t, + x.clone(), + nf_t2, + ))); } } return Ok(dhall_expr!(Optional t)); @@ -568,15 +524,14 @@ where | Builtin(Double) | Builtin(Text) => Ok(Const(Type)), Record(ref kts) => { for (k, t) in kts { - let s = normalize::(type_with(ctx, t.clone())?); + let s = normalize(type_with(ctx, t.clone())?); match s.as_ref() { Const(Type) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidFieldType(k.clone(), t.clone()), - )); + return Err(mkerr(InvalidFieldType( + k.clone(), + t.clone(), + ))); } } } @@ -587,15 +542,14 @@ where .iter() .map(|(k, v)| { let t = type_with(ctx, v.clone())?; - let s = normalize::(type_with(ctx, t.clone())?); + let s = normalize(type_with(ctx, t.clone())?); match s.as_ref() { Const(Type) => {} _ => { - return Err(TypeError::new( - ctx, - e.clone(), - InvalidField(k.clone(), v.clone()), - )); + return Err(mkerr(InvalidField( + k.clone(), + v.clone(), + ))); } } Ok((k.clone(), t)) @@ -603,112 +557,17 @@ where .collect::>()?; Ok(Record(kts)) } - /* - type_with ctx e@(Union kts ) = do - let process (k, t) = do - s <- fmap Dhall.Core.normalize (type_with ctx t) - case s of - Const Type -> return () - _ -> Left (TypeError ctx e (InvalidAlternativeType k t)) - mapM_ process (Data.Map.toList kts) - return (Const Type) - type_with ctx e@(UnionLit k v kts) = do - case Data.Map.lookup k kts of - Just _ -> Left (TypeError ctx e (DuplicateAlternative k)) - Nothing -> return () - t <- type_with ctx v - let union = Union (Data.Map.insert k t kts) - _ <- type_with ctx union - return union - type_with ctx e@(Combine kvsX kvsY) = do - tKvsX <- fmap Dhall.Core.normalize (type_with ctx kvsX) - ktsX <- case tKvsX of - Record kts -> return kts - _ -> Left (TypeError ctx e (MustCombineARecord kvsX tKvsX)) - - tKvsY <- fmap Dhall.Core.normalize (type_with ctx kvsY) - ktsY <- case tKvsY of - Record kts -> return kts - _ -> Left (TypeError ctx e (MustCombineARecord kvsY tKvsY)) - - let combineTypes ktsL ktsR = do - let ks = - Data.Set.union (Data.Map.keysSet ktsL) (Data.Map.keysSet ktsR) - kts <- forM (toList ks) (\k -> do - case (Data.Map.lookup k ktsL, Data.Map.lookup k ktsR) of - (Just (Record ktsL'), Just (Record ktsR')) -> do - t <- combineTypes ktsL' ktsR' - return (k, t) - (Nothing, Just t) -> do - return (k, t) - (Just t, Nothing) -> do - return (k, t) - _ -> do - Left (TypeError ctx e (FieldCollision k)) ) - return (Record (Data.Map.fromList kts)) - - combineTypes ktsX ktsY - type_with ctx e@(Merge kvsX kvsY t) = do - tKvsX <- fmap Dhall.Core.normalize (type_with ctx kvsX) - ktsX <- case tKvsX of - Record kts -> return kts - _ -> Left (TypeError ctx e (MustMergeARecord kvsX tKvsX)) - let ksX = Data.Map.keysSet ktsX - - tKvsY <- fmap Dhall.Core.normalize (type_with ctx kvsY) - ktsY <- case tKvsY of - Union kts -> return kts - _ -> Left (TypeError ctx e (MustMergeUnion kvsY tKvsY)) - let ksY = Data.Map.keysSet ktsY - - let diffX = Data.Set.difference ksX ksY - let diffY = Data.Set.difference ksY ksX - - if Data.Set.null diffX - then return () - else Left (TypeError ctx e (UnusedHandler diffX)) - - let process (kY, tY) = do - case Data.Map.lookup kY ktsX of - Nothing -> Left (TypeError ctx e (MissingHandler diffY)) - Just tX -> - case tX of - Pi _ tY' t' -> do - if prop_equal tY tY' - then return () - else Left (TypeError ctx e (HandlerInputTypeMismatch kY tY tY')) - if prop_equal t t' - then return () - else Left (TypeError ctx e (HandlerOutputTypeMismatch kY t t')) - _ -> Left (TypeError ctx e (HandlerNotAFunction kY tX)) - mapM_ process (Data.Map.toList ktsY) - return t - */ Field(ref r, ref x) => { let t = normalize(type_with(ctx, r.clone())?); match t.as_ref() { Record(ref kts) => { return kts.get(x).cloned().ok_or_else(|| { - TypeError::new( - ctx, - e.clone(), - MissingField(x.clone(), t.clone()), - ) + mkerr(MissingField(x.clone(), t.clone())) }) } - _ => Err(TypeError::new( - ctx, - e.clone(), - NotARecord(x.clone(), r.clone(), t.clone()), - )), + _ => Err(mkerr(NotARecord(x.clone(), r.clone(), t.clone()))), } } - /* - type_with ctx (Note s e' ) = case type_with ctx e' of - Left (TypeError ctx2 (Note s' e'') m) -> Left (TypeError ctx2 (Note s' e'') m) - Left (TypeError ctx2 e'' m) -> Left (TypeError ctx2 (Note s e'') m) - Right r -> Right r - */ Embed(p) => match p {}, _ => panic!("Unimplemented typecheck case: {:?}", e), } @@ -718,7 +577,7 @@ where /// `typeOf` is the same as `type_with` with an empty context, meaning that the /// expression must be closed (i.e. no free variables), otherwise type-checking /// will fail. -pub fn type_of( +pub fn type_of( e: Rc>, ) -> Result>, TypeError> { let ctx = Context::new(); @@ -788,7 +647,7 @@ pub struct TypeError { pub type_message: TypeMessage, } -impl TypeError { +impl TypeError { pub fn new( context: &Context>>, current: Rc>, diff --git a/dhall/tests/common/mod.rs b/dhall/tests/common/mod.rs index 4d64fea..0589a4a 100644 --- a/dhall/tests/common/mod.rs +++ b/dhall/tests/common/mod.rs @@ -100,10 +100,7 @@ pub fn run_test(base_path: &str, feature: Feature) { let expr = rc(read_dhall_file(&expr_file_path).unwrap()); let expected = rc(read_dhall_file(&expected_file_path).unwrap()); - assert_eq_display!( - normalize::<_, X, _>(expr), - normalize::<_, X, _>(expected) - ); + assert_eq_display!(normalize(expr), normalize(expected)); } TypecheckFailure => { let file_path = base_path + ".dhall"; diff --git a/dhall_core/src/core.rs b/dhall_core/src/core.rs index 51356cf..a3236eb 100644 --- a/dhall_core/src/core.rs +++ b/dhall_core/src/core.rs @@ -215,9 +215,7 @@ impl InterpolatedText { } } - pub fn iter( - &self, - ) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator> { use std::iter::once; once(InterpolatedTextContents::Text(self.head.as_ref())).chain( self.tail.iter().flat_map(|(e, s)| { @@ -228,8 +226,7 @@ impl InterpolatedText { } } -impl<'a, N: 'a, E: 'a> - FromIterator> +impl<'a, N: 'a, E: 'a> FromIterator> for InterpolatedText { fn from_iter(iter: T) -> Self @@ -259,9 +256,7 @@ impl<'a, N: 'a, E: 'a> impl Add for &InterpolatedText { type Output = InterpolatedText; fn add(self, rhs: &InterpolatedText) -> Self::Output { - self.iter() - .chain(rhs.iter()) - .collect() + self.iter().chain(rhs.iter()).collect() } } @@ -418,34 +413,16 @@ impl Expr { F3: FnOnce(&A) -> B, F4: Fn(&Label) -> Label, { - map_shallow( + map_subexpr( self, |x| rc(map_expr(x.as_ref())), map_note, map_embed, map_label, + |_, x| rc(map_expr(x.as_ref())), ) } - pub fn map_shallow_rc( - &self, - map_expr: F1, - map_note: F2, - map_embed: F3, - map_label: F4, - ) -> Expr - where - A: Clone, - T: Clone, - S: Clone, - F1: Fn(&SubExpr) -> SubExpr, - F2: FnOnce(&S) -> T, - F3: FnOnce(&A) -> B, - F4: Fn(&Label) -> Label, - { - map_shallow(self, map_expr, map_note, map_embed, map_label) - } - pub fn map_embed(&self, map_embed: &F) -> Expr where A: Clone, @@ -885,102 +862,92 @@ fn add_ui(u: usize, i: isize) -> usize { } /// Map over the immediate children of the passed Expr -pub fn map_shallow( +pub fn map_subexpr( e: &Expr, map: F1, map_note: F2, map_embed: F3, map_label: F4, + map_under_binder: F5, ) -> Expr where - A: Clone, - S: Clone, - T: Clone, F1: Fn(&SubExpr) -> SubExpr, F2: FnOnce(&S) -> T, F3: FnOnce(&A) -> B, F4: Fn(&Label) -> Label, + F5: FnOnce(&Label, &SubExpr) -> SubExpr, { use crate::Expr::*; let map = ↦ let opt = |x: &Option<_>| x.as_ref().map(&map); - match *e { - Const(k) => Const(k), - Var(V(ref x, n)) => Var(V(map_label(x), n)), - Lam(ref x, ref t, ref b) => Lam(map_label(x), map(t), map(b)), - Pi(ref x, ref t, ref b) => Pi(map_label(x), map(t), map(b)), - App(ref f, ref args) => { - let args = args.iter().map(map).collect(); - App(map(f), args) - } - Let(ref l, ref t, ref a, ref b) => { - Let(map_label(l), opt(t), map(a), map(b)) - } - Annot(ref x, ref t) => Annot(map(x), map(t)), - Builtin(v) => Builtin(v), - BoolLit(b) => BoolLit(b), - BoolIf(ref b, ref t, ref f) => BoolIf(map(b), map(t), map(f)), - NaturalLit(n) => NaturalLit(n), - IntegerLit(n) => IntegerLit(n), - DoubleLit(n) => DoubleLit(n), - TextLit(ref t) => TextLit(t.map(&map)), - BinOp(o, ref x, ref y) => BinOp(o, map(x), map(y)), - ListLit(ref t, ref es) => { - let es = es.iter().map(&map).collect(); - ListLit(opt(t), es) - } - OptionalLit(ref t, ref es) => OptionalLit(opt(t), opt(es)), - Record(ref kts) => { - Record(map_record_value_and_keys(kts, map, map_label)) - } - RecordLit(ref kvs) => { - RecordLit(map_record_value_and_keys(kvs, map, map_label)) + let vec = |x: &Vec<_>| x.iter().map(&map).collect(); + let btmap = |x: &BTreeMap<_, _>| { + x.into_iter().map(|(k, v)| (map_label(k), map(v))).collect() + }; + match e { + Const(k) => Const(*k), + Var(V(x, n)) => Var(V(map_label(x), *n)), + Lam(x, t, b) => Lam(map_label(x), map(t), map_under_binder(x, b)), + Pi(x, t, b) => Pi(map_label(x), map(t), map_under_binder(x, b)), + App(f, args) => App(map(f), vec(args)), + Let(l, t, a, b) => { + Let(map_label(l), opt(t), map(a), map_under_binder(l, b)) } - Union(ref kts) => Union(map_record_value_and_keys(kts, map, map_label)), - UnionLit(ref k, ref v, ref kvs) => UnionLit( - map_label(k), - map(v), - map_record_value_and_keys(kvs, map, map_label), - ), - Merge(ref x, ref y, ref t) => Merge(map(x), map(y), opt(t)), - Field(ref r, ref x) => Field(map(r), map_label(x)), - Note(ref n, ref e) => Note(map_note(n), map(e)), - Embed(ref a) => Embed(map_embed(a)), + Annot(x, t) => Annot(map(x), map(t)), + Builtin(v) => Builtin(*v), + BoolLit(b) => BoolLit(*b), + BoolIf(b, t, f) => BoolIf(map(b), map(t), map(f)), + NaturalLit(n) => NaturalLit(*n), + IntegerLit(n) => IntegerLit(*n), + DoubleLit(n) => DoubleLit(*n), + TextLit(t) => TextLit(t.map(&map)), + BinOp(o, x, y) => BinOp(*o, map(x), map(y)), + ListLit(t, es) => ListLit(opt(t), vec(es)), + OptionalLit(t, es) => OptionalLit(opt(t), opt(es)), + Record(kts) => Record(btmap(kts)), + RecordLit(kvs) => RecordLit(btmap(kvs)), + Union(kts) => Union(btmap(kts)), + UnionLit(k, v, kvs) => UnionLit(map_label(k), map(v), btmap(kvs)), + Merge(x, y, t) => Merge(map(x), map(y), opt(t)), + Field(r, x) => Field(map(r), map_label(x)), + Note(n, e) => Note(map_note(n), map(e)), + Embed(a) => Embed(map_embed(a)), } } -pub fn map_record_value<'a, I, K, V, U, F>(it: I, f: F) -> BTreeMap -where - I: IntoIterator, - K: Eq + Ord + Clone + 'a, - V: 'a, - F: FnMut(&V) -> U, -{ - map_record_value_and_keys(it, f, |x| x.clone()) -} - -pub fn map_record_value_and_keys<'a, I, K, L, V, U, F, G>( - it: I, - mut f: F, - mut g: G, -) -> BTreeMap +pub fn map_subexpr_rc_binder( + e: &SubExpr, + map_expr: F1, + map_under_binder: F2, +) -> SubExpr where - I: IntoIterator, - K: Eq + Ord + 'a, - L: Eq + Ord + 'a, - V: 'a, - F: FnMut(&V) -> U, - G: FnMut(&K) -> L, + F1: Fn(&SubExpr) -> SubExpr, + F2: FnOnce(&Label, &SubExpr) -> SubExpr, { - it.into_iter().map(|(k, v)| (g(k), f(v))).collect() + match e.as_ref() { + Expr::Embed(_) => Rc::clone(e), + Expr::Note(_, e) => { + map_subexpr_rc_binder(e, map_expr, map_under_binder) + } + _ => rc(map_subexpr( + e, + map_expr, + |_| unreachable!(), + |_| unreachable!(), + Label::clone, + map_under_binder, + )), + } } -fn map_op2(f: F, g: G, a: T, b: T) -> V +pub fn map_subexpr_rc( + e: &SubExpr, + map_expr: F1, +) -> SubExpr where - F: FnOnce(U, U) -> V, - G: Fn(T) -> U, + F1: Fn(&SubExpr) -> SubExpr, { - f(g(a), g(b)) + map_subexpr_rc_binder(e, &map_expr, |_, e| map_expr(e)) } /// `shift` is used by both normalization and type-checking to avoid variable @@ -1055,88 +1022,27 @@ where /// name in order to avoid shifting the bound variables by mistake. /// pub fn shift( - d: isize, - v: &V, - e: &Rc>, + delta: isize, + var: &V, + in_expr: &Rc>, ) -> Rc> { use crate::Expr::*; - let V(x, n) = v; - rc(match e.as_ref() { - Const(a) => Const(*a), - Var(V(x2, n2)) => { - let n3 = if x == x2 && n <= n2 { - add_ui(*n2, d) - } else { - *n2 - }; - Var(V(x2.clone(), n3)) - } - Lam(x2, tA, b) => { - let n2 = if x == x2 { n + 1 } else { *n }; - let tA2 = shift(d, v, tA); - let b2 = shift(d, &V(x.clone(), n2), b); - Lam(x2.clone(), tA2, b2) - } - Pi(x2, tA, tB) => { - let n2 = if x == x2 { n + 1 } else { *n }; - let tA2 = shift(d, v, tA); - let tB2 = shift(d, &V(x.clone(), n2), tB); - Pi(x2.clone(), tA2, tB2) - } - App(f, args) => { - let f = shift(d, v, f); - let args = args.iter().map(|a| shift(d, v, a)).collect(); - App(f, args) - } - Let(f, mt, r, e) => { - let n2 = if x == f { n + 1 } else { *n }; - let e2 = shift(d, &V(x.clone(), n2), e); - let mt2 = mt.as_ref().map(|t| shift(d, v, t)); - let r2 = shift(d, v, r); - Let(f.clone(), mt2, r2, e2) - } - Annot(a, b) => map_op2(Annot, |x| shift(d, v, x), a, b), - Builtin(v) => Builtin(*v), - BoolLit(a) => BoolLit(*a), - BinOp(o, a, b) => { - map_op2(|x, y| BinOp(*o, x, y), |x| shift(d, v, x), a, b) - } - BoolIf(a, b, c) => { - BoolIf(shift(d, v, a), shift(d, v, b), shift(d, v, c)) + let V(x, n) = var; + let under_binder = |y: &Label, e: &SubExpr<_, _>| { + let n = if x == y { n + 1 } else { *n }; + let newvar = &V(x.clone(), n); + shift(delta, newvar, e) + }; + match in_expr.as_ref() { + Var(V(y, m)) if x == y && n <= m => { + rc(Var(V(y.clone(), add_ui(*m, delta)))) } - NaturalLit(a) => NaturalLit(*a), - IntegerLit(a) => IntegerLit(*a), - DoubleLit(a) => DoubleLit(*a), - TextLit(a) => TextLit(a.map(|e| shift(d, v, &*e))), - ListLit(t, es) => ListLit( - t.as_ref().map(|t| shift(d, v, t)), - es.iter().map(|e| shift(d, v, e)).collect(), + _ => map_subexpr_rc_binder( + in_expr, + |e| shift(delta, var, e), + under_binder, ), - OptionalLit(t, e) => OptionalLit( - t.as_ref().map(|t| shift(d, v, t)), - e.as_ref().map(|t| shift(d, v, t)), - ), - Record(a) => Record(map_record_value(a, |val| shift(d, v, &*val))), - RecordLit(a) => { - RecordLit(map_record_value(a, |val| shift(d, v, &*val))) - } - Union(a) => Union(map_record_value(a, |val| shift(d, v, &*val))), - UnionLit(k, uv, a) => UnionLit( - k.clone(), - shift(d, v, uv), - map_record_value(a, |val| shift(d, v, &*val)), - ), - Merge(a, b, c) => Merge( - shift(d, v, a), - shift(d, v, b), - c.as_ref().map(|c| shift(d, v, c)), - ), - Field(a, b) => Field(shift(d, v, a), b.clone()), - Note(_, b) => return shift(d, v, b), - // The Dhall compiler enforces that all embedded values are closed expressions - // and `shift` does nothing to a closed expression - Embed(_) => return Rc::clone(e), - }) + } } /// Substitute all occurrences of a variable with an expression @@ -1146,79 +1052,23 @@ pub fn shift( /// ``` /// pub fn subst( - v: &V, - e: &Rc>, - b: &Rc>, -) -> Rc> -{ + var: &V, + value: &Rc>, + in_expr: &Rc>, +) -> Rc> { use crate::Expr::*; - let V(x, n) = v; - rc(match b.as_ref() { - Lam(y, tA, b) => { - let n2 = if x == y { n + 1 } else { *n }; - let b2 = - subst(&V(x.clone(), n2), &shift(1, &V(y.clone(), 0), e), b); - let tA2 = subst(v, e, tA); - Lam(y.clone(), tA2, b2) - } - Pi(y, tA, tB) => { - let n2 = if x == y { n + 1 } else { *n }; - let tB2 = - subst(&V(x.clone(), n2), &shift(1, &V(y.clone(), 0), e), tB); - let tA2 = subst(v, e, tA); - Pi(y.clone(), tA2, tB2) - } - App(f, args) => { - let f2 = subst(v, e, f); - let args = args.iter().map(|a| subst(v, e, a)).collect(); - App(f2, args) - } - Var(v2) if v == v2 => { - return e.clone(); - } - Let(f, mt, r, b) => { - let n2 = if x == f { n + 1 } else { *n }; - let b2 = - subst(&V(x.clone(), n2), &shift(1, &V(f.clone(), 0), e), b); - let mt2 = mt.as_ref().map(|t| subst(v, e, t)); - let r2 = subst(v, e, r); - Let(f.clone(), mt2, r2, b2) - } - Annot(a, b) => map_op2(Annot, |x| subst(v, e, x), a, b), - BinOp(o, a, b) => { - map_op2(|x, y| BinOp(*o, x, y), |x| subst(v, e, x), a, b) - } - BoolIf(a, b, c) => { - BoolIf(subst(v, e, a), subst(v, e, b), subst(v, e, c)) - } - TextLit(a) => TextLit(a.map(|b| subst(v, e, &*b))), - ListLit(a, b) => { - let a2 = a.as_ref().map(|a| subst(v, e, a)); - let b2 = b.iter().map(|be| subst(v, e, be)).collect(); - ListLit(a2, b2) - } - OptionalLit(a, b) => { - let a2 = a.as_ref().map(|a| subst(v, e, a)); - let b2 = b.as_ref().map(|a| subst(v, e, a)); - OptionalLit(a2, b2) - } - Record(kts) => Record(map_record_value(kts, |t| subst(v, e, &*t))), - RecordLit(kvs) => { - RecordLit(map_record_value(kvs, |val| subst(v, e, &*val))) - } - Union(kts) => Union(map_record_value(kts, |t| subst(v, e, &*t))), - UnionLit(k, uv, kvs) => UnionLit( - k.clone(), - subst(v, e, uv), - map_record_value(kvs, |val| subst(v, e, &*val)), - ), - Merge(a, b, c) => Merge( - subst(v, e, a), - subst(v, e, b), - c.as_ref().map(|c| subst(v, e, c)), + let under_binder = |y: &Label, e: &SubExpr<_, _>| { + let V(x, n) = var; + let n = if x == y { n + 1 } else { *n }; + let newvar = &V(x.clone(), n); + subst(newvar, &shift(1, &V(y.clone(), 0), value), e) + }; + match in_expr.as_ref() { + Var(v) if var == v => Rc::clone(value), + _ => map_subexpr_rc_binder( + in_expr, + |e| subst(var, value, e), + under_binder, ), - Field(a, b) => Field(subst(v, e, a), b.clone()), - Note(_, b) => return subst(v, e, b), - _ => return Rc::clone(b), - }) + } } -- cgit v1.2.3