summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
authorSon Ho2023-09-11 06:35:07 +0200
committerSon Ho2023-09-11 06:35:07 +0200
commit5921be8e2e8955db5101354d8bf29ae6a3693f48 (patch)
treef17b4c4cfe0ba184a4831cae353530aea7ee354b /compiler/SymbolicToPure.ml
parentc6b88a2e54b7697262ad3677ad7500471c68e332 (diff)
Make progress on correctly handling trait method calls in the symbolic execution
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml73
1 files changed, 61 insertions, 12 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3312e22d..1da0521d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -727,6 +727,54 @@ let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
let ctx = mk_type_check_ctx ctx in
PureTypeCheck.check_texpression ctx e
+let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
+ (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
+ match id with
+ | A.FunId fun_id -> FunId fun_id
+ | TraitMethod (trait_ref, method_name, fun_decl_id) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ TraitMethod (trait_ref, method_name, fun_decl_id)
+
+let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ let calls = ctx.calls in
+ assert (not (V.FunCallId.Map.mem call_id calls));
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
+ let calls = V.FunCallId.Map.add call_id info calls in
+ { ctx with calls }
+
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
+ (* Insert the abstraction in the call informations *)
+ let info = V.FunCallId.Map.find call_id ctx.calls in
+ assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
+ let info = { info with backwards } in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
+ (* Insert the abstraction in the abstractions map *)
+ let abstractions = ctx.abstractions in
+ assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ (* Update the context and return *)
+ ({ ctx with calls; abstractions }, fun_id)
+
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
(call_id : V.FunCallId.id) : V.AbstractionId.id list =
@@ -780,10 +828,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular fid ->
+ | A.TraitMethod (_, _, fid) | A.FunId (A.Regular fid) ->
let info = A.FunDeclId.Map.find fid fun_infos in
let stateful_group = info.stateful in
let stateful =
@@ -796,7 +844,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
can_diverge = info.can_diverge;
is_rec = info.is_rec || Option.is_some lid;
}
- | A.Assumed aid ->
+ | A.FunId (A.Assumed aid) ->
assert (lid = None);
{
can_fail = Assumed.assumed_can_fail aid;
@@ -828,7 +876,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
in
(* Is the function stateful, and can it fail? *)
let lid = None in
- let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
+ let effect_info = get_fun_effect_info fun_infos (A.FunId fun_id) lid bid in
(* List the inputs for:
* - the fuel
* - the forward function
@@ -1615,7 +1663,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None, None)) in
+ let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
+ let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -2043,8 +2092,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopCall ->
let fun_id = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fun_id)
+ (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
let generics = loop_info.generics in
@@ -2095,7 +2144,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
@@ -2508,7 +2557,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fid) None ctx.bid
in
(* Introduce a fresh output value for the forward function *)
@@ -2553,7 +2602,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -2873,8 +2922,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
- bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id))
+ None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)