summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-12 17:59:12 +0100
committerSon Ho2023-12-12 17:59:12 +0100
commit91f5cd49660b5f012a2faeaf00c49455c548734a (patch)
treee20d07c395f75abd01bdeb7366f4cb79d15d8781
parent698f631e7addb92eb270a75607f1f6ffd8b2414f (diff)
Fix a minor issue with the divergent encoding
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean133
1 files changed, 80 insertions, 53 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index e9544f38..ff680c07 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -496,9 +496,24 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Not a recursive call: do nothing
pure e
| .const name _ =>
- -- Sanity check: we eliminated all the recursive calls
- if (nameToInfo.find? name).isSome then
- throwError "mkUnaryBodies: a recursive call was not eliminated"
+ /- This might refer to the one of the top-level functions if we use
+ it without arguments (if we give it to a higher-order
+ function for instance) and there are actually no type parameters.
+ -/
+ if (nameToInfo.find? name).isSome then do
+ -- Checking the type information
+ match nameToInfo.find? name with
+ | none => pure e
+ | some (id, type_info) =>
+ trace[Diverge.def.genBody.visit] "this is a recursive call"
+ -- This is a recursive call: replace it
+ -- Compute the index
+ let i ← mkFinVal grSize id
+ -- Check that there are no type parameters
+ if type_info.num_params ≠ 0 then throwError "mkUnaryBodies: a recursive call was not eliminated"
+ -- Introduce the call to the continuation
+ let param_args ← mkSigmasVal type_info.params_ty []
+ mkAppM' kk_var #[i, param_args]
else pure e
| _ => pure e
trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}"
@@ -1436,57 +1451,11 @@ elab_rules : command
namespace Tests
/- Some examples of partial functions -/
- section HigherOrder
- open Ex8
-
- inductive Tree (a : Type u) :=
- | leaf (x : a)
- | node (tl : List (Tree a))
-
- divergent def id {a : Type u} (t : Tree a) : Result (Tree a) :=
- match t with
- | .leaf x => .ret (.leaf x)
- | .node tl =>
- do
- let tl ← map id tl
- .ret (.node tl)
-
- #check id.unfold
-
- divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) :=
- match t with
- | .leaf x => .ret (.leaf x)
- | .node tl =>
- do
- let tl ← map (fun x => id1 x) tl
- .ret (.node tl)
-
- #check id1.unfold
-
- divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) :=
- match t with
- | .leaf x => .ret (.leaf x)
- | .node tl =>
- do
- let tl ← map (fun x => do let _ ← id2 x; id2 x) tl
- .ret (.node tl)
-
- #check id2.unfold
-
- /-set_option trace.Diverge.def true
- -- set_option trace.Diverge.def.genBody true
- set_option trace.Diverge.def.valid true
- divergent def incr (t : Tree Nat) : Result (Tree Nat) :=
- match t with
- | .leaf x => .ret (.leaf (x + 1))
- | .node tl =>
- do
- let tl ← map incr tl
- .ret (.node tl)
-
- set_option trace.Diverge.def false-/
- end HigherOrder
+ --set_option trace.Diverge.def true
+ --set_option trace.Diverge.def.genBody true
+ --set_option trace.Diverge.def.valid true
+ --set_option trace.Diverge.def.genBody.visit true
divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
@@ -1495,6 +1464,8 @@ namespace Tests
if i = 0 then return x
else return (← list_nth ls (i - 1))
+ --set_option trace.Diverge false
+
#check list_nth.unfold
example {a: Type} (ls : List a) :
@@ -1590,6 +1561,62 @@ namespace Tests
#check test1.unfold
+ /- Tests with higher-order functions -/
+ section HigherOrder
+ open Ex8
+
+ inductive Tree (a : Type u) :=
+ | leaf (x : a)
+ | node (tl : List (Tree a))
+
+ divergent def id {a : Type u} (t : Tree a) : Result (Tree a) :=
+ match t with
+ | .leaf x => .ret (.leaf x)
+ | .node tl =>
+ do
+ let tl ← map id tl
+ .ret (.node tl)
+
+ #check id.unfold
+
+ divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) :=
+ match t with
+ | .leaf x => .ret (.leaf x)
+ | .node tl =>
+ do
+ let tl ← map (fun x => id1 x) tl
+ .ret (.node tl)
+
+ #check id1.unfold
+
+ divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) :=
+ match t with
+ | .leaf x => .ret (.leaf x)
+ | .node tl =>
+ do
+ let tl ← map (fun x => do let _ ← id2 x; id2 x) tl
+ .ret (.node tl)
+
+ #check id2.unfold
+
+ --set_option trace.Diverge.def true
+ --set_option trace.Diverge.def.genBody true
+ --set_option trace.Diverge.def.valid true
+ --set_option trace.Diverge.def.genBody.visit true
+ divergent def incr (t : Tree Nat) : Result (Tree Nat) :=
+ match t with
+ | .leaf x => .ret (.leaf (x + 1))
+ | .node tl =>
+ do
+ let tl ← map incr tl
+ .ret (.node tl)
+
+ --set_option trace.Diverge false
+
+ #check incr.unfold
+
+ end HigherOrder
+
end Tests
end Diverge