diff options
author | Son Ho | 2023-12-21 14:37:43 +0100 |
---|---|---|
committer | Son Ho | 2023-12-21 14:37:43 +0100 |
commit | 8835d87df111d09122267fadc9a32f16b52d234a (patch) | |
tree | 43a1fd0e3ec0e8b94834744396b86bbd3084c888 /compiler | |
parent | e90b23a0d42e2ea6805c88d6eaa4f9e5370a1dc1 (diff) |
Make good progress on merging the fwd/back functions
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 2 | ||||
-rw-r--r-- | compiler/Extract.ml | 4 | ||||
-rw-r--r-- | compiler/InterpreterStatements.ml | 47 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 19 | ||||
-rw-r--r-- | compiler/SymbolicAst.ml | 4 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 266 | ||||
-rw-r--r-- | compiler/SynthesizeSymbolic.ml | 16 | ||||
-rw-r--r-- | compiler/Translate.ml | 3 |
8 files changed, 274 insertions, 87 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index c8f3ed58..b8af6c6d 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -153,7 +153,7 @@ let loop_fixed_point_max_num_iters = 2 return (x :: ls))) ]} *) -let return_back_funs = ref false +let return_back_funs = ref true (** Forbids using field projectors for structures. diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 3429cd11..46cf8c4a 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1332,7 +1332,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) in let fwd_back_comment = match def.back_id with - | None -> [ "forward function" ] + | None -> + if !Config.return_back_funs then [ "function definition" ] + else [ "forward function" ] | Some id -> (* Check if there is only one backward function, and no forward function *) if (not keep_fwd) && num_backs = 1 then diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index da617c64..94c65b5c 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -728,7 +728,12 @@ let create_push_abstractions_from_abs_region_groups to a trait clause but directly to the method provided in the trait declaration. *) let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) - : fun_id_or_trait_method_ref * generic_args * fun_decl * inst_fun_sig = + : + fun_id_or_trait_method_ref + * generic_args + * fun_decl + * region_var_groups + * inst_fun_sig = match call.func with | FnOpMove _ -> (* Closure case: TODO *) @@ -753,7 +758,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx func.generics tr_self def.signature regions_hierarchy in - (func.func, func.generics, def, inst_sg) + (func.func, func.generics, def, regions_hierarchy, inst_sg) | FunId (FAssumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") @@ -806,7 +811,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) we also need to update the generics. *) let func = FunId fid in - (func, generics, method_def, inst_sg) + (func, generics, method_def, regions_hierarchy, inst_sg) | None -> (* If not found, lookup the methods provided by the trait *declaration* (remember: for now, we forbid overriding provided methods) *) @@ -860,7 +865,11 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx all_generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg)) + ( func.func, + func.generics, + method_def, + regions_hierarchy, + inst_sg )) | _ -> (* We are using a local clause - we lookup the trait decl *) let trait_decl = @@ -891,7 +900,8 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg))) + (func.func, func.generics, method_def, regions_hierarchy, inst_sg) + )) (** Evaluate a statement *) let rec eval_statement (config : config) (st : statement) : st_cm_fun = @@ -1277,14 +1287,14 @@ and eval_transparent_function_call_concrete (config : config) and eval_transparent_function_call_symbolic (config : config) (call : call) : st_cm_fun = fun cf ctx -> - let func, generics, def, inst_sg = + let func, generics, def, regions_hierarchy, inst_sg = eval_transparent_function_call_symbolic_inst call ctx in (* Sanity check *) assert (List.length call.args = List.length def.signature.inputs); (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config func inst_sg generics - call.args call.dest cf ctx + eval_function_call_symbolic_from_inst_sig config func def.signature + regions_hierarchy inst_sg generics call.args call.dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1298,7 +1308,8 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) : trait ref as input. *) and eval_function_call_symbolic_from_inst_sig (config : config) - (fid : fun_id_or_trait_method_ref) (inst_sg : inst_fun_sig) + (fid : fun_id_or_trait_method_ref) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig) (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun = fun cf ctx -> log#ldebug @@ -1378,8 +1389,8 @@ and eval_function_call_symbolic_from_inst_sig (config : config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids generics args - args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy + abs_ids generics args args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1468,7 +1479,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) (* In symbolic mode, the behaviour of a function call is completely defined * by the signature of the function: we thus simply generate correctly * instantiated signatures, and delegate the work to an auxiliary function *) - let inst_sig = + let sg, regions_hierarchy, inst_sig = match fid with | BoxFree -> (* Should have been treated above *) @@ -1480,14 +1491,16 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) in (* There shouldn't be any reference to Self *) let tr_self = UnknownTrait __FUNCTION__ in - instantiate_fun_sig ctx generics tr_self - (Assumed.get_assumed_fun_sig fid) - regions_hierarchy + let sg = Assumed.get_assumed_fun_sig fid in + let inst_sg = + instantiate_fun_sig ctx generics tr_self sg regions_hierarchy + in + (sg, regions_hierarchy, inst_sg) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) - inst_sig generics args dest cf ctx + eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg + regions_hierarchy inst_sig generics args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : config) (body : statement) : st_cm_fun = diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 6e86578c..6579e84c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -321,14 +321,23 @@ let destruct_apps (e : texpression) : texpression * texpression list = (** Make an [App (app, arg)] expression *) let mk_app (app : texpression) (arg : texpression) : texpression = + let raise_or_return msg = + if !Config.fail_hard then raise (Failure msg) + else + let e = App (app, arg) in + (* Dummy type - TODO: introduce an error type *) + let ty = app.ty in + { e; ty } + in match app.ty with | TArrow (ty0, ty1) -> (* Sanity check *) - assert (ty0 = arg.ty); - let e = App (app, arg) in - let ty = ty1 in - { e; ty } - | _ -> raise (Failure "Expected an arrow type") + if ty0 <> arg.ty then raise_or_return "App: wrong input type" + else + let e = App (app, arg) in + let ty = ty1 in + { e; ty } + | _ -> raise_or_return "Expected an arrow type" (** The reverse of {!destruct_apps} *) let mk_apps (app : texpression) (args : texpression list) : texpression = diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 53f99b7f..54d207d9 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -42,7 +42,11 @@ type call = { evaluated). We need it to compute the translated values for shared borrows (we need to perform lookups). *) + sg : fun_sig option; + (** The uninstantiated function signature, if this is not a unop/binop *) + regions_hierarchy : region_var_groups; abstractions : AbstractionId.id list; + (** The region abstractions introduced upon calling the function *) generics : generic_args; args : typed_value list; args_places : mplace option list; (** Meta information *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index e2787271..1ce6c698 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,6 +67,18 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) + back_funs : texpression RegionGroupId.Map.t option; + (** If we do not split between the forward/backward functions: the + variables we introduced for the backward functions. + + Example: + {[ + let x, back = Vec.index_mut n v in + ^^^^ + here + ... + ]} + *) } [@@deriving show] @@ -118,6 +130,8 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { + (* TODO: there are a lot of duplications with the various decls ctx *) + decls_ctx : C.decls_ctx; type_ctx : type_ctx; fun_ctx : fun_ctx; global_ctx : global_ctx; @@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) TraitMethod (trait_ref, method_name, fun_decl_id) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = + (args : texpression list) + (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx + = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = { forward; forward_inputs = args } in + let info = { forward; forward_inputs = args; back_funs } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = +(** [inherit_args]: the list of inputs inherited from the forward function and + the ancestors backward functions, if pertinent. + [back_args]: the *additional* list of inputs received by the backward function, + including the state. + + Returns the updated context and the expression corresponding to the function. + *) +let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) + (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) + (inherited_args : texpression list) (back_args : texpression list) + (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : + bs_ctx * texpression = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + (* Compute the expression corresponding to the function *) + let func = + if !Config.return_back_funs then + (* Lookup the variable introduced for the backward function *) + RegionGroupId.Map.find back_id (Option.get info.back_funs) + else + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + let args = List.append inherited_args back_args in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output_ty else output_ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp fun_id; generics } in + { e = Qualif func; ty = func_ty } in (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) + ({ ctx with calls; abstractions }, func) (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) - (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : - decomposed_fun_sig = +let translate_fun_sig_with_regions_hierarchy_to_decomposed + (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig) + (input_names : string option list) : decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in - (* Retrieve the list of parent backward functions *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None None - in + let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) + get_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = @@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list) + : decomposed_fun_sig = + (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies + in + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + (FunId (FRegular fun_id)) regions_hierarchy sg input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) +(** Return the pure signature of a backward function, in the case the + forward/backward functions are merged (i.e., the forward functions + return the backward functions). + + TODO: merge with {!translate_fun_sig_from_decomposed} + *) +let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id) : fun_sig = + assert !Config.return_back_funs; + + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) + in + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty = mk_output_ty_from_effect_info in + + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + (* Do not prepend the forward inputs *) + let inputs = List.map snd back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, out_state = + let ctx, fun_id, effect_info, args, back_funs, out_state = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in + (* If we do not split the forward/backward functions: generate the + variables for the backward functions returned by the forward + function. *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + fid call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_sgs = + List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + in + (* Introduce variables for the backward functions *) + let back_tys = + List.map + (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + back_sgs + in + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls + in + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) + in + name ^ "_back" + in + let ctx, back_vars = + fresh_vars + (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + ctx + in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in (* Register the function call *) - let ctx = bs_ctx_register_forward_call call_id call args ctx in - (ctx, func, effect_info, args, out_state) + let ctx = + bs_ctx_register_forward_call call_id call args back_funs_map ctx + in + (ctx, func, effect_info, args, back_funs, out_state) | S.Unop E.Not -> let effect_info = { @@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, None) + (ctx, Unop Not, effect_info, args, [], None) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, None) + (ctx, Unop (Neg int_ty), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, None) + (ctx, Binop (binop, int_ty0), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) in let dest_v = let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = mk_simpl_tuple_pattern (dest :: back_funs) in match out_state with | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] @@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + let inherited_inputs = + if !Config.return_back_funs then [] + else List.concat [ fwd_inputs; back_ancestors_inputs ] in + let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the * meta-place information from the input values given to the forward function @@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Retrieve the function id, and register the function call in the context - * if necessary *) + if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs + back_inputs generics output.ty ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) + let inputs = List.append inherited_inputs back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + ================= + We do a small optimization here if we split the forward/backward functions. + If the backward function doesn't have any output, we don't introduce any function + call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else mk_let effect_info.can_fail output call next_e diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index efcf001a..4ec7524b 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -2,6 +2,7 @@ open Types open TypesUtils open Expressions open Values +open LlbcAst open SymbolicAst let mk_mplace (p : place) (ctx : Contexts.eval_ctx) : mplace = @@ -92,6 +93,7 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value) synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) + (sg : fun_sig option) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) @@ -102,6 +104,8 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) { call_id; ctx; + sg; + regions_hierarchy; abstractions; generics; args; @@ -118,28 +122,30 @@ let synthesize_global_eval (gid : GlobalDeclId.id) (dest : symbolic_value) Option.map (fun e -> EvalGlobal (gid, dest, e)) e let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref) - (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) + (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions generics args args_places dest dest_place e + ctx (Some sg) regions_hierarchy abstractions generics args args_places dest + dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop) (arg : typed_value) (arg_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ] - dest dest_place e + synthesize_function_call (Unop unop) ctx None [] [] generics [ arg ] + [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) (arg0 : typed_value) (arg0_place : mplace option) (arg1 : typed_value) (arg1_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ] + synthesize_function_call (Binop binop) ctx None [] [] generics [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 631a5af9..5584fb9a 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -129,7 +129,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let sg = - SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id) + SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id fdef.signature input_names in @@ -151,6 +151,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let ctx = { + decls_ctx = trans_ctx; SymbolicToPure.bid = None; sg; (* Will need to be updated for the backward functions *) |