summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/AssociatedTypes.ml10
-rw-r--r--compiler/Config.ml89
-rw-r--r--compiler/Contexts.ml42
-rw-r--r--compiler/Extract.ml157
-rw-r--r--compiler/ExtractBase.ml83
-rw-r--r--compiler/ExtractBuiltin.ml8
-rw-r--r--compiler/Interpreter.ml8
-rw-r--r--compiler/InterpreterBorrows.ml12
-rw-r--r--compiler/InterpreterExpansion.ml2
-rw-r--r--compiler/InterpreterExpressions.ml5
-rw-r--r--compiler/InterpreterExpressions.mli2
-rw-r--r--compiler/InterpreterLoopsFixedPoint.ml18
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml36
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.ml4
-rw-r--r--compiler/InterpreterStatements.ml76
-rw-r--r--compiler/InterpreterUtils.ml26
-rw-r--r--compiler/Invariants.ml2
-rw-r--r--compiler/Main.ml6
-rw-r--r--compiler/Print.ml10
-rw-r--r--compiler/PrintPure.ml82
-rw-r--r--compiler/Pure.ml176
-rw-r--r--compiler/PureMicroPasses.ml550
-rw-r--r--compiler/PureTypeCheck.ml8
-rw-r--r--compiler/PureUtils.ml98
-rw-r--r--compiler/SymbolicAst.ml8
-rw-r--r--compiler/SymbolicToPure.ml1902
-rw-r--r--compiler/SynthesizeSymbolic.ml21
-rw-r--r--compiler/Translate.ml211
28 files changed, 2404 insertions, 1248 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
index e2f687e8..054c8169 100644
--- a/compiler/AssociatedTypes.ml
+++ b/compiler/AssociatedTypes.ml
@@ -493,11 +493,11 @@ let norm_ctx_normalize_trait_type_constraint (ctx : norm_ctx)
let mk_norm_ctx (ctx : eval_ctx) : norm_ctx =
{
norm_trait_types = ctx.norm_trait_types;
- type_decls = ctx.type_context.type_decls;
- fun_decls = ctx.fun_context.fun_decls;
- global_decls = ctx.global_context.global_decls;
- trait_decls = ctx.trait_decls_context.trait_decls;
- trait_impls = ctx.trait_impls_context.trait_impls;
+ type_decls = ctx.type_ctx.type_decls;
+ fun_decls = ctx.fun_ctx.fun_decls;
+ global_decls = ctx.global_ctx.global_decls;
+ trait_decls = ctx.trait_decls_ctx.trait_decls;
+ trait_impls = ctx.trait_impls_ctx.trait_impls;
type_vars = ctx.type_vars;
const_generic_vars = ctx.const_generic_vars;
}
diff --git a/compiler/Config.ml b/compiler/Config.ml
index b09544ba..2bb1ca34 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -92,6 +92,69 @@ let loop_fixed_point_max_num_iters = 2
(** {1 Translation} *)
+(** If true, do not define separate forward/backward functions, but make the
+ forward functions return the backward function.
+
+ Example:
+ {[
+ (* Rust *)
+ pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T {
+ match l {
+ List::Nil => {
+ panic!()
+ }
+ List::Cons(x, tl) => {
+ if i == 0 {
+ x
+ } else {
+ list_nth(tl, i - 1)
+ }
+ }
+ }
+ }
+
+ (* Translation, if return_back_funs = false *)
+ def list_nth (T : Type) (l : List T) (i : U32) : Result T :=
+ match l with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret x
+ else do
+ let i0 ← i - 1#u32
+ list_nth T tl i0
+ | List.Nil => Result.fail .panic
+
+ def list_nth_back
+ (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) :=
+ match l with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret (List.Cons ret tl)
+ else
+ do
+ let i0 ← i - 1#u32
+ let tl0 ← list_nth_back T tl i0 ret
+ Result.ret (List.Cons x tl0)
+ | List.Nil => Result.fail .panic
+
+ (* Translation, if return_back_funs = true *)
+ def list_nth (T: Type) (ls : List T) (i : U32) :
+ Result (T × (T → Result (List T))) :=
+ match ls with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret (x, (λ ret => return (ret :: ls)))
+ else do
+ let i0 ← i - 1#u32
+ let (x, back) ← list_nth ls i0
+ Return.ret (x,
+ (λ ret => do
+ let ls ← back ret
+ return (x :: ls)))
+ ]}
+ *)
+let return_back_funs = ref true
+
(** Forbids using field projectors for structures.
If we don't use field projectors, whenever we symbolically expand a structure
@@ -307,6 +370,32 @@ let filter_useless_monadic_calls = ref true
*)
let filter_useless_functions = ref true
+(** Simplify the forward/backward functions, in case we merge them
+ (i.e., the forward functions return the backward functions).
+
+ The simplification occurs as follows:
+ - if a forward function returns the unit type and has non-trivial backward
+ functions, then we remove the returned output.
+ - if a backward function doesn't have inputs, we evaluate it inside the
+ forward function and don't wrap it in a result.
+
+ Example:
+ {[
+ // LLBC:
+ fn incr(x: &mut u32) { *x += 1 }
+
+ // Translation without simplification:
+ let incr (x : u32) : result (unit * result u32) = ...
+ ^^^^ ^^^^^^
+ | remove this result
+ remove the unit
+
+ // Translation with simplification:
+ let incr (x : u32) : result u32 = ...
+ ]}
+ *)
+let simplify_merged_fwd_backs = ref true
+
(** Use short names for the record fields.
Some backends can't disambiguate records when their field names have collisions.
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index a30ed0f1..5d646a61 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -180,35 +180,35 @@ type config = {
let mk_config (mode : interpreter_mode) : config = { mode }
-type type_context = {
+type type_ctx = {
type_decls_groups : type_declaration_group TypeDeclId.Map.t;
type_decls : type_decl TypeDeclId.Map.t;
type_infos : TypesAnalysis.type_infos;
}
[@@deriving show]
-type fun_context = {
+type fun_ctx = {
fun_decls : fun_decl FunDeclId.Map.t;
fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t;
regions_hierarchies : region_var_groups FunIdMap.t;
}
[@@deriving show]
-type global_context = { global_decls : global_decl GlobalDeclId.Map.t }
+type global_ctx = { global_decls : global_decl GlobalDeclId.Map.t }
[@@deriving show]
-type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t }
+type trait_decls_ctx = { trait_decls : trait_decl TraitDeclId.Map.t }
[@@deriving show]
-type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t }
+type trait_impls_ctx = { trait_impls : trait_impl TraitImplId.Map.t }
[@@deriving show]
type decls_ctx = {
- type_ctx : type_context;
- fun_ctx : fun_context;
- global_ctx : global_context;
- trait_decls_ctx : trait_decls_context;
- trait_impls_ctx : trait_impls_context;
+ type_ctx : type_ctx;
+ fun_ctx : fun_ctx;
+ global_ctx : global_ctx;
+ trait_decls_ctx : trait_decls_ctx;
+ trait_impls_ctx : trait_impls_ctx;
}
[@@deriving show]
@@ -230,11 +230,11 @@ module TraitTypeRefMap = Collections.MakeMap (TraitTypeRefOrd)
(** Evaluation context *)
type eval_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
- trait_decls_context : trait_decls_context;
- trait_impls_context : trait_impls_context;
+ type_ctx : type_ctx;
+ fun_ctx : fun_ctx;
+ global_ctx : global_ctx;
+ trait_decls_ctx : trait_decls_ctx;
+ trait_impls_ctx : trait_impls_ctx;
region_groups : RegionGroupId.id list;
type_vars : type_var list;
const_generic_vars : const_generic_var list;
@@ -290,20 +290,20 @@ let ctx_lookup_var_binder (ctx : eval_ctx) (vid : VarId.id) : var_binder =
fst (env_lookup_var ctx.env vid)
let ctx_lookup_type_decl (ctx : eval_ctx) (tid : TypeDeclId.id) : type_decl =
- TypeDeclId.Map.find tid ctx.type_context.type_decls
+ TypeDeclId.Map.find tid ctx.type_ctx.type_decls
let ctx_lookup_fun_decl (ctx : eval_ctx) (fid : FunDeclId.id) : fun_decl =
- FunDeclId.Map.find fid ctx.fun_context.fun_decls
+ FunDeclId.Map.find fid ctx.fun_ctx.fun_decls
let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) :
global_decl =
- GlobalDeclId.Map.find gid ctx.global_context.global_decls
+ GlobalDeclId.Map.find gid ctx.global_ctx.global_decls
let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl =
- TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls
+ TraitDeclId.Map.find id ctx.trait_decls_ctx.trait_decls
let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl =
- TraitImplId.Map.find id ctx.trait_impls_context.trait_impls
+ TraitImplId.Map.find id ctx.trait_impls_ctx.trait_impls
(** Retrieve a variable's value in the current frame *)
let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
@@ -528,7 +528,7 @@ let ctx_set_abs_can_end (ctx : eval_ctx) (abs_id : AbstractionId.id)
fst (ctx_subst_abs ctx abs_id abs)
let ctx_type_decl_is_rec (ctx : eval_ctx) (id : TypeDeclId.id) : bool =
- let decl_group = TypeDeclId.Map.find id ctx.type_context.type_decls_groups in
+ let decl_group = TypeDeclId.Map.find id ctx.type_ctx.type_decls_groups in
match decl_group with RecGroup _ -> true | NonRecGroup _ -> false
(** Visitor to iterate over the values in the *current* frame *)
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 20cdb20b..87dcb1fd 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -43,8 +43,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
}
| _ -> ctx
in
- let backs = List.map (fun f -> f.f) def.backs in
- let funs = if def.keep_fwd then def.fwd.f :: backs else backs in
+ let funs =
+ if !Config.return_back_funs then [ def.fwd.f ]
+ else
+ let backs = List.map (fun f -> f.f) def.backs in
+ if def.keep_fwd then def.fwd.f :: backs else backs
+ in
List.fold_left
(fun ctx (f : fun_decl) ->
let open ExtractBuiltin in
@@ -128,9 +132,15 @@ let extract_adt_g_value
F.pp_print_string fmt "tt";
ctx)
else
- (* If there is exactly one value, we don't print the parentheses *)
+ (* If there is exactly one value, we don't print the parentheses.
+ Also, for Coq, we need the special syntax ['(...)] if we destruct
+ a tuple pattern in a let-binding and the tuple has > 2 values.
+ *)
let lb, rb =
- if List.length field_values = 1 then ("", "") else ("(", ")")
+ if List.length field_values = 1 then ("", "")
+ else if !backend = Coq && is_single_pat && List.length field_values > 2
+ then ("'(", ")")
+ else ("(", ")")
in
F.pp_print_string fmt lb;
let ctx =
@@ -237,30 +247,60 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list)
Result.Ok types
(** [inside]: see {!extract_ty}.
+ [with_type]: do we also generate a type annotation? This is necessary for
+ backends like Coq when we write lambdas (Coq is not powerful enough to
+ infer the type).
As a pattern can introduce new variables, we return an extraction context
updated with new bindings.
*)
let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
- (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx =
- match v.value with
- | PatConstant cv ->
- extract_literal fmt inside cv;
- ctx
- | PatVar (v, _) ->
- let vname = ctx_compute_var_basename ctx v.basename v.ty in
- let ctx, vname = ctx_add_var vname v.id ctx in
- F.pp_print_string fmt vname;
- ctx
- | PatDummy ->
- F.pp_print_string fmt "_";
- ctx
- | PatAdt av ->
- let extract_value ctx inside v =
- extract_typed_pattern ctx fmt is_let inside v
- in
- extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
- av.field_values v.ty
+ (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) :
+ extraction_ctx =
+ if with_type then F.pp_print_string fmt "(";
+ let inside = inside && not with_type in
+ let ctx =
+ match v.value with
+ | PatConstant cv ->
+ extract_literal fmt inside cv;
+ ctx
+ | PatVar (v, _) ->
+ let vname = ctx_compute_var_basename ctx v.basename v.ty in
+ let ctx, vname = ctx_add_var vname v.id ctx in
+ F.pp_print_string fmt vname;
+ ctx
+ | PatDummy ->
+ F.pp_print_string fmt "_";
+ ctx
+ | PatAdt av ->
+ let extract_value ctx inside v =
+ extract_typed_pattern ctx fmt is_let inside v
+ in
+ extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
+ av.field_values v.ty
+ in
+ if with_type then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty false v.ty;
+ F.pp_print_string fmt ")");
+ ctx
+
+(** Return true if we need to wrap a succession of let-bindings in a [do ...]
+ block (because some of them are monadic) *)
+let lets_require_wrap_in_do (lets : (bool * typed_pattern * texpression) list) :
+ bool =
+ match !backend with
+ | Lean ->
+ (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *)
+ List.exists (fun (m, _, _) -> m) lets
+ | HOL4 ->
+ (* HOL4 is similar to HOL4, but we add a sanity check *)
+ let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in
+ if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets);
+ wrap_in_do
+ | FStar | Coq -> false
(** [inside]: controls the introduction of parentheses. See [extract_ty]
@@ -285,9 +325,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| App _ ->
let app, args = destruct_apps e in
extract_App ctx fmt inside app args
- | Abs _ ->
- let xl, e = destruct_abs_list e in
- extract_Abs ctx fmt inside xl e
+ | Lambda _ ->
+ let xl, e = destruct_lambdas e in
+ extract_Lambda ctx fmt inside xl e
| Qualif _ ->
(* We use the app case *)
extract_App ctx fmt inside e []
@@ -574,7 +614,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(* No argument: shouldn't happen *)
raise (Failure "Unreachable")
-and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
+and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(xl : typed_pattern list) (e : texpression) : unit =
(* Open a box for the abs expression *)
F.pp_open_hovbox fmt ctx.indent_incr;
@@ -583,15 +623,16 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Print the lambda - note that there should always be at least one variable *)
assert (xl <> []);
F.pp_print_string fmt "fun";
+ let with_type = !backend = Coq in
let ctx =
List.fold_left
(fun ctx x ->
F.pp_print_space fmt ();
- extract_typed_pattern ctx fmt true true x)
+ extract_typed_pattern ctx fmt true true ~with_type x)
ctx xl
in
F.pp_print_space fmt ();
- if !backend = Lean then F.pp_print_string fmt "=>"
+ if !backend = Lean || !backend = Coq then F.pp_print_string fmt "=>"
else F.pp_print_string fmt "->";
F.pp_print_space fmt ();
(* Print the body *)
@@ -630,15 +671,6 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| HOL4 -> destruct_lets_no_interleave e
| FStar | Coq | Lean -> destruct_lets e
in
- (* Open a box for the whole expression.
-
- In the case of Lean, we use a vbox so that line breaks are inserted
- at the end of every let-binding: let-bindings are indeed not ended
- with an "in" keyword.
- *)
- if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0;
- (* Open parentheses *)
- if inside && !backend <> Lean then F.pp_print_string fmt "(";
(* Extract the let-bindings *)
let extract_let (ctx : extraction_ctx) (monadic : bool) (lv : typed_pattern)
(re : texpression) : extraction_ctx =
@@ -711,22 +743,19 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Return *)
ctx
in
+ (* Open a box for the whole expression.
+
+ In the case of Lean, we use a vbox so that line breaks are inserted
+ at the end of every let-binding: let-bindings are indeed not ended
+ with an "in" keyword.
+ *)
+ if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0;
+ (* Open parentheses *)
+ if inside && !backend <> Lean then F.pp_print_string fmt "(";
(* If Lean and HOL4, we rely on monadic blocks, so we insert a do and open a new box
immediately *)
- let wrap_in_do_od =
- match !backend with
- | Lean ->
- (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *)
- List.exists (fun (m, _, _) -> m) lets
- | HOL4 ->
- (* HOL4 is similar to HOL4, but we add a sanity check *)
- let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in
- if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets);
- wrap_in_do
- | FStar | Coq -> false
- in
+ let wrap_in_do_od = lets_require_wrap_in_do lets in
if wrap_in_do_od then (
- F.pp_open_vbox fmt (if !backend = Lean then ctx.indent_incr else 0);
F.pp_print_string fmt "do";
F.pp_print_space fmt ());
let ctx =
@@ -742,11 +771,10 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
F.pp_close_box fmt ();
(* do-box (Lean and HOL4 only) *)
- if wrap_in_do_od then (
+ if wrap_in_do_od then
if !backend = HOL4 then (
F.pp_print_space fmt ();
F.pp_print_string fmt "od");
- F.pp_close_box fmt ());
(* Close parentheses *)
if inside && !backend <> Lean then F.pp_print_string fmt ")";
(* Close the box for the whole expression *)
@@ -1319,16 +1347,16 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id)
ctx.fun_name_info
in
- let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]: " in
+ let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in
let comment =
let loop_comment =
match def.loop_id with
| None -> ""
- | Some id -> "loop " ^ LoopId.to_string id ^ ": "
+ | Some id -> " loop " ^ LoopId.to_string id ^ ":"
in
let fwd_back_comment =
match def.back_id with
- | None -> [ "forward function" ]
+ | None -> if !Config.return_back_funs then [] else [ "forward function" ]
| Some id ->
(* Check if there is only one backward function, and no forward function *)
if (not keep_fwd) && num_backs = 1 then
@@ -1340,9 +1368,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
else [ "backward function " ^ T.RegionGroupId.to_string id ]
in
match fwd_back_comment with
- | [] -> raise (Failure "Unreachable")
- | [ s ] -> [ comment_pre ^ loop_comment ^ s ]
- | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl
+ | [] -> [ comment_pre ^ loop_comment ]
+ | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ]
+ | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl
in
extract_comment_with_span fmt comment def.meta.span
@@ -1470,7 +1498,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let inputs_lvs =
let all_inputs = (Option.get def.body).inputs_lvs in
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state
in
Collections.List.prefix num_fwd_inputs all_inputs
in
@@ -1516,7 +1544,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let def_body = Option.get def.body in
let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state
in
let vars = Collections.List.prefix num_fwd_inputs all_vars in
@@ -1794,7 +1822,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
assert body.is_global_decl_body;
assert (Option.is_none body.back_id);
assert (body.signature.inputs = []);
- assert (List.length body.signature.doutputs = 1);
assert (body.signature.generics = empty_generic_params);
(* Add a break then the name of the corresponding LLBC declaration *)
@@ -1813,7 +1840,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_ty, body_ty =
let ty = body.signature.output in
- if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty)
+ if body.signature.fwd_info.effect_info.can_fail then
+ (unwrap_result_ty ty, ty)
else (ty, mk_result_ty ty)
in
match body.body with
@@ -1984,7 +2012,8 @@ let extract_trait_decl_method_names (ctx : extraction_ctx)
(* We add one field per required forward/backward function *)
let get_funs_for_id (id : fun_decl_id) : fun_decl list =
let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in
- List.map (fun f -> f.f) (trans.fwd :: trans.backs)
+ if !Config.return_back_funs then [ trans.fwd.f ]
+ else List.map (fun f -> f.f) (trans.fwd :: trans.backs)
in
match builtin_info with
| None ->
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index eb2a2ec9..db887539 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -351,7 +351,7 @@ let basename_to_unique (names_set : StringSet.t)
let s = append basename i in
if StringSet.mem s names_set then gen (i + 1) else s
in
- if StringSet.mem basename names_set then gen 0 else basename
+ if StringSet.mem basename names_set then gen 1 else basename
type fun_name_info = { keep_fwd : bool; num_backs : int }
@@ -1051,33 +1051,60 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
let assumed_llbc_functions () :
(A.assumed_fun_id * T.RegionGroupId.id option * string) list =
let rg0 = Some T.RegionGroupId.zero in
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexShared, None, "array_index_usize");
- (ArrayIndexMut, None, "array_index_usize");
- (ArrayIndexMut, rg0, "array_update_usize");
- (ArrayToSliceShared, None, "array_to_slice");
- (ArrayToSliceMut, None, "array_to_slice");
- (ArrayToSliceMut, rg0, "array_from_slice");
- (ArrayRepeat, None, "array_repeat");
- (SliceIndexShared, None, "slice_index_usize");
- (SliceIndexMut, None, "slice_index_usize");
- (SliceIndexMut, rg0, "slice_update_usize");
- ]
- | Lean ->
- [
- (ArrayIndexShared, None, "Array.index_usize");
- (ArrayIndexMut, None, "Array.index_usize");
- (ArrayIndexMut, rg0, "Array.update_usize");
- (ArrayToSliceShared, None, "Array.to_slice");
- (ArrayToSliceMut, None, "Array.to_slice");
- (ArrayToSliceMut, rg0, "Array.from_slice");
- (ArrayRepeat, None, "Array.repeat");
- (SliceIndexShared, None, "Slice.index_usize");
- (SliceIndexMut, None, "Slice.index_usize");
- (SliceIndexMut, rg0, "Slice.update_usize");
- ]
+ let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexShared, None, "array_index_usize");
+ (ArrayToSliceShared, None, "array_to_slice");
+ (ArrayRepeat, None, "array_repeat");
+ (SliceIndexShared, None, "slice_index_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexShared, None, "Array.index_usize");
+ (ArrayToSliceShared, None, "Array.to_slice");
+ (ArrayRepeat, None, "Array.repeat");
+ (SliceIndexShared, None, "Slice.index_usize");
+ ]
+ in
+ let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
+ if !Config.return_back_funs then
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexMut, None, "array_index_mut_usize");
+ (ArrayToSliceMut, None, "array_to_slice_mut");
+ (SliceIndexMut, None, "slice_index_mut_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexMut, None, "Array.index_mut_usize");
+ (ArrayToSliceMut, None, "Array.to_slice_mut");
+ (SliceIndexMut, None, "Slice.index_mut_usize");
+ ]
+ else
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexMut, None, "array_index_usize");
+ (ArrayIndexMut, rg0, "array_update_usize");
+ (ArrayToSliceMut, None, "array_to_slice");
+ (ArrayToSliceMut, rg0, "array_from_slice");
+ (SliceIndexMut, None, "slice_index_usize");
+ (SliceIndexMut, rg0, "slice_update_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexMut, None, "Array.index_usize");
+ (ArrayIndexMut, rg0, "Array.update_usize");
+ (ArrayToSliceMut, None, "Array.to_slice");
+ (ArrayToSliceMut, rg0, "Array.from_slice");
+ (SliceIndexMut, None, "Slice.index_usize");
+ (SliceIndexMut, rg0, "Slice.update_usize");
+ ]
+ in
+ regular @ mut_funs
let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
match !backend with
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index 24d16dca..ee8d4831 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -232,6 +232,14 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
let mk_fun (rust_name : string) (extract_name : string option)
(filter : bool list option) (with_back : bool) (back_no_suffix : bool) :
pattern * bool list option * builtin_fun_info list =
+ (* [back_no_suffix] is used to control whether the backward function should
+ have the suffix "_back" or not (if not, then the forward function has the
+ prefix "_fwd", and is filtered anyway). This is pertinent only if we split
+ the fwd/back functions. *)
+ let back_no_suffix = back_no_suffix && not !Config.return_back_funs in
+ (* Same for the [with_back] option: this is pertinent only if we split
+ the fwd/back functions *)
+ let with_back = with_back && not !Config.return_back_funs in
let rust_name =
try parse_pattern rust_name
with Failure _ ->
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 76432faa..22d176c9 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -195,7 +195,7 @@ let initialize_symbolic_context_for_fun (ctx : decls_ctx) (fdef : fun_decl) :
List.map (fun (g : region_var_group) -> g.id) regions_hierarchy
in
let ctx =
- initialize_eval_context ctx region_groups sg.generics.types
+ initialize_eval_ctx ctx region_groups sg.generics.types
sg.generics.const_generics
in
(* Instantiate the signature. This updates the context because we compute
@@ -277,7 +277,7 @@ let evaluate_function_symbolic_synthesize_backward_from_return (config : config)
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let regions_hierarchy =
- FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies
+ FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies
in
let _, ret_inst_sg =
symbolic_instantiate_fun_sig ctx fdef.signature regions_hierarchy fdef.kind
@@ -466,7 +466,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : decls_ctx)
let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in
let regions_hierarchy =
- FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies
+ FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies
in
(* Create the continuation to finish the evaluation *)
@@ -615,7 +615,7 @@ module Test = struct
assert (body.arg_count = 0);
(* Create the evaluation context *)
- let ctx = initialize_eval_context decls_ctx [] [] [] in
+ let ctx = initialize_eval_ctx decls_ctx [] [] [] in
(* Insert the (uninitialized) local variables *)
let ctx = ctx_push_uninitialized_vars ctx body.locals in
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index e56919fa..a2eb2545 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -1628,7 +1628,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool)
push { value; ty }
| AIgnoredMutLoan (opt_bid, child_av) ->
(* We don't support nested borrows for now *)
- assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty));
assert (opt_bid = None);
(* Simply explore the child *)
list_avalues false push_fail child_av
@@ -1639,7 +1639,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool)
{ child = child_av; given_back = _; given_back_meta = _ }
| AIgnoredSharedLoan child_av ->
(* We don't support nested borrows for now *)
- assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty));
(* Simply explore the child *)
list_avalues false push_fail child_av)
| ABorrow bc -> (
@@ -1659,14 +1659,14 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool)
push av
| AIgnoredMutBorrow (opt_bid, child_av) ->
(* We don't support nested borrows for now *)
- assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty));
assert (opt_bid = None);
(* Just explore the child *)
list_avalues false push_fail child_av
| AEndedIgnoredMutBorrow
{ child = child_av; given_back = _; given_back_meta = _ } ->
(* We don't support nested borrows for now *)
- assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty));
(* Just explore the child *)
list_avalues false push_fail child_av
| AProjSharedBorrow asb ->
@@ -1683,7 +1683,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool)
| ASymbolic _ ->
(* For now, we fore all symbolic values containing borrows to be eagerly
expanded *)
- assert (not (ty_has_borrows ctx.type_context.type_infos ty))
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos ty))
and list_values (v : typed_value) : typed_avalue list * typed_value =
let ty = v.ty in
match v.value with
@@ -1732,7 +1732,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool)
| VSymbolic _ ->
(* For now, we fore all symbolic values containing borrows to be eagerly
expanded *)
- assert (not (ty_has_borrows ctx.type_context.type_infos ty));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos ty));
([], v)
in
diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml
index bbf4d9d5..e489ddc3 100644
--- a/compiler/InterpreterExpansion.ml
+++ b/compiler/InterpreterExpansion.ml
@@ -627,7 +627,7 @@ let greedy_expand_symbolics_with_borrows (config : config) : cm_fun =
inherit [_] iter_eval_ctx
method! visit_VSymbolic _ sv =
- if ty_has_borrows ctx.type_context.type_infos sv.sv_ty then
+ if ty_has_borrows ctx.type_ctx.type_infos sv.sv_ty then
raise (FoundSymbolicValue sv)
else ()
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 1b5b79dd..8536b4ab 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -32,8 +32,7 @@ let expand_primitively_copyable_at_place (config : config)
fun cf ctx ->
let v = read_place access p ctx in
match
- find_first_primitively_copyable_sv_with_borrows
- ctx.type_context.type_infos v
+ find_first_primitively_copyable_sv_with_borrows ctx.type_ctx.type_infos v
with
| None -> cf ctx
| Some sv ->
@@ -351,7 +350,7 @@ let eval_operand_no_reorganize (config : config) (op : operand)
assert (
Option.is_none
(find_first_primitively_copyable_sv_with_borrows
- ctx.type_context.type_infos v));
+ ctx.type_ctx.type_infos v));
(* Actually perform the copy *)
let allow_adt_copy = false in
let ctx, v = copy_value allow_adt_copy config ctx v in
diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli
index f8d979f4..b975371c 100644
--- a/compiler/InterpreterExpressions.mli
+++ b/compiler/InterpreterExpressions.mli
@@ -52,7 +52,7 @@ val eval_operands :
Transmits the computed rvalue to the received continuation.
- Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant
+ Note that this function fails on {!Aeneas.Expressions.rvalue.Discriminant}: discriminant
reads should have been eliminated from the AST.
*)
val eval_rvalue_not_global :
diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml
index c4e180fa..4dabe974 100644
--- a/compiler/InterpreterLoopsFixedPoint.ml
+++ b/compiler/InterpreterLoopsFixedPoint.ml
@@ -300,7 +300,7 @@ let prepare_ashared_loans (loop_id : LoopId.id option) : cm_fun =
let env = List.append fresh_absl env in
let ctx = { ctx with env } in
- let _, new_ctx_ids_map = compute_context_ids ctx in
+ let _, new_ctx_ids_map = compute_ctx_ids ctx in
(* Synthesize *)
match cf ctx with
@@ -385,8 +385,8 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id)
match !fixed_ids with
| Some _ -> ctx1
| None ->
- let old_ids, _ = compute_context_ids ctx1 in
- let new_ids, _ = compute_contexts_ids !ctxs in
+ let old_ids, _ = compute_ctx_ids ctx1 in
+ let new_ids, _ = compute_ctxs_ids !ctxs in
let blids = BorrowId.Set.diff old_ids.blids new_ids.blids in
let aids = AbstractionId.Set.diff old_ids.aids new_ids.aids in
(* End those borrows and abstractions *)
@@ -409,7 +409,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id)
ctxs := List.map (end_borrows_abs blids aids) !ctxs;
(* Note that the fixed ids are given by the original context, from *before*
we introduce fresh abstractions/reborrows for the shared values *)
- fixed_ids := Some (fst (compute_context_ids ctx0));
+ fixed_ids := Some (fst (compute_ctx_ids ctx0));
ctx1
in
@@ -424,12 +424,12 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id)
intersection of ids between the original environment and the list
of new environments *)
let compute_fixed_ids (ctxl : eval_ctx list) : ids_sets =
- let fixed_ids, _ = compute_context_ids ctx0 in
+ let fixed_ids, _ = compute_ctx_ids ctx0 in
let { aids; blids; borrow_ids; loan_ids; dids; rids; sids } = fixed_ids in
let sids = ref sids in
List.iter
(fun ctx ->
- let fixed_ids, _ = compute_context_ids ctx in
+ let fixed_ids, _ = compute_ctx_ids ctx in
sids := SymbolicValueId.Set.inter !sids fixed_ids.sids)
ctxl;
let sids = !sids in
@@ -568,7 +568,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id)
InterpreterBorrows.end_abstraction_no_synth config abs_id ctx
in
(* Explore the context, and check which abstractions are not there anymore *)
- let ids, _ = compute_context_ids ctx in
+ let ids, _ = compute_ctx_ids ctx in
let ended_ids = AbstractionId.Set.diff !fp_aids ids.aids in
add_ended_aids rg_id ended_ids)
ctx.region_groups
@@ -840,8 +840,8 @@ let compute_fixed_point_id_correspondance (fixed_ids : ids_sets)
let compute_fp_ctx_symbolic_values (ctx : eval_ctx) (fp_ctx : eval_ctx) :
SymbolicValueId.Set.t * symbolic_value list =
- let old_ids, _ = compute_context_ids ctx in
- let fp_ids, fp_ids_maps = compute_context_ids fp_ctx in
+ let old_ids, _ = compute_ctx_ids ctx in
+ let fp_ids, fp_ids_maps = compute_ctx_ids fp_ctx in
let fresh_sids = SymbolicValueId.Set.diff fp_ids.sids old_ids.sids in
(* Compute the set of symbolic values which appear in shared values inside
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index 8d485483..445e5abf 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -326,8 +326,8 @@ let mk_collapse_ctx_merge_duplicate_funs (loop_id : LoopId.id) (ctx : eval_ctx)
let _ =
let _, ty0, _ = ty_as_ref ty0 in
let _, ty1, _ = ty_as_ref ty1 in
- assert (not (ty_has_borrows ctx.type_context.type_infos ty0));
- assert (not (ty_has_borrows ctx.type_context.type_infos ty1))
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos ty0));
+ assert (not (ty_has_borrows ctx.type_ctx.type_infos ty1))
in
(* Same remarks as for [merge_amut_borrows] *)
@@ -543,11 +543,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx)
(* Construct the joined context - of course, the type, fun, etc. contexts
* should be the same in the two contexts *)
let {
- type_context;
- fun_context;
- global_context;
- trait_decls_context;
- trait_impls_context;
+ type_ctx;
+ fun_ctx;
+ global_ctx;
+ trait_decls_ctx;
+ trait_impls_ctx;
region_groups;
type_vars;
const_generic_vars;
@@ -559,11 +559,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx)
ctx0
in
let {
- type_context = _;
- fun_context = _;
- global_context = _;
- trait_decls_context = _;
- trait_impls_context = _;
+ type_ctx = _;
+ fun_ctx = _;
+ global_ctx = _;
+ trait_decls_ctx = _;
+ trait_impls_ctx = _;
region_groups = _;
type_vars = _;
const_generic_vars = _;
@@ -577,11 +577,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx)
let ended_regions = RegionId.Set.union ended_regions0 ended_regions1 in
Ok
{
- type_context;
- fun_context;
- global_context;
- trait_decls_context;
- trait_impls_context;
+ type_ctx;
+ fun_ctx;
+ global_ctx;
+ trait_decls_ctx;
+ trait_impls_ctx;
region_groups;
type_vars;
const_generic_vars;
@@ -621,7 +621,7 @@ let destructure_new_abs (loop_id : LoopId.id)
contexts we join don't have non-fixed abstractions with the same ids.
*)
let refresh_abs (old_abs : AbstractionId.Set.t) (ctx : eval_ctx) : eval_ctx =
- let ids, _ = compute_context_ids ctx in
+ let ids, _ = compute_ctx_ids ctx in
let abs_to_refresh = AbstractionId.Set.diff ids.aids old_abs in
let aids_subst =
List.map
diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml
index 90559c29..2a688fa7 100644
--- a/compiler/InterpreterLoopsMatchCtxs.ml
+++ b/compiler/InterpreterLoopsMatchCtxs.ml
@@ -658,7 +658,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct
else (
(* The caller should have checked that the symbolic values don't contain
borrows *)
- assert (not (ty_has_borrows S.ctx.type_context.type_infos sv0.sv_ty));
+ assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv0.sv_ty));
(* We simply introduce a fresh symbolic value *)
mk_fresh_symbolic_value sv0.sv_ty)
@@ -669,7 +669,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct
- there are no borrows in the "regular" value
If there are loans in the regular value, raise an exception.
*)
- assert (not (ty_has_borrows S.ctx.type_context.type_infos sv.sv_ty));
+ assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv.sv_ty));
assert (not (value_has_borrows S.ctx v.value));
let value_is_left = not left in
(match InterpreterBorrowsCore.get_first_loan_in_value v with
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 30b7b333..97c8bcd6 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -728,7 +728,13 @@ let create_push_abstractions_from_abs_region_groups
to a trait clause but directly to the method provided in the trait declaration.
*)
let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
- : fun_id_or_trait_method_ref * generic_args * fun_decl * inst_fun_sig =
+ :
+ fun_id_or_trait_method_ref
+ * generic_args
+ * (generic_args * trait_instance_id) option
+ * fun_decl
+ * region_var_groups
+ * inst_fun_sig =
match call.func with
| FnOpMove _ ->
(* Closure case: TODO *)
@@ -747,13 +753,13 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
let tr_self = UnknownTrait __FUNCTION__ in
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular fid)
- ctx.fun_context.regions_hierarchies
+ ctx.fun_ctx.regions_hierarchies
in
let inst_sg =
instantiate_fun_sig ctx func.generics tr_self def.signature
regions_hierarchy
in
- (func.func, func.generics, def, inst_sg)
+ (func.func, func.generics, None, def, regions_hierarchy, inst_sg)
| FunId (FAssumed _) ->
(* Unreachable: must be a transparent function *)
raise (Failure "Unreachable")
@@ -793,7 +799,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
let fid : fun_id = FRegular id in
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find fid
- ctx.fun_context.regions_hierarchies
+ ctx.fun_ctx.regions_hierarchies
in
let inst_sg =
instantiate_fun_sig ctx generics tr_self
@@ -806,7 +812,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
we also need to update the generics.
*)
let func = FunId fid in
- (func, generics, method_def, inst_sg)
+ ( func,
+ generics,
+ Some (generics, tr_self),
+ method_def,
+ regions_hierarchy,
+ inst_sg )
| None ->
(* If not found, lookup the methods provided by the trait *declaration*
(remember: for now, we forbid overriding provided methods) *)
@@ -853,14 +864,19 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
method_def.signature.parent_params_info));
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular method_id)
- ctx.fun_context.regions_hierarchies
+ ctx.fun_ctx.regions_hierarchies
in
let tr_self = TraitRef trait_ref in
let inst_sg =
instantiate_fun_sig ctx all_generics tr_self
method_def.signature regions_hierarchy
in
- (func.func, func.generics, method_def, inst_sg))
+ ( func.func,
+ func.generics,
+ Some (all_generics, tr_self),
+ method_def,
+ regions_hierarchy,
+ inst_sg ))
| _ ->
(* We are using a local clause - we lookup the trait decl *)
let trait_decl =
@@ -884,14 +900,19 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
(* Instantiate *)
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular method_id)
- ctx.fun_context.regions_hierarchies
+ ctx.fun_ctx.regions_hierarchies
in
let tr_self = TraitRef trait_ref in
let inst_sg =
instantiate_fun_sig ctx generics tr_self method_def.signature
regions_hierarchy
in
- (func.func, func.generics, method_def, inst_sg)))
+ ( func.func,
+ func.generics,
+ Some (generics, tr_self),
+ method_def,
+ regions_hierarchy,
+ inst_sg )))
(** Evaluate a statement *)
let rec eval_statement (config : config) (st : statement) : st_cm_fun =
@@ -1277,14 +1298,15 @@ and eval_transparent_function_call_concrete (config : config)
and eval_transparent_function_call_symbolic (config : config) (call : call) :
st_cm_fun =
fun cf ctx ->
- let func, generics, def, inst_sg =
+ let func, generics, trait_method_generics, def, regions_hierarchy, inst_sg =
eval_transparent_function_call_symbolic_inst call ctx
in
(* Sanity check *)
assert (List.length call.args = List.length def.signature.inputs);
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config func inst_sg generics
- call.args call.dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config func def.signature
+ regions_hierarchy inst_sg generics trait_method_generics call.args call.dest
+ cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
@@ -1298,8 +1320,11 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) :
trait ref as input.
*)
and eval_function_call_symbolic_from_inst_sig (config : config)
- (fid : fun_id_or_trait_method_ref) (inst_sg : inst_fun_sig)
- (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun =
+ (fid : fun_id_or_trait_method_ref) (sg : fun_sig)
+ (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig)
+ (generics : generic_args)
+ (trait_method_generics : (generic_args * trait_instance_id) option)
+ (args : operand list) (dest : place) : st_cm_fun =
fun cf ctx ->
log#ldebug
(lazy
@@ -1378,8 +1403,9 @@ and eval_function_call_symbolic_from_inst_sig (config : config)
let expr = cf ctx in
(* Synthesize the symbolic AST *)
- S.synthesize_regular_function_call fid call_id ctx abs_ids generics args
- args_places ret_spc dest_place expr
+ S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy
+ abs_ids generics trait_method_generics args args_places ret_spc dest_place
+ expr
in
let cc = comp cc cf_call in
@@ -1450,7 +1476,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id)
* this is a current limitation of our synthesis *)
assert (
List.for_all
- (fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty))
+ (fun ty -> not (ty_has_borrows ctx.type_ctx.type_infos ty))
generics.types);
(* There are two cases (and this is extremely annoying):
@@ -1468,7 +1494,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id)
(* In symbolic mode, the behaviour of a function call is completely defined
* by the signature of the function: we thus simply generate correctly
* instantiated signatures, and delegate the work to an auxiliary function *)
- let inst_sig =
+ let sg, regions_hierarchy, inst_sig =
match fid with
| BoxFree ->
(* Should have been treated above *)
@@ -1476,18 +1502,20 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id)
| _ ->
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FAssumed fid)
- ctx.fun_context.regions_hierarchies
+ ctx.fun_ctx.regions_hierarchies
in
(* There shouldn't be any reference to Self *)
let tr_self = UnknownTrait __FUNCTION__ in
- instantiate_fun_sig ctx generics tr_self
- (Assumed.get_assumed_fun_sig fid)
- regions_hierarchy
+ let sg = Assumed.get_assumed_fun_sig fid in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self sg regions_hierarchy
+ in
+ (sg, regions_hierarchy, inst_sg)
in
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid))
- inst_sig generics args dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg
+ regions_hierarchy inst_sig generics None args dest cf ctx
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : config) (body : statement) : st_cm_fun =
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index e04a6b90..a1a06ee5 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -265,7 +265,7 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : eval_ctx)
inherit [_] iter_typed_value
method! visit_symbolic_value _ s =
- if ty_has_borrow_under_mut ctx.type_context.type_infos s.sv_ty then
+ if ty_has_borrow_under_mut ctx.type_ctx.type_infos s.sv_ty then
raise Found
else ()
end
@@ -288,15 +288,15 @@ let rvalue_get_place (rv : rvalue) : place option =
(** See {!ValuesUtils.symbolic_value_has_borrows} *)
let symbolic_value_has_borrows (ctx : eval_ctx) (sv : symbolic_value) : bool =
- ValuesUtils.symbolic_value_has_borrows ctx.type_context.type_infos sv
+ ValuesUtils.symbolic_value_has_borrows ctx.type_ctx.type_infos sv
(** See {!ValuesUtils.value_has_borrows}. *)
let value_has_borrows (ctx : eval_ctx) (v : value) : bool =
- ValuesUtils.value_has_borrows ctx.type_context.type_infos v
+ ValuesUtils.value_has_borrows ctx.type_ctx.type_infos v
(** See {!ValuesUtils.value_has_loans_or_borrows}. *)
let value_has_loans_or_borrows (ctx : eval_ctx) (v : value) : bool =
- ValuesUtils.value_has_loans_or_borrows ctx.type_context.type_infos v
+ ValuesUtils.value_has_loans_or_borrows ctx.type_ctx.type_infos v
(** See {!ValuesUtils.value_has_loans}. *)
let value_has_loans (v : value) : bool = ValuesUtils.value_has_loans v
@@ -401,19 +401,19 @@ let compute_env_elem_ids (x : env_elem) : ids_sets * ids_to_values =
compute_env_ids [ x ]
(** Compute the sets of ids found in a list of contexts. *)
-let compute_contexts_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values =
+let compute_ctxs_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values =
let compute, get_ids, get_ids_to_values = compute_ids () in
List.iter (compute#visit_eval_ctx ()) ctxl;
(get_ids (), get_ids_to_values ())
(** Compute the sets of ids found in a context. *)
-let compute_context_ids (ctx : eval_ctx) : ids_sets * ids_to_values =
- compute_contexts_ids [ ctx ]
+let compute_ctx_ids (ctx : eval_ctx) : ids_sets * ids_to_values =
+ compute_ctxs_ids [ ctx ]
(** **WARNING**: this function doesn't compute the normalized types
(for the trait type aliases). This should be computed afterwards.
*)
-let initialize_eval_context (ctx : decls_ctx)
+let initialize_eval_ctx (ctx : decls_ctx)
(region_groups : RegionGroupId.id list) (type_vars : type_var list)
(const_generic_vars : const_generic_var list) : eval_ctx =
reset_global_counters ();
@@ -427,11 +427,11 @@ let initialize_eval_context (ctx : decls_ctx)
const_generic_vars)
in
{
- type_context = ctx.type_ctx;
- fun_context = ctx.fun_ctx;
- global_context = ctx.global_ctx;
- trait_decls_context = ctx.trait_decls_ctx;
- trait_impls_context = ctx.trait_impls_ctx;
+ type_ctx = ctx.type_ctx;
+ fun_ctx = ctx.fun_ctx;
+ global_ctx = ctx.global_ctx;
+ trait_decls_ctx = ctx.trait_decls_ctx;
+ trait_impls_ctx = ctx.trait_impls_ctx;
region_groups;
type_vars;
const_generic_vars;
diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml
index fa0d7436..b87cdff7 100644
--- a/compiler/Invariants.ml
+++ b/compiler/Invariants.ml
@@ -768,7 +768,7 @@ let check_symbolic_values (ctx : eval_ctx) : unit =
assert (info.env_count = 0 || info.aproj_borrows = []);
(* A symbolic value containing borrows can't be duplicated (i.e., copied):
* it must be expanded first *)
- if ty_has_borrows ctx.type_context.type_infos info.ty then
+ if ty_has_borrows ctx.type_ctx.type_infos info.ty then
assert (info.env_count <= 1);
(* A duplicated symbolic value is necessarily primitively copyable *)
assert (info.env_count <= 1 || ty_is_primitively_copyable info.ty);
diff --git a/compiler/Main.ml b/compiler/Main.ml
index 835b9088..0b8ec439 100644
--- a/compiler/Main.ml
+++ b/compiler/Main.ml
@@ -120,6 +120,9 @@ let () =
" Generate a default lakefile.lean (Lean only)" );
("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC");
("-k", Arg.Clear fail_hard, " Do not fail hard in case of error");
+ ( "-split-fwd-back",
+ Arg.Clear return_back_funs,
+ " Split the forward and backward functions." );
]
in
@@ -193,9 +196,6 @@ let () =
let _ =
match !backend with
| FStar ->
- (* Some patterns are not supported *)
- decompose_monadic_let_bindings := false;
- decompose_nested_let_patterns := false;
(* F* can disambiguate the field names *)
record_fields_short_names := true
| Coq ->
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 0e2ec1fc..8999c77d 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -409,11 +409,11 @@ module Contexts = struct
}
let eval_ctx_to_fmt_env (ctx : eval_ctx) : fmt_env =
- let type_decls = ctx.type_context.type_decls in
- let fun_decls = ctx.fun_context.fun_decls in
- let global_decls = ctx.global_context.global_decls in
- let trait_decls = ctx.trait_decls_context.trait_decls in
- let trait_impls = ctx.trait_impls_context.trait_impls in
+ let type_decls = ctx.type_ctx.type_decls in
+ let fun_decls = ctx.fun_ctx.fun_decls in
+ let global_decls = ctx.global_ctx.global_decls in
+ let trait_decls = ctx.trait_decls_ctx.trait_decls in
+ let trait_impls = ctx.trait_impls_ctx.trait_impls in
(* Below: it is always safe to omit fields - if an id can't be found at
printing time, we print the id (in raw form) instead of the name it
designates. *)
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 2fe5843e..66475d02 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -103,10 +103,13 @@ let adt_field_names (env : fmt_env) =
Print.Types.adt_field_names (fmt_env_to_llbc_fmt_env env)
let option_to_string = Print.option_to_string
-let type_var_to_string = Print.Types.type_var_to_string
-let const_generic_var_to_string = Print.Types.const_generic_var_to_string
-let integer_type_to_string = Print.Values.integer_type_to_string
let literal_type_to_string = Print.Values.literal_type_to_string
+let type_var_to_string (v : type_var) = "(" ^ v.name ^ ": Type)"
+
+let const_generic_var_to_string (v : const_generic_var) =
+ "(" ^ v.name ^ " : " ^ literal_type_to_string v.ty ^ ")"
+
+let integer_type_to_string = Print.Values.integer_type_to_string
let scalar_value_to_string = Print.Values.scalar_value_to_string
let literal_to_string = Print.Values.literal_to_string
@@ -203,13 +206,12 @@ and trait_instance_id_to_string (env : fmt_env) (inside : bool)
| UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")"
let trait_clause_to_string (env : fmt_env) (clause : trait_clause) : string =
- let clause_id = trait_clause_id_to_string env clause.clause_id in
let trait_id = trait_decl_id_to_string env clause.trait_id in
let generics = generic_args_to_strings env true clause.generics in
let generics =
if generics = [] then "" else " " ^ String.concat " " generics
in
- "[" ^ clause_id ^ "]: " ^ trait_id ^ generics
+ trait_id ^ generics
let generic_params_to_strings (env : fmt_env) (generics : generic_params) :
string list =
@@ -543,9 +545,9 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string)
let app, args = destruct_apps e in
(* Convert to string *)
app_to_string env inside indent indent_incr app args
- | Abs _ ->
- let xl, e = destruct_abs_list e in
- let e = abs_to_string env indent indent_incr xl e in
+ | Lambda _ ->
+ let xl, e = destruct_lambdas e in
+ let e = lambda_to_string env indent indent_incr xl e in
if inside then "(" ^ e ^ ")" else e
| Qualif _ ->
(* Qualifier without arguments *)
@@ -609,35 +611,36 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string)
* expression *)
let app, generics =
match app.e with
- | Qualif qualif ->
+ | Qualif qualif -> (
(* Qualifier case *)
- (* Convert the qualifier identifier *)
- let qualif_s =
- match qualif.id with
- | FunOrOp fun_id -> fun_or_op_id_to_string env fun_id
- | Global global_id -> global_decl_id_to_string env global_id
- | AdtCons adt_cons_id ->
- let variant_s =
- adt_variant_to_string env adt_cons_id.adt_id
- adt_cons_id.variant_id
- in
- ConstStrings.constructor_prefix ^ variant_s
- | Proj { adt_id; field_id } ->
- let adt_s = adt_variant_to_string env adt_id None in
- let field_s = adt_field_to_string env adt_id field_id in
- (* Adopting an F*-like syntax *)
- ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s
- | TraitConst (trait_ref, generics, const_name) ->
- let trait_ref = trait_ref_to_string env true trait_ref in
- let generics_s = generic_args_to_string env generics in
+ match qualif.id with
+ | FunOrOp fun_id ->
+ let generics = generic_args_to_strings env true qualif.generics in
+ let qualif_s = fun_or_op_id_to_string env fun_id in
+ (qualif_s, generics)
+ | Global global_id ->
+ let generics = generic_args_to_strings env true qualif.generics in
+ (global_decl_id_to_string env global_id, generics)
+ | AdtCons adt_cons_id ->
+ let variant_s =
+ adt_variant_to_string env adt_cons_id.adt_id
+ adt_cons_id.variant_id
+ in
+ (ConstStrings.constructor_prefix ^ variant_s, [])
+ | Proj { adt_id; field_id } ->
+ let adt_s = adt_variant_to_string env adt_id None in
+ let field_s = adt_field_to_string env adt_id field_id in
+ (* Adopting an F*-like syntax *)
+ (ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s, [])
+ | TraitConst (trait_ref, generics, const_name) ->
+ let trait_ref = trait_ref_to_string env true trait_ref in
+ let generics_s = generic_args_to_string env generics in
+ let qualif =
if generics <> empty_generic_args then
"(" ^ trait_ref ^ generics_s ^ ")." ^ const_name
else trait_ref ^ "." ^ const_name
- in
- (* Convert the type instantiation *)
- let generics = generic_args_to_strings env true qualif.generics in
- (* *)
- (qualif_s, generics)
+ in
+ (qualif, []))
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
@@ -660,7 +663,7 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string)
(* Add parentheses *)
if all_args <> [] && inside then "(" ^ e ^ ")" else e
-and abs_to_string (env : fmt_env) (indent : string) (indent_incr : string)
+and lambda_to_string (env : fmt_env) (indent : string) (indent_incr : string)
(xl : typed_pattern list) (e : texpression) : string =
let xl = List.map (typed_pattern_to_string env) xl in
let e = texpression_to_string env false indent indent_incr e in
@@ -708,21 +711,14 @@ and loop_to_string (env : fmt_env) (indent : string) (indent_incr : string)
^ String.concat "; " (List.map (var_to_string env) loop.inputs)
^ "]"
in
- let back_output_tys =
- let tys =
- match loop.back_output_tys with
- | None -> ""
- | Some tys -> String.concat "; " (List.map (ty_to_string env false) tys)
- in
- "back_output_tys: [" ^ tys ^ "]"
- in
+ let output_ty = "output_ty: " ^ ty_to_string env false loop.output_ty in
let fun_end =
texpression_to_string env false indent2 indent_incr loop.fun_end
in
let loop_body =
texpression_to_string env false indent2 indent_incr loop.loop_body
in
- "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n"
+ "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ output_ty ^ "\n"
^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n"
^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n"
^ indent ^ "}"
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 8d39cc69..a879ba37 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref =
(** A function id for a non-assumed function *)
type regular_fun_id =
- fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option
+ fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -684,7 +684,7 @@ type expression =
field accesses with calls to projectors over fields (when there
are clashes of field names, some provers like F* get pretty bad...)
*)
- | Abs of typed_pattern * texpression (** Lambda abstraction: [fun x -> e] *)
+ | Lambda of typed_pattern * texpression (** Lambda abstraction: [λ x => e] *)
| Qualif of qualif (** A top-level qualifier *)
| Let of bool * typed_pattern * texpression * texpression
(** Let binding.
@@ -754,9 +754,7 @@ and loop = {
inputs : var list;
inputs_lvs : typed_pattern list;
(** The inputs seen as patterns. See {!fun_body}. *)
- back_output_tys : ty list option;
- (** The types of the given back values, if we ar esynthesizing a backward
- function *)
+ output_ty : ty; (** The output type of the loop *)
loop_body : texpression;
}
@@ -860,8 +858,8 @@ type fun_effect_info = {
the set [{ forward function } U { backward functions }].
We need this because of the option {!val:Config.backward_no_state_update}:
- if it is [true], then in case of a backward function {!stateful} is [false],
- but we might need to know whether the corresponding forward function
+ if it is [true], then in case of a backward function {!stateful} might be
+ [false], but we might need to know whether the corresponding forward function
is stateful or not.
*)
stateful : bool; (** [true] if the function is stateful (updates a state) *)
@@ -873,22 +871,108 @@ type fun_effect_info = {
}
[@@deriving show]
-(** Meta information about a function signature *)
-type fun_sig_info = {
+type inputs_info = {
has_fuel : bool;
- (* TODO: add [num_fwd_inputs_no_fuel_no_state] *)
- num_fwd_inputs_with_fuel_no_state : int;
- (** The number of input types for forward computation, with the fuel (if used)
+ num_inputs_no_fuel_no_state : int;
+ (** The number of input types ignoring the fuel (if used)
+ and ignoring the state (if used) *)
+ num_inputs_with_fuel_no_state : int;
+ (** The number of input types, with the fuel (if used)
and ignoring the state (if used) *)
- num_fwd_inputs_with_fuel_with_state : int;
- (** The number of input types for forward computation, with fuel and state (if used) *)
- num_back_inputs_no_state : int option;
- (** The number of additional inputs for the backward computation (if pertinent),
- ignoring the state (if there is one) *)
- num_back_inputs_with_state : int option;
- (** The number of additional inputs for the backward computation (if pertinent),
- with the state (if there is one) *)
+ num_inputs_with_fuel_with_state : int;
+ (** The number of input types, with fuel and state (if used) *)
+}
+[@@deriving show]
+
+(** Meta information about a function signature *)
+type fun_sig_info = {
+ fwd_info : inputs_info;
+ (** Information about the inputs of the forward function *)
effect_info : fun_effect_info;
+ ignore_output : bool;
+ (** In case we merge the forward/backward functions: should we ignore
+ the output (happens for forward functions if the output type is
+ [unit] and there are non-filtered backward functions)?
+ *)
+}
+[@@deriving show]
+
+type back_sg_info = {
+ inputs : (string option * ty) list;
+ (** The additional inputs of the backward function *)
+ inputs_no_state : (string option * ty) list;
+ outputs : ty list;
+ (** The "decomposed" list of outputs.
+
+ The list contains all the types of
+ all the given back values (there is at most one type per forward
+ input argument).
+
+ Ex.:
+ {[
+ fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T;
+ ]}
+ Decomposed outputs:
+ - forward function: [[T]]
+ - backward function: [[T; T]] (for "x" and "y")
+
+ Non-decomposed ouputs (if the function can fail, but is not stateful):
+ - [result T]
+ - [[result (T * T)]]
+ *)
+ output_names : string option list;
+ (** The optional names for the backward outputs.
+ We derive those from the names of the inputs of the original LLBC
+ function. *)
+ effect_info : fun_effect_info;
+ filter : bool; (** Should we filter this backward function? *)
+}
+[@@deriving show]
+
+(** A *decomposed* function signature. *)
+type decomposed_fun_sig = {
+ generics : generic_params;
+ (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *)
+ llbc_generics : Types.generic_params;
+ (** We use the LLBC generics to generate "pretty" names, for instance
+ for the variables we introduce for the trait clauses: we derive
+ those names from the types, and when doing so it is more meaningful
+ to derive them from the original LLBC types from before the
+ simplification of types like boxes and references. *)
+ preds : predicates;
+ fwd_inputs : ty list;
+ (** The types of the inputs of the forward function.
+
+ Note that those input types take include the [fuel] parameter,
+ if the function uses fuel for termination, and the [state] parameter,
+ if the function is stateful.
+
+ For instance, if we have the following Rust function:
+ {[
+ fn f(x : int);
+ ]}
+
+ If we translate it to a stateful function which uses fuel we get:
+ {[
+ val f : nat -> int -> state -> result (state * unit);
+ ]}
+
+ In particular, the list of input types is: [[nat; int; state]].
+ *)
+ fwd_output : ty;
+ (** The "pure" output type of the forward function.
+
+ Note that this type doesn't contain the "effect" of the function (i.e.,
+ we haven't added the [state] if it is a stateful function and haven't
+ wrapped the type in a [result]). Also, this output type is only about
+ the forward function (it doesn't contain the type of the closures we
+ return for the backward functions, in case we merge the forward and
+ backward functions).
+ *)
+ back_sg : back_sg_info RegionGroupId.Map.t;
+ (** Information about the backward functions *)
+ fwd_info : fun_sig_info;
+ (** Additional information about the forward function *)
}
[@@deriving show]
@@ -906,15 +990,15 @@ type fun_sig_info = {
[in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state ->
result (state & (back_out0 & ... & back_outp))] (* state-error *)
- Note that a stateful backward function may take two states as inputs: the
- state received by the associated forward function, and the state at which
- the backward is called. This leads to code of the following shape:
+ Note that a stateful backward function may take two states as inputs: the
+ state received by the associated forward function, and the state at which
+ the backward is called. This leads to code of the following shape:
- {[
- (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd
- ... // the state may be updated
- (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back
- ]}
+ {[
+ (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd
+ ... // the state may be updated
+ (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back
+ ]}
The function's type should be given by [mk_arrows sig.inputs sig.output].
We provide additional meta-information with {!fun_sig.info}:
@@ -962,40 +1046,14 @@ type fun_sig = {
be a tuple with a [state] if the function is stateful, and will be wrapped
in a [result] if the function can fail.
*)
- doutputs : ty list;
- (** The "decomposed" list of outputs.
-
- In case of a forward function, the list has length = 1, for the
- type of the returned value.
-
- In case of backward function, the list contains all the types of
- all the given back values (there is at most one type per forward
- input argument).
-
- Ex.:
- {[
- fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T;
- ]}
- Decomposed outputs:
- - forward function: [[T]]
- - backward function: [[T; T]] (for "x" and "y")
-
- Non-decomposed ouputs (if the function can fail, but is not stateful):
- - [result T]
- - [[result (T * T)]]
- *)
- info : fun_sig_info; (** Additional information *)
+ fwd_info : fun_sig_info;
+ (** Additional information about the forward function. *)
+ back_effect_info : fun_effect_info RegionGroupId.Map.t;
}
[@@deriving show]
(** An instantiated function signature. See {!fun_sig} *)
-type inst_fun_sig = {
- inputs : ty list;
- output : ty;
- doutputs : ty list;
- info : fun_sig_info;
-}
-[@@deriving show]
+type inst_fun_sig = { inputs : ty list; output : ty } [@@deriving show]
type fun_body = {
inputs : var list;
@@ -1020,7 +1078,7 @@ type fun_decl = {
*)
loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
- back_id : T.RegionGroupId.id option;
+ back_id : RegionGroupId.id option;
llbc_name : llbc_name; (** The original LLBC name. *)
name : string;
(** We use the name only for printing purposes (for debugging):
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 959ec1c8..ec64df21 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -11,6 +11,10 @@ let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string =
let fmt = trans_ctx_to_pure_fmt_env ctx in
PrintPure.fun_decl_to_string fmt def
+let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string =
+ let fmt = trans_ctx_to_pure_fmt_env ctx in
+ PrintPure.fun_sig_to_string fmt sg
+
(** Small utility.
We sometimes have to insert new fresh variables in a function body, in which
@@ -385,17 +389,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx, arg = update_texpression arg ctx in
let e = App (app, arg) in
(ctx, e)
- | Abs (x, e) -> update_abs x e ctx
| Qualif _ -> (* nothing to do *) (ctx, e.e)
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
| Loop loop -> update_loop loop ctx
| StructUpdate supd -> update_struct_update supd ctx
+ | Lambda (lb, e) -> update_lambda lb e ctx
| Meta (meta, e) -> update_emeta meta e ctx
in
(ctx, { e; ty })
(* *)
- and update_abs (x : typed_pattern) (e : texpression) (ctx : pn_ctx) :
+ and update_lambda (x : typed_pattern) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
(* We first add the left-constraint *)
let ctx = add_left_constraint x ctx in
@@ -404,7 +408,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(* Update the abstracted value *)
let x = update_typed_pattern ctx x in
(* Put together *)
- (ctx, Abs (x, e))
+ (ctx, Lambda (x, e))
(* *)
and update_let (monadic : bool) (lv : typed_pattern) (re : texpression)
(e : texpression) (ctx : pn_ctx) : pn_ctx * expression =
@@ -455,7 +459,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
} =
loop
@@ -474,7 +478,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in
@@ -641,6 +645,135 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let body = { body with body = obj#visit_texpression () body.body } in
{ def with body = Some body }
+(** Simplify the let-bindings by performing the following rewritings:
+
+ Move inner let-bindings outside. This is especially useful to simplify
+ the backward expressions, when we merge the forward/backward functions.
+ Note that the rule is also applied with monadic let-bindings.
+ {[
+ let x :=
+ let y := ... in
+ e
+
+ ~~>
+
+ let y := ... in
+ let x := e
+ ]}
+
+ Simplify panics and returns:
+ {[
+ let x ← fail
+ ...
+ ~~>
+ fail
+
+ let x ← return y
+ ...
+ ~~>
+ let x := y
+ ...
+ ]}
+
+ Simplify tuples:
+ {[
+ let (y0, y1) := (x0, x1) in
+ ...
+ ~~>
+ let y0 = x0 in
+ let y1 = x1 in
+ ...
+ ]}
+
+ Simplify arrows:
+ {[
+ let f := fun x => g x in
+ ...
+ ~~>
+ let f := g in
+ ...
+ ]}
+ *)
+let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_Let env monadic lv rv next =
+ match rv.e with
+ | Let (rmonadic, rlv, rrv, rnext) ->
+ (* Case 1: move the inner let outside then re-visit *)
+ let rnext1 = Let (monadic, lv, rnext, next) in
+ let rnext1 = { ty = next.ty; e = rnext1 } in
+ self#visit_Let env rmonadic rlv rrv rnext1
+ | App
+ ( {
+ e =
+ Qualif
+ {
+ id =
+ AdtCons
+ {
+ adt_id = TAssumed TResult;
+ variant_id = Some variant_id;
+ };
+ generics = _;
+ };
+ ty = _;
+ },
+ x ) ->
+ (* return/fail case *)
+ if variant_id = result_return_id then
+ (* Return case - note that the simplification we just perform
+ might have unlocked the tuple simplification below *)
+ self#visit_Let env false lv x next
+ else if variant_id = result_fail_id then
+ (* Fail case *)
+ self#visit_expression env rv.e
+ else raise (Failure "Unexpected")
+ | App _ ->
+ (* This might be the tuple case *)
+ if not monadic then
+ match
+ (opt_dest_struct_pattern lv, opt_dest_tuple_texpression rv)
+ with
+ | Some pats, Some vals ->
+ (* Tuple case *)
+ let pat_vals = List.combine pats vals in
+ let e =
+ List.fold_right
+ (fun (pat, v) next -> mk_let false pat v next)
+ pat_vals next
+ in
+ super#visit_expression env e.e
+ | _ -> super#visit_Let env monadic lv rv next
+ else super#visit_Let env monadic lv rv next
+ | Lambda _ ->
+ if not monadic then
+ (* Arrow case *)
+ let pats, e = destruct_lambdas rv in
+ let g, args = destruct_apps e in
+ if List.length pats = List.length args then
+ (* Check if the arguments are exactly the lambdas *)
+ let check_pat_arg ((pat, arg) : typed_pattern * texpression) =
+ match (pat.value, arg.e) with
+ | PatVar (v, _), Var vid -> v.id = vid
+ | _ -> false
+ in
+ if List.for_all check_pat_arg (List.combine pats args) then
+ self#visit_Let env monadic lv g next
+ else super#visit_Let env monadic lv rv next
+ else super#visit_Let env monadic lv rv next
+ else super#visit_Let env monadic lv rv next
+ | _ -> super#visit_Let env monadic lv rv next
+ end
+ in
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
+
(** Inline the useless variable (re-)assignments:
A lot of intermediate variable assignments are introduced through the
@@ -667,8 +800,8 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
leave the let-bindings where they are, and eliminated them in a subsequent
pass (if they are useless).
*)
-let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
- (inline_pure : bool) (def : fun_decl) : fun_decl =
+let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool)
+ ~(inline_const : bool) ~(inline_pure : bool) (def : fun_decl) : fun_decl =
let obj =
object (self)
inherit [_] map_expression as super
@@ -693,15 +826,31 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
| _ -> false
in
(* And either:
- * 2.1 the right-expression is a variable, a global or a const generic var *)
+ 2.1 the right-expression is a variable, a global or a const generic var *)
let var_or_global = is_var re || is_cvar re || is_global re in
(* Or:
- * 2.2 the right-expression is a constant value, an ADT value,
- * a projection or a primitive function call *and* the flag
- * [inline_pure] is set *)
+ 2.2 the right-expression is a constant-value and we inline constant values,
+ *or* it is a qualif with no arguments (we consider this as a const) *)
+ let const_re =
+ inline_const
+ &&
+ let is_const_adt =
+ let app, args = destruct_apps re in
+ if args = [] then
+ match app.e with
+ | Qualif _ -> true
+ | StructUpdate upd -> upd.updates = []
+ | _ -> false
+ else false
+ in
+ is_const re || is_const_adt
+ in
+ (* Or:
+ 2.3 the right-expression is an ADT value, a projection or a
+ primitive function call *and* the flag [inline_pure] is set *)
let pure_re =
- is_const re
- ||
+ inline_pure
+ &&
let app, _ = destruct_apps re in
match app.e with
| Qualif qualif -> (
@@ -716,7 +865,7 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
| _ -> false
in
let filter =
- filter_left && (var_or_global || (inline_pure && pure_re))
+ filter_left && (var_or_global || const_re || pure_re)
in
(* Update the rhs (we may perform substitutions inside, and it is
@@ -776,9 +925,11 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
in
{ def with body = Some body }
-(** Given a forward or backward function call, is there, for every execution
+(** For the cases where we split the forward/backward functions.
+
+ Given a forward or backward function call, is there, for every execution
path, a child backward function called later with exactly the same input
- list prefix? We use this to filter useless function calls: if there are
+ list prefix. We use this to filter useless function calls: if there are
such child calls, we can remove this one (in case its outputs are not
used).
We do this check because we can't simply remove function calls whose
@@ -890,12 +1041,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
+ | Lambda (_, e) -> self#visit_texpression env e
| App _ -> (
fun () ->
match opt_destruct_function_call e with
| Some (func1, tys1, args1) -> check_call func1 tys1 args1
| None -> false)
- | Abs (_, e) -> self#visit_texpression env e
| Qualif _ ->
(* Note that this case includes functions without arguments *)
fun () -> false
@@ -973,10 +1124,23 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
| Var _ | CVar _ | Const _ | App _ | Qualif _
- | Switch (_, _)
| Meta (_, _)
- | StructUpdate _ | Abs _ ->
+ | StructUpdate _ | Lambda _ ->
super#visit_expression env e
+ | Switch (scrut, switch) -> (
+ match switch with
+ | If (_, _) -> super#visit_expression env e
+ | Match branches ->
+ (* Simplify the branches *)
+ let simplify_branch (br : match_branch) =
+ (* Compute the set of values used inside the branch *)
+ let branch, used = self#visit_texpression env br.branch in
+ (* Simplify the pattern *)
+ let pat, _ = filter_typed_pattern (used ()) br.pat in
+ { pat; branch }
+ in
+ super#visit_expression env
+ (Switch (scrut, Match (List.map simplify_branch branches))))
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
let e, used = self#visit_texpression env e in
@@ -1008,17 +1172,21 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
* under some conditions. *)
match (filter_monadic_calls, opt_destruct_function_call re) with
| true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
- (* We need to check if there is a child call - see
- * the comments for:
- * [expression_contains_child_call_in_all_paths] *)
- let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid lp_id
- rg_id tys args e
- in
- if has_child_call then (* Filter *)
- (e.e, fun _ -> used)
- else (* No child call: don't filter *)
- dont_filter ()
+ (* If we split the forward/backward functions.
+
+ We need to check if there is a child call - see
+ the comments for:
+ [expression_contains_child_call_in_all_paths] *)
+ if not !Config.return_back_funs then
+ let has_child_call =
+ expression_contains_child_call_in_all_paths ctx fid
+ lp_id rg_id tys args e
+ in
+ if has_child_call then (* Filter *)
+ (e.e, fun _ -> used)
+ else (* No child call: don't filter *)
+ dont_filter ()
+ else dont_filter ()
| _ ->
(* Not an LLBC function call or not allowed to filter: we can't filter *)
dont_filter ()
@@ -1297,7 +1465,8 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
those function bodies into independent definitions while removing
occurrences of the {!Pure.Loop} node.
*)
-let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
+let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
+ fun_decl * fun_decl list =
match def.body with
| None -> (def, [])
| Some body ->
@@ -1323,77 +1492,55 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
method! visit_Loop env loop =
let fun_sig = def.signature in
- let fun_sig_info = fun_sig.info in
- let fun_effect_info = fun_sig_info.effect_info in
+ let fwd_info = fun_sig.fwd_info in
+ let fwd_effect_info = fwd_info.effect_info in
+ let ignore_output = fwd_info.ignore_output in
(* Generate the loop definition *)
- let loop_effect_info =
- {
- stateful_group = fun_effect_info.stateful_group;
- stateful = fun_effect_info.stateful;
- can_fail = fun_effect_info.can_fail;
- can_diverge = fun_effect_info.can_diverge;
- is_rec = fun_effect_info.is_rec;
- }
- in
+ let loop_fwd_effect_info = fwd_effect_info in
- let loop_sig_info =
+ let loop_fwd_sig_info : fun_sig_info =
let fuel = if !Config.use_fuel then 1 else 0 in
let num_inputs = List.length loop.inputs in
- let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
- let fwd_state =
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- - fun_sig_info.num_fwd_inputs_with_fuel_no_state
- in
- let num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_no_state + fwd_state
+ let fwd_info : inputs_info =
+ let info = fwd_info.fwd_info in
+ let fwd_state =
+ info.num_inputs_with_fuel_with_state
+ - info.num_inputs_with_fuel_no_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_inputs_no_fuel_no_state = num_inputs;
+ num_inputs_with_fuel_no_state = num_inputs + fuel;
+ num_inputs_with_fuel_with_state =
+ num_inputs + fuel + fwd_state;
+ }
in
- {
- has_fuel = !Config.use_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
- num_back_inputs_with_state =
- fun_sig_info.num_back_inputs_with_state;
- effect_info = loop_effect_info;
- }
+
+ { fwd_info; effect_info = loop_fwd_effect_info; ignore_output }
in
+ assert (fun_sig_info_is_wf loop_fwd_sig_info);
let inputs_tys =
let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
- let state =
+ let info = fwd_info.fwd_info in
+ let fwd_state =
Collections.List.subslice fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_no_state
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_no_state
+ info.num_inputs_with_fuel_with_state
in
- let _, back_inputs =
- Collections.List.split_at fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ let back_inputs =
+ if !Config.return_back_funs then []
+ else
+ snd
+ (Collections.List.split_at fun_sig.inputs
+ info.num_inputs_with_fuel_with_state)
in
- List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
in
- let output, doutputs =
- match loop.back_output_tys with
- | None ->
- (* Forward function: the return type is the same as the
- parent function *)
- (fun_sig.output, fun_sig.doutputs)
- | Some doutputs ->
- (* Backward function: custom return type *)
- let output = mk_simpl_tuple_ty doutputs in
- let output =
- if loop_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if loop_effect_info.can_fail then mk_result_ty output
- else output
- in
- (output, doutputs)
- in
+ let output = loop.output_ty in
let loop_sig =
{
@@ -1402,8 +1549,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
preds = fun_sig.preds;
inputs = inputs_tys;
output;
- doutputs;
- info = loop_sig_info;
+ fwd_info = loop_fwd_sig_info;
+ back_effect_info = fun_sig.back_effect_info;
}
in
@@ -1420,7 +1567,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
(* Introduce the forward input state *)
let fwd_state_var, fwd_state_lvs =
assert (
- loop_effect_info.stateful = Option.is_some loop.input_state);
+ loop_fwd_effect_info.stateful
+ = Option.is_some loop.input_state);
match loop.input_state with
| None -> ([], [])
| Some input_state ->
@@ -1429,15 +1577,16 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
([ state_var ], [ state_lvs ])
in
- (* Introduce the additional backward inputs *)
+ (* Introduce the additional backward inputs, if necessary *)
let fun_body = Option.get def.body in
+ let info = fwd_info.fwd_info in
let _, back_inputs =
Collections.List.split_at fun_body.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_with_state
in
let _, back_inputs_lvs =
Collections.List.split_at fun_body.inputs_lvs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_with_state
in
let inputs =
@@ -1508,9 +1657,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
altogether.
*)
let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
- (* Note that at this point, the output types are no longer seen as tuples:
- * they should be lists of length 1. *)
- if
+ (* The question of filtering the forward functions arises only if we split
+ the forward/backward functions *)
+ if !Config.return_back_funs then true
+ else if
+ (* Note that at this point, the output types are no longer seen as tuples:
+ * they should be lists of length 1. *)
!Config.filter_useless_functions
&& fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
@@ -1816,11 +1968,15 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
log#ldebug
(lazy ("intro_struct_updates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Simplify the let-bindings *)
+ let def = simplify_let_bindings ctx def in
+ log#ldebug
+ (lazy ("simplify_let_bindings:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
(* Inline the useless variable reassignments *)
- let inline_named_vars = true in
- let inline_pure = true in
let def =
- inline_useless_var_reassignments ctx inline_named_vars inline_pure def
+ inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true
+ ~inline_pure:true def
in
log#ldebug
(lazy
@@ -1867,6 +2023,23 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
log#ldebug
(lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Simplify the let-bindings - some simplifications may have been unlocked by
+ the pass above (for instance, the lambda simplification) *)
+ let def = simplify_let_bindings ctx def in
+ log#ldebug
+ (lazy
+ ("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Inline the useless vars again *)
+ let def =
+ inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true
+ ~inline_pure:false def
+ in
+ log#ldebug
+ (lazy
+ ("inline_useless_var_assignments (pass 2):\n\n"
+ ^ fun_decl_to_string ctx def ^ "\n"));
+
(* Decompose the monadic let-bindings - used by Coq *)
let def =
if !Config.decompose_monadic_let_bindings then (
@@ -1917,67 +2090,6 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* We are done *)
def
-(** Apply all the micro-passes to a function.
-
- As loops are initially directly integrated into the function definition,
- {!apply_passes_to_def} extracts those loops definitions from the body;
- it thus returns the pair: (function def, loop defs). See {!decompose_loops}
- for more information.
-
- Will return [None] if the function is a backward function with no outputs.
-
- [ctx]: used only for printing.
- *)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- fun_and_loops option =
- (* Debug *)
- log#ldebug
- (lazy
- ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string def.back_id
- ^ ")"));
-
- log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* First, find names for the variables which are unnamed *)
- let def = compute_pretty_names def in
- log#ldebug
- (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* TODO: we might want to leverage more the assignment meta-data, for
- * aggregates for instance. *)
-
- (* TODO: reorder the branches of the matches/switches *)
-
- (* The meta-information is now useless: remove it.
- * Rk.: some passes below use the fact that we removed the meta-data
- * (otherwise we would have to "unmeta" expressions before matching) *)
- let def = remove_meta def in
- log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Remove the backward functions with no outputs.
- * Note that the calls to those functions should already have been removed,
- * when translating from symbolic to pure. Here, we remove the definitions
- * altogether, because they are now useless *)
- let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
- let opt_def = filter_if_backward_with_no_outputs def in
-
- match opt_def with
- | None ->
- log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
- None
- | Some def ->
- log#ldebug
- (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
-
- (* Extract the loop definitions by removing the {!Loop} node *)
- let def, loops = decompose_loops def in
-
- (* Apply the remaining passes *)
- let f = apply_end_passes_to_def ctx def in
- let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some { f; loops }
-
(** Small utility for {!filter_loop_inputs} *)
let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
let ls0, ls1 = Collections.List.split_at ls (List.length keep) in
@@ -2058,7 +2170,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* We only look at the forward inputs, without the state *)
let inputs_prefix, _ =
Collections.List.split_at body.inputs
- decl.signature.info.num_fwd_inputs_with_fuel_no_state
+ decl.signature.fwd_info.fwd_info.num_inputs_with_fuel_no_state
in
let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in
let inputs_prefix_length = List.length inputs_prefix in
@@ -2077,8 +2189,9 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
in
(* Set the fuel as used *)
- let sg_info = decl.signature.info in
- if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0));
+ let sg_info = decl.signature.fwd_info in
+ if sg_info.fwd_info.has_fuel then
+ set_used (fst (Collections.List.nth inputs 0));
let visitor =
object (self : 'self)
@@ -2162,37 +2275,54 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let { generics; llbc_generics; preds; inputs; output; doutputs; info }
- =
+ let {
+ generics;
+ llbc_generics;
+ preds;
+ inputs;
+ output;
+ fwd_info;
+ back_effect_info;
+ } =
decl.signature
in
+ let { fwd_info; effect_info; ignore_output } = fwd_info in
+
let {
has_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state;
- num_back_inputs_with_state;
- effect_info;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state;
} =
- info
+ fwd_info
in
let inputs = filter_prefix used_info inputs in
- let info =
+ let fwd_info =
{
has_fuel;
- num_fwd_inputs_with_fuel_no_state =
- num_fwd_inputs_with_fuel_no_state - num_filtered;
- num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_with_state - num_filtered;
- num_back_inputs_no_state;
- num_back_inputs_with_state;
- effect_info;
+ num_inputs_no_fuel_no_state =
+ num_inputs_no_fuel_no_state - num_filtered;
+ num_inputs_with_fuel_no_state =
+ num_inputs_with_fuel_no_state - num_filtered;
+ num_inputs_with_fuel_with_state =
+ num_inputs_with_fuel_with_state - num_filtered;
}
in
+
+ let fwd_info = { fwd_info; effect_info; ignore_output } in
+ assert (fun_sig_info_is_wf fwd_info);
let signature =
- { generics; llbc_generics; preds; inputs; output; doutputs; info }
+ {
+ generics;
+ llbc_generics;
+ preds;
+ inputs;
+ output;
+ fwd_info;
+ back_effect_info;
+ }
in
{ decl with signature }
@@ -2283,6 +2413,68 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* Return *)
transl
+(** Apply all the micro-passes to a function.
+
+ As loops are initially directly integrated into the function definition,
+ {!apply_passes_to_def} extracts those loops definitions from the body;
+ it thus returns the pair: (function def, loop defs). See {!decompose_loops}
+ for more information.
+
+ Will return [None] if the function is a backward function with no outputs.
+
+ [ctx]: used only for printing.
+ *)
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
+ fun_and_loops option =
+ (* Debug *)
+ log#ldebug
+ (lazy
+ ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
+ ^ Print.option_to_string T.RegionGroupId.to_string def.back_id
+ ^ ")"));
+
+ log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* First, find names for the variables which are unnamed *)
+ let def = compute_pretty_names def in
+ log#ldebug
+ (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* TODO: we might want to leverage more the assignment meta-data, for
+ * aggregates for instance. *)
+
+ (* TODO: reorder the branches of the matches/switches *)
+
+ (* The meta-information is now useless: remove it.
+ * Rk.: some passes below use the fact that we removed the meta-data
+ * (otherwise we would have to "unmeta" expressions before matching) *)
+ let def = remove_meta def in
+ log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Remove the backward functions with no outputs.
+
+ Note that the *calls* to those functions should already have been removed,
+ when translating from symbolic to pure. Here, we remove the definitions
+ altogether, because they are now useless *)
+ let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
+ let opt_def = filter_if_backward_with_no_outputs def in
+
+ match opt_def with
+ | None ->
+ log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
+ None
+ | Some def ->
+ log#ldebug
+ (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
+
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops ctx def in
+
+ (* Apply the remaining passes *)
+ let f = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ Some { f; loops }
+
(** Apply the micro-passes to a list of forward/backward translations.
This function also extracts the loop definitions from the function body
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index a62a2361..a989fd3b 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -120,7 +120,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
assert (output_ty = e.ty);
check_texpression ctx app;
check_texpression ctx arg
- | Abs (pat, body) ->
+ | Lambda (pat, body) ->
let pat_ty, body_ty = destruct_arrow e.ty in
assert (pat.ty = pat_ty);
assert (body.ty = body_ty);
@@ -188,12 +188,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
List.iter check_branch branches)
| Loop loop ->
assert (loop.fun_end.ty = e.ty);
- (* If we translate forward functions, the type of the loop is the same
- as the type of the parent expression - in case of backward functions,
- the loop doesn't necessarily give back the same values as the parent
- function
- *)
- assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty);
check_texpression ctx loop.fun_end;
check_texpression ctx loop.loop_body
| StructUpdate supd -> (
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 39dcd52d..80bf3c42 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -57,6 +57,23 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType)
+let inputs_info_is_wf (info : inputs_info) : bool =
+ let {
+ has_fuel;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state;
+ } =
+ info
+ in
+ let fuel = if has_fuel then 1 else 0 in
+ num_inputs_no_fuel_no_state >= 0
+ && num_inputs_with_fuel_no_state = num_inputs_no_fuel_no_state + fuel
+ && num_inputs_with_fuel_with_state >= num_inputs_with_fuel_no_state
+
+let fun_sig_info_is_wf (info : fun_sig_info) : bool =
+ inputs_info_is_wf info.fwd_info
+
let dest_arrow_ty (ty : ty) : ty * ty =
match ty with
| TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
@@ -187,9 +204,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
let subst = ty_substitute subst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
- let doutputs = List.map subst sg.doutputs in
- let info = sg.info in
- { inputs; output; doutputs; info }
+ { inputs; output }
(** We use this to check whether we need to add parentheses around expressions.
We only look for outer monadic let-bindings.
@@ -200,12 +215,14 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ ->
- false
+ | Var _ | CVar _ | Const _ | App _ | Qualif _ | StructUpdate _ -> false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
| Meta (_, next_e) -> let_group_requires_parentheses next_e
+ | Lambda (_, _) ->
+ (* Being conservative here *)
+ true
| Loop _ ->
(* Should have been eliminated *)
raise (Failure "Unreachable")
@@ -304,14 +321,26 @@ let destruct_apps (e : texpression) : texpression * texpression list =
(** Make an [App (app, arg)] expression *)
let mk_app (app : texpression) (arg : texpression) : texpression =
+ let raise_or_return msg =
+ if !Config.fail_hard then raise (Failure msg)
+ else
+ let e = App (app, arg) in
+ (* Dummy type - TODO: introduce an error type *)
+ let ty = app.ty in
+ { e; ty }
+ in
match app.ty with
| TArrow (ty0, ty1) ->
(* Sanity check *)
- assert (ty0 = arg.ty);
- let e = App (app, arg) in
- let ty = ty1 in
- { e; ty }
- | _ -> raise (Failure "Expected an arrow type")
+ if
+ (* TODO: we need to normalize the types *)
+ !Config.type_check_pure_code && ty0 <> arg.ty
+ then raise_or_return "App: wrong input type"
+ else
+ let e = App (app, arg) in
+ let ty = ty1 in
+ { e; ty }
+ | _ -> raise_or_return "Expected an arrow type"
(** The reverse of {!destruct_apps} *)
let mk_apps (app : texpression) (args : texpression list) : texpression =
@@ -356,18 +385,6 @@ let opt_destruct_tuple (ty : ty) : ty list option =
Some generics.types
| _ -> None
-let mk_abs (x : typed_pattern) (e : texpression) : texpression =
- let ty = TArrow (x.ty, e.ty) in
- let e = Abs (x, e) in
- { e; ty }
-
-let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression =
- match e.e with
- | Abs (x, e') ->
- let xl, e'' = destruct_abs_list e' in
- (x :: xl, e'')
- | _ -> ([], e)
-
let destruct_arrow (ty : ty) : ty * ty =
match ty with
| TArrow (ty0, ty1) -> (ty0, ty1)
@@ -431,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty =
let mk_bool_ty : ty = TLiteral TBool
let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args)
+let ty_is_unit ty : bool = ty = mk_unit_ty
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = TTuple; variant_id = None } in
@@ -698,3 +716,39 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
let info = TypeDeclId.Map.find id ctx in
info.is_tuple_struct
| TAssumed _ -> false
+
+let mk_lambda (x : typed_pattern) (e : texpression) : texpression =
+ let ty = TArrow (x.ty, e.ty) in
+ let e = Lambda (x, e) in
+ { e; ty }
+
+let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) :
+ texpression =
+ let pat = PatVar (var, mp) in
+ let pat = { value = pat; ty = var.ty } in
+ mk_lambda pat e
+
+let mk_lambdas_from_vars (vars : var list) (mps : mplace option list)
+ (e : texpression) : texpression =
+ let vars = List.combine vars mps in
+ List.fold_right (fun (v, mp) e -> mk_lambda_from_var v mp e) vars e
+
+let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression =
+ match e.e with
+ | Lambda (pat, e) ->
+ let pats, e = destruct_lambdas e in
+ (pat :: pats, e)
+ | _ -> ([], e)
+
+let opt_dest_tuple_texpression (e : texpression) : texpression list option =
+ let app, args = destruct_apps e in
+ match app.e with
+ | Qualif { id = AdtCons { adt_id = TTuple; variant_id = None }; generics = _ }
+ ->
+ Some args
+ | _ -> None
+
+let opt_dest_struct_pattern (pat : typed_pattern) : typed_pattern list option =
+ match pat.value with
+ | PatAdt { variant_id = None; field_values } -> Some field_values
+ | _ -> None
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 53f99b7f..8e8cdec3 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -42,8 +42,16 @@ type call = {
evaluated). We need it to compute the translated values for shared
borrows (we need to perform lookups).
*)
+ sg : fun_sig option;
+ (** The uninstantiated function signature, if this is not a unop/binop *)
+ regions_hierarchy : region_var_groups;
abstractions : AbstractionId.id list;
+ (** The region abstractions introduced upon calling the function *)
generics : generic_args;
+ trait_method_generics : (generic_args * trait_instance_id) option;
+ (** In case the call is to a trait method, we may need an additional type
+ parameter ([Self]) and the self trait clause to instantiate the
+ function signature. *)
args : typed_value list;
args_places : mplace option list; (** Meta information *)
dest : symbolic_value;
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 84f09280..3a50e495 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -15,7 +15,7 @@ module PP = PrintPure
(** The local logger *)
let log = Logging.symbolic_to_pure_log
-type type_context = {
+type type_ctx = {
llbc_type_decls : T.type_decl TypeDeclId.Map.t;
type_decls : type_decl TypeDeclId.Map.t;
(** We use this for type-checking (for sanity checks) when translating
@@ -43,19 +43,18 @@ type fun_sig_named_outputs = {
}
[@@deriving show]
-type fun_context = {
+type fun_ctx = {
llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t;
- fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *)
fun_infos : fun_info A.FunDeclId.Map.t;
regions_hierarchies : T.region_var_groups FunIdMap.t;
}
[@@deriving show]
-type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
+type global_ctx = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
-type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
-type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+type trait_decls_ctx = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_ctx = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
@@ -68,11 +67,21 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
- backwards : (V.abs * texpression list) T.RegionGroupId.Map.t;
- (** A map from region group id (i.e., backward function id) to
- pairs (abstraction, additional arguments received by the backward function)
-
- TODO: remove? it is also in the bs_ctx ("abstractions" field)
+ back_funs : texpression option RegionGroupId.Map.t option;
+ (** If we do not split between the forward/backward functions: the
+ variables we introduced for the backward functions.
+
+ Example:
+ {[
+ let x, back = Vec.index_mut n v in
+ ^^^^
+ here
+ ...
+ ]}
+
+ The expression might be [None] in case the backward function
+ has to be filtered (because it does nothing - the backward
+ functions for shared borrows for instance).
*)
}
[@@deriving show]
@@ -114,33 +123,61 @@ type loop_info = {
generics : generic_args;
forward_inputs : texpression list option;
(** The forward inputs are initialized at [None] *)
- forward_output_no_state_no_result : var option;
+ forward_output_no_state_no_result : texpression option;
(** The forward outputs are initialized at [None] *)
+ back_outputs : ty list RegionGroupId.Map.t;
+ (** The map from region group ids to the types of the values given back
+ by the corresponding loop abstractions.
+ *)
+ back_funs : texpression option RegionGroupId.Map.t option;
+ (** Same as {!call_info.back_funs}.
+ Initialized with [None], gets updated to [Some] only if we merge
+ the fwd/back functions.
+ *)
+ fwd_effect_info : fun_effect_info;
+ back_effect_infos : fun_effect_info RegionGroupId.Map.t;
}
[@@deriving show]
(** Body synthesis context *)
type bs_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
- trait_decls_ctx : trait_decls_context;
- trait_impls_ctx : trait_impls_context;
+ (* TODO: there are a lot of duplications with the various decls ctx *)
+ decls_ctx : C.decls_ctx;
+ type_ctx : type_ctx;
+ fun_ctx : fun_ctx;
+ global_ctx : global_ctx;
+ trait_decls_ctx : trait_decls_ctx;
+ trait_impls_ctx : trait_impls_ctx;
+ fun_dsigs : decomposed_fun_sig FunDeclId.Map.t;
fun_decl : A.fun_decl;
- bid : T.RegionGroupId.id option; (** TODO: rename *)
- sg : fun_sig;
- (** The function signature - useful in particular to translate [Panic] *)
- fwd_sg : fun_sig; (** The signature of the forward function *)
+ bid : RegionGroupId.id option;
+ (** TODO: rename
+
+ The id of the group region we are currently translating.
+ If we split the forward/backward functions, we set this id at the
+ very beginning of the translation.
+ If we don't split, we set it to `None`, then update it when we enter
+ an expression which is specific to a backward function.
+ *)
+ sg : decomposed_fun_sig;
+ (** Information about the function signature - useful in particular to
+ translate [Panic] *)
sv_to_var : var V.SymbolicValueId.Map.t;
(** Whenever we encounter a new symbolic value (introduced because of
a symbolic expansion or upon ending an abstraction, for instance)
we introduce a new variable (with a let-binding).
*)
- var_counter : VarId.generator;
+ var_counter : VarId.generator ref;
+ (** Using a ref to make sure all the variables identifiers are unique.
+ TODO: this is not very clean, and the code was initially written without
+ a reference (and it's shape hasn't changed). We should use DeBruijn indices.
+ *)
state_var : VarId.id;
(** The current state variable, in case the function is stateful *)
- back_state_var : VarId.id;
- (** The additional input state variable received by a stateful backward function.
+ back_state_vars : VarId.id RegionGroupId.Map.t;
+ (** The additional input state variable received by a stateful backward
+ function, **in case we are splitting the forward/backward functions**.
+
When generating stateful functions, we generate code of the following
form:
@@ -153,7 +190,9 @@ type bs_ctx = {
When translating a backward function, we need at some point to update
[state_var] with [back_state_var], to account for the fact that the
state may have been updated by the caller between the call to the
- forward function and the call to the backward function.
+ forward function and the call to the backward function. We also need
+ to make sure we use the same variable in all the branches (because
+ this variable is quantified at the definition level).
*)
fuel0 : VarId.id;
(** The original fuel taken as input by the function (if we use fuel) *)
@@ -163,28 +202,50 @@ type bs_ctx = {
(** The input parameters for the forward function corresponding to the
translated Rust inputs (no fuel, no state).
*)
- backward_inputs : var list T.RegionGroupId.Map.t;
+ backward_inputs_no_state : var list RegionGroupId.Map.t;
(** The additional input parameters for the backward functions coming
from the borrows consumed upon ending the lifetime (as a consequence
those don't include the backward state, if there is one).
- *)
- backward_outputs : var list T.RegionGroupId.Map.t;
- (** The variables that the backward functions will output, corresponding
- to the borrows they give back (don't include the backward state)
- *)
- loop_backward_outputs : var list T.RegionGroupId.Map.t option;
- (** Same as {!backward_outputs}, but for loops (if we entered a loop).
- [None] if we are not inside a loop, [Some] otherwise (and whatever
- the kind of function we are translating: it will be [Some] even
- though we are synthesizing a forward function).
+ If we split the forward/backward functions: we initialize this map
+ when initializing the bs_ctx, because those variables are quantified
+ at the definition level. Otherwise, we initialize it upon diving
+ into the expressions which are specific to the backward functions.
+ *)
+ backward_inputs_with_state : var list RegionGroupId.Map.t;
+ (** All the additional input parameters for the backward functions.
- TODO: move to {!loop_info}
+ Same remarks as for {!backward_inputs_no_state}.
+ *)
+ backward_outputs : var list option;
+ (** The variables that the backward functions will output, corresponding
+ to the borrows they give back (don't include the backward state).
+
+ The translation is done as follows:
+ - when we detect the ended input abstraction which corresponds
+ to the backward function of the LLBC function we are translating,
+ and which consumed the values [consumed_i] (that we need to give
+ back to the caller), we introduce:
+ {[
+ let v_i = consumed_i in
+ ...
+ ]}
+ where the [v_i] are fresh, and are stored in the [backward_output].
+ - Then, upon reaching the [Return] node, we introduce:
+ {[
+ return (v_i)
+ ]}
+
+ The option is [None] before we detect the ended input abstraction,
+ and [Some] afterwards.
*)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
- (** The ended abstractions we encountered so far, with their additional input arguments *)
+ (** The ended abstractions we encountered so far, with their additional
+ input arguments. We store it here and not in {!call_info} because
+ we need a map from abstraction id to abstraction (and not
+ from call id + region group id to abstraction). *)
loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *)
loops : loop_info LoopId.Map.t;
(** The loops we encountered so far.
@@ -209,9 +270,9 @@ type bs_ctx = {
(* TODO: move *)
let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env =
- let type_decls = ctx.type_context.llbc_type_decls in
- let fun_decls = ctx.fun_context.llbc_fun_decls in
- let global_decls = ctx.global_context.llbc_global_decls in
+ let type_decls = ctx.type_ctx.llbc_type_decls in
+ let fun_decls = ctx.fun_ctx.llbc_fun_decls in
+ let global_decls = ctx.global_ctx.llbc_global_decls in
let trait_decls = ctx.trait_decls_ctx in
let trait_impls = ctx.trait_impls_ctx in
let { regions; types; const_generics; trait_clauses } : T.generic_params =
@@ -233,9 +294,9 @@ let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env =
}
let bs_ctx_to_pure_fmt_env (ctx : bs_ctx) : PrintPure.fmt_env =
- let type_decls = ctx.type_context.llbc_type_decls in
- let fun_decls = ctx.fun_context.llbc_fun_decls in
- let global_decls = ctx.global_context.llbc_global_decls in
+ let type_decls = ctx.type_ctx.llbc_type_decls in
+ let fun_decls = ctx.fun_ctx.llbc_fun_decls in
+ let global_decls = ctx.global_ctx.llbc_global_decls in
let trait_decls = ctx.trait_decls_ctx in
let trait_impls = ctx.trait_impls_ctx in
let generics = ctx.sg.generics in
@@ -300,6 +361,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string =
let env = bs_ctx_to_pure_fmt_env ctx in
PrintPure.typed_pattern_to_string env p
+let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) :
+ fun_effect_info =
+ match bid with
+ | None -> ctx.sg.fwd_info.effect_info
+ | Some bid ->
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ back_sg.effect_info
+
+let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info =
+ ctx_get_effect_info_for_bid ctx ctx.bid
+
(* TODO: move *)
let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let env = bs_ctx_to_fmt_env ctx in
@@ -308,33 +380,13 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let indent_incr = " " in
Print.Values.abs_to_string env verbose indent indent_incr abs
-let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (generics : generic_args)
- (ctx : bs_ctx) : inst_fun_sig =
- (* Lookup the non-instantiated function signature *)
- let sg =
- (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
- in
- (* Create the substitution *)
- (* There shouldn't be any reference to Self *)
- let tr_self = UnknownTrait __FUNCTION__ in
- let subst = make_subst_from_generics sg.generics generics tr_self in
- (* Apply *)
- fun_sig_substitute subst sg
-
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
- TypeDeclId.Map.find id ctx.type_context.llbc_type_decls
+ TypeDeclId.Map.find id ctx.type_ctx.llbc_type_decls
let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) :
A.fun_decl =
- A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls
-
-(* TODO: move *)
-let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id)
- (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig =
- let id = (E.FRegular def_id, back_id) in
- (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg
+ A.FunDeclId.Map.find id ctx.fun_ctx.llbc_fun_decls
(* Some generic translation functions (we need to translate different "flavours"
of types: forward types, backward types, etc.) *)
@@ -601,13 +653,13 @@ and translate_fwd_trait_instance_id (type_infos : type_infos)
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : T.ty) : ty =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_fwd_ty type_infos ty
(** Simply calls [translate_fwd_generic_args] *)
let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : T.generic_args) :
generic_args =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_fwd_generic_args type_infos generics
(** Translate a type, when some regions may have ended.
@@ -682,17 +734,21 @@ let rec translate_back_ty (type_infos : type_infos)
None
| TTraitType (trait_ref, generics, type_name) ->
assert (generics.regions = []);
- (* Translate the trait ref and the generics as "forward" generics -
- we do not want to filter any type *)
- let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
- let generics = translate_fwd_generic_args type_infos generics in
- Some (TTraitType (trait_ref, generics, type_name))
+ assert (
+ AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id);
+ if inside_mut then
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TTraitType (trait_ref, generics, type_name))
+ else None
| TArrow _ -> raise (Failure "TODO")
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
(inside_mut : bool) (ty : T.ty) : ty option =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_back_ty type_infos keep_region inside_mut ty
let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
@@ -705,8 +761,8 @@ let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
in
let env = VarId.Map.empty in
{
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
+ PureTypeCheck.type_decls = ctx.type_ctx.type_decls;
+ global_decls = ctx.global_ctx.llbc_global_decls;
env;
const_generics;
}
@@ -726,31 +782,36 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
match id with
| FunId fun_id -> FunId fun_id
| TraitMethod (trait_ref, method_name, fun_decl_id) ->
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
TraitMethod (trait_ref, method_name, fun_decl_id)
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ (args : texpression list)
+ (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) :
+ bs_ctx =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
- let info =
- { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
- in
+ let info = { forward; forward_inputs = args; back_funs } in
let calls = V.FunCallId.Map.add call_id info calls in
{ ctx with calls }
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
+(** [inherit_args]: the list of inputs inherited from the forward function and
+ the ancestors backward functions, if pertinent.
+ [back_args]: the *additional* list of inputs received by the backward function,
+ including the state.
+
+ Returns the updated context and the expression corresponding to the function
+ that we need to call. This function may be [None] if it has to be ignored
+ (because it does nothing).
+ *)
+let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
+ (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
+ (inherited_args : texpression list) (back_args : texpression list)
+ (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
+ bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
- assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards =
- T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
- in
- let info = { info with backwards } in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
(* Insert the abstraction in the abstractions map *)
let abstractions = ctx.abstractions in
@@ -758,16 +819,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ (* Compute the expression corresponding to the function *)
+ let func =
+ if !Config.return_back_funs then
+ (* Lookup the variable introduced for the backward function *)
+ RegionGroupId.Map.find back_id (Option.get info.back_funs)
+ else
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ let args = List.append inherited_args back_args in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output_ty else output_ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp fun_id; generics } in
+ Some { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+ ({ ctx with calls; abstractions }, func)
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -821,7 +897,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else []
(** Small utility. *)
-let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
+let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
(fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
@@ -848,33 +924,53 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
is_rec = false;
}
-(** Translate a function signature.
+(** TODO: not very clean. *)
+let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) :
+ fun_effect_info =
+ match lid with
+ | None -> (
+ match fun_id with
+ | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
+ let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
+ let info =
+ match gid with
+ | None -> dsg.fwd_info.effect_info
+ | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
+ in
+ { info with is_rec = info.is_rec || Option.is_some lid }
+ | FunId (FAssumed _) ->
+ compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid)
+ | Some lid -> (
+ (* This is necessarily for the current function *)
+ match fun_id with
+ | FunId (FRegular fid) -> (
+ assert (fid = ctx.fun_decl.def_id);
+ (* Lookup the loop *)
+ let lid = V.LoopId.Map.find lid ctx.loop_ids_map in
+ let loop_info = LoopId.Map.find lid ctx.loops in
+ match gid with
+ | None -> loop_info.fwd_effect_info
+ | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos)
+ | _ -> raise (Failure "Unreachable"))
+
+(** Translate a function signature to a decomposed function signature.
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
of the forward function) which we use as hints to generate pretty names
in the extracted code.
+
+ We use [bid] ("backward function id") only if we split the forward
+ and the backward functions.
*)
-let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
- (sg : A.fun_sig) (input_names : string option list)
- (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+let translate_fun_sig_with_regions_hierarchy_to_decomposed
+ (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig)
+ (input_names : string option list) : decomposed_fun_sig =
let fun_infos = decls_ctx.fun_ctx.fun_infos in
let type_infos = decls_ctx.type_ctx.type_infos in
- (* Retrieve the list of parent backward functions *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
- let gid, parents =
- match bid with
- | None -> (None, T.RegionGroupId.Set.empty)
- | Some bid ->
- let parents = list_ancestor_region_groups regions_hierarchy bid in
- (Some bid, parents)
- in
- (* Is the function stateful, and can it fail? *)
- let lid = None in
- let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in
(* We need an evaluation context to normalize the types (to normalize the
associated types, etc. - for instance it may happen that the types
refer to the types associated to a trait ref, but where the trait ref
@@ -885,7 +981,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
List.map (fun (g : T.region_var_group) -> g.id) regions_hierarchy
in
let ctx =
- InterpreterUtils.initialize_eval_context decls_ctx region_groups
+ InterpreterUtils.initialize_eval_ctx decls_ctx region_groups
sg.generics.types sg.generics.const_generics
in
(* Compute the normalization map for the *sty* types and add it to the context *)
@@ -902,17 +998,28 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
{ sg with A.inputs; output }
in
- (* List the inputs for:
- * - the fuel
- * - the forward function
- * - the parent backward functions, in proper order
- * - the current backward function (if it is a backward function)
- *)
- let fuel = mk_fuel_input_ty_as_list effect_info in
- let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in
- (* For the backward functions: for now we don't supported nested borrows,
- * so just check that there aren't parent regions *)
- assert (T.RegionGroupId.Set.is_empty parents);
+ (* Is the forward function stateful, and can it fail? *)
+ let fwd_effect_info =
+ compute_raw_fun_effect_info fun_infos fun_id None None
+ in
+ (* Compute the forward inputs *)
+ let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in
+ let fwd_inputs_no_fuel_no_state =
+ List.map (translate_fwd_ty type_infos) sg.inputs
+ in
+ (* State input for the forward function *)
+ let fwd_state_ty =
+ (* For the forward state, we check if the *whole group* is stateful.
+ See {!effect_info}. *)
+ if fwd_effect_info.stateful_group then [ mk_state_ty ] else []
+ in
+ let fwd_inputs =
+ List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ]
+ in
+ (* Compute the backward output, without the effect information *)
+ let fwd_output = translate_fwd_ty type_infos sg.output in
+
+ (* Compute the type information for the backward function *)
(* Small helper to translate types for backward functions *)
let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) :
ty option =
@@ -939,163 +1046,336 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
let inside_mut = false in
translate_back_ty type_infos keep_region inside_mut ty
in
- (* Compute the additinal inputs for the current function, if it is a backward
- * function *)
- let back_inputs =
- match gid with
- | None -> []
- | Some gid ->
- (* For now, we don't allow nested borrows, so the additional inputs to the
- backward function can only come from borrows that were returned like
- in (for the backward function we introduce for 'a):
- {[
- fn f<'a>(...) -> &'a mut u32;
- ]}
- Upon ending the abstraction for 'a, we need to get back the borrow
- the function returned.
- *)
- List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
- in
- (* If the function is stateful, the inputs are:
- - forward: [fwd_ty0, ..., fwd_tyn, state]
- - backward:
- - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state]
- - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty]
-
- The backward takes the same state as input as the forward function,
- together with the state at the point where it gets called, if it is
- stateful.
-
- See the comments for {!Config.backward_no_state_update}
- *)
- let fwd_state_ty =
- (* For the forward state, we check if the *whole group* is stateful.
- See {!effect_info}. *)
- if effect_info.stateful_group then [ mk_state_ty ] else []
+ let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list =
+ (* For now we don't supported nested borrows, so we check that there
+ aren't parent regions *)
+ let parents = list_ancestor_region_groups regions_hierarchy gid in
+ assert (T.RegionGroupId.Set.is_empty parents);
+ (* For now, we don't allow nested borrows, so the additional inputs to the
+ backward function can only come from borrows that were returned like
+ in (for the backward function we introduce for 'a):
+ {[
+ fn f<'a>(...) -> &'a mut u32;
+ ]}
+ Upon ending the abstraction for 'a, we need to get back the borrow
+ the function returned.
+ *)
+ let inputs =
+ List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let output = Print.Types.ty_to_string ctx sg.output in
+ let inputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) inputs
+ in
+ "translate_back_inputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n"));
+ inputs
+ in
+ let compute_back_outputs_for_gid (gid : RegionGroupId.id) :
+ string option list * ty list =
+ (* The outputs are the borrows inside the regions of the abstractions
+ and which are present in the input values. For instance, see:
+ {[
+ fn f<'a>(x : &'a mut u32) -> ...;
+ ]}
+ Upon ending the abstraction for 'a, we give back the borrow which
+ was consumed through the [x] parameter.
+ *)
+ let outputs =
+ List.map
+ (fun (name, input_ty) -> (name, translate_back_ty_for_gid gid input_ty))
+ (List.combine input_names sg.inputs)
+ in
+ (* Filter *)
+ let outputs =
+ List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs
+ in
+ let outputs =
+ List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs
+ in
+ let names, outputs = List.split outputs in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let inputs =
+ Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs
+ in
+ let outputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) outputs
+ in
+ "compute_back_outputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n"));
+ (names, outputs)
+ in
+ let compute_back_info_for_group (rg : T.region_var_group) :
+ RegionGroupId.id * back_sg_info =
+ let gid = rg.id in
+ let back_effect_info =
+ compute_raw_fun_effect_info fun_infos fun_id None (Some gid)
+ in
+ let inputs_no_state = translate_back_inputs_for_gid gid in
+ let inputs_no_state =
+ List.map (fun ty -> (Some "ret", ty)) inputs_no_state
+ in
+ (* In case we merge the forward/backward functions:
+ we consider the backward function as stateful and potentially failing
+ **only if it has inputs** (for the "potentially failing": if it has
+ not inputs, we directly evaluate it in the body of the forward function).
+ *)
+ let back_effect_info =
+ if !Config.return_back_funs then
+ let b = inputs_no_state <> [] in
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && b;
+ can_fail = back_effect_info.can_fail && b;
+ }
+ else back_effect_info
+ in
+ let state =
+ if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
+ in
+ let inputs = inputs_no_state @ state in
+ let output_names, outputs = compute_back_outputs_for_gid gid in
+ let filter =
+ !Config.simplify_merged_fwd_backs
+ && !Config.return_back_funs && inputs = [] && outputs = []
+ in
+ let info =
+ {
+ inputs;
+ inputs_no_state;
+ outputs;
+ output_names;
+ effect_info = back_effect_info;
+ filter;
+ }
+ in
+ (gid, info)
in
- let back_state_ty =
- (* For the backward state, we check if the function is a backward function,
- and it is stateful *)
- if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else []
+ let back_sg =
+ RegionGroupId.Map.of_list
+ (List.map compute_back_info_for_group regions_hierarchy)
in
- (* Concatenate the inputs, in the following order:
- * - forward inputs
- * - forward state input
- * - backward inputs
- * - backward state input
- *)
- let inputs =
- List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
- in
- (* Outputs *)
- let output_names, doutputs =
- match gid with
- | None ->
- (* This is a forward function: there is one (unnamed) output *)
- ([ None ], [ translate_fwd_ty type_infos sg.output ])
- | Some gid ->
- (* This is a backward function: there might be several outputs.
- The outputs are the borrows inside the regions of the abstractions
- and which are present in the input values. For instance, see:
- {[
- fn f<'a>(x : &'a mut u32) -> ...;
- ]}
- Upon ending the abstraction for 'a, we give back the borrow which
- was consumed through the [x] parameter.
- *)
- let outputs =
- List.map
- (fun (name, input_ty) ->
- (name, translate_back_ty_for_gid gid input_ty))
- (List.combine input_names sg.inputs)
- in
- (* Filter *)
- let outputs =
- List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs
- in
- let outputs =
- List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs
- in
- List.split outputs
- in
- (* Create the return type *)
- let output =
- (* Group the outputs together *)
- let output = mk_simpl_tuple_ty doutputs in
- (* Add the output state *)
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
+ (* The additional information about the forward function *)
+ let fwd_info =
+ (* *)
+ let has_fuel = fwd_fuel <> [] in
+ let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
+ let num_inputs_with_fuel_no_state =
+ (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
+ List.length fwd_fuel + num_inputs_no_fuel_no_state
+ in
+ let fwd_info : inputs_info =
+ {
+ has_fuel;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state =
+ (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
+ and 0 otherwise *)
+ num_inputs_with_fuel_no_state + List.length fwd_state_ty;
+ }
+ in
+ let ignore_output =
+ if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ ty_is_unit fwd_output
+ && List.exists
+ (fun (info : back_sg_info) -> not info.filter)
+ (RegionGroupId.Map.values back_sg)
+ else false
in
- (* Wrap in a result type *)
- if effect_info.can_fail then mk_result_ty output else output
+ let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in
+ assert (fun_sig_info_is_wf info);
+ info
in
+
(* Generic parameters *)
let generics = translate_generic_params sg.generics in
+
(* Return *)
- let has_fuel = fuel <> [] in
- let num_fwd_inputs_no_state = List.length fwd_inputs in
- let num_fwd_inputs_with_fuel_no_state =
- (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
- List.length fuel + num_fwd_inputs_no_state
- in
- let num_back_inputs_no_state =
- if bid = None then None else Some (List.length back_inputs)
+ let preds = translate_predicates sg.preds in
+ {
+ generics;
+ llbc_generics = sg.generics;
+ preds;
+ fwd_inputs;
+ fwd_output;
+ back_sg;
+ fwd_info;
+ }
+
+let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
+ (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list)
+ : decomposed_fun_sig =
+ (* Retrieve the list of parent backward functions *)
+ let regions_hierarchy =
+ FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies
in
- let info =
- {
- has_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state =
- (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
- and 0 otherwise *)
- num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
- num_back_inputs_no_state;
- num_back_inputs_with_state =
- (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
- Option.map
- (fun n -> n + List.length back_state_ty)
- num_back_inputs_no_state;
- effect_info;
- }
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ (FunId (FRegular fun_id)) regions_hierarchy sg input_names
+
+let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx)
+ (fdef : LlbcAst.fun_decl) : decomposed_fun_sig =
+ let input_names =
+ match fdef.body with
+ | None -> List.map (fun _ -> None) fdef.signature.inputs
+ | Some body ->
+ List.map
+ (fun (v : LlbcAst.var) -> v.name)
+ (LlbcAstUtils.fun_body_get_input_vars body)
in
- let preds = translate_predicates sg.preds in
let sg =
- {
- generics;
- llbc_generics = sg.generics;
- preds;
- inputs;
- output;
- doutputs;
- info;
- }
+ translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature
+ input_names
+ in
+ log#ldebug
+ (lazy
+ ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: "
+ ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n"));
+ sg
+
+let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
+ =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
in
- { sg; output_names }
+ if effect_info.can_fail then mk_result_ty output else output
-let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
+ (inputs : ty list) (ty : ty) : ty =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
+ in
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output else output
+
+(** Compute the arrow types for all the backward functions.
+
+ If a backward function has no inputs/outputs we filter it.
+ *)
+let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) :
+ (back_sg_info * ty) option list =
+ List.map
+ (fun (back_sg : back_sg_info) ->
+ let effect_info = back_sg.effect_info in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = back_sg.outputs in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
+ None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ (* Substitute - TODO: normalize *)
+ let ty =
+ match subst with
+ | None -> ty
+ | Some (generics, tr_self) ->
+ let subst =
+ make_subst_from_generics dsg.generics generics tr_self
+ in
+ ty_substitute subst ty
+ in
+ Some (back_sg, ty))
+ (RegionGroupId.Map.values dsg.back_sg)
+
+let compute_back_tys (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) : ty option list =
+ List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
+
+(** In case we merge the fwd/back functions: compute the output type of
+ a function, from a decomposed signature. *)
+let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
+ assert !Config.return_back_funs;
+ (* Compute the arrow types for all the backward functions *)
+ let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
+ (* Group the forward output and the types of the backward functions *)
+ let effect_info = dsg.fwd_info.effect_info in
+ let output =
+ (* We might need to ignore the output of the forward function
+ (if it is unit for instance) *)
+ let tys =
+ if dsg.fwd_info.ignore_output then back_tys
+ else dsg.fwd_output :: back_tys
+ in
+ mk_simpl_tuple_ty tys
+ in
+ mk_output_ty_from_effect_info effect_info output
+
+let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
+ (gid : RegionGroupId.id option) : fun_sig =
+ let generics = dsg.generics in
+ let llbc_generics = dsg.llbc_generics in
+ let preds = dsg.preds in
+ (* Compute the effects info *)
+ let fwd_info = dsg.fwd_info in
+ let back_effect_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((gid, info) : RegionGroupId.id * back_sg_info) ->
+ (gid, info.effect_info))
+ (RegionGroupId.Map.bindings dsg.back_sg))
+ in
+ let mk_output_ty = mk_output_ty_from_effect_info in
+ let inputs, output =
+ (* Two cases depending on whether we split the forward/backward functions or not *)
+ if !Config.return_back_funs then (
+ assert (gid = None);
+ let output = compute_output_ty_from_decomposed dsg in
+ let inputs = dsg.fwd_inputs in
+ (inputs, output))
+ else
+ match gid with
+ | None ->
+ let effect_info = dsg.fwd_info.effect_info in
+ let output = mk_output_ty effect_info dsg.fwd_output in
+ (dsg.fwd_inputs, output)
+ | Some gid ->
+ let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
+ let effect_info = back_sg.effect_info in
+ let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
+ let output = mk_simpl_tuple_ty back_sg.outputs in
+ let output = mk_output_ty effect_info output in
+ (inputs, output)
+ in
+ { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
+
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let state_var =
{ id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
let state_pat = mk_typed_pattern_from_var state_var None in
(* Update the context *)
- let ctx = { ctx with var_counter; state_var = id } in
+ ctx.var_counter := var_counter;
+ let ctx = { ctx with state_var = id } in
(* Return *)
- (ctx, state_pat)
+ (ctx, state_var, state_pat)
(** WARNING: do not call this function directly.
Call [fresh_named_var_for_symbolic_value] instead. *)
let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) :
bs_ctx * var =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let ty = ctx_translate_fwd_ty ctx ty in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -1129,10 +1409,10 @@ let fresh_named_vars_for_symbolic_values
let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var
=
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -1140,6 +1420,58 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars
+let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) :
+ bs_ctx * var option list =
+ List.fold_left_map
+ (fun ctx var ->
+ match var with
+ | None -> (ctx, None)
+ | Some (name, ty) ->
+ let ctx, var = fresh_var name ty ctx in
+ (ctx, Some var))
+ ctx vars
+
+(* Introduce variables for the backward functions *)
+let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
+ (* We lookup the LLBC definition in an attempt to derive pretty names
+ for the backward functions. *)
+ let back_var_names =
+ let def_id = ctx.fun_decl.def_id in
+ let sg = ctx.fun_decl.signature in
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id)
+ ctx.fun_ctx.regions_hierarchies
+ in
+ List.map
+ (fun (gid, _) ->
+ let rg = RegionGroupId.nth regions_hierarchy gid in
+ let region_names =
+ List.map
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
+ rg.regions
+ in
+ let name =
+ match region_names with
+ | [] -> "back"
+ | [ Some r ] -> "back" ^ r
+ | _ ->
+ (* Concatenate all the region names *)
+ "back"
+ ^ String.concat "" (List.filter_map (fun x -> x) region_names)
+ in
+ Some name)
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+ let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
+ let back_vars =
+ List.map
+ (fun (name, ty) ->
+ match ty with None -> None | Some ty -> Some (name, ty))
+ back_vars
+ in
+ fresh_opt_vars back_vars ctx
+
+(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *)
let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
| Some v -> v
@@ -1158,12 +1490,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
| _ -> raise (Failure "Unreachable"))
| _ -> v
-(** Translate a symbolic value *)
+(** Translate a symbolic value.
+
+ Because we do not necessarily introduce variables for the symbolic values
+ of (translated) type unit, it is important that we do not lookup variables
+ in case the symbolic value has type unit.
+ *)
let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) :
texpression =
(* Translate the type *)
- let var = lookup_var_for_symbolic_value sv ctx in
- mk_texpression_from_var var
+ let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+ (* If the type is unit, directly return unit *)
+ if ty_is_unit ty then mk_unit_rvalue
+ else
+ (* Otherwise lookup the variable *)
+ let var = lookup_var_for_symbolic_value sv ctx in
+ mk_texpression_from_var var
(** Translate a typed value.
@@ -1342,13 +1684,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option =
match aproj with
| V.AEndedProjLoans (msv, []) ->
(* The symbolic value was left unchanged *)
- let var = lookup_var_for_symbolic_value msv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx msv)
| V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) ->
assert (child_aproj = AIgnoredProjBorrows);
(* The symbolic value was updated *)
- let var = lookup_var_for_symbolic_value mnv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx mnv)
| V.AEndedProjLoans (_, _) ->
(* The symbolic value was updated, and the given back values come from sevearl
* abstractions *)
@@ -1543,7 +1883,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list)
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
- | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx
+ | S.Return (ectx, opt_v) ->
+ (* Remark: we can't get there if we are inside a loop *)
+ translate_return ectx opt_v ctx
| ReturnWithLoop (loop_id, is_continue) ->
translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
@@ -1565,32 +1907,56 @@ and translate_panic (ctx : bs_ctx) : texpression =
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
- let output_ty =
- if ctx.inside_loop && Option.is_some ctx.bid then
- (* We are synthesizing the backward function of a loop body *)
- let bid = Option.get ctx.bid in
- let back_vars =
- T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
- in
- let tys = List.map (fun (v : var) -> v.ty) back_vars in
- mk_simpl_tuple_ty tys
- else
- (* Regular function, or forward function (the forward translation for
- a loop has the same return type as the parent function)
- *)
- mk_simpl_tuple_ty ctx.sg.doutputs
- in
+ let effect_info = ctx_get_effect_info ctx in
(* TODO: we should use a [Fail] function *)
- if ctx.sg.info.effect_info.stateful then
- (* Create the [Fail] value *)
- let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v =
- mk_result_fail_texpression_with_error_id error_failure_id ret_ty
- in
- ret_v
- else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ let mk_output output_ty =
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ in
+ ret_v
+ else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ in
+ if ctx.inside_loop && Option.is_some ctx.bid then
+ (* We are synthesizing the backward function of a loop body *)
+ let bid = Option.get ctx.bid in
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ else
+ (* Regular function, or forward function (the forward translation for
+ a loop has the same return type as the parent function)
+ *)
+ match ctx.bid with
+ | None ->
+ if !Config.return_back_funs then
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ else mk_output ctx.sg.fwd_output
+ | Some bid ->
+ let output =
+ mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
+ in
+ mk_output output
+
+(** [opt_v]: the value to return, in case we translate a forward body.
-(** [opt_v]: the value to return, in case we translate a forward body *)
+ Remark: for now, we can't get there if we are inside a loop.
+ If inside a loop, we use {!translate_return_with_loop}.
+
+ Remark: in case we merge the forward/backward functions, we introduce
+ those in [translate_forward_end].
+*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -1599,22 +1965,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
- or we are translating a backward function, in which case it should be [None]
*)
(* Compute the values that we should return *without the state and the result
- * wrapper* *)
+ wrapper* *)
let output =
match ctx.bid with
| None ->
(* Forward function *)
let v = Option.get opt_v in
typed_value_to_texpression ctx ectx v
- | Some bid ->
+ | Some _ ->
(* Backward function *)
(* Sanity check *)
assert (opt_v = None);
(* Group the variables in which we stored the values we need to give back.
- * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
+ See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs = Option.get ctx.backward_outputs in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -1622,7 +1986,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
* - error-monad: Return x
* - state-error: Return (state, x)
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -1647,24 +2011,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
*)
let output =
match ctx.bid with
- | None ->
- (* Forward *)
- mk_texpression_from_var
- (Option.get loop_info.forward_output_no_state_no_result)
- | Some bid ->
+ | None -> Option.get loop_info.forward_output_no_state_no_result
+ | Some _ ->
(* Backward *)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- let map =
- if ctx.inside_loop then
- (* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function *)
- ctx.backward_outputs
- in
- T.RegionGroupId.Map.find bid map
- in
+ let backward_outputs = Option.get ctx.backward_outputs in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -1676,7 +2028,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
* effect - in particular, one manipulates a state iff the other does
* the same.
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -1684,13 +2036,15 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
else output
in
(* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- mk_result_return_texpression output
+ mk_emeta (Tag "return_with_loop") (mk_result_return_texpression output)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
log#ldebug
(lazy
- ("translate_function_call:\n"
+ ("translate_function_call:\n" ^ "\n- call.call_id:"
+ ^ S.show_call_id call.call_id
+ ^ "\n\n- call.generics:\n"
^ ctx_generic_args_to_string ctx call.generics));
(* Translate the function call *)
let generics = ctx_translate_fwd_generic_args ctx call.generics in
@@ -1702,10 +2056,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.combine args args_mplaces)
in
let dest_mplace = translate_opt_mplace call.dest_place in
- let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, effect_info, args, out_state =
+ let ctx, fun_id, effect_info, args, dest_v =
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
@@ -1713,25 +2066,150 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None None
- in
+ let effect_info = get_fun_effect_info ctx fid None None in
(* Depending on the function effects:
- * - add the fuel
- * - add the state input argument
- * - generate a fresh state variable for the returned state
+ - add the fuel
+ - add the state input argument
+ - generate a fresh state variable for the returned state
*)
let args, ctx, out_state =
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
+ (* If we do not split the forward/backward functions: generate the
+ variables for the backward functions returned by the forward
+ function. *)
+ let ctx, ignore_fwd_output, back_funs_map, back_funs =
+ if !Config.return_back_funs then (
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ fid call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ log#ldebug
+ (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
+ let tr_self, all_generics =
+ match call.trait_method_generics with
+ | None -> (UnknownTrait __FUNCTION__, generics)
+ | Some (all_generics, tr_self) ->
+ let all_generics =
+ ctx_translate_fwd_generic_args ctx all_generics
+ in
+ let tr_self =
+ translate_fwd_trait_instance_id ctx.type_ctx.type_infos
+ tr_self
+ in
+ (tr_self, all_generics)
+ in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (all_generics, tr_self))
+ in
+ (* Introduce variables for the backward functions *)
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
+ in
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
+ in
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then
+ Some back_fun_name
+ else None
+ in
+ Some (name, ty))
+ back_tys)
+ ctx
+ in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
+ in
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs))
+ else (ctx, false, None, [])
+ in
+ (* Compute the pattern for the destination *)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ let dest =
+ (* Here there is something subtle: as we might ignore the output
+ of the forward function (because it translates to unit) we doNOT
+ necessarily introduce in the let-binding the variable to which we
+ map the symbolic value which was introduced for the output of the
+ function call. This would be problematic if later we need to
+ translate this symbolic value, but we implemented
+ {!symbolic_value_to_texpression} so that it doesn't perform any
+ lookups if the symbolic value has type unit.
+ *)
+ let vars =
+ if ignore_fwd_output then back_funs else dest :: back_funs
+ in
+ mk_simpl_tuple_pattern vars
+ in
+ let dest =
+ match out_state with
+ | None -> dest
+ | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
+ in
(* Register the function call *)
- let ctx = bs_ctx_register_forward_call call_id call args ctx in
- (ctx, func, effect_info, args, out_state)
+ let ctx =
+ bs_ctx_register_forward_call call_id call args back_funs_map ctx
+ in
+ (ctx, func, effect_info, args, dest)
| S.Unop E.Not ->
let effect_info =
{
@@ -1742,7 +2220,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop Not, effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop Not, effect_info, args, dest)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
@@ -1758,7 +2238,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Neg int_ty), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Neg int_ty), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
@@ -1773,7 +2255,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest)
| CastFnPtr _ -> raise (Failure "TODO: function casts"))
| S.Binop binop -> (
match args with
@@ -1793,15 +2277,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Binop (binop, int_ty0), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Binop (binop, int_ty0), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
in
- let dest_v =
- let dest = mk_typed_pattern_from_var dest dest_mplace in
- match out_state with
- | None -> dest
- | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
- in
let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -1846,45 +2326,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
^ abs_to_string ctx abs ^ "\n"));
(* When we end an input abstraction, this input abstraction gets back
- * the borrows which it introduced in the context through the input
- * values: by listing those values, we get the values which are given
- * back by one of the backward functions we are synthesizing. *)
- (* Note that we don't support nested borrows for now: if we find
- * an ended synthesized input abstraction, it must be the one corresponding
- * to the backward function wer are synthesizing, it can't be the one
- * for a parent backward function.
- *)
+ the borrows which it introduced in the context through the input
+ values: by listing those values, we get the values which are given
+ back by one of the backward functions we are synthesizing.
+
+ Note that we don't support nested borrows for now: if we find
+ an ended synthesized input abstraction, it must be the one corresponding
+ to the backward function wer are synthesizing, it can't be the one
+ for a parent backward function.
+ *)
let bid = Option.get ctx.bid in
assert (rg_id = bid);
- (* The translation is done as follows:
- * - for a given backward function, we choose a set of variables [v_i]
- * - when we detect the ended input abstraction which corresponds
- * to the backward function, and which consumed the values [consumed_i],
- * we introduce:
- * {[
- * let v_i = consumed_i in
- * ...
- * ]}
- * Then, when we reach the [Return] node, we introduce:
- * {[
- * (v_i)
- * ]}
- * *)
- (* First, get the given back variables.
+ (* First, introduce the given back variables.
We don't use the same given back variables if we translate a loop or
the standard body of a function.
*)
- let given_back_variables =
- let map =
+ let ctx, given_back_variables =
+ let vars =
if ctx.inside_loop then
(* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function body *)
- ctx.backward_outputs
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
+ List.map (fun ty -> (None, ty)) tys
+ else
+ (* Regular function body *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ List.combine back_sg.output_names back_sg.outputs
in
- T.RegionGroupId.Map.find bid map
+ let ctx, vars = fresh_vars vars ctx in
+ ({ ctx with backward_outputs = Some vars }, vars)
in
(* Get the list of values consumed by the abstraction upon ending *)
@@ -1933,9 +2406,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Those don't have backward functions *)
raise (Failure "Unreachable")
in
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
- in
+ let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
@@ -1959,14 +2430,16 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ let inherited_inputs =
+ if !Config.return_back_funs then []
+ else List.concat [ fwd_inputs; back_ancestors_inputs ]
in
+ let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
* meta-place information from the input values given to the forward function
@@ -1983,78 +2456,61 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| None -> output
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
- (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- (if (* TODO: normalize the types *) !Config.type_check_pure_code then
- match fun_id with
- | FunId fun_id ->
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", "
- (List.map (pure_ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
+ back_inputs generics output.ty ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
+ let inputs = List.append inherited_inputs back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
- let call = mk_apps func args in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ =================
+ We do a small optimization here if we split the forward/backward functions.
+ If the backward function doesn't have any output, we don't introduce any function
+ call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
+ then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
- else mk_let effect_info.can_fail output call next_e
+ else
+ (* The backward function might also have been filtered if we do not
+ split the forward/backward functions *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2095,15 +2551,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs)
let-binding:
{[
let id_back x nx =
- let s = nx in // the name [s] is not important (only collision matters)
- ...
+ let s = nx in // the name [s] is not important (only collision matters)
+ ...
]}
This let-binding later gets inlined, during a micro-pass.
*)
(* First, retrieve the list of variables used for the inputs for the
* backward function *)
- let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: as there are no nested borrows, there should be none. *)
let consumed = abs_to_consumed ctx ectx abs in
@@ -2150,11 +2606,12 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopSynthInput ->
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
- | V.LoopCall ->
+ | V.LoopCall -> (
+ (* We need to introduce a call to the backward function corresponding
+ to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id)
- (Some vloop_id) (Some rg_id)
+ get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
let generics = loop_info.generics in
@@ -2165,7 +2622,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
values consumed upon ending the abstraction (i.e., we don't use
[abs_to_consumed]) *)
let back_inputs_vars =
- T.RegionGroupId.Map.find rg_id ctx.backward_inputs
+ T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state
in
let back_inputs = List.map mk_texpression_from_var back_inputs_vars in
(* If the function is stateful:
@@ -2175,12 +2632,15 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
+ let inputs =
+ if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
+ else List.concat [ fwd_inputs; back_inputs; back_state ]
+ in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2204,68 +2664,88 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
- let call = mk_apps func args in
+ (* Create the expression for the function:
+ - it is either a call to a top-level function, if we split the
+ forward/backward functions
+ - or a call to the variable we introduced for the backward function,
+ if we merge the forward/backward functions *)
+ let func =
+ if !Config.return_back_funs then
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
+ else
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
+ Some { e = Qualif func; ty = func_ty }
+ in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
+ =================
+ We do a small optimization here in case we split the forward/backward
+ functions.
+ If the backward function doesn't have any output, we don't introduce
+ any function call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
+ (* In case we merge the fwd/back functions we filter the backward
+ functions elsewhere *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
let ctx, var = fresh_var_for_symbolic_value sval ctx in
- let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in
+ let decl = A.GlobalDeclId.Map.find gid ctx.global_ctx.llbc_global_decls in
let global_expr = { id = Global gid; generics = empty_generic_args } in
(* We use translate_fwd_ty to translate the global type *)
let ty = ctx_translate_fwd_ty ctx decl.ty in
@@ -2290,8 +2770,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
- let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
- let scrutinee = mk_texpression_from_var scrutinee_var in
+ let scrutinee = symbolic_value_to_texpression ctx sv in
let scrutinee_mplace = translate_opt_mplace p in
(* Translate the branches *)
match exp with
@@ -2370,7 +2849,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
If (true_e, false_e) )
in
let ty = true_e.ty in
- assert (ty = false_e.ty);
+ log#ldebug
+ (lazy
+ ("true_e.ty: "
+ ^ pure_ty_to_string ctx true_e.ty
+ ^ "\n\nfalse_e.ty: "
+ ^ pure_ty_to_string ctx false_e.ty));
+ if !Config.fail_hard then assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
@@ -2441,7 +2926,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
- if we forbid using field projectors.
*)
let is_rec_def =
- T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls
+ T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls
in
let use_let_with_cons =
is_enum
@@ -2454,7 +2939,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
like Coq don't, in which case we have to deconstruct the whole ADT
at once (`let (a, b, c) = x in`) *)
|| TypesUtils.type_decl_from_type_id_is_tuple_struct
- ctx.type_context.type_infos type_id
+ ctx.type_ctx.type_infos type_id
&& not (Config.backend_has_tuple_projectors ())
in
if use_let_with_cons then
@@ -2547,7 +3032,7 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
{ e = StructUpdate su; ty = var.ty }
| VaCgValue cg_id -> { e = CVar cg_id; ty = var.ty }
| VaTraitConstValue (trait_ref, generics, const_name) ->
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
let generics = translate_fwd_generic_args type_infos generics in
let qualif_id = TraitConst (trait_ref, generics, const_name) in
@@ -2562,22 +3047,169 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
and translate_forward_end (ectx : C.eval_ctx)
(loop_input_values : V.typed_value S.symbolic_value_id_map option)
- (e : S.expression) (back_e : S.expression S.region_group_id_map)
+ (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map)
(ctx : bs_ctx) : texpression =
- (* Update the current state with the additional state received by the backward
- function, if needs be, and lookup the proper expression *)
- let translate_end ctx =
+ let translate_one_end ctx (bid : RegionGroupId.id option) =
+ let ctx = { ctx with bid } in
(* Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression *)
- let ctx, e =
- match ctx.bid with
- | None -> (ctx, e)
+ let ctx, e, finish =
+ match bid with
+ | None ->
+ (* We are translating the forward function - nothing to do *)
+ (ctx, fwd_e, fun e -> e)
| Some bid ->
- let ctx = { ctx with state_var = ctx.back_state_var } in
+ (* There are two cases here:
+ - if we split the fwd/backward functions, we simply need to update
+ the state.
+ - if we don't split, we also need to wrap the expression in a
+ lambda, which introduces the additional inputs of the backward
+ function
+ *)
+ let ctx =
+ (* Introduce variables for the inputs and the state variable
+ and update the context. *)
+ if !Config.return_back_funs then
+ (* If the forward/backward functions are not split, we need
+ to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if back_sg.effect_info.stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
+ else
+ (* Update the state variable *)
+ let back_state_var =
+ RegionGroupId.Map.find bid ctx.back_state_vars
+ in
+ { ctx with state_var = back_state_var }
+ in
+
let e = T.RegionGroupId.Map.find bid back_e in
- (ctx, e)
+ let finish e =
+ (* Wrap in lambdas if necessary *)
+ if !Config.return_back_funs then
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
+ else e
+ in
+ (ctx, e, finish)
in
- translate_expression e ctx
+ let e = translate_expression e ctx in
+ finish e
+ in
+
+ (* There are two cases, depending on whether we are splitting the forward/backward
+ functions or not.
+
+ - if we split, then we simply need to translate the proper "end" expression,
+ that is the end of the forward function, or of the backward function we
+ are currently translating.
+ - if we don't split, then we need to translate the end of the forward
+ function (this is the value we will return) and generate the bodies
+ of the backward functions (which we will also return).
+
+ Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression.
+ *)
+ let translate_end ctx =
+ if !Config.return_back_funs then
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
+ let fwd_e = translate_one_end ctx None in
+
+ (* Introduce the backward functions. *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs. *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
+
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
+
+ (* Create the return expressions *)
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else pure_fwd_var :: back_vars
+ in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
+
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
+
+ (* Introduce all the let-bindings *)
+
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
+ in
+ let e =
+ List.fold_right
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
+ back_vars_els ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
+ else translate_one_end ctx ctx.bid
in
(* If we are (re-)entering a loop, we need to introduce a call to the
@@ -2624,17 +3256,53 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = E.FRegular ctx.fun_decl.def_id in
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid
- in
+ let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in
(* Introduce a fresh output value for the forward function *)
- let ctx, output_var =
- let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in
- fresh_var None output_ty ctx
+ let ctx, fwd_output, output_pat =
+ if ctx.sg.fwd_info.ignore_output then
+ (* Note that we still need the forward output (which is unit),
+ because even though the loop function will ignore the forward output,
+ the forward expression will still compute an output (which
+ will have type unit - otherwise we can't ignore it). *)
+ (ctx, mk_unit_rvalue, [])
+ else
+ let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
+ ( ctx,
+ mk_texpression_from_var output_var,
+ [ mk_typed_pattern_from_var output_var None ] )
+ in
+
+ (* Introduce fresh variables for the backward functions of the loop.
+
+ For now, the backward functions of the loop are the same as the
+ backward functions of the outer function.
+ *)
+ let ctx, back_funs_map, back_funs =
+ if !Config.return_back_funs then
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
+ else (ctx, None, [])
in
+
+ (* Introduce patterns *)
let args, ctx, out_pats =
- let output_pat = mk_typed_pattern_from_var output_var None in
+ (* Add the returned backward functions (they might be empty) *)
+ let output_pat = mk_simpl_tuple_pattern (output_pat @ back_funs) in
(* Depending on the function effects:
* - add the fuel
@@ -2644,7 +3312,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in
( List.concat [ fuel; args; [ state_var ] ],
ctx,
[ nstate_pat; output_pat ] )
@@ -2656,7 +3324,8 @@ and translate_forward_end (ectx : C.eval_ctx)
{
loop_info with
forward_inputs = Some args;
- forward_output_no_state_no_result = Some output_var;
+ forward_output_no_state_no_result = Some fwd_output;
+ back_funs = back_funs_map;
}
in
let ctx =
@@ -2753,35 +3422,107 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Compute the backward outputs *)
let ctx = ref ctx in
- let loop_backward_outputs =
- T.RegionGroupId.Map.map
+ let rg_to_given_back_tys =
+ RegionGroupId.Map.map
(fun (_, tys) ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
- let vars =
- List.map
- (fun ty ->
- assert (
- not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty));
- (None, ctx_translate_fwd_ty !ctx ty))
- tys
- in
- (* Introduce fresh variables *)
- let ctx', vars = fresh_vars vars !ctx in
- ctx := ctx';
- vars)
+ List.map
+ (fun ty ->
+ assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty));
+ ctx_translate_fwd_ty !ctx ty)
+ tys)
loop.rg_to_given_back_tys
in
let ctx = !ctx in
- let back_output_tys =
- match ctx.bid with
- | None -> None
- | Some rg_id ->
- let back_outputs =
- T.RegionGroupId.Map.find rg_id loop_backward_outputs
- in
- let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in
- Some back_output_tys
+ (* The output type of the loop function *)
+ let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
+ let back_effect_infos, output_ty =
+ if !Config.return_back_funs then
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
+ let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ let ty =
+ if
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
+ (List.combine back_sgs given_back_tys)
+ in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
+ else
+ let back_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((id, back_sg) : _ * back_sg_info) ->
+ (id, { back_sg.effect_info with is_rec = true }))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg))
+ in
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward function: same type as the parent function *)
+ (translate_fun_sig_from_decomposed ctx.sg None).output
+ | Some rg_id ->
+ (* Backward function: custom return type *)
+ let doutputs =
+ T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
+ in
+ let output = mk_simpl_tuple_ty doutputs in
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if fwd_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if fwd_effect_info.can_fail then mk_result_ty output else output
+ in
+ output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -2823,6 +3564,10 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
+ back_outputs = rg_to_given_back_tys;
+ back_funs = None;
+ fwd_effect_info;
+ back_effect_infos;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in
@@ -2830,13 +3575,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Update the context to translate the function end *)
- let ctx_end =
- {
- ctx with
- loop_id = Some loop_id;
- loop_backward_outputs = Some loop_backward_outputs;
- }
- in
+ let ctx_end = { ctx with loop_id = Some loop_id } in
let fun_end = translate_expression loop.end_expr ctx_end in
(* Update the context for the loop body *)
@@ -2844,7 +3583,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Add the input state *)
let input_state =
- if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None
+ if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None
in
(* Translate the loop body *)
@@ -2862,7 +3601,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in
@@ -2982,10 +3721,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
- (* Retrieve the signature *)
- let signature = ctx.sg in
+ (* Translate the signature *)
+ let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
let regions_hierarchy =
- FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies
+ FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
in
(* Translate the body, if there is *)
let body =
@@ -2993,8 +3732,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos
- (FunId (FRegular def_id)) None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -3027,20 +3765,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
match bid with
| None -> []
| Some back_id ->
+ assert (not !Config.return_back_funs);
let parents_ids =
list_ordered_ancestor_region_groups regions_hierarchy back_id
in
let backward_ids = List.append parents_ids [ back_id ] in
List.concat
(List.map
- (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
+ (fun id ->
+ T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
backward_ids)
in
(* Introduce the backward input state (the state at call site of the
* *backward* function), if necessary *)
let back_state =
if effect_info.stateful && Option.is_some bid then
- [ mk_state_var ctx.back_state_var ]
+ let state_var =
+ RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
+ in
+ [ mk_state_var state_var ]
else []
in
(* Group the inputs together *)
@@ -3114,63 +3857,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list =
List.map (translate_type_decl ctx)
(TypeDeclId.Map.values ctx.type_ctx.type_decls)
-(** Translates function signatures.
-
- Takes as input a list of function information containing:
- - the function id
- - a list of optional names for the inputs
- - the function signature
-
- Returns a map from forward/backward functions identifiers to:
- - translated function signatures
- - optional names for the outputs values (we derive them for the backward
- functions)
- *)
-let translate_fun_signatures (decls_ctx : C.decls_ctx)
- (functions : (A.fun_id * string option list * A.fun_sig) list) :
- fun_sig_named_outputs RegularFunIdNotLoopMap.t =
- (* For every function, translate the signatures of:
- - the forward function
- - the backward functions
- *)
- let translate_one (fun_id : A.fun_id) (input_names : string option list)
- (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
- =
- log#ldebug
- (lazy
- ("Translating signature of function: "
- ^ Print.Expressions.fun_id_to_string
- (Print.Contexts.decls_ctx_to_fmt_env decls_ctx)
- fun_id));
- (* Retrieve the regions hierarchy *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
- (* The forward function *)
- let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
- let fwd_id = (fun_id, None) in
- (* The backward functions *)
- let back_sgs =
- List.map
- (fun (rg : T.region_var_group) ->
- let tsg =
- translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
- in
- let id = (fun_id, Some rg.id) in
- (id, tsg))
- regions_hierarchy
- in
- (* Return *)
- (fwd_id, fwd_sg) :: back_sgs
- in
- let translated =
- List.concat
- (List.map (fun (id, names, sg) -> translate_one id names sg) functions)
- in
- List.fold_left
- (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
- RegularFunIdNotLoopMap.empty translated
-
let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl)
: trait_decl =
let {
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index efcf001a..865185a8 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -2,6 +2,7 @@ open Types
open TypesUtils
open Expressions
open Values
+open LlbcAst
open SymbolicAst
let mk_mplace (p : place) (ctx : Contexts.eval_ctx) : mplace =
@@ -92,7 +93,9 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value)
synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
+ (sg : fun_sig option) (regions_hierarchy : region_var_groups)
(abstractions : AbstractionId.id list) (generics : generic_args)
+ (trait_method_generics : (generic_args * trait_instance_id) option)
(args : typed_value list) (args_places : mplace option list)
(dest : symbolic_value) (dest_place : mplace option) (e : expression option)
: expression option =
@@ -102,8 +105,11 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
{
call_id;
ctx;
+ sg;
+ regions_hierarchy;
abstractions;
generics;
+ trait_method_generics;
args;
dest;
args_places;
@@ -118,29 +124,32 @@ let synthesize_global_eval (gid : GlobalDeclId.id) (dest : symbolic_value)
Option.map (fun e -> EvalGlobal (gid, dest, e)) e
let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref)
- (call_id : FunCallId.id) (ctx : Contexts.eval_ctx)
+ (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig)
+ (regions_hierarchy : region_var_groups)
(abstractions : AbstractionId.id list) (generics : generic_args)
+ (trait_method_generics : (generic_args * trait_instance_id) option)
(args : typed_value list) (args_places : mplace option list)
(dest : symbolic_value) (dest_place : mplace option) (e : expression option)
: expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- ctx abstractions generics args args_places dest dest_place e
+ ctx (Some sg) regions_hierarchy abstractions generics trait_method_generics
+ args args_places dest dest_place e
let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop)
(arg : typed_value) (arg_place : mplace option) (dest : symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
let generics = empty_generic_args in
- synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ]
- dest dest_place e
+ synthesize_function_call (Unop unop) ctx None [] [] generics None [ arg ]
+ [ arg_place ] dest dest_place e
let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop)
(arg0 : typed_value) (arg0_place : mplace option) (arg1 : typed_value)
(arg1_place : mplace option) (dest : symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
let generics = empty_generic_args in
- synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ]
- [ arg0_place; arg1_place ] dest dest_place e
+ synthesize_function_call (Binop binop) ctx None [] [] generics None
+ [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e
let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs)
(e : expression option) : expression option =
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 221d4e73..55a94302 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -6,7 +6,6 @@ open LlbcAst
open Contexts
module SA = SymbolicAst
module Micro = PureMicroPasses
-open PureUtils
open TranslateCore
(** The local logger *)
@@ -43,8 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) :
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (trans_ctx : trans_ctx)
- (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t)
- (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) :
+ (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t)
+ (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) :
pure_fun_translation_no_loops =
(* Debug *)
log#ldebug
@@ -58,13 +57,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Convert the symbolic ASTs to pure ASTs: *)
(* Initialize the context *)
- let forward_sig =
- RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs
- in
let sv_to_var = SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
let state_var, var_counter = Pure.VarId.fresh var_counter in
- let back_state_var, var_counter = Pure.VarId.fresh var_counter in
let fuel0, var_counter = Pure.VarId.fresh var_counter in
let fuel, var_counter = Pure.VarId.fresh var_counter in
let calls = FunCallId.Map.empty in
@@ -78,7 +73,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| RecGroup _ -> Some tid)
(TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups))
in
- let type_context =
+ let type_ctx =
{
SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos;
llbc_type_decls = trans_ctx.type_ctx.type_decls;
@@ -86,15 +81,14 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
recursive_decls = recursive_type_decls;
}
in
- let fun_context =
+ let fun_ctx =
{
SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls;
- fun_sigs;
fun_infos = trans_ctx.fun_ctx.fun_infos;
regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies;
}
in
- let global_context =
+ let global_ctx =
{ SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls }
in
@@ -126,31 +120,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
!m
in
+ let sg =
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef
+ in
+
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies
+ in
+
+ let var_counter, back_state_vars =
+ if !Config.return_back_funs then (var_counter, [])
+ else
+ List.fold_left_map
+ (fun var_counter (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let var, var_counter = Pure.VarId.fresh var_counter in
+ (var_counter, (gid, var)))
+ var_counter regions_hierarchy
+ in
+ let back_state_vars = RegionGroupId.Map.of_list back_state_vars in
+
let ctx =
{
+ decls_ctx = trans_ctx;
SymbolicToPure.bid = None;
- (* Dummy for now *)
- sg = forward_sig.sg;
- fwd_sg = forward_sig.sg;
+ sg;
+ fun_dsigs;
(* Will need to be updated for the backward functions *)
sv_to_var;
- var_counter;
+ var_counter = ref var_counter;
state_var;
- back_state_var;
+ back_state_vars;
fuel0;
fuel;
- type_context;
- fun_context;
- global_context;
+ type_ctx;
+ fun_ctx;
+ global_ctx;
trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls;
trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls;
fun_decl = fdef;
forward_inputs = [];
- (* Empty for now *)
- backward_inputs = RegionGroupId.Map.empty;
- (* Empty for now *)
- backward_outputs = RegionGroupId.Map.empty;
- loop_backward_outputs = None;
+ (* Initialized just below *)
+ backward_inputs_no_state = RegionGroupId.Map.empty;
+ (* Initialized just below *)
+ backward_inputs_with_state = RegionGroupId.Map.empty;
+ backward_outputs = None;
(* Empty for now *)
calls;
abstractions;
@@ -180,6 +194,37 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| _ -> raise (Failure "Unreachable")
in
+ (* Add the backward inputs *)
+ let ctx, backward_inputs_no_state, backward_inputs_with_state =
+ if !Config.return_back_funs then (ctx, [], [])
+ else
+ let ctx, inputs_no_with_state =
+ List.fold_left_map
+ (fun ctx (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let back_sg = RegionGroupId.Map.find gid sg.back_sg in
+ let ctx, no_state =
+ SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, with_state =
+ SymbolicToPure.fresh_vars back_sg.inputs ctx
+ in
+ (ctx, ((gid, no_state), (gid, with_state))))
+ ctx regions_hierarchy
+ in
+ let inputs_no_state, inputs_with_state =
+ List.split inputs_no_with_state
+ in
+ (ctx, inputs_no_state, inputs_with_state)
+ in
+ let backward_inputs_no_state =
+ RegionGroupId.Map.of_list backward_inputs_no_state
+ in
+ let backward_inputs_with_state =
+ RegionGroupId.Map.of_list backward_inputs_with_state
+ in
+ let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
+
(* Translate the forward function *)
let pure_forward =
match symbolic_trans with
@@ -187,7 +232,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
in
- (* Translate the backward functions *)
+ (* Translate the backward functions, if we split the forward and backward functions *)
let translate_backward (rg : region_var_group) : Pure.fun_decl =
(* For the backward inputs/outputs initialization: we use the fact that
* there are no nested borrows for now, and so that the region groups
@@ -197,77 +242,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
match symbolic_trans with
| None ->
- (* Initialize the context - note that the ret_ty is not really
- * useful as we don't translate a body *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx None
| Some (_, symbolic) ->
- (* Finish initializing the context by adding the additional input
- variables required by the backward function.
- *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- (* We need to ignore the forward inputs, and the state input (if there is) *)
- let backward_inputs =
- let sg = backward_sg.sg in
- (* We need to ignore the forward state and the backward state *)
- let num_forward_inputs =
- sg.info.num_fwd_inputs_with_fuel_with_state
- in
- let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
- Collections.List.subslice sg.inputs num_forward_inputs
- (num_forward_inputs + num_back_inputs)
- in
- (* As we forbid nested borrows, the additional inputs for the backward
- * functions come from the borrows in the return value of the rust function:
- * we thus use the name "ret" for those inputs *)
- let backward_inputs =
- List.map (fun ty -> (Some "ret", ty)) backward_inputs
- in
- let ctx, backward_inputs =
- SymbolicToPure.fresh_vars backward_inputs ctx
- in
- (* The outputs for the backward functions, however, come from borrows
- * present in the input values of the rust function: for those we reuse
- * the names of the input values. *)
- let backward_outputs =
- List.combine backward_sg.output_names backward_sg.sg.doutputs
- in
- let ctx, backward_outputs =
- SymbolicToPure.fresh_vars backward_outputs ctx
- in
- let backward_inputs =
- RegionGroupId.Map.singleton back_id backward_inputs
- in
- let backward_outputs =
- RegionGroupId.Map.singleton back_id backward_outputs
- in
-
- (* Put everything in the context *)
- let ctx =
- {
- ctx with
- bid = Some back_id;
- sg = backward_sg.sg;
- backward_inputs;
- backward_outputs;
- }
- in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx (Some symbolic)
in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id)
- fun_context.regions_hierarchies
+ let pure_backwards =
+ if !Config.return_back_funs then []
+ else List.map translate_backward regions_hierarchy
in
- let pure_backwards = List.map translate_backward regions_hierarchy in
(* Return *)
(pure_forward, pure_backwards)
@@ -294,36 +282,21 @@ let translate_crate_to_pure (crate : crate) :
(List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls)
in
- (* Translate all the function *signatures* *)
- let assumed_sigs =
- List.map
- (fun (info : Assumed.assumed_fun_info) ->
- ( FAssumed info.fun_id,
- List.map (fun _ -> None) info.fun_sig.inputs,
- info.fun_sig ))
- Assumed.assumed_fun_infos
- in
- let local_sigs =
- List.map
- (fun (fdef : fun_decl) ->
- let input_names =
- match fdef.body with
- | None -> List.map (fun _ -> None) fdef.signature.inputs
- | Some body ->
- List.map
- (fun (v : var) -> v.name)
- (LlbcAstUtils.fun_body_get_input_vars body)
- in
- (FRegular fdef.def_id, input_names, fdef.signature))
- (FunDeclId.Map.values crate.fun_decls)
+ (* Compute the decomposed fun sigs for the whole crate *)
+ let fun_dsigs =
+ FunDeclId.Map.of_list
+ (List.map
+ (fun (fdef : LlbcAst.fun_decl) ->
+ ( fdef.def_id,
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx
+ fdef ))
+ (FunDeclId.Map.values crate.fun_decls))
in
- let sigs = List.append assumed_sigs local_sigs in
- let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure trans_ctx fun_sigs type_decls_map)
+ (translate_function_to_pure trans_ctx type_decls_map fun_dsigs)
(FunDeclId.Map.values crate.fun_decls)
in
@@ -1030,7 +1003,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
List.map
(fun { fwd; _ } ->
let fwd_f =
- if fwd.f.Pure.signature.info.effect_info.is_rec then
+ if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then
[ (fwd.f.def_id, None) ]
else []
in
@@ -1198,7 +1171,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let exe_dir = Filename.dirname Sys.argv.(0) in
let primitives_src_dest =
match !Config.backend with
- | FStar -> Some ("/backends/fstar/Primitives.fst", "Primitives.fst")
+ | FStar ->
+ let src =
+ if !Config.return_back_funs then
+ "/backends/fstar/merge/Primitives.fst"
+ else "/backends/fstar/split/Primitives.fst"
+ in
+ Some (src, "Primitives.fst")
| Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v")
| Lean -> None
| HOL4 -> None