diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 148 |
1 files changed, 132 insertions, 16 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index b024f40e..120689e5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -163,6 +163,13 @@ type bs_ctx = { (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state) *) + loop_backward_outputs : var list T.RegionGroupId.Map.t option; + (** Same as {!backward_outputs}, but for loops (if we entered a loop). + + [None] if we are not inside a loop, [Some] otherwise (and whatever + the kind of function we are translating: it will be [Some] even + though we are synthesizing a forward function). + *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -255,6 +262,11 @@ let ty_to_string (ctx : bs_ctx) (ty : ty) : string = let fmt = PrintPure.ast_to_type_formatter fmt in PrintPure.ty_to_string fmt false ty +let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = + let fmt = bs_ctx_to_ctx_formatter ctx in + let fmt = Print.PC.ctx_to_rtype_formatter fmt in + Print.PT.rty_to_string fmt ty + let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = let type_params = def.type_params in let type_decls = ctx.type_context.llbc_type_decls in @@ -829,7 +841,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Return *) (ctx, state_pat) -let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : +let fresh_var_llbc_ty (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in @@ -843,7 +855,7 @@ let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : let fresh_named_var_for_symbolic_value (basename : string option) (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let ctx, var = fresh_var basename sv.sv_ty ctx in + let ctx, var = fresh_var_llbc_ty basename sv.sv_ty ctx in (* Insert in the map *) let sv_to_var = V.SymbolicValueId.Map.add_strict sv.sv_id var ctx.sv_to_var in (* Update the context *) @@ -981,8 +993,9 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) log#ldebug (lazy ("typed_value_to_texpression: result:" ^ "\n- input value:\n" - ^ V.show_typed_value v ^ "\n- translated expression:\n" - ^ show_texpression value)); + ^ typed_value_to_string ctx v + ^ "\n- translated expression:\n" + ^ texpression_to_string ctx value)); (* Sanity check *) type_check_texpression ctx value; (* Return *) @@ -1296,7 +1309,21 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in + let output_ty = + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let back_vars = + T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) + in + let tys = List.map (fun (v : var) -> v.ty) back_vars in + mk_simpl_tuple_ty tys + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + mk_simpl_tuple_ty ctx.sg.doutputs + in (* TODO: we should use a [Fail] function *) if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) @@ -1373,7 +1400,14 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs + let map = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function *) + ctx.backward_outputs + in + T.RegionGroupId.Map.find bid map in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values @@ -1535,9 +1569,15 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) = log#ldebug (lazy - ("translate_end_abstraction_synth_input:" ^ "\n- eval_ctx:\n" - ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" ^ IU.abs_to_string ectx abs - ^ "\n")); + ("translate_end_abstraction_synth_input:" ^ "\n- function: " + ^ Print.name_to_string ctx.fun_decl.name + ^ "\n- rg_id: " + ^ T.RegionGroupId.to_string rg_id + ^ "\n- loop_id: " + ^ Print.option_to_string Pure.LoopId.to_string ctx.loop_id + ^ "\n- eval_ctx:\n" ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" + ^ IU.abs_to_string ectx abs ^ "\n")); + (* When we end an input abstraction, this input abstraction gets back * the borrows which it introduced in the context through the input * values: by listing those values, we get the values which are given @@ -1564,12 +1604,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) * (v_i) * ]} * *) - (* First, get the given back variables *) + (* First, get the given back variables. + + We don't use the same given back variables if we translate a loop or + the standard body of a function. + *) let given_back_variables = - T.RegionGroupId.Map.find bid ctx.backward_outputs + let map = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function body *) + ctx.backward_outputs + in + T.RegionGroupId.Map.find bid map in + (* Get the list of values consumed by the abstraction upon ending *) let consumed_values = abs_to_consumed ctx ectx abs in + + log#ldebug + (lazy + ("translate_end_abstraction_synth_input:" + ^ "\n\n- given back variables types:\n" + ^ Print.list_to_string + (fun (v : var) -> ty_to_string ctx v.ty) + given_back_variables + ^ "\n\n- consumed values:\n" + ^ Print.list_to_string + (fun e -> texpression_to_string ctx e ^ " : " ^ ty_to_string ctx e.ty) + consumed_values + ^ "\n")); + (* Group the two lists *) let variables_values = List.combine given_back_variables consumed_values in (* Sanity check: the two lists match (same types) *) @@ -1655,11 +1721,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" ^ string_of_int (List.length inputs) ^ "): " - ^ String.concat ", " (List.map show_texpression inputs) + ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) ^ "\n- inst_sg.inputs (" ^ string_of_int (List.length inst_sg.inputs) ^ "): " - ^ String.concat ", " (List.map show_ty inst_sg.inputs))); + ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); List.iter (fun (x, ty) -> assert ((x : texpression).ty = ty)) (List.combine inputs inst_sg.inputs); @@ -2272,6 +2338,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = loop.input_svalues ^ "\n- filtered svl: " ^ (Print.list_to_string (symbolic_value_to_string ctx)) svl + ^ "\n- rg_to_abs\n:" + ^ T.RegionGroupId.Map.show + (fun (rids, tys) -> + "(" ^ T.RegionId.Set.show rids ^ ", " + ^ Print.list_to_string (rty_to_string ctx) tys + ^ ")") + loop.rg_to_given_back_tys ^ "\n")); let ctx, _ = fresh_vars_for_symbolic_values svl ctx in ctx @@ -2294,6 +2367,39 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun var -> mk_typed_pattern_from_var var None) inputs in + (* Compute the backward outputs *) + let ctx = ref ctx in + let loop_backward_outputs = + T.RegionGroupId.Map.map + (fun (_, tys) -> + (* The types shouldn't contain borrows - we can translate them as forward types *) + let vars = + List.map + (fun ty -> + assert ( + not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty)); + (None, ctx_translate_fwd_ty !ctx ty)) + tys + in + (* Introduce fresh variables *) + let ctx', vars = fresh_vars vars !ctx in + ctx := ctx'; + vars) + loop.rg_to_given_back_tys + in + let ctx = !ctx in + + let back_output_tys = + match ctx.bid with + | None -> None + | Some rg_id -> + let back_outputs = + T.RegionGroupId.Map.find rg_id loop_backward_outputs + in + let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in + Some back_output_tys + in + (* Add the loop information in the context *) let ctx = assert (not (LoopId.Map.mem loop_id ctx.loops)); @@ -2319,7 +2425,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = { ctx with loop_id = Some loop_id } in + let ctx_end = + { + ctx with + loop_id = Some loop_id; + loop_backward_outputs = Some loop_backward_outputs; + } + in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) @@ -2339,10 +2451,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state = (if !Config.use_state then Some ctx.state_var else None); inputs; inputs_lvs; + back_output_tys; loop_body; } in - assert (fun_end.ty = loop_body.ty); + (* If we translate forward functions: the return type of a loop body is the + same as the parent function *) + assert (Option.is_some ctx.bid || fun_end.ty = loop_body.ty); let ty = fun_end.ty in { e = loop; ty } @@ -2524,7 +2639,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = ^ "\n- back_state: " ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " - ^ String.concat ", " (List.map show_ty signature.inputs))); + ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs) + )); assert ( List.for_all (fun (var, ty) -> (var : var).ty = ty) |