diff options
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r-- | compiler/Extract.ml | 753 |
1 files changed, 501 insertions, 252 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index e24cae16..91827a31 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -229,12 +229,9 @@ let assumed_adts () : (assumed_ty * string) list = (Result, "Result"); (Error, "Error"); (Fuel, "Nat"); - (Option, "Option"); - (Vec, "Vec"); (Array, "Array"); (Slice, "Slice"); (Str, "Str"); - (Range, "Range"); ] | Coq | FStar -> [ @@ -242,12 +239,9 @@ let assumed_adts () : (assumed_ty * string) list = (Result, "result"); (Error, "error"); (Fuel, "nat"); - (Option, "option"); - (Vec, "vec"); (Array, "array"); (Slice, "slice"); (Str, "str"); - (Range, "range"); ] | HOL4 -> [ @@ -255,20 +249,17 @@ let assumed_adts () : (assumed_ty * string) list = (Result, "result"); (Error, "error"); (Fuel, "num"); - (Option, "option"); - (Vec, "vec"); (Array, "array"); (Slice, "slice"); (Str, "str"); - (Range, "range"); ] let assumed_struct_constructors () : (assumed_ty * string) list = match !backend with - | Lean -> [ (Range, "Range.mk"); (Array, "Array.make") ] - | Coq -> [ (Range, "mk_range"); (Array, "mk_array") ] - | FStar -> [ (Range, "Mkrange"); (Array, "mk_array") ] - | HOL4 -> [ (Range, "mk_range"); (Array, "mk_array") ] + | Lean -> [ (Array, "Array.make") ] + | Coq -> [ (Array, "mk_array") ] + | FStar -> [ (Array, "mk_array") ] + | HOL4 -> [ (Array, "mk_array") ] let assumed_variants () : (assumed_ty * VariantId.id * string) list = match !backend with @@ -280,8 +271,6 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Error, error_out_of_fuel_id, "OutOfFuel"); (* No Fuel::Zero on purpose *) (* No Fuel::Succ on purpose *) - (Option, option_some_id, "Some"); - (Option, option_none_id, "None"); ] | Coq -> [ @@ -291,8 +280,6 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Error, error_out_of_fuel_id, "OutOfFuel"); (Fuel, fuel_zero_id, "O"); (Fuel, fuel_succ_id, "S"); - (Option, option_some_id, "Some"); - (Option, option_none_id, "None"); ] | Lean -> [ @@ -301,8 +288,6 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Error, error_failure_id, "panic"); (* No Fuel::Zero on purpose *) (* No Fuel::Succ on purpose *) - (Option, option_some_id, "some"); - (Option, option_none_id, "none"); ] | HOL4 -> [ @@ -311,8 +296,6 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Error, error_failure_id, "Failure"); (* No Fuel::Zero on purpose *) (* No Fuel::Succ on purpose *) - (Option, option_some_id, "SOME"); - (Option, option_none_id, "NONE"); ] let assumed_llbc_functions () : @@ -321,66 +304,30 @@ let assumed_llbc_functions () : match !backend with | FStar | Coq | HOL4 -> [ - (Replace, None, "mem_replace_fwd"); - (Replace, rg0, "mem_replace_back"); - (VecNew, None, "vec_new"); - (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "vec_push_back"); - (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "vec_insert_back"); - (VecLen, None, "vec_len"); - (VecIndex, None, "vec_index_fwd"); - (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); - (VecIndexMut, None, "vec_index_mut_fwd"); - (VecIndexMut, rg0, "vec_index_mut_back"); (ArrayIndexShared, None, "array_index_shared"); (ArrayIndexMut, None, "array_index_mut_fwd"); (ArrayIndexMut, rg0, "array_index_mut_back"); (ArrayToSliceShared, None, "array_to_slice_shared"); (ArrayToSliceMut, None, "array_to_slice_mut_fwd"); (ArrayToSliceMut, rg0, "array_to_slice_mut_back"); - (ArraySubsliceShared, None, "array_subslice_shared"); - (ArraySubsliceMut, None, "array_subslice_mut_fwd"); - (ArraySubsliceMut, rg0, "array_subslice_mut_back"); (ArrayRepeat, None, "array_repeat"); (SliceIndexShared, None, "slice_index_shared"); (SliceIndexMut, None, "slice_index_mut_fwd"); (SliceIndexMut, rg0, "slice_index_mut_back"); - (SliceSubsliceShared, None, "slice_subslice_shared"); - (SliceSubsliceMut, None, "slice_subslice_mut_fwd"); - (SliceSubsliceMut, rg0, "slice_subslice_mut_back"); (SliceLen, None, "slice_len"); ] | Lean -> [ - (Replace, None, "mem.replace"); - (Replace, rg0, "mem.replace_back"); - (VecNew, None, "Vec.new"); - (VecPush, None, "Vec.push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "Vec.push"); - (VecInsert, None, "Vec.insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "Vec.insert"); - (VecLen, None, "Vec.len"); - (VecIndex, None, "Vec.index_shared"); - (VecIndex, rg0, "Vec.index_shared_back") (* shouldn't be used *); - (VecIndexMut, None, "Vec.index_mut"); - (VecIndexMut, rg0, "Vec.index_mut_back"); (ArrayIndexShared, None, "Array.index_shared"); (ArrayIndexMut, None, "Array.index_mut"); (ArrayIndexMut, rg0, "Array.index_mut_back"); (ArrayToSliceShared, None, "Array.to_slice_shared"); (ArrayToSliceMut, None, "Array.to_slice_mut"); (ArrayToSliceMut, rg0, "Array.to_slice_mut_back"); - (ArraySubsliceShared, None, "Array.subslice_shared"); - (ArraySubsliceMut, None, "Array.subslice_mut"); - (ArraySubsliceMut, rg0, "Array.subslice_mut_back"); (ArrayRepeat, None, "Array.repeat"); (SliceIndexShared, None, "Slice.index_shared"); (SliceIndexMut, None, "Slice.index_mut"); (SliceIndexMut, rg0, "Slice.index_mut_back"); - (SliceSubsliceShared, None, "Slice.subslice_shared"); - (SliceSubsliceMut, None, "Slice.subslice_mut"); - (SliceSubsliceMut, rg0, "Slice.subslice_mut_back"); (SliceLen, None, "Slice.len"); ] @@ -814,12 +761,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ lp_suffix ^ suffix in - let opaque_pre () = - match !Config.backend with - | FStar | Coq | HOL4 -> "" - | Lean -> if !Config.wrap_opaque_in_sig then "opaque_defs." else "" - in - let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = (* Small helper to derive var names from ADT type names. @@ -853,12 +794,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Assumed Result -> "r" | Assumed Error -> ConstStrings.error_basename | Assumed Fuel -> ConstStrings.fuel_basename - | Assumed Option -> "opt" - | Assumed Vec -> "v" | Assumed Array -> "a" | Assumed Slice -> "s" | Assumed Str -> "s" - | Assumed Range -> "r" | Assumed State -> ConstStrings.state_basename | AdtId adt_id -> let def = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in @@ -927,10 +865,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* We need to add parentheses if the value is negative *) if sv.PV.value >= Z.of_int 0 then F.pp_print_string fmt (Z.to_string sv.PV.value) - else + else if !backend = Lean then + (* TODO: parsing issues with Lean because there are ambiguous + interpretations between int values and nat values *) F.pp_print_string fmt - ("(" ^ Z.to_string sv.PV.value - ^ if !backend = Lean then ":Int" else "" ^ ")"); + ("(-(" ^ Z.to_string (Z.neg sv.PV.value) ^ ":Int))") + else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); (match !backend with | Coq -> let iname = int_name sv.PV.int_ty in @@ -993,7 +933,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) trait_type_name; trait_method_name; trait_type_clause_name; - opaque_pre; var_basename; type_var_basename; const_generic_var_basename; @@ -1042,11 +981,8 @@ let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter) (* In HOL4, opaque functions have a special treatment *) if is_single_opaque_fun_decl_group dg then () else - let with_opaque_pre = false in let compute_fun_def_name (def : Pure.fun_decl) : string = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx - ^ "_def" + ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def" in let names = List.map compute_fun_def_name dg in (* Add a break before *) @@ -1169,7 +1105,7 @@ let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (cg : const_generic) : unit = match cg with | ConstGenericGlobal id -> - let s = ctx_get_global ctx.use_opaque_pre id ctx in + let s = ctx_get_global id ctx in F.pp_print_string fmt s | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v | ConstGenericVar id -> @@ -1237,14 +1173,33 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) In HOL4 we would write: `('a, 'b) tree` *) - let with_opaque_pre = false in match !backend with | FStar | Coq | Lean -> let print_paren = inside && has_params in if print_paren then F.pp_print_string fmt "("; (* TODO: for now, only the opaque *functions* are extracted in the opaque module. The opaque *types* are assumed. *) - F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); + F.pp_print_string fmt (ctx_get_type type_id ctx); + (* We might need to filter the type arguments, if the type + is builtin (for instance, we filter the global allocator type + argument for `Vec`). *) + let generics = + match type_id with + | AdtId id -> ( + match + TypeDeclId.Map.find_opt id ctx.types_filter_type_args_map + with + | None -> generics + | Some filter -> + let types = List.combine filter generics.types in + let types = + List.filter_map + (fun (b, ty) -> if b then Some ty else None) + types + in + { generics with types }) + | _ -> generics + in extract_generic_args ctx fmt no_params_tys generics; if print_paren then F.pp_print_string fmt ")" | HOL4 -> @@ -1267,7 +1222,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (extract_rec true) types; if print_paren then F.pp_print_string fmt ")"; F.pp_print_space fmt ()); - F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); + F.pp_print_string fmt (ctx_get_type type_id ctx); if trait_refs <> [] then ( F.pp_print_space fmt (); Collections.List.iter_link (F.pp_print_space fmt) @@ -1332,11 +1287,13 @@ and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_decl_ref) : unit = let use_brackets = tr.decl_generics <> empty_generic_args && inside in - let is_opaque = false in - let name = ctx_get_trait_decl is_opaque tr.trait_decl_id ctx in + let name = ctx_get_trait_decl tr.trait_decl_id ctx in if use_brackets then F.pp_print_string fmt "("; F.pp_print_string fmt name; - extract_generic_args ctx fmt no_params_tys tr.decl_generics; + (* There is something subtle here: the trait obligations for the implemented + trait are put inside the parent clauses, so we must ignore them here *) + let generics = { tr.decl_generics with trait_refs = [] } in + extract_generic_args ctx fmt no_params_tys generics; if use_brackets then F.pp_print_string fmt ")" and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter) @@ -1363,14 +1320,13 @@ and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter) and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id) : unit = - let with_opaque_pre = false in match id with | Self -> (* This has specific treatment depending on the item we're extracting (associated type, etc.). We should have caught this elsewhere. *) raise (Failure "Unexpected") | TraitImpl id -> - let name = ctx_get_trait_impl with_opaque_pre id ctx in + let name = ctx_get_trait_impl id ctx in F.pp_print_string fmt name | Clause id -> let name = ctx_get_local_trait_clause id ctx in @@ -1400,8 +1356,28 @@ and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter) *) let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : extraction_ctx = + (* Lookup the builtin information, if there is *) + let open ExtractBuiltin in + let sname = name_to_simple_name def.name in + let info = SimpleNameMap.find_opt sname (builtin_types_map ()) in + (* Register the filtering information, if there is *) + let ctx = + match info with + | Some { keep_params = Some keep; _ } -> + { + ctx with + types_filter_type_args_map = + TypeDeclId.Map.add def.def_id keep ctx.types_filter_type_args_map; + } + | _ -> ctx + in (* Compute and register the type def name *) - let ctx = ctx_add_type_decl def ctx in + let def_name = + match info with + | None -> ctx.fmt.type_name def.name + | Some info -> String.concat "." info.rust_name + in + let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in (* Compute and register: * - the variant names, if this is an enumeration * - the field names, if this is a structure @@ -1409,18 +1385,77 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : let ctx = match def.kind with | Struct fields -> + (* Compute the names *) + let field_names, cons_name = + match info with + | None | Some { body_info = None; _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + (fid, ctx.fmt.field_name def.name fid field.field_name)) + fields + in + let cons_name = ctx.fmt.struct_constructor def.name in + (field_names, cons_name) + | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> + let field_names = + FieldId.mapi + (fun fid (_, name) -> (fid, name)) + (List.combine fields field_names) + in + (field_names, cons_name) + | Some info -> + raise + (Failure + ("Invalid builtin information: " + ^ show_builtin_type_info info)) + in (* Add the fields *) let ctx = - fst - (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) + List.fold_left + (fun ctx (fid, name) -> + ctx_add (FieldId (AdtId def.def_id, fid)) name ctx) + ctx field_names in (* Add the constructor name *) - fst (ctx_add_struct def ctx) + ctx_add (StructId (AdtId def.def_id)) cons_name ctx | Enum variants -> - fst - (ctx_add_variants def - (VariantId.mapi (fun id v -> (id, v)) variants) - ctx) + let variant_names = + match info with + | None -> + VariantId.mapi + (fun variant_id (variant : variant) -> + let name = + ctx.fmt.variant_name def.name variant.variant_name + in + (* Add the type name prefix for Lean *) + let name = + if !Config.backend = Lean then + let type_name = ctx.fmt.type_name def.name in + type_name ^ "." ^ name + else name + in + (variant_id, name)) + variants + | Some { body_info = Some (Enum variant_infos); _ } -> + (* We need to compute the map from variant to variant *) + let variant_map = + StringMap.of_list + (List.map + (fun (info : builtin_enum_variant_info) -> + (info.rust_variant_name, info.extract_variant_name)) + variant_infos) + in + VariantId.mapi + (fun variant_id (variant : variant) -> + (variant_id, StringMap.find variant.variant_name variant_map)) + variants + | _ -> raise (Failure "Invalid builtin information") + in + List.fold_left + (fun ctx (vid, vname) -> + ctx_add (VariantId (AdtId def.def_id, vid)) vname ctx) + ctx variant_names | Opaque -> (* Nothing to do *) ctx @@ -1622,9 +1657,7 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (* If Coq: print the constructor name *) (* TODO: remove superfluous test not is_rec below *) if !backend = Coq && not is_rec then ( - let with_opaque_pre = false in - F.pp_print_string fmt - (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx); + F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx); F.pp_print_string fmt " "); (match !backend with | Lean -> () @@ -1668,16 +1701,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (* We extract for Coq or Lean, and we have a recursive record, or a record in a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && (!backend = Coq || !backend = Lean)); - let with_opaque_pre = false in (* Small trick: in Lean we use namespaces, meaning we don't need to prefix the constructor name with the name of the type at definition site, i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq we generate `inductive Foo := | mk ... *) let cons_name = - if !backend = Lean then "mk" - else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx + if !backend = Lean then "mk" else ctx_get_struct (AdtId def.def_id) ctx in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in + let def_name = ctx_get_local_type def.def_id ctx in extract_type_decl_variant ctx fmt type_decl_group def_name type_params cg_params cons_name fields) in @@ -1707,8 +1738,7 @@ let extract_comment (fmt : F.formatter) (sl : string list) : unit = let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit = - let with_opaque_pre = false in - let trait_name = ctx_get_trait_decl with_opaque_pre clause.trait_id ctx in + let trait_name = ctx_get_trait_decl clause.trait_id ctx in F.pp_print_string fmt trait_name; extract_generic_args ctx fmt no_params_tys clause.generics @@ -1730,8 +1760,7 @@ let extract_trait_self_clause (insert_req_space : unit -> unit) F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); - let with_opaque_pre = false in - let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in + let trait_id = ctx_get_trait_decl trait_decl.def_id ctx in F.pp_print_string fmt trait_id; List.iter (fun p -> @@ -1900,8 +1929,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let is_opaque_coq = !backend = Coq && is_opaque in let use_forall = is_opaque_coq && def.generics <> empty_generic_params in (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in + let def_name = ctx_get_local_type def.def_id ctx in (* Add the type and const generic params - note that we need those bindings only for the * body translation (they are not top-level) *) let ctx_body, type_params, cg_params, trait_clauses = @@ -1988,8 +2016,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (def : type_decl) : unit = (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in + let def_name = ctx_get_local_type def.def_id ctx in (* Generic parameters are unsupported *) assert (def.generics.const_generics = []); (* Trait clauses on type definitions are unsupported *) @@ -2014,8 +2041,7 @@ let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) let extract_type_decl_hol4_empty_record (ctx : extraction_ctx) (fmt : F.formatter) (def : type_decl) : unit = (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in + let def_name = ctx_get_local_type def.def_id ctx in (* Sanity check *) assert (def.generics = empty_generic_params); (* Generate the declaration *) @@ -2098,8 +2124,7 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) | Struct fields -> let adt_id = AdtId decl.def_id in (* Generate the instruction for the record constructor *) - let with_opaque_pre = false in - let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in + let cons_name = ctx_get_struct adt_id ctx in extract_arguments_info cons_name fields; (* Generate the instruction for the record projectors, if there are *) let is_rec = decl_is_from_rec_group kind in @@ -2143,11 +2168,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) in let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in - let cons_name = - ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx - in + let def_name = ctx_get_local_type decl.def_id ctx in + let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in let extract_field_proj (field_id : FieldId.id) (_ : field) : unit = F.pp_print_space fmt (); (* Outer box for the projector definition *) @@ -2359,33 +2381,81 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : extraction_ctx = - let fwd = def.fwd in - let backs = def.backs in - (* Register the decrease clauses, if necessary *) - let register_decreases ctx def = - if has_decreases_clause def then - (* Add the termination measure *) - let ctx = ctx_add_termination_measure def ctx in - (* Add the decreases proof for Lean only *) - match !Config.backend with - | Coq | FStar -> ctx - | HOL4 -> raise (Failure "Unexpected") - | Lean -> ctx_add_decreases_proof def ctx - else ctx - in - let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in - let register_fun ctx f = ctx_add_fun_decl def f ctx in - let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the names of the forward functions *) - let ctx = - if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx - in - (* Register the names of the backward functions *) - List.fold_left - (fun ctx { f = back; loops = loop_backs } -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx backs + (* Ignore the trait methods **declarations** (rem.: we do not ignore the trait + method implementations): we do not need to refer to them directly. We will + only use their type for the fields of the records we generate for the trait + declarations *) + match def.fwd.f.kind with + | TraitMethodDecl _ -> ctx + | _ -> ( + (* Check if the function is builtin *) + let builtin = + let open ExtractBuiltin in + let funs_map = builtin_funs_map () in + let sname = name_to_simple_name def.fwd.f.basename in + SimpleNameMap.find_opt sname funs_map + in + (* Use the builtin names if necessary *) + match builtin with + | Some (filter_info, info) -> + (* Register the filtering information, if there is *) + let ctx = + match filter_info with + | Some keep -> + { + ctx with + funs_filter_type_args_map = + FunDeclId.Map.add def.fwd.f.def_id keep + ctx.funs_filter_type_args_map; + } + | _ -> ctx + in + let backs = List.map (fun f -> f.f) def.backs in + let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + List.fold_left + (fun ctx (f : fun_decl) -> + let open ExtractBuiltin in + let fun_id = + (Pure.FunId (Regular f.def_id), f.loop_id, f.back_id) + in + let fun_name = + (List.find + (fun (x : builtin_fun_info) -> x.rg = f.back_id) + info) + .extract_name + in + ctx_add (FunId (FromLlbc fun_id)) fun_name ctx) + ctx funs + | None -> + let fwd = def.fwd in + let backs = def.backs in + (* Register the decrease clauses, if necessary *) + let register_decreases ctx def = + if has_decreases_clause def then + (* Add the termination measure *) + let ctx = ctx_add_termination_measure def ctx in + (* Add the decreases proof for Lean only *) + match !Config.backend with + | Coq | FStar -> ctx + | HOL4 -> raise (Failure "Unexpected") + | Lean -> ctx_add_decreases_proof def ctx + else ctx + in + let ctx = + List.fold_left register_decreases ctx (fwd.f :: fwd.loops) + in + let register_fun ctx f = ctx_add_fun_decl def f ctx in + let register_funs ctx fl = List.fold_left register_fun ctx fl in + (* Register the names of the forward functions *) + let ctx = + if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx + in + (* Register the names of the backward functions *) + List.fold_left + (fun ctx { f = back; loops = loop_backs } -> + let ctx = register_fun ctx back in + register_funs ctx loop_backs) + ctx backs) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -2459,18 +2529,14 @@ let extract_adt_g_value * [{ field0=...; ...; fieldn=...; }] in case of structures. *) let cons = - (* The ADT shouldn't be opaque *) - let with_opaque_pre = false in match variant_id with | Some vid -> ( (* In the case of Lean, we might have to add the type name as a prefix *) match (!backend, adt_id) with | Lean, Assumed _ -> - ctx_get_type with_opaque_pre adt_id ctx - ^ "." - ^ ctx_get_variant adt_id vid ctx + ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx | _ -> ctx_get_variant adt_id vid ctx) - | None -> ctx_get_struct with_opaque_pre adt_id ctx + | None -> ctx_get_struct adt_id ctx in let use_parentheses = inside && field_values <> [] in if use_parentheses then F.pp_print_string fmt "("; @@ -2489,8 +2555,7 @@ let extract_adt_g_value (* Extract globals in the same way as variables *) let extract_global (ctx : extraction_ctx) (fmt : F.formatter) (id : A.GlobalDeclId.id) : unit = - let with_opaque_pre = ctx.use_opaque_pre in - F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx) + F.pp_print_string fmt (ctx_get_global id ctx) (** [inside]: see {!extract_ty}. @@ -2626,9 +2691,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) if inside then F.pp_print_string fmt "("; (* Open a box for the function call *) F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the function name *) - let with_opaque_pre = ctx.use_opaque_pre in - (* For the function name: the id is not the same depending on whether + (* Print the function name. + + For the function name: the id is not the same depending on whether we call a trait method and a "regular" function (remark: trait method *implementations* are considered as regular functions here; only calls to method of traits which are parameterized in a where @@ -2701,7 +2766,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) let fun_id = FromLlbc (FunId (Regular method_id.id), lp_id, rg_id) in - let fun_name = ctx_get_function with_opaque_pre fun_id ctx in + let fun_name = ctx_get_function fun_id ctx in F.pp_print_string fmt fun_name; (* Note that we do not need to print the generics for the trait @@ -2712,12 +2777,32 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref | _ -> - let fun_name = ctx_get_function with_opaque_pre fun_id ctx in + let fun_name = ctx_get_function fun_id ctx in F.pp_print_string fmt fun_name); (* Sanity check: HOL4 doesn't support const generics *) assert (generics.const_generics = [] || !backend <> HOL4); - (* Print the generics *) + (* Print the generics. + + We might need to filter some of the type arguments, if the type + is builtin (for instance, we filter the global allocator type + argument for `Vec::new`). + *) + let generics = + match fun_id with + | FromLlbc (FunId (Regular id), _, _) -> ( + match FunDeclId.Map.find_opt id ctx.funs_filter_type_args_map with + | None -> generics + | Some filter -> + let types = List.combine filter generics.types in + let types = + List.filter_map + (fun (b, ty) -> if b then Some ty else None) + types + in + { generics with types }) + | _ -> generics + in extract_generic_args ctx fmt TypeDeclId.Set.empty generics; (* Print the arguments *) List.iter @@ -3210,7 +3295,7 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) (* Open the box for `Array.replicate T N [` *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the array constructor *) - let cs = ctx_get_struct false (Assumed Array) ctx in + let cs = ctx_get_struct (Assumed Array) ctx in F.pp_print_string fmt cs; (* Print the parameters *) let _, generics = ty_as_adt e_ty in @@ -3563,10 +3648,8 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) - let with_opaque_pre = false in let def_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id - ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then @@ -3594,19 +3677,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let use_forall = is_opaque_coq && def.signature.generics <> empty_generic_params in - (* Print the qualifier ("assume", etc.). - - if `wrap_opaque_in_sig`: we generate a record of assumed funcions. - TODO: this is obsolete. - *) - (if not (!Config.wrap_opaque_in_sig && (kind = Assumed || kind = Declared)) - then - let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in - match qualif with - | Some qualif -> - F.pp_print_string fmt qualif; - F.pp_print_space fmt () - | None -> ()); + (* Print the qualifier ("assume", etc.). *) + let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in + (match qualif with + | Some qualif -> + F.pp_print_string fmt qualif; + F.pp_print_space fmt () + | None -> ()); F.pp_print_string fmt def_name; F.pp_print_space fmt (); if use_forall then ( @@ -3817,10 +3894,8 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = (* Retrieve the definition name *) - let with_opaque_pre = false in let def_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id - ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in assert (def.signature.generics.const_generics = []); (* Add the type/const gen parameters - note that we need those bindings @@ -4015,10 +4090,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) extract_comment fmt [ "[" ^ Print.global_name_to_string global.name ^ "]" ]; F.pp_print_space fmt (); - let with_opaque_pre = false in - let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in + let decl_name = ctx_get_global global.def_id ctx in let body_name = - ctx_get_function with_opaque_pre + ctx_get_function (FromLlbc (Pure.FunId (Regular global.body_id), None, None)) ctx in @@ -4056,73 +4130,263 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break to insert lines between declarations *) F.pp_print_break fmt 0 0 -(** Register the names for one trait method item *) -let extract_trait_decl_method_register_names (ctx : extraction_ctx) - (trait_decl : trait_decl) (name : string) (id : fun_decl_id) : +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_register_parent_clause_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : extraction_ctx = - (* We add one field per required forward/backward function *) - let trans = A.FunDeclId.Map.find id ctx.trans_funs in - - let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in + let generics = trait_decl.generics in + (* Compute the clause names *) + let clause_names = + match builtin_info with + | None -> + List.map + (fun (c : trait_clause) -> + let name = ctx.fmt.trait_parent_clause_name trait_decl c in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (c.clause_id, name)) + generics.trait_clauses + | Some info -> + List.map + (fun (c, name) -> (c.clause_id, name)) + (List.combine generics.trait_clauses info.parent_clauses) + in + (* Register the names *) + List.fold_left + (fun ctx (cid, cname) -> + ctx_add (TraitParentClauseId (trait_decl.def_id, cid)) cname ctx) + ctx clause_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_register_constant_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let consts = trait_decl.consts in + (* Compute the names *) + let constant_names = + match builtin_info with + | None -> + List.map + (fun (item_name, _) -> + let name = ctx.fmt.trait_const_name trait_decl item_name in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (item_name, name)) + consts + | Some info -> + let const_map = StringMap.of_list info.consts in + List.map + (fun (item_name, _) -> + (item_name, StringMap.find item_name const_map)) + consts + in + (* Register the names *) + List.fold_left + (fun ctx (item_name, name) -> + ctx_add (TraitItemId (trait_decl.def_id, item_name)) name ctx) + ctx constant_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_type_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let types = trait_decl.types in + (* Compute the names *) + let type_names = + match builtin_info with + | None -> + let compute_type_name (item_name : string) : string = + let type_name = ctx.fmt.trait_type_name trait_decl item_name in + if !Config.record_fields_short_names then type_name + else ctx.fmt.trait_decl_name trait_decl ^ type_name + in + let compute_clause_name (item_name : string) (clause : trait_clause) : + TraitClauseId.id * string = + let name = + ctx.fmt.trait_type_clause_name trait_decl item_name clause + in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (clause.clause_id, name) + in + List.map + (fun (item_name, (item_clauses, _)) -> + (* Type name *) + let type_name = compute_type_name item_name in + (* Clause names *) + let clauses = + List.map (compute_clause_name item_name) item_clauses + in + (item_name, (type_name, clauses))) + types + | Some info -> + let type_map = StringMap.of_list info.types in + List.map + (fun (item_name, (item_clauses, _)) -> + let type_name, clauses_info = StringMap.find item_name type_map in + let clauses = + List.map + (fun (clause, clause_name) -> (clause.clause_id, clause_name)) + (List.combine item_clauses clauses_info) + in + (item_name, (type_name, clauses))) + types + in + (* Register the names *) + List.fold_left + (fun ctx (item_name, (type_name, clauses)) -> + let ctx = + ctx_add (TraitItemId (trait_decl.def_id, item_name)) type_name ctx + in + List.fold_left + (fun ctx (clause_id, clause_name) -> + ctx_add + (TraitItemClauseId (trait_decl.def_id, item_name, clause_id)) + clause_name ctx) + ctx clauses) + ctx type_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_method_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let required_methods = trait_decl.required_methods in + (* Compute the names *) + let method_names = + (* We add one field per required forward/backward function *) + let get_funs_for_id (id : fun_decl_id) : fun_decl list = + let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in + List.map (fun f -> f.f) (trans.fwd :: trans.backs) + in + match builtin_info with + | None -> + (* We add one field per required forward/backward function *) + let compute_item_names (item_name : string) (id : fun_decl_id) : + string * (RegionGroupId.id option * string) list = + let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string + = + (* We do something special: we use the base name but remove everything + but the crate (because [get_name] removes it) and the last ident. + This allows us to reuse the [ctx_compute_fun_decl] function. + *) + let basename : name = + match (f.basename : name) with + | Ident crate :: name -> + Ident crate :: [ Collections.List.last name ] + | _ -> raise (Failure "Unexpected") + in + let f = { f with basename } in + let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in + let name = ctx_compute_fun_name trans f ctx in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ "_" ^ name + in + (f.back_id, name) + in + let funs = get_funs_for_id id in + (item_name, List.map compute_fun_name funs) + in + List.map (fun (name, id) -> compute_item_names name id) required_methods + | Some info -> + let funs_map = StringMap.of_list info.funs in + List.map + (fun (item_name, fun_id) -> + let info = StringMap.find item_name funs_map in + let trans_funs = get_funs_for_id fun_id in + let rg_with_name_list = + List.map + (fun (trans_fun : fun_decl) -> + List.find (fun (rg, _) -> rg = trans_fun.back_id) info) + trans_funs + in + (item_name, rg_with_name_list)) + required_methods + in (* Register the names *) - let funs = trans.fwd :: trans.backs in - List.fold_left register_fun ctx funs + List.fold_left + (fun ctx (item_name, funs) -> + (* We add one field per required forward/backward function *) + List.fold_left + (fun ctx (rg, fun_name) -> + ctx_add + (TraitMethodId (trait_decl.def_id, item_name, rg)) + fun_name ctx) + ctx funs) + ctx method_names (** Similar to {!extract_type_decl_register_names} *) let extract_trait_decl_register_names (ctx : extraction_ctx) (trait_decl : trait_decl) : extraction_ctx = - let { - def_id = _; - name = _; - generics; - preds = _; - all_trait_clauses = _; - consts; - types; - required_methods; - provided_methods = _; - } = - trait_decl + (* Lookup the information if this is a builtin trait *) + let open ExtractBuiltin in + let sname = name_to_simple_name trait_decl.name in + let builtin_info = + SimpleNameMap.find_opt sname (builtin_trait_decls_map ()) + in + let ctx = + let trait_name = + match builtin_info with + | None -> ctx.fmt.trait_decl_name trait_decl + | Some info -> info.extract_name + in + ctx_add (TraitDeclId trait_decl.def_id) trait_name ctx in - let ctx = ctx_add_trait_decl trait_decl ctx in (* Parent clauses *) let ctx = - List.fold_left - (fun ctx clause -> ctx_add_trait_parent_clause trait_decl clause ctx) - ctx generics.trait_clauses + extract_trait_decl_register_parent_clause_names ctx trait_decl builtin_info in (* Constants *) let ctx = - List.fold_left - (fun ctx (name, (_, _)) -> ctx_add_trait_const trait_decl name ctx) - ctx consts + extract_trait_decl_register_constant_names ctx trait_decl builtin_info in (* Types *) - let ctx = - List.fold_left - (fun ctx (name, (clauses, _)) -> - let ctx = ctx_add_trait_type trait_decl name ctx in - List.fold_left - (fun ctx clause -> - ctx_add_trait_type_clause trait_decl name clause ctx) - ctx clauses) - ctx types - in + let ctx = extract_trait_decl_type_names ctx trait_decl builtin_info in (* Required methods *) - List.fold_left - (fun ctx (name, id) -> - (* We add one field per required forward/backward function *) - extract_trait_decl_method_register_names ctx trait_decl name id) - ctx required_methods + let ctx = extract_trait_decl_method_names ctx trait_decl builtin_info in + ctx (** Similar to {!extract_type_decl_register_names} *) let extract_trait_impl_register_names (ctx : extraction_ctx) (trait_impl : trait_impl) : extraction_ctx = + let trait_decl = + TraitDeclId.Map.find trait_impl.impl_trait.trait_decl_id + ctx.trans_trait_decls + in + (* Check if the trait implementation is builtin *) + let builtin_info = + let open ExtractBuiltin in + let type_sname = name_to_simple_name trait_impl.name in + let trait_sname = name_to_simple_name trait_decl.name in + SimpleNamePairMap.find_opt (type_sname, trait_sname) + (builtin_trait_impls_map ()) + in + (* For now we do not support overriding provided methods *) assert (trait_impl.provided_methods = []); (* Everything is taken care of by {!extract_trait_decl_register_names} *but* the name of the implementation itself *) - ctx_add_trait_impl trait_impl ctx + (* Compute the name *) + let name = + match builtin_info with + | None -> ctx.fmt.trait_impl_name trait_decl trait_impl + | Some name -> name + in + ctx_add (TraitImplId trait_impl.def_id) name ctx (** Small helper. @@ -4198,8 +4462,7 @@ let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) (decl : trait_decl) : unit = (* Retrieve the trait name *) - let with_opaque_pre = false in - let decl_name = ctx_get_trait_decl with_opaque_pre decl.def_id ctx in + let decl_name = ctx_get_trait_decl decl.def_id ctx in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -4344,7 +4607,7 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) if use_forall then F.pp_print_string fmt ","; (* Extract the function call *) F.pp_print_space fmt (); - let id = ctx_get_local_function false f.def_id None f.back_id ctx in + let id = ctx_get_local_function f.def_id None f.back_id ctx in F.pp_print_string fmt id; let all_generics = let i_tys, i_cgs, i_tcs = impl_generics in @@ -4363,9 +4626,9 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (** Extract a trait implementation *) let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) (impl : trait_impl) : unit = + log#ldebug (lazy ("extract_trait_impl: " ^ Names.name_to_string impl.name)); (* Retrieve the impl name *) - let with_opaque_pre = false in - let impl_name = ctx_get_trait_impl with_opaque_pre impl.def_id ctx in + let impl_name = ctx_get_trait_impl impl.def_id ctx in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -4389,9 +4652,11 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) (* `let (....) : Trait ... =` *) (* Open the box for the name + generics *) F.pp_open_hovbox fmt ctx.indent_incr; - let qualif = Option.get (ctx.fmt.fun_decl_kind_to_qualif SingleNonRec) in - F.pp_print_string fmt qualif; - F.pp_print_space fmt (); + (match ctx.fmt.fun_decl_kind_to_qualif SingleNonRec with + | Some qualif -> + F.pp_print_string fmt qualif; + F.pp_print_space fmt () + | None -> ()); F.pp_print_string fmt impl_name; (* Print the generics *) @@ -4439,7 +4704,7 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) let item_name = ctx_get_trait_const trait_decl_id name ctx in let ty () = F.pp_print_space fmt (); - F.pp_print_string fmt (ctx_get_global false id ctx) + F.pp_print_string fmt (ctx_get_global id ctx) in extract_trait_impl_item ctx fmt item_name ty) @@ -4525,12 +4790,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -4545,12 +4806,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -4562,12 +4819,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -4581,12 +4834,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) | HOL4 -> F.pp_print_string fmt "val _ = assert_return ("; F.pp_print_string fmt "“"; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( |