summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
authorSon Ho2022-11-25 08:13:37 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit59596561873162d632f3d3091485b32a76427ee9 (patch)
tree2bdeb89950981306bacff00a1e8e68b92ec0f9db /compiler/SymbolicToPure.ml
parentbbdd0da25b974b03d58489d3bbc2654f4f774644 (diff)
Start implementing support for loops
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml56
1 files changed, 31 insertions, 25 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3bd6c5b3..45e35742 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -236,17 +236,17 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
{ ctx with calls }
(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list)
- (ctx : bs_ctx) : bs_ctx * fun_or_op_id =
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
(* Insert the abstraction in the call informations *)
- let back_id = abs.back_id in
- let info = V.FunCallId.Map.find abs.call_id ctx.calls in
+ let info = V.FunCallId.Map.find call_id ctx.calls in
assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
let backwards =
T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
in
let info = { info with backwards } in
- let calls = V.FunCallId.Map.add abs.call_id info ctx.calls in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
(* Insert the abstraction in the abstractions map *)
let abstractions = ctx.abstractions in
assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
@@ -256,7 +256,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list)
(* Retrieve the fun_id *)
let fun_id =
match info.forward.call_id with
- | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some abs.back_id))
+ | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some back_id))
| S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
in
(* Update the context and return *)
@@ -470,8 +470,8 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
translate_back_ty types_infos keep_region inside_mut ty
(** List the ancestors of an abstraction *)
-let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) :
- V.AbstractionId.id list =
+let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
+ (call_id : V.FunCallId.id) : V.AbstractionId.id list =
(* We could do something more "elegant" without references, but it is
* so much simpler to use references... *)
let abs_set = ref V.AbstractionId.Set.empty in
@@ -485,14 +485,16 @@ let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) :
List.iter gather abs.original_parents;
let ids = !abs_set in
(* List the ancestors, in the proper order *)
- let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
+ let call_info = V.FunCallId.Map.find call_id ctx.calls in
List.filter
(fun id -> V.AbstractionId.Set.mem id ids)
call_info.forward.abstractions
-let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
- (V.abs * texpression list) list =
- let abs_ids = list_ancestor_abstractions_ids ctx abs in
+(** List the ancestor abstractions of an abstraction introduced because of
+ a function call *)
+let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs)
+ (call_id : V.FunCallId.id) : (V.abs * texpression list) list =
+ let abs_ids = list_ancestor_abstractions_ids ctx abs call_id in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
(** Small utility.
@@ -1143,10 +1145,10 @@ let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) :
Is used for instance when collecting the input values given to all the
parent functions, in order to properly instantiate an
*)
-let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) :
+let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) :
S.call * (V.abs * texpression list) list =
- let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
- let abs_ancestors = list_ancestor_abstractions ctx abs in
+ let call_info = V.FunCallId.Map.find call_id ctx.calls in
+ let abs_ancestors = list_ancestor_abstractions ctx abs call_id in
(call_info.forward, abs_ancestors)
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
@@ -1171,6 +1173,7 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
(ctx, e)
in
translate_expression e ctx
+ | Loop _loop -> raise (Failure "Unimplemented")
and translate_panic (ctx : bs_ctx) : texpression =
(* Here we use the function return type - note that it is ok because
@@ -1357,7 +1360,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
("translate_end_abstraction: abstraction kind: "
^ V.show_abs_kind abs.kind));
match abs.kind with
- | V.SynthInput ->
+ | V.SynthInput rg_id ->
(* When we end an input abstraction, this input abstraction gets back
* the borrows which it introduced in the context through the input
* values: by listing those values, we get the values which are given
@@ -1368,7 +1371,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
* for a parent backward function.
*)
let bid = Option.get ctx.bid in
- assert (abs.back_id = bid);
+ assert (rg_id = bid);
(* The translation is done as follows:
* - for a given backward function, we choose a set of variables [v_i]
@@ -1406,8 +1409,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (var, value) (e : texpression) ->
mk_let monadic (mk_typed_pattern_from_var var None) value e)
variables_values next_e
- | V.FunCall ->
- let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
+ | V.FunCall (call_id, rg_id) ->
+ let call_info = V.FunCallId.Map.find call_id ctx.calls in
let call = call_info.forward in
let fun_id =
match call.call_id with
@@ -1417,11 +1420,11 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
raise (Failure "Unreachable")
in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id)
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some rg_id)
in
let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
(* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs in
+ let _forward, backwards = get_abs_ancestors ctx abs call_id in
(* Retrieve the values consumed when we called the forward function and
* ended the parent backward functions: those give us part of the input
* values (rem: for now, as we disallow nested lifetimes, there can't be
@@ -1470,7 +1473,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Sanity check: the inputs and outputs have the proper number and the proper type *)
let _ =
let inst_sg =
- get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx
+ get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx
in
log#ldebug
(lazy
@@ -1497,7 +1500,9 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
in
(* Retrieve the function id, and register the function call in the context
* if necessary *)
- let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in
+ let ctx, func =
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
@@ -1535,7 +1540,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e
- | V.SynthRet ->
+ | V.SynthRet rg_id ->
(* If we end the abstraction which consumed the return value of the function
we are synthesizing, we get back the borrows which were inside. Those borrows
are actually input arguments of the backward function we are synthesizing.
@@ -1566,7 +1571,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
*)
(* First, retrieve the list of variables used for the inputs for the
* backward function *)
- let inputs = T.RegionGroupId.Map.find abs.back_id ctx.backward_inputs in
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: as there are no nested borrows, there should be none. *)
let consumed = abs_to_consumed ctx abs in
@@ -1597,6 +1602,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (given_back, input_var) e ->
mk_let monadic given_back (mk_texpression_from_var input_var) e)
given_back_inputs next_e
+ | Loop _ -> raise (Failure "Unimplemented")
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =