diff options
author | Son Ho | 2022-12-17 10:27:12 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 66638a2a96c7639553a340917b87e26d94265c5e (patch) | |
tree | a0219df7582ca17784135345924790dc26a7e315 /compiler | |
parent | 07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff) |
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r-- | compiler/Extract.ml | 29 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 21 | ||||
-rw-r--r-- | compiler/Pure.ml | 5 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 383 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 15 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 23 | ||||
-rw-r--r-- | compiler/Translate.ml | 85 | ||||
-rw-r--r-- | compiler/TranslateCore.ml | 4 |
8 files changed, 416 insertions, 149 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index fa384de6..b3d7b49e 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1254,21 +1254,28 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) (forward function and backward functions). *) let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) - (has_decreases_clause : bool) (def : pure_fun_translation) : extraction_ctx - = - let fwd, back_ls = def in - (* Register the decrease clause, if necessary *) - let ctx = - if has_decreases_clause then ctx_add_decrases_clause fwd ctx else ctx + (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : + extraction_ctx = + let (fwd, loop_fwds), back_ls = def in + (* Register the decrease clauses, if necessary *) + let register_decreases ctx def = + if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx in - (* Register the forward function name *) - let ctx = ctx_add_fun_decl (keep_fwd, def) fwd ctx in + let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in + (* Register the function names *) + let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in + let register_funs ctx fl = List.fold_left register_fun ctx fl in + (* Register the forward functions' names *) + let ctx = register_funs ctx (fwd :: loop_fwds) in (* Register the backward functions' names *) let ctx = List.fold_left - (fun ctx back -> ctx_add_fun_decl (keep_fwd, def) back ctx) + (fun ctx (back, loop_backs) -> + let ctx = register_fun ctx back in + register_funs ctx loop_backs) ctx back_ls in + (* Return *) ctx @@ -1855,7 +1862,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = assert (!backend = FStar); (* Retrieve the function name *) - let def_name = ctx_get_decreases_clause def.def_id ctx in + let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -1992,7 +1999,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the decreases term *) F.pp_open_hovbox fmt ctx.indent_incr; (* The name of the decrease clause *) - let decr_name = ctx_get_decreases_clause def.def_id ctx in + let decr_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in F.pp_print_string fmt decr_name; (* Print the type parameters *) List.iter diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index c1ea536a..b952d555 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -273,7 +273,7 @@ type formatter = { type id = | GlobalId of A.GlobalDeclId.id | FunId of fun_id - | DecreasesClauseId of A.fun_id + | DecreasesClauseId of (A.fun_id * LoopId.id option) (** The definition which provides the decreases/termination clause. We insert calls to this clause to prove/reason about termination: the body of those clauses must be defined by the user, in the @@ -467,14 +467,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = in "fun name (" ^ lp_kind ^ fwd_back_kind ^ "): " ^ fun_name | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid) - | DecreasesClauseId fid -> + | DecreasesClauseId (fid, lid) -> let fun_name = match fid with | Regular fid -> Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name | Assumed aid -> A.show_assumed_fun_id aid in - "decreases clause for function: " ^ fun_name + let loop = + match lid with + | None -> "" + | Some lid -> ", loop: " ^ LoopId.to_string lid + in + "decreases clause for function: " ^ fun_name ^ loop | TypeId id -> "type name: " ^ get_type_name id | StructId id -> "struct constructor of: " ^ get_type_name id | VariantId (id, variant_id) -> @@ -581,9 +586,9 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = ctx_get (VariantId (def_id, variant_id)) ctx -let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : - string = - ctx_get (DecreasesClauseId (Regular def_id)) ctx +let ctx_get_decreases_clause (def_id : A.FunDeclId.id) + (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = + ctx_get (DecreasesClauseId (Regular def_id, loop_id)) ctx (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) @@ -669,10 +674,10 @@ let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : let ctx = ctx_add (StructId (AdtId def.def_id)) name ctx in (ctx, name) -let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : +let ctx_add_decreases_clause (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.decreases_clause_name def.def_id def.basename in - ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx + ctx_add (DecreasesClauseId (Regular def.def_id, def.loop_id)) name ctx let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 6fb20b22..97eced1d 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -14,7 +14,7 @@ module SymbolicValueId = V.SymbolicValueId module FunDeclId = A.FunDeclId module GlobalDeclId = A.GlobalDeclId -(** We redefine identifiers for loop: in {Values}, the identifiers are global +(** We redefine identifiers for loop: in {!Values}, the identifiers are global (they monotonically increase across functions) while in {!module:Pure} we want the indices to start at 0 for every function. *) @@ -492,6 +492,9 @@ and match_branch = { pat : typed_pattern; branch : texpression } and loop = { fun_end : texpression; loop_id : loop_id; + fuel0 : var_id; + fuel : var_id; + input_state : var_id option; inputs : var list; inputs_lvs : typed_pattern list; (** The inputs seen as patterns. See {!fun_body}. *) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 87ab4609..335336be 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -432,12 +432,34 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (ctx, Switch (scrut, body)) (* *) and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression = - let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in + let { + fun_end; + loop_id; + fuel0; + fuel; + input_state; + inputs; + inputs_lvs; + loop_body; + } = + loop + in let ctx, fun_end = update_texpression fun_end ctx in let ctx, loop_body = update_texpression loop_body ctx in let inputs = List.map (fun input -> update_var ctx input None) inputs in let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in - let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in + let loop = + { + fun_end; + loop_id; + fuel0; + fuel; + input_state; + inputs; + inputs_lvs; + loop_body; + } + in (ctx, Loop loop) (* *) and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : @@ -972,6 +994,160 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = then None else Some def +(** Retrieve the loop definitions from the function definition. + + {!SymbolicToPure} generates an AST in which the loop bodies are part of + the function body (see the {!Pure.Loop} node). This function extracts + those function bodies into independent definitions while removing + occurrences of the {!Pure.Loop} node. + *) +let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = + (* Store the loops here *) + let loops = ref LoopId.Map.empty in + let expr_visitor = + object (self) + inherit [_] map_expression + + method! visit_Loop env loop = + let fun_sig = def.signature in + let fun_sig_info = fun_sig.info in + let fun_effect_info = fun_sig_info.effect_info in + + (* Generate the loop definition *) + let loop_effect_info = + { + stateful_group = fun_effect_info.stateful_group; + stateful = fun_effect_info.stateful; + can_fail = fun_effect_info.can_fail; + can_diverge = fun_effect_info.can_diverge; + is_rec = fun_effect_info.is_rec; + } + in + + let loop_sig_info = + let fuel = if !Config.use_fuel then 1 else 0 in + let num_inputs = List.length loop.inputs in + let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in + let num_fwd_inputs_with_fuel_with_state = + fun_sig_info.num_fwd_inputs_with_fuel_with_state + - fun_sig_info.num_fwd_inputs_with_fuel_no_state + in + { + has_fuel = !Config.use_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; + num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state; + effect_info = loop_effect_info; + } + in + + let inputs_tys = + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in + let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let state = + Collections.List.subslice fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_no_state + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs = + Collections.List.split_at fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + List.concat [ fuel; fwd_inputs; state; back_inputs ] + in + + let loop_sig = + { + type_params = fun_sig.type_params; + inputs = inputs_tys; + output = fun_sig.output; + doutputs = fun_sig.doutputs; + info = loop_sig_info; + } + in + + let fuel_vars, inputs, inputs_lvs = + (* Introduce the fuel input *) + let fuel_vars, fuel0_var, fuel_lvs = + if !Config.use_fuel then + let fuel0_var = mk_fuel_var loop.fuel0 in + let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in + (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) + else (None, [], []) + in + + (* Introduce the forward input state *) + let fwd_state_var, fwd_state_lvs = + assert (loop_effect_info.stateful = Option.is_some loop.input_state); + match loop.input_state with + | None -> ([], []) + | Some input_state -> + let state_var = mk_state_var input_state in + let state_lvs = mk_typed_pattern_from_var state_var None in + ([ state_var ], [ state_lvs ]) + in + + (* Introduce the additional backward inputs *) + let fun_body = Option.get def.body in + let _, back_inputs = + Collections.List.split_at fun_body.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs_lvs = + Collections.List.split_at fun_body.inputs_lvs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + + let inputs = + List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] + in + let inputs_lvs = + List.concat + [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] + in + (fuel_vars, inputs, inputs_lvs) + in + + (* Wrap the loop body in a match over the fuel *) + let loop_body = + match fuel_vars with + | None -> loop.loop_body + | Some (fuel0, fuel) -> + SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body + in + + let loop_body = { inputs; inputs_lvs; body = loop_body } in + + let loop_def = + { + def_id = def.def_id; + num_loops = 0; + loop_id = Some loop.loop_id; + back_id = def.back_id; + basename = def.basename; + signature = loop_sig; + is_global_decl_body = def.is_global_decl_body; + body = Some loop_body; + } + in + (* Store the loop definition *) + loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; + + (* Update the current expression to remove the [Loop] node, and continue *) + (self#visit_texpression env loop.fun_end).e + end + in + + match def.body with + | None -> (def, []) + | Some body -> + let body_expr = expr_visitor#visit_texpression () body.body in + let body = { body with body = body_expr } in + let def = { def with body = Some body } in + let loops = List.map snd (LoopId.Map.bindings !loops) in + (def, loops) + (** Return [false] if the forward function is useless and should be filtered. - a forward function with no output (comes from a Rust function with @@ -989,7 +1165,7 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = altogether. *) let keep_forward (trans : pure_fun_translation) : bool = - let fwd, backs = trans in + let (fwd, _), backs = trans in (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if @@ -1306,13 +1482,110 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Return *) { def with body = Some body } +(** Auxiliary function for {!apply_passes_to_def} *) +let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = + (* Convert the unit variables to [()] if they are used as right-values or + * [_] if they are used as left values. *) + let def = unit_vars_to_unit def in + log#ldebug + (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Inline the useless variable reassignments *) + let inline_named_vars = true in + let inline_pure = true in + let def = + inline_useless_var_reassignments inline_named_vars inline_pure def + in + log#ldebug + (lazy + ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Eliminate the box functions - note that the "box" types were eliminated + * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *) + let def = eliminate_box_functions ctx def in + log#ldebug + (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Filter the useless variables, assignments, function calls, etc. *) + let def = filter_useless !Config.filter_useless_monadic_calls ctx def in + log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Simplify the aggregated ADTs. + + Ex.: + {[ + (* type struct = { f0 : nat; f1 : nat } *) + + Mkstruct x.f0 x.f1 ~~> x + ]} + *) + let def = simplify_aggregates ctx def in + log#ldebug + (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Decompose the monadic let-bindings - used by Coq *) + let def = + if !Config.decompose_monadic_let_bindings then ( + let def = decompose_monadic_let_bindings ctx def in + log#ldebug + (lazy + ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_monadic_let_bindings due to the configuration\n"); + def) + in + + (* Decompose nested let-patterns *) + let def = + if !Config.decompose_nested_let_patterns then ( + let def = decompose_nested_let_patterns ctx def in + log#ldebug + (lazy + ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_nested_let_patterns due to the configuration\n"); + def) + in + + (* Unfold the monadic let-bindings *) + let def = + if !Config.unfold_monadic_let_bindings then ( + let def = unfold_monadic_let_bindings ctx def in + log#ldebug + (lazy + ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy "ignoring unfold_monadic_let_bindings due to the configuration\n"); + def) + in + + (* We are done *) + def + (** Apply all the micro-passes to a function. + As loops are initially directly integrated into the function definition, + {!apply_passes_to_def} extracts those loops definitions from the body; + it thus returns the pair: (function def, loop defs). See {!decompose_loops} + for more information. + Will return [None] if the function is a backward function with no outputs. [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : + (fun_decl * fun_decl list) option = (* Debug *) log#ldebug (lazy @@ -1347,101 +1620,19 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = match def with | None -> None | Some def -> - (* Convert the unit variables to [()] if they are used as right-values or - * [_] if they are used as left values. *) - let def = unit_vars_to_unit def in - log#ldebug - (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Inline the useless variable reassignments *) - let inline_named_vars = true in - let inline_pure = true in - let def = - inline_useless_var_reassignments inline_named_vars inline_pure def - in - log#ldebug - (lazy - ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - - (* Eliminate the box functions - note that the "box" types were eliminated - * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *) - let def = eliminate_box_functions ctx def in - log#ldebug - (lazy - ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Filter the useless variables, assignments, function calls, etc. *) - let def = filter_useless !Config.filter_useless_monadic_calls ctx def in - log#ldebug - (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops def in - (* Simplify the aggregated ADTs. - - Ex.: - {[ - (* type struct = { f0 : nat; f1 : nat } *) - - Mkstruct x.f0 x.f1 ~~> x - ]} - *) - let def = simplify_aggregates ctx def in - log#ldebug - (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Decompose the monadic let-bindings - used by Coq *) - let def = - if !Config.decompose_monadic_let_bindings then ( - let def = decompose_monadic_let_bindings ctx def in - log#ldebug - (lazy - ("decompose_monadic_let_bindings:\n\n" - ^ fun_decl_to_string ctx def ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring decompose_monadic_let_bindings due to the configuration\n"); - def) - in - - (* Decompose nested let-patterns *) - let def = - if !Config.decompose_nested_let_patterns then ( - let def = decompose_nested_let_patterns ctx def in - log#ldebug - (lazy - ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring decompose_nested_let_patterns due to the configuration\n"); - def) - in - - (* Unfold the monadic let-bindings *) - let def = - if !Config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings ctx def in - log#ldebug - (lazy - ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring unfold_monadic_let_bindings due to the configuration\n"); - def) - in - - (* We are done *) - Some def + (* Apply the remaining passes *) + let def = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + Some (def, loops) (** Return the forward/backward translations on which we applied the micro-passes. + This function also extracts the loop definitions from the function body + (see {!decompose_loops}). + Also returns a boolean indicating whether the forward function should be kept or not (because useful/useless - [true] means we need to keep the forward function). @@ -1450,7 +1641,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = functions: keeping it is not necessary but more convenient. *) let apply_passes_to_pure_fun_translation (ctx : trans_ctx) - (trans : pure_fun_translation) : bool * pure_fun_translation = + (trans : fun_decl * fun_decl list) : bool * pure_fun_translation = (* Apply the passes to the individual functions *) let forward, backwards = trans in let forward = Option.get (apply_passes_to_def ctx forward) in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index b5c9b686..e1421f5a 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -7,6 +7,9 @@ let log = Logging.pure_utils_log type regular_fun_id = A.fun_id * T.RegionGroupId.id option [@@deriving show, ord] +(** We use this type as a key for lookups *) +type fun_loop_id = A.FunDeclId.id * LoopId.id option [@@deriving show, ord] + module RegularFunIdOrderedType = struct type t = regular_fun_id @@ -30,6 +33,18 @@ end module FunOrOpIdMap = Collections.MakeMap (FunOrOpIdOrderedType) module FunOrOpIdSet = Collections.MakeSet (FunOrOpIdOrderedType) +module FunLoopIdOrderedType = struct + type t = fun_loop_id + + let compare = compare_fun_loop_id + let to_string = show_fun_loop_id + let pp_t = pp_fun_loop_id + let show_t = show_fun_loop_id +end + +module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) +module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType) + let dest_arrow_ty (ty : ty) : ty * ty = match ty with | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index a2b41165..ad603bd5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2261,7 +2261,19 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let loop_body = translate_expression loop.loop_expr ctx_loop in (* Create the loop node and return *) - let loop = Loop { fun_end; loop_id; inputs; inputs_lvs; loop_body } in + let loop = + Loop + { + fun_end; + loop_id; + fuel0 = ctx.fuel0; + fuel = ctx.fuel; + input_state = (if !Config.use_state then Some ctx.state_var else None); + inputs; + inputs_lvs; + loop_body; + } + in assert (fun_end.ty = loop_body.ty); let ty = fun_end.ty in { e = loop; ty } @@ -2282,10 +2294,11 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : { e; ty } (** Wrap a function body in a match over the fuel to control termination. *) -let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression = - let fuel0_var : var = mk_fuel_var ctx.fuel0 in +let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) + : texpression = + let fuel0_var : var = mk_fuel_var fuel0 in let fuel0 = mk_texpression_from_var fuel0_var in - let nfuel_var : var = mk_fuel_var ctx.fuel in + let nfuel_var : var = mk_fuel_var fuel in let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in let fail_branch = mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty @@ -2376,7 +2389,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Add a match over the fuel, if necessary *) let body = if function_decreases_fuel effect_info then - wrap_in_match_fuel body ctx + wrap_in_match_fuel ctx.fuel0 ctx.fuel body else body in (* Sanity check *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 32c32ac4..10a37770 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -51,7 +51,7 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) let translate_function_to_pure (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) - : pure_fun_translation = + : pure_fun_translation_no_loops = (* Debug *) log#ldebug (lazy @@ -213,7 +213,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) sg.info.num_fwd_inputs_with_fuel_with_state in let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in - Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs + Collections.List.subslice sg.inputs num_forward_inputs + (num_forward_inputs + num_back_inputs) in (* As we forbid nested borrows, the additional inputs for the backward * functions come from the borrows in the return value of the rust function: @@ -336,7 +337,7 @@ type gen_ctx = { extract_ctx : ExtractBase.extraction_ctx; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; - functions_with_decreases_clause : A.FunDeclId.Set.t; + functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; } type gen_config = { @@ -370,7 +371,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = in let has_opaque_funs = A.FunDeclId.Map.exists - (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) -> + (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) -> Option.is_none t_fwd.body) ctx.trans_funs in @@ -452,10 +453,11 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in - let _, (body, body_backs) = + let _, ((body, loop_fwds), body_backs) = A.FunDeclId.Map.find global.body_id ctx.trans_funs in - assert (List.length body_backs = 0); + assert (body_backs = []); + assert (loop_fwds = []); let is_opaque = Option.is_none body.Pure.body in if @@ -487,7 +489,8 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause in (* Extract the function declarations *) @@ -532,16 +535,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause in (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun (_, (fwd, _)) -> - let has_decr_clause = has_decreases_clause fwd in - if has_decr_clause then - Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd) + (fun (_, ((fwd, loop_fwds), _)) -> + let extract_decrease decl = + let has_decr_clause = has_decreases_clause decl in + if has_decr_clause then + Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl + in + extract_decrease fwd; + List.iter extract_decrease loop_fwds) pure_ls; (* Concatenate the function definitions, filtering the useless forward @@ -549,8 +557,15 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let decls = List.concat (List.map - (fun (keep_fwd, (fwd, back_ls)) -> - if keep_fwd then fwd :: back_ls else back_ls) + (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) -> + let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in + let back : Pure.fun_decl list = + List.concat + (List.map + (fun (back, loop_backs) -> List.append loop_backs [ back ]) + back_ls) + in + List.append fwd back) pure_ls) in @@ -568,7 +583,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun (keep_fwd, (fwd, _)) -> + (fun (keep_fwd, ((fwd, _), _)) -> if keep_fwd then Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) pure_ls @@ -721,12 +736,25 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = - A.FunDeclId.Set.of_list - (List.concat - (List.map - (fun decl -> match decl with A.Fun (Rec ids) -> ids | _ -> []) - crate.declarations)) + List.map + (fun (_, ((fwd, loop_fwds), _)) -> + let fwd = + if fwd.Pure.signature.info.effect_info.is_rec then + [ (fwd.def_id, None) ] + else [] + in + let loop_fwds = + List.map + (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) + loop_fwds + in + fwd :: loop_fwds) + trans_funs + in + let rec_functions : PureUtils.fun_loop_id list = + List.concat (List.concat rec_functions) in + let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in (* Register unique names for all the top-level types, globals and functions. * Note that the order in which we generate the names doesn't matter: @@ -740,18 +768,21 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : let ctx = List.fold_left - (fun ctx (keep_fwd, def) -> + (fun ctx (keep_fwd, defs) -> (* We generate a decrease clause for all the recursive functions *) - let gen_decr_clause = - A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions + let fwd_def = fst (fst defs) in + let gen_decr_clause (def : Pure.fun_decl) = + PureUtils.FunLoopIdSet.mem + (def.Pure.def_id, def.Pure.loop_id) + rec_functions in (* Register the names, only if the function is not a global body - * those are handled later *) - let is_global = (fst def).Pure.is_global_decl_body in + let is_global = fwd_def.Pure.is_global_decl_body in if is_global then ctx else Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause - def) + defs) ctx trans_funs in @@ -785,7 +816,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : A.FunDeclId.Map.of_list (List.map (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> - (fd.def_id, (keep_fwd, (fd, bdl)))) + ((fst fd).def_id, (keep_fwd, (fd, bdl)))) trans_funs) in @@ -883,7 +914,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* Extract the template clauses *) let needs_clauses_module = !Config.extract_decreases_clauses - && not (A.FunDeclId.Set.is_empty rec_functions) + && not (PureUtils.FunLoopIdSet.is_empty rec_functions) in (if needs_clauses_module && !Config.extract_template_decreases_clauses then let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index a658147d..9ba73c7e 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -26,7 +26,9 @@ type trans_ctx = { global_context : global_context; } -type pure_fun_translation = Pure.fun_decl * Pure.fun_decl list +type fun_and_loops = Pure.fun_decl * Pure.fun_decl list +type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list +type pure_fun_translation = fun_and_loops * fun_and_loops list let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = let type_params = def.type_params in |