From 774eb319e514a0ba02473f9c82ee9d3355de8a3d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 11:09:10 +0100 Subject: Fix an issue when merging the fwd/back functions of trait methods --- compiler/InterpreterStatements.ml | 33 ++++++++++++++++++++++++--------- compiler/SymbolicAst.ml | 4 ++++ compiler/SymbolicToPure.ml | 26 +++++++++++++++++++++----- compiler/SynthesizeSymbolic.ml | 13 ++++++++----- 4 files changed, 57 insertions(+), 19 deletions(-) diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 94c65b5c..97c8bcd6 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -731,6 +731,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) : fun_id_or_trait_method_ref * generic_args + * (generic_args * trait_instance_id) option * fun_decl * region_var_groups * inst_fun_sig = @@ -758,7 +759,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx func.generics tr_self def.signature regions_hierarchy in - (func.func, func.generics, def, regions_hierarchy, inst_sg) + (func.func, func.generics, None, def, regions_hierarchy, inst_sg) | FunId (FAssumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") @@ -811,7 +812,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) we also need to update the generics. *) let func = FunId fid in - (func, generics, method_def, regions_hierarchy, inst_sg) + ( func, + generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ) | None -> (* If not found, lookup the methods provided by the trait *declaration* (remember: for now, we forbid overriding provided methods) *) @@ -867,6 +873,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) in ( func.func, func.generics, + Some (all_generics, tr_self), method_def, regions_hierarchy, inst_sg )) @@ -900,8 +907,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, regions_hierarchy, inst_sg) - )) + ( func.func, + func.generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ))) (** Evaluate a statement *) let rec eval_statement (config : config) (st : statement) : st_cm_fun = @@ -1287,14 +1298,15 @@ and eval_transparent_function_call_concrete (config : config) and eval_transparent_function_call_symbolic (config : config) (call : call) : st_cm_fun = fun cf ctx -> - let func, generics, def, regions_hierarchy, inst_sg = + let func, generics, trait_method_generics, def, regions_hierarchy, inst_sg = eval_transparent_function_call_symbolic_inst call ctx in (* Sanity check *) assert (List.length call.args = List.length def.signature.inputs); (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config func def.signature - regions_hierarchy inst_sg generics call.args call.dest cf ctx + regions_hierarchy inst_sg generics trait_method_generics call.args call.dest + cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1310,7 +1322,9 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) : and eval_function_call_symbolic_from_inst_sig (config : config) (fid : fun_id_or_trait_method_ref) (sg : fun_sig) (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig) - (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun = + (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) + (args : operand list) (dest : place) : st_cm_fun = fun cf ctx -> log#ldebug (lazy @@ -1390,7 +1404,8 @@ and eval_function_call_symbolic_from_inst_sig (config : config) (* Synthesize the symbolic AST *) S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy - abs_ids generics args args_places ret_spc dest_place expr + abs_ids generics trait_method_generics args args_places ret_spc dest_place + expr in let cc = comp cc cf_call in @@ -1500,7 +1515,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg - regions_hierarchy inst_sig generics args dest cf ctx + regions_hierarchy inst_sig generics None args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : config) (body : statement) : st_cm_fun = diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 54d207d9..8e8cdec3 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -48,6 +48,10 @@ type call = { abstractions : AbstractionId.id list; (** The region abstractions introduced upon calling the function *) generics : generic_args; + trait_method_generics : (generic_args * trait_instance_id) option; + (** In case the call is to a trait method, we may need an additional type + parameter ([Self]) and the self trait clause to instantiate the + function signature. *) args : typed_value list; args_places : mplace option list; (** Meta information *) dest : symbolic_value; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 7eb75584..41922cb5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1985,7 +1985,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = log#ldebug (lazy - ("translate_function_call:\n" + ("translate_function_call:\n" ^ "\n- call.call_id:" + ^ S.show_call_id call.call_id + ^ "\n\n- call.generics:\n" ^ ctx_generic_args_to_string ctx call.generics)); (* Translate the function call *) let generics = ctx_translate_fwd_generic_args ctx call.generics in @@ -2025,7 +2027,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then + if !Config.return_back_funs then ( (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in let decls_ctx = ctx.decls_ctx in @@ -2034,9 +2036,23 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let tr_self = UnknownTrait __FUNCTION__ in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in let back_tys = - compute_back_tys_with_info dsg (Some (generics, tr_self)) + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) @@ -2106,7 +2122,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) else (ctx, false, None, []) in (* Compute the pattern for the destination *) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 4ec7524b..865185a8 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -95,6 +95,7 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value) let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) (sg : fun_sig option) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = @@ -108,6 +109,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) regions_hierarchy; abstractions; generics; + trait_method_generics; args; dest; args_places; @@ -125,19 +127,20 @@ let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref) (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx (Some sg) regions_hierarchy abstractions generics args args_places dest - dest_place e + ctx (Some sg) regions_hierarchy abstractions generics trait_method_generics + args args_places dest dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop) (arg : typed_value) (arg_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Unop unop) ctx None [] [] generics [ arg ] + synthesize_function_call (Unop unop) ctx None [] [] generics None [ arg ] [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) @@ -145,8 +148,8 @@ let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) (arg1_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Binop binop) ctx None [] [] generics [ arg0; arg1 ] - [ arg0_place; arg1_place ] dest dest_place e + synthesize_function_call (Binop binop) ctx None [] [] generics None + [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs) (e : expression option) : expression option = -- cgit v1.2.3