diff options
author | Son Ho | 2023-07-04 17:30:35 +0200 |
---|---|---|
committer | Son Ho | 2023-07-04 17:30:35 +0200 |
commit | bd873499f9a8d517cc948c6336a5c6ce856d846d (patch) | |
tree | 0e4fc5eda91c9d34c27790286a6098dc937e79b9 | |
parent | 87d6f6c7c90bf7b427397d6bd2e2c70d610678e3 (diff) |
Fix some issues with the extraction to Lean
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 63 | ||||
-rw-r--r-- | compiler/Extract.ml | 134 |
2 files changed, 133 insertions, 64 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 41209021..4b08fe44 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) preDefs.mapM fun preDef => do -- Replace the recursive calls let body ← mapVisit visit_e preDef.value + trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" -- Currify the function by grouping the arguments into a dependent tuple -- (over which we match to retrieve the individual arguments). - lambdaLetTelescope body fun args body => do + lambdaTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration @@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE dName dTy dValue body _nonDep => do - -- Introduce a local declaration for the let-binding - withLetDecl dName dTy dValue fun decl => do + | .letE .. => do + -- Telescope all the let-bindings (remark: this also telescopes the lambdas) + lambdaLetTelescope e fun xs body => do + -- Note that we don't visit the bound values: there shouldn't be + -- recursive calls, lambda expressions, etc. inside + -- Prove that the body is valid let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around. + -- Add the let-bindings around. -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, -- but because it reduces in the end it doesn't matter. More precisely: -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. - mkLetFVars #[decl] isValid + mkLambdaFVars xs isValid (usedLetOnly := false) | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do if isIte then proveExprIsValid k_var kk_var br else do -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let br ← mkLambdaFVars xs br @@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope y fun xs y => do + lambdaTelescope y fun xs y => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let y ← mkLambdaFVars xs y @@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx let xs_beg := xs.extract 0 numParams let xs_end := xs.extract numParams xs.size @@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid let env ← getEnv let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" - lambdaLetTelescope body fun xs body => do + lambdaTelescope body fun xs body => do assert! xs.size = 2 let kk_var := xs.get! 0 let x_var := xs.get! 1 @@ -695,8 +699,10 @@ def proveMutRecIsValid let bodiesValid ← bodies.mapIdxM fun idx body => do let preDef := preDefs.get! idx + trace[Diverge.def.valid] "## Proving that the body {body} is valid" proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid + trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid -- Save the theorem let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] @@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do - lambdaLetTelescope preDef.value fun xs _ => do + lambdaTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple @@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let preDef := preDefs.get! i let defName := decls.get! i -- Retrieve the arguments - lambdaLetTelescope preDef.value fun xs body => do + lambdaTelescope preDef.value fun xs body => do trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" -- The theorem statement @@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - trace[Diverge.def] ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) -- TODO: what is this? for preDef in preDefs do @@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations + trace[Diverge.def] "# Generating the unary bodies" let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body + trace[Diverge.def] "# Generating the mut rec body" let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" @@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do + trace[Diverge.def] "# Proving that the mut rec body is valid" let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions + trace[Diverge.def] "# Generating the final definitions" let decls ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding theorems + trace[Diverge.def] "# Proving the unfolding theorems" proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions - TODO + -- Generating code -- TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -1064,13 +1076,32 @@ namespace Tests -- Testing dependent branching and let-bindings -- TODO: why the linter warning? - divergent def is_non_zero (i : Int) : Result Bool := + divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else let b := true return b - #check is_non_zero.unfold + #check isNonZero.unfold + + -- Testing let-bindings + divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := + let i0 := ls.length + if i < i0 + then Result.ret True + else Result.ret False + + #check iInBounds.unfold + + divergent def isCons + {a : Type} (ls : List a) : Result Bool := + let ls1 := ls + match ls1 with + | [] => Result.ret False + | x :: tl => Result.ret True + + #check isCons.unfold + end Tests end Diverge diff --git a/compiler/Extract.ml b/compiler/Extract.ml index a54a2299..b18d4743 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -618,9 +618,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let struct_constructor (basename : name) : string = let tname = type_name basename in let prefix = - match !backend with FStar -> "Mk" | Lean | Coq | HOL4 -> "mk" + match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> "" in - prefix ^ tname + let suffix = + match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" + in + prefix ^ tname ^ suffix in let get_fun_name = get_name in let fun_name_to_snake_case (fname : fun_name) : string = @@ -1326,7 +1329,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt (unit_name ())) else if !backend = Lean && fields = [] then () (* If the definition is recursive, we may need to extract it as an inductive - (instead of a record) *) + (instead of a record). We start with the "normal" case: we extract it + as a record. *) else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then ( if !backend <> Lean then F.pp_print_space fmt (); (* If Coq: print the constructor name *) @@ -1379,7 +1383,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && (!backend = Coq || !backend = Lean)); let with_opaque_pre = false in - let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in + (* Small trick: in Lean we use namespaces, meaning we don't need to prefix + the constructor name with the name of the type at definition site, + i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq + we generate `inductive Foo := | mk ... *) + let cons_name = + if !backend = Lean then "mk" + else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx + in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt type_decl_group def_name type_params cons_name fields) @@ -1950,14 +1961,17 @@ let extract_global_decl_register_names (ctx : extraction_ctx) Note that patterns can introduce new variables: we thus return an extraction context updated with new bindings. + [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding + or a lambda). + TODO: we don't need something very generic anymore (some definitions used to be polymorphic). *) let extract_adt_g_value (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) - (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) - (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : - extraction_ctx = + (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool) + (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) + (ty : ty) : extraction_ctx = match ty with | Adt (Tuple, type_args) -> (* Tuple *) @@ -1982,36 +1996,57 @@ let extract_adt_g_value ctx) | Adt (adt_id, _) -> (* "Regular" ADT *) - (* We print something of the form: [Cons field0 ... fieldn]. - * We could update the code to print something of the form: - * [{ field0=...; ...; fieldn=...; }] in case of structures. - *) - let cons = - (* The ADT shouldn't be opaque *) - let with_opaque_pre = false in - match variant_id with - | Some vid -> ( - (* In the case of Lean, we might have to add the type name as a prefix *) - match (!backend, adt_id) with - | Lean, Assumed _ -> - ctx_get_type with_opaque_pre adt_id ctx - ^ "." - ^ ctx_get_variant adt_id vid ctx - | _ -> ctx_get_variant adt_id vid ctx) - | None -> ctx_get_struct with_opaque_pre adt_id ctx - in - let use_parentheses = inside && field_values <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - let ctx = - Collections.List.fold_left - (fun ctx v -> - F.pp_print_space fmt (); - extract_value ctx true v) - ctx field_values - in - if use_parentheses then F.pp_print_string fmt ")"; - ctx + + (* If we are generating a pattern for a let-binding and we target Lean, + the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. + + Otherwise, it is: `let Cons x0 ... xn = ...` + *) + if is_single_pat && !Config.backend = Lean then ( + F.pp_print_string fmt "⟨"; + F.pp_print_space fmt (); + let ctx = + Collections.List.fold_left_link + (fun _ -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx true v) + ctx field_values + in + F.pp_print_space fmt (); + F.pp_print_string fmt "⟩"; + ctx) + else + (* We print something of the form: [Cons field0 ... fieldn]. + * We could update the code to print something of the form: + * [{ field0=...; ...; fieldn=...; }] in case of structures. + *) + let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in + match variant_id with + | Some vid -> ( + (* In the case of Lean, we might have to add the type name as a prefix *) + match (!backend, adt_id) with + | Lean, Assumed _ -> + ctx_get_type with_opaque_pre adt_id ctx + ^ "." + ^ ctx_get_variant adt_id vid ctx + | _ -> ctx_get_variant adt_id vid ctx) + | None -> ctx_get_struct with_opaque_pre adt_id ctx + in + let use_parentheses = inside && field_values <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + let ctx = + Collections.List.fold_left + (fun ctx v -> + F.pp_print_space fmt (); + extract_value ctx true v) + ctx field_values + in + if use_parentheses then F.pp_print_string fmt ")"; + ctx | _ -> raise (Failure "Inconsistent typed value") (* Extract globals in the same way as variables *) @@ -2026,7 +2061,7 @@ let extract_global (ctx : extraction_ctx) (fmt : F.formatter) updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_pattern) : extraction_ctx = + (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = match v.value with | PatConstant cv -> ctx.fmt.extract_primitive_value fmt inside cv; @@ -2042,8 +2077,10 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "_"; ctx | PatAdt av -> - let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in - extract_adt_g_value extract_value fmt ctx inside av.variant_id + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id av.field_values v.ty (** [inside]: controls the introduction of parentheses. See [extract_ty] @@ -2173,12 +2210,13 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : unit = let e_ty = Adt (adt_cons.adt_id, type_args) in + let is_single_pat = false in let _ = extract_adt_g_value (fun ctx inside e -> extract_texpression ctx fmt inside e; ctx) - fmt ctx inside adt_cons.variant_id args e_ty + fmt ctx is_single_pat inside adt_cons.variant_id args e_ty in () @@ -2226,7 +2264,7 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true x) + extract_typed_pattern ctx fmt true true x) ctx xl in F.pp_print_space fmt (); @@ -2295,7 +2333,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) * TODO: cleanup * *) if monadic && (!backend = Coq || !backend = HOL4) then ( - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let arrow = match !backend with @@ -2321,7 +2359,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) else ( F.pp_print_string fmt "let"; F.pp_print_space fmt ()); - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let eq = match !backend with @@ -2468,7 +2506,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) match !backend with | FStar -> "begin match" | Coq -> "match" - | Lean -> "match h:" + | Lean -> if ctx.use_dep_ite then "match h:" else "match" | HOL4 -> (* We're being extra safe in the case of HOL4 *) "(case" @@ -2495,7 +2533,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) (* Print the pattern *) F.pp_print_string fmt "|"; F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false br.pat in + let ctx = extract_typed_pattern ctx fmt false false br.pat in F.pp_print_space fmt (); let arrow = match !backend with FStar -> "->" | Coq | Lean | HOL4 -> "=>" @@ -2687,7 +2725,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (* Open a box for the input parameter *) F.pp_open_hovbox fmt 0; F.pp_print_string fmt "("; - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); @@ -3032,7 +3070,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) List.fold_left (fun ctx (lv : typed_pattern) -> F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in ctx) ctx inputs_lvs in |