diff options
author | Son Ho | 2023-07-17 23:37:31 +0200 |
---|---|---|
committer | Son Ho | 2023-07-17 23:40:38 +0200 |
commit | 3e8060b5501ec83940a4309389a68898df26ebd0 (patch) | |
tree | b02399a2137e8bbe54c181def92b5b43d5e42cf5 | |
parent | 510e409b551f876a28f93f869e108f3f9e761212 (diff) |
Reorganize the Lean backend
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Arith.lean | 3 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Arith.lean | 329 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Int.lean | 236 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Scalar.lean | 48 | ||||
-rw-r--r-- | backends/lean/Base/IList.lean | 127 | ||||
-rw-r--r-- | backends/lean/Base/IList/IList.lean | 142 | ||||
-rw-r--r-- | backends/lean/Base/Primitives.lean | 718 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Base.lean | 130 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 507 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Vec.lean | 113 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 2 |
11 files changed, 1184 insertions, 1171 deletions
diff --git a/backends/lean/Base/Arith.lean b/backends/lean/Base/Arith.lean index fd5698c5..c0d09fd2 100644 --- a/backends/lean/Base/Arith.lean +++ b/backends/lean/Base/Arith.lean @@ -1 +1,2 @@ -import Base.Arith.Arith +import Base.Arith.Int +import Base.Arith.Scalar diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index da263e86..e69de29b 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -1,329 +0,0 @@ -/- This file contains tactics to solve arithmetic goals -/ - -import Lean -import Lean.Meta.Tactic.Simp -import Init.Data.List.Basic -import Mathlib.Tactic.RunCmd -import Mathlib.Tactic.Linarith --- TODO: there is no Omega tactic for now - it seems it hasn't been ported yet ---import Mathlib.Tactic.Omega -import Base.Primitives -import Base.Utils -import Base.Arith.Base - -namespace Arith - -open Primitives Utils - --- TODO: move -/- Remark: we can't write the following instance because of restrictions about - the type class parameters (`ty` doesn't appear in the return type, which is - forbidden): - - ``` - instance Scalar.cast (ty : ScalarTy) : Coe (Scalar ty) Int where coe := λ v => v.val - ``` - -/ -def Scalar.toInt {ty : ScalarTy} (x : Scalar ty) : Int := x.val - --- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)` --- but the lookup didn't work -class HasScalarProp (a : Sort u) where - prop_ty : a → Prop - prop : ∀ x:a, prop_ty x - -class HasIntProp (a : Sort u) where - prop_ty : a → Prop - prop : ∀ x:a, prop_ty x - -instance (ty : ScalarTy) : HasScalarProp (Scalar ty) where - -- prop_ty is inferred - prop := λ x => And.intro x.hmin x.hmax - -instance (a : Type) : HasScalarProp (Vec a) where - prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize - prop := λ ⟨ _, l ⟩ => l - -class PropHasImp (x : Prop) where - concl : Prop - prop : x → concl - --- This also works for `x ≠ y` because this expression reduces to `¬ x = y` --- and `Ne` is marked as `reducible` -instance (x y : Int) : PropHasImp (¬ x = y) where - concl := x < y ∨ x > y - prop := λ (h:x ≠ y) => ne_is_lt_or_gt h - -open Lean Lean.Elab Command Term Lean.Meta - --- Small utility: print all the declarations in the context -elab "print_all_decls" : tactic => do - let ctx ← Lean.MonadLCtx.getLCtx - for decl in ← ctx.getDecls do - let ty ← Lean.Meta.inferType decl.toExpr - logInfo m!"{decl.toExpr} : {ty}" - pure () - --- Explore a term by decomposing the applications (we explore the applied --- functions and their arguments, but ignore lambdas, forall, etc. - --- should we go inside?). -partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do - -- We do it in a very simpler manner: we deconstruct applications, - -- and recursively explore the sub-expressions. Note that we do - -- not go inside foralls and abstractions (should we?). - e.withApp fun f args => do - let s ← k s f - args.foldlM (foldTermApps k) s - --- Provided a function `k` which lookups type class instances on an expression, --- collect all the instances lookuped by applying `k` on the sub-expressions of `e`. -def collectInstances - (k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do - let k s e := do - match ← k e with - | none => pure s - | some i => pure (s.insert i) - foldTermApps k s e - --- Similar to `collectInstances`, but explores all the local declarations in the --- main context. -def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do - Tactic.withMainContext do - -- Get the local context - let ctx ← Lean.MonadLCtx.getLCtx - -- Just a matter of precaution - let ctx ← instantiateLCtxMVars ctx - -- Initialize the hashset - let hs := HashSet.empty - -- Explore the declarations - let decls ← ctx.getDecls - decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs - --- Helper -def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do - trace[Arith] fName - -- TODO: do we need Lean.observing? - -- This actually eliminates the error messages - Lean.observing? do - trace[Arith] m!"{fName}: observing" - let ty ← Lean.Meta.inferType e - let hasProp ← mkAppM className #[ty] - let hasPropInst ← trySynthInstance hasProp - match hasPropInst with - | LOption.some i => - trace[Arith] "Found HasScalarProp instance" - let i_prop ← mkProjection i (Name.mkSimple "prop") - some (← mkAppM' i_prop #[e]) - | _ => none - --- Return an instance of `HasIntProp` for `e` if it has some -def lookupHasIntProp (e : Expr) : MetaM (Option Expr) := - lookupProp "lookupHasScalarProp" ``HasIntProp e - --- Return an instance of `HasScalarProp` for `e` if it has some -def lookupHasScalarProp (e : Expr) : MetaM (Option Expr) := - lookupProp "lookupHasScalarProp" ``HasScalarProp e - --- Collect the instances of `HasIntProp` for the subexpressions in the context -def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do - collectInstancesFromMainCtx lookupHasIntProp - --- Collect the instances of `HasScalarProp` for the subexpressions in the context -def collectHasScalarPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do - collectInstancesFromMainCtx lookupHasScalarProp - -elab "display_has_prop_instances" : tactic => do - trace[Arith] "Displaying the HasScalarProp instances" - let hs ← collectHasScalarPropInstancesFromMainCtx - hs.forM fun e => do - trace[Arith] "+ HasScalarProp instance: {e}" - -example (x : U32) : True := by - let i : HasScalarProp U32 := inferInstance - have p := @HasScalarProp.prop _ i x - simp only [HasScalarProp.prop_ty] at p - display_has_prop_instances - simp - --- Return an instance of `PropHasImp` for `e` if it has some -def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do - trace[Arith] "lookupPropHasImp" - -- TODO: do we need Lean.observing? - -- This actually eliminates the error messages - Lean.observing? do - trace[Arith] "lookupPropHasImp: observing" - let ty ← Lean.Meta.inferType e - trace[Arith] "lookupPropHasImp: ty: {ty}" - let cl ← mkAppM ``PropHasImp #[ty] - let inst ← trySynthInstance cl - match inst with - | LOption.some i => - trace[Arith] "Found PropHasImp instance" - let i_prop ← mkProjection i (Name.mkSimple "prop") - some (← mkAppM' i_prop #[e]) - | _ => none - --- Collect the instances of `PropHasImp` for the subexpressions in the context -def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do - collectInstancesFromMainCtx lookupPropHasImp - -elab "display_prop_has_imp_instances" : tactic => do - trace[Arith] "Displaying the PropHasImp instances" - let hs ← collectPropHasImpInstancesFromMainCtx - hs.forM fun e => do - trace[Arith] "+ PropHasImp instance: {e}" - -example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by - display_prop_has_imp_instances - simp - --- Lookup instances in a context and introduce them with additional declarations. -def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do - let hs ← collectInstancesFromMainCtx lookup - hs.toArray.mapM fun e => do - let type ← inferType e - let name ← mkFreshUserName `h - -- Add a declaration - let nval ← Utils.addDeclTac name e type (asLet := false) - -- Simplify to unfold the declaration to unfold (i.e., the projector) - Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false) - -- Return the new value - pure nval - -def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do - trace[Arith] "Introducing the HasIntProp instances" - introInstances ``HasIntProp.prop_ty lookupHasIntProp - -def introHasScalarPropInstances : Tactic.TacticM (Array Expr) := do - trace[Arith] "Introducing the HasScalarProp instances" - introInstances ``HasScalarProp.prop_ty lookupHasScalarProp - --- Lookup the instances of `HasScalarProp for all the sub-expressions in the context, --- and introduce the corresponding assumptions -elab "intro_has_prop_instances" : tactic => do - let _ ← introHasScalarPropInstances - -example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by - intro_has_prop_instances - simp [*] - -example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by - intro_has_prop_instances - simp_all [Scalar.max, Scalar.min] - --- Lookup the instances of `PropHasImp for all the sub-expressions in the context, --- and introduce the corresponding assumptions -elab "intro_prop_has_imp_instances" : tactic => do - trace[Arith] "Introducing the PropHasImp instances" - let _ ← introInstances ``PropHasImp.concl lookupPropHasImp - -example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by - intro_prop_has_imp_instances - rename_i h - split_disj h - . linarith - . linarith - -/- Boosting a bit the linarith tac. - - We do the following: - - for all the assumptions of the shape `(x : Int) ≠ y` or `¬ (x = y), we - introduce two goals with the assumptions `x < y` and `x > y` - TODO: we could create a PR for mathlib. - -/ -def intTacPreprocess : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Lookup the instances of PropHasImp (this is how we detect assumptions - -- of the proper shape), introduce assumptions in the context and split - -- on those - -- TODO: get rid of the assumptions that we split - let rec splitOnAsms (asms : List Expr) : Tactic.TacticM Unit := - match asms with - | [] => pure () - | asm :: asms => - let k := splitOnAsms asms - Utils.splitDisjTac asm k k - -- Introduce - let _ ← introHasIntPropInstances - let asms ← introInstances ``PropHasImp.concl lookupPropHasImp - -- Split - splitOnAsms asms.toList - -elab "int_tac_preprocess" : tactic => - intTacPreprocess - -def intTac : Tactic.TacticM Unit := do - Tactic.withMainContext do - Tactic.focus do - -- Preprocess - wondering if we should do this before or after splitting - -- the goal. I think before leads to a smaller proof term? - Tactic.allGoals intTacPreprocess - -- Split the conjunctions in the goal - Utils.repeatTac Utils.splitConjTarget - -- Call linarith - let linarith := - let cfg : Linarith.LinarithConfig := { - -- We do this with our custom preprocessing - splitNe := false - } - Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg - Tactic.allGoals linarith - -elab "int_tac" : tactic => - intTac - -example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by - int_tac_preprocess - linarith - linarith - -example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by - int_tac - --- Checking that things append correctly when there are several disjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by - int_tac - --- Checking that things append correctly when there are several disjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by - int_tac - -def scalarTacPreprocess (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Introduce the scalar bounds - let _ ← introHasScalarPropInstances - Tactic.allGoals do - -- Inroduce the bounds for the isize/usize types - let add (e : Expr) : Tactic.TacticM Unit := do - let ty ← inferType e - let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) - add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) - -- Reveal the concrete bounds - Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, - ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, - ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, - ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, - ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max - ] [] [] .wildcard - -- Finish the proof - tac - -elab "scalar_tac_preprocess" : tactic => - scalarTacPreprocess intTacPreprocess - --- A tactic to solve linear arithmetic goals in the presence of scalars -def scalarTac : Tactic.TacticM Unit := do - scalarTacPreprocess intTac - -elab "scalar_tac" : tactic => - scalarTac - -example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by - scalar_tac - -example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by - scalar_tac - -end Arith diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean new file mode 100644 index 00000000..5f00ab52 --- /dev/null +++ b/backends/lean/Base/Arith/Int.lean @@ -0,0 +1,236 @@ +/- This file contains tactics to solve arithmetic goals -/ + +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith +-- TODO: there is no Omega tactic for now - it seems it hasn't been ported yet +--import Mathlib.Tactic.Omega +import Base.Utils +import Base.Arith.Base + +namespace Arith + +open Utils + +-- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)` +-- but the lookup didn't work +class HasIntProp (a : Sort u) where + prop_ty : a → Prop + prop : ∀ x:a, prop_ty x + +class PropHasImp (x : Prop) where + concl : Prop + prop : x → concl + +-- This also works for `x ≠ y` because this expression reduces to `¬ x = y` +-- and `Ne` is marked as `reducible` +instance (x y : Int) : PropHasImp (¬ x = y) where + concl := x < y ∨ x > y + prop := λ (h:x ≠ y) => ne_is_lt_or_gt h + +open Lean Lean.Elab Lean.Meta + +-- Small utility: print all the declarations in the context +elab "print_all_decls" : tactic => do + let ctx ← Lean.MonadLCtx.getLCtx + for decl in ← ctx.getDecls do + let ty ← Lean.Meta.inferType decl.toExpr + logInfo m!"{decl.toExpr} : {ty}" + pure () + +-- Explore a term by decomposing the applications (we explore the applied +-- functions and their arguments, but ignore lambdas, forall, etc. - +-- should we go inside?). +partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do + -- We do it in a very simpler manner: we deconstruct applications, + -- and recursively explore the sub-expressions. Note that we do + -- not go inside foralls and abstractions (should we?). + e.withApp fun f args => do + let s ← k s f + args.foldlM (foldTermApps k) s + +-- Provided a function `k` which lookups type class instances on an expression, +-- collect all the instances lookuped by applying `k` on the sub-expressions of `e`. +def collectInstances + (k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do + let k s e := do + match ← k e with + | none => pure s + | some i => pure (s.insert i) + foldTermApps k s e + +-- Similar to `collectInstances`, but explores all the local declarations in the +-- main context. +def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do + Tactic.withMainContext do + -- Get the local context + let ctx ← Lean.MonadLCtx.getLCtx + -- Just a matter of precaution + let ctx ← instantiateLCtxMVars ctx + -- Initialize the hashset + let hs := HashSet.empty + -- Explore the declarations + let decls ← ctx.getDecls + decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs + +-- Helper +def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do + trace[Arith] fName + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] m!"{fName}: observing" + let ty ← Lean.Meta.inferType e + let hasProp ← mkAppM className #[ty] + let hasPropInst ← trySynthInstance hasProp + match hasPropInst with + | LOption.some i => + trace[Arith] "Found {fName} instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Return an instance of `HasIntProp` for `e` if it has some +def lookupHasIntProp (e : Expr) : MetaM (Option Expr) := + lookupProp "lookupHasIntProp" ``HasIntProp e + +-- Collect the instances of `HasIntProp` for the subexpressions in the context +def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupHasIntProp + +-- Return an instance of `PropHasImp` for `e` if it has some +def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do + trace[Arith] "lookupPropHasImp" + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] "lookupPropHasImp: observing" + let ty ← Lean.Meta.inferType e + trace[Arith] "lookupPropHasImp: ty: {ty}" + let cl ← mkAppM ``PropHasImp #[ty] + let inst ← trySynthInstance cl + match inst with + | LOption.some i => + trace[Arith] "Found PropHasImp instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Collect the instances of `PropHasImp` for the subexpressions in the context +def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupPropHasImp + +elab "display_prop_has_imp_instances" : tactic => do + trace[Arith] "Displaying the PropHasImp instances" + let hs ← collectPropHasImpInstancesFromMainCtx + hs.forM fun e => do + trace[Arith] "+ PropHasImp instance: {e}" + +example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by + display_prop_has_imp_instances + simp + +-- Lookup instances in a context and introduce them with additional declarations. +def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do + let hs ← collectInstancesFromMainCtx lookup + hs.toArray.mapM fun e => do + let type ← inferType e + let name ← mkFreshUserName `h + -- Add a declaration + let nval ← Utils.addDeclTac name e type (asLet := false) + -- Simplify to unfold the declaration to unfold (i.e., the projector) + Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false) + -- Return the new value + pure nval + +def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasIntProp instances" + introInstances ``HasIntProp.prop_ty lookupHasIntProp + +-- Lookup the instances of `HasIntProp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_has_int_prop_instances" : tactic => do + let _ ← introHasIntPropInstances + +-- Lookup the instances of `PropHasImp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_prop_has_imp_instances" : tactic => do + trace[Arith] "Introducing the PropHasImp instances" + let _ ← introInstances ``PropHasImp.concl lookupPropHasImp + +example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by + intro_prop_has_imp_instances + rename_i h + split_disj h + . linarith + . linarith + +/- Boosting a bit the linarith tac. + + We do the following: + - for all the assumptions of the shape `(x : Int) ≠ y` or `¬ (x = y), we + introduce two goals with the assumptions `x < y` and `x > y` + TODO: we could create a PR for mathlib. + -/ +def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Lookup the instances of PropHasImp (this is how we detect assumptions + -- of the proper shape), introduce assumptions in the context and split + -- on those + -- TODO: get rid of the assumptions that we split + let rec splitOnAsms (asms : List Expr) : Tactic.TacticM Unit := + match asms with + | [] => pure () + | asm :: asms => + let k := splitOnAsms asms + Utils.splitDisjTac asm k k + -- Introduce the scalar bounds + let _ ← introHasIntPropInstances + -- Extra preprocessing, before we split on the disjunctions + extraPreprocess + -- Split + let asms ← introInstances ``PropHasImp.concl lookupPropHasImp + splitOnAsms asms.toList + +elab "int_tac_preprocess" : tactic => + intTacPreprocess (do pure ()) + +def intTac (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + Tactic.withMainContext do + Tactic.focus do + -- Preprocess - wondering if we should do this before or after splitting + -- the goal. I think before leads to a smaller proof term? + Tactic.allGoals (intTacPreprocess extraPreprocess) + -- Split the conjunctions in the goal + Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) + -- Call linarith + let linarith := + let cfg : Linarith.LinarithConfig := { + -- We do this with our custom preprocessing + splitNe := false + } + Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg + Tactic.allGoals linarith + +elab "int_tac" : tactic => + intTac (do pure ()) + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac_preprocess + linarith + linarith + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac + +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by + int_tac + +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by + int_tac + +end Arith diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean new file mode 100644 index 00000000..f8903ecf --- /dev/null +++ b/backends/lean/Base/Arith/Scalar.lean @@ -0,0 +1,48 @@ +import Base.Arith.Int +import Base.Primitives.Scalar + +/- Automation for scalars - TODO: not sure it is worth having two files (Int.lean and Scalar.lean) -/ +namespace Arith + +open Lean Lean.Elab Lean.Meta +open Primitives + +def scalarTacExtraPreprocess : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds + Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, + ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, + ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, + ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max + ] [] [] .wildcard + +elab "scalar_tac_preprocess" : tactic => + intTacPreprocess scalarTacExtraPreprocess + +-- A tactic to solve linear arithmetic goals in the presence of scalars +def scalarTac : Tactic.TacticM Unit := do + intTac scalarTacExtraPreprocess + +elab "scalar_tac" : tactic => + scalarTac + +instance (ty : ScalarTy) : HasIntProp (Scalar ty) where + -- prop_ty is inferred + prop := λ x => And.intro x.hmin x.hmax + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + intro_has_int_prop_instances + simp [*] + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + scalar_tac + +end Arith diff --git a/backends/lean/Base/IList.lean b/backends/lean/Base/IList.lean index 3db00cbb..31b66ffa 100644 --- a/backends/lean/Base/IList.lean +++ b/backends/lean/Base/IList.lean @@ -1,126 +1 @@ -/- Complementary list functions and lemmas which operate on integers rather - than natural numbers. -/ - -import Std.Data.Int.Lemmas -import Mathlib.Tactic.Linarith -import Base.Arith - -namespace List - -def len (ls : List α) : Int := - match ls with - | [] => 0 - | _ :: tl => 1 + len tl - --- Remark: if i < 0, then the result is none -def optIndex (i : Int) (ls : List α) : Option α := - match ls with - | [] => none - | hd :: tl => if i = 0 then some hd else optIndex (i - 1) tl - --- Remark: if i < 0, then the result is the defaul element -def index [Inhabited α] (i : Int) (ls : List α) : α := - match ls with - | [] => Inhabited.default - | x :: tl => - if i = 0 then x else index (i - 1) tl - --- Remark: the list is unchanged if the index is not in bounds (in particular --- if it is < 0) -def update (ls : List α) (i : Int) (y : α) : List α := - match ls with - | [] => [] - | x :: tl => if i = 0 then y :: tl else x :: update tl (i - 1) y - --- Remark: the whole list is dropped if the index is not in bounds (in particular --- if it is < 0) -def idrop (i : Int) (ls : List α) : List α := - match ls with - | [] => [] - | x :: tl => if i = 0 then x :: tl else idrop (i - 1) tl - -@[simp] theorem len_nil : len ([] : List α) = 0 := by simp [len] -@[simp] theorem len_cons : len ((x :: tl) : List α) = 1 + len tl := by simp [len] - -@[simp] theorem index_zero_cons [Inhabited α] : index 0 ((x :: tl) : List α) = x := by simp [index] -@[simp] theorem index_nzero_cons [Inhabited α] (hne : i ≠ 0) : index i ((x :: tl) : List α) = index (i - 1) tl := by simp [*, index] - -@[simp] theorem update_nil : update ([] : List α) i y = [] := by simp [update] -@[simp] theorem update_zero_cons : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update] -@[simp] theorem update_nzero_cons (hne : i ≠ 0) : update ((x :: tl) : List α) i y = x :: update tl (i - 1) y := by simp [*, update] - -@[simp] theorem idrop_nil : idrop i ([] : List α) = [] := by simp [idrop] -@[simp] theorem idrop_zero : idrop 0 (ls : List α) = ls := by cases ls <;> simp [idrop] -@[simp] theorem idrop_nzero_cons (hne : i ≠ 0) : idrop i ((x :: tl) : List α) = idrop (i - 1) tl := by simp [*, idrop] - -theorem len_eq_length (ls : List α) : ls.len = ls.length := by - induction ls - . rfl - . simp [*, Int.ofNat_succ, Int.add_comm] - -theorem len_pos : 0 ≤ (ls : List α).len := by - induction ls <;> simp [*] - linarith - -instance (a : Type u) : Arith.HasIntProp (List a) where - prop_ty := λ ls => 0 ≤ ls.len - prop := λ ls => ls.len_pos - -@[simp] theorem len_append (l1 l2 : List α) : (l1 ++ l2).len = l1.len + l2.len := by - -- Remark: simp loops here because of the following rewritings: - -- @Nat.cast_add: ↑(List.length l1 + List.length l2) ==> ↑(List.length l1) + ↑(List.length l2) - -- Int.ofNat_add_ofNat: ↑(List.length l1) + ↑(List.length l2) ==> ↑(List.length l1 + List.length l2) - -- TODO: post an issue? - simp only [len_eq_length] - simp only [length_append] - simp only [Int.ofNat_add] - -theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) : - l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by - revert l1' - induction l1 - . intro l1'; cases l1' <;> simp [*] - . intro l1'; cases l1' <;> simp_all; tauto - -theorem right_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.length = l2'.length) : - l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by - have := left_length_eq_append_eq l1 l2 l1' l2' - constructor <;> intro heq2 <;> - have : l1.length + l2.length = l1'.length + l2'.length := by - have : (l1 ++ l2).length = (l1' ++ l2').length := by simp [*] - simp only [length_append] at this - apply this - . simp [heq] at this - tauto - . tauto - -theorem left_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.len = l1'.len) : - l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by - simp [len_eq_length] at heq - apply left_length_eq_append_eq - assumption - -theorem right_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.len = l2'.len) : - l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by - simp [len_eq_length] at heq - apply right_length_eq_append_eq - assumption - -open Arith in -theorem idrop_eq_nil_of_le (hineq : ls.len ≤ i) : idrop i ls = [] := by - revert i - induction ls <;> simp [*] - rename_i hd tl hi - intro i hineq - if heq: i = 0 then - simp [*] at * - have := tl.len_pos - linarith - else - simp at hineq - have : 0 < i := by int_tac - simp [*] - apply hi - linarith - -end List +import Base.IList.IList diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean new file mode 100644 index 00000000..2a335cac --- /dev/null +++ b/backends/lean/Base/IList/IList.lean @@ -0,0 +1,142 @@ +/- Complementary list functions and lemmas which operate on integers rather + than natural numbers. -/ + +import Std.Data.Int.Lemmas +import Base.Arith + +namespace List + +def len (ls : List α) : Int := + match ls with + | [] => 0 + | _ :: tl => 1 + len tl + +-- Remark: if i < 0, then the result is none +def indexOpt (ls : List α) (i : Int) : Option α := + match ls with + | [] => none + | hd :: tl => if i = 0 then some hd else indexOpt tl (i - 1) + +-- Remark: if i < 0, then the result is the defaul element +def index [Inhabited α] (ls : List α) (i : Int) : α := + match ls with + | [] => Inhabited.default + | x :: tl => + if i = 0 then x else index tl (i - 1) + +-- Remark: the list is unchanged if the index is not in bounds (in particular +-- if it is < 0) +def update (ls : List α) (i : Int) (y : α) : List α := + match ls with + | [] => [] + | x :: tl => if i = 0 then y :: tl else x :: update tl (i - 1) y + +-- Remark: the whole list is dropped if the index is not in bounds (in particular +-- if it is < 0) +def idrop (i : Int) (ls : List α) : List α := + match ls with + | [] => [] + | x :: tl => if i = 0 then x :: tl else idrop (i - 1) tl + +section Lemmas + +variable {α : Type u} + +@[simp] theorem len_nil : len ([] : List α) = 0 := by simp [len] +@[simp] theorem len_cons : len ((x :: tl) : List α) = 1 + len tl := by simp [len] + +@[simp] theorem index_zero_cons [Inhabited α] : index ((x :: tl) : List α) 0 = x := by simp [index] +@[simp] theorem index_nzero_cons [Inhabited α] (hne : i ≠ 0) : index ((x :: tl) : List α) i = index tl (i - 1) := by simp [*, index] + +@[simp] theorem update_nil : update ([] : List α) i y = [] := by simp [update] +@[simp] theorem update_zero_cons : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update] +@[simp] theorem update_nzero_cons (hne : i ≠ 0) : update ((x :: tl) : List α) i y = x :: update tl (i - 1) y := by simp [*, update] + +@[simp] theorem idrop_nil : idrop i ([] : List α) = [] := by simp [idrop] +@[simp] theorem idrop_zero : idrop 0 (ls : List α) = ls := by cases ls <;> simp [idrop] +@[simp] theorem idrop_nzero_cons (hne : i ≠ 0) : idrop i ((x :: tl) : List α) = idrop (i - 1) tl := by simp [*, idrop] + +theorem len_eq_length (ls : List α) : ls.len = ls.length := by + induction ls + . rfl + . simp [*, Int.ofNat_succ, Int.add_comm] + +@[simp] theorem len_append (l1 l2 : List α) : (l1 ++ l2).len = l1.len + l2.len := by + -- Remark: simp loops here because of the following rewritings: + -- @Nat.cast_add: ↑(List.length l1 + List.length l2) ==> ↑(List.length l1) + ↑(List.length l2) + -- Int.ofNat_add_ofNat: ↑(List.length l1) + ↑(List.length l2) ==> ↑(List.length l1 + List.length l2) + -- TODO: post an issue? + simp only [len_eq_length] + simp only [length_append] + simp only [Int.ofNat_add] + +@[simp] +theorem length_update (ls : List α) (i : Int) (x : α) : (ls.update i x).length = ls.length := by + revert i + induction ls <;> simp_all [length, update] + intro; split <;> simp [*] + +@[simp] +theorem len_update (ls : List α) (i : Int) (x : α) : (ls.update i x).len = ls.len := by + simp [len_eq_length] + + +theorem len_pos : 0 ≤ (ls : List α).len := by + induction ls <;> simp [*] + linarith + +instance (a : Type u) : Arith.HasIntProp (List a) where + prop_ty := λ ls => 0 ≤ ls.len + prop := λ ls => ls.len_pos + +theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + revert l1' + induction l1 + . intro l1'; cases l1' <;> simp [*] + . intro l1'; cases l1' <;> simp_all; tauto + +theorem right_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.length = l2'.length) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + have := left_length_eq_append_eq l1 l2 l1' l2' + constructor <;> intro heq2 <;> + have : l1.length + l2.length = l1'.length + l2'.length := by + have : (l1 ++ l2).length = (l1' ++ l2').length := by simp [*] + simp only [length_append] at this + apply this + . simp [heq] at this + tauto + . tauto + +theorem left_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.len = l1'.len) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + simp [len_eq_length] at heq + apply left_length_eq_append_eq + assumption + +theorem right_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.len = l2'.len) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + simp [len_eq_length] at heq + apply right_length_eq_append_eq + assumption + +open Arith in +theorem idrop_eq_nil_of_le (hineq : ls.len ≤ i) : idrop i ls = [] := by + revert i + induction ls <;> simp [*] + rename_i hd tl hi + intro i hineq + if heq: i = 0 then + simp [*] at * + have := tl.len_pos + linarith + else + simp at hineq + have : 0 < i := by int_tac + simp [*] + apply hi + linarith + +end Lemmas + +end List diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index 1a0c665d..91823cb6 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -1,715 +1,3 @@ -import Lean -import Lean.Meta.Tactic.Simp -import Init.Data.List.Basic -import Mathlib.Tactic.RunCmd -import Mathlib.Tactic.Linarith - -namespace Primitives - --------------------- --- ASSERT COMMAND --Std. --------------------- - -open Lean Elab Command Term Meta - -syntax (name := assert) "#assert" term: command - -@[command_elab assert] -unsafe -def assertImpl : CommandElab := fun (_stx: Syntax) => do - runTermElabM (fun _ => do - let r ← evalTerm Bool (mkConst ``Bool) _stx[1] - if not r then - logInfo ("Assertion failed for:\n" ++ _stx[1]) - throwError ("Expression reduced to false:\n" ++ _stx[1]) - pure ()) - -#eval 2 == 2 -#assert (2 == 2) - -------------- --- PRELUDE -- -------------- - --- Results & monadic combinators - -inductive Error where - | assertionFailure: Error - | integerOverflow: Error - | divisionByZero: Error - | arrayOutOfBounds: Error - | maximumSizeExceeded: Error - | panic: Error -deriving Repr, BEq - -open Error - -inductive Result (α : Type u) where - | ret (v: α): Result α - | fail (e: Error): Result α - | div -deriving Repr, BEq - -open Result - -instance Result_Inhabited (α : Type u) : Inhabited (Result α) := - Inhabited.mk (fail panic) - -instance Result_Nonempty (α : Type u) : Nonempty (Result α) := - Nonempty.intro div - -/- HELPERS -/ - -def ret? {α: Type u} (r: Result α): Bool := - match r with - | ret _ => true - | fail _ | div => false - -def div? {α: Type u} (r: Result α): Bool := - match r with - | div => true - | ret _ | fail _ => false - -def massert (b:Bool) : Result Unit := - if b then ret () else fail assertionFailure - -def eval_global {α: Type u} (x: Result α) (_: ret? x): α := - match x with - | fail _ | div => by contradiction - | ret x => x - -/- DO-DSL SUPPORT -/ - -def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := - match x with - | ret v => f v - | fail v => fail v - | div => div - --- Allows using Result in do-blocks -instance : Bind Result where - bind := bind - --- Allows using return x in do-blocks -instance : Pure Result where - pure := fun x => ret x - -@[simp] theorem bind_ret (x : α) (f : α → Result β) : bind (.ret x) f = f x := by simp [bind] -@[simp] theorem bind_fail (x : Error) (f : α → Result β) : bind (.fail x) f = .fail x := by simp [bind] -@[simp] theorem bind_div (f : α → Result β) : bind .div f = .div := by simp [bind] - -/- CUSTOM-DSL SUPPORT -/ - --- Let-binding the Result of a monadic operation is oftentimes not sufficient, --- because we may need a hypothesis for equational reasoning in the scope. We --- rely on subtype, and a custom let-binding operator, in effect recreating our --- own variant of the do-dsl - -def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := - match o with - | ret x => ret ⟨x, rfl⟩ - | fail e => fail e - | div => div - -@[simp] theorem bind_tc_ret (x : α) (f : α → Result β) : - (do let y ← .ret x; f y) = f x := by simp [Bind.bind, bind] - -@[simp] theorem bind_tc_fail (x : Error) (f : α → Result β) : - (do let y ← fail x; f y) = fail x := by simp [Bind.bind, bind] - -@[simp] theorem bind_tc_div (f : α → Result β) : - (do let y ← div; f y) = div := by simp [Bind.bind, bind] - ----------------------- --- MACHINE INTEGERS -- ----------------------- - --- We redefine our machine integers types. - --- For Isize/Usize, we reuse `getNumBits` from `USize`. You cannot reduce `getNumBits` --- using the simplifier, meaning that proofs do not depend on the compile-time value of --- USize.size. (Lean assumes 32 or 64-bit platforms, and Rust doesn't really support, at --- least officially, 16-bit microcontrollers, so this seems like a fine design decision --- for now.) - --- Note from Chris Bailey: "If there's more than one salient property of your --- definition then the subtyping strategy might get messy, and the property part --- of a subtype is less discoverable by the simplifier or tactics like --- library_search." So, we will not add refinements on the return values of the --- operations defined on Primitives, but will rather rely on custom lemmas to --- invert on possible return values of the primitive operations. - --- Machine integer constants, done via `ofNatCore`, which requires a proof that --- the `Nat` fits within the desired integer type. We provide a custom tactic. - -open System.Platform.getNumBits - --- TODO: is there a way of only importing System.Platform.getNumBits? --- -@[simp] def size_num_bits : Nat := (System.Platform.getNumBits ()).val - --- Remark: Lean seems to use < for the comparisons with the upper bounds by convention. - --- The "structured" bounds -def Isize.smin : Int := - (HPow.hPow 2 (size_num_bits - 1)) -def Isize.smax : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1 -def I8.smin : Int := - (HPow.hPow 2 7) -def I8.smax : Int := HPow.hPow 2 7 - 1 -def I16.smin : Int := - (HPow.hPow 2 15) -def I16.smax : Int := HPow.hPow 2 15 - 1 -def I32.smin : Int := -(HPow.hPow 2 31) -def I32.smax : Int := HPow.hPow 2 31 - 1 -def I64.smin : Int := -(HPow.hPow 2 63) -def I64.smax : Int := HPow.hPow 2 63 - 1 -def I128.smin : Int := -(HPow.hPow 2 127) -def I128.smax : Int := HPow.hPow 2 127 - 1 -def Usize.smin : Int := 0 -def Usize.smax : Int := HPow.hPow 2 size_num_bits - 1 -def U8.smin : Int := 0 -def U8.smax : Int := HPow.hPow 2 8 - 1 -def U16.smin : Int := 0 -def U16.smax : Int := HPow.hPow 2 16 - 1 -def U32.smin : Int := 0 -def U32.smax : Int := HPow.hPow 2 32 - 1 -def U64.smin : Int := 0 -def U64.smax : Int := HPow.hPow 2 64 - 1 -def U128.smin : Int := 0 -def U128.smax : Int := HPow.hPow 2 128 - 1 - --- The "normalized" bounds, that we use in practice -def I8.min := -128 -def I8.max := 127 -def I16.min := -32768 -def I16.max := 32767 -def I32.min := -2147483648 -def I32.max := 2147483647 -def I64.min := -9223372036854775808 -def I64.max := 9223372036854775807 -def I128.min := -170141183460469231731687303715884105728 -def I128.max := 170141183460469231731687303715884105727 -@[simp] def U8.min := 0 -def U8.max := 255 -@[simp] def U16.min := 0 -def U16.max := 65535 -@[simp] def U32.min := 0 -def U32.max := 4294967295 -@[simp] def U64.min := 0 -def U64.max := 18446744073709551615 -@[simp] def U128.min := 0 -def U128.max := 340282366920938463463374607431768211455 -@[simp] def Usize.min := 0 - -def Isize.refined_min : { n:Int // n = I32.min ∨ n = I64.min } := - ⟨ Isize.smin, by - simp [Isize.smin] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] ⟩ - -def Isize.refined_max : { n:Int // n = I32.max ∨ n = I64.max } := - ⟨ Isize.smax, by - simp [Isize.smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] ⟩ - -def Usize.refined_max : { n:Int // n = U32.max ∨ n = U64.max } := - ⟨ Usize.smax, by - simp [Usize.smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] ⟩ - -def Isize.min := Isize.refined_min.val -def Isize.max := Isize.refined_max.val -def Usize.max := Usize.refined_max.val - -inductive ScalarTy := -| Isize -| I8 -| I16 -| I32 -| I64 -| I128 -| Usize -| U8 -| U16 -| U32 -| U64 -| U128 - -def Scalar.smin (ty : ScalarTy) : Int := - match ty with - | .Isize => Isize.smin - | .I8 => I8.smin - | .I16 => I16.smin - | .I32 => I32.smin - | .I64 => I64.smin - | .I128 => I128.smin - | .Usize => Usize.smin - | .U8 => U8.smin - | .U16 => U16.smin - | .U32 => U32.smin - | .U64 => U64.smin - | .U128 => U128.smin - -def Scalar.smax (ty : ScalarTy) : Int := - match ty with - | .Isize => Isize.smax - | .I8 => I8.smax - | .I16 => I16.smax - | .I32 => I32.smax - | .I64 => I64.smax - | .I128 => I128.smax - | .Usize => Usize.smax - | .U8 => U8.smax - | .U16 => U16.smax - | .U32 => U32.smax - | .U64 => U64.smax - | .U128 => U128.smax - -def Scalar.min (ty : ScalarTy) : Int := - match ty with - | .Isize => Isize.min - | .I8 => I8.min - | .I16 => I16.min - | .I32 => I32.min - | .I64 => I64.min - | .I128 => I128.min - | .Usize => Usize.min - | .U8 => U8.min - | .U16 => U16.min - | .U32 => U32.min - | .U64 => U64.min - | .U128 => U128.min - -def Scalar.max (ty : ScalarTy) : Int := - match ty with - | .Isize => Isize.max - | .I8 => I8.max - | .I16 => I16.max - | .I32 => I32.max - | .I64 => I64.max - | .I128 => I128.max - | .Usize => Usize.max - | .U8 => U8.max - | .U16 => U16.max - | .U32 => U32.max - | .U64 => U64.max - | .U128 => U128.max - -def Scalar.smin_eq (ty : ScalarTy) : Scalar.min ty = Scalar.smin ty := by - cases ty <;> rfl - -def Scalar.smax_eq (ty : ScalarTy) : Scalar.max ty = Scalar.smax ty := by - cases ty <;> rfl - --- "Conservative" bounds --- We use those because we can't compare to the isize bounds (which can't --- reduce at compile-time). Whenever we perform an arithmetic operation like --- addition we need to check that the result is in bounds: we first compare --- to the conservative bounds, which reduce, then compare to the real bounds. --- This is useful for the various #asserts that we want to reduce at --- type-checking time. -def Scalar.cMin (ty : ScalarTy) : Int := - match ty with - | .Isize => Scalar.min .I32 - | _ => Scalar.min ty - -def Scalar.cMax (ty : ScalarTy) : Int := - match ty with - | .Isize => Scalar.max .I32 - | .Usize => Scalar.max .U32 - | _ => Scalar.max ty - -theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by - cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * - have h := Isize.refined_min.property - cases h <;> simp [*, Isize.min] - -theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by - cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * - . have h := Isize.refined_max.property - cases h <;> simp [*, Isize.max] - . have h := Usize.refined_max.property - cases h <;> simp [*, Usize.max] - -theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by - have := Scalar.cMin_bound ty - linarith - -theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by - have := Scalar.cMax_bound ty - linarith - -structure Scalar (ty : ScalarTy) where - val : Int - hmin : Scalar.min ty ≤ val - hmax : val ≤ Scalar.max ty -deriving Repr - -theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : - Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty -> - Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty - := - λ h => by - apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith - -def Scalar.ofIntCore {ty : ScalarTy} (x : Int) - (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty := - { val := x, hmin := hmin, hmax := hmax } - --- Tactic to prove that integers are in bounds --- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam -syntax "intlit" : tactic -macro_rules - | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide) - -def Scalar.ofInt {ty : ScalarTy} (x : Int) - (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty := - -- Remark: we initially wrote: - -- let ⟨ hmin, hmax ⟩ := h - -- Scalar.ofIntCore x hmin hmax - -- We updated to the line below because a similar pattern in `Scalar.tryMk` - -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though. - -- TODO: investigate - Scalar.ofIntCore x h.left h.right - -@[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := - (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) - -theorem Scalar.check_bounds_prop {ty : ScalarTy} {x : Int} (h: Scalar.check_bounds ty x) : - Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by - simp at * - have ⟨ hmin, hmax ⟩ := h - have hbmin := Scalar.cMin_bound ty - have hbmax := Scalar.cMax_bound ty - cases hmin <;> cases hmax <;> apply And.intro <;> linarith - --- Further thoughts: look at what has been done here: --- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean --- and --- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean --- which both contain a fair amount of reasoning already! -def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := - if h:Scalar.check_bounds ty x then - -- If we do: - -- ``` - -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_prop h) - -- Scalar.ofIntCore x hmin hmax - -- ``` - -- then normalization blocks (for instance, some proofs which use reflexivity fail). - -- However, the version below doesn't block reduction (TODO: investigate): - return Scalar.ofInt x (Scalar.check_bounds_prop h) - else fail integerOverflow - -def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) - -def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - if y.val != 0 then Scalar.tryMk ty (x.val / y.val) else fail divisionByZero - --- Our custom remainder operation, which satisfies the semantics of Rust --- TODO: is there a better way? -def scalar_rem (x y : Int) : Int := - if 0 ≤ x then |x| % |y| - else - (|x| % |y|) - --- Our custom division operation, which satisfies the semantics of Rust --- TODO: is there a better way? -def scalar_div (x y : Int) : Int := - if 0 ≤ x && 0 ≤ y then |x| / |y| - else if 0 ≤ x && y < 0 then - (|x| / |y|) - else if x < 0 && 0 ≤ y then - (|x| / |y|) - else |x| / |y| - --- Checking that the remainder operation is correct -#assert scalar_rem 1 2 = 1 -#assert scalar_rem (-1) 2 = -1 -#assert scalar_rem 1 (-2) = 1 -#assert scalar_rem (-1) (-2) = -1 -#assert scalar_rem 7 3 = (1:Int) -#assert scalar_rem (-7) 3 = -1 -#assert scalar_rem 7 (-3) = 1 -#assert scalar_rem (-7) (-3) = -1 - --- Checking that the division operation is correct -#assert scalar_div 3 2 = 1 -#assert scalar_div (-3) 2 = -1 -#assert scalar_div 3 (-2) = -1 -#assert scalar_div (-3) (-2) = 1 -#assert scalar_div 7 3 = 2 -#assert scalar_div (-7) 3 = -2 -#assert scalar_div 7 (-3) = -2 -#assert scalar_div (-7) (-3) = 2 - -def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - if y.val != 0 then Scalar.tryMk ty (x.val % y.val) else fail divisionByZero - -def Scalar.add {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val + y.val) - -def Scalar.sub {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val - y.val) - -def Scalar.mul {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val * y.val) - --- TODO: instances of +, -, * etc. for scalars - --- Cast an integer from a [src_ty] to a [tgt_ty] --- TODO: check the semantics of casts in Rust -def Scalar.cast {src_ty : ScalarTy} (tgt_ty : ScalarTy) (x : Scalar src_ty) : Result (Scalar tgt_ty) := - Scalar.tryMk tgt_ty x.val - --- The scalar types --- We declare the definitions as reducible so that Lean can unfold them (useful --- for type class resolution for instance). -@[reducible] def Isize := Scalar .Isize -@[reducible] def I8 := Scalar .I8 -@[reducible] def I16 := Scalar .I16 -@[reducible] def I32 := Scalar .I32 -@[reducible] def I64 := Scalar .I64 -@[reducible] def I128 := Scalar .I128 -@[reducible] def Usize := Scalar .Usize -@[reducible] def U8 := Scalar .U8 -@[reducible] def U16 := Scalar .U16 -@[reducible] def U32 := Scalar .U32 -@[reducible] def U64 := Scalar .U64 -@[reducible] def U128 := Scalar .U128 - --- TODO: below: not sure this is the best way. --- Should we rather overload operations like +, -, etc.? --- Also, it is possible to automate the generation of those definitions --- with macros (but would it be a good idea? It would be less easy to --- read the file, which is not supposed to change a lot) - --- Negation - -/-- -Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce -one here. - -The notation typeclass for heterogeneous addition. -This enables the notation `- a : β` where `a : α`. --/ -class HNeg (α : Type u) (β : outParam (Type v)) where - /-- `- a` computes the negation of `a`. - The meaning of this notation is type-dependent. -/ - hNeg : α → β - -prefix:75 "-" => HNeg.hNeg - -instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x -instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x -instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x -instance : HNeg I32 (Result I32) where hNeg x := Scalar.neg x -instance : HNeg I64 (Result I64) where hNeg x := Scalar.neg x -instance : HNeg I128 (Result I128) where hNeg x := Scalar.neg x - --- Addition -instance {ty} : HAdd (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hAdd x y := Scalar.add x y - --- Substraction -instance {ty} : HSub (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hSub x y := Scalar.sub x y - --- Multiplication -instance {ty} : HMul (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hMul x y := Scalar.mul x y - --- Division -instance {ty} : HDiv (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hDiv x y := Scalar.div x y - --- Remainder -instance {ty} : HMod (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hMod x y := Scalar.rem x y - --- ofIntCore --- TODO: typeclass? -def Isize.ofIntCore := @Scalar.ofIntCore .Isize -def I8.ofIntCore := @Scalar.ofIntCore .I8 -def I16.ofIntCore := @Scalar.ofIntCore .I16 -def I32.ofIntCore := @Scalar.ofIntCore .I32 -def I64.ofIntCore := @Scalar.ofIntCore .I64 -def I128.ofIntCore := @Scalar.ofIntCore .I128 -def Usize.ofIntCore := @Scalar.ofIntCore .Usize -def U8.ofIntCore := @Scalar.ofIntCore .U8 -def U16.ofIntCore := @Scalar.ofIntCore .U16 -def U32.ofIntCore := @Scalar.ofIntCore .U32 -def U64.ofIntCore := @Scalar.ofIntCore .U64 -def U128.ofIntCore := @Scalar.ofIntCore .U128 - --- ofInt --- TODO: typeclass? -def Isize.ofInt := @Scalar.ofInt .Isize -def I8.ofInt := @Scalar.ofInt .I8 -def I16.ofInt := @Scalar.ofInt .I16 -def I32.ofInt := @Scalar.ofInt .I32 -def I64.ofInt := @Scalar.ofInt .I64 -def I128.ofInt := @Scalar.ofInt .I128 -def Usize.ofInt := @Scalar.ofInt .Usize -def U8.ofInt := @Scalar.ofInt .U8 -def U16.ofInt := @Scalar.ofInt .U16 -def U32.ofInt := @Scalar.ofInt .U32 -def U64.ofInt := @Scalar.ofInt .U64 -def U128.ofInt := @Scalar.ofInt .U128 - --- Comparisons -instance {ty} : LT (Scalar ty) where - lt a b := LT.lt a.val b.val - -instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val - -instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. -instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. - -theorem Scalar.eq_of_val_eq {ty} : ∀ {i j : Scalar ty}, Eq i.val j.val → Eq i j - | ⟨_, _, _⟩, ⟨_, _, _⟩, rfl => rfl - -theorem Scalar.val_eq_of_eq {ty} {i j : Scalar ty} (h : Eq i j) : Eq i.val j.val := - h ▸ rfl - -theorem Scalar.ne_of_val_ne {ty} {i j : Scalar ty} (h : Not (Eq i.val j.val)) : Not (Eq i j) := - fun h' => absurd (val_eq_of_eq h') h - -instance (ty : ScalarTy) : DecidableEq (Scalar ty) := - fun i j => - match decEq i.val j.val with - | isTrue h => isTrue (Scalar.eq_of_val_eq h) - | isFalse h => isFalse (Scalar.ne_of_val_ne h) - -def Scalar.toInt {ty} (n : Scalar ty) : Int := n.val - --- -- We now define a type class that subsumes the various machine integer types, so --- -- as to write a concise definition for scalar_cast, rather than exhaustively --- -- enumerating all of the possible pairs. We remark that Rust has sane semantics --- -- and fails if a cast operation would involve a truncation or modulo. - --- class MachineInteger (t: Type) where --- size: Nat --- val: t -> Fin size --- ofNatCore: (n:Nat) -> LT.lt n size -> t - --- set_option hygiene false in --- run_cmd --- for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do --- Lean.Elab.Command.elabCommand (← `( --- namespace $typeName --- instance: MachineInteger $typeName where --- size := size --- val := val --- ofNatCore := ofNatCore --- end $typeName --- )) - --- -- Aeneas only instantiates the destination type (`src` is implicit). We rely on --- -- Lean to infer `src`. - --- def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): Result dst := --- if h: MachineInteger.val x < MachineInteger.size dst then --- .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) --- else --- .fail integerOverflow - -------------- --- VECTORS -- -------------- - -def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max } - --- TODO: do we really need it? It should be with Subtype by default -instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val - -def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ - -def Vec.len (α : Type u) (v : Vec α) : Usize := - let ⟨ v, l ⟩ := v - Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l - --- This shouldn't be used -def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := () - --- This is actually the backward function -def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α) - := - let nlen := List.length v.val + 1 - if h : nlen ≤ U32.max || nlen ≤ Usize.max then - have h : nlen ≤ Usize.max := by - simp [Usize.max] at * - have hm := Usize.refined_max.property - cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith - return ⟨ List.concat v.val x, by simp at *; assumption ⟩ - else - fail maximumSizeExceeded - --- This shouldn't be used -def Vec.insert_fwd (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit := - if i.val < List.length v.val then - .ret () - else - .fail arrayOutOfBounds - --- This is actually the backward function -def Vec.insert (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := - if i.val < List.length v.val then - -- TODO: maybe we should redefine a list library which uses integers - -- (instead of natural numbers) - let i := i.val.toNat - .ret ⟨ List.set v.val i x, by - have h: List.length v.val ≤ Usize.max := v.property - simp [*] at * - ⟩ - else - .fail arrayOutOfBounds - -def Vec.index_to_fin {α : Type u} {v: Vec α} {i: Usize} (h : i.val < List.length v.val) : - Fin (List.length v.val) := - let j := i.val.toNat - let h: j < List.length v.val := by - have heq := @Int.toNat_lt (List.length v.val) i.val i.hmin - apply heq.mpr - assumption - ⟨j, h⟩ - -def Vec.index (α : Type u) (v: Vec α) (i: Usize): Result α := - if h: i.val < List.length v.val then - let i := Vec.index_to_fin h - .ret (List.get v.val i) - else - .fail arrayOutOfBounds - --- This shouldn't be used -def Vec.index_back (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit := - if i.val < List.length v.val then - .ret () - else - .fail arrayOutOfBounds - -def Vec.index_mut (α : Type u) (v: Vec α) (i: Usize): Result α := - if h: i.val < List.length v.val then - let i := Vec.index_to_fin h - .ret (List.get v.val i) - else - .fail arrayOutOfBounds - -def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := - if h: i.val < List.length v.val then - let i := Vec.index_to_fin h - .ret ⟨ List.set v.val i x, by - have h: List.length v.val ≤ Usize.max := v.property - simp [*] at * - ⟩ - else - .fail arrayOutOfBounds - ----------- --- MISC -- ----------- - -@[simp] def mem.replace (a : Type) (x : a) (_ : a) : a := x -@[simp] def mem.replace_back (a : Type) (_ : a) (y : a) : a := y - -/-- Aeneas-translated function -- useful to reduce non-recursive definitions. - Use with `simp [ aeneas ]` -/ -register_simp_attr aeneas - -end Primitives +import Base.Primitives.Base +import Base.Primitives.Scalar +import Base.Primitives.Vec diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean new file mode 100644 index 00000000..db462c38 --- /dev/null +++ b/backends/lean/Base/Primitives/Base.lean @@ -0,0 +1,130 @@ +import Lean + +namespace Primitives + +-------------------- +-- ASSERT COMMAND --Std. +-------------------- + +open Lean Elab Command Term Meta + +syntax (name := assert) "#assert" term: command + +@[command_elab assert] +unsafe +def assertImpl : CommandElab := fun (_stx: Syntax) => do + runTermElabM (fun _ => do + let r ← evalTerm Bool (mkConst ``Bool) _stx[1] + if not r then + logInfo ("Assertion failed for:\n" ++ _stx[1]) + throwError ("Expression reduced to false:\n" ++ _stx[1]) + pure ()) + +#eval 2 == 2 +#assert (2 == 2) + +------------- +-- PRELUDE -- +------------- + +-- Results & monadic combinators + +inductive Error where + | assertionFailure: Error + | integerOverflow: Error + | divisionByZero: Error + | arrayOutOfBounds: Error + | maximumSizeExceeded: Error + | panic: Error +deriving Repr, BEq + +open Error + +inductive Result (α : Type u) where + | ret (v: α): Result α + | fail (e: Error): Result α + | div +deriving Repr, BEq + +open Result + +instance Result_Inhabited (α : Type u) : Inhabited (Result α) := + Inhabited.mk (fail panic) + +instance Result_Nonempty (α : Type u) : Nonempty (Result α) := + Nonempty.intro div + +/- HELPERS -/ + +def ret? {α: Type u} (r: Result α): Bool := + match r with + | ret _ => true + | fail _ | div => false + +def div? {α: Type u} (r: Result α): Bool := + match r with + | div => true + | ret _ | fail _ => false + +def massert (b:Bool) : Result Unit := + if b then ret () else fail assertionFailure + +def eval_global {α: Type u} (x: Result α) (_: ret? x): α := + match x with + | fail _ | div => by contradiction + | ret x => x + +/- DO-DSL SUPPORT -/ + +def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := + match x with + | ret v => f v + | fail v => fail v + | div => div + +-- Allows using Result in do-blocks +instance : Bind Result where + bind := bind + +-- Allows using return x in do-blocks +instance : Pure Result where + pure := fun x => ret x + +@[simp] theorem bind_ret (x : α) (f : α → Result β) : bind (.ret x) f = f x := by simp [bind] +@[simp] theorem bind_fail (x : Error) (f : α → Result β) : bind (.fail x) f = .fail x := by simp [bind] +@[simp] theorem bind_div (f : α → Result β) : bind .div f = .div := by simp [bind] + +/- CUSTOM-DSL SUPPORT -/ + +-- Let-binding the Result of a monadic operation is oftentimes not sufficient, +-- because we may need a hypothesis for equational reasoning in the scope. We +-- rely on subtype, and a custom let-binding operator, in effect recreating our +-- own variant of the do-dsl + +def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := + match o with + | ret x => ret ⟨x, rfl⟩ + | fail e => fail e + | div => div + +@[simp] theorem bind_tc_ret (x : α) (f : α → Result β) : + (do let y ← .ret x; f y) = f x := by simp [Bind.bind, bind] + +@[simp] theorem bind_tc_fail (x : Error) (f : α → Result β) : + (do let y ← fail x; f y) = fail x := by simp [Bind.bind, bind] + +@[simp] theorem bind_tc_div (f : α → Result β) : + (do let y ← div; f y) = div := by simp [Bind.bind, bind] + +---------- +-- MISC -- +---------- + +@[simp] def mem.replace (a : Type) (x : a) (_ : a) : a := x +@[simp] def mem.replace_back (a : Type) (_ : a) (y : a) : a := y + +/-- Aeneas-translated function -- useful to reduce non-recursive definitions. + Use with `simp [ aeneas ]` -/ +register_simp_attr aeneas + +end Primitives diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean new file mode 100644 index 00000000..241dfa07 --- /dev/null +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -0,0 +1,507 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Mathlib.Tactic.Linarith +import Base.Primitives.Base + +namespace Primitives + +---------------------- +-- MACHINE INTEGERS -- +---------------------- + +-- We redefine our machine integers types. + +-- For Isize/Usize, we reuse `getNumBits` from `USize`. You cannot reduce `getNumBits` +-- using the simplifier, meaning that proofs do not depend on the compile-time value of +-- USize.size. (Lean assumes 32 or 64-bit platforms, and Rust doesn't really support, at +-- least officially, 16-bit microcontrollers, so this seems like a fine design decision +-- for now.) + +-- Note from Chris Bailey: "If there's more than one salient property of your +-- definition then the subtyping strategy might get messy, and the property part +-- of a subtype is less discoverable by the simplifier or tactics like +-- library_search." So, we will not add refinements on the return values of the +-- operations defined on Primitives, but will rather rely on custom lemmas to +-- invert on possible return values of the primitive operations. + +-- Machine integer constants, done via `ofNatCore`, which requires a proof that +-- the `Nat` fits within the desired integer type. We provide a custom tactic. + +open Result Error +open System.Platform.getNumBits + +-- TODO: is there a way of only importing System.Platform.getNumBits? +-- +@[simp] def size_num_bits : Nat := (System.Platform.getNumBits ()).val + +-- Remark: Lean seems to use < for the comparisons with the upper bounds by convention. + +-- The "structured" bounds +def Isize.smin : Int := - (HPow.hPow 2 (size_num_bits - 1)) +def Isize.smax : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1 +def I8.smin : Int := - (HPow.hPow 2 7) +def I8.smax : Int := HPow.hPow 2 7 - 1 +def I16.smin : Int := - (HPow.hPow 2 15) +def I16.smax : Int := HPow.hPow 2 15 - 1 +def I32.smin : Int := -(HPow.hPow 2 31) +def I32.smax : Int := HPow.hPow 2 31 - 1 +def I64.smin : Int := -(HPow.hPow 2 63) +def I64.smax : Int := HPow.hPow 2 63 - 1 +def I128.smin : Int := -(HPow.hPow 2 127) +def I128.smax : Int := HPow.hPow 2 127 - 1 +def Usize.smin : Int := 0 +def Usize.smax : Int := HPow.hPow 2 size_num_bits - 1 +def U8.smin : Int := 0 +def U8.smax : Int := HPow.hPow 2 8 - 1 +def U16.smin : Int := 0 +def U16.smax : Int := HPow.hPow 2 16 - 1 +def U32.smin : Int := 0 +def U32.smax : Int := HPow.hPow 2 32 - 1 +def U64.smin : Int := 0 +def U64.smax : Int := HPow.hPow 2 64 - 1 +def U128.smin : Int := 0 +def U128.smax : Int := HPow.hPow 2 128 - 1 + +-- The "normalized" bounds, that we use in practice +def I8.min := -128 +def I8.max := 127 +def I16.min := -32768 +def I16.max := 32767 +def I32.min := -2147483648 +def I32.max := 2147483647 +def I64.min := -9223372036854775808 +def I64.max := 9223372036854775807 +def I128.min := -170141183460469231731687303715884105728 +def I128.max := 170141183460469231731687303715884105727 +@[simp] def U8.min := 0 +def U8.max := 255 +@[simp] def U16.min := 0 +def U16.max := 65535 +@[simp] def U32.min := 0 +def U32.max := 4294967295 +@[simp] def U64.min := 0 +def U64.max := 18446744073709551615 +@[simp] def U128.min := 0 +def U128.max := 340282366920938463463374607431768211455 +@[simp] def Usize.min := 0 + +def Isize.refined_min : { n:Int // n = I32.min ∨ n = I64.min } := + ⟨ Isize.smin, by + simp [Isize.smin] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Isize.refined_max : { n:Int // n = I32.max ∨ n = I64.max } := + ⟨ Isize.smax, by + simp [Isize.smax] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Usize.refined_max : { n:Int // n = U32.max ∨ n = U64.max } := + ⟨ Usize.smax, by + simp [Usize.smax] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Isize.min := Isize.refined_min.val +def Isize.max := Isize.refined_max.val +def Usize.max := Usize.refined_max.val + +inductive ScalarTy := +| Isize +| I8 +| I16 +| I32 +| I64 +| I128 +| Usize +| U8 +| U16 +| U32 +| U64 +| U128 + +def Scalar.smin (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.smin + | .I8 => I8.smin + | .I16 => I16.smin + | .I32 => I32.smin + | .I64 => I64.smin + | .I128 => I128.smin + | .Usize => Usize.smin + | .U8 => U8.smin + | .U16 => U16.smin + | .U32 => U32.smin + | .U64 => U64.smin + | .U128 => U128.smin + +def Scalar.smax (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.smax + | .I8 => I8.smax + | .I16 => I16.smax + | .I32 => I32.smax + | .I64 => I64.smax + | .I128 => I128.smax + | .Usize => Usize.smax + | .U8 => U8.smax + | .U16 => U16.smax + | .U32 => U32.smax + | .U64 => U64.smax + | .U128 => U128.smax + +def Scalar.min (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.min + | .I8 => I8.min + | .I16 => I16.min + | .I32 => I32.min + | .I64 => I64.min + | .I128 => I128.min + | .Usize => Usize.min + | .U8 => U8.min + | .U16 => U16.min + | .U32 => U32.min + | .U64 => U64.min + | .U128 => U128.min + +def Scalar.max (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.max + | .I8 => I8.max + | .I16 => I16.max + | .I32 => I32.max + | .I64 => I64.max + | .I128 => I128.max + | .Usize => Usize.max + | .U8 => U8.max + | .U16 => U16.max + | .U32 => U32.max + | .U64 => U64.max + | .U128 => U128.max + +def Scalar.smin_eq (ty : ScalarTy) : Scalar.min ty = Scalar.smin ty := by + cases ty <;> rfl + +def Scalar.smax_eq (ty : ScalarTy) : Scalar.max ty = Scalar.smax ty := by + cases ty <;> rfl + +-- "Conservative" bounds +-- We use those because we can't compare to the isize bounds (which can't +-- reduce at compile-time). Whenever we perform an arithmetic operation like +-- addition we need to check that the result is in bounds: we first compare +-- to the conservative bounds, which reduce, then compare to the real bounds. +-- This is useful for the various #asserts that we want to reduce at +-- type-checking time. +def Scalar.cMin (ty : ScalarTy) : Int := + match ty with + | .Isize => Scalar.min .I32 + | _ => Scalar.min ty + +def Scalar.cMax (ty : ScalarTy) : Int := + match ty with + | .Isize => Scalar.max .I32 + | .Usize => Scalar.max .U32 + | _ => Scalar.max ty + +theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * + have h := Isize.refined_min.property + cases h <;> simp [*, Isize.min] + +theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * + . have h := Isize.refined_max.property + cases h <;> simp [*, Isize.max] + . have h := Usize.refined_max.property + cases h <;> simp [*, Usize.max] + +theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by + have := Scalar.cMin_bound ty + linarith + +theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by + have := Scalar.cMax_bound ty + linarith + +structure Scalar (ty : ScalarTy) where + val : Int + hmin : Scalar.min ty ≤ val + hmax : val ≤ Scalar.max ty +deriving Repr + +theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : + Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty -> + Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty + := + λ h => by + apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith + +def Scalar.ofIntCore {ty : ScalarTy} (x : Int) + (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty := + { val := x, hmin := hmin, hmax := hmax } + +-- Tactic to prove that integers are in bounds +-- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam +syntax "intlit" : tactic +macro_rules + | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide) + +def Scalar.ofInt {ty : ScalarTy} (x : Int) + (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty := + -- Remark: we initially wrote: + -- let ⟨ hmin, hmax ⟩ := h + -- Scalar.ofIntCore x hmin hmax + -- We updated to the line below because a similar pattern in `Scalar.tryMk` + -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though. + -- TODO: investigate + Scalar.ofIntCore x h.left h.right + +@[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := + (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) + +theorem Scalar.check_bounds_prop {ty : ScalarTy} {x : Int} (h: Scalar.check_bounds ty x) : + Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by + simp at * + have ⟨ hmin, hmax ⟩ := h + have hbmin := Scalar.cMin_bound ty + have hbmax := Scalar.cMax_bound ty + cases hmin <;> cases hmax <;> apply And.intro <;> linarith + +-- Further thoughts: look at what has been done here: +-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean +-- and +-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean +-- which both contain a fair amount of reasoning already! +def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := + if h:Scalar.check_bounds ty x then + -- If we do: + -- ``` + -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_prop h) + -- Scalar.ofIntCore x hmin hmax + -- ``` + -- then normalization blocks (for instance, some proofs which use reflexivity fail). + -- However, the version below doesn't block reduction (TODO: investigate): + return Scalar.ofInt x (Scalar.check_bounds_prop h) + else fail integerOverflow + +def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) + +def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := + if y.val != 0 then Scalar.tryMk ty (x.val / y.val) else fail divisionByZero + +-- Our custom remainder operation, which satisfies the semantics of Rust +-- TODO: is there a better way? +def scalar_rem (x y : Int) : Int := + if 0 ≤ x then |x| % |y| + else - (|x| % |y|) + +-- Our custom division operation, which satisfies the semantics of Rust +-- TODO: is there a better way? +def scalar_div (x y : Int) : Int := + if 0 ≤ x && 0 ≤ y then |x| / |y| + else if 0 ≤ x && y < 0 then - (|x| / |y|) + else if x < 0 && 0 ≤ y then - (|x| / |y|) + else |x| / |y| + +-- Checking that the remainder operation is correct +#assert scalar_rem 1 2 = 1 +#assert scalar_rem (-1) 2 = -1 +#assert scalar_rem 1 (-2) = 1 +#assert scalar_rem (-1) (-2) = -1 +#assert scalar_rem 7 3 = (1:Int) +#assert scalar_rem (-7) 3 = -1 +#assert scalar_rem 7 (-3) = 1 +#assert scalar_rem (-7) (-3) = -1 + +-- Checking that the division operation is correct +#assert scalar_div 3 2 = 1 +#assert scalar_div (-3) 2 = -1 +#assert scalar_div 3 (-2) = -1 +#assert scalar_div (-3) (-2) = 1 +#assert scalar_div 7 3 = 2 +#assert scalar_div (-7) 3 = -2 +#assert scalar_div 7 (-3) = -2 +#assert scalar_div (-7) (-3) = 2 + +def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := + if y.val != 0 then Scalar.tryMk ty (x.val % y.val) else fail divisionByZero + +def Scalar.add {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := + Scalar.tryMk ty (x.val + y.val) + +def Scalar.sub {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := + Scalar.tryMk ty (x.val - y.val) + +def Scalar.mul {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := + Scalar.tryMk ty (x.val * y.val) + +-- TODO: instances of +, -, * etc. for scalars + +-- Cast an integer from a [src_ty] to a [tgt_ty] +-- TODO: check the semantics of casts in Rust +def Scalar.cast {src_ty : ScalarTy} (tgt_ty : ScalarTy) (x : Scalar src_ty) : Result (Scalar tgt_ty) := + Scalar.tryMk tgt_ty x.val + +-- The scalar types +-- We declare the definitions as reducible so that Lean can unfold them (useful +-- for type class resolution for instance). +@[reducible] def Isize := Scalar .Isize +@[reducible] def I8 := Scalar .I8 +@[reducible] def I16 := Scalar .I16 +@[reducible] def I32 := Scalar .I32 +@[reducible] def I64 := Scalar .I64 +@[reducible] def I128 := Scalar .I128 +@[reducible] def Usize := Scalar .Usize +@[reducible] def U8 := Scalar .U8 +@[reducible] def U16 := Scalar .U16 +@[reducible] def U32 := Scalar .U32 +@[reducible] def U64 := Scalar .U64 +@[reducible] def U128 := Scalar .U128 + +-- TODO: below: not sure this is the best way. +-- Should we rather overload operations like +, -, etc.? +-- Also, it is possible to automate the generation of those definitions +-- with macros (but would it be a good idea? It would be less easy to +-- read the file, which is not supposed to change a lot) + +-- Negation + +/-- +Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce +one here. + +The notation typeclass for heterogeneous addition. +This enables the notation `- a : β` where `a : α`. +-/ +class HNeg (α : Type u) (β : outParam (Type v)) where + /-- `- a` computes the negation of `a`. + The meaning of this notation is type-dependent. -/ + hNeg : α → β + +prefix:75 "-" => HNeg.hNeg + +instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x +instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x +instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x +instance : HNeg I32 (Result I32) where hNeg x := Scalar.neg x +instance : HNeg I64 (Result I64) where hNeg x := Scalar.neg x +instance : HNeg I128 (Result I128) where hNeg x := Scalar.neg x + +-- Addition +instance {ty} : HAdd (Scalar ty) (Scalar ty) (Result (Scalar ty)) where + hAdd x y := Scalar.add x y + +-- Substraction +instance {ty} : HSub (Scalar ty) (Scalar ty) (Result (Scalar ty)) where + hSub x y := Scalar.sub x y + +-- Multiplication +instance {ty} : HMul (Scalar ty) (Scalar ty) (Result (Scalar ty)) where + hMul x y := Scalar.mul x y + +-- Division +instance {ty} : HDiv (Scalar ty) (Scalar ty) (Result (Scalar ty)) where + hDiv x y := Scalar.div x y + +-- Remainder +instance {ty} : HMod (Scalar ty) (Scalar ty) (Result (Scalar ty)) where + hMod x y := Scalar.rem x y + +-- ofIntCore +-- TODO: typeclass? +def Isize.ofIntCore := @Scalar.ofIntCore .Isize +def I8.ofIntCore := @Scalar.ofIntCore .I8 +def I16.ofIntCore := @Scalar.ofIntCore .I16 +def I32.ofIntCore := @Scalar.ofIntCore .I32 +def I64.ofIntCore := @Scalar.ofIntCore .I64 +def I128.ofIntCore := @Scalar.ofIntCore .I128 +def Usize.ofIntCore := @Scalar.ofIntCore .Usize +def U8.ofIntCore := @Scalar.ofIntCore .U8 +def U16.ofIntCore := @Scalar.ofIntCore .U16 +def U32.ofIntCore := @Scalar.ofIntCore .U32 +def U64.ofIntCore := @Scalar.ofIntCore .U64 +def U128.ofIntCore := @Scalar.ofIntCore .U128 + +-- ofInt +-- TODO: typeclass? +def Isize.ofInt := @Scalar.ofInt .Isize +def I8.ofInt := @Scalar.ofInt .I8 +def I16.ofInt := @Scalar.ofInt .I16 +def I32.ofInt := @Scalar.ofInt .I32 +def I64.ofInt := @Scalar.ofInt .I64 +def I128.ofInt := @Scalar.ofInt .I128 +def Usize.ofInt := @Scalar.ofInt .Usize +def U8.ofInt := @Scalar.ofInt .U8 +def U16.ofInt := @Scalar.ofInt .U16 +def U32.ofInt := @Scalar.ofInt .U32 +def U64.ofInt := @Scalar.ofInt .U64 +def U128.ofInt := @Scalar.ofInt .U128 + +-- Comparisons +instance {ty} : LT (Scalar ty) where + lt a b := LT.lt a.val b.val + +instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val + +instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. +instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. + +theorem Scalar.eq_of_val_eq {ty} : ∀ {i j : Scalar ty}, Eq i.val j.val → Eq i j + | ⟨_, _, _⟩, ⟨_, _, _⟩, rfl => rfl + +theorem Scalar.val_eq_of_eq {ty} {i j : Scalar ty} (h : Eq i j) : Eq i.val j.val := + h ▸ rfl + +theorem Scalar.ne_of_val_ne {ty} {i j : Scalar ty} (h : Not (Eq i.val j.val)) : Not (Eq i j) := + fun h' => absurd (val_eq_of_eq h') h + +instance (ty : ScalarTy) : DecidableEq (Scalar ty) := + fun i j => + match decEq i.val j.val with + | isTrue h => isTrue (Scalar.eq_of_val_eq h) + | isFalse h => isFalse (Scalar.ne_of_val_ne h) + +/- Remark: we can't write the following instance because of restrictions about + the type class parameters (`ty` doesn't appear in the return type, which is + forbidden): + + ``` + instance Scalar.cast (ty : ScalarTy) : Coe (Scalar ty) Int where coe := λ v => v.val + ``` + -/ +def Scalar.toInt {ty} (n : Scalar ty) : Int := n.val + +-- -- We now define a type class that subsumes the various machine integer types, so +-- -- as to write a concise definition for scalar_cast, rather than exhaustively +-- -- enumerating all of the possible pairs. We remark that Rust has sane semantics +-- -- and fails if a cast operation would involve a truncation or modulo. + +-- class MachineInteger (t: Type) where +-- size: Nat +-- val: t -> Fin size +-- ofNatCore: (n:Nat) -> LT.lt n size -> t + +-- set_option hygiene false in +-- run_cmd +-- for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do +-- Lean.Elab.Command.elabCommand (← `( +-- namespace $typeName +-- instance: MachineInteger $typeName where +-- size := size +-- val := val +-- ofNatCore := ofNatCore +-- end $typeName +-- )) + +-- -- Aeneas only instantiates the destination type (`src` is implicit). We rely on +-- -- Lean to infer `src`. + +-- def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): Result dst := +-- if h: MachineInteger.val x < MachineInteger.size dst then +-- .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) +-- else +-- .fail integerOverflow + +end Primitives diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean new file mode 100644 index 00000000..7851a232 --- /dev/null +++ b/backends/lean/Base/Primitives/Vec.lean @@ -0,0 +1,113 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith +import Base.IList +import Base.Primitives.Scalar +import Base.Arith + +namespace Primitives + +open Result Error + +------------- +-- VECTORS -- +------------- + +def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max } + +-- TODO: do we really need it? It should be with Subtype by default +instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val + +instance (a : Type) : Arith.HasIntProp (Vec a) where + prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize + prop := λ ⟨ _, l ⟩ => l + +example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by + intro_has_int_prop_instances + simp_all [Scalar.max, Scalar.min] + +example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by + scalar_tac + +def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ + +def Vec.len (α : Type u) (v : Vec α) : Usize := + let ⟨ v, l ⟩ := v + Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l + +def Vec.length {α : Type u} (v : Vec α) : Int := v.val.len + +-- This shouldn't be used +def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := () + +-- This is actually the backward function +def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α) + := + let nlen := List.length v.val + 1 + if h : nlen ≤ U32.max || nlen ≤ Usize.max then + have h : nlen ≤ Usize.max := by + simp [Usize.max] at * + have hm := Usize.refined_max.property + cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith + return ⟨ List.concat v.val x, by simp at *; assumption ⟩ + else + fail maximumSizeExceeded + +-- This shouldn't be used +def Vec.insert_fwd (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit := + if i.val < List.length v.val then + .ret () + else + .fail arrayOutOfBounds + +-- This is actually the backward function +def Vec.insert (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := + if i.val < List.length v.val then + -- TODO: maybe we should redefine a list library which uses integers + -- (instead of natural numbers) + .ret ⟨ v.val.update i.val x, by have := v.property; simp [*] ⟩ + else + .fail arrayOutOfBounds + +-- TODO: remove +def Vec.index_to_fin {α : Type u} {v: Vec α} {i: Usize} (h : i.val < List.length v.val) : + Fin (List.length v.val) := + let j := i.val.toNat + let h: j < List.length v.val := by + have heq := @Int.toNat_lt (List.length v.val) i.val i.hmin + apply heq.mpr + assumption + ⟨j, h⟩ + +def Vec.index (α : Type u) (v: Vec α) (i: Usize): Result α := + match v.val.indexOpt i.val with + | none => fail .arrayOutOfBounds + | some x => ret x + +-- This shouldn't be used +def Vec.index_back (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit := + if i.val < List.length v.val then + .ret () + else + .fail arrayOutOfBounds + +def Vec.index_mut (α : Type u) (v: Vec α) (i: Usize): Result α := + if h: i.val < List.length v.val then + let i := Vec.index_to_fin h + .ret (List.get v.val i) + else + .fail arrayOutOfBounds + +def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := + if h: i.val < List.length v.val then + let i := Vec.index_to_fin h + .ret ⟨ List.set v.val i x, by + have h: List.length v.val ≤ Usize.max := v.property + simp [*] at * + ⟩ + else + .fail arrayOutOfBounds + +end Primitives diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 35a3c25a..af7b426a 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -7,6 +7,7 @@ namespace Progress open Lean Elab Term Meta Tactic open Utils +-- TODO: remove namespace Test open Primitives @@ -199,6 +200,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do elab "progress" args:progressArgs : tactic => evalProgress args +-- TODO: remove namespace Test open Primitives |