summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge/Elab.lean
blob: 22e0039f8dba48749ac2034e248137ed9d918957 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import Lean
import Lean.Meta.Tactic.Simp
import Init.Data.List.Basic
import Mathlib.Tactic.RunCmd
import Base.Diverge.Base
import Base.Diverge.ElabBase

namespace Diverge

/- Automating the generation of the encoding and the proofs so as to use nice
   syntactic sugar. -/

syntax (name := divergentDef)
  declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command

open Lean Elab Term Meta Primitives

set_option trace.Diverge.def true

/- The following was copied from the `wfRecursion` function. -/

open WF in



-- Replace the recursive calls by a call to the continuation
-- def replace_rec_calls

#check Lean.Meta.forallTelescope
#check Expr
#check withRef
#check MonadRef.withRef
#check Nat
#check Array
#check Lean.Meta.inferType
#check Nat
#check Int

#check (0, 1)
#check Prod
#check ()
#check Unit
#check Sigma

--  print_decl is_even_body
#check instOfNatNat
#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ...
#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ...
#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat


-- TODO: is there already such a utility somewhere?
-- TODO: change to mkSigmas
def mkProds (tys : List Expr) : MetaM Expr :=
  match tys with
  | [] => do return (Expr.const ``PUnit.unit [])
  | [ty] => do return ty
  | ty :: tys => do
    let pty  mkProds tys
    mkAppM ``Prod.mk #[ty, pty]

def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
  let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
  trace[Diverge.def] ("divRecursion: defs: " ++ msg)

  -- CHANGE HERE This function should add definitions with these names/types/values ^^
  -- Temporarily add the predefinitions as axioms
  for preDef in preDefs do
    addAsAxiom preDef

  -- TODO: what is this?
  for preDef in preDefs do
    applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation

  -- Retrieve the name of the first definition, that we will use as the namespace
  -- for the definitions common to the group
  let def0 := preDefs[0]!
  let grName := def0.declName
  trace[Diverge.def] "group name: {grName}"

  /- Compute the type of the continuation.

     We do the following
     - we make sure all the definitions have the same universe parameters
       (we can make this more general later)
     - we group all the type parameters together, make sure all the
       definitions have the same type parameters, and enforce
       a uniform polymorphism (we can also lift this later).
       This would require generalizing a bit our indexed fixed point to
       make the output type parametric in the input.
     - we group all the non-type parameters: we parameterize the continuation
       by those
  -/
  let grLvlParams := def0.levelParams
  trace[Diverge.def] "def0 type: {def0.type}"

  -- Small utility: compute the list of type parameters
  let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) :=
    Lean.Meta.forallTelescope ty fun tys out_ty => do
      trace[Diverge.def] "types: {tys}"
/-      let (_, params) ← StateT.run (do
        for x in tys do
          let ty ← Lean.Meta.inferType x
          match ty with
          | .sort _ => do
            let st ← StateT.get
            StateT.set (ty :: st)
          | _ => do break
        ) ([] : List Expr)
      let params := params.reverse
      trace[Diverge.def] "  type parameters {params}"
      return params -/
      let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) :=
        match ls with
        | x :: tl => do
          let ty  Lean.Meta.inferType x
          match ty with
          | .sort _ => do
            let (ty_params, params)  get_params tl
            return (x :: ty_params, params)
          | _ => do return ([], ls)
        | _ => do return ([], [])
      let (ty_params, params)  get_params tys.toList
      trace[Diverge.def] "  parameters: {ty_params}; {params}"
      return (ty_params, params, out_ty)
  let (grTyParams, _, _)  do
    getTypeParams def0.type

  -- Compute the input types and the output types
  let all_tys  preDefs.mapM fun preDef => do
    let (tyParams, params, ret_ty)  getTypeParams preDef.type
    -- TODO: this is not complete, there are more checks to perform
    if tyParams.length  grTyParams.length then
      throwError "Non-uniform polymorphism"
    return (params, ret_ty)

  -- TODO: I think there are issues with the free variables
  let (input_tys, output_tys) := List.unzip all_tys.toList
  let input_tys : List Expr  liftM (List.mapM mkProds input_tys)

  trace[Diverge.def] "  in/out tys: {input_tys}; {output_tys}"

  -- Compute the names set
  let names := preDefs.map PreDefinition.declName
  let names := HashSet.empty.insertMany names

  --
  for preDef in preDefs do
    trace[Diverge.def] "about to explore: {preDef.declName}"
    explore_term "" preDef.value

  -- Compute the bodies

  -- Process the definitions
  addAndCompilePartialRec preDefs

