summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-09-17 04:43:01 +0200
committerSon Ho2023-09-17 04:43:01 +0200
commit296f97bb6a768ffd85f35db2762f2db4f7a357ad (patch)
tree2d83d49ae85deb48527fceda1d6ad8e8c6af4166
parentf2928eaa854688b679f7e504c036866ee7664fe5 (diff)
Make progress on correctly extracting trait method calls
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml110
-rw-r--r--compiler/InterpreterStatements.ml33
-rw-r--r--compiler/PureUtils.ml17
3 files changed, 121 insertions, 39 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 7da5610e..e841082b 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1760,13 +1760,11 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* HOL4 doesn't support const generics *)
assert (cg_params = [] || !backend <> HOL4);
- let left_bracket () =
- if as_implicits then F.pp_print_string fmt "{"
- else F.pp_print_string fmt "("
+ let left_bracket (implicit : bool) =
+ if implicit then F.pp_print_string fmt "{" else F.pp_print_string fmt "("
in
- let right_bracket () =
- if as_implicits then F.pp_print_string fmt "}"
- else F.pp_print_string fmt ")"
+ let right_bracket (implicit : bool) =
+ if implicit then F.pp_print_string fmt "}" else F.pp_print_string fmt ")"
in
let insert_req_space () =
match space with
@@ -1782,7 +1780,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
insert_req_space ();
F.pp_print_string fmt "forall");
(* Small helper - we may need to split the parameters *)
- let print_generics (type_params : string list)
+ let print_generics (as_implicits : bool) (type_params : string list)
(const_generics : const_generic_var list)
(trait_clauses : trait_clause list) : unit =
(* Note that in HOL4 we don't print the type parameters. *)
@@ -1791,7 +1789,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
if type_params <> [] then (
insert_req_space ();
(* ( *)
- left_bracket ();
+ left_bracket as_implicits;
List.iter
(fun s ->
F.pp_print_string fmt s;
@@ -1801,13 +1799,13 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
F.pp_print_string fmt (type_keyword ());
(* ) *)
- right_bracket ());
+ right_bracket as_implicits);
(* Print the const generic parameters *)
List.iter
(fun (var : const_generic_var) ->
insert_req_space ();
(* ( *)
- left_bracket ();
+ left_bracket as_implicits;
let n = ctx_get_const_generic_var var.index ctx in
F.pp_print_string fmt n;
F.pp_print_space fmt ();
@@ -1815,14 +1813,14 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
extract_literal_type ctx fmt var.ty;
(* ) *)
- right_bracket ())
+ right_bracket as_implicits)
const_generics);
(* Print the trait clauses *)
List.iter
(fun (clause : trait_clause) ->
insert_req_space ();
(* ( *)
- left_bracket ();
+ left_bracket as_implicits;
let n = ctx_get_local_trait_clause clause.clause_id ctx in
F.pp_print_string fmt n;
F.pp_print_space fmt ();
@@ -1830,7 +1828,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
extract_trait_clause_type ctx fmt no_params_tys clause;
(* ) *)
- right_bracket ())
+ right_bracket as_implicits)
trait_clauses
in
(* If we extract the generics for a provided method for a trait declaration
@@ -1841,7 +1839,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
*)
match trait_decl with
| None ->
- print_generics type_params generics.const_generics
+ print_generics as_implicits type_params generics.const_generics
generics.trait_clauses
| Some trait_decl ->
(* Split the generics between the generics specific to the trait decl
@@ -1858,8 +1856,10 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
split_at generics.trait_clauses
(length trait_decl.generics.trait_clauses)
in
- (* Extract the trait decl generics *)
- print_generics dtype_params dcgs dtrait_clauses;
+ (* Extract the trait decl generics - note that we can always deduce
+ those parameters from the trait self clause: for this reason
+ they are always implicit *)
+ print_generics true dtype_params dcgs dtrait_clauses;
(* Extract the trait self clause *)
let params =
concat
@@ -1876,7 +1876,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
in
extract_trait_self_clause insert_req_space ctx fmt trait_decl params;
(* Extract the method generics *)
- print_generics mtype_params mcgs mtrait_clauses)
+ print_generics as_implicits mtype_params mcgs mtrait_clauses)
(** Extract a type declaration.
@@ -2646,20 +2646,78 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
allow collisions between trait item names and some other names,
while we do not allow collisions between function names.
- Remark: calls to trait methods when the implementation is known
- (i.e., when we do not use a trait parameter) are desugared to regular
- function calls.
+ # Impl trait refs:
+ ==================
+ When the trait ref refers to an impl, in
+ [InterpreterStatement.eval_transparent_function_call_symbolic] we
+ replace the call to the trait impl method to a call to the function
+ which implements the trait method (that is, we "forget" that we
+ called a trait method, and treat it as a regular function call).
+
+ # Provided trait methods:
+ =========================
+ Calls to provided trait methods also have a special treatment.
+ For now, we do not allow overriding provided trait methods (methods
+ for which a default implementation is provided in the trait declaration).
+ Whenever we translate a provided trait method, we translate it once as
+ a function which takes a trait ref as input. We have to handle this
+ case below.
+
+ With an example, if in Rust we write:
+ {[
+ fn Foo {
+ fn f(&self) -> u32; // Required
+ fn ret_true(&self) -> bool { true } // Provided
+ }
+ ]}
+
+ We generate:
+ {[
+ structure Foo (Self : Type) = {
+ f : Self -> result u32
+ }
+
+ let ret_true (Self : Type) (self_clause : Foo Self) (self : Self) : result bool =
+ true
+ ]}
*)
(match fun_id with
| FromLlbc
(TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) ->
- assert (lp_id = None);
- extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
- let fun_name =
- ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
- method_name rg_id ctx
+ (* We have to check whether the trait method is required or provided *)
+ let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in
+ let trait_decl =
+ TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls
in
- F.pp_print_string fmt ("." ^ fun_name)
+ let method_id =
+ PureUtils.trait_decl_get_method trait_decl method_name
+ in
+
+ if not method_id.is_provided then (
+ (* Required method *)
+ assert (lp_id = None);
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
+ let fun_name =
+ ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
+ method_name rg_id ctx
+ in
+ F.pp_print_string fmt ("." ^ fun_name))
+ else
+ (* Provided method: we see it as a regular function call, and use
+ the function name *)
+ let fun_id =
+ FromLlbc (FunId (A.Regular method_id.id), lp_id, rg_id)
+ in
+ let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
+ F.pp_print_string fmt fun_name;
+
+ (* Note that we do not need to print the generics for the trait
+ declaration: they are always implicit as they can be deduced
+ from the trait self clause.
+
+ Print the trait ref (to instantate the self clause) *)
+ F.pp_print_space fmt ();
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref
| _ ->
let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
F.pp_print_string fmt fun_name);
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 2a5c8952..f54c5dbd 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -1191,6 +1191,9 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
(* Instantiate the signature and introduce fresh abstractions and region ids while doing so.
We perform some manipulations when instantiating the signature.
+
+ # Trait impl calls
+ ==================
In particular, we have a special treatment of trait method calls when
the trait ref is a known impl.
@@ -1216,11 +1219,11 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
The generated code looks like this:
{[
- structure HasValue (T : Type) = {
- has_value : T -> result bool
+ structure HasValue (Self : Type) = {
+ has_value : Self -> result bool
}
- let OptionHasValueImpl.has_value (T : Type) (self : T) : result bool =
+ let OptionHasValueImpl.has_value (Self : Type) (self : Self) : result bool =
match self with
| None => false
| Some _ => true
@@ -1244,6 +1247,13 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
let option_has_value (T : Type) (x : Option T) : result bool =
OptionHasValueImpl.has_value T x
]}
+
+ # Provided trait methods
+ ========================
+ Calls to provided trait methods also have a special treatment because
+ for now we forbid overriding provided trait methods in the trait implementations,
+ which means that whenever we call a provided trait method, we do not refer
+ to a trait clause but directly to the method provided in the trait declaration.
*)
let func, generics, def, self_trait_ref, inst_sg =
match call.func with
@@ -1319,7 +1329,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
in
let method_id = Option.get method_id in
let method_def = C.ctx_lookup_fun_decl ctx method_id in
- (* For the instantiation we have to do something perculiar
+ (* For the instantiation we have to do something peculiar
because the method was defined for the trait declaration.
We have to group:
- the parameters given to the trait decl reference
@@ -1336,15 +1346,15 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
}
]}
*)
- let generics =
+ let all_generics =
TypesUtils.merge_generic_args
trait_ref.trait_decl_ref.decl_generics call.generics
in
log#ldebug
(lazy
("provided method call:" ^ "\n- method name: " ^ method_name
- ^ "\n- generics:\n"
- ^ egeneric_args_to_string ctx generics
+ ^ "\n- all_generics:\n"
+ ^ egeneric_args_to_string ctx all_generics
^ "\n- parent params info: "
^ Print.option_to_string A.show_params_info
method_def.signature.parent_params_info));
@@ -1352,13 +1362,10 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
in
let inst_sg =
- instantiate_fun_sig ctx generics tr_self
+ instantiate_fun_sig ctx all_generics tr_self
method_def.A.signature
in
- (* We directly call the function, pretending it is not a trait method call *)
- (* TODO: we need to add the self trait ref *)
- let func = A.FunId (A.Regular method_def.def_id) in
- (func, generics, method_def, Some trait_ref, inst_sg))
+ (call.func, call.generics, method_def, Some trait_ref, inst_sg))
| _ ->
(* We are using a local clause - we lookup the trait decl *)
let trait_decl =
@@ -1387,7 +1394,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
let inst_sg =
instantiate_fun_sig ctx generics tr_self method_def.A.signature
in
- (call.func, generics, method_def, None, inst_sg))
+ (call.func, call.generics, method_def, None, inst_sg))
in
(* Sanity check *)
assert (List.length call.args = List.length def.A.signature.inputs);
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 1357793b..4e44f252 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -625,3 +625,20 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
Some (mk_apps cons fields_values).e
in
match e_opt with None -> None | Some e -> Some { e; ty = pat.ty }
+
+type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id }
+
+let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) :
+ trait_decl_method_decl_id =
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods
+ in
+ match method_id with
+ | Some (_, id) -> { is_provided = false; id }
+ | None ->
+ (* Must be a provided method *)
+ let _, id =
+ List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods
+ in
+ { is_provided = true; id = Option.get id }