summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-07 12:07:39 +0100
committerSon Ho2023-12-07 12:07:39 +0100
commit0209fee47a11b371d258fe02b8cc59b325de21d6 (patch)
tree9e23c2618c7138a02be28310eb8deaac2b4b3c5c /compiler
parenteb05c2e3b63377c323c33c1296495baa9357596a (diff)
Use a better syntax when extracting tuple types (structures with unnamed fields)
Diffstat (limited to '')
-rw-r--r--compiler/Config.ml18
-rw-r--r--compiler/Extract.ml74
-rw-r--r--compiler/ExtractBase.ml17
-rw-r--r--compiler/ExtractTypes.ml230
-rw-r--r--compiler/InterpreterBorrows.ml19
-rw-r--r--compiler/PureMicroPasses.ml82
-rw-r--r--compiler/PureUtils.ml11
-rw-r--r--compiler/SymbolicToPure.ml30
-rw-r--r--compiler/TypesAnalysis.ml47
-rw-r--r--compiler/TypesUtils.ml18
10 files changed, 347 insertions, 199 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 364ef748..b09544ba 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -338,7 +338,7 @@ let type_check_pure_code = ref false
as far as possible while leaving "holes" in the generated code? *)
let fail_hard = ref true
-(** if true, add the type name as a prefix
+(** If true, add the type name as a prefix
to the variant names.
Ex.:
In Rust:
@@ -364,3 +364,19 @@ let fail_hard = ref true
]}
*)
let variant_concatenate_type_name = ref true
+
+(** If true, extract the structures with unnamed fields as tuples.
+
+ ex.:
+ {[
+ // Rust
+ struct Foo(u32)
+
+ // OCaml
+ type Foo = (u32)
+ ]}
+ *)
+let use_tuple_structs = ref true
+
+let backend_has_tuple_projectors () =
+ match !backend with Lean -> true | Coq | FStar | HOL4 -> false
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index e48e6ae6..85bdd929 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -111,7 +111,7 @@ let extract_global_decl_register_names (ctx : extraction_ctx)
context updated with new bindings.
[is_single_pat]: are we extracting a single pattern (a pattern for a let-binding
- or a lambda).
+ or a lambda)?
TODO: we don't need something very generic anymore (some definitions used
to be polymorphic).
@@ -121,38 +121,53 @@ let extract_adt_g_value
(fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool)
(inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
(ty : ty) : extraction_ctx =
+ let extract_as_tuple () =
+ (* This is very annoying: in Coq, we can't write [()] for the value of
+ type [unit], we have to write [tt]. *)
+ if !backend = Coq && field_values = [] then (
+ F.pp_print_string fmt "tt";
+ ctx)
+ else
+ (* If there is exactly one value, we don't print the parentheses *)
+ let lb, rb =
+ if List.length field_values = 1 then ("", "") else ("(", ")")
+ in
+ F.pp_print_string fmt lb;
+ let ctx =
+ Collections.List.fold_left_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx false v)
+ ctx field_values
+ in
+ F.pp_print_string fmt rb;
+ ctx
+ in
match ty with
| TAdt (TTuple, generics) ->
(* Tuple *)
(* For now, we only support fully applied tuple constructors *)
assert (List.length generics.types = List.length field_values);
assert (generics.const_generics = [] && generics.trait_refs = []);
- (* This is very annoying: in Coq, we can't write [()] for the value of
- type [unit], we have to write [tt]. *)
- if !backend = Coq && field_values = [] then (
- F.pp_print_string fmt "tt";
- ctx)
- else (
- F.pp_print_string fmt "(";
- let ctx =
- Collections.List.fold_left_link
- (fun () ->
- F.pp_print_string fmt ",";
- F.pp_print_space fmt ())
- (fun ctx v -> extract_value ctx false v)
- ctx field_values
- in
- F.pp_print_string fmt ")";
- ctx)
+ extract_as_tuple ()
| TAdt (adt_id, _) ->
(* "Regular" ADT *)
-
- (* If we are generating a pattern for a let-binding and we target Lean,
- the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`.
-
- Otherwise, it is: `let Cons x0 ... xn = ...`
- *)
- if is_single_pat && !Config.backend = Lean then (
+ (* We may still extract the ADT as a tuple, if none of the fields are
+ named *)
+ if
+ PureUtils.type_decl_from_type_id_is_tuple_struct
+ ctx.trans_ctx.type_ctx.type_infos adt_id
+ then (* Extract as a tuple *)
+ extract_as_tuple ()
+ else if
+ (* If we are generating a pattern for a let-binding and we target Lean,
+ the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`.
+
+ Otherwise, it is: `let Cons x0 ... xn = ...`
+ *)
+ is_single_pat && !Config.backend = Lean
+ then (
F.pp_print_string fmt "⟨";
F.pp_print_space fmt ();
let ctx =
@@ -517,7 +532,14 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
match args with
| [ arg ] ->
(* Exactly one argument: pretty-print *)
- let field_name = ctx_get_field proj.adt_id proj.field_id ctx in
+ let field_name =
+ (* Check if we need to extract the type as a structure *)
+ if
+ PureUtils.type_decl_from_type_id_is_tuple_struct
+ ctx.trans_ctx.type_ctx.type_infos proj.adt_id
+ then FieldId.to_string proj.field_id
+ else ctx_get_field proj.adt_id proj.field_id ctx
+ in
(* Open a box *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Extract the expression *)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 43658b6e..93204515 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -109,7 +109,7 @@ let decl_is_first_from_group (kind : decl_kind) : bool =
let decl_is_not_last_from_group (kind : decl_kind) : bool =
not (decl_is_last_from_group kind)
-type type_decl_kind = Enum | Struct [@@deriving show]
+type type_decl_kind = Enum | Struct | Tuple [@@deriving show]
(** We use identifiers to look for name clashes *)
type id =
@@ -1194,12 +1194,13 @@ let type_decl_kind_to_qualif (kind : decl_kind)
| Declared -> Some "val")
| Coq -> (
match (kind, type_kind) with
+ | SingleNonRec, Some Tuple -> Some "Definition"
| SingleNonRec, Some Enum -> Some "Inductive"
| SingleNonRec, Some Struct -> Some "Record"
| (SingleRec | MutRecFirst), Some _ -> Some "Inductive"
| (MutRecInner | MutRecLast), Some _ ->
(* Coq doesn't support groups of mutually recursive definitions which mix
- * records and inducties: we convert everything to records if this happens
+ * records and inductives: we convert everything to records if this happens
*)
Some "with"
| (Assumed | Declared), None -> Some "Axiom"
@@ -1214,12 +1215,12 @@ let type_decl_kind_to_qualif (kind : decl_kind)
^ ")")))
| Lean -> (
match kind with
- | SingleNonRec ->
- if type_kind = Some Struct then Some "structure" else Some "inductive"
- | SingleRec -> Some "inductive"
- | MutRecFirst -> Some "inductive"
- | MutRecInner -> Some "inductive"
- | MutRecLast -> Some "inductive"
+ | SingleNonRec -> (
+ match type_kind with
+ | Some Tuple -> Some "def"
+ | Some Struct -> Some "structure"
+ | _ -> Some "inductive")
+ | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> Some "inductive"
| Assumed -> Some "axiom"
| Declared -> Some "axiom")
| HOL4 -> None
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index 3657627b..22243a4a 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -1,7 +1,4 @@
(** The generic extraction *)
-(* Turn the whole module into a functor: it is very annoying to carry the
- the formatter everywhere...
-*)
open Pure
open PureUtils
@@ -696,92 +693,101 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
* - the field names, if this is a structure
*)
let ctx =
- match def.kind with
- | Struct fields ->
- (* Compute the names *)
- let field_names, cons_name =
- match info with
- | None | Some { body_info = None; _ } ->
- let field_names =
- FieldId.mapi
- (fun fid (field : field) ->
- ( fid,
- ctx_compute_field_name ctx def.llbc_name fid
- field.field_name ))
- fields
- in
- let cons_name =
- ctx_compute_struct_constructor ctx def.llbc_name
- in
- (field_names, cons_name)
- | Some { body_info = Some (Struct (cons_name, field_names)); _ } ->
- let field_names =
- FieldId.mapi
- (fun fid (field : field) ->
- let rust_name = Option.get field.field_name in
+ (* Ignore this if the type is to be extracted as a tuple *)
+ if
+ TypesUtils.type_decl_from_decl_id_is_tuple_struct
+ ctx.trans_ctx.type_ctx.type_infos def.def_id
+ then ctx
+ else
+ match def.kind with
+ | Struct fields ->
+ (* Compute the names *)
+ let field_names, cons_name =
+ match info with
+ | None | Some { body_info = None; _ } ->
+ let field_names =
+ FieldId.mapi
+ (fun fid (field : field) ->
+ ( fid,
+ ctx_compute_field_name ctx def.llbc_name fid
+ field.field_name ))
+ fields
+ in
+ let cons_name =
+ ctx_compute_struct_constructor ctx def.llbc_name
+ in
+ (field_names, cons_name)
+ | Some { body_info = Some (Struct (cons_name, field_names)); _ } ->
+ let field_names =
+ FieldId.mapi
+ (fun fid (field : field) ->
+ let rust_name = Option.get field.field_name in
+ let name =
+ snd
+ (List.find (fun (n, _) -> n = rust_name) field_names)
+ in
+ (fid, name))
+ fields
+ in
+ (field_names, cons_name)
+ | Some info ->
+ raise
+ (Failure
+ ("Invalid builtin information: "
+ ^ show_builtin_type_info info))
+ in
+ (* Add the fields *)
+ let ctx =
+ List.fold_left
+ (fun ctx (fid, name) ->
+ ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx)
+ ctx field_names
+ in
+ (* Add the constructor name *)
+ ctx_add (StructId (TAdtId def.def_id)) cons_name ctx
+ | Enum variants ->
+ let variant_names =
+ match info with
+ | None ->
+ VariantId.mapi
+ (fun variant_id (variant : variant) ->
let name =
- snd (List.find (fun (n, _) -> n = rust_name) field_names)
+ ctx_compute_variant_name ctx def.llbc_name
+ variant.variant_name
in
- (fid, name))
- fields
- in
- (field_names, cons_name)
- | Some info ->
- raise
- (Failure
- ("Invalid builtin information: "
- ^ show_builtin_type_info info))
- in
- (* Add the fields *)
- let ctx =
+ (* Add the type name prefix for Lean *)
+ let name =
+ if !Config.backend = Lean then
+ let type_name =
+ ctx_compute_type_name ctx def.llbc_name
+ in
+ type_name ^ "." ^ name
+ else name
+ in
+ (variant_id, name))
+ variants
+ | Some { body_info = Some (Enum variant_infos); _ } ->
+ (* We need to compute the map from variant to variant *)
+ let variant_map =
+ StringMap.of_list
+ (List.map
+ (fun (info : builtin_enum_variant_info) ->
+ (info.rust_variant_name, info.extract_variant_name))
+ variant_infos)
+ in
+ VariantId.mapi
+ (fun variant_id (variant : variant) ->
+ (variant_id, StringMap.find variant.variant_name variant_map))
+ variants
+ | _ -> raise (Failure "Invalid builtin information")
+ in
List.fold_left
- (fun ctx (fid, name) ->
- ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx)
- ctx field_names
- in
- (* Add the constructor name *)
- ctx_add (StructId (TAdtId def.def_id)) cons_name ctx
- | Enum variants ->
- let variant_names =
- match info with
- | None ->
- VariantId.mapi
- (fun variant_id (variant : variant) ->
- let name =
- ctx_compute_variant_name ctx def.llbc_name
- variant.variant_name
- in
- (* Add the type name prefix for Lean *)
- let name =
- if !Config.backend = Lean then
- let type_name = ctx_compute_type_name ctx def.llbc_name in
- type_name ^ "." ^ name
- else name
- in
- (variant_id, name))
- variants
- | Some { body_info = Some (Enum variant_infos); _ } ->
- (* We need to compute the map from variant to variant *)
- let variant_map =
- StringMap.of_list
- (List.map
- (fun (info : builtin_enum_variant_info) ->
- (info.rust_variant_name, info.extract_variant_name))
- variant_infos)
- in
- VariantId.mapi
- (fun variant_id (variant : variant) ->
- (variant_id, StringMap.find variant.variant_name variant_map))
- variants
- | _ -> raise (Failure "Invalid builtin information")
- in
- List.fold_left
- (fun ctx (vid, vname) ->
- ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx)
- ctx variant_names
- | Opaque ->
- (* Nothing to do *)
- ctx
+ (fun ctx (vid, vname) ->
+ ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx)
+ ctx variant_names
+ | Opaque ->
+ (* Nothing to do *)
+ ctx
in
(* Return *)
ctx
@@ -906,6 +912,19 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
List.iter (fun (vid, v) -> print_variant vid v) variants
+(** Extract a struct as a tuple *)
+let extract_type_decl_tuple_struct_body (ctx : extraction_ctx)
+ (fmt : F.formatter) (fields : field list) : unit =
+ let sep = match !backend with Coq | FStar | HOL4 -> "*" | Lean -> "×" in
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt sep)
+ (fun (f : field) ->
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty true f.field_ty)
+ fields
+
let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
(type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl)
(type_params : string list) (cg_params : string list) (fields : field list)
@@ -1264,12 +1283,18 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(extract_body : bool) : unit =
(* Sanity check *)
assert (extract_body || !backend <> HOL4);
+ let is_tuple_struct =
+ TypesUtils.type_decl_from_decl_id_is_tuple_struct
+ ctx.trans_ctx.type_ctx.type_infos def.def_id
+ in
let type_kind =
if extract_body then
- match def.kind with
- | Struct _ -> Some Struct
- | Enum _ -> Some Enum
- | Opaque -> None
+ if is_tuple_struct then Some Tuple
+ else
+ match def.kind with
+ | Struct _ -> Some Struct
+ | Enum _ -> Some Enum
+ | Opaque -> None
else None
in
(* If in Coq and the declaration is opaque, it must have the shape:
@@ -1300,7 +1325,8 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
* for parsing: we thus use a hovbox. *)
(match !backend with
| Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0
- | Lean -> F.pp_open_vbox fmt 0);
+ | Lean ->
+ if is_tuple_struct then F.pp_open_hvbox fmt 0 else F.pp_open_vbox fmt 0);
(* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* > "type TYPE_NAME" *)
@@ -1320,7 +1346,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let eq =
match !backend with
| FStar -> "="
- | Coq -> ":="
+ | Coq ->
+ (* For Coq, the `*` is overloaded. If we want to extract a product
+ type (and not a product between, say, integers) we need to help Coq
+ a bit *)
+ if is_tuple_struct then ": Type :=" else ":="
| Lean ->
if type_kind = Some Struct && kind = SingleNonRec then "where"
else ":="
@@ -1341,8 +1371,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(if extract_body then
match def.kind with
| Struct fields ->
- extract_type_decl_struct_body ctx_body fmt type_decl_group kind def
- type_params cg_params fields
+ if is_tuple_struct then
+ extract_type_decl_tuple_struct_body ctx_body fmt fields
+ else
+ extract_type_decl_struct_body ctx_body fmt type_decl_group kind def
+ type_params cg_params fields
| Enum variants ->
extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name
type_params cg_params variants
@@ -1670,8 +1703,13 @@ let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
match !backend with
| FStar | Lean | HOL4 -> ()
| Coq ->
- extract_type_decl_coq_arguments ctx fmt kind decl;
- extract_type_decl_record_field_projectors ctx fmt kind decl
+ if
+ not
+ (TypesUtils.type_decl_from_decl_id_is_tuple_struct
+ ctx.trans_ctx.type_ctx.type_infos decl.def_id)
+ then (
+ extract_type_decl_coq_arguments ctx fmt kind decl;
+ extract_type_decl_record_field_projectors ctx fmt kind decl)
(** Extract the state type declaration. *)
let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 19b9fd3b..e56919fa 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -706,7 +706,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id)
(** Convert an {!type:avalue} to a {!type:value}.
This function is used when ending abstractions: whenever we end a borrow
- in an abstraction, we converted the borrowed {!avalue} to a fresh symbolic
+ in an abstraction, we convert the borrowed {!avalue} to a fresh symbolic
{!type:value}, then give back this {!type:value} to the context.
Note that some regions may have ended in the symbolic value we generate.
@@ -719,8 +719,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id)
be expanded (because expanding this symbolic value would require expanding
a reference whose region has already ended).
*)
-let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue)
- : symbolic_value =
+let convert_avalue_to_given_back_value (av : typed_avalue) : symbolic_value =
mk_fresh_symbolic_value av.ty
(** Auxiliary function: see {!end_borrow_aux}.
@@ -739,8 +738,8 @@ let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue)
borrows. This kind of internal reshuffling. should be similar to ending
abstractions (it is tantamount to ending *sub*-abstractions).
*)
-let give_back (config : config) (abs_id_opt : AbstractionId.id option)
- (l : BorrowId.id) (bc : g_borrow_content) (ctx : eval_ctx) : eval_ctx =
+let give_back (config : config) (l : BorrowId.id) (bc : g_borrow_content)
+ (ctx : eval_ctx) : eval_ctx =
(* Debug *)
log#ldebug
(lazy
@@ -781,9 +780,7 @@ let give_back (config : config) (abs_id_opt : AbstractionId.id option)
Rem.: we shouldn't do this here. We should do this in a function
which takes care of ending *sub*-abstractions.
*)
- let abs_id = Option.get abs_id_opt in
- let abs = ctx_lookup_abs ctx abs_id in
- let sv = convert_avalue_to_given_back_value abs.kind av in
+ let sv = convert_avalue_to_given_back_value av in
(* Update the context *)
give_back_avalue_to_same_abstraction config l av
(mk_typed_value_from_symbolic_value sv)
@@ -929,14 +926,14 @@ let rec end_borrow_aux (config : config) (chain : borrow_or_abs_ids)
cf_check cf ctx
(* We found a borrow and replaced it with [Bottom]: give it back (i.e., update
the corresponding loan) *)
- | Ok (ctx, Some (abs_id_opt, bc)) ->
+ | Ok (ctx, Some (_, bc)) ->
(* Sanity check: the borrowed value shouldn't contain loans *)
(match bc with
| Concrete (VMutBorrow (_, bv)) ->
assert (Option.is_none (get_first_loan_in_value bv))
| _ -> ());
(* Give back the value *)
- let ctx = give_back config abs_id_opt l bc ctx in
+ let ctx = give_back config l bc ctx in
(* Do a sanity check and continue *)
cf_check cf ctx
@@ -1161,7 +1158,7 @@ and end_abstraction_borrows (config : config) (chain : borrow_or_abs_ids)
match bc with
| AMutBorrow (bid, av) ->
(* First, convert the avalue to a (fresh symbolic) value *)
- let sv = convert_avalue_to_given_back_value abs.kind av in
+ let sv = convert_avalue_to_given_back_value av in
(* Replace the mut borrow to register the fact that we ended
* it and store with it the freshly generated given back value *)
let ended_borrow = ABorrow (AEndedMutBorrow (sv, av)) in
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index d0741b29..68f8943a 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -563,12 +563,13 @@ let remove_meta (def : fun_decl) : fun_decl =
This micro-pass turns those into expressions which use structure syntax:
{[
- {
- f0 := x0;
- ...
- fn := xn;
- }
+ type struct = { f0 : nat; f1 : nat; f2 : nat }
+
+ Mkstruct x.f0 x.f1 y ~~> { x with f2 = y }
]}
+
+ Note however that we do not apply this transformation if the
+ structure is to be extracted as a tuple.
*)
let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let obj =
@@ -592,37 +593,44 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
} ->
(* Lookup the def *)
let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
- (* Check that there are as many arguments as there are fields - note
- that the def should have a body (otherwise we couldn't use the
- constructor) *)
- let fields = TypesUtils.type_decl_get_fields decl None in
- if List.length fields = List.length args then
- (* Check if the definition is recursive *)
- let is_rec =
- match
- TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
- with
- | NonRecGroup _ -> false
- | RecGroup _ -> true
- in
- (* Convert, if possible - note that for now for Lean and Coq
- we don't support the structure syntax on recursive structures *)
- if
- (!Config.backend <> Lean && !Config.backend <> Coq)
- || not is_rec
- then
- let struct_id = TAdtId adt_id in
- let init = None in
- let updates =
- FieldId.mapi
- (fun fid fe -> (fid, self#visit_texpression env fe))
- args
+ (* Check if the def will be extracted as a tuple *)
+ if
+ TypesUtils.type_decl_from_decl_id_is_tuple_struct
+ ctx.type_ctx.type_infos adt_id
+ then ignore ()
+ else
+ (* Check that there are as many arguments as there are fields - note
+ that the def should have a body (otherwise we couldn't use the
+ constructor) *)
+ let fields = TypesUtils.type_decl_get_fields decl None in
+ if List.length fields = List.length args then
+ (* Check if the definition is recursive *)
+ let is_rec =
+ match
+ TypeDeclId.Map.find adt_id
+ ctx.type_ctx.type_decls_groups
+ with
+ | NonRecGroup _ -> false
+ | RecGroup _ -> true
in
- let ne = { struct_id; init; updates } in
- let nty = e.ty in
- { e = StructUpdate ne; ty = nty }
+ (* Convert, if possible - note that for now for Lean and Coq
+ we don't support the structure syntax on recursive structures *)
+ if
+ (!Config.backend <> Lean && !Config.backend <> Coq)
+ || not is_rec
+ then
+ let struct_id = TAdtId adt_id in
+ let init = None in
+ let updates =
+ FieldId.mapi
+ (fun fid fe -> (fid, self#visit_texpression env fe))
+ args
+ in
+ let ne = { struct_id; init; updates } in
+ let nty = e.ty in
+ { e = StructUpdate ne; ty = nty }
+ else ignore ()
else ignore ()
- else ignore ()
| _ -> ignore ())
| _ -> super#visit_texpression env e
end
@@ -1069,12 +1077,10 @@ let simplify_let_then_return _ctx def =
(** Simplify the aggregated ADTs.
Ex.:
{[
- type struct = { f0 : nat; f1 : nat }
+ type struct = { f0 : nat; f1 : nat; f2 : nat }
- Mkstruct x.f0 x.f1 ~~> x
+ Mkstruct x.f0 x.f1 x.f2 ~~> x
]}
-
- TODO: introduce a notation for [{ x with field = ... }], and use it.
*)
let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let expr_visitor =
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index a5143f3c..39dcd52d 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -687,3 +687,14 @@ let trait_impl_is_empty (trait_impl : trait_impl) : bool =
in
parent_trait_refs = [] && consts = [] && types = [] && required_methods = []
&& provided_methods = []
+
+(** Return true if a type declaration should be extracted as a tuple, because
+ it is a non-recursive structure with unnamed fields. *)
+let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
+ (id : type_id) : bool =
+ match id with
+ | TTuple -> true
+ | TAdtId id ->
+ let info = TypeDeclId.Map.find id ctx in
+ info.is_tuple_struct
+ | TAssumed _ -> false
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3b30549c..bf4d26f2 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2299,11 +2299,11 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
match sexp with
| V.SeLiteral _ ->
(* We do not *register* symbolic expansions to literal
- * values in the symbolic ADT *)
+ values in the symbolic ADT *)
raise (Failure "Unreachable")
| SeMutRef (_, nsv) | SeSharedRef (_, nsv) ->
(* The (mut/shared) borrow type is extracted to identity: we thus simply
- * introduce an reassignment *)
+ introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
let next_e = translate_expression e ctx in
let monadic = false in
@@ -2324,10 +2324,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
&& !Config.always_deconstruct_adts_with_matches) ->
(* There is exactly one branch: no branching.
- We can decompose the ADT value with a let-binding, unless
- the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}):
- we *ignore* this branch (and go to the next one) if the ADT is a custom
- adt, and [always_deconstruct_adts_with_matches] is true.
+ We can decompose the ADT value with a let-binding, unless
+ the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}):
+ we *ignore* this branch (and go to the next one) if the ADT is a custom
+ adt, and [always_deconstruct_adts_with_matches] is true.
*)
translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace
variant_id svl branch ctx
@@ -2361,7 +2361,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
{ e; ty })
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any
- * new values/variables *)
+ new values/variables *)
let true_e = translate_expression true_e ctx in
let false_e = translate_expression false_e ctx in
let e =
@@ -2376,7 +2376,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
match_branch =
(* We don't need to update the context: we don't introduce any
- * new values/variables *)
+ new values/variables *)
let branch = translate_expression branch_e ctx in
let pat = mk_typed_pattern_from_literal (VScalar v) in
{ pat; branch }
@@ -2436,20 +2436,28 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(* Detect if this is an enumeration or not *)
let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in
let is_enum = TypesUtils.type_decl_is_enum tdef in
- (* We deconstruct the ADT with a let-binding in two situations:
+ (* We deconstruct the ADT with a single let-binding in two situations:
- if the ADT is an enumeration (which must have exactly one branch)
- if we forbid using field projectors.
*)
let is_rec_def =
T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls
in
- let use_let =
+ let use_let_with_cons =
is_enum
|| !Config.dont_use_field_projectors
(* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *)
|| (!Config.backend = Lean && is_rec_def)
+ (* Also, there is a special case when the ADT is to be extracted as
+ a tuple, because it is a structure with unnamed fields. Some backends
+ like Lean have projectors for tuples (like so: `x.3`), but others
+ like Coq don't, in which case we have to deconstruct the whole ADT
+ at once (`let (a, b, c) = x in`) *)
+ || TypesUtils.type_decl_from_type_id_is_tuple_struct
+ ctx.type_context.type_infos type_id
+ && not (Config.backend_has_tuple_projectors ())
in
- if use_let then
+ if use_let_with_cons then
(* Introduce a let binding which expands the ADT *)
let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in
let lv = mk_adt_pattern scrutinee.ty variant_id lvars in
diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml
index 589c380c..12c20262 100644
--- a/compiler/TypesAnalysis.ml
+++ b/compiler/TypesAnalysis.ml
@@ -27,6 +27,10 @@ type 'p g_type_info = {
borrows_info : type_borrows_info;
(** Various informations about the borrows *)
param_infos : 'p; (** Gives information about the type parameters *)
+ is_tuple_struct : bool;
+ (** If true, it means the type is a record that we should extract as a tuple.
+ This field is only valid for type declarations.
+ *)
}
[@@deriving show]
@@ -55,22 +59,43 @@ let type_borrows_info_init : type_borrows_info =
contains_borrow_under_mut = false;
}
-let initialize_g_type_info (param_infos : 'p) : 'p g_type_info =
- { borrows_info = type_borrows_info_init; param_infos }
+(** Return true if a type declaration is a structure with unnamed fields.
-let initialize_type_decl_info (def : type_decl) : type_decl_info =
+ Note that there are two possibilities:
+ - either all the fields are named
+ - or none of the fields are named
+ *)
+let type_decl_is_tuple_struct (x : type_decl) : bool =
+ match x.kind with
+ | Struct fields -> List.for_all (fun f -> f.field_name = None) fields
+ | _ -> false
+
+let initialize_g_type_info (is_tuple_struct : bool) (param_infos : 'p) :
+ 'p g_type_info =
+ { borrows_info = type_borrows_info_init; is_tuple_struct; param_infos }
+
+let initialize_type_decl_info (is_rec : bool) (def : type_decl) : type_decl_info
+ =
let param_info = { under_borrow = false; under_mut_borrow = false } in
let param_infos = List.map (fun _ -> param_info) def.generics.types in
- initialize_g_type_info param_infos
+ let is_tuple_struct =
+ !Config.use_tuple_structs && (not is_rec) && type_decl_is_tuple_struct def
+ in
+ initialize_g_type_info is_tuple_struct param_infos
let type_decl_info_to_partial_type_info (info : type_decl_info) :
partial_type_info =
- { borrows_info = info.borrows_info; param_infos = Some info.param_infos }
+ {
+ borrows_info = info.borrows_info;
+ is_tuple_struct = info.is_tuple_struct;
+ param_infos = Some info.param_infos;
+ }
let partial_type_info_to_type_decl_info (info : partial_type_info) :
type_decl_info =
{
borrows_info = info.borrows_info;
+ is_tuple_struct = info.is_tuple_struct;
param_infos = Option.get info.param_infos;
}
@@ -283,14 +308,20 @@ let analyze_type_decl (updated : bool ref) (infos : type_infos)
let analyze_type_declaration_group (type_decls : type_decl TypeDeclId.Map.t)
(infos : type_infos) (decl : type_declaration_group) : type_infos =
(* Collect the identifiers used in the declaration group *)
- let ids = match decl with NonRecGroup id -> [ id ] | RecGroup ids -> ids in
+ let is_rec, ids =
+ match decl with
+ | NonRecGroup id -> (false, [ id ])
+ | RecGroup ids -> (true, ids)
+ in
(* Retrieve the type definitions *)
let decl_defs = List.map (fun id -> TypeDeclId.Map.find id type_decls) ids in
(* Initialize the type information for the current definitions *)
let infos =
List.fold_left
(fun infos (def : type_decl) ->
- TypeDeclId.Map.add def.def_id (initialize_type_decl_info def) infos)
+ TypeDeclId.Map.add def.def_id
+ (initialize_type_decl_info is_rec def)
+ infos)
infos decl_defs
in
(* Analyze the types - this function simply computes a fixed-point *)
@@ -327,7 +358,7 @@ let analyze_ty (infos : type_infos) (ty : ty) : ty_info =
(* We don't use [updated] but need to give it as parameter *)
let updated = ref false in
(* We don't need to compute whether the type contains 'static or not *)
- let ty_info = initialize_g_type_info None in
+ let ty_info = initialize_g_type_info false None in
let ty_info = analyze_full_ty updated infos ty_info ty in
(* Convert the ty_info *)
partial_type_info_to_ty_info ty_info
diff --git a/compiler/TypesUtils.ml b/compiler/TypesUtils.ml
index c8418ba0..28db59ec 100644
--- a/compiler/TypesUtils.ml
+++ b/compiler/TypesUtils.ml
@@ -111,3 +111,21 @@ let trait_type_constraint_no_regions (x : trait_type_constraint) : bool =
raise_if_region_ty_visitor#visit_ty () ty;
true
with Found -> false
+
+(** Return true if a type declaration should be extracted as a tuple, because
+ it is a non-recursive structure with unnamed fields. *)
+let type_decl_from_decl_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
+ (id : TypeDeclId.id) : bool =
+ let info = TypeDeclId.Map.find id ctx in
+ info.is_tuple_struct
+
+(** Return true if a type declaration should be extracted as a tuple, because
+ it is a non-recursive structure with unnamed fields. *)
+let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
+ (id : type_id) : bool =
+ match id with
+ | TTuple -> true
+ | TAdtId id ->
+ let info = TypeDeclId.Map.find id ctx in
+ info.is_tuple_struct
+ | TAssumed _ -> false