summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Progress/Base.lean
blob: 76a92795419988b464ebcee5f512367e2f758dce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import Lean
import Std.Lean.HashSet
import Base.Utils
import Base.Primitives.Base

namespace Progress

open Lean Elab Term Meta
open Utils

-- We can't define and use trace classes in the same file
initialize registerTraceClass `Progress

/- # Progress tactic -/

structure PSpecDesc where
  -- The universally quantified variables
  fvars : Array Expr
  -- The existentially quantified variables
  evars : Array Expr
  -- The function
  fExpr : Expr
  fName : Name
  -- The function arguments
  fLevels : List Level
  args : Array Expr
  -- The universally quantified variables which appear in the function arguments
  argsFVars : Array FVarId
  -- The returned value
  ret : Expr
  -- The postcondition (if there is)
  post : Option Expr

section Methods
  variable [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m]
  variable [MonadTrace m] [MonadLiftT IO m] [MonadRef m] [AddMessageContext m]
  variable [MonadError m]
  variable {a : Type}

  /- Analyze a pspec theorem to decompose its arguments.

     PSpec theorems should be of the following shape:
     ```
     ∀ x1 ... xn, H1 → ... Hn → ∃ y1 ... ym. f x1 ... xn = .ret ... ∧ Post1 ∧ ... ∧ Postk
     ```

     The continuation `k` receives the following inputs:
     - universally quantified variables
     - assumptions
     - existentially quantified variables
     - function name
     - function arguments
     - return
     - postconditions

     TODO: generalize for when we do inductive proofs
  -/
  partial
  def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc  m a)
    (sanityChecks : Bool := false) :
    m a := do
    trace[Progress] "Proposition: {th}"
    -- Dive into the quantified variables and the assumptions
    forallTelescope th.consumeMData fun fvars th => do
    trace[Progress] "Universally quantified arguments and assumptions: {fvars}"
    -- Dive into the existentials
    existsTelescope th.consumeMData fun evars th => do
    trace[Progress] "Existentials: {evars}"
    trace[Progress] "Proposition after stripping the quantifiers: {th}"
    -- Take the first conjunct
    let (th, post)  optSplitConj th.consumeMData
    trace[Progress] "After splitting the conjunction:\n- eq: {th}\n- post: {post}"
    -- Destruct the equality
    let (mExpr, ret)  destEq th.consumeMData
    trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}"
    -- Destruct the monadic application to dive into the bind, if necessary (this
    -- is for when we use `withPSpec` inside of the `progress` tactic), and
    -- destruct the application to get the function name
    mExpr.consumeMData.withApp fun mf margs => do
    trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
    let (fExpr, f, args)  do
      if mf.isConst  mf.constName = ``Bind.bind then do
        -- Dive into the bind
        let fExpr := (margs.get! 4).consumeMData
        fExpr.withApp fun f args => pure (fExpr, f, args)
      else pure (mExpr, mf, margs)
    trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}"
    if ¬ f.isConst then throwError "Not a constant: {f}"
    -- Compute the set of universally quantified variables which appear in the function arguments
    let allArgsFVars  args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
    -- Sanity check
    if sanityChecks then
      -- All the variables which appear in the inputs given to the function are
      -- universally quantified (in particular, they are not *existentially* quantified)
      let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
      let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
      if ¬ filtArgsFVars.isEmpty then
        let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId)
        throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}"
    let argsFVars := fvars.map (fun x => x.fvarId!)
    let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar)
    -- Return
    trace[Progress] "Function: {f.constName!}";
    let thDesc := {
      fvars := fvars
      evars := evars
      fExpr
      fName := f.constName!
      fLevels := f.constLevels!
      args := args
      argsFVars
      ret := ret
      post := post
    }
    k thDesc

end Methods

def getPSpecFunName (th : Expr) : MetaM Name :=
  withPSpec th (fun d => do pure d.fName) true

def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) :=
  withPSpec th (fun d => do
    let arg0 := d.args.get! 0
    arg0.withApp fun f _ => do
    if ¬ f.isConst then throwError "Not a constant: {f}"
    pure (d.fName, f.constName)
    ) true

def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) :=
  withPSpec th (fun d => do
    let arg0 := d.args.get! 0
    pure (d.fName, arg0)
    ) true

-- "Regular" pspec attribute
structure PSpecAttr where
  attr : AttributeImpl
  ext  : MapDeclarationExtension Name
  deriving Inhabited

/- pspec attribute for type classes: we use the name of the type class to
   lookup another map. We use the *first* argument of the type class to lookup
   into this second map.

   Example:
   ========
   We use type classes for addition. For instance, the addition between two
   U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence,
   we store the theorem through the bindings: HAdd.add → Scalar → ...

   SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the
   specs of the scalar operations, which is what I really need, but I'm not sure it
   applies well to other situations. A better way would probably to use type classes, but
   I couldn't get them to work on those cases. It is worth retrying.
-/
structure PSpecClassAttr where
  attr : AttributeImpl
  ext  : MapDeclarationExtension (NameMap Name)
  deriving Inhabited

/- Same as `PSpecClassAttr` but we use the full first argument (it works when it
   is a constant). -/
structure PSpecClassExprAttr where
  attr : AttributeImpl
  ext  : MapDeclarationExtension (HashMap Expr Name)
  deriving Inhabited

-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR?
def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) :
  IO (MapDeclarationExtension α) :=
  registerSimplePersistentEnvExtension {
    name          := name,
    addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty,
    addEntryFn    := fun s n => s.insert n.1 n.2 ,
    toArrayFn     := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
  }

-- Declare an extension of maps of maps (using [RBMap]).
-- The important point is that we need to merge the bound values (which are maps).
def mkMapMapDeclarationExtension [Inhabited β] (ord : α  α  Ordering)
  (name : Name := by exact decl_name%) :
  IO (MapDeclarationExtension (RBMap α β ord)) :=
  registerSimplePersistentEnvExtension {
    name          := name,
    addImportedFn := fun a =>
      a.foldl (fun s a => a.foldl (
        -- We need to merge the maps
        fun s (k0, k1_to_v) =>
        match s.find? k0 with
        | none =>
          -- No binding: insert one
          s.insert k0 k1_to_v
        | some m =>
          -- There is already a binding: merge
          let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v
          s.insert k0 m)
          s) RBMap.empty,
    addEntryFn    := fun s n => s.insert n.1 n.2 ,
    toArrayFn     := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
  }

-- Declare an extension of maps of maps (using [HashMap]).
-- The important point is that we need to merge the bound values (which are maps).
def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β]
  (name : Name := by exact decl_name%) :
  IO (MapDeclarationExtension (HashMap α β)) :=
  registerSimplePersistentEnvExtension {
    name          := name,
    addImportedFn := fun a =>
      a.foldl (fun s a => a.foldl (
        -- We need to merge the maps
        fun s (k0, k1_to_v) =>
        match s.find? k0 with
        | none =>
          -- No binding: insert one
          s.insert k0 k1_to_v
        | some m =>
          -- There is already a binding: merge
          let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v
          s.insert k0 m)
          s) RBMap.empty,
    addEntryFn    := fun s n => s.insert n.1 n.2 ,
    toArrayFn     := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
  }

/- The persistent map from function to pspec theorems. -/
initialize pspecAttr : PSpecAttr  do
  let ext  mkMapDeclarationExtension `pspecMap
  let attrImpl : AttributeImpl := {
    name := `pspec
    descr := "Marks theorems to use with the `progress` tactic"
    add := fun thName stx attrKind => do
      Attribute.Builtin.ensureNoArgs stx
      -- TODO: use the attribute kind
      unless attrKind == AttributeKind.global do
        throwError "invalid attribute 'pspec', must be global"
      -- Lookup the theorem
      let env  getEnv
      let thDecl := env.constants.find! thName
      let fName  MetaM.run' (getPSpecFunName thDecl.type)
      trace[Progress] "Registering spec theorem for {fName}"
      let env := ext.addEntry env (fName, thName)
      setEnv env
      pure ()
  }
  registerBuiltinAttribute attrImpl
  pure { attr := attrImpl, ext := ext }

/- The persistent map from type classes to pspec theorems -/
initialize pspecClassAttr : PSpecClassAttr  do
  let ext : MapDeclarationExtension (NameMap Name) 
    mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap
  let attrImpl : AttributeImpl  := {
    name := `cpspec
    descr := "Marks theorems to use for type classes with the `progress` tactic"
    add := fun thName stx attrKind => do
      Attribute.Builtin.ensureNoArgs stx
      -- TODO: use the attribute kind
      unless attrKind == AttributeKind.global do
        throwError "invalid attribute 'cpspec', must be global"
      -- Lookup the theorem
      let env  getEnv
      let thDecl := env.constants.find! thName
      let (fName, argName)  MetaM.run' (getPSpecClassFunNames thDecl.type)
      trace[Progress] "Registering class spec theorem for ({fName}, {argName})"
      -- Update the entry if there is one, add an entry if there is none
      let env :=
        match (ext.getState ( getEnv)).find? fName with
        | none =>
          let m := RBMap.ofList [(argName, thName)]
          ext.addEntry env (fName, m)
        | some m =>
          let m := m.insert argName thName
          ext.addEntry env (fName, m)
      setEnv env
      pure ()
  }
  registerBuiltinAttribute attrImpl
  pure { attr := attrImpl, ext := ext }

