summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--dhall_syntax/src/core.rs15
-rw-r--r--dhall_syntax/src/printer.rs51
2 files changed, 47 insertions, 19 deletions
diff --git a/dhall_syntax/src/core.rs b/dhall_syntax/src/core.rs
index c8a2425..389f037 100644
--- a/dhall_syntax/src/core.rs
+++ b/dhall_syntax/src/core.rs
@@ -60,7 +60,7 @@ pub enum Const {
/// The `Int` field is a DeBruijn index.
/// See dhall-lang/standard/semantics.md for details
#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct Var<Label>(pub Label, pub usize);
+pub struct Var<VarLabel>(pub VarLabel, pub usize);
// Definition order must match precedence order for
// pretty-printing to work correctly
@@ -529,6 +529,19 @@ impl<'a> From<&'a Label> for Var<Label> {
}
}
+/// Trait for things that capture a label used for variables.
+/// Allows normalization to be generic in whether to alpha-normalize or not.
+pub trait VarLabel: std::fmt::Display + Clone {
+ /// Is `self` the default variable (i.e. "_") ?
+ fn is_underscore_var(&self) -> bool;
+}
+
+impl VarLabel for Label {
+ fn is_underscore_var(&self) -> bool {
+ &String::from(self) == "_"
+ }
+}
+
/// `shift` is used by both normalization and type-checking to avoid variable
/// capture by shifting variable indices
/// See https://github.com/dhall-lang/dhall-lang/blob/master/standard/semantics.md#shift
diff --git a/dhall_syntax/src/printer.rs b/dhall_syntax/src/printer.rs
index 9cc1b46..6ebd537 100644
--- a/dhall_syntax/src/printer.rs
+++ b/dhall_syntax/src/printer.rs
@@ -3,8 +3,11 @@ use itertools::Itertools;
use std::fmt::{self, Display};
/// Generic instance that delegates to subexpressions
-impl<SE: Display + Clone, L: Display + Clone, E: Display> Display
- for ExprF<SE, L, E>
+impl<SE, L, E> Display for ExprF<SE, L, E>
+where
+ SE: Display + Clone,
+ L: VarLabel,
+ E: Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use crate::ExprF::*;
@@ -15,10 +18,9 @@ impl<SE: Display + Clone, L: Display + Clone, E: Display> Display
BoolIf(a, b, c) => {
write!(f, "if {} then {} else {}", a, b, c)?;
}
- // TODO: arrow type
- // Pi(a, b, c) if &String::from(a) == "_" => {
- // write!(f, "{} → {}", b, c)?;
- // }
+ Pi(a, b, c) if a.is_underscore_var() => {
+ write!(f, "{} → {}", b, c)?;
+ }
Pi(a, b, c) => {
write!(f, "∀({} : {}) → {}", a, b, c)?;
}
@@ -129,23 +131,34 @@ enum PrintPhase {
#[derive(Clone)]
struct PhasedExpr<'a, L, S, A>(&'a SubExpr<L, S, A>, PrintPhase);
-impl<'a, L: Display + Clone, S: Clone, A: Display + Clone> Display
- for PhasedExpr<'a, L, S, A>
+impl<'a, L, S, A> Display for PhasedExpr<'a, L, S, A>
+where
+ L: VarLabel,
+ S: Clone,
+ A: Display + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
self.0.as_ref().fmt_phase(f, self.1)
}
}
-impl<'a, L: Display + Clone, S: Clone, A: Display + Clone>
- PhasedExpr<'a, L, S, A>
+impl<'a, L, S, A> PhasedExpr<'a, L, S, A>
+where
+ L: VarLabel,
+ S: Clone,
+ A: Display + Clone,
{
fn phase(self, phase: PrintPhase) -> PhasedExpr<'a, L, S, A> {
PhasedExpr(self.0, phase)
}
}
-impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
+impl<L, S, A> Expr<L, S, A>
+where
+ L: VarLabel,
+ S: Clone,
+ A: Display + Clone,
+{
fn fmt_phase(
&self,
f: &mut fmt::Formatter,
@@ -179,12 +192,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
// Annotate subexpressions with the appropriate phase, defaulting to Base
let phased_self = match self.map_ref_simple(|e| PhasedExpr(e, Base)) {
Pi(a, b, c) => {
- // TODO: arrow type
- // if &String::from(&a) == "_" {
- // Pi(a, b.phase(Operator), c)
- // } else {
+ if a.is_underscore_var() {
+ Pi(a, b.phase(Operator), c)
+ } else {
Pi(a, b, c)
- // }
+ }
}
Merge(a, b, c) => Merge(
a.phase(Import),
@@ -220,8 +232,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
}
}
-impl<L: Display + Clone, S: Clone, A: Display + Clone> Display
- for SubExpr<L, S, A>
+impl<L, S, A> Display for SubExpr<L, S, A>
+where
+ L: VarLabel,
+ S: Clone,
+ A: Display + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
self.as_ref().fmt_phase(f, PrintPhase::Base)