summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml4
-rw-r--r--compiler/ExtractBase.ml16
-rw-r--r--compiler/InterpreterStatements.ml126
-rw-r--r--compiler/InterpreterUtils.ml3
-rw-r--r--compiler/Print.ml8
-rw-r--r--compiler/PrintPure.ml7
-rw-r--r--compiler/Pure.ml14
-rw-r--r--compiler/PureMicroPasses.ml29
-rw-r--r--compiler/ReorderDecls.ml4
-rw-r--r--compiler/SymbolicAst.ml2
-rw-r--r--compiler/SymbolicToPure.ml73
-rw-r--r--compiler/SynthesizeSymbolic.ml2
12 files changed, 221 insertions, 67 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index d000c447..fe007d31 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -3411,7 +3411,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
let { keep_fwd; num_backs } =
PureUtils.RegularFunIdMap.find
- (A.Regular def.def_id, def.loop_id, def.back_id)
+ (Pure.FunId (A.Regular def.def_id), def.loop_id, def.back_id)
ctx.fun_name_info
in
let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in
@@ -3908,7 +3908,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in
let body_name =
ctx_get_function with_opaque_pre
- (FromLlbc (Regular global.body_id, None, None))
+ (FromLlbc (Pure.FunId (Regular global.body_id), None, None))
ctx
in
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 9c9e08a5..438a3475 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -703,10 +703,13 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let fun_name =
match fid with
- | Regular fid ->
+ | FunId (Regular fid) ->
Print.fun_name_to_string
(A.FunDeclId.Map.find fid fun_decls).name
- | Assumed aid -> A.show_assumed_fun_id aid
+ | FunId (Assumed aid) -> A.show_assumed_fun_id aid
+ | TraitMethod _ ->
+ (* Shouldn't happen *)
+ raise (Failure "Unexpected")
in
let lp_kind =
@@ -908,7 +911,7 @@ let ctx_get_function (with_opaque_pre : bool) (id : fun_id)
let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id)
(lp : LoopId.id option) (rg : RegionGroupId.id option)
(ctx : extraction_ctx) : string =
- ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx
+ ctx_get_function with_opaque_pre (FromLlbc (FunId (Regular id), lp, rg)) ctx
let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx)
: string =
@@ -1212,7 +1215,7 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
| None ->
(* Not the case: "standard" registration *)
let name = ctx.fmt.global_name def.name in
- let body = FunId (FromLlbc (Regular def.body_id, None, None)) in
+ let body = FunId (FromLlbc (FunId (Regular def.body_id), None, None)) in
let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in
let ctx = ctx_add is_opaque body (name ^ "_body") ctx in
ctx
@@ -1256,7 +1259,7 @@ let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
let is_opaque = def.body = None in
(* Add the function name *)
let def_name = ctx_compute_fun_name trans_group def ctx in
- let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in
+ let fun_id = (Pure.FunId (Regular def_id), def.loop_id, def.back_id) in
let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in
(* Add the name info *)
{
@@ -1381,7 +1384,8 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
in
let assumed_functions =
List.map
- (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name))
+ (fun (fid, rg, name) ->
+ (FromLlbc (Pure.FunId (A.Assumed fid), None, rg), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 5791a359..1ea16e58 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -1082,12 +1082,12 @@ and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun
* where we haven't panicked. Of course, the translation needs to take the
* panic case into account... *)
eval_assumed_function_call_concrete config fid call (cf Unit) ctx
- | A.TraitMethod (_, _) -> raise (Failure "Unimplemented")
+ | A.TraitMethod _ -> raise (Failure "Unimplemented")
and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun
=
match call.func with
- | A.FunId (A.Regular _) | A.TraitMethod (_, _) ->
+ | A.FunId (A.Regular _) | A.TraitMethod _ ->
eval_transparent_function_call_symbolic config call
| A.FunId (A.Assumed fid) ->
eval_assumed_function_call_symbolic config fid call
@@ -1182,7 +1182,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
: st_cm_fun =
fun cf ctx ->
(* Instantiate the signature and introduce fresh abstractions and region ids while doing so *)
- let def, inst_sg =
+ let func, def, self_trait_ref, inst_sg =
match call.func with
| A.FunId (A.Regular fid) ->
let def = C.ctx_lookup_fun_decl ctx fid in
@@ -1190,11 +1190,19 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
let inst_sg =
instantiate_fun_sig ctx call.generics tr_self def.A.signature
in
- (def, inst_sg)
+ (call.func, def, None, inst_sg)
| A.FunId (A.Assumed _) ->
(* Unreachable: must be a transparent function *)
raise (Failure "Unreachable")
- | A.TraitMethod (trait_ref, method_name) -> (
+ | A.TraitMethod (trait_ref, method_name, _) -> (
+ log#ldebug
+ (lazy
+ ("trait method call:\n- call: " ^ call_to_string ctx call
+ ^ "\n- method name: " ^ method_name ^ "\n- generics:\n"
+ ^ egeneric_args_to_string ctx call.generics
+ ^ "\n- trait and method generics:\n"
+ ^ egeneric_args_to_string ctx
+ (Option.get call.trait_and_method_generic_args)));
(* When instantiating, we need to group the generics for the trait ref
and the method *)
let generics = Option.get call.trait_and_method_generic_args in
@@ -1202,31 +1210,97 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
depending on whethere we call a top-level trait method impl or the
method from a local clause *)
match trait_ref.trait_id with
- | TraitImpl impl_id ->
+ | TraitImpl impl_id -> (
(* Lookup the trait impl *)
let trait_impl = C.ctx_lookup_trait_impl ctx impl_id in
- let _, method_id =
- List.find
+ log#ldebug
+ (lazy ("trait impl: " ^ trait_impl_to_string ctx trait_impl));
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt
(fun (s, _) -> s = method_name)
trait_impl.required_methods
in
- let method_def = C.ctx_lookup_fun_decl ctx method_id in
- (* Instantiate *)
- let tr_self = T.UnknownTrait __FUNCTION__ in
- let inst_sg =
- instantiate_fun_sig ctx generics tr_self method_def.A.signature
- in
- (method_def, inst_sg)
+ match method_id with
+ | Some (_, id) ->
+ let method_def = C.ctx_lookup_fun_decl ctx id in
+ (* Instantiate *)
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self
+ method_def.A.signature
+ in
+ (call.func, method_def, None, inst_sg)
+ | None ->
+ (* If not found, lookup the methods provided by the trait *declaration*
+ (remember: for now, we forbid overriding provided methods) *)
+ assert (trait_impl.provided_methods = []);
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx
+ trait_ref.trait_decl_ref.trait_decl_id
+ in
+ let _, method_id =
+ List.find
+ (fun (s, _) -> s = method_name)
+ trait_decl.provided_methods
+ in
+ let method_id = Option.get method_id in
+ let method_def = C.ctx_lookup_fun_decl ctx method_id in
+ (* For the instantiation we have to do something perculiar
+ because the method was defined for the trait declaration.
+ We have to group:
+ - the parameters given to the trait decl reference
+ - the parameters given to the method itself
+ For instance:
+ {[
+ trait Foo<T> {
+ fn f<U>(...) { ... }
+ }
+
+ fn g<G>(x : G) where Clause0: Foo<G, bool>
+ {
+ x.f::<u32>(...) // The arguments to f are: <G, bool, u32>
+ }
+ ]}
+ *)
+ let generics =
+ TypesUtils.merge_generic_args
+ trait_ref.trait_decl_ref.decl_generics call.generics
+ in
+ log#ldebug
+ (lazy
+ ("provided method call:" ^ "\n- method name: " ^ method_name
+ ^ "\n- generics:\n"
+ ^ egeneric_args_to_string ctx generics));
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self
+ method_def.A.signature
+ in
+ (* We directly call the function, pretending it is not a trait method call *)
+ (* TODO: we need to add the self trait ref *)
+ let func = A.FunId (A.Regular method_def.def_id) in
+ (func, method_def, Some trait_ref, inst_sg))
| _ ->
(* We are using a local clause - we lookup the trait decl *)
let trait_decl =
C.ctx_lookup_trait_decl ctx trait_ref.trait_decl_ref.trait_decl_id
in
- (* Lookup the method decl *)
+ (* Lookup the method decl in the required *and* the provided methods *)
let _, method_id =
+ let provided =
+ List.filter_map
+ (fun (id, f) ->
+ match f with None -> None | Some f -> Some (id, f))
+ trait_decl.provided_methods
+ in
List.find
(fun (s, _) -> s = method_name)
- trait_decl.required_methods
+ (List.append trait_decl.required_methods provided)
in
let method_def = C.ctx_lookup_fun_decl ctx method_id in
(* Instantiate *)
@@ -1238,26 +1312,30 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
let inst_sg =
instantiate_fun_sig ctx generics tr_self method_def.A.signature
in
- (method_def, inst_sg))
+ (call.func, method_def, None, inst_sg))
in
(* Sanity check *)
assert (List.length call.args = List.length def.A.signature.inputs);
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config call.func inst_sg
+ eval_function_call_symbolic_from_inst_sig config func inst_sg self_trait_ref
call.generics call.args call.dest cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
This allows us to factorize the evaluation of local and non-local function
calls in symbolic mode: only their signatures matter.
+
+ The [self_trait_ref] trait ref refers to [Self]. We use it when calling
+ a provided trait method, because those methods have a special treatment:
+ we dot not group them with the required trait methods, and forbid (for now)
+ overriding them. We treat them as regular method, which take an additional
+ trait ref as input.
*)
and eval_function_call_symbolic_from_inst_sig (config : C.config)
(fid : A.fun_id_or_trait_method_ref) (inst_sg : A.inst_fun_sig)
- (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+ (self_trait_ref : T.etrait_ref option) (generics : T.egeneric_args)
+ (args : E.operand list) (dest : E.place) : st_cm_fun =
fun cf ctx ->
- (* TODO: trait methods are not supported yet *)
- let fid = match fid with A.FunId fid -> fid | _ -> raise (Failure "TODO") in
(* Generate a fresh symbolic value for the return value *)
let ret_sv_ty = inst_sg.A.output in
let ret_spc = mk_fresh_symbolic_value V.FunCallRet ret_sv_ty in
@@ -1427,7 +1505,7 @@ and eval_assumed_function_call_symbolic (config : C.config)
(* Evaluate the function call *)
eval_function_call_symbolic_from_inst_sig config (A.FunId (A.Assumed fid))
- inst_sig generics args dest cf ctx
+ inst_sig None generics args dest cf ctx
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun =
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index 0986c53b..8a7e8c52 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -38,6 +38,9 @@ let typed_value_to_string = PA.typed_value_to_string
let typed_avalue_to_string = PA.typed_avalue_to_string
let place_to_string = PA.place_to_string
let operand_to_string = PA.operand_to_string
+let egeneric_args_to_string = PA.egeneric_args_to_string
+let call_to_string = PA.call_to_string
+let trait_impl_to_string = PA.trait_impl_to_string
let statement_to_string ctx = PA.statement_to_string ctx "" " "
let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " "
let env_elem_to_string ctx = PA.env_elem_to_string ctx "" " "
diff --git a/compiler/Print.ml b/compiler/Print.ml
index aebfd09c..55aa0c53 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -698,11 +698,19 @@ module EvalCtxLlbcAst = struct
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PE.operand_to_string fmt op
+ let call_to_string (ctx : C.eval_ctx) (call : A.call) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.call_to_string fmt "" call
+
let statement_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (e : A.statement) : string =
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PA.statement_to_string fmt indent indent_incr e
+ let trait_impl_to_string (ctx : C.eval_ctx) (timpl : A.trait_impl) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.trait_impl_to_string fmt " " " " timpl
+
let env_elem_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (ev : C.env_elem) : string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index c7f59ec9..d539dcf6 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -629,8 +629,11 @@ let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let f =
match fid with
- | Regular fid -> fmt.fun_decl_id_to_string fid
- | Assumed fid -> llbc_assumed_fun_id_to_string fid
+ | FunId (Regular fid) -> fmt.fun_decl_id_to_string fid
+ | FunId (Assumed fid) -> llbc_assumed_fun_id_to_string fid
+ | TraitMethod (trait_ref, method_name, _) ->
+ let fmt = ast_to_type_formatter fmt in
+ trait_ref_to_string fmt true trait_ref ^ "." ^ method_name
in
f ^ fun_suffix lp_id rg_id
| Pure fid -> pure_assumed_fun_id_to_string fid
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 81060c43..47c7beb4 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -306,6 +306,7 @@ and trait_instance_id =
| UnknownTrait of string
[@@deriving
show,
+ ord,
visitors
{
name = "iter_ty";
@@ -369,7 +370,7 @@ type trait_type_constraint = {
type_name : trait_item_name;
ty : ty;
}
-[@@deriving show]
+[@@deriving show, ord]
type predicates = { trait_type_constraints : trait_type_constraint list }
[@@deriving show]
@@ -530,8 +531,15 @@ type pure_assumed_fun_id =
| FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *)
[@@deriving show, ord]
+type fun_id_or_trait_method_ref =
+ | FunId of A.fun_id
+ | TraitMethod of trait_ref * string * fun_decl_id
+ (** The fun decl id is not really needed and here for convenience purposes *)
+[@@deriving show, ord]
+
(** A function id for a non-assumed function *)
-type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option
+type regular_fun_id =
+ fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -1003,7 +1011,7 @@ type trait_decl = {
consts : (trait_item_name * (ty * global_decl_id option)) list;
types : (trait_item_name * (trait_clause list * ty option)) list;
required_methods : (trait_item_name * fun_decl_id) list;
- provided_methods : trait_item_name list;
+ provided_methods : (trait_item_name * fun_decl_id option) list;
}
[@@deriving show]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 6c9c3a91..53148dbb 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -765,7 +765,7 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : LoopId.id option)
+ (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
(rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
(args0 : texpression list) (e : texpression) : bool =
let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
@@ -791,6 +791,11 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
(* We need to use the regions hierarchy *)
(* First, lookup the signature of the LLBC function *)
let sg =
+ let id0 =
+ match id0 with
+ | FunId fun_id -> fun_id
+ | TraitMethod (_, _, fun_decl_id) -> A.Regular fun_decl_id
+ in
LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls
in
(* Compute the set of ancestors of the function in call1 *)
@@ -1521,7 +1526,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
+ | Fun (FromLlbc (FunId (A.Assumed aid), _lp_id, rg_id)) -> (
(* Below, when dealing with the arguments: we consider the very
* general case, where functions could be boxed (meaning we
* could have: [box_new f x])
@@ -2024,7 +2029,6 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
additional parameters.
*)
let used_map = ref FunLoopIdMap.empty in
- let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
@@ -2069,8 +2073,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id')) ->
- if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
applications of loop functions: the number of arguments
@@ -2129,9 +2133,9 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* We then apply the filtering to all the function definitions at once *)
let filter_in_one (decl : fun_decl) : fun_decl =
(* Filter the function signature *)
- let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let fun_id = (A.Regular decl.def_id, decl.loop_id) in
let decl =
- match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) decl
| Some used_info ->
let num_filtered =
@@ -2179,9 +2183,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
let { inputs; inputs_lvs; body } = body in
let inputs, inputs_lvs =
- match
- FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
- with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) (inputs, inputs_lvs)
| Some used_info ->
let inputs = filter_prefix used_info inputs in
@@ -2201,11 +2203,10 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
+ -> (
match
- FunLoopIdMap.find_opt
- (fun_id_to_fun_loop_id fun_id)
- !used_map
+ FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
| None -> super#visit_texpression env e
| Some used_info ->
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index db646a87..10b68da3 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -46,8 +46,8 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
| Pure _ -> ()
| FromLlbc (fid, lp_id, rg_id) -> (
match fid with
- | Assumed _ -> ()
- | Regular fid ->
+ | FunId (Assumed _) -> ()
+ | TraitMethod (_, _, fid) | FunId (Regular fid) ->
let id = { def_id = fid; lp_id; rg_id } in
ids := FunIdSet.add id !ids))
end
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index b170ebe5..4df8fec7 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -29,7 +29,7 @@ type mplace = {
[@@deriving show]
type call_id =
- | Fun of A.fun_id * V.FunCallId.id
+ | Fun of A.fun_id_or_trait_method_ref * V.FunCallId.id
(** A "regular" function (i.e., a function which is not a primitive operation) *)
| Unop of E.unop
| Binop of E.binop
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3312e22d..1da0521d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -727,6 +727,54 @@ let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
let ctx = mk_type_check_ctx ctx in
PureTypeCheck.check_texpression ctx e
+let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
+ (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
+ match id with
+ | A.FunId fun_id -> FunId fun_id
+ | TraitMethod (trait_ref, method_name, fun_decl_id) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ TraitMethod (trait_ref, method_name, fun_decl_id)
+
+let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ let calls = ctx.calls in
+ assert (not (V.FunCallId.Map.mem call_id calls));
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
+ let calls = V.FunCallId.Map.add call_id info calls in
+ { ctx with calls }
+
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
+ (* Insert the abstraction in the call informations *)
+ let info = V.FunCallId.Map.find call_id ctx.calls in
+ assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
+ let info = { info with backwards } in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
+ (* Insert the abstraction in the abstractions map *)
+ let abstractions = ctx.abstractions in
+ assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ (* Update the context and return *)
+ ({ ctx with calls; abstractions }, fun_id)
+
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
(call_id : V.FunCallId.id) : V.AbstractionId.id list =
@@ -780,10 +828,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular fid ->
+ | A.TraitMethod (_, _, fid) | A.FunId (A.Regular fid) ->
let info = A.FunDeclId.Map.find fid fun_infos in
let stateful_group = info.stateful in
let stateful =
@@ -796,7 +844,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
can_diverge = info.can_diverge;
is_rec = info.is_rec || Option.is_some lid;
}
- | A.Assumed aid ->
+ | A.FunId (A.Assumed aid) ->
assert (lid = None);
{
can_fail = Assumed.assumed_can_fail aid;
@@ -828,7 +876,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
in
(* Is the function stateful, and can it fail? *)
let lid = None in
- let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
+ let effect_info = get_fun_effect_info fun_infos (A.FunId fun_id) lid bid in
(* List the inputs for:
* - the fuel
* - the forward function
@@ -1615,7 +1663,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None, None)) in
+ let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
+ let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -2043,8 +2092,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopCall ->
let fun_id = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fun_id)
+ (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
let generics = loop_info.generics in
@@ -2095,7 +2144,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
@@ -2508,7 +2557,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fid) None ctx.bid
in
(* Introduce a fresh output value for the forward function *)
@@ -2553,7 +2602,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -2873,8 +2922,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
- bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id))
+ None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index cac56487..aeb6899f 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -122,7 +122,7 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value)
(e : expression option) : expression option =
Option.map (fun e -> EvalGlobal (gid, dest, e)) e
-let synthesize_regular_function_call (fun_id : A.fun_id)
+let synthesize_regular_function_call (fun_id : A.fun_id_or_trait_method_ref)
(call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx)
(abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
(args : V.typed_value list) (args_places : mplace option list)