/- The 2nd persistent map from type classes to pspec theorems -/
initialize pspecClassExprAttr : PSpecClassExprAttr  do
  let ext : MapDeclarationExtension (HashMap Expr Name) 
    mkMapHashMapDeclarationExtension `pspecClassExprMap
  let attrImpl : AttributeImpl  := {
    name := `cepspec
    descr := "Marks theorems to use for type classes with the `progress` tactic"
    add := fun thName stx attrKind => do
      Attribute.Builtin.ensureNoArgs stx
      -- TODO: use the attribute kind
      unless attrKind == AttributeKind.global do
        throwError "invalid attribute 'cpspec', must be global"
      -- Lookup the theorem
      let env  getEnv
      let thDecl := env.constants.find! thName
      let (fName, arg)  MetaM.run' (getPSpecClassFunNameArg thDecl.type)
      -- Sanity check: no variables appear in the argument
      MetaM.run' do
        let fvars  getFVarIds arg
        if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables"
      -- We store two bindings:
      -- - arg to theorem name
      -- - reduced arg to theorem name
      let rarg  MetaM.run' (reduceAll arg)
      trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})"
      -- Update the entry if there is one, add an entry if there is none
      let env :=
        match (ext.getState ( getEnv)).find? fName with
        | none =>
          let m := HashMap.ofList [(arg, thName), (rarg, thName)]
          ext.addEntry env (fName, m)
        | some m =>
          let m := m.insert arg thName
          let m := m.insert rarg thName
          ext.addEntry env (fName, m)
      setEnv env
      pure ()
  }
  registerBuiltinAttribute attrImpl
  pure { attr := attrImpl, ext := ext }


