summaryrefslogtreecommitdiff
path: root/compiler/Translate.ml
diff options
context:
space:
mode:
authorSon HO2023-12-23 01:46:58 +0100
committerGitHub2023-12-23 01:46:58 +0100
commit15a7d7b7322a1cd0ebeb328fde214060e23fa8b4 (patch)
tree6cce7d76969870f5bc18c5a7cd585e8873a1c0dc /compiler/Translate.ml
parentc3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff)
parent63ccbd914d5d44aa30dee38a6fcc019310ab640b (diff)
Merge pull request #64 from AeneasVerif/son/merge_back
Merge the forward/backward functions
Diffstat (limited to 'compiler/Translate.ml')
-rw-r--r--compiler/Translate.ml211
1 files changed, 95 insertions, 116 deletions
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 221d4e73..55a94302 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -6,7 +6,6 @@ open LlbcAst
open Contexts
module SA = SymbolicAst
module Micro = PureMicroPasses
-open PureUtils
open TranslateCore
(** The local logger *)
@@ -43,8 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) :
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (trans_ctx : trans_ctx)
- (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t)
- (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) :
+ (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t)
+ (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) :
pure_fun_translation_no_loops =
(* Debug *)
log#ldebug
@@ -58,13 +57,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Convert the symbolic ASTs to pure ASTs: *)
(* Initialize the context *)
- let forward_sig =
- RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs
- in
let sv_to_var = SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
let state_var, var_counter = Pure.VarId.fresh var_counter in
- let back_state_var, var_counter = Pure.VarId.fresh var_counter in
let fuel0, var_counter = Pure.VarId.fresh var_counter in
let fuel, var_counter = Pure.VarId.fresh var_counter in
let calls = FunCallId.Map.empty in
@@ -78,7 +73,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| RecGroup _ -> Some tid)
(TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups))
in
- let type_context =
+ let type_ctx =
{
SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos;
llbc_type_decls = trans_ctx.type_ctx.type_decls;
@@ -86,15 +81,14 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
recursive_decls = recursive_type_decls;
}
in
- let fun_context =
+ let fun_ctx =
{
SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls;
- fun_sigs;
fun_infos = trans_ctx.fun_ctx.fun_infos;
regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies;
}
in
- let global_context =
+ let global_ctx =
{ SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls }
in
@@ -126,31 +120,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
!m
in
+ let sg =
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef
+ in
+
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies
+ in
+
+ let var_counter, back_state_vars =
+ if !Config.return_back_funs then (var_counter, [])
+ else
+ List.fold_left_map
+ (fun var_counter (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let var, var_counter = Pure.VarId.fresh var_counter in
+ (var_counter, (gid, var)))
+ var_counter regions_hierarchy
+ in
+ let back_state_vars = RegionGroupId.Map.of_list back_state_vars in
+
let ctx =
{
+ decls_ctx = trans_ctx;
SymbolicToPure.bid = None;
- (* Dummy for now *)
- sg = forward_sig.sg;
- fwd_sg = forward_sig.sg;
+ sg;
+ fun_dsigs;
(* Will need to be updated for the backward functions *)
sv_to_var;
- var_counter;
+ var_counter = ref var_counter;
state_var;
- back_state_var;
+ back_state_vars;
fuel0;
fuel;
- type_context;
- fun_context;
- global_context;
+ type_ctx;
+ fun_ctx;
+ global_ctx;
trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls;
trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls;
fun_decl = fdef;
forward_inputs = [];
- (* Empty for now *)
- backward_inputs = RegionGroupId.Map.empty;
- (* Empty for now *)
- backward_outputs = RegionGroupId.Map.empty;
- loop_backward_outputs = None;
+ (* Initialized just below *)
+ backward_inputs_no_state = RegionGroupId.Map.empty;
+ (* Initialized just below *)
+ backward_inputs_with_state = RegionGroupId.Map.empty;
+ backward_outputs = None;
(* Empty for now *)
calls;
abstractions;
@@ -180,6 +194,37 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| _ -> raise (Failure "Unreachable")
in
+ (* Add the backward inputs *)
+ let ctx, backward_inputs_no_state, backward_inputs_with_state =
+ if !Config.return_back_funs then (ctx, [], [])
+ else
+ let ctx, inputs_no_with_state =
+ List.fold_left_map
+ (fun ctx (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let back_sg = RegionGroupId.Map.find gid sg.back_sg in
+ let ctx, no_state =
+ SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, with_state =
+ SymbolicToPure.fresh_vars back_sg.inputs ctx
+ in
+ (ctx, ((gid, no_state), (gid, with_state))))
+ ctx regions_hierarchy
+ in
+ let inputs_no_state, inputs_with_state =
+ List.split inputs_no_with_state
+ in
+ (ctx, inputs_no_state, inputs_with_state)
+ in
+ let backward_inputs_no_state =
+ RegionGroupId.Map.of_list backward_inputs_no_state
+ in
+ let backward_inputs_with_state =
+ RegionGroupId.Map.of_list backward_inputs_with_state
+ in
+ let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
+
(* Translate the forward function *)
let pure_forward =
match symbolic_trans with
@@ -187,7 +232,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
in
- (* Translate the backward functions *)
+ (* Translate the backward functions, if we split the forward and backward functions *)
let translate_backward (rg : region_var_group) : Pure.fun_decl =
(* For the backward inputs/outputs initialization: we use the fact that
* there are no nested borrows for now, and so that the region groups
@@ -197,77 +242,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
match symbolic_trans with
| None ->
- (* Initialize the context - note that the ret_ty is not really
- * useful as we don't translate a body *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx None
| Some (_, symbolic) ->
- (* Finish initializing the context by adding the additional input
- variables required by the backward function.
- *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- (* We need to ignore the forward inputs, and the state input (if there is) *)
- let backward_inputs =
- let sg = backward_sg.sg in
- (* We need to ignore the forward state and the backward state *)
- let num_forward_inputs =
- sg.info.num_fwd_inputs_with_fuel_with_state
- in
- let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
- Collections.List.subslice sg.inputs num_forward_inputs
- (num_forward_inputs + num_back_inputs)
- in
- (* As we forbid nested borrows, the additional inputs for the backward
- * functions come from the borrows in the return value of the rust function:
- * we thus use the name "ret" for those inputs *)
- let backward_inputs =
- List.map (fun ty -> (Some "ret", ty)) backward_inputs
- in
- let ctx, backward_inputs =
- SymbolicToPure.fresh_vars backward_inputs ctx
- in
- (* The outputs for the backward functions, however, come from borrows
- * present in the input values of the rust function: for those we reuse
- * the names of the input values. *)
- let backward_outputs =
- List.combine backward_sg.output_names backward_sg.sg.doutputs
- in
- let ctx, backward_outputs =
- SymbolicToPure.fresh_vars backward_outputs ctx
- in
- let backward_inputs =
- RegionGroupId.Map.singleton back_id backward_inputs
- in
- let backward_outputs =
- RegionGroupId.Map.singleton back_id backward_outputs
- in
-
- (* Put everything in the context *)
- let ctx =
- {
- ctx with
- bid = Some back_id;
- sg = backward_sg.sg;
- backward_inputs;
- backward_outputs;
- }
- in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx (Some symbolic)
in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id)
- fun_context.regions_hierarchies
+ let pure_backwards =
+ if !Config.return_back_funs then []
+ else List.map translate_backward regions_hierarchy
in
- let pure_backwards = List.map translate_backward regions_hierarchy in
(* Return *)
(pure_forward, pure_backwards)
@@ -294,36 +282,21 @@ let translate_crate_to_pure (crate : crate) :
(List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls)
in
- (* Translate all the function *signatures* *)
- let assumed_sigs =
- List.map
- (fun (info : Assumed.assumed_fun_info) ->
- ( FAssumed info.fun_id,
- List.map (fun _ -> None) info.fun_sig.inputs,
- info.fun_sig ))
- Assumed.assumed_fun_infos
- in
- let local_sigs =
- List.map
- (fun (fdef : fun_decl) ->
- let input_names =
- match fdef.body with
- | None -> List.map (fun _ -> None) fdef.signature.inputs
- | Some body ->
- List.map
- (fun (v : var) -> v.name)
- (LlbcAstUtils.fun_body_get_input_vars body)
- in
- (FRegular fdef.def_id, input_names, fdef.signature))
- (FunDeclId.Map.values crate.fun_decls)
+ (* Compute the decomposed fun sigs for the whole crate *)
+ let fun_dsigs =
+ FunDeclId.Map.of_list
+ (List.map
+ (fun (fdef : LlbcAst.fun_decl) ->
+ ( fdef.def_id,
+ SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx
+ fdef ))
+ (FunDeclId.Map.values crate.fun_decls))
in
- let sigs = List.append assumed_sigs local_sigs in
- let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure trans_ctx fun_sigs type_decls_map)
+ (translate_function_to_pure trans_ctx type_decls_map fun_dsigs)
(FunDeclId.Map.values crate.fun_decls)
in
@@ -1030,7 +1003,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
List.map
(fun { fwd; _ } ->
let fwd_f =
- if fwd.f.Pure.signature.info.effect_info.is_rec then
+ if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then
[ (fwd.f.def_id, None) ]
else []
in
@@ -1198,7 +1171,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let exe_dir = Filename.dirname Sys.argv.(0) in
let primitives_src_dest =
match !Config.backend with
- | FStar -> Some ("/backends/fstar/Primitives.fst", "Primitives.fst")
+ | FStar ->
+ let src =
+ if !Config.return_back_funs then
+ "/backends/fstar/merge/Primitives.fst"
+ else "/backends/fstar/split/Primitives.fst"
+ in
+ Some (src, "Primitives.fst")
| Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v")
| Lean -> None
| HOL4 -> None