summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-07-04 17:30:35 +0200
committerSon Ho2023-07-04 17:30:35 +0200
commitbd873499f9a8d517cc948c6336a5c6ce856d846d (patch)
tree0e4fc5eda91c9d34c27790286a6098dc937e79b9
parent87d6f6c7c90bf7b427397d6bd2e2c70d610678e3 (diff)
Fix some issues with the extraction to Lean
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean63
-rw-r--r--compiler/Extract.ml134
2 files changed, 133 insertions, 64 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 41209021..4b08fe44 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
preDefs.mapM fun preDef => do
-- Replace the recursive calls
let body ← mapVisit visit_e preDef.value
+ trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}"
-- Currify the function by grouping the arguments into a dependent tuple
-- (over which we match to retrieve the individual arguments).
- lambdaLetTelescope body fun args body => do
+ lambdaTelescope body fun args body => do
let body ← mkSigmasMatch args.toList body 0
-- Add the declaration
@@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
| .sort _ => throwError "Unreachable"
| .lam .. => throwError "Unimplemented"
| .forallE .. => throwError "Unreachable" -- Shouldn't get there
- | .letE dName dTy dValue body _nonDep => do
- -- Introduce a local declaration for the let-binding
- withLetDecl dName dTy dValue fun decl => do
+ | .letE .. => do
+ -- Telescope all the let-bindings (remark: this also telescopes the lambdas)
+ lambdaLetTelescope e fun xs body => do
+ -- Note that we don't visit the bound values: there shouldn't be
+ -- recursive calls, lambda expressions, etc. inside
+ -- Prove that the body is valid
let isValid ← proveExprIsValid k_var kk_var body
- -- Add the let-binding around.
+ -- Add the let-bindings around.
-- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside,
-- but because it reduces in the end it doesn't matter. More precisely:
-- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression.
- mkLetFVars #[decl] isValid
+ mkLambdaFVars xs isValid (usedLetOnly := false)
| .mdata _ b => proveExprIsValid k_var kk_var b
| .proj _ _ _ =>
-- The projection shouldn't use the continuation
@@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
if isIte then proveExprIsValid k_var kk_var br
else do
-- There is a lambda -- TODO: how do we remove exacly *one* lambda?
- lambdaLetTelescope br fun xs br => do
+ lambdaTelescope br fun xs br => do
let x := xs.get! 0
let xs := xs.extract 1 xs.size
let br ← mkLambdaFVars xs br
@@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
let yValid ← do
-- This is a lambda expression -- TODO: how do we remove exacly *one* lambda?
- lambdaLetTelescope y fun xs y => do
+ lambdaTelescope y fun xs y => do
let x := xs.get! 0
let xs := xs.extract 1 xs.size
let y ← mkLambdaFVars xs y
@@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
-- binders might come from the match, and some of the binders might come
-- from the fact that the expression in the match is a lambda expression:
-- we use the branchesNumParams field for this reason
- lambdaLetTelescope br fun xs br => do
+ lambdaTelescope br fun xs br => do
let numParams := me.branchesNumParams.get! idx
let xs_beg := xs.extract 0 numParams
let xs_end := xs.extract numParams xs.size
@@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid
let env ← getEnv
let body := (env.constants.find! name).value!
trace[Diverge.def.valid] "body: {body}"
- lambdaLetTelescope body fun xs body => do
+ lambdaTelescope body fun xs body => do
assert! xs.size = 2
let kk_var := xs.get! 0
let x_var := xs.get! 1
@@ -695,8 +699,10 @@ def proveMutRecIsValid
let bodiesValid ←
bodies.mapIdxM fun idx body => do
let preDef := preDefs.get! idx
+ trace[Diverge.def.valid] "## Proving that the body {body} is valid"
proveSingleBodyIsValid k_var preDef body
-- Then prove that the mut rec body is valid
+ trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid"
let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid
-- Save the theorem
let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst]
@@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
let defs ← preDefs.mapIdxM fun idx preDef => do
- lambdaLetTelescope preDef.value fun xs _ => do
+ lambdaTelescope preDef.value fun xs _ => do
-- Create the index
let idx ← mkFinVal grSize idx.val
-- Group the inputs into a dependent tuple
@@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
let preDef := preDefs.get! i
let defName := decls.get! i
-- Retrieve the arguments
- lambdaLetTelescope preDef.value fun xs body => do
+ lambdaTelescope preDef.value fun xs body => do
trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}"
trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}"
-- The theorem statement
@@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
- trace[Diverge.def] ("divRecursion: defs: " ++ msg)
+ trace[Diverge.def] ("divRecursion: defs:\n" ++ msg)
-- TODO: what is this?
for preDef in preDefs do
@@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Replace the recursive calls in all the function bodies by calls to the
-- continuation `k` and and generate for those bodies declarations
+ trace[Diverge.def] "# Generating the unary bodies"
let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs
+ trace[Diverge.def] "Unary bodies (after decl): {bodies}"
-- Generate the mutually recursive body
+ trace[Diverge.def] "# Generating the mut rec body"
let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies
trace[Diverge.def] "mut rec body (after decl): {mutRecBody}"
@@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- our fixed-point
let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty]
withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do
+ trace[Diverge.def] "# Proving that the mut rec body is valid"
let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
+ trace[Diverge.def] "# Generating the final definitions"
let decls ← mkDeclareFixDefs mutRecBody preDefs
-- Prove the unfolding theorems
+ trace[Diverge.def] "# Proving the unfolding theorems"
proveUnfoldingThms isValidThm preDefs decls
- -- Process the definitions - TODO
+ -- Generating code -- TODO
addAndCompilePartialRec preDefs
-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main
@@ -1064,13 +1076,32 @@ namespace Tests
-- Testing dependent branching and let-bindings
-- TODO: why the linter warning?
- divergent def is_non_zero (i : Int) : Result Bool :=
+ divergent def isNonZero (i : Int) : Result Bool :=
if _h:i = 0 then return false
else
let b := true
return b
- #check is_non_zero.unfold
+ #check isNonZero.unfold
+
+ -- Testing let-bindings
+ divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool :=
+ let i0 := ls.length
+ if i < i0
+ then Result.ret True
+ else Result.ret False
+
+ #check iInBounds.unfold
+
+ divergent def isCons
+ {a : Type} (ls : List a) : Result Bool :=
+ let ls1 := ls
+ match ls1 with
+ | [] => Result.ret False
+ | x :: tl => Result.ret True
+
+ #check isCons.unfold
+
end Tests
end Diverge
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index a54a2299..b18d4743 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -618,9 +618,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let struct_constructor (basename : name) : string =
let tname = type_name basename in
let prefix =
- match !backend with FStar -> "Mk" | Lean | Coq | HOL4 -> "mk"
+ match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> ""
in
- prefix ^ tname
+ let suffix =
+ match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk"
+ in
+ prefix ^ tname ^ suffix
in
let get_fun_name = get_name in
let fun_name_to_snake_case (fname : fun_name) : string =
@@ -1326,7 +1329,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt (unit_name ()))
else if !backend = Lean && fields = [] then ()
(* If the definition is recursive, we may need to extract it as an inductive
- (instead of a record) *)
+ (instead of a record). We start with the "normal" case: we extract it
+ as a record. *)
else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then (
if !backend <> Lean then F.pp_print_space fmt ();
(* If Coq: print the constructor name *)
@@ -1379,7 +1383,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
a group of mutually recursive types: we extract it as an inductive type *)
assert (is_rec && (!backend = Coq || !backend = Lean));
let with_opaque_pre = false in
- let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in
+ (* Small trick: in Lean we use namespaces, meaning we don't need to prefix
+ the constructor name with the name of the type at definition site,
+ i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq
+ we generate `inductive Foo := | mk ... *)
+ let cons_name =
+ if !backend = Lean then "mk"
+ else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx
+ in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
extract_type_decl_variant ctx fmt type_decl_group def_name type_params
cons_name fields)
@@ -1950,14 +1961,17 @@ let extract_global_decl_register_names (ctx : extraction_ctx)
Note that patterns can introduce new variables: we thus return an extraction
context updated with new bindings.
+ [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding
+ or a lambda).
+
TODO: we don't need something very generic anymore (some definitions used
to be polymorphic).
*)
let extract_adt_g_value
(extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx)
- (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool)
- (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
- extraction_ctx =
+ (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool)
+ (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
+ (ty : ty) : extraction_ctx =
match ty with
| Adt (Tuple, type_args) ->
(* Tuple *)
@@ -1982,36 +1996,57 @@ let extract_adt_g_value
ctx)
| Adt (adt_id, _) ->
(* "Regular" ADT *)
- (* We print something of the form: [Cons field0 ... fieldn].
- * We could update the code to print something of the form:
- * [{ field0=...; ...; fieldn=...; }] in case of structures.
- *)
- let cons =
- (* The ADT shouldn't be opaque *)
- let with_opaque_pre = false in
- match variant_id with
- | Some vid -> (
- (* In the case of Lean, we might have to add the type name as a prefix *)
- match (!backend, adt_id) with
- | Lean, Assumed _ ->
- ctx_get_type with_opaque_pre adt_id ctx
- ^ "."
- ^ ctx_get_variant adt_id vid ctx
- | _ -> ctx_get_variant adt_id vid ctx)
- | None -> ctx_get_struct with_opaque_pre adt_id ctx
- in
- let use_parentheses = inside && field_values <> [] in
- if use_parentheses then F.pp_print_string fmt "(";
- F.pp_print_string fmt cons;
- let ctx =
- Collections.List.fold_left
- (fun ctx v ->
- F.pp_print_space fmt ();
- extract_value ctx true v)
- ctx field_values
- in
- if use_parentheses then F.pp_print_string fmt ")";
- ctx
+
+ (* If we are generating a pattern for a let-binding and we target Lean,
+ the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`.
+
+ Otherwise, it is: `let Cons x0 ... xn = ...`
+ *)
+ if is_single_pat && !Config.backend = Lean then (
+ F.pp_print_string fmt "⟨";
+ F.pp_print_space fmt ();
+ let ctx =
+ Collections.List.fold_left_link
+ (fun _ ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx true v)
+ ctx field_values
+ in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "⟩";
+ ctx)
+ else
+ (* We print something of the form: [Cons field0 ... fieldn].
+ * We could update the code to print something of the form:
+ * [{ field0=...; ...; fieldn=...; }] in case of structures.
+ *)
+ let cons =
+ (* The ADT shouldn't be opaque *)
+ let with_opaque_pre = false in
+ match variant_id with
+ | Some vid -> (
+ (* In the case of Lean, we might have to add the type name as a prefix *)
+ match (!backend, adt_id) with
+ | Lean, Assumed _ ->
+ ctx_get_type with_opaque_pre adt_id ctx
+ ^ "."
+ ^ ctx_get_variant adt_id vid ctx
+ | _ -> ctx_get_variant adt_id vid ctx)
+ | None -> ctx_get_struct with_opaque_pre adt_id ctx
+ in
+ let use_parentheses = inside && field_values <> [] in
+ if use_parentheses then F.pp_print_string fmt "(";
+ F.pp_print_string fmt cons;
+ let ctx =
+ Collections.List.fold_left
+ (fun ctx v ->
+ F.pp_print_space fmt ();
+ extract_value ctx true v)
+ ctx field_values
+ in
+ if use_parentheses then F.pp_print_string fmt ")";
+ ctx
| _ -> raise (Failure "Inconsistent typed value")
(* Extract globals in the same way as variables *)
@@ -2026,7 +2061,7 @@ let extract_global (ctx : extraction_ctx) (fmt : F.formatter)
updated with new bindings.
*)
let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (v : typed_pattern) : extraction_ctx =
+ (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx =
match v.value with
| PatConstant cv ->
ctx.fmt.extract_primitive_value fmt inside cv;
@@ -2042,8 +2077,10 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "_";
ctx
| PatAdt av ->
- let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in
- extract_adt_g_value extract_value fmt ctx inside av.variant_id
+ let extract_value ctx inside v =
+ extract_typed_pattern ctx fmt is_let inside v
+ in
+ extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
av.field_values v.ty
(** [inside]: controls the introduction of parentheses. See [extract_ty]
@@ -2173,12 +2210,13 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) :
unit =
let e_ty = Adt (adt_cons.adt_id, type_args) in
+ let is_single_pat = false in
let _ =
extract_adt_g_value
(fun ctx inside e ->
extract_texpression ctx fmt inside e;
ctx)
- fmt ctx inside adt_cons.variant_id args e_ty
+ fmt ctx is_single_pat inside adt_cons.variant_id args e_ty
in
()
@@ -2226,7 +2264,7 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
List.fold_left
(fun ctx x ->
F.pp_print_space fmt ();
- extract_typed_pattern ctx fmt true x)
+ extract_typed_pattern ctx fmt true true x)
ctx xl
in
F.pp_print_space fmt ();
@@ -2295,7 +2333,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
* TODO: cleanup
* *)
if monadic && (!backend = Coq || !backend = HOL4) then (
- let ctx = extract_typed_pattern ctx fmt true lv in
+ let ctx = extract_typed_pattern ctx fmt true true lv in
F.pp_print_space fmt ();
let arrow =
match !backend with
@@ -2321,7 +2359,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
else (
F.pp_print_string fmt "let";
F.pp_print_space fmt ());
- let ctx = extract_typed_pattern ctx fmt true lv in
+ let ctx = extract_typed_pattern ctx fmt true true lv in
F.pp_print_space fmt ();
let eq =
match !backend with
@@ -2468,7 +2506,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool)
match !backend with
| FStar -> "begin match"
| Coq -> "match"
- | Lean -> "match h:"
+ | Lean -> if ctx.use_dep_ite then "match h:" else "match"
| HOL4 ->
(* We're being extra safe in the case of HOL4 *)
"(case"
@@ -2495,7 +2533,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool)
(* Print the pattern *)
F.pp_print_string fmt "|";
F.pp_print_space fmt ();
- let ctx = extract_typed_pattern ctx fmt false br.pat in
+ let ctx = extract_typed_pattern ctx fmt false false br.pat in
F.pp_print_space fmt ();
let arrow =
match !backend with FStar -> "->" | Coq | Lean | HOL4 -> "=>"
@@ -2687,7 +2725,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
(* Open a box for the input parameter *)
F.pp_open_hovbox fmt 0;
F.pp_print_string fmt "(";
- let ctx = extract_typed_pattern ctx fmt false lv in
+ let ctx = extract_typed_pattern ctx fmt true false lv in
F.pp_print_space fmt ();
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
@@ -3032,7 +3070,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
List.fold_left
(fun ctx (lv : typed_pattern) ->
F.pp_print_space fmt ();
- let ctx = extract_typed_pattern ctx fmt false lv in
+ let ctx = extract_typed_pattern ctx fmt true false lv in
ctx)
ctx inputs_lvs
in