summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/Pure.ml17
-rw-r--r--compiler/SymbolicToPure.ml186
-rw-r--r--compiler/Translate.ml29
3 files changed, 151 insertions, 81 deletions
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 05cdbd70..71531688 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -886,23 +886,6 @@ type inputs_info = {
}
[@@deriving show]
-type ('a, 'b) back_info =
- | SingleBack of 'a
- (** Information about a single backward function, if pertinent.
-
- We use this variant if we split the forward and the backward functions.
- *)
- | AllBacks of 'b RegionGroupId.Map.t
- (** Information about the various backward functions.
-
- We use this if we *do not* split the forward and the backward functions.
- All the information is then carried by the forward function.
- *)
-[@@deriving show]
-
-type back_inputs_info = (inputs_info option, inputs_info) back_info
-[@@deriving show]
-
(** Meta information about a function signature *)
type fun_sig_info = {
fwd_info : inputs_info;
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 70a4e18d..37f621e4 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -146,6 +146,7 @@ type bs_ctx = {
global_ctx : global_ctx;
trait_decls_ctx : trait_decls_ctx;
trait_impls_ctx : trait_impls_ctx;
+ fun_dsigs : decomposed_fun_sig FunDeclId.Map.t;
fun_decl : A.fun_decl;
bid : RegionGroupId.id option;
(** TODO: rename
@@ -890,7 +891,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else []
(** Small utility. *)
-let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
+let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
(fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
@@ -917,6 +918,22 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
is_rec = false;
}
+(** TODO: not very clean. *)
+let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) :
+ fun_effect_info =
+ match fun_id with
+ | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
+ let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
+ let info =
+ match gid with
+ | None -> dsg.fwd_info.effect_info
+ | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
+ in
+ { info with is_rec = info.is_rec || Option.is_some lid }
+ | FunId (FAssumed _) ->
+ compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid
+
(** Translate a function signature to a decomposed function signature.
Note that the function also takes a list of names for the inputs, and
@@ -962,7 +979,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
in
(* Is the forward function stateful, and can it fail? *)
- let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in
+ let fwd_effect_info =
+ compute_raw_fun_effect_info fun_infos fun_id None None
+ in
(* Compute the forward inputs *)
let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in
let fwd_inputs_no_fuel_no_state =
@@ -1051,12 +1070,23 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
RegionGroupId.id * back_sg_info =
let gid = rg.id in
let back_effect_info =
- get_fun_effect_info fun_infos fun_id None (Some gid)
+ compute_raw_fun_effect_info fun_infos fun_id None (Some gid)
in
let inputs_no_state = translate_back_inputs_for_gid gid in
let inputs_no_state =
List.map (fun ty -> (Some "ret", ty)) inputs_no_state
in
+ (* We consider the backward function as stateful and potentially failing
+ **only if it has inputs** (for the "potentially failing": if it has
+ not inputs, we directly evaluate it in the body of the forward function).
+ *)
+ let back_effect_info =
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && inputs_no_state <> [];
+ can_fail = back_effect_info.can_fail && inputs_no_state <> [];
+ }
+ in
let state =
if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
in
@@ -1140,6 +1170,19 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
(FunId (FRegular fun_id)) regions_hierarchy sg input_names
+let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx)
+ (fdef : LlbcAst.fun_decl) : decomposed_fun_sig =
+ let input_names =
+ match fdef.body with
+ | None -> List.map (fun _ -> None) fdef.signature.inputs
+ | Some body ->
+ List.map
+ (fun (v : LlbcAst.var) -> v.name)
+ (LlbcAstUtils.fun_body_get_input_vars body)
+ in
+ translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature
+ input_names
+
let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
=
let output =
@@ -1158,8 +1201,9 @@ let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
If a backward function has no inputs/outputs we filter it.
*)
-let compute_back_tys (dsg : Pure.decomposed_fun_sig)
- (subst : (generic_args * trait_instance_id) option) : ty option list =
+let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) :
+ (back_sg_info * ty) option list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
@@ -1185,9 +1229,13 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
in
ty_substitute subst ty
in
- Some ty)
+ Some (back_sg, ty))
(RegionGroupId.Map.values dsg.back_sg)
+let compute_back_tys (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) : ty option list =
+ List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
+
(** In case we merge the fwd/back functions: compute the output type of
a function, from a decomposed signature. *)
let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
@@ -1363,6 +1411,7 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
in
fresh_opt_vars back_vars ctx
+(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *)
let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
| Some v -> v
@@ -1381,12 +1430,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
| _ -> raise (Failure "Unreachable"))
| _ -> v
-(** Translate a symbolic value *)
+(** Translate a symbolic value.
+
+ Because we do not necessarily introduce variables for the symbolic values
+ of (translated) type unit, it is important that we do not lookup variables
+ in case the symbolic value has type unit.
+ *)
let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) :
texpression =
(* Translate the type *)
- let var = lookup_var_for_symbolic_value sv ctx in
- mk_texpression_from_var var
+ let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+ (* If the type is unit, directly return unit *)
+ if ty_is_unit ty then mk_unit_rvalue
+ else
+ (* Otherwise lookup the variable *)
+ let var = lookup_var_for_symbolic_value sv ctx in
+ mk_texpression_from_var var
(** Translate a typed value.
@@ -1565,13 +1624,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option =
match aproj with
| V.AEndedProjLoans (msv, []) ->
(* The symbolic value was left unchanged *)
- let var = lookup_var_for_symbolic_value msv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx msv)
| V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) ->
assert (child_aproj = AIgnoredProjBorrows);
(* The symbolic value was updated *)
- let var = lookup_var_for_symbolic_value mnv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx mnv)
| V.AEndedProjLoans (_, _) ->
(* The symbolic value was updated, and the given back values come from sevearl
* abstractions *)
@@ -1940,10 +1997,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.combine args args_mplaces)
in
let dest_mplace = translate_opt_mplace call.dest_place in
- let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, effect_info, args, back_funs, out_state =
+ let ctx, fun_id, effect_info, args, dest_v =
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
@@ -1951,13 +2007,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
- let effect_info =
- get_fun_effect_info ctx.fun_ctx.fun_infos fid None None
- in
+ let effect_info = get_fun_effect_info ctx fid None None in
(* Depending on the function effects:
- * - add the fuel
- * - add the state input argument
- * - generate a fresh state variable for the returned state
+ - add the fuel
+ - add the state input argument
+ - generate a fresh state variable for the returned state
*)
let args, ctx, out_state =
let fuel = mk_fuel_input_as_list ctx effect_info in
@@ -1970,7 +2024,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* If we do not split the forward/backward functions: generate the
variables for the backward functions returned by the forward
function. *)
- let ctx, back_funs_map, back_funs =
+ let ctx, ignore_fwd_output, back_funs_map, back_funs =
if !Config.return_back_funs then
(* We need to compute the signatures of the backward functions. *)
let sg = Option.get call.sg in
@@ -1981,7 +2035,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.map (fun _ -> None) sg.inputs)
in
let tr_self = UnknownTrait __FUNCTION__ in
- let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (generics, tr_self))
+ in
(* Introduce variables for the backward functions *)
(* Compute a proper basename for the variables *)
let back_fun_name =
@@ -2016,7 +2072,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(fun ty ->
match ty with
| None -> None
- | Some ty -> Some (Some back_fun_name, ty))
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then
+ Some back_fun_name
+ else None
+ in
+ Some (name, ty))
back_tys)
ctx
in
@@ -2039,14 +2106,37 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let back_funs_map =
RegionGroupId.Map.of_list (List.combine gids back_vars)
in
- (ctx, Some back_funs_map, back_funs)
- else (ctx, None, [])
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)
+ else (ctx, false, None, [])
+ in
+ (* Compute the pattern for the destination *)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ let dest =
+ (* Here there is something subtle: as we might ignore the output
+ of the forward function (because it translates to unit) we doNOT
+ necessarily introduce in the let-binding the variable to which we
+ map the symbolic value which was introduced for the output of the
+ function call. This would be problematic if later we need to
+ translate this symbolic value, but we implemented
+ {!symbolic_value_to_texpression} so that it doesn't perform any
+ lookups if the symbolic value has type unit.
+ *)
+ let vars =
+ if ignore_fwd_output then back_funs else dest :: back_funs
+ in
+ mk_simpl_tuple_pattern vars
+ in
+ let dest =
+ match out_state with
+ | None -> dest
+ | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
(* Register the function call *)
let ctx =
bs_ctx_register_forward_call call_id call args back_funs_map ctx
in
- (ctx, func, effect_info, args, back_funs, out_state)
+ (ctx, func, effect_info, args, dest)
| S.Unop E.Not ->
let effect_info =
{
@@ -2057,7 +2147,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop Not, effect_info, args, [], None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop Not, effect_info, args, dest)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
@@ -2073,7 +2165,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Neg int_ty), effect_info, args, [], None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Neg int_ty), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
@@ -2088,7 +2182,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest)
| CastFnPtr _ -> raise (Failure "TODO: function casts"))
| S.Binop binop -> (
match args with
@@ -2108,16 +2204,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Binop (binop, int_ty0), effect_info, args, [], None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Binop (binop, int_ty0), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
in
- let dest_v =
- let dest = mk_typed_pattern_from_var dest dest_mplace in
- let dest = mk_simpl_tuple_pattern (dest :: back_funs) in
- match out_state with
- | None -> dest
- | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
- in
let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -2242,9 +2333,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Those don't have backward functions *)
raise (Failure "Unreachable")
in
- let effect_info =
- get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id)
- in
+ let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
@@ -2449,8 +2538,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
let generics = loop_info.generics in
@@ -2609,8 +2697,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
- let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
- let scrutinee = mk_texpression_from_var scrutinee_var in
+ let scrutinee = symbolic_value_to_texpression ctx sv in
let scrutinee_mplace = translate_opt_mplace p in
(* Translate the branches *)
match exp with
@@ -2999,7 +3086,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Compute whether the backward expressions should be evaluated straight
away or not (i.e., if we should bind them with monadic let-bindings
or not). We evaluate them straight away if they can fail and have no
- inputs *)
+ inputs. *)
let evaluate_backs =
List.map
(fun (sg : back_sg_info) ->
@@ -3098,9 +3185,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = E.FRegular ctx.fun_decl.def_id in
- let effect_info =
- get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid
- in
+ let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in
(* Introduce a fresh output value for the forward function *)
let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
@@ -3479,8 +3564,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id))
- None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 5584fb9a..ccc46420 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -42,7 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) :
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (trans_ctx : trans_ctx)
- (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) :
+ (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t)
+ (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) :
pure_fun_translation_no_loops =
(* Debug *)
log#ldebug
@@ -119,18 +120,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
!m
in
- let input_names =
- match fdef.body with
- | None -> List.map (fun _ -> None) fdef.signature.inputs
- | Some body ->
- List.map
- (fun (v : var) -> v.name)
- (LlbcAstUtils.fun_body_get_input_vars body)
- in
-
let sg =
- SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id
- fdef.signature input_names
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef
in
let regions_hierarchy =
@@ -154,6 +145,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
decls_ctx = trans_ctx;
SymbolicToPure.bid = None;
sg;
+ fun_dsigs;
(* Will need to be updated for the backward functions *)
sv_to_var;
var_counter = ref var_counter;
@@ -290,10 +282,21 @@ let translate_crate_to_pure (crate : crate) :
(List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls)
in
+ (* Compute the decomposed fun sigs for the whole crate *)
+ let fun_dsigs =
+ FunDeclId.Map.of_list
+ (List.map
+ (fun (fdef : LlbcAst.fun_decl) ->
+ ( fdef.def_id,
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx
+ fdef ))
+ (FunDeclId.Map.values crate.fun_decls))
+ in
+
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure trans_ctx type_decls_map)
+ (translate_function_to_pure trans_ctx type_decls_map fun_dsigs)
(FunDeclId.Map.values crate.fun_decls)
in