diff options
Diffstat (limited to 'dhall/src')
-rw-r--r-- | dhall/src/binary.rs | 22 | ||||
-rw-r--r-- | dhall/src/typecheck.rs | 46 |
2 files changed, 58 insertions, 10 deletions
diff --git a/dhall/src/binary.rs b/dhall/src/binary.rs index 72704de..c12aa2a 100644 --- a/dhall/src/binary.rs +++ b/dhall/src/binary.rs @@ -143,11 +143,11 @@ fn cbor_value_to_dhall(data: &cbor::Value) -> Result<ParsedExpr, DecodeError> { Field(x, l) } [U64(11), Object(map)] => { - let map = cbor_map_to_dhall_map(map)?; + let map = cbor_map_to_dhall_opt_map(map)?; UnionType(map) } [U64(12), String(l), x, Object(map)] => { - let map = cbor_map_to_dhall_map(map)?; + let map = cbor_map_to_dhall_opt_map(map)?; let x = cbor_value_to_dhall(&x)?; let l = Label::from(l.as_str()); UnionLit(l, x, map) @@ -343,3 +343,21 @@ fn cbor_map_to_dhall_map( }) .collect::<Result<_, _>>() } + +fn cbor_map_to_dhall_opt_map( + map: &std::collections::BTreeMap<cbor::ObjectKey, cbor::Value>, +) -> Result<std::collections::BTreeMap<Label, Option<ParsedExpr>>, DecodeError> +{ + map.iter() + .map(|(k, v)| -> Result<(_, _), _> { + let k = k.as_string().ok_or_else(|| { + DecodeError::WrongFormatError("map/key".to_owned()) + })?; + let v = match v { + cbor::Value::Null => None, + _ => Some(cbor_value_to_dhall(v)?), + }; + Ok((Label::from(k.as_ref()), v)) + }) + .collect::<Result<_, _>>() +} diff --git a/dhall/src/typecheck.rs b/dhall/src/typecheck.rs index 0e1a10e..d9566d0 100644 --- a/dhall/src/typecheck.rs +++ b/dhall/src/typecheck.rs @@ -182,10 +182,16 @@ where } (UnionType(ktsL0), UnionType(ktsR0)) => { ktsL0.len() == ktsR0.len() - && ktsL0 - .iter() - .zip(ktsR0.iter()) - .all(|((kL, tL), (kR, tR))| kL == kR && go(ctx, tL, tR)) + && ktsL0.iter().zip(ktsR0.iter()).all( + |((kL, tL), (kR, tR))| { + kL == kR + && match (tL, tR) { + (None, None) => true, + (Some(tL), Some(tR)) => go(ctx, tL, tR), + _ => false, + } + }, + ) } (_, _) => false, } @@ -602,7 +608,7 @@ fn type_last_layer( let t = tx.embed()?; Ok(RetExpr(dhall::expr!(Optional t))) } - RecordType(kts) | UnionType(kts) => { + RecordType(kts) => { // Check that all types are the same const let mut k = None; for (x, t) in kts { @@ -622,6 +628,29 @@ fn type_last_layer( let k = k.unwrap_or(dhall_core::Const::Type); Ok(RetType(const_to_type(k))) } + UnionType(kts) => { + // Check that all types are the same const + let mut k = None; + for (x, t) in kts { + if let Some(t) = t { + let k2 = ensure_is_const!( + t.get_type()?, + mkerr(InvalidFieldType(x, t)) + ); + match k { + None => k = Some(k2), + Some(k1) if k1 != k2 => { + return Err(mkerr(InvalidFieldType(x, t))) + } + Some(_) => {} + } + } + } + // An empty union type has type Type + // An union type with only unary variants has type Type + let k = k.unwrap_or(dhall_core::Const::Type); + Ok(RetType(const_to_type(k))) + } RecordLit(kvs) => { let kts = kvs .into_iter() @@ -636,12 +665,12 @@ fn type_last_layer( let mut kts: std::collections::BTreeMap<_, _> = kvs .into_iter() .map(|(x, v)| { - let t = v.normalize().embed(); + let t = v.map(|x| x.normalize().embed()); Ok((x, t)) }) .collect::<Result<_, _>>()?; let t = v.get_type_move()?.embed()?; - kts.insert(x, t); + kts.insert(x, Some(t)); Ok(RetExpr(UnionType(kts))) } Field(r, x) => match r.get_type()?.unroll_ref()?.as_ref() { @@ -655,11 +684,12 @@ fn type_last_layer( UnionType(kts) => match kts.get(&x) { // Constructor has type T -> < x: T, ... > // TODO: use "_" instead of x - Some(t) => Ok(RetExpr(Pi( + Some(Some(t)) => Ok(RetExpr(Pi( x.clone(), t.embed_absurd(), r.embed(), ))), + Some(None) => Ok(RetType(r.into_type())), None => Err(mkerr(MissingUnionField(x, r))), }, _ => Err(mkerr(NotARecord(x, r))), |