diff options
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Progress/Base.lean | 102 |
1 files changed, 80 insertions, 22 deletions
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 72438d40..3599d866 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -62,24 +62,6 @@ section Methods -- Dive into the quantified variables and the assumptions forallTelescope th fun fvars th => do trace[Progress] "Universally quantified arguments and assumptions: {fvars}" - /- -- Filter the argumens which are not propositions - let rec getFirstPropIdx (i : Nat) : MetaM Nat := do - if i ≥ fargs.size then pure i - else do - let x := fargs.get! i - if ← Meta.isProp (← inferType x) then pure i - else getFirstPropIdx (i + 1) - let i ← getFirstPropIdx 0 - let fvars := fargs.extract 0 i - let hyps := fargs.extract i fargs.size - trace[Progress] "Quantified variables: {fvars}" - trace[Progress] "Assumptions: {hyps}" - -- Sanity check: all hypotheses are propositions (in particular, all the - -- quantified variables are at the beginning) - let hypsAreProp ← hyps.allM fun x => do Meta.isProp (← inferType x) - if ¬ hypsAreProp then - throwError "The theorem doesn't have the proper shape: all the quantified arguments should be at the beginning" - -/ -- Dive into the existentials existsTelescope th fun evars th => do trace[Progress] "Existentials: {evars}" @@ -98,6 +80,8 @@ section Methods let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty -- Sanity check if sanityChecks then + -- All the variables which appear in the inputs given to the function are + -- universally quantified (in particular, they are not *existentially* quantified) let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!)) let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar) if ¬ filtArgsFVars.isEmpty then @@ -132,6 +116,12 @@ def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) := pure (d.fName, f.constName) ) true +def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) := + withPSpec th (fun d => do + let arg0 := d.args.get! 0 + pure (d.fName, arg0) + ) true + -- "Regular" pspec attribute structure PSpecAttr where attr : AttributeImpl @@ -145,14 +135,26 @@ structure PSpecAttr where Example: ======== We use type classes for addition. For instance, the addition between two - U32 is written (without syntactic sugar) as `HAdd.add (Scalar ) x y`. As a consequence, + U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence, we store the theorem through the bindings: HAdd.add → Scalar → ... + + SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the + specs of the scalar operations, which is what I really need, but I'm not sure it + applies well to other situations. A better way would probably to use type classes, but + I couldn't get them to work on those cases. It is worth retrying. -/ structure PSpecClassAttr where attr : AttributeImpl ext : MapDeclarationExtension (NameMap Name) deriving Inhabited +/- Same as `PSpecClassAttr` but we use the full first argument (it works when it + is a constant). -/ +structure PSpecClassExprAttr where + attr : AttributeImpl + ext : MapDeclarationExtension (HashMap Expr Name) + deriving Inhabited + -- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) := registerSimplePersistentEnvExtension { @@ -216,21 +218,69 @@ initialize pspecClassAttr : PSpecClassAttr ← do registerBuiltinAttribute attrImpl pure { attr := attrImpl, ext := ext } +/- The 2nd persistent map from type classes to pspec theorems -/ +initialize pspecClassExprAttr : PSpecClassExprAttr ← do + let ext : MapDeclarationExtension (HashMap Expr Name) ← mkMapDeclarationExtension `pspecClassExprMap + let attrImpl : AttributeImpl := { + name := `cepspec + descr := "Marks theorems to use for type classes with the `progress` tactic" + add := fun thName stx attrKind => do + Attribute.Builtin.ensureNoArgs stx + -- TODO: use the attribute kind + unless attrKind == AttributeKind.global do + throwError "invalid attribute 'cpspec', must be global" + -- Lookup the theorem + let env ← getEnv + let thDecl := env.constants.find! thName + let (fName, arg) ← MetaM.run' (getPSpecClassFunNameArg thDecl.type) + -- Sanity check: no variables appear in the argument + MetaM.run' do + let fvars ← getFVarIds arg + if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables" + -- We store two bindings: + -- - arg to theorem name + -- - reduced arg to theorem name + let rarg ← MetaM.run' (reduce arg) + trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})" + -- Update the entry if there is one, add an entry if there is none + let env := + match (ext.getState (← getEnv)).find? fName with + | none => + let m := HashMap.ofList [(arg, thName), (rarg, thName)] + ext.addEntry env (fName, m) + | some m => + let m := m.insert arg thName + let m := m.insert rarg thName + ext.addEntry env (fName, m) + setEnv env + pure () + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl, ext := ext } + def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do return (s.ext.getState (← getEnv)).find? name -def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do - pure (s.ext.getState (← getEnv)) - def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do match (s.ext.getState (← getEnv)).find? className with | none => return none | some map => return map.find? argName +def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do + match (s.ext.getState (← getEnv)).find? className with + | none => return none + | some map => return map.find? arg + +def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do + pure (s.ext.getState (← getEnv)) + def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do pure (s.ext.getState (← getEnv)) +def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do + pure (s.ext.getState (← getEnv)) + def showStoredPSpec : MetaM Unit := do let st ← pspecAttr.getState let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" @@ -244,4 +294,12 @@ def showStoredPSpecClass : MetaM Unit := do f!"{s}\n{f} → [{ms}]") f!"" IO.println s +def showStoredPSpecExprClass : MetaM Unit := do + let st ← pspecClassExprAttr.getState + let s := st.toList.foldl (fun s (f, m) => + let ms := m.toList.foldl (fun s (f, th) => + f!"{s}\n {f} → {th}") f!"" + f!"{s}\n{f} → [{ms}]") f!"" + IO.println s + end Progress |