diff options
author | Son HO | 2023-12-23 01:46:58 +0100 |
---|---|---|
committer | GitHub | 2023-12-23 01:46:58 +0100 |
commit | 15a7d7b7322a1cd0ebeb328fde214060e23fa8b4 (patch) | |
tree | 6cce7d76969870f5bc18c5a7cd585e8873a1c0dc /compiler | |
parent | c3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff) | |
parent | 63ccbd914d5d44aa30dee38a6fcc019310ab640b (diff) |
Merge pull request #64 from AeneasVerif/son/merge_back
Merge the forward/backward functions
Diffstat (limited to 'compiler')
28 files changed, 2404 insertions, 1248 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index e2f687e8..054c8169 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -493,11 +493,11 @@ let norm_ctx_normalize_trait_type_constraint (ctx : norm_ctx) let mk_norm_ctx (ctx : eval_ctx) : norm_ctx = { norm_trait_types = ctx.norm_trait_types; - type_decls = ctx.type_context.type_decls; - fun_decls = ctx.fun_context.fun_decls; - global_decls = ctx.global_context.global_decls; - trait_decls = ctx.trait_decls_context.trait_decls; - trait_impls = ctx.trait_impls_context.trait_impls; + type_decls = ctx.type_ctx.type_decls; + fun_decls = ctx.fun_ctx.fun_decls; + global_decls = ctx.global_ctx.global_decls; + trait_decls = ctx.trait_decls_ctx.trait_decls; + trait_impls = ctx.trait_impls_ctx.trait_impls; type_vars = ctx.type_vars; const_generic_vars = ctx.const_generic_vars; } diff --git a/compiler/Config.ml b/compiler/Config.ml index b09544ba..2bb1ca34 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -92,6 +92,69 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) +(** If true, do not define separate forward/backward functions, but make the + forward functions return the backward function. + + Example: + {[ + (* Rust *) + pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T { + match l { + List::Nil => { + panic!() + } + List::Cons(x, tl) => { + if i == 0 { + x + } else { + list_nth(tl, i - 1) + } + } + } + } + + (* Translation, if return_back_funs = false *) + def list_nth (T : Type) (l : List T) (i : U32) : Result T := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret x + else do + let i0 ← i - 1#u32 + list_nth T tl i0 + | List.Nil => Result.fail .panic + + def list_nth_back + (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (List.Cons ret tl) + else + do + let i0 ← i - 1#u32 + let tl0 ← list_nth_back T tl i0 ret + Result.ret (List.Cons x tl0) + | List.Nil => Result.fail .panic + + (* Translation, if return_back_funs = true *) + def list_nth (T: Type) (ls : List T) (i : U32) : + Result (T × (T → Result (List T))) := + match ls with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (x, (λ ret => return (ret :: ls))) + else do + let i0 ← i - 1#u32 + let (x, back) ← list_nth ls i0 + Return.ret (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + ]} + *) +let return_back_funs = ref true + (** Forbids using field projectors for structures. If we don't use field projectors, whenever we symbolically expand a structure @@ -307,6 +370,32 @@ let filter_useless_monadic_calls = ref true *) let filter_useless_functions = ref true +(** Simplify the forward/backward functions, in case we merge them + (i.e., the forward functions return the backward functions). + + The simplification occurs as follows: + - if a forward function returns the unit type and has non-trivial backward + functions, then we remove the returned output. + - if a backward function doesn't have inputs, we evaluate it inside the + forward function and don't wrap it in a result. + + Example: + {[ + // LLBC: + fn incr(x: &mut u32) { *x += 1 } + + // Translation without simplification: + let incr (x : u32) : result (unit * result u32) = ... + ^^^^ ^^^^^^ + | remove this result + remove the unit + + // Translation with simplification: + let incr (x : u32) : result u32 = ... + ]} + *) +let simplify_merged_fwd_backs = ref true + (** Use short names for the record fields. Some backends can't disambiguate records when their field names have collisions. diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index a30ed0f1..5d646a61 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -180,35 +180,35 @@ type config = { let mk_config (mode : interpreter_mode) : config = { mode } -type type_context = { +type type_ctx = { type_decls_groups : type_declaration_group TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; type_infos : TypesAnalysis.type_infos; } [@@deriving show] -type fun_context = { +type fun_ctx = { fun_decls : fun_decl FunDeclId.Map.t; fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t; regions_hierarchies : region_var_groups FunIdMap.t; } [@@deriving show] -type global_context = { global_decls : global_decl GlobalDeclId.Map.t } +type global_ctx = { global_decls : global_decl GlobalDeclId.Map.t } [@@deriving show] -type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t } +type trait_decls_ctx = { trait_decls : trait_decl TraitDeclId.Map.t } [@@deriving show] -type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t } +type trait_impls_ctx = { trait_impls : trait_impl TraitImplId.Map.t } [@@deriving show] type decls_ctx = { - type_ctx : type_context; - fun_ctx : fun_context; - global_ctx : global_context; - trait_decls_ctx : trait_decls_context; - trait_impls_ctx : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; } [@@deriving show] @@ -230,11 +230,11 @@ module TraitTypeRefMap = Collections.MakeMap (TraitTypeRefOrd) (** Evaluation context *) type eval_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; - trait_decls_context : trait_decls_context; - trait_impls_context : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; region_groups : RegionGroupId.id list; type_vars : type_var list; const_generic_vars : const_generic_var list; @@ -290,20 +290,20 @@ let ctx_lookup_var_binder (ctx : eval_ctx) (vid : VarId.id) : var_binder = fst (env_lookup_var ctx.env vid) let ctx_lookup_type_decl (ctx : eval_ctx) (tid : TypeDeclId.id) : type_decl = - TypeDeclId.Map.find tid ctx.type_context.type_decls + TypeDeclId.Map.find tid ctx.type_ctx.type_decls let ctx_lookup_fun_decl (ctx : eval_ctx) (fid : FunDeclId.id) : fun_decl = - FunDeclId.Map.find fid ctx.fun_context.fun_decls + FunDeclId.Map.find fid ctx.fun_ctx.fun_decls let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : global_decl = - GlobalDeclId.Map.find gid ctx.global_context.global_decls + GlobalDeclId.Map.find gid ctx.global_ctx.global_decls let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl = - TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls + TraitDeclId.Map.find id ctx.trait_decls_ctx.trait_decls let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl = - TraitImplId.Map.find id ctx.trait_impls_context.trait_impls + TraitImplId.Map.find id ctx.trait_impls_ctx.trait_impls (** Retrieve a variable's value in the current frame *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = @@ -528,7 +528,7 @@ let ctx_set_abs_can_end (ctx : eval_ctx) (abs_id : AbstractionId.id) fst (ctx_subst_abs ctx abs_id abs) let ctx_type_decl_is_rec (ctx : eval_ctx) (id : TypeDeclId.id) : bool = - let decl_group = TypeDeclId.Map.find id ctx.type_context.type_decls_groups in + let decl_group = TypeDeclId.Map.find id ctx.type_ctx.type_decls_groups in match decl_group with RecGroup _ -> true | NonRecGroup _ -> false (** Visitor to iterate over the values in the *current* frame *) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 20cdb20b..87dcb1fd 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -43,8 +43,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) } | _ -> ctx in - let backs = List.map (fun f -> f.f) def.backs in - let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + let funs = + if !Config.return_back_funs then [ def.fwd.f ] + else + let backs = List.map (fun f -> f.f) def.backs in + if def.keep_fwd then def.fwd.f :: backs else backs + in List.fold_left (fun ctx (f : fun_decl) -> let open ExtractBuiltin in @@ -128,9 +132,15 @@ let extract_adt_g_value F.pp_print_string fmt "tt"; ctx) else - (* If there is exactly one value, we don't print the parentheses *) + (* If there is exactly one value, we don't print the parentheses. + Also, for Coq, we need the special syntax ['(...)] if we destruct + a tuple pattern in a let-binding and the tuple has > 2 values. + *) let lb, rb = - if List.length field_values = 1 then ("", "") else ("(", ")") + if List.length field_values = 1 then ("", "") + else if !backend = Coq && is_single_pat && List.length field_values > 2 + then ("'(", ")") + else ("(", ")") in F.pp_print_string fmt lb; let ctx = @@ -237,30 +247,60 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) Result.Ok types (** [inside]: see {!extract_ty}. + [with_type]: do we also generate a type annotation? This is necessary for + backends like Coq when we write lambdas (Coq is not powerful enough to + infer the type). As a pattern can introduce new variables, we return an extraction context updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = - match v.value with - | PatConstant cv -> - extract_literal fmt inside cv; - ctx - | PatVar (v, _) -> - let vname = ctx_compute_var_basename ctx v.basename v.ty in - let ctx, vname = ctx_add_var vname v.id ctx in - F.pp_print_string fmt vname; - ctx - | PatDummy -> - F.pp_print_string fmt "_"; - ctx - | PatAdt av -> - let extract_value ctx inside v = - extract_typed_pattern ctx fmt is_let inside v - in - extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id - av.field_values v.ty + (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) : + extraction_ctx = + if with_type then F.pp_print_string fmt "("; + let inside = inside && not with_type in + let ctx = + match v.value with + | PatConstant cv -> + extract_literal fmt inside cv; + ctx + | PatVar (v, _) -> + let vname = ctx_compute_var_basename ctx v.basename v.ty in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | PatDummy -> + F.pp_print_string fmt "_"; + ctx + | PatAdt av -> + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id + av.field_values v.ty + in + if with_type then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty false v.ty; + F.pp_print_string fmt ")"); + ctx + +(** Return true if we need to wrap a succession of let-bindings in a [do ...] + block (because some of them are monadic) *) +let lets_require_wrap_in_do (lets : (bool * typed_pattern * texpression) list) : + bool = + match !backend with + | Lean -> + (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) + List.exists (fun (m, _, _) -> m) lets + | HOL4 -> + (* HOL4 is similar to HOL4, but we add a sanity check *) + let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in + if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); + wrap_in_do + | FStar | Coq -> false (** [inside]: controls the introduction of parentheses. See [extract_ty] @@ -285,9 +325,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args - | Abs _ -> - let xl, e = destruct_abs_list e in - extract_Abs ctx fmt inside xl e + | Lambda _ -> + let xl, e = destruct_lambdas e in + extract_Lambda ctx fmt inside xl e | Qualif _ -> (* We use the app case *) extract_App ctx fmt inside e [] @@ -574,7 +614,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* No argument: shouldn't happen *) raise (Failure "Unreachable") -and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) +and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (xl : typed_pattern list) (e : texpression) : unit = (* Open a box for the abs expression *) F.pp_open_hovbox fmt ctx.indent_incr; @@ -583,15 +623,16 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Print the lambda - note that there should always be at least one variable *) assert (xl <> []); F.pp_print_string fmt "fun"; + let with_type = !backend = Coq in let ctx = List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true true x) + extract_typed_pattern ctx fmt true true ~with_type x) ctx xl in F.pp_print_space fmt (); - if !backend = Lean then F.pp_print_string fmt "=>" + if !backend = Lean || !backend = Coq then F.pp_print_string fmt "=>" else F.pp_print_string fmt "->"; F.pp_print_space fmt (); (* Print the body *) @@ -630,15 +671,6 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | HOL4 -> destruct_lets_no_interleave e | FStar | Coq | Lean -> destruct_lets e in - (* Open a box for the whole expression. - - In the case of Lean, we use a vbox so that line breaks are inserted - at the end of every let-binding: let-bindings are indeed not ended - with an "in" keyword. - *) - if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; - (* Open parentheses *) - if inside && !backend <> Lean then F.pp_print_string fmt "("; (* Extract the let-bindings *) let extract_let (ctx : extraction_ctx) (monadic : bool) (lv : typed_pattern) (re : texpression) : extraction_ctx = @@ -711,22 +743,19 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Return *) ctx in + (* Open a box for the whole expression. + + In the case of Lean, we use a vbox so that line breaks are inserted + at the end of every let-binding: let-bindings are indeed not ended + with an "in" keyword. + *) + if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; + (* Open parentheses *) + if inside && !backend <> Lean then F.pp_print_string fmt "("; (* If Lean and HOL4, we rely on monadic blocks, so we insert a do and open a new box immediately *) - let wrap_in_do_od = - match !backend with - | Lean -> - (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) - List.exists (fun (m, _, _) -> m) lets - | HOL4 -> - (* HOL4 is similar to HOL4, but we add a sanity check *) - let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in - if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); - wrap_in_do - | FStar | Coq -> false - in + let wrap_in_do_od = lets_require_wrap_in_do lets in if wrap_in_do_od then ( - F.pp_open_vbox fmt (if !backend = Lean then ctx.indent_incr else 0); F.pp_print_string fmt "do"; F.pp_print_space fmt ()); let ctx = @@ -742,11 +771,10 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_close_box fmt (); (* do-box (Lean and HOL4 only) *) - if wrap_in_do_od then ( + if wrap_in_do_od then if !backend = HOL4 then ( F.pp_print_space fmt (); F.pp_print_string fmt "od"); - F.pp_close_box fmt ()); (* Close parentheses *) if inside && !backend <> Lean then F.pp_print_string fmt ")"; (* Close the box for the whole expression *) @@ -1319,16 +1347,16 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in - let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]: " in + let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in let comment = let loop_comment = match def.loop_id with | None -> "" - | Some id -> "loop " ^ LoopId.to_string id ^ ": " + | Some id -> " loop " ^ LoopId.to_string id ^ ":" in let fwd_back_comment = match def.back_id with - | None -> [ "forward function" ] + | None -> if !Config.return_back_funs then [] else [ "forward function" ] | Some id -> (* Check if there is only one backward function, and no forward function *) if (not keep_fwd) && num_backs = 1 then @@ -1340,9 +1368,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) else [ "backward function " ^ T.RegionGroupId.to_string id ] in match fwd_back_comment with - | [] -> raise (Failure "Unreachable") - | [ s ] -> [ comment_pre ^ loop_comment ^ s ] - | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl + | [] -> [ comment_pre ^ loop_comment ] + | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ] + | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl in extract_comment_with_span fmt comment def.meta.span @@ -1470,7 +1498,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1516,7 +1544,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let def_body = Option.get def.body in let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in @@ -1794,7 +1822,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) assert body.is_global_decl_body; assert (Option.is_none body.back_id); assert (body.signature.inputs = []); - assert (List.length body.signature.doutputs = 1); assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) @@ -1813,7 +1840,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_ty, body_ty = let ty = body.signature.output in - if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + if body.signature.fwd_info.effect_info.can_fail then + (unwrap_result_ty ty, ty) else (ty, mk_result_ty ty) in match body.body with @@ -1984,7 +2012,8 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) (* We add one field per required forward/backward function *) let get_funs_for_id (id : fun_decl_id) : fun_decl list = let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in - List.map (fun f -> f.f) (trans.fwd :: trans.backs) + if !Config.return_back_funs then [ trans.fwd.f ] + else List.map (fun f -> f.f) (trans.fwd :: trans.backs) in match builtin_info with | None -> diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index eb2a2ec9..db887539 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -351,7 +351,7 @@ let basename_to_unique (names_set : StringSet.t) let s = append basename i in if StringSet.mem s names_set then gen (i + 1) else s in - if StringSet.mem basename names_set then gen 0 else basename + if StringSet.mem basename names_set then gen 1 else basename type fun_name_info = { keep_fwd : bool; num_backs : int } @@ -1051,33 +1051,60 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = let assumed_llbc_functions () : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = let rg0 = Some T.RegionGroupId.zero in - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexShared, None, "array_index_usize"); - (ArrayIndexMut, None, "array_index_usize"); - (ArrayIndexMut, rg0, "array_update_usize"); - (ArrayToSliceShared, None, "array_to_slice"); - (ArrayToSliceMut, None, "array_to_slice"); - (ArrayToSliceMut, rg0, "array_from_slice"); - (ArrayRepeat, None, "array_repeat"); - (SliceIndexShared, None, "slice_index_usize"); - (SliceIndexMut, None, "slice_index_usize"); - (SliceIndexMut, rg0, "slice_update_usize"); - ] - | Lean -> - [ - (ArrayIndexShared, None, "Array.index_usize"); - (ArrayIndexMut, None, "Array.index_usize"); - (ArrayIndexMut, rg0, "Array.update_usize"); - (ArrayToSliceShared, None, "Array.to_slice"); - (ArrayToSliceMut, None, "Array.to_slice"); - (ArrayToSliceMut, rg0, "Array.from_slice"); - (ArrayRepeat, None, "Array.repeat"); - (SliceIndexShared, None, "Slice.index_usize"); - (SliceIndexMut, None, "Slice.index_usize"); - (SliceIndexMut, rg0, "Slice.update_usize"); - ] + let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, None, "array_index_usize"); + (ArrayToSliceShared, None, "array_to_slice"); + (ArrayRepeat, None, "array_repeat"); + (SliceIndexShared, None, "slice_index_usize"); + ] + | Lean -> + [ + (ArrayIndexShared, None, "Array.index_usize"); + (ArrayToSliceShared, None, "Array.to_slice"); + (ArrayRepeat, None, "Array.repeat"); + (SliceIndexShared, None, "Slice.index_usize"); + ] + in + let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + if !Config.return_back_funs then + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_mut_usize"); + (ArrayToSliceMut, None, "array_to_slice_mut"); + (SliceIndexMut, None, "slice_index_mut_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_mut_usize"); + (ArrayToSliceMut, None, "Array.to_slice_mut"); + (SliceIndexMut, None, "Slice.index_mut_usize"); + ] + else + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_usize"); + (ArrayIndexMut, rg0, "array_update_usize"); + (ArrayToSliceMut, None, "array_to_slice"); + (ArrayToSliceMut, rg0, "array_from_slice"); + (SliceIndexMut, None, "slice_index_usize"); + (SliceIndexMut, rg0, "slice_update_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_usize"); + (ArrayIndexMut, rg0, "Array.update_usize"); + (ArrayToSliceMut, None, "Array.to_slice"); + (ArrayToSliceMut, rg0, "Array.from_slice"); + (SliceIndexMut, None, "Slice.index_usize"); + (SliceIndexMut, rg0, "Slice.update_usize"); + ] + in + regular @ mut_funs let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 24d16dca..ee8d4831 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -232,6 +232,14 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list let mk_fun (rust_name : string) (extract_name : string option) (filter : bool list option) (with_back : bool) (back_no_suffix : bool) : pattern * bool list option * builtin_fun_info list = + (* [back_no_suffix] is used to control whether the backward function should + have the suffix "_back" or not (if not, then the forward function has the + prefix "_fwd", and is filtered anyway). This is pertinent only if we split + the fwd/back functions. *) + let back_no_suffix = back_no_suffix && not !Config.return_back_funs in + (* Same for the [with_back] option: this is pertinent only if we split + the fwd/back functions *) + let with_back = with_back && not !Config.return_back_funs in let rust_name = try parse_pattern rust_name with Failure _ -> diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 76432faa..22d176c9 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -195,7 +195,7 @@ let initialize_symbolic_context_for_fun (ctx : decls_ctx) (fdef : fun_decl) : List.map (fun (g : region_var_group) -> g.id) regions_hierarchy in let ctx = - initialize_eval_context ctx region_groups sg.generics.types + initialize_eval_ctx ctx region_groups sg.generics.types sg.generics.const_generics in (* Instantiate the signature. This updates the context because we compute @@ -277,7 +277,7 @@ let evaluate_function_symbolic_synthesize_backward_from_return (config : config) * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let regions_hierarchy = - FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies in let _, ret_inst_sg = symbolic_instantiate_fun_sig ctx fdef.signature regions_hierarchy fdef.kind @@ -466,7 +466,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : decls_ctx) let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in let regions_hierarchy = - FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies in (* Create the continuation to finish the evaluation *) @@ -615,7 +615,7 @@ module Test = struct assert (body.arg_count = 0); (* Create the evaluation context *) - let ctx = initialize_eval_context decls_ctx [] [] [] in + let ctx = initialize_eval_ctx decls_ctx [] [] [] in (* Insert the (uninitialized) local variables *) let ctx = ctx_push_uninitialized_vars ctx body.locals in diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index e56919fa..a2eb2545 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -1628,7 +1628,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) push { value; ty } | AIgnoredMutLoan (opt_bid, child_av) -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); assert (opt_bid = None); (* Simply explore the child *) list_avalues false push_fail child_av @@ -1639,7 +1639,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) { child = child_av; given_back = _; given_back_meta = _ } | AIgnoredSharedLoan child_av -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); (* Simply explore the child *) list_avalues false push_fail child_av) | ABorrow bc -> ( @@ -1659,14 +1659,14 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) push av | AIgnoredMutBorrow (opt_bid, child_av) -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); assert (opt_bid = None); (* Just explore the child *) list_avalues false push_fail child_av | AEndedIgnoredMutBorrow { child = child_av; given_back = _; given_back_meta = _ } -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); (* Just explore the child *) list_avalues false push_fail child_av | AProjSharedBorrow asb -> @@ -1683,7 +1683,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) | ASymbolic _ -> (* For now, we fore all symbolic values containing borrows to be eagerly expanded *) - assert (not (ty_has_borrows ctx.type_context.type_infos ty)) + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty)) and list_values (v : typed_value) : typed_avalue list * typed_value = let ty = v.ty in match v.value with @@ -1732,7 +1732,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) | VSymbolic _ -> (* For now, we fore all symbolic values containing borrows to be eagerly expanded *) - assert (not (ty_has_borrows ctx.type_context.type_infos ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty)); ([], v) in diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index bbf4d9d5..e489ddc3 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -627,7 +627,7 @@ let greedy_expand_symbolics_with_borrows (config : config) : cm_fun = inherit [_] iter_eval_ctx method! visit_VSymbolic _ sv = - if ty_has_borrows ctx.type_context.type_infos sv.sv_ty then + if ty_has_borrows ctx.type_ctx.type_infos sv.sv_ty then raise (FoundSymbolicValue sv) else () diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 1b5b79dd..8536b4ab 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -32,8 +32,7 @@ let expand_primitively_copyable_at_place (config : config) fun cf ctx -> let v = read_place access p ctx in match - find_first_primitively_copyable_sv_with_borrows - ctx.type_context.type_infos v + find_first_primitively_copyable_sv_with_borrows ctx.type_ctx.type_infos v with | None -> cf ctx | Some sv -> @@ -351,7 +350,7 @@ let eval_operand_no_reorganize (config : config) (op : operand) assert ( Option.is_none (find_first_primitively_copyable_sv_with_borrows - ctx.type_context.type_infos v)); + ctx.type_ctx.type_infos v)); (* Actually perform the copy *) let allow_adt_copy = false in let ctx, v = copy_value allow_adt_copy config ctx v in diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli index f8d979f4..b975371c 100644 --- a/compiler/InterpreterExpressions.mli +++ b/compiler/InterpreterExpressions.mli @@ -52,7 +52,7 @@ val eval_operands : Transmits the computed rvalue to the received continuation. - Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant + Note that this function fails on {!Aeneas.Expressions.rvalue.Discriminant}: discriminant reads should have been eliminated from the AST. *) val eval_rvalue_not_global : diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml index c4e180fa..4dabe974 100644 --- a/compiler/InterpreterLoopsFixedPoint.ml +++ b/compiler/InterpreterLoopsFixedPoint.ml @@ -300,7 +300,7 @@ let prepare_ashared_loans (loop_id : LoopId.id option) : cm_fun = let env = List.append fresh_absl env in let ctx = { ctx with env } in - let _, new_ctx_ids_map = compute_context_ids ctx in + let _, new_ctx_ids_map = compute_ctx_ids ctx in (* Synthesize *) match cf ctx with @@ -385,8 +385,8 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) match !fixed_ids with | Some _ -> ctx1 | None -> - let old_ids, _ = compute_context_ids ctx1 in - let new_ids, _ = compute_contexts_ids !ctxs in + let old_ids, _ = compute_ctx_ids ctx1 in + let new_ids, _ = compute_ctxs_ids !ctxs in let blids = BorrowId.Set.diff old_ids.blids new_ids.blids in let aids = AbstractionId.Set.diff old_ids.aids new_ids.aids in (* End those borrows and abstractions *) @@ -409,7 +409,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) ctxs := List.map (end_borrows_abs blids aids) !ctxs; (* Note that the fixed ids are given by the original context, from *before* we introduce fresh abstractions/reborrows for the shared values *) - fixed_ids := Some (fst (compute_context_ids ctx0)); + fixed_ids := Some (fst (compute_ctx_ids ctx0)); ctx1 in @@ -424,12 +424,12 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) intersection of ids between the original environment and the list of new environments *) let compute_fixed_ids (ctxl : eval_ctx list) : ids_sets = - let fixed_ids, _ = compute_context_ids ctx0 in + let fixed_ids, _ = compute_ctx_ids ctx0 in let { aids; blids; borrow_ids; loan_ids; dids; rids; sids } = fixed_ids in let sids = ref sids in List.iter (fun ctx -> - let fixed_ids, _ = compute_context_ids ctx in + let fixed_ids, _ = compute_ctx_ids ctx in sids := SymbolicValueId.Set.inter !sids fixed_ids.sids) ctxl; let sids = !sids in @@ -568,7 +568,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) InterpreterBorrows.end_abstraction_no_synth config abs_id ctx in (* Explore the context, and check which abstractions are not there anymore *) - let ids, _ = compute_context_ids ctx in + let ids, _ = compute_ctx_ids ctx in let ended_ids = AbstractionId.Set.diff !fp_aids ids.aids in add_ended_aids rg_id ended_ids) ctx.region_groups @@ -840,8 +840,8 @@ let compute_fixed_point_id_correspondance (fixed_ids : ids_sets) let compute_fp_ctx_symbolic_values (ctx : eval_ctx) (fp_ctx : eval_ctx) : SymbolicValueId.Set.t * symbolic_value list = - let old_ids, _ = compute_context_ids ctx in - let fp_ids, fp_ids_maps = compute_context_ids fp_ctx in + let old_ids, _ = compute_ctx_ids ctx in + let fp_ids, fp_ids_maps = compute_ctx_ids fp_ctx in let fresh_sids = SymbolicValueId.Set.diff fp_ids.sids old_ids.sids in (* Compute the set of symbolic values which appear in shared values inside diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index 8d485483..445e5abf 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -326,8 +326,8 @@ let mk_collapse_ctx_merge_duplicate_funs (loop_id : LoopId.id) (ctx : eval_ctx) let _ = let _, ty0, _ = ty_as_ref ty0 in let _, ty1, _ = ty_as_ref ty1 in - assert (not (ty_has_borrows ctx.type_context.type_infos ty0)); - assert (not (ty_has_borrows ctx.type_context.type_infos ty1)) + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty0)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty1)) in (* Same remarks as for [merge_amut_borrows] *) @@ -543,11 +543,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) (* Construct the joined context - of course, the type, fun, etc. contexts * should be the same in the two contexts *) let { - type_context; - fun_context; - global_context; - trait_decls_context; - trait_impls_context; + type_ctx; + fun_ctx; + global_ctx; + trait_decls_ctx; + trait_impls_ctx; region_groups; type_vars; const_generic_vars; @@ -559,11 +559,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) ctx0 in let { - type_context = _; - fun_context = _; - global_context = _; - trait_decls_context = _; - trait_impls_context = _; + type_ctx = _; + fun_ctx = _; + global_ctx = _; + trait_decls_ctx = _; + trait_impls_ctx = _; region_groups = _; type_vars = _; const_generic_vars = _; @@ -577,11 +577,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) let ended_regions = RegionId.Set.union ended_regions0 ended_regions1 in Ok { - type_context; - fun_context; - global_context; - trait_decls_context; - trait_impls_context; + type_ctx; + fun_ctx; + global_ctx; + trait_decls_ctx; + trait_impls_ctx; region_groups; type_vars; const_generic_vars; @@ -621,7 +621,7 @@ let destructure_new_abs (loop_id : LoopId.id) contexts we join don't have non-fixed abstractions with the same ids. *) let refresh_abs (old_abs : AbstractionId.Set.t) (ctx : eval_ctx) : eval_ctx = - let ids, _ = compute_context_ids ctx in + let ids, _ = compute_ctx_ids ctx in let abs_to_refresh = AbstractionId.Set.diff ids.aids old_abs in let aids_subst = List.map diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 90559c29..2a688fa7 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -658,7 +658,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct else ( (* The caller should have checked that the symbolic values don't contain borrows *) - assert (not (ty_has_borrows S.ctx.type_context.type_infos sv0.sv_ty)); + assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv0.sv_ty)); (* We simply introduce a fresh symbolic value *) mk_fresh_symbolic_value sv0.sv_ty) @@ -669,7 +669,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct - there are no borrows in the "regular" value If there are loans in the regular value, raise an exception. *) - assert (not (ty_has_borrows S.ctx.type_context.type_infos sv.sv_ty)); + assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv.sv_ty)); assert (not (value_has_borrows S.ctx v.value)); let value_is_left = not left in (match InterpreterBorrowsCore.get_first_loan_in_value v with diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 30b7b333..97c8bcd6 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -728,7 +728,13 @@ let create_push_abstractions_from_abs_region_groups to a trait clause but directly to the method provided in the trait declaration. *) let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) - : fun_id_or_trait_method_ref * generic_args * fun_decl * inst_fun_sig = + : + fun_id_or_trait_method_ref + * generic_args + * (generic_args * trait_instance_id) option + * fun_decl + * region_var_groups + * inst_fun_sig = match call.func with | FnOpMove _ -> (* Closure case: TODO *) @@ -747,13 +753,13 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) let tr_self = UnknownTrait __FUNCTION__ in let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular fid) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let inst_sg = instantiate_fun_sig ctx func.generics tr_self def.signature regions_hierarchy in - (func.func, func.generics, def, inst_sg) + (func.func, func.generics, None, def, regions_hierarchy, inst_sg) | FunId (FAssumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") @@ -793,7 +799,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) let fid : fun_id = FRegular id in let regions_hierarchy = LlbcAstUtils.FunIdMap.find fid - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let inst_sg = instantiate_fun_sig ctx generics tr_self @@ -806,7 +812,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) we also need to update the generics. *) let func = FunId fid in - (func, generics, method_def, inst_sg) + ( func, + generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ) | None -> (* If not found, lookup the methods provided by the trait *declaration* (remember: for now, we forbid overriding provided methods) *) @@ -853,14 +864,19 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) method_def.signature.parent_params_info)); let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular method_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let tr_self = TraitRef trait_ref in let inst_sg = instantiate_fun_sig ctx all_generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg)) + ( func.func, + func.generics, + Some (all_generics, tr_self), + method_def, + regions_hierarchy, + inst_sg )) | _ -> (* We are using a local clause - we lookup the trait decl *) let trait_decl = @@ -884,14 +900,19 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) (* Instantiate *) let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular method_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let tr_self = TraitRef trait_ref in let inst_sg = instantiate_fun_sig ctx generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg))) + ( func.func, + func.generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ))) (** Evaluate a statement *) let rec eval_statement (config : config) (st : statement) : st_cm_fun = @@ -1277,14 +1298,15 @@ and eval_transparent_function_call_concrete (config : config) and eval_transparent_function_call_symbolic (config : config) (call : call) : st_cm_fun = fun cf ctx -> - let func, generics, def, inst_sg = + let func, generics, trait_method_generics, def, regions_hierarchy, inst_sg = eval_transparent_function_call_symbolic_inst call ctx in (* Sanity check *) assert (List.length call.args = List.length def.signature.inputs); (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config func inst_sg generics - call.args call.dest cf ctx + eval_function_call_symbolic_from_inst_sig config func def.signature + regions_hierarchy inst_sg generics trait_method_generics call.args call.dest + cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1298,8 +1320,11 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) : trait ref as input. *) and eval_function_call_symbolic_from_inst_sig (config : config) - (fid : fun_id_or_trait_method_ref) (inst_sg : inst_fun_sig) - (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun = + (fid : fun_id_or_trait_method_ref) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig) + (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) + (args : operand list) (dest : place) : st_cm_fun = fun cf ctx -> log#ldebug (lazy @@ -1378,8 +1403,9 @@ and eval_function_call_symbolic_from_inst_sig (config : config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids generics args - args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy + abs_ids generics trait_method_generics args args_places ret_spc dest_place + expr in let cc = comp cc cf_call in @@ -1450,7 +1476,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) * this is a current limitation of our synthesis *) assert ( List.for_all - (fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty)) + (fun ty -> not (ty_has_borrows ctx.type_ctx.type_infos ty)) generics.types); (* There are two cases (and this is extremely annoying): @@ -1468,7 +1494,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) (* In symbolic mode, the behaviour of a function call is completely defined * by the signature of the function: we thus simply generate correctly * instantiated signatures, and delegate the work to an auxiliary function *) - let inst_sig = + let sg, regions_hierarchy, inst_sig = match fid with | BoxFree -> (* Should have been treated above *) @@ -1476,18 +1502,20 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) | _ -> let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FAssumed fid) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in (* There shouldn't be any reference to Self *) let tr_self = UnknownTrait __FUNCTION__ in - instantiate_fun_sig ctx generics tr_self - (Assumed.get_assumed_fun_sig fid) - regions_hierarchy + let sg = Assumed.get_assumed_fun_sig fid in + let inst_sg = + instantiate_fun_sig ctx generics tr_self sg regions_hierarchy + in + (sg, regions_hierarchy, inst_sg) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) - inst_sig generics args dest cf ctx + eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg + regions_hierarchy inst_sig generics None args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : config) (body : statement) : st_cm_fun = diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index e04a6b90..a1a06ee5 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -265,7 +265,7 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : eval_ctx) inherit [_] iter_typed_value method! visit_symbolic_value _ s = - if ty_has_borrow_under_mut ctx.type_context.type_infos s.sv_ty then + if ty_has_borrow_under_mut ctx.type_ctx.type_infos s.sv_ty then raise Found else () end @@ -288,15 +288,15 @@ let rvalue_get_place (rv : rvalue) : place option = (** See {!ValuesUtils.symbolic_value_has_borrows} *) let symbolic_value_has_borrows (ctx : eval_ctx) (sv : symbolic_value) : bool = - ValuesUtils.symbolic_value_has_borrows ctx.type_context.type_infos sv + ValuesUtils.symbolic_value_has_borrows ctx.type_ctx.type_infos sv (** See {!ValuesUtils.value_has_borrows}. *) let value_has_borrows (ctx : eval_ctx) (v : value) : bool = - ValuesUtils.value_has_borrows ctx.type_context.type_infos v + ValuesUtils.value_has_borrows ctx.type_ctx.type_infos v (** See {!ValuesUtils.value_has_loans_or_borrows}. *) let value_has_loans_or_borrows (ctx : eval_ctx) (v : value) : bool = - ValuesUtils.value_has_loans_or_borrows ctx.type_context.type_infos v + ValuesUtils.value_has_loans_or_borrows ctx.type_ctx.type_infos v (** See {!ValuesUtils.value_has_loans}. *) let value_has_loans (v : value) : bool = ValuesUtils.value_has_loans v @@ -401,19 +401,19 @@ let compute_env_elem_ids (x : env_elem) : ids_sets * ids_to_values = compute_env_ids [ x ] (** Compute the sets of ids found in a list of contexts. *) -let compute_contexts_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values = +let compute_ctxs_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values = let compute, get_ids, get_ids_to_values = compute_ids () in List.iter (compute#visit_eval_ctx ()) ctxl; (get_ids (), get_ids_to_values ()) (** Compute the sets of ids found in a context. *) -let compute_context_ids (ctx : eval_ctx) : ids_sets * ids_to_values = - compute_contexts_ids [ ctx ] +let compute_ctx_ids (ctx : eval_ctx) : ids_sets * ids_to_values = + compute_ctxs_ids [ ctx ] (** **WARNING**: this function doesn't compute the normalized types (for the trait type aliases). This should be computed afterwards. *) -let initialize_eval_context (ctx : decls_ctx) +let initialize_eval_ctx (ctx : decls_ctx) (region_groups : RegionGroupId.id list) (type_vars : type_var list) (const_generic_vars : const_generic_var list) : eval_ctx = reset_global_counters (); @@ -427,11 +427,11 @@ let initialize_eval_context (ctx : decls_ctx) const_generic_vars) in { - type_context = ctx.type_ctx; - fun_context = ctx.fun_ctx; - global_context = ctx.global_ctx; - trait_decls_context = ctx.trait_decls_ctx; - trait_impls_context = ctx.trait_impls_ctx; + type_ctx = ctx.type_ctx; + fun_ctx = ctx.fun_ctx; + global_ctx = ctx.global_ctx; + trait_decls_ctx = ctx.trait_decls_ctx; + trait_impls_ctx = ctx.trait_impls_ctx; region_groups; type_vars; const_generic_vars; diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index fa0d7436..b87cdff7 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -768,7 +768,7 @@ let check_symbolic_values (ctx : eval_ctx) : unit = assert (info.env_count = 0 || info.aproj_borrows = []); (* A symbolic value containing borrows can't be duplicated (i.e., copied): * it must be expanded first *) - if ty_has_borrows ctx.type_context.type_infos info.ty then + if ty_has_borrows ctx.type_ctx.type_infos info.ty then assert (info.env_count <= 1); (* A duplicated symbolic value is necessarily primitively copyable *) assert (info.env_count <= 1 || ty_is_primitively_copyable info.ty); diff --git a/compiler/Main.ml b/compiler/Main.ml index 835b9088..0b8ec439 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -120,6 +120,9 @@ let () = " Generate a default lakefile.lean (Lean only)" ); ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC"); ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error"); + ( "-split-fwd-back", + Arg.Clear return_back_funs, + " Split the forward and backward functions." ); ] in @@ -193,9 +196,6 @@ let () = let _ = match !backend with | FStar -> - (* Some patterns are not supported *) - decompose_monadic_let_bindings := false; - decompose_nested_let_patterns := false; (* F* can disambiguate the field names *) record_fields_short_names := true | Coq -> diff --git a/compiler/Print.ml b/compiler/Print.ml index 0e2ec1fc..8999c77d 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -409,11 +409,11 @@ module Contexts = struct } let eval_ctx_to_fmt_env (ctx : eval_ctx) : fmt_env = - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in - let trait_decls = ctx.trait_decls_context.trait_decls in - let trait_impls = ctx.trait_impls_context.trait_impls in + let type_decls = ctx.type_ctx.type_decls in + let fun_decls = ctx.fun_ctx.fun_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in (* Below: it is always safe to omit fields - if an id can't be found at printing time, we print the id (in raw form) instead of the name it designates. *) diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 2fe5843e..66475d02 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -103,10 +103,13 @@ let adt_field_names (env : fmt_env) = Print.Types.adt_field_names (fmt_env_to_llbc_fmt_env env) let option_to_string = Print.option_to_string -let type_var_to_string = Print.Types.type_var_to_string -let const_generic_var_to_string = Print.Types.const_generic_var_to_string -let integer_type_to_string = Print.Values.integer_type_to_string let literal_type_to_string = Print.Values.literal_type_to_string +let type_var_to_string (v : type_var) = "(" ^ v.name ^ ": Type)" + +let const_generic_var_to_string (v : const_generic_var) = + "(" ^ v.name ^ " : " ^ literal_type_to_string v.ty ^ ")" + +let integer_type_to_string = Print.Values.integer_type_to_string let scalar_value_to_string = Print.Values.scalar_value_to_string let literal_to_string = Print.Values.literal_to_string @@ -203,13 +206,12 @@ and trait_instance_id_to_string (env : fmt_env) (inside : bool) | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")" let trait_clause_to_string (env : fmt_env) (clause : trait_clause) : string = - let clause_id = trait_clause_id_to_string env clause.clause_id in let trait_id = trait_decl_id_to_string env clause.trait_id in let generics = generic_args_to_strings env true clause.generics in let generics = if generics = [] then "" else " " ^ String.concat " " generics in - "[" ^ clause_id ^ "]: " ^ trait_id ^ generics + trait_id ^ generics let generic_params_to_strings (env : fmt_env) (generics : generic_params) : string list = @@ -543,9 +545,9 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) let app, args = destruct_apps e in (* Convert to string *) app_to_string env inside indent indent_incr app args - | Abs _ -> - let xl, e = destruct_abs_list e in - let e = abs_to_string env indent indent_incr xl e in + | Lambda _ -> + let xl, e = destruct_lambdas e in + let e = lambda_to_string env indent indent_incr xl e in if inside then "(" ^ e ^ ")" else e | Qualif _ -> (* Qualifier without arguments *) @@ -609,35 +611,36 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string) * expression *) let app, generics = match app.e with - | Qualif qualif -> + | Qualif qualif -> ( (* Qualifier case *) - (* Convert the qualifier identifier *) - let qualif_s = - match qualif.id with - | FunOrOp fun_id -> fun_or_op_id_to_string env fun_id - | Global global_id -> global_decl_id_to_string env global_id - | AdtCons adt_cons_id -> - let variant_s = - adt_variant_to_string env adt_cons_id.adt_id - adt_cons_id.variant_id - in - ConstStrings.constructor_prefix ^ variant_s - | Proj { adt_id; field_id } -> - let adt_s = adt_variant_to_string env adt_id None in - let field_s = adt_field_to_string env adt_id field_id in - (* Adopting an F*-like syntax *) - ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s - | TraitConst (trait_ref, generics, const_name) -> - let trait_ref = trait_ref_to_string env true trait_ref in - let generics_s = generic_args_to_string env generics in + match qualif.id with + | FunOrOp fun_id -> + let generics = generic_args_to_strings env true qualif.generics in + let qualif_s = fun_or_op_id_to_string env fun_id in + (qualif_s, generics) + | Global global_id -> + let generics = generic_args_to_strings env true qualif.generics in + (global_decl_id_to_string env global_id, generics) + | AdtCons adt_cons_id -> + let variant_s = + adt_variant_to_string env adt_cons_id.adt_id + adt_cons_id.variant_id + in + (ConstStrings.constructor_prefix ^ variant_s, []) + | Proj { adt_id; field_id } -> + let adt_s = adt_variant_to_string env adt_id None in + let field_s = adt_field_to_string env adt_id field_id in + (* Adopting an F*-like syntax *) + (ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s, []) + | TraitConst (trait_ref, generics, const_name) -> + let trait_ref = trait_ref_to_string env true trait_ref in + let generics_s = generic_args_to_string env generics in + let qualif = if generics <> empty_generic_args then "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name else trait_ref ^ "." ^ const_name - in - (* Convert the type instantiation *) - let generics = generic_args_to_strings env true qualif.generics in - (* *) - (qualif_s, generics) + in + (qualif, [])) | _ -> (* "Regular" expression case *) let inside = args <> [] || (args = [] && inside) in @@ -660,7 +663,7 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string) (* Add parentheses *) if all_args <> [] && inside then "(" ^ e ^ ")" else e -and abs_to_string (env : fmt_env) (indent : string) (indent_incr : string) +and lambda_to_string (env : fmt_env) (indent : string) (indent_incr : string) (xl : typed_pattern list) (e : texpression) : string = let xl = List.map (typed_pattern_to_string env) xl in let e = texpression_to_string env false indent indent_incr e in @@ -708,21 +711,14 @@ and loop_to_string (env : fmt_env) (indent : string) (indent_incr : string) ^ String.concat "; " (List.map (var_to_string env) loop.inputs) ^ "]" in - let back_output_tys = - let tys = - match loop.back_output_tys with - | None -> "" - | Some tys -> String.concat "; " (List.map (ty_to_string env false) tys) - in - "back_output_tys: [" ^ tys ^ "]" - in + let output_ty = "output_ty: " ^ ty_to_string env false loop.output_ty in let fun_end = texpression_to_string env false indent2 indent_incr loop.fun_end in let loop_body = texpression_to_string env false indent2 indent_incr loop.loop_body in - "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n" + "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ output_ty ^ "\n" ^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n" ^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n" ^ indent ^ "}" diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 8d39cc69..a879ba37 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref = (** A function id for a non-assumed function *) type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option + fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -684,7 +684,7 @@ type expression = field accesses with calls to projectors over fields (when there are clashes of field names, some provers like F* get pretty bad...) *) - | Abs of typed_pattern * texpression (** Lambda abstraction: [fun x -> e] *) + | Lambda of typed_pattern * texpression (** Lambda abstraction: [λ x => e] *) | Qualif of qualif (** A top-level qualifier *) | Let of bool * typed_pattern * texpression * texpression (** Let binding. @@ -754,9 +754,7 @@ and loop = { inputs : var list; inputs_lvs : typed_pattern list; (** The inputs seen as patterns. See {!fun_body}. *) - back_output_tys : ty list option; - (** The types of the given back values, if we ar esynthesizing a backward - function *) + output_ty : ty; (** The output type of the loop *) loop_body : texpression; } @@ -860,8 +858,8 @@ type fun_effect_info = { the set [{ forward function } U { backward functions }]. We need this because of the option {!val:Config.backward_no_state_update}: - if it is [true], then in case of a backward function {!stateful} is [false], - but we might need to know whether the corresponding forward function + if it is [true], then in case of a backward function {!stateful} might be + [false], but we might need to know whether the corresponding forward function is stateful or not. *) stateful : bool; (** [true] if the function is stateful (updates a state) *) @@ -873,22 +871,108 @@ type fun_effect_info = { } [@@deriving show] -(** Meta information about a function signature *) -type fun_sig_info = { +type inputs_info = { has_fuel : bool; - (* TODO: add [num_fwd_inputs_no_fuel_no_state] *) - num_fwd_inputs_with_fuel_no_state : int; - (** The number of input types for forward computation, with the fuel (if used) + num_inputs_no_fuel_no_state : int; + (** The number of input types ignoring the fuel (if used) + and ignoring the state (if used) *) + num_inputs_with_fuel_no_state : int; + (** The number of input types, with the fuel (if used) and ignoring the state (if used) *) - num_fwd_inputs_with_fuel_with_state : int; - (** The number of input types for forward computation, with fuel and state (if used) *) - num_back_inputs_no_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - ignoring the state (if there is one) *) - num_back_inputs_with_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - with the state (if there is one) *) + num_inputs_with_fuel_with_state : int; + (** The number of input types, with fuel and state (if used) *) +} +[@@deriving show] + +(** Meta information about a function signature *) +type fun_sig_info = { + fwd_info : inputs_info; + (** Information about the inputs of the forward function *) effect_info : fun_effect_info; + ignore_output : bool; + (** In case we merge the forward/backward functions: should we ignore + the output (happens for forward functions if the output type is + [unit] and there are non-filtered backward functions)? + *) +} +[@@deriving show] + +type back_sg_info = { + inputs : (string option * ty) list; + (** The additional inputs of the backward function *) + inputs_no_state : (string option * ty) list; + outputs : ty list; + (** The "decomposed" list of outputs. + + The list contains all the types of + all the given back values (there is at most one type per forward + input argument). + + Ex.: + {[ + fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; + ]} + Decomposed outputs: + - forward function: [[T]] + - backward function: [[T; T]] (for "x" and "y") + + Non-decomposed ouputs (if the function can fail, but is not stateful): + - [result T] + - [[result (T * T)]] + *) + output_names : string option list; + (** The optional names for the backward outputs. + We derive those from the names of the inputs of the original LLBC + function. *) + effect_info : fun_effect_info; + filter : bool; (** Should we filter this backward function? *) +} +[@@deriving show] + +(** A *decomposed* function signature. *) +type decomposed_fun_sig = { + generics : generic_params; + (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) + llbc_generics : Types.generic_params; + (** We use the LLBC generics to generate "pretty" names, for instance + for the variables we introduce for the trait clauses: we derive + those names from the types, and when doing so it is more meaningful + to derive them from the original LLBC types from before the + simplification of types like boxes and references. *) + preds : predicates; + fwd_inputs : ty list; + (** The types of the inputs of the forward function. + + Note that those input types take include the [fuel] parameter, + if the function uses fuel for termination, and the [state] parameter, + if the function is stateful. + + For instance, if we have the following Rust function: + {[ + fn f(x : int); + ]} + + If we translate it to a stateful function which uses fuel we get: + {[ + val f : nat -> int -> state -> result (state * unit); + ]} + + In particular, the list of input types is: [[nat; int; state]]. + *) + fwd_output : ty; + (** The "pure" output type of the forward function. + + Note that this type doesn't contain the "effect" of the function (i.e., + we haven't added the [state] if it is a stateful function and haven't + wrapped the type in a [result]). Also, this output type is only about + the forward function (it doesn't contain the type of the closures we + return for the backward functions, in case we merge the forward and + backward functions). + *) + back_sg : back_sg_info RegionGroupId.Map.t; + (** Information about the backward functions *) + fwd_info : fun_sig_info; + (** Additional information about the forward function *) } [@@deriving show] @@ -906,15 +990,15 @@ type fun_sig_info = { [in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state -> result (state & (back_out0 & ... & back_outp))] (* state-error *) - Note that a stateful backward function may take two states as inputs: the - state received by the associated forward function, and the state at which - the backward is called. This leads to code of the following shape: + Note that a stateful backward function may take two states as inputs: the + state received by the associated forward function, and the state at which + the backward is called. This leads to code of the following shape: - {[ - (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd - ... // the state may be updated - (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back - ]} + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} The function's type should be given by [mk_arrows sig.inputs sig.output]. We provide additional meta-information with {!fun_sig.info}: @@ -962,40 +1046,14 @@ type fun_sig = { be a tuple with a [state] if the function is stateful, and will be wrapped in a [result] if the function can fail. *) - doutputs : ty list; - (** The "decomposed" list of outputs. - - In case of a forward function, the list has length = 1, for the - type of the returned value. - - In case of backward function, the list contains all the types of - all the given back values (there is at most one type per forward - input argument). - - Ex.: - {[ - fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; - ]} - Decomposed outputs: - - forward function: [[T]] - - backward function: [[T; T]] (for "x" and "y") - - Non-decomposed ouputs (if the function can fail, but is not stateful): - - [result T] - - [[result (T * T)]] - *) - info : fun_sig_info; (** Additional information *) + fwd_info : fun_sig_info; + (** Additional information about the forward function. *) + back_effect_info : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] (** An instantiated function signature. See {!fun_sig} *) -type inst_fun_sig = { - inputs : ty list; - output : ty; - doutputs : ty list; - info : fun_sig_info; -} -[@@deriving show] +type inst_fun_sig = { inputs : ty list; output : ty } [@@deriving show] type fun_body = { inputs : var list; @@ -1020,7 +1078,7 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : T.RegionGroupId.id option; + back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 959ec1c8..ec64df21 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -11,6 +11,10 @@ let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let fmt = trans_ctx_to_pure_fmt_env ctx in PrintPure.fun_decl_to_string fmt def +let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx in + PrintPure.fun_sig_to_string fmt sg + (** Small utility. We sometimes have to insert new fresh variables in a function body, in which @@ -385,17 +389,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, arg = update_texpression arg ctx in let e = App (app, arg) in (ctx, e) - | Abs (x, e) -> update_abs x e ctx | Qualif _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Loop loop -> update_loop loop ctx | StructUpdate supd -> update_struct_update supd ctx + | Lambda (lb, e) -> update_lambda lb e ctx | Meta (meta, e) -> update_emeta meta e ctx in (ctx, { e; ty }) (* *) - and update_abs (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : + and update_lambda (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) let ctx = add_left_constraint x ctx in @@ -404,7 +408,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (* Update the abstracted value *) let x = update_typed_pattern ctx x in (* Put together *) - (ctx, Abs (x, e)) + (ctx, Lambda (x, e)) (* *) and update_let (monadic : bool) (lv : typed_pattern) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = @@ -455,7 +459,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } = loop @@ -474,7 +478,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in @@ -641,6 +645,135 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = obj#visit_texpression () body.body } in { def with body = Some body } +(** Simplify the let-bindings by performing the following rewritings: + + Move inner let-bindings outside. This is especially useful to simplify + the backward expressions, when we merge the forward/backward functions. + Note that the rule is also applied with monadic let-bindings. + {[ + let x := + let y := ... in + e + + ~~> + + let y := ... in + let x := e + ]} + + Simplify panics and returns: + {[ + let x ← fail + ... + ~~> + fail + + let x ← return y + ... + ~~> + let x := y + ... + ]} + + Simplify tuples: + {[ + let (y0, y1) := (x0, x1) in + ... + ~~> + let y0 = x0 in + let y1 = x1 in + ... + ]} + + Simplify arrows: + {[ + let f := fun x => g x in + ... + ~~> + let f := g in + ... + ]} + *) +let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let env monadic lv rv next = + match rv.e with + | Let (rmonadic, rlv, rrv, rnext) -> + (* Case 1: move the inner let outside then re-visit *) + let rnext1 = Let (monadic, lv, rnext, next) in + let rnext1 = { ty = next.ty; e = rnext1 } in + self#visit_Let env rmonadic rlv rrv rnext1 + | App + ( { + e = + Qualif + { + id = + AdtCons + { + adt_id = TAssumed TResult; + variant_id = Some variant_id; + }; + generics = _; + }; + ty = _; + }, + x ) -> + (* return/fail case *) + if variant_id = result_return_id then + (* Return case - note that the simplification we just perform + might have unlocked the tuple simplification below *) + self#visit_Let env false lv x next + else if variant_id = result_fail_id then + (* Fail case *) + self#visit_expression env rv.e + else raise (Failure "Unexpected") + | App _ -> + (* This might be the tuple case *) + if not monadic then + match + (opt_dest_struct_pattern lv, opt_dest_tuple_texpression rv) + with + | Some pats, Some vals -> + (* Tuple case *) + let pat_vals = List.combine pats vals in + let e = + List.fold_right + (fun (pat, v) next -> mk_let false pat v next) + pat_vals next + in + super#visit_expression env e.e + | _ -> super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next + | Lambda _ -> + if not monadic then + (* Arrow case *) + let pats, e = destruct_lambdas rv in + let g, args = destruct_apps e in + if List.length pats = List.length args then + (* Check if the arguments are exactly the lambdas *) + let check_pat_arg ((pat, arg) : typed_pattern * texpression) = + match (pat.value, arg.e) with + | PatVar (v, _), Var vid -> v.id = vid + | _ -> false + in + if List.for_all check_pat_arg (List.combine pats args) then + self#visit_Let env monadic lv g next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next + | _ -> super#visit_Let env monadic lv rv next + end + in + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } + (** Inline the useless variable (re-)assignments: A lot of intermediate variable assignments are introduced through the @@ -667,8 +800,8 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = leave the let-bindings where they are, and eliminated them in a subsequent pass (if they are useless). *) -let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) - (inline_pure : bool) (def : fun_decl) : fun_decl = +let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool) + ~(inline_const : bool) ~(inline_pure : bool) (def : fun_decl) : fun_decl = let obj = object (self) inherit [_] map_expression as super @@ -693,15 +826,31 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) | _ -> false in (* And either: - * 2.1 the right-expression is a variable, a global or a const generic var *) + 2.1 the right-expression is a variable, a global or a const generic var *) let var_or_global = is_var re || is_cvar re || is_global re in (* Or: - * 2.2 the right-expression is a constant value, an ADT value, - * a projection or a primitive function call *and* the flag - * [inline_pure] is set *) + 2.2 the right-expression is a constant-value and we inline constant values, + *or* it is a qualif with no arguments (we consider this as a const) *) + let const_re = + inline_const + && + let is_const_adt = + let app, args = destruct_apps re in + if args = [] then + match app.e with + | Qualif _ -> true + | StructUpdate upd -> upd.updates = [] + | _ -> false + else false + in + is_const re || is_const_adt + in + (* Or: + 2.3 the right-expression is an ADT value, a projection or a + primitive function call *and* the flag [inline_pure] is set *) let pure_re = - is_const re - || + inline_pure + && let app, _ = destruct_apps re in match app.e with | Qualif qualif -> ( @@ -716,7 +865,7 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) | _ -> false in let filter = - filter_left && (var_or_global || (inline_pure && pure_re)) + filter_left && (var_or_global || const_re || pure_re) in (* Update the rhs (we may perform substitutions inside, and it is @@ -776,9 +925,11 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) in { def with body = Some body } -(** Given a forward or backward function call, is there, for every execution +(** For the cases where we split the forward/backward functions. + + Given a forward or backward function call, is there, for every execution path, a child backward function called later with exactly the same input - list prefix? We use this to filter useless function calls: if there are + list prefix. We use this to filter useless function calls: if there are such child calls, we can remove this one (in case its outputs are not used). We do this check because we can't simply remove function calls whose @@ -890,12 +1041,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) + | Lambda (_, e) -> self#visit_texpression env e | App _ -> ( fun () -> match opt_destruct_function_call e with | Some (func1, tys1, args1) -> check_call func1 tys1 args1 | None -> false) - | Abs (_, e) -> self#visit_texpression env e | Qualif _ -> (* Note that this case includes functions without arguments *) fun () -> false @@ -973,10 +1124,23 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with | Var _ | CVar _ | Const _ | App _ | Qualif _ - | Switch (_, _) | Meta (_, _) - | StructUpdate _ | Abs _ -> + | StructUpdate _ | Lambda _ -> super#visit_expression env e + | Switch (scrut, switch) -> ( + match switch with + | If (_, _) -> super#visit_expression env e + | Match branches -> + (* Simplify the branches *) + let simplify_branch (br : match_branch) = + (* Compute the set of values used inside the branch *) + let branch, used = self#visit_texpression env br.branch in + (* Simplify the pattern *) + let pat, _ = filter_typed_pattern (used ()) br.pat in + { pat; branch } + in + super#visit_expression env + (Switch (scrut, Match (List.map simplify_branch branches)))) | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) let e, used = self#visit_texpression env e in @@ -1008,17 +1172,21 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* We need to check if there is a child call - see - * the comments for: - * [expression_contains_child_call_in_all_paths] *) - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid lp_id - rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () + (* If we split the forward/backward functions. + + We need to check if there is a child call - see + the comments for: + [expression_contains_child_call_in_all_paths] *) + if not !Config.return_back_funs then + let has_child_call = + expression_contains_child_call_in_all_paths ctx fid + lp_id rg_id tys args e + in + if has_child_call then (* Filter *) + (e.e, fun _ -> used) + else (* No child call: don't filter *) + dont_filter () + else dont_filter () | _ -> (* Not an LLBC function call or not allowed to filter: we can't filter *) dont_filter () @@ -1297,7 +1465,8 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = those function bodies into independent definitions while removing occurrences of the {!Pure.Loop} node. *) -let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = +let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : + fun_decl * fun_decl list = match def.body with | None -> (def, []) | Some body -> @@ -1323,77 +1492,55 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = method! visit_Loop env loop = let fun_sig = def.signature in - let fun_sig_info = fun_sig.info in - let fun_effect_info = fun_sig_info.effect_info in + let fwd_info = fun_sig.fwd_info in + let fwd_effect_info = fwd_info.effect_info in + let ignore_output = fwd_info.ignore_output in (* Generate the loop definition *) - let loop_effect_info = - { - stateful_group = fun_effect_info.stateful_group; - stateful = fun_effect_info.stateful; - can_fail = fun_effect_info.can_fail; - can_diverge = fun_effect_info.can_diverge; - is_rec = fun_effect_info.is_rec; - } - in + let loop_fwd_effect_info = fwd_effect_info in - let loop_sig_info = + let loop_fwd_sig_info : fun_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in - let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in - let fwd_state = - fun_sig_info.num_fwd_inputs_with_fuel_with_state - - fun_sig_info.num_fwd_inputs_with_fuel_no_state - in - let num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_no_state + fwd_state + let fwd_info : inputs_info = + let info = fwd_info.fwd_info in + let fwd_state = + info.num_inputs_with_fuel_with_state + - info.num_inputs_with_fuel_no_state + in + { + has_fuel = !Config.use_fuel; + num_inputs_no_fuel_no_state = num_inputs; + num_inputs_with_fuel_no_state = num_inputs + fuel; + num_inputs_with_fuel_with_state = + num_inputs + fuel + fwd_state; + } in - { - has_fuel = !Config.use_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; - num_back_inputs_with_state = - fun_sig_info.num_back_inputs_with_state; - effect_info = loop_effect_info; - } + + { fwd_info; effect_info = loop_fwd_effect_info; ignore_output } in + assert (fun_sig_info_is_wf loop_fwd_sig_info); let inputs_tys = let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in - let state = + let info = fwd_info.fwd_info in + let fwd_state = Collections.List.subslice fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_no_state - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_no_state + info.num_inputs_with_fuel_with_state in - let _, back_inputs = - Collections.List.split_at fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + let back_inputs = + if !Config.return_back_funs then [] + else + snd + (Collections.List.split_at fun_sig.inputs + info.num_inputs_with_fuel_with_state) in - List.concat [ fuel; fwd_inputs; state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] in - let output, doutputs = - match loop.back_output_tys with - | None -> - (* Forward function: the return type is the same as the - parent function *) - (fun_sig.output, fun_sig.doutputs) - | Some doutputs -> - (* Backward function: custom return type *) - let output = mk_simpl_tuple_ty doutputs in - let output = - if loop_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if loop_effect_info.can_fail then mk_result_ty output - else output - in - (output, doutputs) - in + let output = loop.output_ty in let loop_sig = { @@ -1402,8 +1549,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = preds = fun_sig.preds; inputs = inputs_tys; output; - doutputs; - info = loop_sig_info; + fwd_info = loop_fwd_sig_info; + back_effect_info = fun_sig.back_effect_info; } in @@ -1420,7 +1567,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = (* Introduce the forward input state *) let fwd_state_var, fwd_state_lvs = assert ( - loop_effect_info.stateful = Option.is_some loop.input_state); + loop_fwd_effect_info.stateful + = Option.is_some loop.input_state); match loop.input_state with | None -> ([], []) | Some input_state -> @@ -1429,15 +1577,16 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = ([ state_var ], [ state_lvs ]) in - (* Introduce the additional backward inputs *) + (* Introduce the additional backward inputs, if necessary *) let fun_body = Option.get def.body in + let info = fwd_info.fwd_info in let _, back_inputs = Collections.List.split_at fun_body.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let _, back_inputs_lvs = Collections.List.split_at fun_body.inputs_lvs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let inputs = @@ -1508,9 +1657,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = altogether. *) let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - if + (* The question of filtering the forward functions arises only if we split + the forward/backward functions *) + if !Config.return_back_funs then true + else if + (* Note that at this point, the output types are no longer seen as tuples: + * they should be lists of length 1. *) !Config.filter_useless_functions && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] @@ -1816,11 +1968,15 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("intro_struct_updates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the let-bindings *) + let def = simplify_let_bindings ctx def in + log#ldebug + (lazy ("simplify_let_bindings:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Inline the useless variable reassignments *) - let inline_named_vars = true in - let inline_pure = true in let def = - inline_useless_var_reassignments ctx inline_named_vars inline_pure def + inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true + ~inline_pure:true def in log#ldebug (lazy @@ -1867,6 +2023,23 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the let-bindings - some simplifications may have been unlocked by + the pass above (for instance, the lambda simplification) *) + let def = simplify_let_bindings ctx def in + log#ldebug + (lazy + ("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Inline the useless vars again *) + let def = + inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true + ~inline_pure:false def + in + log#ldebug + (lazy + ("inline_useless_var_assignments (pass 2):\n\n" + ^ fun_decl_to_string ctx def ^ "\n")); + (* Decompose the monadic let-bindings - used by Coq *) let def = if !Config.decompose_monadic_let_bindings then ( @@ -1917,67 +2090,6 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* We are done *) def -(** Apply all the micro-passes to a function. - - As loops are initially directly integrated into the function definition, - {!apply_passes_to_def} extracts those loops definitions from the body; - it thus returns the pair: (function def, loop defs). See {!decompose_loops} - for more information. - - Will return [None] if the function is a backward function with no outputs. - - [ctx]: used only for printing. - *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - fun_and_loops option = - (* Debug *) - log#ldebug - (lazy - ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string def.back_id - ^ ")")); - - log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* First, find names for the variables which are unnamed *) - let def = compute_pretty_names def in - log#ldebug - (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* TODO: we might want to leverage more the assignment meta-data, for - * aggregates for instance. *) - - (* TODO: reorder the branches of the matches/switches *) - - (* The meta-information is now useless: remove it. - * Rk.: some passes below use the fact that we removed the meta-data - * (otherwise we would have to "unmeta" expressions before matching) *) - let def = remove_meta def in - log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Remove the backward functions with no outputs. - * Note that the calls to those functions should already have been removed, - * when translating from symbolic to pure. Here, we remove the definitions - * altogether, because they are now useless *) - let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in - let opt_def = filter_if_backward_with_no_outputs def in - - match opt_def with - | None -> - log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); - None - | Some def -> - log#ldebug - (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); - - (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops def in - - (* Apply the remaining passes *) - let f = apply_end_passes_to_def ctx def in - let loops = List.map (apply_end_passes_to_def ctx) loops in - Some { f; loops } - (** Small utility for {!filter_loop_inputs} *) let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = let ls0, ls1 = Collections.List.split_at ls (List.length keep) in @@ -2058,7 +2170,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We only look at the forward inputs, without the state *) let inputs_prefix, _ = Collections.List.split_at body.inputs - decl.signature.info.num_fwd_inputs_with_fuel_no_state + decl.signature.fwd_info.fwd_info.num_inputs_with_fuel_no_state in let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in let inputs_prefix_length = List.length inputs_prefix in @@ -2077,8 +2189,9 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in (* Set the fuel as used *) - let sg_info = decl.signature.info in - if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); + let sg_info = decl.signature.fwd_info in + if sg_info.fwd_info.has_fuel then + set_used (fst (Collections.List.nth inputs 0)); let visitor = object (self : 'self) @@ -2162,37 +2275,54 @@ let filter_loop_inputs (transl : pure_fun_translation list) : let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { generics; llbc_generics; preds; inputs; output; doutputs; info } - = + let { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } = decl.signature in + let { fwd_info; effect_info; ignore_output } = fwd_info in + let { has_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; } = - info + fwd_info in let inputs = filter_prefix used_info inputs in - let info = + let fwd_info = { has_fuel; - num_fwd_inputs_with_fuel_no_state = - num_fwd_inputs_with_fuel_no_state - num_filtered; - num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_with_state - num_filtered; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state = + num_inputs_no_fuel_no_state - num_filtered; + num_inputs_with_fuel_no_state = + num_inputs_with_fuel_no_state - num_filtered; + num_inputs_with_fuel_with_state = + num_inputs_with_fuel_with_state - num_filtered; } in + + let fwd_info = { fwd_info; effect_info; ignore_output } in + assert (fun_sig_info_is_wf fwd_info); let signature = - { generics; llbc_generics; preds; inputs; output; doutputs; info } + { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } in { decl with signature } @@ -2283,6 +2413,68 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* Return *) transl +(** Apply all the micro-passes to a function. + + As loops are initially directly integrated into the function definition, + {!apply_passes_to_def} extracts those loops definitions from the body; + it thus returns the pair: (function def, loop defs). See {!decompose_loops} + for more information. + + Will return [None] if the function is a backward function with no outputs. + + [ctx]: used only for printing. + *) +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : + fun_and_loops option = + (* Debug *) + log#ldebug + (lazy + ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" + ^ Print.option_to_string T.RegionGroupId.to_string def.back_id + ^ ")")); + + log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* First, find names for the variables which are unnamed *) + let def = compute_pretty_names def in + log#ldebug + (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* TODO: we might want to leverage more the assignment meta-data, for + * aggregates for instance. *) + + (* TODO: reorder the branches of the matches/switches *) + + (* The meta-information is now useless: remove it. + * Rk.: some passes below use the fact that we removed the meta-data + * (otherwise we would have to "unmeta" expressions before matching) *) + let def = remove_meta def in + log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Remove the backward functions with no outputs. + + Note that the *calls* to those functions should already have been removed, + when translating from symbolic to pure. Here, we remove the definitions + altogether, because they are now useless *) + let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in + let opt_def = filter_if_backward_with_no_outputs def in + + match opt_def with + | None -> + log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); + None + | Some def -> + log#ldebug + (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); + + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops ctx def in + + (* Apply the remaining passes *) + let f = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + Some { f; loops } + (** Apply the micro-passes to a list of forward/backward translations. This function also extracts the loop definitions from the function body diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index a62a2361..a989fd3b 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -120,7 +120,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (output_ty = e.ty); check_texpression ctx app; check_texpression ctx arg - | Abs (pat, body) -> + | Lambda (pat, body) -> let pat_ty, body_ty = destruct_arrow e.ty in assert (pat.ty = pat_ty); assert (body.ty = body_ty); @@ -188,12 +188,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = List.iter check_branch branches) | Loop loop -> assert (loop.fun_end.ty = e.ty); - (* If we translate forward functions, the type of the loop is the same - as the type of the parent expression - in case of backward functions, - the loop doesn't necessarily give back the same values as the parent - function - *) - assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body | StructUpdate supd -> ( diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 39dcd52d..80bf3c42 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -57,6 +57,23 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType) +let inputs_info_is_wf (info : inputs_info) : bool = + let { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; + } = + info + in + let fuel = if has_fuel then 1 else 0 in + num_inputs_no_fuel_no_state >= 0 + && num_inputs_with_fuel_no_state = num_inputs_no_fuel_no_state + fuel + && num_inputs_with_fuel_with_state >= num_inputs_with_fuel_no_state + +let fun_sig_info_is_wf (info : fun_sig_info) : bool = + inputs_info_is_wf info.fwd_info + let dest_arrow_ty (ty : ty) : ty * ty = match ty with | TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) @@ -187,9 +204,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in - let doutputs = List.map subst sg.doutputs in - let info = sg.info in - { inputs; output; doutputs; info } + { inputs; output } (** We use this to check whether we need to add parentheses around expressions. We only look for outer monadic let-bindings. @@ -200,12 +215,14 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> - false + | Var _ | CVar _ | Const _ | App _ | Qualif _ | StructUpdate _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> let_group_requires_parentheses next_e + | Lambda (_, _) -> + (* Being conservative here *) + true | Loop _ -> (* Should have been eliminated *) raise (Failure "Unreachable") @@ -304,14 +321,26 @@ let destruct_apps (e : texpression) : texpression * texpression list = (** Make an [App (app, arg)] expression *) let mk_app (app : texpression) (arg : texpression) : texpression = + let raise_or_return msg = + if !Config.fail_hard then raise (Failure msg) + else + let e = App (app, arg) in + (* Dummy type - TODO: introduce an error type *) + let ty = app.ty in + { e; ty } + in match app.ty with | TArrow (ty0, ty1) -> (* Sanity check *) - assert (ty0 = arg.ty); - let e = App (app, arg) in - let ty = ty1 in - { e; ty } - | _ -> raise (Failure "Expected an arrow type") + if + (* TODO: we need to normalize the types *) + !Config.type_check_pure_code && ty0 <> arg.ty + then raise_or_return "App: wrong input type" + else + let e = App (app, arg) in + let ty = ty1 in + { e; ty } + | _ -> raise_or_return "Expected an arrow type" (** The reverse of {!destruct_apps} *) let mk_apps (app : texpression) (args : texpression list) : texpression = @@ -356,18 +385,6 @@ let opt_destruct_tuple (ty : ty) : ty list option = Some generics.types | _ -> None -let mk_abs (x : typed_pattern) (e : texpression) : texpression = - let ty = TArrow (x.ty, e.ty) in - let e = Abs (x, e) in - { e; ty } - -let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression = - match e.e with - | Abs (x, e') -> - let xl, e'' = destruct_abs_list e' in - (x :: xl, e'') - | _ -> ([], e) - let destruct_arrow (ty : ty) : ty * ty = match ty with | TArrow (ty0, ty1) -> (ty0, ty1) @@ -431,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty = let mk_bool_ty : ty = TLiteral TBool let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args) +let ty_is_unit ty : bool = ty = mk_unit_ty let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = TTuple; variant_id = None } in @@ -698,3 +716,39 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) let info = TypeDeclId.Map.find id ctx in info.is_tuple_struct | TAssumed _ -> false + +let mk_lambda (x : typed_pattern) (e : texpression) : texpression = + let ty = TArrow (x.ty, e.ty) in + let e = Lambda (x, e) in + { e; ty } + +let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : + texpression = + let pat = PatVar (var, mp) in + let pat = { value = pat; ty = var.ty } in + mk_lambda pat e + +let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) + (e : texpression) : texpression = + let vars = List.combine vars mps in + List.fold_right (fun (v, mp) e -> mk_lambda_from_var v mp e) vars e + +let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = + match e.e with + | Lambda (pat, e) -> + let pats, e = destruct_lambdas e in + (pat :: pats, e) + | _ -> ([], e) + +let opt_dest_tuple_texpression (e : texpression) : texpression list option = + let app, args = destruct_apps e in + match app.e with + | Qualif { id = AdtCons { adt_id = TTuple; variant_id = None }; generics = _ } + -> + Some args + | _ -> None + +let opt_dest_struct_pattern (pat : typed_pattern) : typed_pattern list option = + match pat.value with + | PatAdt { variant_id = None; field_values } -> Some field_values + | _ -> None diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 53f99b7f..8e8cdec3 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -42,8 +42,16 @@ type call = { evaluated). We need it to compute the translated values for shared borrows (we need to perform lookups). *) + sg : fun_sig option; + (** The uninstantiated function signature, if this is not a unop/binop *) + regions_hierarchy : region_var_groups; abstractions : AbstractionId.id list; + (** The region abstractions introduced upon calling the function *) generics : generic_args; + trait_method_generics : (generic_args * trait_instance_id) option; + (** In case the call is to a trait method, we may need an additional type + parameter ([Self]) and the self trait clause to instantiate the + function signature. *) args : typed_value list; args_places : mplace option list; (** Meta information *) dest : symbolic_value; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 84f09280..3a50e495 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -15,7 +15,7 @@ module PP = PrintPure (** The local logger *) let log = Logging.symbolic_to_pure_log -type type_context = { +type type_ctx = { llbc_type_decls : T.type_decl TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; (** We use this for type-checking (for sanity checks) when translating @@ -43,19 +43,18 @@ type fun_sig_named_outputs = { } [@@deriving show] -type fun_context = { +type fun_ctx = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } [@@deriving show] -type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } +type global_ctx = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } [@@deriving show] -type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] -type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show] +type trait_decls_ctx = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] +type trait_impls_ctx = A.trait_impl A.TraitImplId.Map.t [@@deriving show] (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended @@ -68,11 +67,21 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; - (** A map from region group id (i.e., backward function id) to - pairs (abstraction, additional arguments received by the backward function) - - TODO: remove? it is also in the bs_ctx ("abstractions" field) + back_funs : texpression option RegionGroupId.Map.t option; + (** If we do not split between the forward/backward functions: the + variables we introduced for the backward functions. + + Example: + {[ + let x, back = Vec.index_mut n v in + ^^^^ + here + ... + ]} + + The expression might be [None] in case the backward function + has to be filtered (because it does nothing - the backward + functions for shared borrows for instance). *) } [@@deriving show] @@ -114,33 +123,61 @@ type loop_info = { generics : generic_args; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) - forward_output_no_state_no_result : var option; + forward_output_no_state_no_result : texpression option; (** The forward outputs are initialized at [None] *) + back_outputs : ty list RegionGroupId.Map.t; + (** The map from region group ids to the types of the values given back + by the corresponding loop abstractions. + *) + back_funs : texpression option RegionGroupId.Map.t option; + (** Same as {!call_info.back_funs}. + Initialized with [None], gets updated to [Some] only if we merge + the fwd/back functions. + *) + fwd_effect_info : fun_effect_info; + back_effect_infos : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] (** Body synthesis context *) type bs_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; - trait_decls_ctx : trait_decls_context; - trait_impls_ctx : trait_impls_context; + (* TODO: there are a lot of duplications with the various decls ctx *) + decls_ctx : C.decls_ctx; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; + fun_dsigs : decomposed_fun_sig FunDeclId.Map.t; fun_decl : A.fun_decl; - bid : T.RegionGroupId.id option; (** TODO: rename *) - sg : fun_sig; - (** The function signature - useful in particular to translate [Panic] *) - fwd_sg : fun_sig; (** The signature of the forward function *) + bid : RegionGroupId.id option; + (** TODO: rename + + The id of the group region we are currently translating. + If we split the forward/backward functions, we set this id at the + very beginning of the translation. + If we don't split, we set it to `None`, then update it when we enter + an expression which is specific to a backward function. + *) + sg : decomposed_fun_sig; + (** Information about the function signature - useful in particular to + translate [Panic] *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) we introduce a new variable (with a let-binding). *) - var_counter : VarId.generator; + var_counter : VarId.generator ref; + (** Using a ref to make sure all the variables identifiers are unique. + TODO: this is not very clean, and the code was initially written without + a reference (and it's shape hasn't changed). We should use DeBruijn indices. + *) state_var : VarId.id; (** The current state variable, in case the function is stateful *) - back_state_var : VarId.id; - (** The additional input state variable received by a stateful backward function. + back_state_vars : VarId.id RegionGroupId.Map.t; + (** The additional input state variable received by a stateful backward + function, **in case we are splitting the forward/backward functions**. + When generating stateful functions, we generate code of the following form: @@ -153,7 +190,9 @@ type bs_ctx = { When translating a backward function, we need at some point to update [state_var] with [back_state_var], to account for the fact that the state may have been updated by the caller between the call to the - forward function and the call to the backward function. + forward function and the call to the backward function. We also need + to make sure we use the same variable in all the branches (because + this variable is quantified at the definition level). *) fuel0 : VarId.id; (** The original fuel taken as input by the function (if we use fuel) *) @@ -163,28 +202,50 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list T.RegionGroupId.Map.t; + backward_inputs_no_state : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). - *) - backward_outputs : var list T.RegionGroupId.Map.t; - (** The variables that the backward functions will output, corresponding - to the borrows they give back (don't include the backward state) - *) - loop_backward_outputs : var list T.RegionGroupId.Map.t option; - (** Same as {!backward_outputs}, but for loops (if we entered a loop). - [None] if we are not inside a loop, [Some] otherwise (and whatever - the kind of function we are translating: it will be [Some] even - though we are synthesizing a forward function). + If we split the forward/backward functions: we initialize this map + when initializing the bs_ctx, because those variables are quantified + at the definition level. Otherwise, we initialize it upon diving + into the expressions which are specific to the backward functions. + *) + backward_inputs_with_state : var list RegionGroupId.Map.t; + (** All the additional input parameters for the backward functions. - TODO: move to {!loop_info} + Same remarks as for {!backward_inputs_no_state}. + *) + backward_outputs : var list option; + (** The variables that the backward functions will output, corresponding + to the borrows they give back (don't include the backward state). + + The translation is done as follows: + - when we detect the ended input abstraction which corresponds + to the backward function of the LLBC function we are translating, + and which consumed the values [consumed_i] (that we need to give + back to the caller), we introduce: + {[ + let v_i = consumed_i in + ... + ]} + where the [v_i] are fresh, and are stored in the [backward_output]. + - Then, upon reaching the [Return] node, we introduce: + {[ + return (v_i) + ]} + + The option is [None] before we detect the ended input abstraction, + and [Some] afterwards. *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; - (** The ended abstractions we encountered so far, with their additional input arguments *) + (** The ended abstractions we encountered so far, with their additional + input arguments. We store it here and not in {!call_info} because + we need a map from abstraction id to abstraction (and not + from call id + region group id to abstraction). *) loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *) loops : loop_info LoopId.Map.t; (** The loops we encountered so far. @@ -209,9 +270,9 @@ type bs_ctx = { (* TODO: move *) let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let { regions; types; const_generics; trait_clauses } : T.generic_params = @@ -233,9 +294,9 @@ let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = } let bs_ctx_to_pure_fmt_env (ctx : bs_ctx) : PrintPure.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let generics = ctx.sg.generics in @@ -300,6 +361,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p +let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) : + fun_effect_info = + match bid with + | None -> ctx.sg.fwd_info.effect_info + | Some bid -> + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + back_sg.effect_info + +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + ctx_get_effect_info_for_bid ctx ctx.bid + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -308,33 +380,13 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let indent_incr = " " in Print.Values.abs_to_string env verbose indent indent_incr abs -let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (generics : generic_args) - (ctx : bs_ctx) : inst_fun_sig = - (* Lookup the non-instantiated function signature *) - let sg = - (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg - in - (* Create the substitution *) - (* There shouldn't be any reference to Self *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics sg.generics generics tr_self in - (* Apply *) - fun_sig_substitute subst sg - let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = - TypeDeclId.Map.find id ctx.type_context.llbc_type_decls + TypeDeclId.Map.find id ctx.type_ctx.llbc_type_decls let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = - A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls - -(* TODO: move *) -let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) - (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = - let id = (E.FRegular def_id, back_id) in - (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg + A.FunDeclId.Map.find id ctx.fun_ctx.llbc_fun_decls (* Some generic translation functions (we need to translate different "flavours" of types: forward types, backward types, etc.) *) @@ -601,13 +653,13 @@ and translate_fwd_trait_instance_id (type_infos : type_infos) (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : T.ty) : ty = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_ty type_infos ty (** Simply calls [translate_fwd_generic_args] *) let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : T.generic_args) : generic_args = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_generic_args type_infos generics (** Translate a type, when some regions may have ended. @@ -682,17 +734,21 @@ let rec translate_back_ty (type_infos : type_infos) None | TTraitType (trait_ref, generics, type_name) -> assert (generics.regions = []); - (* Translate the trait ref and the generics as "forward" generics - - we do not want to filter any type *) - let trait_ref = translate_fwd_trait_ref type_infos trait_ref in - let generics = translate_fwd_generic_args type_infos generics in - Some (TTraitType (trait_ref, generics, type_name)) + assert ( + AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id); + if inside_mut then + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TTraitType (trait_ref, generics, type_name)) + else None | TArrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : T.ty) : ty option = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_back_ty type_infos keep_region inside_mut ty let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = @@ -705,8 +761,8 @@ let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = in let env = VarId.Map.empty in { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; + PureTypeCheck.type_decls = ctx.type_ctx.type_decls; + global_decls = ctx.global_ctx.llbc_global_decls; env; const_generics; } @@ -726,31 +782,36 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) match id with | FunId fun_id -> FunId fun_id | TraitMethod (trait_ref, method_name, fun_decl_id) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in TraitMethod (trait_ref, method_name, fun_decl_id) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = + (args : texpression list) + (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) : + bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = - { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } - in + let info = { forward; forward_inputs = args; back_funs } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = +(** [inherit_args]: the list of inputs inherited from the forward function and + the ancestors backward functions, if pertinent. + [back_args]: the *additional* list of inputs received by the backward function, + including the state. + + Returns the updated context and the expression corresponding to the function + that we need to call. This function may be [None] if it has to be ignored + (because it does nothing). + *) +let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) + (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) + (inherited_args : texpression list) (back_args : texpression list) + (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : + bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in - assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = - T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards - in - let info = { info with backwards } in let calls = V.FunCallId.Map.add call_id info ctx.calls in (* Insert the abstraction in the abstractions map *) let abstractions = ctx.abstractions in @@ -758,16 +819,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + (* Compute the expression corresponding to the function *) + let func = + if !Config.return_back_funs then + (* Lookup the variable introduced for the backward function *) + RegionGroupId.Map.find back_id (Option.get info.back_funs) + else + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + let args = List.append inherited_args back_args in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output_ty else output_ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp fun_id; generics } in + Some { e = Qualif func; ty = func_ty } in (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) + ({ ctx with calls; abstractions }, func) (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -821,7 +897,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] (** Small utility. *) -let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) +let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with @@ -848,33 +924,53 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } -(** Translate a function signature. +(** TODO: not very clean. *) +let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : + fun_effect_info = + match lid with + | None -> ( + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid) + | Some lid -> ( + (* This is necessarily for the current function *) + match fun_id with + | FunId (FRegular fid) -> ( + assert (fid = ctx.fun_decl.def_id); + (* Lookup the loop *) + let lid = V.LoopId.Map.find lid ctx.loop_ids_map in + let loop_info = LoopId.Map.find lid ctx.loops in + match gid with + | None -> loop_info.fwd_effect_info + | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos) + | _ -> raise (Failure "Unreachable")) + +(** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names in the extracted code. + + We use [bid] ("backward function id") only if we split the forward + and the backward functions. *) -let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) - (sg : A.fun_sig) (input_names : string option list) - (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = +let translate_fun_sig_with_regions_hierarchy_to_decomposed + (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig) + (input_names : string option list) : decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in - (* Retrieve the list of parent backward functions *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in - let gid, parents = - match bid with - | None -> (None, T.RegionGroupId.Set.empty) - | Some bid -> - let parents = list_ancestor_region_groups regions_hierarchy bid in - (Some bid, parents) - in - (* Is the function stateful, and can it fail? *) - let lid = None in - let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -885,7 +981,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) List.map (fun (g : T.region_var_group) -> g.id) regions_hierarchy in let ctx = - InterpreterUtils.initialize_eval_context decls_ctx region_groups + InterpreterUtils.initialize_eval_ctx decls_ctx region_groups sg.generics.types sg.generics.const_generics in (* Compute the normalization map for the *sty* types and add it to the context *) @@ -902,17 +998,28 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) { sg with A.inputs; output } in - (* List the inputs for: - * - the fuel - * - the forward function - * - the parent backward functions, in proper order - * - the current backward function (if it is a backward function) - *) - let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in - (* For the backward functions: for now we don't supported nested borrows, - * so just check that there aren't parent regions *) - assert (T.RegionGroupId.Set.is_empty parents); + (* Is the forward function stateful, and can it fail? *) + let fwd_effect_info = + compute_raw_fun_effect_info fun_infos fun_id None None + in + (* Compute the forward inputs *) + let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in + let fwd_inputs_no_fuel_no_state = + List.map (translate_fwd_ty type_infos) sg.inputs + in + (* State input for the forward function *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if fwd_effect_info.stateful_group then [ mk_state_ty ] else [] + in + let fwd_inputs = + List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ] + in + (* Compute the backward output, without the effect information *) + let fwd_output = translate_fwd_ty type_infos sg.output in + + (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) : ty option = @@ -939,163 +1046,336 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in - (* Compute the additinal inputs for the current function, if it is a backward - * function *) - let back_inputs = - match gid with - | None -> [] - | Some gid -> - (* For now, we don't allow nested borrows, so the additional inputs to the - backward function can only come from borrows that were returned like - in (for the backward function we introduce for 'a): - {[ - fn f<'a>(...) -> &'a mut u32; - ]} - Upon ending the abstraction for 'a, we need to get back the borrow - the function returned. - *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] - in - (* If the function is stateful, the inputs are: - - forward: [fwd_ty0, ..., fwd_tyn, state] - - backward: - - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] - - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] - - The backward takes the same state as input as the forward function, - together with the state at the point where it gets called, if it is - stateful. - - See the comments for {!Config.backward_no_state_update} - *) - let fwd_state_ty = - (* For the forward state, we check if the *whole group* is stateful. - See {!effect_info}. *) - if effect_info.stateful_group then [ mk_state_ty ] else [] + let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list = + (* For now we don't supported nested borrows, so we check that there + aren't parent regions *) + let parents = list_ancestor_region_groups regions_hierarchy gid in + assert (T.RegionGroupId.Set.is_empty parents); + (* For now, we don't allow nested borrows, so the additional inputs to the + backward function can only come from borrows that were returned like + in (for the backward function we introduce for 'a): + {[ + fn f<'a>(...) -> &'a mut u32; + ]} + Upon ending the abstraction for 'a, we need to get back the borrow + the function returned. + *) + let inputs = + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let output = Print.Types.ty_to_string ctx sg.output in + let inputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) inputs + in + "translate_back_inputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n")); + inputs + in + let compute_back_outputs_for_gid (gid : RegionGroupId.id) : + string option list * ty list = + (* The outputs are the borrows inside the regions of the abstractions + and which are present in the input values. For instance, see: + {[ + fn f<'a>(x : &'a mut u32) -> ...; + ]} + Upon ending the abstraction for 'a, we give back the borrow which + was consumed through the [x] parameter. + *) + let outputs = + List.map + (fun (name, input_ty) -> (name, translate_back_ty_for_gid gid input_ty)) + (List.combine input_names sg.inputs) + in + (* Filter *) + let outputs = + List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs + in + let outputs = + List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs + in + let names, outputs = List.split outputs in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let inputs = + Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs + in + let outputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) outputs + in + "compute_back_outputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n")); + (names, outputs) + in + let compute_back_info_for_group (rg : T.region_var_group) : + RegionGroupId.id * back_sg_info = + let gid = rg.id in + let back_effect_info = + compute_raw_fun_effect_info fun_infos fun_id None (Some gid) + in + let inputs_no_state = translate_back_inputs_for_gid gid in + let inputs_no_state = + List.map (fun ty -> (Some "ret", ty)) inputs_no_state + in + (* In case we merge the forward/backward functions: + we consider the backward function as stateful and potentially failing + **only if it has inputs** (for the "potentially failing": if it has + not inputs, we directly evaluate it in the body of the forward function). + *) + let back_effect_info = + if !Config.return_back_funs then + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } + else back_effect_info + in + let state = + if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] + in + let inputs = inputs_no_state @ state in + let output_names, outputs = compute_back_outputs_for_gid gid in + let filter = + !Config.simplify_merged_fwd_backs + && !Config.return_back_funs && inputs = [] && outputs = [] + in + let info = + { + inputs; + inputs_no_state; + outputs; + output_names; + effect_info = back_effect_info; + filter; + } + in + (gid, info) in - let back_state_ty = - (* For the backward state, we check if the function is a backward function, - and it is stateful *) - if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] + let back_sg = + RegionGroupId.Map.of_list + (List.map compute_back_info_for_group regions_hierarchy) in - (* Concatenate the inputs, in the following order: - * - forward inputs - * - forward state input - * - backward inputs - * - backward state input - *) - let inputs = - List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] - in - (* Outputs *) - let output_names, doutputs = - match gid with - | None -> - (* This is a forward function: there is one (unnamed) output *) - ([ None ], [ translate_fwd_ty type_infos sg.output ]) - | Some gid -> - (* This is a backward function: there might be several outputs. - The outputs are the borrows inside the regions of the abstractions - and which are present in the input values. For instance, see: - {[ - fn f<'a>(x : &'a mut u32) -> ...; - ]} - Upon ending the abstraction for 'a, we give back the borrow which - was consumed through the [x] parameter. - *) - let outputs = - List.map - (fun (name, input_ty) -> - (name, translate_back_ty_for_gid gid input_ty)) - (List.combine input_names sg.inputs) - in - (* Filter *) - let outputs = - List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs - in - let outputs = - List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs - in - List.split outputs - in - (* Create the return type *) - let output = - (* Group the outputs together *) - let output = mk_simpl_tuple_ty doutputs in - (* Add the output state *) - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output + (* The additional information about the forward function *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let ignore_output = + if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + ty_is_unit fwd_output + && List.exists + (fun (info : back_sg_info) -> not info.filter) + (RegionGroupId.Map.values back_sg) + else false in - (* Wrap in a result type *) - if effect_info.can_fail then mk_result_ty output else output + let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in + assert (fun_sig_info_is_wf info); + info in + (* Generic parameters *) let generics = translate_generic_params sg.generics in + (* Return *) - let has_fuel = fuel <> [] in - let num_fwd_inputs_no_state = List.length fwd_inputs in - let num_fwd_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_state - in - let num_back_inputs_no_state = - if bid = None then None else Some (List.length back_inputs) + let preds = translate_predicates sg.preds in + { + generics; + llbc_generics = sg.generics; + preds; + fwd_inputs; + fwd_output; + back_sg; + fwd_info; + } + +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list) + : decomposed_fun_sig = + (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies in - let info = - { - has_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - num_back_inputs_no_state; - num_back_inputs_with_state = - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - Option.map - (fun n -> n + List.length back_state_ty) - num_back_inputs_no_state; - effect_info; - } + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + (FunId (FRegular fun_id)) regions_hierarchy sg input_names + +let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) + (fdef : LlbcAst.fun_decl) : decomposed_fun_sig = + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : LlbcAst.var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) in - let preds = translate_predicates sg.preds in let sg = - { - generics; - llbc_generics = sg.generics; - preds; - inputs; - output; - doutputs; - info; - } + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + in + log#ldebug + (lazy + ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: " + ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n")); + sg + +let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty + = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty in - { sg; output_names } + if effect_info.can_fail then mk_result_ty output else output -let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = +let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) + (inputs : ty list) (ty : ty) : ty = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output else output + +(** Compute the arrow types for all the backward functions. + + If a backward function has no inputs/outputs we filter it. + *) +let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : + (back_sg_info * ty) option list = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = back_sg.outputs in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then + None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + let ty = + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = + make_subst_from_generics dsg.generics generics tr_self + in + ty_substitute subst ty + in + Some (back_sg, ty)) + (RegionGroupId.Map.values dsg.back_sg) + +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty option list = + List.map (Option.map snd) (compute_back_tys_with_info dsg subst) + +(** In case we merge the fwd/back functions: compute the output type of + a function, from a decomposed signature. *) +let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = + assert !Config.return_back_funs; + (* Compute the arrow types for all the backward functions *) + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = + (* We might need to ignore the output of the forward function + (if it is unit for instance) *) + let tys = + if dsg.fwd_info.ignore_output then back_tys + else dsg.fwd_output :: back_tys + in + mk_simpl_tuple_ty tys + in + mk_output_ty_from_effect_info effect_info output + +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id option) : fun_sig = + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) + in + let mk_output_ty = mk_output_ty_from_effect_info in + let inputs, output = + (* Two cases depending on whether we split the forward/backward functions or not *) + if !Config.return_back_funs then ( + assert (gid = None); + let output = compute_output_ty_from_decomposed dsg in + let inputs = dsg.fwd_inputs in + (inputs, output)) + else + match gid with + | None -> + let effect_info = dsg.fwd_info.effect_info in + let output = mk_output_ty effect_info dsg.fwd_output in + (dsg.fwd_inputs, output) + | Some gid -> + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + (inputs, output) + in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let state_var = { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in let state_pat = mk_typed_pattern_from_var state_var None in (* Update the context *) - let ctx = { ctx with var_counter; state_var = id } in + ctx.var_counter := var_counter; + let ctx = { ctx with state_var = id } in (* Return *) - (ctx, state_pat) + (ctx, state_var, state_pat) (** WARNING: do not call this function directly. Call [fresh_named_var_for_symbolic_value] instead. *) let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let ty = ctx_translate_fwd_ty ctx ty in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -1129,10 +1409,10 @@ let fresh_named_vars_for_symbolic_values let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -1140,6 +1420,58 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) : + bs_ctx * var option list = + List.fold_left_map + (fun ctx var -> + match var with + | None -> (ctx, None) + | Some (name, ty) -> + let ctx, var = fresh_var name ty ctx in + (ctx, Some var)) + ctx vars + +(* Introduce variables for the backward functions *) +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = + (* We lookup the LLBC definition in an attempt to derive pretty names + for the backward functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_ctx.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in + let back_vars = + List.map + (fun (name, ty) -> + match ty with None -> None | Some ty -> Some (name, ty)) + back_vars + in + fresh_opt_vars back_vars ctx + +(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *) let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1158,12 +1490,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value = | _ -> raise (Failure "Unreachable")) | _ -> v -(** Translate a symbolic value *) +(** Translate a symbolic value. + + Because we do not necessarily introduce variables for the symbolic values + of (translated) type unit, it is important that we do not lookup variables + in case the symbolic value has type unit. + *) let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) : texpression = (* Translate the type *) - let var = lookup_var_for_symbolic_value sv ctx in - mk_texpression_from_var var + let ty = ctx_translate_fwd_ty ctx sv.sv_ty in + (* If the type is unit, directly return unit *) + if ty_is_unit ty then mk_unit_rvalue + else + (* Otherwise lookup the variable *) + let var = lookup_var_for_symbolic_value sv ctx in + mk_texpression_from_var var (** Translate a typed value. @@ -1342,13 +1684,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option = match aproj with | V.AEndedProjLoans (msv, []) -> (* The symbolic value was left unchanged *) - let var = lookup_var_for_symbolic_value msv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx msv) | V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) -> assert (child_aproj = AIgnoredProjBorrows); (* The symbolic value was updated *) - let var = lookup_var_for_symbolic_value mnv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx mnv) | V.AEndedProjLoans (_, _) -> (* The symbolic value was updated, and the given back values come from sevearl * abstractions *) @@ -1543,7 +1883,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list) let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with - | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx + | S.Return (ectx, opt_v) -> + (* Remark: we can't get there if we are inside a loop *) + translate_return ectx opt_v ctx | ReturnWithLoop (loop_id, is_continue) -> translate_return_with_loop loop_id is_continue ctx | Panic -> translate_panic ctx @@ -1565,32 +1907,56 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in - mk_simpl_tuple_ty tys - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - mk_simpl_tuple_ty ctx.sg.doutputs - in + let effect_info = ctx_get_effect_info ctx in (* TODO: we should use a [Fail] function *) - if ctx.sg.info.effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id error_failure_id ret_ty - in - ret_v - else mk_result_fail_texpression_with_error_id error_failure_id output_ty + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in + ret_v + else mk_result_fail_texpression_with_error_id error_failure_id output_ty + in + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + let output = mk_simpl_tuple_ty tys in + mk_output output + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + match ctx.bid with + | None -> + if !Config.return_back_funs then + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output + else mk_output ctx.sg.fwd_output + | Some bid -> + let output = + mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output + +(** [opt_v]: the value to return, in case we translate a forward body. -(** [opt_v]: the value to return, in case we translate a forward body *) + Remark: for now, we can't get there if we are inside a loop. + If inside a loop, we use {!translate_return_with_loop}. + + Remark: in case we merge the forward/backward functions, we introduce + those in [translate_forward_end]. +*) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1599,22 +1965,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) - or we are translating a backward function, in which case it should be [None] *) (* Compute the values that we should return *without the state and the result - * wrapper* *) + wrapper* *) let output = match ctx.bid with | None -> (* Forward function *) let v = Option.get opt_v in typed_value_to_texpression ctx ectx v - | Some bid -> + | Some _ -> (* Backward function *) (* Sanity check *) assert (opt_v = None); (* Group the variables in which we stored the values we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1622,7 +1986,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) * - error-monad: Return x * - state-error: Return (state, x) * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -1647,24 +2011,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) *) let output = match ctx.bid with - | None -> - (* Forward *) - mk_texpression_from_var - (Option.get loop_info.forward_output_no_state_no_result) - | Some bid -> + | None -> Option.get loop_info.forward_output_no_state_no_result + | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map - in + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1676,7 +2028,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) * effect - in particular, one manipulates a state iff the other does * the same. * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -1684,13 +2036,15 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_return_texpression output + mk_emeta (Tag "return_with_loop") (mk_result_return_texpression output) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = log#ldebug (lazy - ("translate_function_call:\n" + ("translate_function_call:\n" ^ "\n- call.call_id:" + ^ S.show_call_id call.call_id + ^ "\n\n- call.generics:\n" ^ ctx_generic_args_to_string ctx call.generics)); (* Translate the function call *) let generics = ctx_translate_fwd_generic_args ctx call.generics in @@ -1702,10 +2056,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.combine args args_mplaces) in let dest_mplace = translate_opt_mplace call.dest_place in - let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, out_state = + let ctx, fun_id, effect_info, args, dest_v = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1713,25 +2066,150 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) - let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None None - in + let effect_info = get_fun_effect_info ctx fid None None in (* Depending on the function effects: - * - add the fuel - * - add the state input argument - * - generate a fresh state variable for the returned state + - add the fuel + - add the state input argument + - generate a fresh state variable for the returned state *) let args, ctx, out_state = let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in + (* If we do not split the forward/backward functions: generate the + variables for the backward functions returned by the forward + function. *) + let ctx, ignore_fwd_output, back_funs_map, back_funs = + if !Config.return_back_funs then ( + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + fid call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in + let back_tys = + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) + in + (* Introduce variables for the backward functions *) + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls + in + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) + in + name ^ "_back" + in + let ctx, back_vars = + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then + Some back_fun_name + else None + in + Some (name, ty)) + back_tys) + ctx + in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list (List.combine gids back_vars) + in + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) + else (ctx, false, None, []) + in + (* Compute the pattern for the destination *) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = + (* Here there is something subtle: as we might ignore the output + of the forward function (because it translates to unit) we doNOT + necessarily introduce in the let-binding the variable to which we + map the symbolic value which was introduced for the output of the + function call. This would be problematic if later we need to + translate this symbolic value, but we implemented + {!symbolic_value_to_texpression} so that it doesn't perform any + lookups if the symbolic value has type unit. + *) + let vars = + if ignore_fwd_output then back_funs else dest :: back_funs + in + mk_simpl_tuple_pattern vars + in + let dest = + match out_state with + | None -> dest + | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] + in (* Register the function call *) - let ctx = bs_ctx_register_forward_call call_id call args ctx in - (ctx, func, effect_info, args, out_state) + let ctx = + bs_ctx_register_forward_call call_id call args back_funs_map ctx + in + (ctx, func, effect_info, args, dest) | S.Unop E.Not -> let effect_info = { @@ -1742,7 +2220,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop Not, effect_info, args, dest) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -1758,7 +2238,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Neg int_ty), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -1773,7 +2255,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -1793,15 +2277,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Binop (binop, int_ty0), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) in - let dest_v = - let dest = mk_typed_pattern_from_var dest dest_mplace in - match out_state with - | None -> dest - | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] - in let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -1846,45 +2326,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) ^ abs_to_string ctx abs ^ "\n")); (* When we end an input abstraction, this input abstraction gets back - * the borrows which it introduced in the context through the input - * values: by listing those values, we get the values which are given - * back by one of the backward functions we are synthesizing. *) - (* Note that we don't support nested borrows for now: if we find - * an ended synthesized input abstraction, it must be the one corresponding - * to the backward function wer are synthesizing, it can't be the one - * for a parent backward function. - *) + the borrows which it introduced in the context through the input + values: by listing those values, we get the values which are given + back by one of the backward functions we are synthesizing. + + Note that we don't support nested borrows for now: if we find + an ended synthesized input abstraction, it must be the one corresponding + to the backward function wer are synthesizing, it can't be the one + for a parent backward function. + *) let bid = Option.get ctx.bid in assert (rg_id = bid); - (* The translation is done as follows: - * - for a given backward function, we choose a set of variables [v_i] - * - when we detect the ended input abstraction which corresponds - * to the backward function, and which consumed the values [consumed_i], - * we introduce: - * {[ - * let v_i = consumed_i in - * ... - * ]} - * Then, when we reach the [Return] node, we introduce: - * {[ - * (v_i) - * ]} - * *) - (* First, get the given back variables. + (* First, introduce the given back variables. We don't use the same given back variables if we translate a loop or the standard body of a function. *) - let given_back_variables = - let map = + let ctx, given_back_variables = + let vars = if ctx.inside_loop then (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function body *) - ctx.backward_outputs + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + List.map (fun ty -> (None, ty)) tys + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + List.combine back_sg.output_names back_sg.outputs in - T.RegionGroupId.Map.find bid map + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) @@ -1933,9 +2406,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Those don't have backward functions *) raise (Failure "Unreachable") in - let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) - in + let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in @@ -1959,14 +2430,16 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + let inherited_inputs = + if !Config.return_back_funs then [] + else List.concat [ fwd_inputs; back_ancestors_inputs ] in + let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the * meta-place information from the input values given to the forward function @@ -1983,78 +2456,61 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | None -> output | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - (if (* TODO: normalize the types *) !Config.type_check_pure_code then - match fun_id with - | FunId fun_id -> - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) generics ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " - (List.map (pure_ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context - * if necessary *) + if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs + back_inputs generics output.ty ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) + let inputs = List.append inherited_inputs back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in - let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + ================= + We do a small optimization here if we split the forward/backward functions. + If the backward function doesn't have any output, we don't introduce any function + call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) - else mk_let effect_info.can_fail output call next_e + else + (* The backward function might also have been filtered if we do not + split the forward/backward functions *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2095,15 +2551,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) let-binding: {[ let id_back x nx = - let s = nx in // the name [s] is not important (only collision matters) - ... + let s = nx in // the name [s] is not important (only collision matters) + ... ]} This let-binding later gets inlined, during a micro-pass. *) (* First, retrieve the list of variables used for the inputs for the * backward function *) - let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in + let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in (* Retrieve the values consumed upon ending the loans inside this * abstraction: as there are no nested borrows, there should be none. *) let consumed = abs_to_consumed ctx ectx abs in @@ -2150,11 +2606,12 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopSynthInput -> (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id - | V.LoopCall -> + | V.LoopCall -> ( + (* We need to introduce a call to the backward function corresponding + to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id) - (Some vloop_id) (Some rg_id) + get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2165,7 +2622,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) values consumed upon ending the abstraction (i.e., we don't use [abs_to_consumed]) *) let back_inputs_vars = - T.RegionGroupId.Map.find rg_id ctx.backward_inputs + T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in let back_inputs = List.map mk_texpression_from_var back_inputs_vars in (* If the function is stateful: @@ -2175,12 +2632,15 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in + let inputs = + if !Config.return_back_funs then List.concat [ back_inputs; back_state ] + else List.concat [ fwd_inputs; back_inputs; back_state ] + in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2204,68 +2664,88 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty in - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in - let call = mk_apps func args in + (* Create the expression for the function: + - it is either a call to a top-level function, if we split the + forward/backward functions + - or a call to the variable we introduced for the backward function, + if we merge the forward/backward functions *) + let func = + if !Config.return_back_funs then + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) + else + let func_ty = mk_arrows input_tys ret_ty in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in + Some { e = Qualif func; ty = func_ty } + in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None + ================= + We do a small optimization here in case we split the forward/backward + functions. + If the backward function doesn't have any output, we don't introduce + any function call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values + (* In case we merge the fwd/back functions we filter the backward + functions elsewhere *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in - let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in + let decl = A.GlobalDeclId.Map.find gid ctx.global_ctx.llbc_global_decls in let global_expr = { id = Global gid; generics = empty_generic_args } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in @@ -2290,8 +2770,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) - let scrutinee_var = lookup_var_for_symbolic_value sv ctx in - let scrutinee = mk_texpression_from_var scrutinee_var in + let scrutinee = symbolic_value_to_texpression ctx sv in let scrutinee_mplace = translate_opt_mplace p in (* Translate the branches *) match exp with @@ -2370,7 +2849,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) If (true_e, false_e) ) in let ty = true_e.ty in - assert (ty = false_e.ty); + log#ldebug + (lazy + ("true_e.ty: " + ^ pure_ty_to_string ctx true_e.ty + ^ "\n\nfalse_e.ty: " + ^ pure_ty_to_string ctx false_e.ty)); + if !Config.fail_hard then assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : @@ -2441,7 +2926,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) - if we forbid using field projectors. *) let is_rec_def = - T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls + T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls in let use_let_with_cons = is_enum @@ -2454,7 +2939,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) like Coq don't, in which case we have to deconstruct the whole ADT at once (`let (a, b, c) = x in`) *) || TypesUtils.type_decl_from_type_id_is_tuple_struct - ctx.type_context.type_infos type_id + ctx.type_ctx.type_infos type_id && not (Config.backend_has_tuple_projectors ()) in if use_let_with_cons then @@ -2547,7 +3032,7 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) { e = StructUpdate su; ty = var.ty } | VaCgValue cg_id -> { e = CVar cg_id; ty = var.ty } | VaTraitConstValue (trait_ref, generics, const_name) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in let generics = translate_fwd_generic_args type_infos generics in let qualif_id = TraitConst (trait_ref, generics, const_name) in @@ -2562,22 +3047,169 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) - (e : S.expression) (back_e : S.expression S.region_group_id_map) + (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* Update the current state with the additional state received by the backward - function, if needs be, and lookup the proper expression *) - let translate_end ctx = + let translate_one_end ctx (bid : RegionGroupId.id option) = + let ctx = { ctx with bid } in (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) - let ctx, e = - match ctx.bid with - | None -> (ctx, e) + let ctx, e, finish = + match bid with + | None -> + (* We are translating the forward function - nothing to do *) + (ctx, fwd_e, fun e -> e) | Some bid -> - let ctx = { ctx with state_var = ctx.back_state_var } in + (* There are two cases here: + - if we split the fwd/backward functions, we simply need to update + the state. + - if we don't split, we also need to wrap the expression in a + lambda, which introduces the additional inputs of the backward + function + *) + let ctx = + (* Introduce variables for the inputs and the state variable + and update the context. *) + if !Config.return_back_funs then + (* If the forward/backward functions are not split, we need + to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if back_sg.effect_info.stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } + else + (* Update the state variable *) + let back_state_var = + RegionGroupId.Map.find bid ctx.back_state_vars + in + { ctx with state_var = back_state_var } + in + let e = T.RegionGroupId.Map.find bid back_e in - (ctx, e) + let finish e = + (* Wrap in lambdas if necessary *) + if !Config.return_back_funs then + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e + else e + in + (ctx, e, finish) in - translate_expression e ctx + let e = translate_expression e ctx in + finish e + in + + (* There are two cases, depending on whether we are splitting the forward/backward + functions or not. + + - if we split, then we simply need to translate the proper "end" expression, + that is the end of the forward function, or of the backward function we + are currently translating. + - if we don't split, then we need to translate the end of the forward + function (this is the value we will return) and generate the bodies + of the backward functions (which we will also return). + + Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression. + *) + let translate_end ctx = + if !Config.return_back_funs then + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in + let fwd_e = translate_one_end ctx None in + + (* Introduce the backward functions. *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs. *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in + + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let _, back_vars = fresh_back_vars_for_current_fun ctx in + + (* Create the return expressions *) + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else pure_fwd_var :: back_vars + in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in + + (* Introduce all the let-bindings *) + + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) + let back_vars_els = + List.filter_map + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) + in + let e = + List.fold_right + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) + back_vars_els ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e + else translate_one_end ctx ctx.bid in (* If we are (re-)entering a loop, we need to introduce a call to the @@ -2624,17 +3256,53 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in - let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid - in + let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = - let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in - fresh_var None output_ty ctx + let ctx, fwd_output, output_pat = + if ctx.sg.fwd_info.ignore_output then + (* Note that we still need the forward output (which is unit), + because even though the loop function will ignore the forward output, + the forward expression will still compute an output (which + will have type unit - otherwise we can't ignore it). *) + (ctx, mk_unit_rvalue, []) + else + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + ( ctx, + mk_texpression_from_var output_var, + [ mk_typed_pattern_from_var output_var None ] ) + in + + (* Introduce fresh variables for the backward functions of the loop. + + For now, the backward functions of the loop are the same as the + backward functions of the outer function. + *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) in + + (* Introduce patterns *) let args, ctx, out_pats = - let output_pat = mk_typed_pattern_from_var output_var None in + (* Add the returned backward functions (they might be empty) *) + let output_pat = mk_simpl_tuple_pattern (output_pat @ back_funs) in (* Depending on the function effects: * - add the fuel @@ -2644,7 +3312,7 @@ and translate_forward_end (ectx : C.eval_ctx) let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in ( List.concat [ fuel; args; [ state_var ] ], ctx, [ nstate_pat; output_pat ] ) @@ -2656,7 +3324,8 @@ and translate_forward_end (ectx : C.eval_ctx) { loop_info with forward_inputs = Some args; - forward_output_no_state_no_result = Some output_var; + forward_output_no_state_no_result = Some fwd_output; + back_funs = back_funs_map; } in let ctx = @@ -2753,35 +3422,107 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in - let loop_backward_outputs = - T.RegionGroupId.Map.map + let rg_to_given_back_tys = + RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) - let vars = - List.map - (fun ty -> - assert ( - not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty)); - (None, ctx_translate_fwd_ty !ctx ty)) - tys - in - (* Introduce fresh variables *) - let ctx', vars = fresh_vars vars !ctx in - ctx := ctx'; - vars) + List.map + (fun ty -> + assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); + ctx_translate_fwd_ty !ctx ty) + tys) loop.rg_to_given_back_tys in let ctx = !ctx in - let back_output_tys = - match ctx.bid with - | None -> None - | Some rg_id -> - let back_outputs = - T.RegionGroupId.Map.find rg_id loop_backward_outputs - in - let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in - Some back_output_tys + (* The output type of the loop function *) + let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in + let back_effect_infos, output_ty = + if !Config.return_back_funs then + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) + let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + let ty = + if + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) + (List.combine back_sgs given_back_tys) + in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) + else + let back_info = + RegionGroupId.Map.of_list + (List.map + (fun ((id, back_sg) : _ * back_sg_info) -> + (id, { back_sg.effect_info with is_rec = true })) + (RegionGroupId.Map.bindings ctx.sg.back_sg)) + in + let output = + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = + T.RegionGroupId.Map.find rg_id rg_to_given_back_tys + in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output + in + (back_info, output) in (* Add the loop information in the context *) @@ -2823,6 +3564,10 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = generics; forward_inputs = None; forward_output_no_state_no_result = None; + back_outputs = rg_to_given_back_tys; + back_funs = None; + fwd_effect_info; + back_effect_infos; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in @@ -2830,13 +3575,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = - { - ctx with - loop_id = Some loop_id; - loop_backward_outputs = Some loop_backward_outputs; - } - in + let ctx_end = { ctx with loop_id = Some loop_id } in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) @@ -2844,7 +3583,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Add the input state *) let input_state = - if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None + if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None in (* Translate the loop body *) @@ -2862,7 +3601,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in @@ -2982,10 +3721,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in - (* Retrieve the signature *) - let signature = ctx.sg in + (* Translate the signature *) + let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies in (* Translate the body, if there is *) let body = @@ -2993,8 +3732,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos - (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -3027,20 +3765,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match bid with | None -> [] | Some back_id -> + assert (not !Config.return_back_funs); let parents_ids = list_ordered_ancestor_region_groups regions_hierarchy back_id in let backward_ids = List.append parents_ids [ back_id ] in List.concat (List.map - (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) + (fun id -> + T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) backward_ids) in (* Introduce the backward input state (the state at call site of the * *backward* function), if necessary *) let back_state = if effect_info.stateful && Option.is_some bid then - [ mk_state_var ctx.back_state_var ] + let state_var = + RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars + in + [ mk_state_var state_var ] else [] in (* Group the inputs together *) @@ -3114,63 +3857,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list = List.map (translate_type_decl ctx) (TypeDeclId.Map.values ctx.type_ctx.type_decls) -(** Translates function signatures. - - Takes as input a list of function information containing: - - the function id - - a list of optional names for the inputs - - the function signature - - Returns a map from forward/backward functions identifiers to: - - translated function signatures - - optional names for the outputs values (we derive them for the backward - functions) - *) -let translate_fun_signatures (decls_ctx : C.decls_ctx) - (functions : (A.fun_id * string option list * A.fun_sig) list) : - fun_sig_named_outputs RegularFunIdNotLoopMap.t = - (* For every function, translate the signatures of: - - the forward function - - the backward functions - *) - let translate_one (fun_id : A.fun_id) (input_names : string option list) - (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list - = - log#ldebug - (lazy - ("Translating signature of function: " - ^ Print.Expressions.fun_id_to_string - (Print.Contexts.decls_ctx_to_fmt_env decls_ctx) - fun_id)); - (* Retrieve the regions hierarchy *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in - (* The forward function *) - let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in - let fwd_id = (fun_id, None) in - (* The backward functions *) - let back_sgs = - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy - in - (* Return *) - (fwd_id, fwd_sg) :: back_sgs - in - let translated = - List.concat - (List.map (fun (id, names, sg) -> translate_one id names sg) functions) - in - List.fold_left - (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) - RegularFunIdNotLoopMap.empty translated - let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) : trait_decl = let { diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index efcf001a..865185a8 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -2,6 +2,7 @@ open Types open TypesUtils open Expressions open Values +open LlbcAst open SymbolicAst let mk_mplace (p : place) (ctx : Contexts.eval_ctx) : mplace = @@ -92,7 +93,9 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value) synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) + (sg : fun_sig option) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = @@ -102,8 +105,11 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) { call_id; ctx; + sg; + regions_hierarchy; abstractions; generics; + trait_method_generics; args; dest; args_places; @@ -118,29 +124,32 @@ let synthesize_global_eval (gid : GlobalDeclId.id) (dest : symbolic_value) Option.map (fun e -> EvalGlobal (gid, dest, e)) e let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref) - (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) + (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions generics args args_places dest dest_place e + ctx (Some sg) regions_hierarchy abstractions generics trait_method_generics + args args_places dest dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop) (arg : typed_value) (arg_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ] - dest dest_place e + synthesize_function_call (Unop unop) ctx None [] [] generics None [ arg ] + [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) (arg0 : typed_value) (arg0_place : mplace option) (arg1 : typed_value) (arg1_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ] - [ arg0_place; arg1_place ] dest dest_place e + synthesize_function_call (Binop binop) ctx None [] [] generics None + [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs) (e : expression option) : expression option = diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 221d4e73..55a94302 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -6,7 +6,6 @@ open LlbcAst open Contexts module SA = SymbolicAst module Micro = PureMicroPasses -open PureUtils open TranslateCore (** The local logger *) @@ -43,8 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t) - (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : + (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) + (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) : pure_fun_translation_no_loops = (* Debug *) log#ldebug @@ -58,13 +57,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Convert the symbolic ASTs to pure ASTs: *) (* Initialize the context *) - let forward_sig = - RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs - in let sv_to_var = SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in - let back_state_var, var_counter = Pure.VarId.fresh var_counter in let fuel0, var_counter = Pure.VarId.fresh var_counter in let fuel, var_counter = Pure.VarId.fresh var_counter in let calls = FunCallId.Map.empty in @@ -78,7 +73,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | RecGroup _ -> Some tid) (TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) in - let type_context = + let type_ctx = { SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos; llbc_type_decls = trans_ctx.type_ctx.type_decls; @@ -86,15 +81,14 @@ let translate_function_to_pure (trans_ctx : trans_ctx) recursive_decls = recursive_type_decls; } in - let fun_context = + let fun_ctx = { SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; - fun_sigs; fun_infos = trans_ctx.fun_ctx.fun_infos; regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } in - let global_context = + let global_ctx = { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls } in @@ -126,31 +120,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx) !m in + let sg = + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef + in + + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies + in + + let var_counter, back_state_vars = + if !Config.return_back_funs then (var_counter, []) + else + List.fold_left_map + (fun var_counter (region_vars : region_var_group) -> + let gid = region_vars.id in + let var, var_counter = Pure.VarId.fresh var_counter in + (var_counter, (gid, var))) + var_counter regions_hierarchy + in + let back_state_vars = RegionGroupId.Map.of_list back_state_vars in + let ctx = { + decls_ctx = trans_ctx; SymbolicToPure.bid = None; - (* Dummy for now *) - sg = forward_sig.sg; - fwd_sg = forward_sig.sg; + sg; + fun_dsigs; (* Will need to be updated for the backward functions *) sv_to_var; - var_counter; + var_counter = ref var_counter; state_var; - back_state_var; + back_state_vars; fuel0; fuel; - type_context; - fun_context; - global_context; + type_ctx; + fun_ctx; + global_ctx; trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls; trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; forward_inputs = []; - (* Empty for now *) - backward_inputs = RegionGroupId.Map.empty; - (* Empty for now *) - backward_outputs = RegionGroupId.Map.empty; - loop_backward_outputs = None; + (* Initialized just below *) + backward_inputs_no_state = RegionGroupId.Map.empty; + (* Initialized just below *) + backward_inputs_with_state = RegionGroupId.Map.empty; + backward_outputs = None; (* Empty for now *) calls; abstractions; @@ -180,6 +194,37 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | _ -> raise (Failure "Unreachable") in + (* Add the backward inputs *) + let ctx, backward_inputs_no_state, backward_inputs_with_state = + if !Config.return_back_funs then (ctx, [], []) + else + let ctx, inputs_no_with_state = + List.fold_left_map + (fun ctx (region_vars : region_var_group) -> + let gid = region_vars.id in + let back_sg = RegionGroupId.Map.find gid sg.back_sg in + let ctx, no_state = + SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx + in + let ctx, with_state = + SymbolicToPure.fresh_vars back_sg.inputs ctx + in + (ctx, ((gid, no_state), (gid, with_state)))) + ctx regions_hierarchy + in + let inputs_no_state, inputs_with_state = + List.split inputs_no_with_state + in + (ctx, inputs_no_state, inputs_with_state) + in + let backward_inputs_no_state = + RegionGroupId.Map.of_list backward_inputs_no_state + in + let backward_inputs_with_state = + RegionGroupId.Map.of_list backward_inputs_with_state + in + let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in + (* Translate the forward function *) let pure_forward = match symbolic_trans with @@ -187,7 +232,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) in - (* Translate the backward functions *) + (* Translate the backward functions, if we split the forward and backward functions *) let translate_backward (rg : region_var_group) : Pure.fun_decl = (* For the backward inputs/outputs initialization: we use the fact that * there are no nested borrows for now, and so that the region groups @@ -197,77 +242,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx) match symbolic_trans with | None -> - (* Initialize the context - note that the ret_ty is not really - * useful as we don't translate a body *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx None | Some (_, symbolic) -> - (* Finish initializing the context by adding the additional input - variables required by the backward function. - *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - (* We need to ignore the forward inputs, and the state input (if there is) *) - let backward_inputs = - let sg = backward_sg.sg in - (* We need to ignore the forward state and the backward state *) - let num_forward_inputs = - sg.info.num_fwd_inputs_with_fuel_with_state - in - let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in - Collections.List.subslice sg.inputs num_forward_inputs - (num_forward_inputs + num_back_inputs) - in - (* As we forbid nested borrows, the additional inputs for the backward - * functions come from the borrows in the return value of the rust function: - * we thus use the name "ret" for those inputs *) - let backward_inputs = - List.map (fun ty -> (Some "ret", ty)) backward_inputs - in - let ctx, backward_inputs = - SymbolicToPure.fresh_vars backward_inputs ctx - in - (* The outputs for the backward functions, however, come from borrows - * present in the input values of the rust function: for those we reuse - * the names of the input values. *) - let backward_outputs = - List.combine backward_sg.output_names backward_sg.sg.doutputs - in - let ctx, backward_outputs = - SymbolicToPure.fresh_vars backward_outputs ctx - in - let backward_inputs = - RegionGroupId.Map.singleton back_id backward_inputs - in - let backward_outputs = - RegionGroupId.Map.singleton back_id backward_outputs - in - - (* Put everything in the context *) - let ctx = - { - ctx with - bid = Some back_id; - sg = backward_sg.sg; - backward_inputs; - backward_outputs; - } - in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx (Some symbolic) in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id) - fun_context.regions_hierarchies + let pure_backwards = + if !Config.return_back_funs then [] + else List.map translate_backward regions_hierarchy in - let pure_backwards = List.map translate_backward regions_hierarchy in (* Return *) (pure_forward, pure_backwards) @@ -294,36 +282,21 @@ let translate_crate_to_pure (crate : crate) : (List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls) in - (* Translate all the function *signatures* *) - let assumed_sigs = - List.map - (fun (info : Assumed.assumed_fun_info) -> - ( FAssumed info.fun_id, - List.map (fun _ -> None) info.fun_sig.inputs, - info.fun_sig )) - Assumed.assumed_fun_infos - in - let local_sigs = - List.map - (fun (fdef : fun_decl) -> - let input_names = - match fdef.body with - | None -> List.map (fun _ -> None) fdef.signature.inputs - | Some body -> - List.map - (fun (v : var) -> v.name) - (LlbcAstUtils.fun_body_get_input_vars body) - in - (FRegular fdef.def_id, input_names, fdef.signature)) - (FunDeclId.Map.values crate.fun_decls) + (* Compute the decomposed fun sigs for the whole crate *) + let fun_dsigs = + FunDeclId.Map.of_list + (List.map + (fun (fdef : LlbcAst.fun_decl) -> + ( fdef.def_id, + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx + fdef )) + (FunDeclId.Map.values crate.fun_decls)) in - let sigs = List.append assumed_sigs local_sigs in - let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure trans_ctx fun_sigs type_decls_map) + (translate_function_to_pure trans_ctx type_decls_map fun_dsigs) (FunDeclId.Map.values crate.fun_decls) in @@ -1030,7 +1003,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : List.map (fun { fwd; _ } -> let fwd_f = - if fwd.f.Pure.signature.info.effect_info.is_rec then + if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then [ (fwd.f.def_id, None) ] else [] in @@ -1198,7 +1171,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let exe_dir = Filename.dirname Sys.argv.(0) in let primitives_src_dest = match !Config.backend with - | FStar -> Some ("/backends/fstar/Primitives.fst", "Primitives.fst") + | FStar -> + let src = + if !Config.return_back_funs then + "/backends/fstar/merge/Primitives.fst" + else "/backends/fstar/split/Primitives.fst" + in + Some (src, "Primitives.fst") | Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v") | Lean -> None | HOL4 -> None |