diff options
-rw-r--r-- | compiler/Pure.ml | 17 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 186 | ||||
-rw-r--r-- | compiler/Translate.ml | 29 |
3 files changed, 151 insertions, 81 deletions
diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 05cdbd70..71531688 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -886,23 +886,6 @@ type inputs_info = { } [@@deriving show] -type ('a, 'b) back_info = - | SingleBack of 'a - (** Information about a single backward function, if pertinent. - - We use this variant if we split the forward and the backward functions. - *) - | AllBacks of 'b RegionGroupId.Map.t - (** Information about the various backward functions. - - We use this if we *do not* split the forward and the backward functions. - All the information is then carried by the forward function. - *) -[@@deriving show] - -type back_inputs_info = (inputs_info option, inputs_info) back_info -[@@deriving show] - (** Meta information about a function signature *) type fun_sig_info = { fwd_info : inputs_info; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 70a4e18d..37f621e4 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -146,6 +146,7 @@ type bs_ctx = { global_ctx : global_ctx; trait_decls_ctx : trait_decls_ctx; trait_impls_ctx : trait_impls_ctx; + fun_dsigs : decomposed_fun_sig FunDeclId.Map.t; fun_decl : A.fun_decl; bid : RegionGroupId.id option; (** TODO: rename @@ -890,7 +891,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] (** Small utility. *) -let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) +let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with @@ -917,6 +918,22 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } +(** TODO: not very clean. *) +let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : + fun_effect_info = + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + (** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and @@ -962,7 +979,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in + let fwd_effect_info = + compute_raw_fun_effect_info fun_infos fun_id None None + in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1051,12 +1070,23 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos fun_id None (Some gid) + compute_raw_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in + (* We consider the backward function as stateful and potentially failing + **only if it has inputs** (for the "potentially failing": if it has + not inputs, we directly evaluate it in the body of the forward function). + *) + let back_effect_info = + { + back_effect_info with + stateful = back_effect_info.stateful && inputs_no_state <> []; + can_fail = back_effect_info.can_fail && inputs_no_state <> []; + } + in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in @@ -1140,6 +1170,19 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx (FunId (FRegular fun_id)) regions_hierarchy sg input_names +let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) + (fdef : LlbcAst.fun_decl) : decomposed_fun_sig = + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : LlbcAst.var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) + in + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1158,8 +1201,9 @@ let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) If a backward function has no inputs/outputs we filter it. *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty option list = +let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : + (back_sg_info * ty) option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in @@ -1185,9 +1229,13 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) in ty_substitute subst ty in - Some ty) + Some (back_sg, ty)) (RegionGroupId.Map.values dsg.back_sg) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty option list = + List.map (Option.map snd) (compute_back_tys_with_info dsg subst) + (** In case we merge the fwd/back functions: compute the output type of a function, from a decomposed signature. *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = @@ -1363,6 +1411,7 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = in fresh_opt_vars back_vars ctx +(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *) let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1381,12 +1430,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value = | _ -> raise (Failure "Unreachable")) | _ -> v -(** Translate a symbolic value *) +(** Translate a symbolic value. + + Because we do not necessarily introduce variables for the symbolic values + of (translated) type unit, it is important that we do not lookup variables + in case the symbolic value has type unit. + *) let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) : texpression = (* Translate the type *) - let var = lookup_var_for_symbolic_value sv ctx in - mk_texpression_from_var var + let ty = ctx_translate_fwd_ty ctx sv.sv_ty in + (* If the type is unit, directly return unit *) + if ty_is_unit ty then mk_unit_rvalue + else + (* Otherwise lookup the variable *) + let var = lookup_var_for_symbolic_value sv ctx in + mk_texpression_from_var var (** Translate a typed value. @@ -1565,13 +1624,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option = match aproj with | V.AEndedProjLoans (msv, []) -> (* The symbolic value was left unchanged *) - let var = lookup_var_for_symbolic_value msv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx msv) | V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) -> assert (child_aproj = AIgnoredProjBorrows); (* The symbolic value was updated *) - let var = lookup_var_for_symbolic_value mnv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx mnv) | V.AEndedProjLoans (_, _) -> (* The symbolic value was updated, and the given back values come from sevearl * abstractions *) @@ -1940,10 +1997,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.combine args args_mplaces) in let dest_mplace = translate_opt_mplace call.dest_place in - let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, back_funs, out_state = + let ctx, fun_id, effect_info, args, dest_v = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1951,13 +2007,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fid None None - in + let effect_info = get_fun_effect_info ctx fid None None in (* Depending on the function effects: - * - add the fuel - * - add the state input argument - * - generate a fresh state variable for the returned state + - add the fuel + - add the state input argument + - generate a fresh state variable for the returned state *) let args, ctx, out_state = let fuel = mk_fuel_input_as_list ctx effect_info in @@ -1970,7 +2024,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* If we do not split the forward/backward functions: generate the variables for the backward functions returned by the forward function. *) - let ctx, back_funs_map, back_funs = + let ctx, ignore_fwd_output, back_funs_map, back_funs = if !Config.return_back_funs then (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in @@ -1981,7 +2035,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.map (fun _ -> None) sg.inputs) in let tr_self = UnknownTrait __FUNCTION__ in - let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in + let back_tys = + compute_back_tys_with_info dsg (Some (generics, tr_self)) + in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) let back_fun_name = @@ -2016,7 +2072,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (fun ty -> match ty with | None -> None - | Some ty -> Some (Some back_fun_name, ty)) + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then + Some back_fun_name + else None + in + Some (name, ty)) back_tys) ctx in @@ -2039,14 +2106,37 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + else (ctx, false, None, []) + in + (* Compute the pattern for the destination *) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = + (* Here there is something subtle: as we might ignore the output + of the forward function (because it translates to unit) we doNOT + necessarily introduce in the let-binding the variable to which we + map the symbolic value which was introduced for the output of the + function call. This would be problematic if later we need to + translate this symbolic value, but we implemented + {!symbolic_value_to_texpression} so that it doesn't perform any + lookups if the symbolic value has type unit. + *) + let vars = + if ignore_fwd_output then back_funs else dest :: back_funs + in + mk_simpl_tuple_pattern vars + in + let dest = + match out_state with + | None -> dest + | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args back_funs_map ctx in - (ctx, func, effect_info, args, back_funs, out_state) + (ctx, func, effect_info, args, dest) | S.Unop E.Not -> let effect_info = { @@ -2057,7 +2147,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop Not, effect_info, args, dest) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -2073,7 +2165,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Neg int_ty), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -2088,7 +2182,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -2108,16 +2204,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Binop (binop, int_ty0), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) in - let dest_v = - let dest = mk_typed_pattern_from_var dest dest_mplace in - let dest = mk_simpl_tuple_pattern (dest :: back_funs) in - match out_state with - | None -> dest - | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] - in let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -2242,9 +2333,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Those don't have backward functions *) raise (Failure "Unreachable") in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id) - in + let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in @@ -2449,8 +2538,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) - (Some rg_id) + get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2609,8 +2697,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) - let scrutinee_var = lookup_var_for_symbolic_value sv ctx in - let scrutinee = mk_texpression_from_var scrutinee_var in + let scrutinee = symbolic_value_to_texpression ctx sv in let scrutinee_mplace = translate_opt_mplace p in (* Translate the branches *) match exp with @@ -2999,7 +3086,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Compute whether the backward expressions should be evaluated straight away or not (i.e., if we should bind them with monadic let-bindings or not). We evaluate them straight away if they can fail and have no - inputs *) + inputs. *) let evaluate_backs = List.map (fun (sg : back_sg_info) -> @@ -3098,9 +3185,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid - in + let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in @@ -3479,8 +3564,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id)) - None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 5584fb9a..ccc46420 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -42,7 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : + (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) + (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) : pure_fun_translation_no_loops = (* Debug *) log#ldebug @@ -119,18 +120,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) !m in - let input_names = - match fdef.body with - | None -> List.map (fun _ -> None) fdef.signature.inputs - | Some body -> - List.map - (fun (v : var) -> v.name) - (LlbcAstUtils.fun_body_get_input_vars body) - in - let sg = - SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id - fdef.signature input_names + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef in let regions_hierarchy = @@ -154,6 +145,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) decls_ctx = trans_ctx; SymbolicToPure.bid = None; sg; + fun_dsigs; (* Will need to be updated for the backward functions *) sv_to_var; var_counter = ref var_counter; @@ -290,10 +282,21 @@ let translate_crate_to_pure (crate : crate) : (List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls) in + (* Compute the decomposed fun sigs for the whole crate *) + let fun_dsigs = + FunDeclId.Map.of_list + (List.map + (fun (fdef : LlbcAst.fun_decl) -> + ( fdef.def_id, + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx + fdef )) + (FunDeclId.Map.values crate.fun_decls)) + in + (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure trans_ctx type_decls_map) + (translate_function_to_pure trans_ctx type_decls_map fun_dsigs) (FunDeclId.Map.values crate.fun_decls) in |