From bd873499f9a8d517cc948c6336a5c6ce856d846d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 17:30:35 +0200 Subject: Fix some issues with the extraction to Lean --- compiler/Extract.ml | 134 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 86 insertions(+), 48 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index a54a2299..b18d4743 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -618,9 +618,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let struct_constructor (basename : name) : string = let tname = type_name basename in let prefix = - match !backend with FStar -> "Mk" | Lean | Coq | HOL4 -> "mk" + match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> "" in - prefix ^ tname + let suffix = + match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" + in + prefix ^ tname ^ suffix in let get_fun_name = get_name in let fun_name_to_snake_case (fname : fun_name) : string = @@ -1326,7 +1329,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt (unit_name ())) else if !backend = Lean && fields = [] then () (* If the definition is recursive, we may need to extract it as an inductive - (instead of a record) *) + (instead of a record). We start with the "normal" case: we extract it + as a record. *) else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then ( if !backend <> Lean then F.pp_print_space fmt (); (* If Coq: print the constructor name *) @@ -1379,7 +1383,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && (!backend = Coq || !backend = Lean)); let with_opaque_pre = false in - let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in + (* Small trick: in Lean we use namespaces, meaning we don't need to prefix + the constructor name with the name of the type at definition site, + i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq + we generate `inductive Foo := | mk ... *) + let cons_name = + if !backend = Lean then "mk" + else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx + in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt type_decl_group def_name type_params cons_name fields) @@ -1950,14 +1961,17 @@ let extract_global_decl_register_names (ctx : extraction_ctx) Note that patterns can introduce new variables: we thus return an extraction context updated with new bindings. + [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding + or a lambda). + TODO: we don't need something very generic anymore (some definitions used to be polymorphic). *) let extract_adt_g_value (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) - (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) - (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : - extraction_ctx = + (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool) + (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) + (ty : ty) : extraction_ctx = match ty with | Adt (Tuple, type_args) -> (* Tuple *) @@ -1982,36 +1996,57 @@ let extract_adt_g_value ctx) | Adt (adt_id, _) -> (* "Regular" ADT *) - (* We print something of the form: [Cons field0 ... fieldn]. - * We could update the code to print something of the form: - * [{ field0=...; ...; fieldn=...; }] in case of structures. - *) - let cons = - (* The ADT shouldn't be opaque *) - let with_opaque_pre = false in - match variant_id with - | Some vid -> ( - (* In the case of Lean, we might have to add the type name as a prefix *) - match (!backend, adt_id) with - | Lean, Assumed _ -> - ctx_get_type with_opaque_pre adt_id ctx - ^ "." - ^ ctx_get_variant adt_id vid ctx - | _ -> ctx_get_variant adt_id vid ctx) - | None -> ctx_get_struct with_opaque_pre adt_id ctx - in - let use_parentheses = inside && field_values <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - let ctx = - Collections.List.fold_left - (fun ctx v -> - F.pp_print_space fmt (); - extract_value ctx true v) - ctx field_values - in - if use_parentheses then F.pp_print_string fmt ")"; - ctx + + (* If we are generating a pattern for a let-binding and we target Lean, + the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. + + Otherwise, it is: `let Cons x0 ... xn = ...` + *) + if is_single_pat && !Config.backend = Lean then ( + F.pp_print_string fmt "⟨"; + F.pp_print_space fmt (); + let ctx = + Collections.List.fold_left_link + (fun _ -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx true v) + ctx field_values + in + F.pp_print_space fmt (); + F.pp_print_string fmt "⟩"; + ctx) + else + (* We print something of the form: [Cons field0 ... fieldn]. + * We could update the code to print something of the form: + * [{ field0=...; ...; fieldn=...; }] in case of structures. + *) + let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in + match variant_id with + | Some vid -> ( + (* In the case of Lean, we might have to add the type name as a prefix *) + match (!backend, adt_id) with + | Lean, Assumed _ -> + ctx_get_type with_opaque_pre adt_id ctx + ^ "." + ^ ctx_get_variant adt_id vid ctx + | _ -> ctx_get_variant adt_id vid ctx) + | None -> ctx_get_struct with_opaque_pre adt_id ctx + in + let use_parentheses = inside && field_values <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + let ctx = + Collections.List.fold_left + (fun ctx v -> + F.pp_print_space fmt (); + extract_value ctx true v) + ctx field_values + in + if use_parentheses then F.pp_print_string fmt ")"; + ctx | _ -> raise (Failure "Inconsistent typed value") (* Extract globals in the same way as variables *) @@ -2026,7 +2061,7 @@ let extract_global (ctx : extraction_ctx) (fmt : F.formatter) updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_pattern) : extraction_ctx = + (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = match v.value with | PatConstant cv -> ctx.fmt.extract_primitive_value fmt inside cv; @@ -2042,8 +2077,10 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "_"; ctx | PatAdt av -> - let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in - extract_adt_g_value extract_value fmt ctx inside av.variant_id + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id av.field_values v.ty (** [inside]: controls the introduction of parentheses. See [extract_ty] @@ -2173,12 +2210,13 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : unit = let e_ty = Adt (adt_cons.adt_id, type_args) in + let is_single_pat = false in let _ = extract_adt_g_value (fun ctx inside e -> extract_texpression ctx fmt inside e; ctx) - fmt ctx inside adt_cons.variant_id args e_ty + fmt ctx is_single_pat inside adt_cons.variant_id args e_ty in () @@ -2226,7 +2264,7 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true x) + extract_typed_pattern ctx fmt true true x) ctx xl in F.pp_print_space fmt (); @@ -2295,7 +2333,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) * TODO: cleanup * *) if monadic && (!backend = Coq || !backend = HOL4) then ( - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let arrow = match !backend with @@ -2321,7 +2359,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) else ( F.pp_print_string fmt "let"; F.pp_print_space fmt ()); - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let eq = match !backend with @@ -2468,7 +2506,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) match !backend with | FStar -> "begin match" | Coq -> "match" - | Lean -> "match h:" + | Lean -> if ctx.use_dep_ite then "match h:" else "match" | HOL4 -> (* We're being extra safe in the case of HOL4 *) "(case" @@ -2495,7 +2533,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) (* Print the pattern *) F.pp_print_string fmt "|"; F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false br.pat in + let ctx = extract_typed_pattern ctx fmt false false br.pat in F.pp_print_space fmt (); let arrow = match !backend with FStar -> "->" | Coq | Lean | HOL4 -> "=>" @@ -2687,7 +2725,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (* Open a box for the input parameter *) F.pp_open_hovbox fmt 0; F.pp_print_string fmt "("; - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); @@ -3032,7 +3070,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) List.fold_left (fun ctx (lv : typed_pattern) -> F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in ctx) ctx inputs_lvs in -- cgit v1.2.3