diff options
author | Son Ho | 2023-12-07 12:07:39 +0100 |
---|---|---|
committer | Son Ho | 2023-12-07 12:07:39 +0100 |
commit | 0209fee47a11b371d258fe02b8cc59b325de21d6 (patch) | |
tree | 9e23c2618c7138a02be28310eb8deaac2b4b3c5c | |
parent | eb05c2e3b63377c323c33c1296495baa9357596a (diff) |
Use a better syntax when extracting tuple types (structures with unnamed fields)
Diffstat (limited to '')
-rw-r--r-- | compiler/Config.ml | 18 | ||||
-rw-r--r-- | compiler/Extract.ml | 74 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 17 | ||||
-rw-r--r-- | compiler/ExtractTypes.ml | 230 | ||||
-rw-r--r-- | compiler/InterpreterBorrows.ml | 19 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 82 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 11 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 30 | ||||
-rw-r--r-- | compiler/TypesAnalysis.ml | 47 | ||||
-rw-r--r-- | compiler/TypesUtils.ml | 18 |
10 files changed, 347 insertions, 199 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 364ef748..b09544ba 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -338,7 +338,7 @@ let type_check_pure_code = ref false as far as possible while leaving "holes" in the generated code? *) let fail_hard = ref true -(** if true, add the type name as a prefix +(** If true, add the type name as a prefix to the variant names. Ex.: In Rust: @@ -364,3 +364,19 @@ let fail_hard = ref true ]} *) let variant_concatenate_type_name = ref true + +(** If true, extract the structures with unnamed fields as tuples. + + ex.: + {[ + // Rust + struct Foo(u32) + + // OCaml + type Foo = (u32) + ]} + *) +let use_tuple_structs = ref true + +let backend_has_tuple_projectors () = + match !backend with Lean -> true | Coq | FStar | HOL4 -> false diff --git a/compiler/Extract.ml b/compiler/Extract.ml index e48e6ae6..85bdd929 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -111,7 +111,7 @@ let extract_global_decl_register_names (ctx : extraction_ctx) context updated with new bindings. [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding - or a lambda). + or a lambda)? TODO: we don't need something very generic anymore (some definitions used to be polymorphic). @@ -121,38 +121,53 @@ let extract_adt_g_value (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool) (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = + let extract_as_tuple () = + (* This is very annoying: in Coq, we can't write [()] for the value of + type [unit], we have to write [tt]. *) + if !backend = Coq && field_values = [] then ( + F.pp_print_string fmt "tt"; + ctx) + else + (* If there is exactly one value, we don't print the parentheses *) + let lb, rb = + if List.length field_values = 1 then ("", "") else ("(", ")") + in + F.pp_print_string fmt lb; + let ctx = + Collections.List.fold_left_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx false v) + ctx field_values + in + F.pp_print_string fmt rb; + ctx + in match ty with | TAdt (TTuple, generics) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) assert (List.length generics.types = List.length field_values); assert (generics.const_generics = [] && generics.trait_refs = []); - (* This is very annoying: in Coq, we can't write [()] for the value of - type [unit], we have to write [tt]. *) - if !backend = Coq && field_values = [] then ( - F.pp_print_string fmt "tt"; - ctx) - else ( - F.pp_print_string fmt "("; - let ctx = - Collections.List.fold_left_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (fun ctx v -> extract_value ctx false v) - ctx field_values - in - F.pp_print_string fmt ")"; - ctx) + extract_as_tuple () | TAdt (adt_id, _) -> (* "Regular" ADT *) - - (* If we are generating a pattern for a let-binding and we target Lean, - the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. - - Otherwise, it is: `let Cons x0 ... xn = ...` - *) - if is_single_pat && !Config.backend = Lean then ( + (* We may still extract the ADT as a tuple, if none of the fields are + named *) + if + PureUtils.type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos adt_id + then (* Extract as a tuple *) + extract_as_tuple () + else if + (* If we are generating a pattern for a let-binding and we target Lean, + the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. + + Otherwise, it is: `let Cons x0 ... xn = ...` + *) + is_single_pat && !Config.backend = Lean + then ( F.pp_print_string fmt "⟨"; F.pp_print_space fmt (); let ctx = @@ -517,7 +532,14 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) match args with | [ arg ] -> (* Exactly one argument: pretty-print *) - let field_name = ctx_get_field proj.adt_id proj.field_id ctx in + let field_name = + (* Check if we need to extract the type as a structure *) + if + PureUtils.type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos proj.adt_id + then FieldId.to_string proj.field_id + else ctx_get_field proj.adt_id proj.field_id ctx + in (* Open a box *) F.pp_open_hovbox fmt ctx.indent_incr; (* Extract the expression *) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 43658b6e..93204515 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -109,7 +109,7 @@ let decl_is_first_from_group (kind : decl_kind) : bool = let decl_is_not_last_from_group (kind : decl_kind) : bool = not (decl_is_last_from_group kind) -type type_decl_kind = Enum | Struct [@@deriving show] +type type_decl_kind = Enum | Struct | Tuple [@@deriving show] (** We use identifiers to look for name clashes *) type id = @@ -1194,12 +1194,13 @@ let type_decl_kind_to_qualif (kind : decl_kind) | Declared -> Some "val") | Coq -> ( match (kind, type_kind) with + | SingleNonRec, Some Tuple -> Some "Definition" | SingleNonRec, Some Enum -> Some "Inductive" | SingleNonRec, Some Struct -> Some "Record" | (SingleRec | MutRecFirst), Some _ -> Some "Inductive" | (MutRecInner | MutRecLast), Some _ -> (* Coq doesn't support groups of mutually recursive definitions which mix - * records and inducties: we convert everything to records if this happens + * records and inductives: we convert everything to records if this happens *) Some "with" | (Assumed | Declared), None -> Some "Axiom" @@ -1214,12 +1215,12 @@ let type_decl_kind_to_qualif (kind : decl_kind) ^ ")"))) | Lean -> ( match kind with - | SingleNonRec -> - if type_kind = Some Struct then Some "structure" else Some "inductive" - | SingleRec -> Some "inductive" - | MutRecFirst -> Some "inductive" - | MutRecInner -> Some "inductive" - | MutRecLast -> Some "inductive" + | SingleNonRec -> ( + match type_kind with + | Some Tuple -> Some "def" + | Some Struct -> Some "structure" + | _ -> Some "inductive") + | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> Some "inductive" | Assumed -> Some "axiom" | Declared -> Some "axiom") | HOL4 -> None diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 3657627b..22243a4a 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -1,7 +1,4 @@ (** The generic extraction *) -(* Turn the whole module into a functor: it is very annoying to carry the - the formatter everywhere... -*) open Pure open PureUtils @@ -696,92 +693,101 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : * - the field names, if this is a structure *) let ctx = - match def.kind with - | Struct fields -> - (* Compute the names *) - let field_names, cons_name = - match info with - | None | Some { body_info = None; _ } -> - let field_names = - FieldId.mapi - (fun fid (field : field) -> - ( fid, - ctx_compute_field_name ctx def.llbc_name fid - field.field_name )) - fields - in - let cons_name = - ctx_compute_struct_constructor ctx def.llbc_name - in - (field_names, cons_name) - | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> - let field_names = - FieldId.mapi - (fun fid (field : field) -> - let rust_name = Option.get field.field_name in + (* Ignore this if the type is to be extracted as a tuple *) + if + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos def.def_id + then ctx + else + match def.kind with + | Struct fields -> + (* Compute the names *) + let field_names, cons_name = + match info with + | None | Some { body_info = None; _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + ( fid, + ctx_compute_field_name ctx def.llbc_name fid + field.field_name )) + fields + in + let cons_name = + ctx_compute_struct_constructor ctx def.llbc_name + in + (field_names, cons_name) + | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + let rust_name = Option.get field.field_name in + let name = + snd + (List.find (fun (n, _) -> n = rust_name) field_names) + in + (fid, name)) + fields + in + (field_names, cons_name) + | Some info -> + raise + (Failure + ("Invalid builtin information: " + ^ show_builtin_type_info info)) + in + (* Add the fields *) + let ctx = + List.fold_left + (fun ctx (fid, name) -> + ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx) + ctx field_names + in + (* Add the constructor name *) + ctx_add (StructId (TAdtId def.def_id)) cons_name ctx + | Enum variants -> + let variant_names = + match info with + | None -> + VariantId.mapi + (fun variant_id (variant : variant) -> let name = - snd (List.find (fun (n, _) -> n = rust_name) field_names) + ctx_compute_variant_name ctx def.llbc_name + variant.variant_name in - (fid, name)) - fields - in - (field_names, cons_name) - | Some info -> - raise - (Failure - ("Invalid builtin information: " - ^ show_builtin_type_info info)) - in - (* Add the fields *) - let ctx = + (* Add the type name prefix for Lean *) + let name = + if !Config.backend = Lean then + let type_name = + ctx_compute_type_name ctx def.llbc_name + in + type_name ^ "." ^ name + else name + in + (variant_id, name)) + variants + | Some { body_info = Some (Enum variant_infos); _ } -> + (* We need to compute the map from variant to variant *) + let variant_map = + StringMap.of_list + (List.map + (fun (info : builtin_enum_variant_info) -> + (info.rust_variant_name, info.extract_variant_name)) + variant_infos) + in + VariantId.mapi + (fun variant_id (variant : variant) -> + (variant_id, StringMap.find variant.variant_name variant_map)) + variants + | _ -> raise (Failure "Invalid builtin information") + in List.fold_left - (fun ctx (fid, name) -> - ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx) - ctx field_names - in - (* Add the constructor name *) - ctx_add (StructId (TAdtId def.def_id)) cons_name ctx - | Enum variants -> - let variant_names = - match info with - | None -> - VariantId.mapi - (fun variant_id (variant : variant) -> - let name = - ctx_compute_variant_name ctx def.llbc_name - variant.variant_name - in - (* Add the type name prefix for Lean *) - let name = - if !Config.backend = Lean then - let type_name = ctx_compute_type_name ctx def.llbc_name in - type_name ^ "." ^ name - else name - in - (variant_id, name)) - variants - | Some { body_info = Some (Enum variant_infos); _ } -> - (* We need to compute the map from variant to variant *) - let variant_map = - StringMap.of_list - (List.map - (fun (info : builtin_enum_variant_info) -> - (info.rust_variant_name, info.extract_variant_name)) - variant_infos) - in - VariantId.mapi - (fun variant_id (variant : variant) -> - (variant_id, StringMap.find variant.variant_name variant_map)) - variants - | _ -> raise (Failure "Invalid builtin information") - in - List.fold_left - (fun ctx (vid, vname) -> - ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx) - ctx variant_names - | Opaque -> - (* Nothing to do *) - ctx + (fun ctx (vid, vname) -> + ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx) + ctx variant_names + | Opaque -> + (* Nothing to do *) + ctx in (* Return *) ctx @@ -906,6 +912,19 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in List.iter (fun (vid, v) -> print_variant vid v) variants +(** Extract a struct as a tuple *) +let extract_type_decl_tuple_struct_body (ctx : extraction_ctx) + (fmt : F.formatter) (fields : field list) : unit = + let sep = match !backend with Coq | FStar | HOL4 -> "*" | Lean -> "×" in + Collections.List.iter_link + (fun () -> + F.pp_print_space fmt (); + F.pp_print_string fmt sep) + (fun (f : field) -> + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty true f.field_ty) + fields + let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) (type_params : string list) (cg_params : string list) (fields : field list) @@ -1264,12 +1283,18 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (extract_body : bool) : unit = (* Sanity check *) assert (extract_body || !backend <> HOL4); + let is_tuple_struct = + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos def.def_id + in let type_kind = if extract_body then - match def.kind with - | Struct _ -> Some Struct - | Enum _ -> Some Enum - | Opaque -> None + if is_tuple_struct then Some Tuple + else + match def.kind with + | Struct _ -> Some Struct + | Enum _ -> Some Enum + | Opaque -> None else None in (* If in Coq and the declaration is opaque, it must have the shape: @@ -1300,7 +1325,8 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) * for parsing: we thus use a hovbox. *) (match !backend with | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 - | Lean -> F.pp_open_vbox fmt 0); + | Lean -> + if is_tuple_struct then F.pp_open_hvbox fmt 0 else F.pp_open_vbox fmt 0); (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) @@ -1320,7 +1346,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let eq = match !backend with | FStar -> "=" - | Coq -> ":=" + | Coq -> + (* For Coq, the `*` is overloaded. If we want to extract a product + type (and not a product between, say, integers) we need to help Coq + a bit *) + if is_tuple_struct then ": Type :=" else ":=" | Lean -> if type_kind = Some Struct && kind = SingleNonRec then "where" else ":=" @@ -1341,8 +1371,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (if extract_body then match def.kind with | Struct fields -> - extract_type_decl_struct_body ctx_body fmt type_decl_group kind def - type_params cg_params fields + if is_tuple_struct then + extract_type_decl_tuple_struct_body ctx_body fmt fields + else + extract_type_decl_struct_body ctx_body fmt type_decl_group kind def + type_params cg_params fields | Enum variants -> extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name type_params cg_params variants @@ -1670,8 +1703,13 @@ let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) match !backend with | FStar | Lean | HOL4 -> () | Coq -> - extract_type_decl_coq_arguments ctx fmt kind decl; - extract_type_decl_record_field_projectors ctx fmt kind decl + if + not + (TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos decl.def_id) + then ( + extract_type_decl_coq_arguments ctx fmt kind decl; + extract_type_decl_record_field_projectors ctx fmt kind decl) (** Extract the state type declaration. *) let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 19b9fd3b..e56919fa 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -706,7 +706,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id) (** Convert an {!type:avalue} to a {!type:value}. This function is used when ending abstractions: whenever we end a borrow - in an abstraction, we converted the borrowed {!avalue} to a fresh symbolic + in an abstraction, we convert the borrowed {!avalue} to a fresh symbolic {!type:value}, then give back this {!type:value} to the context. Note that some regions may have ended in the symbolic value we generate. @@ -719,8 +719,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id) be expanded (because expanding this symbolic value would require expanding a reference whose region has already ended). *) -let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue) - : symbolic_value = +let convert_avalue_to_given_back_value (av : typed_avalue) : symbolic_value = mk_fresh_symbolic_value av.ty (** Auxiliary function: see {!end_borrow_aux}. @@ -739,8 +738,8 @@ let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue) borrows. This kind of internal reshuffling. should be similar to ending abstractions (it is tantamount to ending *sub*-abstractions). *) -let give_back (config : config) (abs_id_opt : AbstractionId.id option) - (l : BorrowId.id) (bc : g_borrow_content) (ctx : eval_ctx) : eval_ctx = +let give_back (config : config) (l : BorrowId.id) (bc : g_borrow_content) + (ctx : eval_ctx) : eval_ctx = (* Debug *) log#ldebug (lazy @@ -781,9 +780,7 @@ let give_back (config : config) (abs_id_opt : AbstractionId.id option) Rem.: we shouldn't do this here. We should do this in a function which takes care of ending *sub*-abstractions. *) - let abs_id = Option.get abs_id_opt in - let abs = ctx_lookup_abs ctx abs_id in - let sv = convert_avalue_to_given_back_value abs.kind av in + let sv = convert_avalue_to_given_back_value av in (* Update the context *) give_back_avalue_to_same_abstraction config l av (mk_typed_value_from_symbolic_value sv) @@ -929,14 +926,14 @@ let rec end_borrow_aux (config : config) (chain : borrow_or_abs_ids) cf_check cf ctx (* We found a borrow and replaced it with [Bottom]: give it back (i.e., update the corresponding loan) *) - | Ok (ctx, Some (abs_id_opt, bc)) -> + | Ok (ctx, Some (_, bc)) -> (* Sanity check: the borrowed value shouldn't contain loans *) (match bc with | Concrete (VMutBorrow (_, bv)) -> assert (Option.is_none (get_first_loan_in_value bv)) | _ -> ()); (* Give back the value *) - let ctx = give_back config abs_id_opt l bc ctx in + let ctx = give_back config l bc ctx in (* Do a sanity check and continue *) cf_check cf ctx @@ -1161,7 +1158,7 @@ and end_abstraction_borrows (config : config) (chain : borrow_or_abs_ids) match bc with | AMutBorrow (bid, av) -> (* First, convert the avalue to a (fresh symbolic) value *) - let sv = convert_avalue_to_given_back_value abs.kind av in + let sv = convert_avalue_to_given_back_value av in (* Replace the mut borrow to register the fact that we ended * it and store with it the freshly generated given back value *) let ended_borrow = ABorrow (AEndedMutBorrow (sv, av)) in diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index d0741b29..68f8943a 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -563,12 +563,13 @@ let remove_meta (def : fun_decl) : fun_decl = This micro-pass turns those into expressions which use structure syntax: {[ - { - f0 := x0; - ... - fn := xn; - } + type struct = { f0 : nat; f1 : nat; f2 : nat } + + Mkstruct x.f0 x.f1 y ~~> { x with f2 = y } ]} + + Note however that we do not apply this transformation if the + structure is to be extracted as a tuple. *) let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let obj = @@ -592,37 +593,44 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = } -> (* Lookup the def *) let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in - (* Check that there are as many arguments as there are fields - note - that the def should have a body (otherwise we couldn't use the - constructor) *) - let fields = TypesUtils.type_decl_get_fields decl None in - if List.length fields = List.length args then - (* Check if the definition is recursive *) - let is_rec = - match - TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups - with - | NonRecGroup _ -> false - | RecGroup _ -> true - in - (* Convert, if possible - note that for now for Lean and Coq - we don't support the structure syntax on recursive structures *) - if - (!Config.backend <> Lean && !Config.backend <> Coq) - || not is_rec - then - let struct_id = TAdtId adt_id in - let init = None in - let updates = - FieldId.mapi - (fun fid fe -> (fid, self#visit_texpression env fe)) - args + (* Check if the def will be extracted as a tuple *) + if + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.type_ctx.type_infos adt_id + then ignore () + else + (* Check that there are as many arguments as there are fields - note + that the def should have a body (otherwise we couldn't use the + constructor) *) + let fields = TypesUtils.type_decl_get_fields decl None in + if List.length fields = List.length args then + (* Check if the definition is recursive *) + let is_rec = + match + TypeDeclId.Map.find adt_id + ctx.type_ctx.type_decls_groups + with + | NonRecGroup _ -> false + | RecGroup _ -> true in - let ne = { struct_id; init; updates } in - let nty = e.ty in - { e = StructUpdate ne; ty = nty } + (* Convert, if possible - note that for now for Lean and Coq + we don't support the structure syntax on recursive structures *) + if + (!Config.backend <> Lean && !Config.backend <> Coq) + || not is_rec + then + let struct_id = TAdtId adt_id in + let init = None in + let updates = + FieldId.mapi + (fun fid fe -> (fid, self#visit_texpression env fe)) + args + in + let ne = { struct_id; init; updates } in + let nty = e.ty in + { e = StructUpdate ne; ty = nty } + else ignore () else ignore () - else ignore () | _ -> ignore ()) | _ -> super#visit_texpression env e end @@ -1069,12 +1077,10 @@ let simplify_let_then_return _ctx def = (** Simplify the aggregated ADTs. Ex.: {[ - type struct = { f0 : nat; f1 : nat } + type struct = { f0 : nat; f1 : nat; f2 : nat } - Mkstruct x.f0 x.f1 ~~> x + Mkstruct x.f0 x.f1 x.f2 ~~> x ]} - - TODO: introduce a notation for [{ x with field = ... }], and use it. *) let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let expr_visitor = diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index a5143f3c..39dcd52d 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -687,3 +687,14 @@ let trait_impl_is_empty (trait_impl : trait_impl) : bool = in parent_trait_refs = [] && consts = [] && types = [] && required_methods = [] && provided_methods = [] + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : type_id) : bool = + match id with + | TTuple -> true + | TAdtId id -> + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + | TAssumed _ -> false diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3b30549c..bf4d26f2 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2299,11 +2299,11 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match sexp with | V.SeLiteral _ -> (* We do not *register* symbolic expansions to literal - * values in the symbolic ADT *) + values in the symbolic ADT *) raise (Failure "Unreachable") | SeMutRef (_, nsv) | SeSharedRef (_, nsv) -> (* The (mut/shared) borrow type is extracted to identity: we thus simply - * introduce an reassignment *) + introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let next_e = translate_expression e ctx in let monadic = false in @@ -2324,10 +2324,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. - We can decompose the ADT value with a let-binding, unless - the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): - we *ignore* this branch (and go to the next one) if the ADT is a custom - adt, and [always_deconstruct_adts_with_matches] is true. + We can decompose the ADT value with a let-binding, unless + the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): + we *ignore* this branch (and go to the next one) if the ADT is a custom + adt, and [always_deconstruct_adts_with_matches] is true. *) translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace variant_id svl branch ctx @@ -2361,7 +2361,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) { e; ty }) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any - * new values/variables *) + new values/variables *) let true_e = translate_expression true_e ctx in let false_e = translate_expression false_e ctx in let e = @@ -2376,7 +2376,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : match_branch = (* We don't need to update the context: we don't introduce any - * new values/variables *) + new values/variables *) let branch = translate_expression branch_e ctx in let pat = mk_typed_pattern_from_literal (VScalar v) in { pat; branch } @@ -2436,20 +2436,28 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (* Detect if this is an enumeration or not *) let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in let is_enum = TypesUtils.type_decl_is_enum tdef in - (* We deconstruct the ADT with a let-binding in two situations: + (* We deconstruct the ADT with a single let-binding in two situations: - if the ADT is an enumeration (which must have exactly one branch) - if we forbid using field projectors. *) let is_rec_def = T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls in - let use_let = + let use_let_with_cons = is_enum || !Config.dont_use_field_projectors (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *) || (!Config.backend = Lean && is_rec_def) + (* Also, there is a special case when the ADT is to be extracted as + a tuple, because it is a structure with unnamed fields. Some backends + like Lean have projectors for tuples (like so: `x.3`), but others + like Coq don't, in which case we have to deconstruct the whole ADT + at once (`let (a, b, c) = x in`) *) + || TypesUtils.type_decl_from_type_id_is_tuple_struct + ctx.type_context.type_infos type_id + && not (Config.backend_has_tuple_projectors ()) in - if use_let then + if use_let_with_cons then (* Introduce a let binding which expands the ADT *) let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in let lv = mk_adt_pattern scrutinee.ty variant_id lvars in diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index 589c380c..12c20262 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -27,6 +27,10 @@ type 'p g_type_info = { borrows_info : type_borrows_info; (** Various informations about the borrows *) param_infos : 'p; (** Gives information about the type parameters *) + is_tuple_struct : bool; + (** If true, it means the type is a record that we should extract as a tuple. + This field is only valid for type declarations. + *) } [@@deriving show] @@ -55,22 +59,43 @@ let type_borrows_info_init : type_borrows_info = contains_borrow_under_mut = false; } -let initialize_g_type_info (param_infos : 'p) : 'p g_type_info = - { borrows_info = type_borrows_info_init; param_infos } +(** Return true if a type declaration is a structure with unnamed fields. -let initialize_type_decl_info (def : type_decl) : type_decl_info = + Note that there are two possibilities: + - either all the fields are named + - or none of the fields are named + *) +let type_decl_is_tuple_struct (x : type_decl) : bool = + match x.kind with + | Struct fields -> List.for_all (fun f -> f.field_name = None) fields + | _ -> false + +let initialize_g_type_info (is_tuple_struct : bool) (param_infos : 'p) : + 'p g_type_info = + { borrows_info = type_borrows_info_init; is_tuple_struct; param_infos } + +let initialize_type_decl_info (is_rec : bool) (def : type_decl) : type_decl_info + = let param_info = { under_borrow = false; under_mut_borrow = false } in let param_infos = List.map (fun _ -> param_info) def.generics.types in - initialize_g_type_info param_infos + let is_tuple_struct = + !Config.use_tuple_structs && (not is_rec) && type_decl_is_tuple_struct def + in + initialize_g_type_info is_tuple_struct param_infos let type_decl_info_to_partial_type_info (info : type_decl_info) : partial_type_info = - { borrows_info = info.borrows_info; param_infos = Some info.param_infos } + { + borrows_info = info.borrows_info; + is_tuple_struct = info.is_tuple_struct; + param_infos = Some info.param_infos; + } let partial_type_info_to_type_decl_info (info : partial_type_info) : type_decl_info = { borrows_info = info.borrows_info; + is_tuple_struct = info.is_tuple_struct; param_infos = Option.get info.param_infos; } @@ -283,14 +308,20 @@ let analyze_type_decl (updated : bool ref) (infos : type_infos) let analyze_type_declaration_group (type_decls : type_decl TypeDeclId.Map.t) (infos : type_infos) (decl : type_declaration_group) : type_infos = (* Collect the identifiers used in the declaration group *) - let ids = match decl with NonRecGroup id -> [ id ] | RecGroup ids -> ids in + let is_rec, ids = + match decl with + | NonRecGroup id -> (false, [ id ]) + | RecGroup ids -> (true, ids) + in (* Retrieve the type definitions *) let decl_defs = List.map (fun id -> TypeDeclId.Map.find id type_decls) ids in (* Initialize the type information for the current definitions *) let infos = List.fold_left (fun infos (def : type_decl) -> - TypeDeclId.Map.add def.def_id (initialize_type_decl_info def) infos) + TypeDeclId.Map.add def.def_id + (initialize_type_decl_info is_rec def) + infos) infos decl_defs in (* Analyze the types - this function simply computes a fixed-point *) @@ -327,7 +358,7 @@ let analyze_ty (infos : type_infos) (ty : ty) : ty_info = (* We don't use [updated] but need to give it as parameter *) let updated = ref false in (* We don't need to compute whether the type contains 'static or not *) - let ty_info = initialize_g_type_info None in + let ty_info = initialize_g_type_info false None in let ty_info = analyze_full_ty updated infos ty_info ty in (* Convert the ty_info *) partial_type_info_to_ty_info ty_info diff --git a/compiler/TypesUtils.ml b/compiler/TypesUtils.ml index c8418ba0..28db59ec 100644 --- a/compiler/TypesUtils.ml +++ b/compiler/TypesUtils.ml @@ -111,3 +111,21 @@ let trait_type_constraint_no_regions (x : trait_type_constraint) : bool = raise_if_region_ty_visitor#visit_ty () ty; true with Found -> false + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_decl_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : TypeDeclId.id) : bool = + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : type_id) : bool = + match id with + | TTuple -> true + | TAdtId id -> + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + | TAssumed _ -> false |