From 153cf8dab3b80aba30ac3adfd44e4be251494ea2 Mon Sep 17 00:00:00 2001
From: Nadrieril
Date: Sat, 4 May 2019 18:37:15 +0200
Subject: Recover arrow type detection

---
 dhall_syntax/src/core.rs    | 15 ++++++++++++-
 dhall_syntax/src/printer.rs | 51 +++++++++++++++++++++++++++++----------------
 2 files changed, 47 insertions(+), 19 deletions(-)

(limited to 'dhall_syntax')

diff --git a/dhall_syntax/src/core.rs b/dhall_syntax/src/core.rs
index c8a2425..389f037 100644
--- a/dhall_syntax/src/core.rs
+++ b/dhall_syntax/src/core.rs
@@ -60,7 +60,7 @@ pub enum Const {
 /// The `Int` field is a DeBruijn index.
 /// See dhall-lang/standard/semantics.md for details
 #[derive(Debug, Clone, PartialEq, Eq)]
-pub struct Var<Label>(pub Label, pub usize);
+pub struct Var<VarLabel>(pub VarLabel, pub usize);
 
 // Definition order must match precedence order for
 // pretty-printing to work correctly
@@ -529,6 +529,19 @@ impl<'a> From<&'a Label> for Var<Label> {
     }
 }
 
+/// Trait for things that capture a label used for variables.
+/// Allows normalization to be generic in whether to alpha-normalize or not.
+pub trait VarLabel: std::fmt::Display + Clone {
+    /// Is `self` the default variable (i.e. "_") ?
+    fn is_underscore_var(&self) -> bool;
+}
+
+impl VarLabel for Label {
+    fn is_underscore_var(&self) -> bool {
+        &String::from(self) == "_"
+    }
+}
+
 /// `shift` is used by both normalization and type-checking to avoid variable
 /// capture by shifting variable indices
 /// See https://github.com/dhall-lang/dhall-lang/blob/master/standard/semantics.md#shift
diff --git a/dhall_syntax/src/printer.rs b/dhall_syntax/src/printer.rs
index 9cc1b46..6ebd537 100644
--- a/dhall_syntax/src/printer.rs
+++ b/dhall_syntax/src/printer.rs
@@ -3,8 +3,11 @@ use itertools::Itertools;
 use std::fmt::{self, Display};
 
 /// Generic instance that delegates to subexpressions
-impl<SE: Display + Clone, L: Display + Clone, E: Display> Display
-    for ExprF<SE, L, E>
+impl<SE, L, E> Display for ExprF<SE, L, E>
+where
+    SE: Display + Clone,
+    L: VarLabel,
+    E: Display,
 {
     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
         use crate::ExprF::*;
@@ -15,10 +18,9 @@ impl<SE: Display + Clone, L: Display + Clone, E: Display> Display
             BoolIf(a, b, c) => {
                 write!(f, "if {} then {} else {}", a, b, c)?;
             }
-            // TODO: arrow type
-            // Pi(a, b, c) if &String::from(a) == "_" => {
-            //     write!(f, "{} → {}", b, c)?;
-            // }
+            Pi(a, b, c) if a.is_underscore_var() => {
+                write!(f, "{} → {}", b, c)?;
+            }
             Pi(a, b, c) => {
                 write!(f, "∀({} : {}) → {}", a, b, c)?;
             }
@@ -129,23 +131,34 @@ enum PrintPhase {
 #[derive(Clone)]
 struct PhasedExpr<'a, L, S, A>(&'a SubExpr<L, S, A>, PrintPhase);
 
-impl<'a, L: Display + Clone, S: Clone, A: Display + Clone> Display
-    for PhasedExpr<'a, L, S, A>
+impl<'a, L, S, A> Display for PhasedExpr<'a, L, S, A>
+where
+    L: VarLabel,
+    S: Clone,
+    A: Display + Clone,
 {
     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
         self.0.as_ref().fmt_phase(f, self.1)
     }
 }
 
-impl<'a, L: Display + Clone, S: Clone, A: Display + Clone>
-    PhasedExpr<'a, L, S, A>
+impl<'a, L, S, A> PhasedExpr<'a, L, S, A>
+where
+    L: VarLabel,
+    S: Clone,
+    A: Display + Clone,
 {
     fn phase(self, phase: PrintPhase) -> PhasedExpr<'a, L, S, A> {
         PhasedExpr(self.0, phase)
     }
 }
 
-impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
+impl<L, S, A> Expr<L, S, A>
+where
+    L: VarLabel,
+    S: Clone,
+    A: Display + Clone,
+{
     fn fmt_phase(
         &self,
         f: &mut fmt::Formatter,
@@ -179,12 +192,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
         // Annotate subexpressions with the appropriate phase, defaulting to Base
         let phased_self = match self.map_ref_simple(|e| PhasedExpr(e, Base)) {
             Pi(a, b, c) => {
-                // TODO: arrow type
-                // if &String::from(&a) == "_" {
-                //     Pi(a, b.phase(Operator), c)
-                // } else {
+                if a.is_underscore_var() {
+                    Pi(a, b.phase(Operator), c)
+                } else {
                 Pi(a, b, c)
-                // }
+                }
             }
             Merge(a, b, c) => Merge(
                 a.phase(Import),
@@ -220,8 +232,11 @@ impl<L: Display + Clone, S: Clone, A: Display + Clone> Expr<L, S, A> {
     }
 }
 
-impl<L: Display + Clone, S: Clone, A: Display + Clone> Display
-    for SubExpr<L, S, A>
+impl<L, S, A> Display for SubExpr<L, S, A>
+where
+    L: VarLabel,
+    S: Clone,
+    A: Display + Clone,
 {
     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
         self.as_ref().fmt_phase(f, PrintPhase::Base)
-- 
cgit v1.2.3