summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-07-04 17:30:35 +0200
committerSon Ho2023-07-04 17:30:35 +0200
commitbd873499f9a8d517cc948c6336a5c6ce856d846d (patch)
tree0e4fc5eda91c9d34c27790286a6098dc937e79b9 /compiler
parent87d6f6c7c90bf7b427397d6bd2e2c70d610678e3 (diff)
Fix some issues with the extraction to Lean
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml134
1 files changed, 86 insertions, 48 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index a54a2299..b18d4743 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -618,9 +618,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let struct_constructor (basename : name) : string =
let tname = type_name basename in
let prefix =
- match !backend with FStar -> "Mk" | Lean | Coq | HOL4 -> "mk"
+ match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> ""
in
- prefix ^ tname
+ let suffix =
+ match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk"
+ in
+ prefix ^ tname ^ suffix
in
let get_fun_name = get_name in
let fun_name_to_snake_case (fname : fun_name) : string =
@@ -1326,7 +1329,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt (unit_name ()))
else if !backend = Lean && fields = [] then ()
(* If the definition is recursive, we may need to extract it as an inductive
- (instead of a record) *)
+ (instead of a record). We start with the "normal" case: we extract it
+ as a record. *)
else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then (
if !backend <> Lean then F.pp_print_space fmt ();
(* If Coq: print the constructor name *)
@@ -1379,7 +1383,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
a group of mutually recursive types: we extract it as an inductive type *)
assert (is_rec && (!backend = Coq || !backend = Lean));
let with_opaque_pre = false in
- let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in
+ (* Small trick: in Lean we use namespaces, meaning we don't need to prefix
+ the constructor name with the name of the type at definition site,
+ i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq
+ we generate `inductive Foo := | mk ... *)
+ let cons_name =
+ if !backend = Lean then "mk"
+ else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx
+ in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
extract_type_decl_variant ctx fmt type_decl_group def_name type_params
cons_name fields)
@@ -1950,14 +1961,17 @@ let extract_global_decl_register_names (ctx : extraction_ctx)
Note that patterns can introduce new variables: we thus return an extraction
context updated with new bindings.
+ [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding
+ or a lambda).
+
TODO: we don't need something very generic anymore (some definitions used
to be polymorphic).
*)
let extract_adt_g_value
(extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx)
- (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool)
- (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
- extraction_ctx =
+ (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool)
+ (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
+ (ty : ty) : extraction_ctx =
match ty with
| Adt (Tuple, type_args) ->
(* Tuple *)
@@ -1982,36 +1996,57 @@ let extract_adt_g_value
ctx)
| Adt (adt_id, _) ->
(* "Regular" ADT *)
- (* We print something of the form: [Cons field0 ... fieldn].
- * We could update the code to print something of the form:
- * [{ field0=...; ...; fieldn=...; }] in case of structures.
- *)
- let cons =
- (* The ADT shouldn't be opaque *)
- let with_opaque_pre = false in
- match variant_id with
- | Some vid -> (
- (* In the case of Lean, we might have to add the type name as a prefix *)
- match (!backend, adt_id) with
- | Lean, Assumed _ ->
- ctx_get_type with_opaque_pre adt_id ctx
- ^ "."
- ^ ctx_get_variant adt_id vid ctx
- | _ -> ctx_get_variant adt_id vid ctx)
- | None -> ctx_get_struct with_opaque_pre adt_id ctx
- in
- let use_parentheses = inside && field_values <> [] in
- if use_parentheses then F.pp_print_string fmt "(";
- F.pp_print_string fmt cons;
- let ctx =
- Collections.List.fold_left
- (fun ctx v ->
- F.pp_print_space fmt ();
- extract_value ctx true v)
- ctx field_values
- in
- if use_parentheses then F.pp_print_string fmt ")";
- ctx
+
+ (* If we are generating a pattern for a let-binding and we target Lean,
+ the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`.
+
+ Otherwise, it is: `let Cons x0 ... xn = ...`
+ *)
+ if is_single_pat && !Config.backend = Lean then (
+ F.pp_print_string fmt "⟨";
+ F.pp_print_space fmt ();
+ let ctx =
+ Collections.List.fold_left_link
+ (fun _ ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx true v)
+ ctx field_values
+ in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "⟩";
+ ctx)
+ else
+ (* We print something of the form: [Cons field0 ... fieldn].
+ * We could update the code to print something of the form:
+ * [{ field0=...; ...; fieldn=...; }] in case of structures.
+ *)
+ let cons =
+ (* The ADT shouldn't be opaque *)
+ let with_opaque_pre = false in
+ match variant_id with
+ | Some vid -> (
+ (* In the case of Lean, we might have to add the type name as a prefix *)
+ match (!backend, adt_id) with
+ | Lean, Assumed _ ->
+ ctx_get_type with_opaque_pre adt_id ctx
+ ^ "."
+ ^ ctx_get_variant adt_id vid ctx
+ | _ -> ctx_get_variant adt_id vid ctx)
+ | None -> ctx_get_struct with_opaque_pre adt_id ctx
+ in
+ let use_parentheses = inside && field_values <> [] in
+ if use_parentheses then F.pp_print_string fmt "(";
+ F.pp_print_string fmt cons;
+ let ctx =
+ Collections.List.fold_left
+ (fun ctx v ->
+ F.pp_print_space fmt ();
+ extract_value ctx true v)
+ ctx field_values
+ in
+ if use_parentheses then F.pp_print_string fmt ")";
+ ctx
| _ -> raise (Failure "Inconsistent typed value")
(* Extract globals in the same way as variables *)
@@ -2026,7 +2061,7 @@ let extract_global (ctx : extraction_ctx) (fmt : F.formatter)
updated with new bindings.
*)
let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (v : typed_pattern) : extraction_ctx =
+ (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx =
match v.value with
| PatConstant cv ->
ctx.fmt.extract_primitive_value fmt inside cv;
@@ -2042,8 +2077,10 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "_";
ctx
| PatAdt av ->
- let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in
- extract_adt_g_value extract_value fmt ctx inside av.variant_id
+ let extract_value ctx inside v =
+ extract_typed_pattern ctx fmt is_let inside v
+ in
+ extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
av.field_values v.ty
(** [inside]: controls the introduction of parentheses. See [extract_ty]
@@ -2173,12 +2210,13 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) :
unit =
let e_ty = Adt (adt_cons.adt_id, type_args) in
+ let is_single_pat = false in
let _ =
extract_adt_g_value
(fun ctx inside e ->
extract_texpression ctx fmt inside e;
ctx)
- fmt ctx inside adt_cons.variant_id args e_ty
+ fmt ctx is_single_pat inside adt_cons.variant_id args e_ty
in
()
@@ -2226,7 +2264,7 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
List.fold_left
(fun ctx x ->
F.pp_print_space fmt ();
- extract_typed_pattern ctx fmt true x)
+ extract_typed_pattern ctx fmt true true x)
ctx xl
in
F.pp_print_space fmt ();
@@ -2295,7 +2333,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
* TODO: cleanup
* *)
if monadic && (!backend = Coq || !backend = HOL4) then (
- let ctx = extract_typed_pattern ctx fmt true lv in
+ let ctx = extract_typed_pattern ctx fmt true true lv in
F.pp_print_space fmt ();
let arrow =
match !backend with
@@ -2321,7 +2359,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
else (
F.pp_print_string fmt "let";
F.pp_print_space fmt ());
- let ctx = extract_typed_pattern ctx fmt true lv in
+ let ctx = extract_typed_pattern ctx fmt true true lv in
F.pp_print_space fmt ();
let eq =
match !backend with
@@ -2468,7 +2506,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool)
match !backend with
| FStar -> "begin match"
| Coq -> "match"
- | Lean -> "match h:"
+ | Lean -> if ctx.use_dep_ite then "match h:" else "match"
| HOL4 ->
(* We're being extra safe in the case of HOL4 *)
"(case"
@@ -2495,7 +2533,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool)
(* Print the pattern *)
F.pp_print_string fmt "|";
F.pp_print_space fmt ();
- let ctx = extract_typed_pattern ctx fmt false br.pat in
+ let ctx = extract_typed_pattern ctx fmt false false br.pat in
F.pp_print_space fmt ();
let arrow =
match !backend with FStar -> "->" | Coq | Lean | HOL4 -> "=>"
@@ -2687,7 +2725,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
(* Open a box for the input parameter *)
F.pp_open_hovbox fmt 0;
F.pp_print_string fmt "(";
- let ctx = extract_typed_pattern ctx fmt false lv in
+ let ctx = extract_typed_pattern ctx fmt true false lv in
F.pp_print_space fmt ();
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
@@ -3032,7 +3070,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
List.fold_left
(fun ctx (lv : typed_pattern) ->
F.pp_print_space fmt ();
- let ctx = extract_typed_pattern ctx fmt false lv in
+ let ctx = extract_typed_pattern ctx fmt true false lv in
ctx)
ctx inputs_lvs
in