def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
  return (s.ext.getState ( getEnv)).find? name

def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do
  match (s.ext.getState ( getEnv)).find? className with
  | none => return none
  | some map => return map.find? argName

def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do
  match (s.ext.getState ( getEnv)).find? className with
  | none => return none
  | some map => return map.find? arg

def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do
  pure (s.ext.getState ( getEnv))

def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do
  pure (s.ext.getState ( getEnv))

def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do
  pure (s.ext.getState ( getEnv))

def showStoredPSpec : MetaM Unit := do
  let st  pspecAttr.getState
  let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!""
  IO.println s

def showStoredPSpecClass : MetaM Unit := do
  let st  pspecClassAttr.getState
  let s := st.toList.foldl (fun s (f, m) =>
    let ms := m.toList.foldl (fun s (f, th) =>
      f!"{s}\n  {f} → {th}") f!""
    f!"{s}\n{f} → [{ms}]") f!""
  IO.println s

def showStoredPSpecExprClass : MetaM Unit := do
  let st  pspecClassExprAttr.getState
  let s := st.toList.foldl (fun s (f, m) =>
    let ms := m.toList.foldl (fun s (f, th) =>
      f!"{s}\n  {f} → {th}") f!""
    f!"{s}\n{f} → [{ms}]") f!""
  IO.println s

end Progress