diff options
Diffstat (limited to 'dhall/src')
-rw-r--r-- | dhall/src/typecheck.rs | 170 |
1 files changed, 108 insertions, 62 deletions
diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 0ee481f..f006ec6 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -1,6 +1,7 @@ #![allow(non_snake_case)] use std::borrow::Borrow; use std::borrow::Cow; +use std::collections::BTreeMap; use std::fmt; use std::marker::PhantomData; @@ -32,12 +33,6 @@ impl<'a> Resolved<'a> { } } impl<'a> Typed<'a> { - fn normalize_to_type( - self, - ctx: &TypecheckContext, - ) -> Result<Type<'a>, TypeError> { - Ok(self.normalize().into_type_ctx(ctx)?) - } fn get_type_move(self) -> Result<Type<'static>, TypeError> { let (expr, ty) = (self.0, self.1); ty.ok_or_else(|| { @@ -71,14 +66,13 @@ impl<'a> Normalized<'a> { self, ctx: &TypecheckContext, ) -> Result<Type<'a>, TypeError> { - Ok(Type(match self.0.as_ref() { - ExprF::Const(c) => TypeInternal::Const(*c), + Ok(match self.0.as_ref() { + ExprF::Const(c) => Type(TypeInternal::Const(*c)), ExprF::Pi(_, _, _) => { - return Ok(type_with(ctx, self.0.embed_absurd())? - .normalize_to_type(ctx)?) + type_with(ctx, self.0.embed_absurd())?.normalize_to_type(ctx)? } - _ => TypeInternal::Expr(Box::new(self)), - })) + _ => Type(TypeInternal::Expr(Box::new(self))), + }) } fn get_type_move(self) -> Result<Type<'static>, TypeError> { let (expr, ty) = (self.0, self.1); @@ -161,6 +155,7 @@ pub(crate) enum TypeInternal<'a> { Box<Type<'static>>, Box<Type<'static>>, ), + RecordType(TypecheckContext, Const, BTreeMap<Label, Type<'static>>), /// The type of `Sort` SuperType, /// This must not contain a value captured by one of the variants above. @@ -172,11 +167,16 @@ impl<'a> TypeInternal<'a> { match self { TypeInternal::Expr(e) => Ok(*e), TypeInternal::Pi(ctx, c, x, t, e) => Ok(Typed( - rc(ExprF::Pi( - x, - t.into_normalized()?.embed(), - e.into_normalized()?.embed(), - )), + rc(ExprF::Pi(x, t, e) + .traverse_ref_simple(|e| e.clone().embed())?), + Some(const_to_type(c)), + ctx, + PhantomData, + ) + .normalize()), + TypeInternal::RecordType(ctx, c, kts) => Ok(Typed( + rc(ExprF::RecordType(kts) + .traverse_ref_simple(|e| e.clone().embed())?), Some(const_to_type(c)), ctx, PhantomData, @@ -204,6 +204,13 @@ impl<'a> TypeInternal<'a> { Box::new(t.shift(delta, var)), Box::new(e.shift(delta, &var.shift0(1, x))), ), + RecordType(ctx, c, kts) => RecordType( + ctx.clone(), + *c, + kts.iter() + .map(|(k, v)| (k.clone(), v.shift(delta, var))) + .collect(), + ), Const(c) => Const(*c), SuperType => SuperType, } @@ -228,7 +235,7 @@ impl TypedOrType { ctx: &TypecheckContext, ) -> Result<Type<'static>, TypeError> { match self { - TypedOrType::Typed(e) => Ok(e.normalize_to_type(ctx)?), + TypedOrType::Typed(e) => Ok(e.normalize().into_type_ctx(ctx)?), TypedOrType::Type(t) => Ok(t), } } @@ -531,43 +538,43 @@ macro_rules! ensure_is_const { #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum TypeIntermediate { Pi(TypecheckContext, Label, Type<'static>, Type<'static>), + RecordType(TypecheckContext, BTreeMap<Label, Type<'static>>), } impl TypeIntermediate { fn typecheck(self) -> Result<TypedOrType, TypeError> { + let mkerr = + |ctx, msg| Ok(TypeError::new(ctx, self.clone().into_expr()?, msg)); match &self { TypeIntermediate::Pi(ctx, x, ta, tb) => { let ctx2 = ctx.insert_type(x, ta.clone()); let kA = ensure_is_const!( &ta.get_type()?, - TypeError::new( + mkerr( ctx, - self.clone().into_expr()?, InvalidInputType(ta.clone().into_normalized()?), - ), + )?, ); let kB = ensure_is_const!( &tb.get_type()?, - TypeError::new( + mkerr( &ctx2, - self.clone().into_expr()?, InvalidOutputType( tb.clone() .into_normalized()? .get_type_move()? .into_normalized()? ), - ), + )?, ); let k = match function_check(kA, kB) { Ok(k) => k, Err(()) => { - return Err(TypeError::new( + return Err(mkerr( ctx, - self.clone().into_expr()?, NoDependentTypes( ta.clone().into_normalized()?, tb.clone() @@ -575,7 +582,7 @@ impl TypeIntermediate { .get_type_move()? .into_normalized()?, ), - )) + )?) } }; @@ -587,16 +594,51 @@ impl TypeIntermediate { Box::new(tb.clone()), )))) } + TypeIntermediate::RecordType(ctx, kts) => { + // Check that all types are the same const + let mut k = None; + for (x, t) in kts { + let k2 = ensure_is_const!( + t.get_type()?, + mkerr( + ctx, + InvalidFieldType( + x.clone(), + TypedOrType::Type(t.clone()) + ) + )? + ); + match k { + None => k = Some(k2), + Some(k1) if k1 != k2 => { + return Err(mkerr( + ctx, + InvalidFieldType( + x.clone(), + TypedOrType::Type(t.clone()), + ), + )?) + } + Some(_) => {} + } + } + // An empty record type has type Type + let k = k.unwrap_or(dhall_core::Const::Type); + + Ok(TypedOrType::Type(Type(TypeInternal::RecordType( + ctx.clone(), + k, + kts.clone(), + )))) + } } } fn into_expr(self) -> Result<SubExpr<X, Normalized<'static>>, TypeError> { - match self { - TypeIntermediate::Pi(_, x, t, e) => Ok(rc(ExprF::Pi( - x, - t.into_normalized()?.embed(), - e.into_normalized()?.embed(), - ))), + Ok(rc(match self { + TypeIntermediate::Pi(_, x, t, e) => ExprF::Pi(x, t, e), + TypeIntermediate::RecordType(_, kts) => ExprF::RecordType(kts), } + .traverse_ref_simple(|e| e.clone().embed())?)) } } @@ -619,7 +661,9 @@ fn simple_type_from_builtin<'a>(b: Builtin) -> Type<'a> { /// Intermediary return type enum Ret { - /// Returns the contained Type as is + /// Returns the contained value as is + RetTypedOrType(TypedOrType), + /// Use the contained Type as the type of the input expression RetType(Type<'static>), /// Returns an expression that must be typechecked and /// turned into a Type first. @@ -651,8 +695,10 @@ fn type_with( let ta = mktype(ctx, ta.clone())?; let ctx2 = ctx.insert_type(x, ta.clone()); let tb = mktype(&ctx2, tb.clone())?; - return TypeIntermediate::Pi(ctx.clone(), x.clone(), ta, tb) - .typecheck(); + Ok(RetTypedOrType( + TypeIntermediate::Pi(ctx.clone(), x.clone(), ta, tb) + .typecheck()?, + )) } Let(x, t, v, e) => { let v = if let Some(t) = t { @@ -667,7 +713,7 @@ fn type_with( Ok(RetType(e.get_type_move()?)) } - Embed(p) => return Ok(TypedOrType::Typed(p.clone().into())), + Embed(p) => Ok(RetTypedOrType(TypedOrType::Typed(p.clone().into()))), _ => type_last_layer( ctx, // Typecheck recursively all subexpressions @@ -689,6 +735,7 @@ fn type_with( ctx.clone(), PhantomData, ))), + RetTypedOrType(tt) => Ok(tt), } } @@ -732,6 +779,10 @@ fn type_last_layer( a.normalize()?.embed(), tb.into_normalized()?.into_expr().embed_absurd(), ))) + // Ok(RetType(mktype( + // &ctx.insert_value(&x, a), + // tb.into_normalized()?.into_expr().embed_absurd(), + // )?)) } Annot(x, t) => { let t = t.normalize_to_type(ctx)?; @@ -823,24 +874,13 @@ fn type_last_layer( Ok(RetExpr(dhall::expr!(Optional t))) } RecordType(kts) => { - // Check that all types are the same const - let mut k = None; - for (x, t) in kts { - let k2 = ensure_is_const!( - t.get_type()?, - mkerr(InvalidFieldType(x, t)) - ); - match k { - None => k = Some(k2), - Some(k1) if k1 != k2 => { - return Err(mkerr(InvalidFieldType(x, t))) - } - Some(_) => {} - } - } - // An empty record type has type Type - let k = k.unwrap_or(dhall_core::Const::Type); - Ok(RetType(const_to_type(k))) + let kts: BTreeMap<_, _> = kts + .into_iter() + .map(|(x, t)| Ok((x, t.normalize_to_type(ctx)?))) + .collect::<Result<_, _>>()?; + Ok(RetTypedOrType( + TypeIntermediate::RecordType(ctx.clone(), kts).typecheck()?, + )) } UnionType(kts) => { // Check that all types are the same const @@ -860,20 +900,21 @@ fn type_last_layer( } } } - // An empty union type has type Type - // An union type with only unary variants has type Type + // An empty union type has type Type; + // an union type with only unary variants also has type Type let k = k.unwrap_or(dhall_core::Const::Type); Ok(RetType(const_to_type(k))) } RecordLit(kvs) => { let kts = kvs .into_iter() - .map(|(x, v)| { - let t = v.get_type_move()?.embed()?; - Ok((x, t)) - }) + .map(|(x, v)| Ok((x, v.get_type_move()?))) .collect::<Result<_, _>>()?; - Ok(RetExpr(RecordType(kts))) + Ok(RetType( + TypeIntermediate::RecordType(ctx.clone(), kts) + .typecheck()? + .normalize_to_type(ctx)?, + )) } UnionLit(x, v, kvs) => { let mut kts: std::collections::BTreeMap<_, _> = kvs @@ -890,6 +931,11 @@ fn type_last_layer( kts.insert(x, Some(t)); Ok(RetExpr(UnionType(kts))) } + // Field(r, x) => match &r.get_type()?.0 { + // TypeInternal::RecordType(_, _, kts) => match kts.get(&x) { + // Some(t) => Ok(RetType(t.clone())), + // None => Err(mkerr(MissingRecordField(x, r))), + // }, Field(r, x) => match r.get_type()?.unroll_ref()?.as_ref() { RecordType(kts) => match kts.get(&x) { Some(t) => Ok(RetExpr(t.unroll().embed_absurd())), |