diff options
-rw-r--r-- | dhall_syntax/src/core.rs | 15 | ||||
-rw-r--r-- | dhall_syntax/src/printer.rs | 51 |
2 files changed, 47 insertions, 19 deletions
diff --git a/dhall_syntax/src/core.rs b/dhall_syntax/src/core.rs index c8a2425..389f037 100644 --- a/dhall_syntax/src/core.rs +++ b/dhall_syntax/src/core.rs @@ -60,7 +60,7 @@ pub enum Const { /// The `Int` field is a DeBruijn index. /// See dhall-lang/standard/semantics.md for details #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Var<Label>(pub Label, pub usize); +pub struct Var<VarLabel>(pub VarLabel, pub usize); // Definition order must match precedence order for // pretty-printing to work correctly @@ -529,6 +529,19 @@ impl<'a> From<&'a Label> for Var<Label> { } } +/// Trait for things that capture a label used for variables. +/// Allows normalization to be generic in whether to alpha-normalize or not. +pub trait VarLabel: std::fmt::Display + Clone { + /// Is `self` the default variable (i.e. "_") ? + fn is_underscore_var(&self) -> bool; +} + +impl VarLabel for Label { + fn is_underscore_var(&self) -> bool { + &String::from(self) == "_" + } +} + /// `shift` is used by both normalization and type-checking to avoid variable /// capture by shifting variable indices /// See https://github.com/dhall-lang/dhall-lang/blob/master/standard/semantics.md#shift diff --git a/dhall_syntax/src/printer.rs b/dhall_syntax/src/printer.rs index 9cc1b46..6ebd537 100644 --- a/dhall_syntax/src/printer.rs +++ b/dhall_syntax/src/printer.rs @@ -3,8 +3,11 @@ use itertools::Itertools; use std::fmt::{self, Display}; /// Generic instance that delegates to subexpressions -impl<SE: Display + Clone, L: Display + Clone, E: Display> Display - for ExprF<SE, L, E> +impl<SE, L, E> Display for ExprF<SE, L, E> +where + SE: Display + Clone, + L: VarLabel, + E: Display, { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { use crate::ExprF::*; @@ -15,10 +18,9 @@ impl<SE: Display + Clone, L: Display + Clone, E: Display> Display BoolIf(a, b, c) => { write!(f, "if {} then {} else {}", a, b, c)?; } - // TODO: arrow type - // Pi(a, b, c) if &String::from(a) == "_" => { - // write!(f, "{} → {}", b, c)?; - // } + Pi(a, b, c) if a.is_underscore_var() => { + write!(f, "{} → {}", b, c)?; + } Pi(a, b, c) => { write!(f, "∀({} : {}) → {}", a, b, c)?; } @@ -129,23 +131,34 @@ enum PrintPhase { #[derive(Clone)] struct PhasedExpr<'a, L, S, A>(&'a SubExpr<L, S, A>, PrintPhase); -impl<'a, L: Display + Clone, S: Clone, A: Display + Clone> Display - for PhasedExpr<'a, L, S, A> +impl<'a, L, S, A> Display for PhasedExpr<'a, L, S, A> +where + L: VarLabel, + S: Clone, + A: Display + Clone, { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { self.0.as_ref().fmt_phase(f, self.1) } } -impl<'a, L: Display + Clone, S: Clone, A: Display + Clone> - PhasedExpr<'a, L, S, A> +impl<'a, L, S, A> PhasedExpr<'a, L, S, A> +where + L: VarLabel, + S: Clone, + A: Display + Clone, { fn phase(self, phase: PrintPhase) -> PhasedExpr<'a, L, S, A> { PhasedExpr(self.0, phase) } } -impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> { +impl<L, S, A> Expr<L, S, A> +where + L: VarLabel, + S: Clone, + A: Display + Clone, +{ fn fmt_phase( &self, f: &mut fmt::Formatter, @@ -179,12 +192,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> { // Annotate subexpressions with the appropriate phase, defaulting to Base let phased_self = match self.map_ref_simple(|e| PhasedExpr(e, Base)) { Pi(a, b, c) => { - // TODO: arrow type - // if &String::from(&a) == "_" { - // Pi(a, b.phase(Operator), c) - // } else { + if a.is_underscore_var() { + Pi(a, b.phase(Operator), c) + } else { Pi(a, b, c) - // } + } } Merge(a, b, c) => Merge( a.phase(Import), @@ -220,8 +232,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> { } } -impl<L: Display + Clone, S: Clone, A: Display + Clone> Display - for SubExpr<L, S, A> +impl<L, S, A> Display for SubExpr<L, S, A> +where + L: VarLabel, + S: Clone, + A: Display + Clone, { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { self.as_ref().fmt_phase(f, PrintPhase::Base) |