summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/Extract.ml115
-rw-r--r--compiler/ExtractBase.ml1
-rw-r--r--compiler/ExtractBuiltin.ml82
-rw-r--r--compiler/Translate.ml158
4 files changed, 234 insertions, 122 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 6a306592..ddc02fa7 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1277,7 +1277,10 @@ and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter)
let name = ctx_get_trait_decl is_opaque tr.trait_decl_id ctx in
if use_brackets then F.pp_print_string fmt "(";
F.pp_print_string fmt name;
- extract_generic_args ctx fmt no_params_tys tr.decl_generics;
+ (* There is something subtle here: the trait obligations for the implemented
+ trait are put inside the parent clauses, so we must ignore them here *)
+ let generics = { tr.decl_generics with trait_refs = [] } in
+ extract_generic_args ctx fmt no_params_tys generics;
if use_brackets then F.pp_print_string fmt ")"
and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
@@ -1349,7 +1352,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
let def_name =
match info with
| None -> ctx.fmt.type_name def.name
- | Some info -> info.rust_name
+ | Some info -> String.concat "." info.rust_name
in
let is_opaque = def.kind = Opaque in
let ctx = ctx_add is_opaque (TypeId (AdtId def.def_id)) def_name ctx in
@@ -1363,7 +1366,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
(* Compute the names *)
let field_names, cons_name =
match info with
- | None ->
+ | None | Some { body_info = None; _ } ->
let field_names =
FieldId.mapi
(fun fid (field : field) ->
@@ -1379,7 +1382,11 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
(List.combine fields field_names)
in
(field_names, cons_name)
- | _ -> raise (Failure "Invalid builtin information")
+ | Some info ->
+ raise
+ (Failure
+ ("Invalid builtin information: "
+ ^ show_builtin_type_info info))
in
(* Add the fields *)
let ctx =
@@ -2365,33 +2372,70 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
extraction_ctx =
- let fwd = def.fwd in
- let backs = def.backs in
- (* Register the decrease clauses, if necessary *)
- let register_decreases ctx def =
- if has_decreases_clause def then
- (* Add the termination measure *)
- let ctx = ctx_add_termination_measure def ctx in
- (* Add the decreases proof for Lean only *)
- match !Config.backend with
- | Coq | FStar -> ctx
- | HOL4 -> raise (Failure "Unexpected")
- | Lean -> ctx_add_decreases_proof def ctx
- else ctx
- in
- let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in
- let register_fun ctx f = ctx_add_fun_decl def f ctx in
- let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the names of the forward functions *)
- let ctx =
- if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
- in
- (* Register the names of the backward functions *)
- List.fold_left
- (fun ctx { f = back; loops = loop_backs } ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx backs
+ (* Ignore the trait methods **declarations** (rem.: we do not ignore the trait
+ method implementations): we do not need to refer to them directly. We will
+ only use their type for the fields of the records we generate for the trait
+ declarations *)
+ match def.fwd.f.kind with
+ | TraitMethodDecl _ -> ctx
+ | _ -> (
+ (* Check if the function is builtin *)
+ let builtin =
+ let open ExtractBuiltin in
+ let funs_map = builtin_funs_map () in
+ let sname = name_to_simple_name def.fwd.f.basename in
+ SimpleNameMap.find_opt sname funs_map
+ in
+ (* Use the builtin names if necessary *)
+ match builtin with
+ | Some (_filter, info) ->
+ let backs = List.map (fun f -> f.f) def.backs in
+ let funs = if def.keep_fwd then def.fwd.f :: backs else backs in
+ let is_opaque = false in
+ List.fold_left
+ (fun ctx (f : fun_decl) ->
+ let open ExtractBuiltin in
+ let fun_id =
+ (Pure.FunId (Regular f.def_id), f.loop_id, f.back_id)
+ in
+ let fun_name =
+ (List.find
+ (fun (x : builtin_fun_info) -> x.rg = f.back_id)
+ info)
+ .extract_name
+ in
+ ctx_add is_opaque (FunId (FromLlbc fun_id)) fun_name ctx)
+ ctx funs
+ | None ->
+ let fwd = def.fwd in
+ let backs = def.backs in
+ (* Register the decrease clauses, if necessary *)
+ let register_decreases ctx def =
+ if has_decreases_clause def then
+ (* Add the termination measure *)
+ let ctx = ctx_add_termination_measure def ctx in
+ (* Add the decreases proof for Lean only *)
+ match !Config.backend with
+ | Coq | FStar -> ctx
+ | HOL4 -> raise (Failure "Unexpected")
+ | Lean -> ctx_add_decreases_proof def ctx
+ else ctx
+ in
+ let ctx =
+ List.fold_left register_decreases ctx (fwd.f :: fwd.loops)
+ in
+ let register_fun ctx f = ctx_add_fun_decl def f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the names of the forward functions *)
+ let ctx =
+ if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
+ in
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx { f = back; loops = loop_backs } ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
+ ctx backs)
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -4539,6 +4583,7 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
(** Extract a trait implementation *)
let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
(impl : trait_impl) : unit =
+ log#ldebug (lazy ("extract_trait_impl: " ^ Names.name_to_string impl.name));
(* Retrieve the impl name *)
let with_opaque_pre = false in
let impl_name = ctx_get_trait_impl with_opaque_pre impl.def_id ctx in
@@ -4565,9 +4610,11 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
(* `let (....) : Trait ... =` *)
(* Open the box for the name + generics *)
F.pp_open_hovbox fmt ctx.indent_incr;
- let qualif = Option.get (ctx.fmt.fun_decl_kind_to_qualif SingleNonRec) in
- F.pp_print_string fmt qualif;
- F.pp_print_space fmt ();
+ (match ctx.fmt.fun_decl_kind_to_qualif SingleNonRec with
+ | Some qualif ->
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ()
+ | None -> ());
F.pp_print_string fmt impl_name;
(* Print the generics *)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index ea5fe8d3..22b017e5 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -1249,6 +1249,7 @@ let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
(keep_fwd, num_backs)
+(* TODO: move to Extract *)
let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
(ctx : extraction_ctx) : extraction_ctx =
(* Sanity check: the function should not be a global body - those are handled
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index 3b4afff6..0d591028 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -78,21 +78,24 @@ let builtin_globals_map : string SimpleNameMap.t =
(List.map (fun (x, y) -> (string_to_simple_name x, y)) builtin_globals)
type builtin_variant_info = { fields : (string * string) list }
+[@@deriving show]
type builtin_enum_variant_info = {
rust_variant_name : string;
extract_variant_name : string;
fields : string list option;
}
+[@@deriving show]
type builtin_type_body_info =
| Struct of string * string list
(* The constructor name and the map for the field names *)
| Enum of builtin_enum_variant_info list
(* For every variant, a map for the field names *)
+[@@deriving show]
type builtin_type_info = {
- rust_name : string;
+ rust_name : string list;
extract_name : string;
keep_params : bool list option;
(** We might want to filter some of the type parameters.
@@ -102,6 +105,7 @@ type builtin_type_info = {
*)
body_info : builtin_type_body_info option;
}
+[@@deriving show]
(** The assumed types.
@@ -113,7 +117,7 @@ let builtin_types () : builtin_type_info list =
[
(* Alloc *)
{
- rust_name = "alloc::alloc::Global";
+ rust_name = [ "alloc"; "alloc"; "Global" ];
extract_name =
(match !backend with
| Lean -> "AllocGlobal"
@@ -123,7 +127,7 @@ let builtin_types () : builtin_type_info list =
};
(* Vec *)
{
- rust_name = "alloc::vec::Vec";
+ rust_name = [ "alloc"; "vec"; "Vec" ];
extract_name =
(match !backend with Lean -> "Vec" | Coq | FStar | HOL4 -> "vec");
keep_params = Some [ true; false ];
@@ -131,7 +135,7 @@ let builtin_types () : builtin_type_info list =
};
(* Option *)
{
- rust_name = "core::option::Option";
+ rust_name = [ "core"; "option"; "Option" ];
extract_name =
(match !backend with
| Lean -> "Option"
@@ -163,7 +167,7 @@ let builtin_types () : builtin_type_info list =
};
(* Range *)
{
- rust_name = "core::ops::range::Range";
+ rust_name = [ "core"; "ops"; "range"; "Range" ];
extract_name =
(match !backend with Lean -> "Range" | Coq | FStar | HOL4 -> "range");
keep_params = None;
@@ -180,9 +184,7 @@ let builtin_types () : builtin_type_info list =
let mk_builtin_types_map () =
SimpleNameMap.of_list
- (List.map
- (fun info -> (string_to_simple_name info.rust_name, info))
- (builtin_types ()))
+ (List.map (fun info -> (info.rust_name, info)) (builtin_types ()))
let builtin_types_map = mk_memoized mk_builtin_types_map
@@ -190,6 +192,7 @@ type builtin_fun_info = {
rg : Types.RegionGroupId.id option;
extract_name : string;
}
+[@@deriving show]
(** The assumed functions.
@@ -197,10 +200,12 @@ type builtin_fun_info = {
parameters. For instance, in the case of the `Vec` functions, there is
a type parameter for the allocator to use, which we want to filter.
*)
-let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
+let builtin_funs () :
+ (string list * bool list option * builtin_fun_info list) list =
let rg0 = Some Types.RegionGroupId.zero in
+ (* TODO: fix the names below *)
[
- ( "core::mem::replace",
+ ( [ "core::mem::replace" ],
None,
[
{
@@ -218,7 +223,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "mem.replace_back");
};
] );
- ( "alloc::vec::Vec::new",
+ ( [ "alloc::vec::Vec::new" ],
Some [ true; false ],
[
{
@@ -236,7 +241,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.new_back");
};
] );
- ( "alloc::vec::Vec::push",
+ ( [ "alloc::vec::Vec::push" ],
Some [ true; false ],
[
(* The forward function shouldn't be used *)
@@ -255,7 +260,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.push");
};
] );
- ( "alloc::vec::Vec::insert",
+ ( [ "alloc::vec::Vec::insert" ],
Some [ true; false ],
[
(* The forward function shouldn't be used *)
@@ -274,7 +279,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.insert");
};
] );
- ( "alloc::vec::Vec::len",
+ ( [ "alloc::vec::Vec::len" ],
Some [ true; false ],
[
{
@@ -285,7 +290,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.len");
};
] );
- ( "alloc::vec::Vec::index",
+ ( [ "alloc::vec::Vec::index" ],
Some [ true; false ],
[
{
@@ -304,7 +309,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.index_shared_back");
};
] );
- ( "alloc::vec::Vec::index_mut",
+ ( [ "alloc::vec::Vec::index_mut" ],
Some [ true; false ],
[
{
@@ -323,16 +328,52 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list =
| Lean -> "Vec.index_mut_back");
};
] );
+ ( [ "alloc"; "boxed"; "Box"; "deref" ],
+ Some [ true; false ],
+ [
+ {
+ rg = None;
+ extract_name =
+ (match !backend with
+ | FStar | Coq | HOL4 -> "alloc_boxed_box_deref"
+ | Lean -> "alloc.boxed.Box.deref");
+ };
+ (* The backward function shouldn't be used *)
+ {
+ rg = rg0;
+ extract_name =
+ (match !backend with
+ | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_back"
+ | Lean -> "alloc.boxed.Box.deref_back");
+ };
+ ] );
+ ( [ "alloc"; "boxed"; "Box"; "deref_mut" ],
+ Some [ true; false ],
+ [
+ {
+ rg = None;
+ extract_name =
+ (match !backend with
+ | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_mut"
+ | Lean -> "alloc.boxed.Box.deref_mut");
+ };
+ {
+ rg = rg0;
+ extract_name =
+ (match !backend with
+ | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_mut_back"
+ | Lean -> "alloc.boxed.Box.deref_mut_back");
+ };
+ ] );
]
let mk_builtin_funs_map () =
SimpleNameMap.of_list
(List.map
- (fun (name, filter, info) ->
- (string_to_simple_name name, (filter, info)))
+ (fun (name, filter, info) -> (name, (filter, info)))
(builtin_funs ()))
-let builtin_funs_map () = mk_memoized mk_builtin_funs_map
+let builtin_funs_map = mk_memoized mk_builtin_funs_map
type builtin_trait_decl_info = {
rust_name : string;
@@ -346,6 +387,7 @@ type builtin_trait_decl_info = {
- a list of clauses *)
funs : (string * (Types.RegionGroupId.id option * string) list) list;
}
+[@@deriving show]
let builtin_trait_decls_info () =
let rg0 = Some Types.RegionGroupId.zero in
@@ -389,7 +431,7 @@ let builtin_trait_decls_info () =
[
(match !backend with
| Coq | FStar | HOL4 -> "deref_inst"
- | Lean -> "DerefInst");
+ | Lean -> "derefInst");
];
consts = [];
types = [];
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 0871a305..95252b61 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -654,78 +654,100 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
*)
let export_functions_group (fmt : Format.formatter) (config : gen_config)
(ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit =
- (* Utility to check a function has a decrease clause *)
- let has_decreases_clause (def : Pure.fun_decl) : bool =
- PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
- ctx.functions_with_decreases_clause
+ (* Check if the definition are builtin - if yes they must be ignored.
+ Note that if one definition in the group is builtin, then all the
+ definitions must be builtin *)
+ let builtin =
+ let open ExtractBuiltin in
+ let funs_map = builtin_funs_map () in
+ List.map
+ (fun (trans : pure_fun_translation) ->
+ let sname = name_to_simple_name trans.fwd.f.basename in
+ SimpleNameMap.find_opt sname funs_map <> None)
+ pure_ls
in
- (* Extract the decrease clauses template bodies *)
- if config.extract_template_decreases_clauses then
- List.iter
- (fun { fwd; _ } ->
- (* We only generate decreases clauses for the forward functions, because
- the termination argument should only depend on the forward inputs.
- The backward functions thus use the same decreases clauses as the
- forward function.
-
- Rem.: we might filter backward functions in {!PureMicroPasses}, but
- we don't remove forward functions. Instead, we remember if we should
- filter those functions at extraction time with a boolean (see the
- type of the [pure_ls] input parameter).
- *)
- let extract_decrease decl =
- let has_decr_clause = has_decreases_clause decl in
- if has_decr_clause then
- match !Config.backend with
- | Lean ->
- Extract.extract_template_lean_termination_and_decreasing ctx fmt
- decl
- | FStar ->
- Extract.extract_template_fstar_decreases_clause ctx fmt decl
- | Coq ->
- raise (Failure "Coq doesn't have decreases/termination clauses")
- | HOL4 ->
- raise
- (Failure "HOL4 doesn't have decreases/termination clauses")
- in
- extract_decrease fwd.f;
- List.iter extract_decrease fwd.loops)
- pure_ls;
-
- (* Concatenate the function definitions, filtering the useless forward
- * functions. *)
- let decls =
- List.concat
- (List.map
- (fun { keep_fwd; fwd; backs } ->
- let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in
- let backs : Pure.fun_decl list =
- List.concat
- (List.map (fun back -> List.append back.loops [ back.f ]) backs)
- in
- List.append fwd backs)
- pure_ls)
- in
+ if List.exists (fun b -> b) builtin then
+ (* Sanity check *)
+ assert (List.for_all (fun b -> b) builtin)
+ else
+ (* Utility to check a function has a decrease clause *)
+ let has_decreases_clause (def : Pure.fun_decl) : bool =
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
+ in
+
+ (* Extract the decrease clauses template bodies *)
+ if config.extract_template_decreases_clauses then
+ List.iter
+ (fun { fwd; _ } ->
+ (* We only generate decreases clauses for the forward functions, because
+ the termination argument should only depend on the forward inputs.
+ The backward functions thus use the same decreases clauses as the
+ forward function.
+
+ Rem.: we might filter backward functions in {!PureMicroPasses}, but
+ we don't remove forward functions. Instead, we remember if we should
+ filter those functions at extraction time with a boolean (see the
+ type of the [pure_ls] input parameter).
+ *)
+ let extract_decrease decl =
+ let has_decr_clause = has_decreases_clause decl in
+ if has_decr_clause then
+ match !Config.backend with
+ | Lean ->
+ Extract.extract_template_lean_termination_and_decreasing ctx
+ fmt decl
+ | FStar ->
+ Extract.extract_template_fstar_decreases_clause ctx fmt decl
+ | Coq ->
+ raise
+ (Failure "Coq doesn't have decreases/termination clauses")
+ | HOL4 ->
+ raise
+ (Failure "HOL4 doesn't have decreases/termination clauses")
+ in
+ extract_decrease fwd.f;
+ List.iter extract_decrease fwd.loops)
+ pure_ls;
+
+ (* Concatenate the function definitions, filtering the useless forward
+ * functions. *)
+ let decls =
+ List.concat
+ (List.map
+ (fun { keep_fwd; fwd; backs } ->
+ let fwd =
+ if keep_fwd then List.append fwd.loops [ fwd.f ] else []
+ in
+ let backs : Pure.fun_decl list =
+ List.concat
+ (List.map
+ (fun back -> List.append back.loops [ back.f ])
+ backs)
+ in
+ List.append fwd backs)
+ pure_ls)
+ in
- (* Extract the function definitions *)
- (if config.extract_fun_decls then
- (* Group the mutually recursive definitions *)
- let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+ (* Extract the function definitions *)
+ (if config.extract_fun_decls then
+ (* Group the mutually recursive definitions *)
+ let subgroups = ReorderDecls.group_reorder_fun_decls decls in
- (* Extract the subgroups *)
- let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
- export_functions_group_scc fmt config ctx is_rec decls
- in
- List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
-
- (* Insert unit tests if necessary *)
- if config.test_trans_unit_functions then
- List.iter
- (fun trans ->
- if trans.keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
- pure_ls
+ (* Extract the subgroups *)
+ let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ export_functions_group_scc fmt config ctx is_rec decls
+ in
+ List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
+
+ (* Insert unit tests if necessary *)
+ if config.test_trans_unit_functions then
+ List.iter
+ (fun trans ->
+ if trans.keep_fwd then
+ Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
+ pure_ls
(** Export a trait declaration. *)
let export_trait_decl (fmt : Format.formatter) (_config : gen_config)