-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main
-- This is the only part where we make actual changes and hook into the equation compiler.
-- (I've removed all the well-founded stuff to make it easier to read though.)

open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDefs
  addAndCompilePartial addAsAxioms from Lean.Elab.PreDefinition.Main

def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do
  for preDef in preDefs do
    trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
  let preDefs  preDefs.mapM ensureNoUnassignedMVarsAtPreDef
  let preDefs  betaReduceLetRecApps preDefs
  let cliques := partitionPreDefs preDefs
  let mut hasErrors := false
  for preDefs in cliques do
    trace[Diverge.elab] "{preDefs.map (·.declName)}"
    try
      trace[Diverge.elab] "calling divRecursion"
      withRef (preDefs[0]!.ref) do
        divRecursion preDefs
      trace[Diverge.elab] "divRecursion succeeded"
    catch ex =>
      -- If it failed, we 
      trace[Diverge.elab] "divRecursion failed"
      hasErrors := true
      logException ex
      let s  saveState
      try
        if preDefs.all fun preDef => preDef.kind == DefKind.def ||
           preDefs.all fun preDef => preDef.kind == DefKind.abbrev then
          -- try to add as partial definition
          try
            addAndCompilePartial preDefs (useSorry := true)
          catch _ =>
            -- Compilation failed try again just as axiom
            s.restore
            addAsAxioms preDefs
        else return ()
      catch _ => s.restore

-- The following two functions are copy&pasted from Lean.Elab.MutualDef

open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
  instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef

def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
    let scopeLevelNames  getLevelNames
    let headers  elabHeaders views
    let headers  levelMVarToParamHeaders views headers
    let allUserLevelNames := getAllUserLevelNames headers
    withFunLocalDecls headers fun funFVars => do
      for view in views, funFVar in funFVars do
        addLocalVarInfo view.declId funFVar
      let values 
        try
          let values  elabFunValues headers
          Term.synthesizeSyntheticMVarsNoPostponing
          values.mapM (instantiateMVars ·)
        catch ex =>
          logException ex
          headers.mapM fun header => mkSorry header.type (synthetic := true)
      let headers  headers.mapM instantiateMVarsAtHeader
      let letRecsToLift  getLetRecsToLift
      let letRecsToLift  letRecsToLift.mapM instantiateMVarsAtLetRecToLift
      checkLetRecsToLiftTypes funFVars letRecsToLift
      withUsed vars headers values letRecsToLift fun vars => do
        let preDefs  MutualClosure.main vars headers funFVars values letRecsToLift
        for preDef in preDefs do
          trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
        let preDefs  withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
        let preDefs  instantiateMVarsAtPreDecls preDefs
        let preDefs  fixLevelParams preDefs scopeLevelNames allUserLevelNames
        for preDef in preDefs do
          trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
        checkForHiddenUnivLevels allUserLevelNames preDefs
        addPreDefinitions preDefs

open Command in
def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
  let views  ds.mapM fun d => do
    let `($mods:declModifiers divergent def $id:declId $sig:optDeclSig $val:declVal) := d
      | throwUnsupportedSyntax
    let modifiers  elabModifiers mods
    let (binders, type) := expandOptDeclSig sig
    let deriving? := none
    pure { ref := d, kind := DefKind.def, modifiers,
           declId := id, binders, type? := type, value := val, deriving? }
  runTermElabM fun vars => Term.elabMutualDef vars views

-- Special command so that we don't fall back to the built-in mutual when we produce an error.
local syntax "_divergent" Parser.Command.mutual : command
elab_rules : command | `(_divergent mutual $decls* end) => Command.elabMutualDef decls

macro_rules
  | `(mutual $decls* end) => do
    unless !decls.isEmpty && decls.all (·.1.getKind == ``divergentDef) do
      Macro.throwUnsupported
    `(command| _divergent mutual $decls* end)

open private setDeclIdName from Lean.Elab.Declaration
elab_rules : command
  | `($mods:declModifiers divergent%$tk def $id:declId $sig:optDeclSig $val:declVal) => do
    let (name, _) := expandDeclIdCore id
    if (`_root_).isPrefixOf name then throwUnsupportedSyntax
    let view := extractMacroScopes name
    let .str ns shortName := view.name | throwUnsupportedSyntax
    let shortName' := { view with name := shortName }.review
    let cmd  `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end)
    if ns matches .anonymous then
      Command.elabCommand cmd
    else
      Command.elabCommand <|  `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns))

mutual
  divergent def is_even (i : Int) : Result Bool :=
    if i = 0 then return true else return ( is_odd (i - 1))

  divergent def is_odd (i : Int) : Result Bool :=
    if i = 0 then return false else return ( is_even (i - 1))
end

example (i : Int) : is_even i = .ret (i % 2 = 0)  is_odd i = .ret (i % 2  0) := by
  induction i
  unfold is_even
  sorry    

divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
  match ls with
  | [] => .fail .panic
  | x :: ls =>
    if i = 0 then return x
    else return ( list_nth ls (i - 1))

mutual
  divergent def foo (i : Int) : Result Nat :=
    if i > 10 then return ( foo (i / 10)) + ( bar i) else bar 10

  divergent def bar (i : Int) : Result Nat :=
    if i > 20 then foo (i / 20) else .ret 42
end

end Diverge