diff options
author | Nadrieril | 2019-05-12 18:44:28 +0200 |
---|---|---|
committer | Nadrieril | 2019-05-12 18:44:44 +0200 |
commit | c2b4a2d9b40efbe4f6cb6fd04f6cb90639f4985f (patch) | |
tree | 02e0f7b57b56f949240cedbdfabac234d9486834 /dhall/src | |
parent | 2d1a333d6c1e8571ca91d29333c284104153b0ef (diff) |
Implement binary encoding
Closes #39
Diffstat (limited to 'dhall/src')
-rw-r--r-- | dhall/src/error/mod.rs | 12 | ||||
-rw-r--r-- | dhall/src/phase/binary.rs | 311 | ||||
-rw-r--r-- | dhall/src/phase/mod.rs | 15 | ||||
-rw-r--r-- | dhall/src/phase/parse.rs | 6 | ||||
-rw-r--r-- | dhall/src/phase/resolve.rs | 1 | ||||
-rw-r--r-- | dhall/src/phase/typecheck.rs | 2 | ||||
-rw-r--r-- | dhall/src/tests.rs | 40 |
7 files changed, 355 insertions, 32 deletions
diff --git a/dhall/src/error/mod.rs b/dhall/src/error/mod.rs index 3f482f6..125d013 100644 --- a/dhall/src/error/mod.rs +++ b/dhall/src/error/mod.rs @@ -14,6 +14,7 @@ pub enum Error { IO(IOError), Parse(ParseError), Decode(DecodeError), + Encode(EncodeError), Resolve(ImportError), Typecheck(TypeError), Deserialize(String), @@ -32,6 +33,11 @@ pub enum DecodeError { WrongFormatError(String), } +#[derive(Debug)] +pub enum EncodeError { + CBORError(serde_cbor::error::Error), +} + /// A structured type error that includes context #[derive(Debug)] pub struct TypeError { @@ -140,6 +146,7 @@ impl std::fmt::Display for Error { Error::IO(err) => write!(f, "{}", err), Error::Parse(err) => write!(f, "{}", err), Error::Decode(err) => write!(f, "{:?}", err), + Error::Encode(err) => write!(f, "{:?}", err), Error::Resolve(err) => write!(f, "{:?}", err), Error::Typecheck(err) => write!(f, "{:?}", err), Error::Deserialize(err) => write!(f, "{}", err), @@ -163,6 +170,11 @@ impl From<DecodeError> for Error { Error::Decode(err) } } +impl From<EncodeError> for Error { + fn from(err: EncodeError) -> Error { + Error::Encode(err) + } +} impl From<ImportError> for Error { fn from(err: ImportError) -> Error { Error::Resolve(err) diff --git a/dhall/src/phase/binary.rs b/dhall/src/phase/binary.rs index 7f72e80..a3ab5de 100644 --- a/dhall/src/phase/binary.rs +++ b/dhall/src/phase/binary.rs @@ -2,24 +2,32 @@ use itertools::Itertools; use serde_cbor::value::value as cbor; use std::iter::FromIterator; +use dhall_syntax::map::DupTreeMap; use dhall_syntax::{ rc, ExprF, FilePrefix, Hash, Import, ImportHashed, ImportLocation, ImportMode, Integer, InterpolatedText, Label, Natural, Scheme, SubExpr, - URL, V, X, + URL, V, }; -use crate::error::DecodeError; +use crate::error::{DecodeError, EncodeError}; +use crate::phase::{DecodedSubExpr, ParsedSubExpr}; -type ParsedExpr = SubExpr<X, Import>; - -pub fn decode(data: &[u8]) -> Result<ParsedExpr, DecodeError> { +pub fn decode(data: &[u8]) -> Result<DecodedSubExpr, DecodeError> { match serde_cbor::de::from_slice(data) { Ok(v) => cbor_value_to_dhall(&v), Err(e) => Err(DecodeError::CBORError(e)), } } -fn cbor_value_to_dhall(data: &cbor::Value) -> Result<ParsedExpr, DecodeError> { +//TODO: encode normalized expression too +pub fn encode(expr: &ParsedSubExpr) -> Result<Vec<u8>, EncodeError> { + serde_cbor::ser::to_vec(&Serialize::Expr(expr)) + .map_err(|e| EncodeError::CBORError(e)) +} + +fn cbor_value_to_dhall( + data: &cbor::Value, +) -> Result<DecodedSubExpr, DecodeError> { use cbor::Value::*; use dhall_syntax::{BinOp, Builtin, Const}; use ExprF::*; @@ -238,12 +246,11 @@ fn cbor_value_to_dhall(data: &cbor::Value) -> Result<ParsedExpr, DecodeError> { ))?, }; let path = rest - .map(|s| { - s.as_string().ok_or_else(|| { - DecodeError::WrongFormatError( - "import/remote/path".to_owned(), - ) - }) + .map(|s| match s.as_string() { + Some(s) => Ok(s.clone()), + None => Err(DecodeError::WrongFormatError( + "import/remote/path".to_owned(), + )), }) .collect::<Result<_, _>>()?; ImportLocation::Remote(URL { @@ -265,12 +272,11 @@ fn cbor_value_to_dhall(data: &cbor::Value) -> Result<ParsedExpr, DecodeError> { ))?, }; let path = rest - .map(|s| { - s.as_string().ok_or_else(|| { - DecodeError::WrongFormatError( - "import/local/path".to_owned(), - ) - }) + .map(|s| match s.as_string() { + Some(s) => Ok(s.clone()), + None => Err(DecodeError::WrongFormatError( + "import/local/path".to_owned(), + )), }) .collect::<Result<_, _>>()?; ImportLocation::Local(prefix, path) @@ -336,7 +342,7 @@ fn cbor_map_to_dhall_map<'a, T>( map: impl IntoIterator<Item = (&'a cbor::ObjectKey, &'a cbor::Value)>, ) -> Result<T, DecodeError> where - T: FromIterator<(Label, ParsedExpr)>, + T: FromIterator<(Label, DecodedSubExpr)>, { map.into_iter() .map(|(k, v)| -> Result<(_, _), _> { @@ -353,7 +359,7 @@ fn cbor_map_to_dhall_opt_map<'a, T>( map: impl IntoIterator<Item = (&'a cbor::ObjectKey, &'a cbor::Value)>, ) -> Result<T, DecodeError> where - T: FromIterator<(Label, Option<ParsedExpr>)>, + T: FromIterator<(Label, Option<DecodedSubExpr>)>, { map.into_iter() .map(|(k, v)| -> Result<(_, _), _> { @@ -368,3 +374,268 @@ where }) .collect::<Result<_, _>>() } + +enum Serialize<'a> { + Expr(&'a ParsedSubExpr), + CBOR(cbor::Value), + RecordMap(&'a DupTreeMap<Label, ParsedSubExpr>), + UnionMap(&'a DupTreeMap<Label, Option<ParsedSubExpr>>), +} + +macro_rules! count { + (@replace_with $_t:tt $sub:expr) => { $sub }; + ($($tts:tt)*) => {0usize $(+ count!(@replace_with $tts 1usize))*}; +} + +macro_rules! ser_seq { + ($ser:expr; $($elt:expr),* $(,)?) => {{ + use serde::ser::SerializeSeq; + let count = count!($($elt)*); + let mut ser_seq = $ser.serialize_seq(Some(count))?; + $( + ser_seq.serialize_element(&$elt)?; + )* + ser_seq.end() + }}; +} + +fn serialize_subexpr<S>(ser: S, e: &ParsedSubExpr) -> Result<S::Ok, S::Error> +where + S: serde::ser::Serializer, +{ + use cbor::Value::{String, I64, U64}; + use dhall_syntax::ExprF::*; + use std::iter::once; + + use self::Serialize::{RecordMap, UnionMap}; + fn expr(x: &ParsedSubExpr) -> self::Serialize<'_> { + self::Serialize::Expr(x) + } + fn cbor<'a>(v: cbor::Value) -> self::Serialize<'a> { + self::Serialize::CBOR(v) + } + fn tag<'a>(x: u64) -> self::Serialize<'a> { + cbor(U64(x)) + } + fn null<'a>() -> self::Serialize<'a> { + cbor(cbor::Value::Null) + } + fn label<'a>(l: &Label) -> self::Serialize<'a> { + cbor(cbor::Value::String(l.into())) + } + + match e.as_ref() { + Const(c) => ser.serialize_str(&c.to_string()), + Builtin(b) => ser.serialize_str(&b.to_string()), + BoolLit(b) => ser.serialize_bool(*b), + NaturalLit(n) => ser_seq!(ser; tag(15), U64(*n as u64)), + IntegerLit(n) => ser_seq!(ser; tag(16), I64(*n as i64)), + DoubleLit(n) => { + let n: f64 = (*n).into(); + ser.serialize_f64(n) + } + BoolIf(x, y, z) => ser_seq!(ser; tag(14), expr(x), expr(y), expr(z)), + Var(V(l, n)) if l == &"_".into() => ser.serialize_u64(*n as u64), + Var(V(l, n)) => ser_seq!(ser; label(l), U64(*n as u64)), + Lam(l, x, y) if l == &"_".into() => { + ser_seq!(ser; tag(1), expr(x), expr(y)) + } + Lam(l, x, y) => ser_seq!(ser; tag(1), label(l), expr(x), expr(y)), + Pi(l, x, y) if l == &"_".into() => { + ser_seq!(ser; tag(2), expr(x), expr(y)) + } + Pi(l, x, y) => ser_seq!(ser; tag(2), label(l), expr(x), expr(y)), + // TODO: multilet + Let(l, None, x, y) => { + ser_seq!(ser; tag(25), label(l), null(), expr(x), expr(y)) + } + Let(l, Some(t), x, y) => { + ser_seq!(ser; tag(25), label(l), expr(t), expr(x), expr(y)) + } + App(_, _) => { + let (f, args) = collect_nested_applications(e); + ser.collect_seq( + once(tag(0)) + .chain(once(expr(f))) + .chain(args.into_iter().rev().map(expr)), + ) + } + Annot(x, y) => ser_seq!(ser; tag(26), expr(x), expr(y)), + OldOptionalLit(None, t) => ser_seq!(ser; tag(5), expr(t)), + OldOptionalLit(Some(x), t) => ser_seq!(ser; tag(5), expr(t), expr(x)), + SomeLit(x) => ser_seq!(ser; tag(5), null(), expr(x)), + EmptyListLit(x) => ser_seq!(ser; tag(4), expr(x)), + NEListLit(xs) => ser.collect_seq( + once(tag(4)).chain(once(null())).chain(xs.iter().map(expr)), + ), + TextLit(xs) => { + use dhall_syntax::InterpolatedTextContents::{Expr, Text}; + ser.collect_seq(once(tag(18)).chain(xs.iter().map(|x| match x { + Expr(x) => expr(x), + Text(x) => cbor(String(x.clone())), + }))) + } + RecordType(map) => ser_seq!(ser; tag(7), RecordMap(map)), + RecordLit(map) => ser_seq!(ser; tag(8), RecordMap(map)), + UnionType(map) => ser_seq!(ser; tag(11), UnionMap(map)), + UnionLit(l, x, map) => { + ser_seq!(ser; tag(12), label(l), expr(x), UnionMap(map)) + } + Field(x, l) => ser_seq!(ser; tag(9), expr(x), label(l)), + BinOp(op, x, y) => { + use dhall_syntax::BinOp::*; + let op = match op { + BoolOr => 0, + BoolAnd => 1, + BoolEQ => 2, + BoolNE => 3, + NaturalPlus => 4, + NaturalTimes => 5, + TextAppend => 6, + ListAppend => 7, + RecursiveRecordMerge => 8, + RightBiasedRecordMerge => 9, + RecursiveRecordTypeMerge => 10, + ImportAlt => 11, + }; + ser_seq!(ser; tag(3), U64(op), expr(x), expr(y)) + } + Merge(x, y, None) => ser_seq!(ser; tag(6), expr(x), expr(y)), + Merge(x, y, Some(z)) => { + ser_seq!(ser; tag(6), expr(x), expr(y), expr(z)) + } + Projection(x, ls) => ser.collect_seq( + once(tag(10)) + .chain(once(expr(x))) + .chain(ls.iter().map(label)), + ), + Embed(import) => serialize_import(ser, import), + } +} + +fn serialize_import<S>(ser: S, import: &Import) -> Result<S::Ok, S::Error> +where + S: serde::ser::Serializer, +{ + use cbor::Value::{Array, Null, String, U64}; + use serde::ser::SerializeSeq; + + let count = 4 + match &import.location_hashed.location { + ImportLocation::Remote(url) => 3 + url.path.len(), + ImportLocation::Local(_, path) => path.len(), + ImportLocation::Env(_) => 1, + ImportLocation::Missing => 0, + }; + let mut ser_seq = ser.serialize_seq(Some(count))?; + + ser_seq.serialize_element(&U64(24))?; + + let hash = match &import.location_hashed.hash { + None => Null, + Some(h) => { + Array(vec![String(h.protocol.clone()), String(h.hash.clone())]) + } + }; + ser_seq.serialize_element(&hash)?; + + let mode = match import.mode { + ImportMode::Code => 0, + ImportMode::RawText => 1, + }; + ser_seq.serialize_element(&U64(mode))?; + + let scheme = match &import.location_hashed.location { + ImportLocation::Remote(url) => match url.scheme { + Scheme::HTTP => 0, + Scheme::HTTPS => 1, + }, + ImportLocation::Local(prefix, _) => match prefix { + FilePrefix::Absolute => 2, + FilePrefix::Here => 3, + FilePrefix::Parent => 4, + FilePrefix::Home => 5, + }, + ImportLocation::Env(_) => 6, + ImportLocation::Missing => 7, + }; + ser_seq.serialize_element(&U64(scheme))?; + + match &import.location_hashed.location { + ImportLocation::Remote(url) => { + match &url.headers { + None => ser_seq.serialize_element(&Null)?, + Some(_x) => unimplemented!(), + // match cbor_value_to_dhall(&x)?.as_ref() { + // Embed(import) => Some(Box::new( + // import.location_hashed.clone(), + // )), + // } + }; + ser_seq.serialize_element(&url.authority)?; + for p in &url.path { + ser_seq.serialize_element(p)?; + } + match &url.query { + None => ser_seq.serialize_element(&Null)?, + Some(x) => ser_seq.serialize_element(x)?, + }; + } + ImportLocation::Local(_, path) => { + for p in path { + ser_seq.serialize_element(p)?; + } + } + ImportLocation::Env(env) => { + ser_seq.serialize_element(env)?; + } + ImportLocation::Missing => {} + } + + ser_seq.end() +} + +impl<'a> serde::ser::Serialize for Serialize<'a> { + fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error> + where + S: serde::ser::Serializer, + { + match self { + Serialize::Expr(e) => serialize_subexpr(ser, e), + Serialize::CBOR(v) => v.serialize(ser), + Serialize::RecordMap(map) => { + ser.collect_map(map.iter().map(|(k, v)| { + (cbor::Value::String(k.into()), Serialize::Expr(v)) + })) + } + Serialize::UnionMap(map) => { + ser.collect_map(map.iter().map(|(k, v)| { + let v = match v { + Some(x) => Serialize::Expr(x), + None => Serialize::CBOR(cbor::Value::Null), + }; + (cbor::Value::String(k.into()), v) + })) + } + } + } +} + +fn collect_nested_applications<'a, N, E>( + e: &'a SubExpr<N, E>, +) -> (&'a SubExpr<N, E>, Vec<&'a SubExpr<N, E>>) { + fn go<'a, N, E>( + e: &'a SubExpr<N, E>, + vec: &mut Vec<&'a SubExpr<N, E>>, + ) -> &'a SubExpr<N, E> { + match e.as_ref() { + ExprF::App(f, a) => { + vec.push(a); + go(f, vec) + } + _ => e, + } + } + let mut vec = vec![]; + let e = go(e, &mut vec); + (e, vec) +} diff --git a/dhall/src/phase/mod.rs b/dhall/src/phase/mod.rs index 63480c5..681b7fe 100644 --- a/dhall/src/phase/mod.rs +++ b/dhall/src/phase/mod.rs @@ -8,7 +8,7 @@ use crate::core::context::TypecheckContext; use crate::core::thunk::Thunk; use crate::core::value::Value; use crate::core::var::{AlphaVar, Shift, Subst}; -use crate::error::{Error, ImportError, TypeError, TypeMessage}; +use crate::error::{EncodeError, Error, ImportError, TypeError, TypeMessage}; use resolve::ImportRoot; use typecheck::type_of_const; @@ -20,6 +20,7 @@ pub(crate) mod resolve; pub(crate) mod typecheck; pub type ParsedSubExpr = SubExpr<Span, Import>; +pub type DecodedSubExpr = SubExpr<X, Import>; pub type ResolvedSubExpr = SubExpr<Span, Normalized>; pub type NormalizedSubExpr = SubExpr<X, X>; @@ -55,24 +56,30 @@ impl Parsed { pub fn parse_file(f: &Path) -> Result<Parsed, Error> { parse::parse_file(f) } - pub fn parse_str(s: &str) -> Result<Parsed, Error> { parse::parse_str(s) } - #[allow(dead_code)] pub fn parse_binary_file(f: &Path) -> Result<Parsed, Error> { parse::parse_binary_file(f) } + #[allow(dead_code)] + pub fn parse_binary(data: &[u8]) -> Result<Parsed, Error> { + parse::parse_binary(data) + } pub fn resolve(self) -> Result<Resolved, ImportError> { resolve::resolve(self) } - #[allow(dead_code)] pub fn skip_resolve(self) -> Result<Resolved, ImportError> { resolve::skip_resolve_expr(self) } + + #[allow(dead_code)] + pub fn encode(&self) -> Result<Vec<u8>, EncodeError> { + crate::phase::binary::encode(&self.0) + } } impl Resolved { diff --git a/dhall/src/phase/parse.rs b/dhall/src/phase/parse.rs index 765fc09..734f6e1 100644 --- a/dhall/src/phase/parse.rs +++ b/dhall/src/phase/parse.rs @@ -22,6 +22,12 @@ pub fn parse_str(s: &str) -> Result<Parsed, Error> { Ok(Parsed(expr, root)) } +pub fn parse_binary(data: &[u8]) -> Result<Parsed, Error> { + let expr = crate::phase::binary::decode(data)?; + let root = ImportRoot::LocalDir(std::env::current_dir()?); + Ok(Parsed(expr.note_absurd(), root)) +} + pub fn parse_binary_file(f: &Path) -> Result<Parsed, Error> { let mut buffer = Vec::new(); File::open(f)?.read_to_end(&mut buffer)?; diff --git a/dhall/src/phase/resolve.rs b/dhall/src/phase/resolve.rs index 7e446eb..fa5f32e 100644 --- a/dhall/src/phase/resolve.rs +++ b/dhall/src/phase/resolve.rs @@ -30,6 +30,7 @@ fn resolve_import( }; match &import.location_hashed.location { Local(prefix, path) => { + let path: PathBuf = path.iter().cloned().collect(); let path = match prefix { // TODO: fail gracefully Parent => cwd.parent().unwrap().join(path), diff --git a/dhall/src/phase/typecheck.rs b/dhall/src/phase/typecheck.rs index 5caf1d5..ac584cd 100644 --- a/dhall/src/phase/typecheck.rs +++ b/dhall/src/phase/typecheck.rs @@ -605,7 +605,7 @@ fn type_last_layer( ensure_equal!( x.get_type()?, &text_type, - mkerr(InvalidTextInterpolation(x)), + mkerr(InvalidTextInterpolation(x.clone())), ); } } diff --git a/dhall/src/tests.rs b/dhall/src/tests.rs index 76e2e26..f7802e8 100644 --- a/dhall/src/tests.rs +++ b/dhall/src/tests.rs @@ -33,9 +33,12 @@ macro_rules! make_spec_test { }; } +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + use crate::error::{Error, Result}; use crate::phase::Parsed; -use std::path::PathBuf; #[derive(Copy, Clone)] pub enum Feature { @@ -57,10 +60,6 @@ fn parse_file_str<'i>(file_path: &str) -> Result<Parsed> { Parsed::parse_file(&PathBuf::from(file_path)) } -fn parse_binary_file_str<'i>(file_path: &str) -> Result<Parsed> { - Parsed::parse_binary_file(&PathBuf::from(file_path)) -} - pub fn run_test_stringy_error( base_path: &str, feature: Feature, @@ -101,11 +100,38 @@ pub fn run_test( let expr = parse_file_str(&expr_file_path)?; if let Parser = feature { + // Compare parse/decoded let expected_file_path = base_path + "B.dhallb"; - let expected = parse_binary_file_str(&expected_file_path)?; - + let expected_file_path = PathBuf::from(&expected_file_path); + let mut expected_data = Vec::new(); + { + File::open(&expected_file_path)? + .read_to_end(&mut expected_data)?; + } + let expected = Parsed::parse_binary(&expected_data)?; assert_eq_pretty!(expr, expected); + // Compare encoded/expected + let expr_data = expr.encode()?; + // Compare bit-by-bit + if expr_data != expected_data { + // use std::io::Write; + // File::create(&expected_file_path)?.write_all(&expr_data)?; + // Pretty-print difference + assert_eq_pretty!( + serde_cbor::de::from_slice::<serde_cbor::value::Value>( + &expr_data + ) + .unwrap(), + serde_cbor::de::from_slice::<serde_cbor::value::Value>( + &expected_data + ) + .unwrap() + ); + // If difference was not visible in the cbor::Value + assert_eq!(expr_data, expected_data); + } + // Round-trip pretty-printer let expr_string = expr.to_string(); let expr: Parsed = Parsed::parse_str(&expr_string)?; |