From 496a3849d1d6ba880bbd1e86c8ef5e2257bb702a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 10:55:57 +0100 Subject: Add the num_fwd_inputs_no_fuel_no_state field in Pure.fun_sig --- compiler/Pure.ml | 8 +++++--- compiler/PureMicroPasses.ml | 5 +++++ compiler/SymbolicToPure.ml | 5 +++-- 3 files changed, 13 insertions(+), 5 deletions(-) (limited to 'compiler') diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 0ae83007..d7aea0f7 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -860,8 +860,8 @@ type fun_effect_info = { the set [{ forward function } U { backward functions }]. We need this because of the option {!val:Config.backward_no_state_update}: - if it is [true], then in case of a backward function {!stateful} is [false], - but we might need to know whether the corresponding forward function + if it is [true], then in case of a backward function {!stateful} might be + [false], but we might need to know whether the corresponding forward function is stateful or not. *) stateful : bool; (** [true] if the function is stateful (updates a state) *) @@ -876,7 +876,9 @@ type fun_effect_info = { (** Meta information about a function signature *) type fun_sig_info = { has_fuel : bool; - (* TODO: add [num_fwd_inputs_no_fuel_no_state] *) + num_fwd_inputs_no_fuel_no_state : int; + (** The number of input types for forward computation, ignoring the fuel (if used) + and ignoring the state (if used) *) num_fwd_inputs_with_fuel_no_state : int; (** The number of input types for forward computation, with the fuel (if used) and ignoring the state (if used) *) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 959ec1c8..34578750 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1340,6 +1340,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in + let num_fwd_inputs_no_fuel_no_state = num_inputs in let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in let fwd_state = fun_sig_info.num_fwd_inputs_with_fuel_with_state @@ -1350,6 +1351,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = in { has_fuel = !Config.use_fuel; + num_fwd_inputs_no_fuel_no_state; num_fwd_inputs_with_fuel_no_state; num_fwd_inputs_with_fuel_with_state; num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; @@ -2168,6 +2170,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in let { has_fuel; + num_fwd_inputs_no_fuel_no_state; num_fwd_inputs_with_fuel_no_state; num_fwd_inputs_with_fuel_with_state; num_back_inputs_no_state; @@ -2182,6 +2185,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) : let info = { has_fuel; + num_fwd_inputs_no_fuel_no_state = + num_fwd_inputs_no_fuel_no_state - num_filtered; num_fwd_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state - num_filtered; num_fwd_inputs_with_fuel_with_state = diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf4d26f2..2ef313e6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1035,10 +1035,10 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let generics = translate_generic_params sg.generics in (* Return *) let has_fuel = fuel <> [] in - let num_fwd_inputs_no_state = List.length fwd_inputs in + let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in let num_fwd_inputs_with_fuel_no_state = (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_state + List.length fuel + num_fwd_inputs_no_fuel_no_state in let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) @@ -1046,6 +1046,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let info = { has_fuel; + num_fwd_inputs_no_fuel_no_state; num_fwd_inputs_with_fuel_no_state; num_fwd_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, -- cgit v1.2.3 From 0c814c97dd8e5167f24b0dbb14186d674e4d097b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 11:44:58 +0100 Subject: Update Pure.fun_sig_info --- compiler/Config.ml | 4 ++ compiler/Extract.ml | 8 +++- compiler/Pure.ml | 30 +++++++------- compiler/PureMicroPasses.ml | 96 +++++++++++++++++++++++++++------------------ compiler/PureUtils.ml | 19 +++++++++ compiler/SymbolicToPure.ml | 32 +++++++++------ compiler/Translate.ml | 8 +++- 7 files changed, 129 insertions(+), 68 deletions(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index b09544ba..9cd1ebc2 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -92,6 +92,10 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) +(** If true, do not define separate forward/backward functions, but make the + forward functions return the backward function. *) +let return_back_funs = ref false + (** Forbids using field projectors for structures. If we don't use field projectors, whenever we symbolically expand a structure diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 20cdb20b..93fcf416 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1469,8 +1469,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in + (* TODO: *) + assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1515,8 +1517,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) if has_decreases_clause && !backend = Lean then ( let def_body = Option.get def.body in let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in + (* TODO: *) + assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index d7aea0f7..80d8782b 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -873,23 +873,25 @@ type fun_effect_info = { } [@@deriving show] -(** Meta information about a function signature *) -type fun_sig_info = { +type inputs_info = { has_fuel : bool; - num_fwd_inputs_no_fuel_no_state : int; - (** The number of input types for forward computation, ignoring the fuel (if used) + num_inputs_no_fuel_no_state : int; + (** The number of input types ignoring the fuel (if used) and ignoring the state (if used) *) - num_fwd_inputs_with_fuel_no_state : int; - (** The number of input types for forward computation, with the fuel (if used) + num_inputs_with_fuel_no_state : int; + (** The number of input types, with the fuel (if used) and ignoring the state (if used) *) - num_fwd_inputs_with_fuel_with_state : int; - (** The number of input types for forward computation, with fuel and state (if used) *) - num_back_inputs_no_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - ignoring the state (if there is one) *) - num_back_inputs_with_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - with the state (if there is one) *) + num_inputs_with_fuel_with_state : int; + (** The number of input types, with fuel and state (if used) *) +} +[@@deriving show] + +(** Meta information about a function signature *) +type fun_sig_info = { + fwd_info : inputs_info; + (** Information about the inputs of the forward function *) + back_info : inputs_info option; + (** Information about the inputs of the backward function, if pertinent *) effect_info : fun_effect_info; } [@@deriving show] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 34578750..7f122f15 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1326,6 +1326,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let fun_sig_info = fun_sig.info in let fun_effect_info = fun_sig_info.effect_info in + (* TODO: *) + assert (not !Config.return_back_funs); + (* Generate the loop definition *) let loop_effect_info = { @@ -1340,38 +1343,44 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in - let num_fwd_inputs_no_fuel_no_state = num_inputs in - let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in - let fwd_state = - fun_sig_info.num_fwd_inputs_with_fuel_with_state - - fun_sig_info.num_fwd_inputs_with_fuel_no_state - in - let num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_no_state + fwd_state + let fwd_info : inputs_info = + let info = fun_sig_info.fwd_info in + let fwd_state = + info.num_inputs_with_fuel_with_state + - info.num_inputs_with_fuel_no_state + in + { + has_fuel = !Config.use_fuel; + num_inputs_no_fuel_no_state = num_inputs; + num_inputs_with_fuel_no_state = num_inputs + fuel; + num_inputs_with_fuel_with_state = + num_inputs + fuel + fwd_state; + } in + { - has_fuel = !Config.use_fuel; - num_fwd_inputs_no_fuel_no_state; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; - num_back_inputs_with_state = - fun_sig_info.num_back_inputs_with_state; + fwd_info; + back_info = fun_sig_info.back_info; effect_info = loop_effect_info; } in + assert (fun_sig_info_is_wf loop_sig_info); let inputs_tys = + (* TODO: *) + assert (not !Config.return_back_funs); + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let info = fun_sig_info.fwd_info in let state = Collections.List.subslice fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_no_state - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_no_state + info.num_inputs_with_fuel_with_state in let _, back_inputs = Collections.List.split_at fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in List.concat [ fuel; fwd_inputs; state; back_inputs ] in @@ -1432,14 +1441,17 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = in (* Introduce the additional backward inputs *) + (* TODO: *) + assert (not !Config.return_back_funs); let fun_body = Option.get def.body in + let info = fun_sig_info.fwd_info in let _, back_inputs = Collections.List.split_at fun_body.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let _, back_inputs_lvs = Collections.List.split_at fun_body.inputs_lvs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let inputs = @@ -2055,12 +2067,14 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = + (* TODO: *) + assert (not !Config.return_back_funs); (* There should be a body *) let body = Option.get decl.body in (* We only look at the forward inputs, without the state *) let inputs_prefix, _ = Collections.List.split_at body.inputs - decl.signature.info.num_fwd_inputs_with_fuel_no_state + decl.signature.info.fwd_info.num_inputs_with_fuel_no_state in let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in let inputs_prefix_length = List.length inputs_prefix in @@ -2080,7 +2094,10 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* Set the fuel as used *) let sg_info = decl.signature.info in - if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); + (* TODO: *) + assert (not !Config.return_back_funs); + if sg_info.fwd_info.has_fuel then + set_used (fst (Collections.List.nth inputs 0)); let visitor = object (self : 'self) @@ -2168,34 +2185,35 @@ let filter_loop_inputs (transl : pure_fun_translation list) : = decl.signature in + (* TODO: *) + assert (not !Config.return_back_funs); + let { fwd_info; back_info; effect_info } = info in + let { has_fuel; - num_fwd_inputs_no_fuel_no_state; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; } = - info + fwd_info in let inputs = filter_prefix used_info inputs in - let info = + let fwd_info = { has_fuel; - num_fwd_inputs_no_fuel_no_state = - num_fwd_inputs_no_fuel_no_state - num_filtered; - num_fwd_inputs_with_fuel_no_state = - num_fwd_inputs_with_fuel_no_state - num_filtered; - num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_with_state - num_filtered; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state = + num_inputs_no_fuel_no_state - num_filtered; + num_inputs_with_fuel_no_state = + num_inputs_with_fuel_no_state - num_filtered; + num_inputs_with_fuel_with_state = + num_inputs_with_fuel_with_state - num_filtered; } in + + let info = { fwd_info; back_info; effect_info } in + assert (fun_sig_info_is_wf info); let signature = { generics; llbc_generics; preds; inputs; output; doutputs; info } in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 39dcd52d..23a41f0e 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -57,6 +57,25 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType) +let inputs_info_is_wf (info : inputs_info) : bool = + let { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; + } = + info + in + let fuel = if has_fuel then 1 else 0 in + num_inputs_no_fuel_no_state >= 0 + && num_inputs_with_fuel_no_state = num_inputs_no_fuel_no_state + fuel + && num_inputs_with_fuel_with_state >= num_inputs_with_fuel_no_state + +let fun_sig_info_is_wf (info : fun_sig_info) : bool = + inputs_info_is_wf info.fwd_info + && + match info.back_info with None -> true | Some info -> inputs_info_is_wf info + let dest_arrow_ty (ty : ty) : ty * ty = match ty with | TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 2ef313e6..971a8cbd 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1034,6 +1034,8 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Generic parameters *) let generics = translate_generic_params sg.generics in (* Return *) + (* TODO: *) + assert (not !Config.return_back_funs); let has_fuel = fuel <> [] in let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in let num_fwd_inputs_with_fuel_no_state = @@ -1043,24 +1045,32 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) in - let info = + let fwd_info : inputs_info = { has_fuel; - num_fwd_inputs_no_fuel_no_state; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state = + num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, and 0 otherwise *) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - num_back_inputs_no_state; - num_back_inputs_with_state = - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - Option.map - (fun n -> n + List.length back_state_ty) - num_back_inputs_no_state; - effect_info; } in + let back_info : inputs_info option = + Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state + in + let info = { fwd_info; back_info; effect_info } in + assert (fun_sig_info_is_wf info); let preds = translate_predicates sg.preds in let sg = { diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 221d4e73..54e24066 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -216,11 +216,15 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = let sg = backward_sg.sg in + (* TODO: *) + assert (not !Config.return_back_funs); (* We need to ignore the forward state and the backward state *) let num_forward_inputs = - sg.info.num_fwd_inputs_with_fuel_with_state + sg.info.fwd_info.num_inputs_with_fuel_with_state + in + let num_back_inputs = + (Option.get sg.info.back_info).num_inputs_no_fuel_no_state in - let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in Collections.List.subslice sg.inputs num_forward_inputs (num_forward_inputs + num_back_inputs) in -- cgit v1.2.3 From 7630c45b7990d0df1db022f827e7de676ad4499a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 11:48:53 +0100 Subject: Make a minor modification in a comment --- compiler/InterpreterExpressions.mli | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli index f8d979f4..b975371c 100644 --- a/compiler/InterpreterExpressions.mli +++ b/compiler/InterpreterExpressions.mli @@ -52,7 +52,7 @@ val eval_operands : Transmits the computed rvalue to the received continuation. - Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant + Note that this function fails on {!Aeneas.Expressions.rvalue.Discriminant}: discriminant reads should have been eliminated from the AST. *) val eval_rvalue_not_global : -- cgit v1.2.3 From f69ac6a4a244c99a41a90ed57f74ea83b3835882 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:11:01 +0100 Subject: Start updating Pure.fun_sig_info to handle merged forward and backward functions --- compiler/Config.ml | 63 ++++++++++++++++++++++++++++++++++++++++++++-- compiler/Pure.ml | 24 +++++++++++++++--- compiler/PureUtils.ml | 6 ++++- compiler/SymbolicToPure.ml | 62 ++++++++++++++++++++++++++++++--------------- 4 files changed, 128 insertions(+), 27 deletions(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index 9cd1ebc2..b8af6c6d 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -93,8 +93,67 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) (** If true, do not define separate forward/backward functions, but make the - forward functions return the backward function. *) -let return_back_funs = ref false + forward functions return the backward function. + + Example: + {[ + (* Rust *) + pub fn list_nth<'a, T>(l: &'a mut List, i: u32) -> &'a mut T { + match l { + List::Nil => { + panic!() + } + List::Cons(x, tl) => { + if i == 0 { + x + } else { + list_nth(tl, i - 1) + } + } + } + } + + (* Translation, if return_back_funs = false *) + def list_nth (T : Type) (l : List T) (i : U32) : Result T := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret x + else do + let i0 ← i - 1#u32 + list_nth T tl i0 + | List.Nil => Result.fail .panic + + def list_nth_back + (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (List.Cons ret tl) + else + do + let i0 ← i - 1#u32 + let tl0 ← list_nth_back T tl i0 ret + Result.ret (List.Cons x tl0) + | List.Nil => Result.fail .panic + + (* Translation, if return_back_funs = true *) + def list_nth (T: Type) (ls : List T) (i : U32) : + Result (T × (T → Result (List T))) := + match ls with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (x, (λ ret => return (ret :: ls))) + else do + let i0 ← i - 1#u32 + let (x, back) ← list_nth ls i0 + Return.ret (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + ]} + *) +let return_back_funs = ref true (** Forbids using field projectors for structures. diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 80d8782b..bb522623 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref = (** A function id for a non-assumed function *) type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option + fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -886,12 +886,28 @@ type inputs_info = { } [@@deriving show] +type 'a back_info = + | SingleBack of 'a option + (** Information about a single backward function, if pertinent. + + We use this variant if we split the forward and the backward functions. + *) + | AllBacks of 'a RegionGroupId.Map.t + (** Information about the various backward functions. + + We use this if we *do not* split the forward and the backward functions. + All the information is then carried by the forward function. + *) +[@@deriving show] + +type back_inputs_info = inputs_info back_info [@@deriving show] + (** Meta information about a function signature *) type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) - back_info : inputs_info option; - (** Information about the inputs of the backward function, if pertinent *) + back_info : back_inputs_info; + (** Information about the inputs of the backward functions. *) effect_info : fun_effect_info; } [@@deriving show] @@ -1024,7 +1040,7 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : T.RegionGroupId.id option; + back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 23a41f0e..3c038149 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -74,7 +74,11 @@ let inputs_info_is_wf (info : inputs_info) : bool = let fun_sig_info_is_wf (info : fun_sig_info) : bool = inputs_info_is_wf info.fwd_info && - match info.back_info with None -> true | Some info -> inputs_info_is_wf info + match info.back_info with + | SingleBack None -> true + | SingleBack (Some info) -> inputs_info_is_wf info + | AllBacks infos -> + List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos) let dest_arrow_ty (ty : ty) : ty * ty = match ty with diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 971a8cbd..59205f08 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -855,10 +855,14 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names in the extracted code. + + We use [bid] ("backward function id") only if we split the forward + and the backward functions. *) let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + assert (Option.is_none bid || not !Config.return_back_funs); let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) @@ -939,6 +943,18 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in + let translate_back_inputs_for_gid gid : ty list = + (* For now, we don't allow nested borrows, so the additional inputs to the + backward function can only come from borrows that were returned like + in (for the backward function we introduce for 'a): + {[ + fn f<'a>(...) -> &'a mut u32; + ]} + Upon ending the abstraction for 'a, we need to get back the borrow + the function returned. + *) + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = @@ -1056,18 +1072,22 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in - let back_info : inputs_info option = - Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state + let back_info : back_inputs_info = + if !Config.return_back_funs then + SingleBack + (Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state) + else (* Create the map *) + failwith "TODO" in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); @@ -3162,14 +3182,16 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy + if !Config.return_back_funs then [] + else + List.map + (fun (rg : T.region_var_group) -> + let tsg = + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) + in + let id = (fun_id, Some rg.id) in + (id, tsg)) + regions_hierarchy in (* Return *) (fwd_id, fwd_sg) :: back_sgs -- cgit v1.2.3 From f1f41818fb14a6c46442ca42a49a3aab0a5b1aaf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:48:44 +0100 Subject: Make progress on generated merged fwd/back functions --- compiler/SymbolicToPure.ml | 56 ++++++++++++++++++++++++---------------------- compiler/Translate.ml | 4 +++- 2 files changed, 32 insertions(+), 28 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 1fd4896e..86c80f87 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -958,19 +958,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = - match gid with - | None -> [] - | Some gid -> - (* For now, we don't allow nested borrows, so the additional inputs to the - backward function can only come from borrows that were returned like - in (for the backward function we introduce for 'a): - {[ - fn f<'a>(...) -> &'a mut u32; - ]} - Upon ending the abstraction for 'a, we need to get back the borrow - the function returned. - *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] @@ -989,11 +977,12 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) See {!effect_info}. *) if effect_info.stateful_group then [ mk_state_ty ] else [] in - let back_state_ty = + let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = (* For the backward state, we check if the function is a backward function, and it is stateful *) if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] in + let back_state_ty = mk_back_state_ty_for_gid gid in (* Concatenate the inputs, in the following order: * - forward inputs @@ -1072,22 +1061,35 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in + let compute_back_info (back_state_ty : ty list) + (num_back_inputs_no_state : int) : inputs_info = + let n = num_back_inputs_no_state in + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + } + in let back_info : back_inputs_info = if !Config.return_back_funs then + (* Create the map *) + AllBacks + (RegionGroupId.Map.of_list + (List.map + (fun (rg : T.region_var_group) -> + ( rg.id, + let back_inputs = translate_back_inputs_for_gid rg.id in + let num_back_inputs = List.length back_inputs in + (* TODO: slightly overkill *) + let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in + compute_back_info back_state_ty num_back_inputs )) + regions_hierarchy)) + else SingleBack - (Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state) - else (* Create the map *) - failwith "TODO" + (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 54e24066..06d4bd6d 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -223,7 +223,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) sg.info.fwd_info.num_inputs_with_fuel_with_state in let num_back_inputs = - (Option.get sg.info.back_info).num_inputs_no_fuel_no_state + match sg.info.back_info with + | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state + | _ -> raise (Failure "Unexpected") in Collections.List.subslice sg.inputs num_forward_inputs (num_forward_inputs + num_back_inputs) -- cgit v1.2.3 From cf984f958da94154d0550060eb290a276ab52f23 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 10:17:06 +0100 Subject: Make minor modifications --- compiler/Pure.ml | 9 ++-- compiler/SymbolicToPure.ml | 108 ++++++++++++++------------------------------- 2 files changed, 38 insertions(+), 79 deletions(-) (limited to 'compiler') diff --git a/compiler/Pure.ml b/compiler/Pure.ml index c3716001..34f3ef72 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -886,13 +886,13 @@ type inputs_info = { } [@@deriving show] -type 'a back_info = - | SingleBack of 'a option +type ('a, 'b) back_info = + | SingleBack of 'a (** Information about a single backward function, if pertinent. We use this variant if we split the forward and the backward functions. *) - | AllBacks of 'a RegionGroupId.Map.t + | AllBacks of 'b RegionGroupId.Map.t (** Information about the various backward functions. We use this if we *do not* split the forward and the backward functions. @@ -900,7 +900,8 @@ type 'a back_info = *) [@@deriving show] -type back_inputs_info = inputs_info back_info [@@deriving show] +type back_inputs_info = (inputs_info option, inputs_info) back_info +[@@deriving show] (** Meta information about a function signature *) type fun_sig_info = { diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 86c80f87..eba44e3e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -308,20 +308,6 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let indent_incr = " " in Print.Values.abs_to_string env verbose indent indent_incr abs -let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (generics : generic_args) - (ctx : bs_ctx) : inst_fun_sig = - (* Lookup the non-instantiated function signature *) - let sg = - (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg - in - (* Create the substitution *) - (* There shouldn't be any reference to Self *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics sg.generics generics tr_self in - (* Apply *) - fun_sig_substitute subst sg - let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = TypeDeclId.Map.find id ctx.type_context.llbc_type_decls @@ -330,12 +316,6 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls -(* TODO: move *) -let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) - (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = - let id = (E.FRegular def_id, back_id) in - (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg - (* Some generic translation functions (we need to translate different "flavours" of types: forward types, backward types, etc.) *) let rec translate_generic_args (translate_ty : T.ty -> ty) @@ -994,35 +974,44 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] in (* Outputs *) + let compute_back_outputs_for_gid (gid : RegionGroupId.id) : + string option list * ty list = + (* The outputs are the borrows inside the regions of the abstractions + and which are present in the input values. For instance, see: + {[ + fn f<'a>(x : &'a mut u32) -> ...; + ]} + Upon ending the abstraction for 'a, we give back the borrow which + was consumed through the [x] parameter. + *) + let outputs = + List.map + (fun (name, input_ty) -> (name, translate_back_ty_for_gid gid input_ty)) + (List.combine input_names sg.inputs) + in + (* Filter *) + let outputs = + List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs + in + let outputs = + List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs + in + List.split outputs + in let output_names, doutputs = match gid with | None -> - (* This is a forward function: there is one (unnamed) output *) + (* This is a forward function: there is one (unnamed) output. + + If we merge the fwd/back functions we might need to compute + the information about the back outputs. + *) + (* TODO: *) + assert (not !Config.return_back_funs); ([ None ], [ translate_fwd_ty type_infos sg.output ]) | Some gid -> - (* This is a backward function: there might be several outputs. - The outputs are the borrows inside the regions of the abstractions - and which are present in the input values. For instance, see: - {[ - fn f<'a>(x : &'a mut u32) -> ...; - ]} - Upon ending the abstraction for 'a, we give back the borrow which - was consumed through the [x] parameter. - *) - let outputs = - List.map - (fun (name, input_ty) -> - (name, translate_back_ty_for_gid gid input_ty)) - (List.combine input_names sg.inputs) - in - (* Filter *) - let outputs = - List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs - in - let outputs = - List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs - in - List.split outputs + (* This is a backward function: there might be several outputs. *) + compute_back_outputs_for_gid gid in (* Create the return type *) let output = @@ -2016,37 +2005,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | None -> output | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - (if (* TODO: normalize the types *) !Config.type_check_pure_code then - match fun_id with - | FunId fun_id -> - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) generics ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " - (List.map (pure_ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = -- cgit v1.2.3 From 83c5be42e1750d329ad31bc9151d7b0446af5a0f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 12:14:01 +0100 Subject: Make progress on generalizing the signature information --- compiler/Extract.ml | 8 +- compiler/Pure.ml | 130 ++++++++++++------ compiler/PureMicroPasses.ml | 10 +- compiler/PureUtils.ml | 10 +- compiler/SymbolicToPure.ml | 313 ++++++++++++++++++++++---------------------- 5 files changed, 253 insertions(+), 218 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 93fcf416..1ea26d79 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1472,7 +1472,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: *) assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.fwd_info.num_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1520,7 +1520,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: *) assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.fwd_info.num_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in @@ -1798,7 +1798,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) assert body.is_global_decl_body; assert (Option.is_none body.back_id); assert (body.signature.inputs = []); - assert (List.length body.signature.doutputs = 1); assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) @@ -1817,7 +1816,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_ty, body_ty = let ty = body.signature.output in - if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + if body.signature.fwd_info.effect_info.can_fail then + (unwrap_result_ty ty, ty) else (ty, mk_result_ty ty) in match body.body with diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 34f3ef72..fb0509f4 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -907,12 +907,88 @@ type back_inputs_info = (inputs_info option, inputs_info) back_info type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) - back_info : back_inputs_info; - (** Information about the inputs of the backward functions. *) effect_info : fun_effect_info; } [@@deriving show] +type back_sg_info = { + inputs : ty list; (** The additional inputs of the backward function *) + input_names : string option list; + (** The optional names for the additional inputs *) + outputs : ty list; + (** The "decomposed" list of outputs. + + The list contains all the types of + all the given back values (there is at most one type per forward + input argument). + + Ex.: + {[ + fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; + ]} + Decomposed outputs: + - forward function: [[T]] + - backward function: [[T; T]] (for "x" and "y") + + Non-decomposed ouputs (if the function can fail, but is not stateful): + - [result T] + - [[result (T * T)]] + *) + output_names : string option list; + (** The optional names for the backward outputs. + We derive those from the names of the inputs of the original LLBC + function. *) + effect_info : fun_effect_info; +} +[@@deriving show] + +(** A *decomposed* function signature. *) +type decomposed_fun_sig = { + generics : generic_params; + (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) + llbc_generics : Types.generic_params; + (** We use the LLBC generics to generate "pretty" names, for instance + for the variables we introduce for the trait clauses: we derive + those names from the types, and when doing so it is more meaningful + to derive them from the original LLBC types from before the + simplification of types like boxes and references. *) + preds : predicates; + fwd_inputs : ty list; + (** The types of the inputs of the forward function. + + Note that those input types take include the [fuel] parameter, + if the function uses fuel for termination, and the [state] parameter, + if the function is stateful. + + For instance, if we have the following Rust function: + {[ + fn f(x : int); + ]} + + If we translate it to a stateful function which uses fuel we get: + {[ + val f : nat -> int -> state -> result (state * unit); + ]} + + In particular, the list of input types is: [[nat; int; state]]. + *) + fwd_output : ty; + (** The "pure" output type of the forward function. + + Note that this type doesn't contain the "effect" of the function (i.e., + we haven't added the [state] if it is a stateful function and haven't + wrapped the type in a [result]). Also, this output type is only about + the forward function (it doesn't contain the type of the closures we + return for the backward functions, in case we merge the forward and + backward functions). + *) + back_sg : back_sg_info RegionGroupId.Map.t; + (** Information about the backward functions *) + fwd_info : fun_sig_info; + (** Additional information about the forward function *) +} +[@@deriving show] + (** A function signature. We have the following cases: @@ -927,15 +1003,15 @@ type fun_sig_info = { [in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state -> result (state & (back_out0 & ... & back_outp))] (* state-error *) - Note that a stateful backward function may take two states as inputs: the - state received by the associated forward function, and the state at which - the backward is called. This leads to code of the following shape: + Note that a stateful backward function may take two states as inputs: the + state received by the associated forward function, and the state at which + the backward is called. This leads to code of the following shape: - {[ - (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd - ... // the state may be updated - (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back - ]} + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} The function's type should be given by [mk_arrows sig.inputs sig.output]. We provide additional meta-information with {!fun_sig.info}: @@ -983,40 +1059,14 @@ type fun_sig = { be a tuple with a [state] if the function is stateful, and will be wrapped in a [result] if the function can fail. *) - doutputs : ty list; - (** The "decomposed" list of outputs. - - In case of a forward function, the list has length = 1, for the - type of the returned value. - - In case of backward function, the list contains all the types of - all the given back values (there is at most one type per forward - input argument). - - Ex.: - {[ - fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; - ]} - Decomposed outputs: - - forward function: [[T]] - - backward function: [[T; T]] (for "x" and "y") - - Non-decomposed ouputs (if the function can fail, but is not stateful): - - [result T] - - [[result (T * T)]] - *) - info : fun_sig_info; (** Additional information *) + fwd_info : fun_sig_info; + (** Additional information about the forward function. *) + back_effect_info : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] (** An instantiated function signature. See {!fun_sig} *) -type inst_fun_sig = { - inputs : ty list; - output : ty; - doutputs : ty list; - info : fun_sig_info; -} -[@@deriving show] +type inst_fun_sig = { inputs : ty list; output : ty } [@@deriving show] type fun_body = { inputs : var list; diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7f122f15..d92b3de0 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1358,11 +1358,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = } in - { - fwd_info; - back_info = fun_sig_info.back_info; - effect_info = loop_effect_info; - } + { fwd_info; effect_info = loop_effect_info } in assert (fun_sig_info_is_wf loop_sig_info); @@ -2187,7 +2183,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in (* TODO: *) assert (not !Config.return_back_funs); - let { fwd_info; back_info; effect_info } = info in + let { fwd_info; effect_info } = info in let { has_fuel; @@ -2212,7 +2208,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let info = { fwd_info; back_info; effect_info } in + let info = { fwd_info; effect_info } in assert (fun_sig_info_is_wf info); let signature = { generics; llbc_generics; preds; inputs; output; doutputs; info } diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 3c038149..dfea255a 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -73,12 +73,6 @@ let inputs_info_is_wf (info : inputs_info) : bool = let fun_sig_info_is_wf (info : fun_sig_info) : bool = inputs_info_is_wf info.fwd_info - && - match info.back_info with - | SingleBack None -> true - | SingleBack (Some info) -> inputs_info_is_wf info - | AllBacks infos -> - List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos) let dest_arrow_ty (ty : ty) : ty * ty = match ty with @@ -210,9 +204,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in - let doutputs = List.map subst sg.doutputs in - let info = sg.info in - { inputs; output; doutputs; info } + { inputs; output } (** We use this to check whether we need to add parentheses around expressions. We only look for outer monadic let-bindings. diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index eba44e3e..456ec0f6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -128,9 +128,9 @@ type bs_ctx = { trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) - sg : fun_sig; - (** The function signature - useful in particular to translate [Panic] *) - fwd_sg : fun_sig; (** The signature of the forward function *) + sg : decomposed_fun_sig; + (** Information about the function signature - useful in particular to + translate [Panic] *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -828,7 +828,7 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } -(** Translate a function signature. +(** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding @@ -839,26 +839,15 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) - (sg : A.fun_sig) (input_names : string option list) - (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = - assert (Option.is_none bid || not !Config.return_back_funs); +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : + decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) let regions_hierarchy = FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies in - let gid, parents = - match bid with - | None -> (None, T.RegionGroupId.Set.empty) - | Some bid -> - let parents = list_ancestor_region_groups regions_hierarchy bid in - (Some bid, parents) - in - (* Is the function stateful, and can it fail? *) - let lid = None in - let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -886,17 +875,52 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) { sg with A.inputs; output } in - (* List the inputs for: - * - the fuel - * - the forward function - * - the parent backward functions, in proper order - * - the current backward function (if it is a backward function) - *) - let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in - (* For the backward functions: for now we don't supported nested borrows, - * so just check that there aren't parent regions *) - assert (T.RegionGroupId.Set.is_empty parents); + (* Is the forward function stateful, and can it fail? *) + let fwd_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None None + in + (* Compute the forward inputs *) + let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in + let fwd_inputs_no_fuel_no_state = + List.map (translate_fwd_ty type_infos) sg.inputs + in + (* State input for the forward function *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if fwd_effect_info.stateful_group then [ mk_state_ty ] else [] + in + let fwd_inputs = + List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ] + in + (* Compute the backward output, without the effect information *) + let fwd_output = translate_fwd_ty type_infos sg.output in + (* The additinoal information *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let info = { fwd_info; effect_info = fwd_effect_info } in + assert (fun_sig_info_is_wf info); + info + in + + (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) : ty option = @@ -923,7 +947,11 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in - let translate_back_inputs_for_gid gid : ty list = + let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list = + (* For now we don't supported nested borrows, so we check that there + aren't parent regions *) + let parents = list_ancestor_region_groups regions_hierarchy gid in + assert (T.RegionGroupId.Set.is_empty parents); (* For now, we don't allow nested borrows, so the additional inputs to the backward function can only come from borrows that were returned like in (for the backward function we introduce for 'a): @@ -935,45 +963,6 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Compute the additinal inputs for the current function, if it is a backward - * function *) - let back_inputs = - match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid - in - (* If the function is stateful, the inputs are: - - forward: [fwd_ty0, ..., fwd_tyn, state] - - backward: - - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] - - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] - - The backward takes the same state as input as the forward function, - together with the state at the point where it gets called, if it is - stateful. - - See the comments for {!Config.backward_no_state_update} - *) - let fwd_state_ty = - (* For the forward state, we check if the *whole group* is stateful. - See {!effect_info}. *) - if effect_info.stateful_group then [ mk_state_ty ] else [] - in - let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = - (* For the backward state, we check if the function is a backward function, - and it is stateful *) - if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] - in - let back_state_ty = mk_back_state_ty_for_gid gid in - - (* Concatenate the inputs, in the following order: - * - forward inputs - * - forward state input - * - backward inputs - * - backward state input - *) - let inputs = - List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] - in - (* Outputs *) let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = (* The outputs are the borrows inside the regions of the abstractions @@ -998,103 +987,111 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) in List.split outputs in - let output_names, doutputs = - match gid with - | None -> - (* This is a forward function: there is one (unnamed) output. - - If we merge the fwd/back functions we might need to compute - the information about the back outputs. - *) - (* TODO: *) - assert (not !Config.return_back_funs); - ([ None ], [ translate_fwd_ty type_infos sg.output ]) - | Some gid -> - (* This is a backward function: there might be several outputs. *) - compute_back_outputs_for_gid gid - in - (* Create the return type *) - let output = - (* Group the outputs together *) - let output = mk_simpl_tuple_ty doutputs in - (* Add the output state *) - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output + let compute_back_info_for_group (rg : T.region_var_group) : + RegionGroupId.id * back_sg_info = + let gid = rg.id in + let back_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in - (* Wrap in a result type *) - if effect_info.can_fail then mk_result_ty output else output + let inputs_no_state = translate_back_inputs_for_gid gid in + let inputs_no_state_names = + List.map (fun _ -> Some "ret") inputs_no_state + in + let state_ty, state_name = + if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + in + let inputs = inputs_no_state @ state_ty in + let input_names = inputs_no_state_names @ state_name in + let output_names, outputs = compute_back_outputs_for_gid gid in + let info = + { + inputs; + input_names; + outputs; + output_names; + effect_info = back_effect_info; + } + in + (gid, info) in + let back_sg = + RegionGroupId.Map.of_list + (List.map compute_back_info_for_group regions_hierarchy) + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in + (* Return *) - (* TODO: *) - assert (not !Config.return_back_funs); - let has_fuel = fuel <> [] in - let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in - let num_fwd_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_fuel_no_state - in - let num_back_inputs_no_state = - if bid = None then None else Some (List.length back_inputs) - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - } + let preds = translate_predicates sg.preds in + { + generics; + llbc_generics = sg.generics; + preds; + fwd_inputs; + fwd_output; + back_sg; + fwd_info; + } + +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id option) : fun_sig = + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) in - let compute_back_info (back_state_ty : ty list) - (num_back_inputs_no_state : int) : inputs_info = - let n = num_back_inputs_no_state in - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - } + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty (effect_info : fun_effect_info) output = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail then mk_result_ty output else output in - let back_info : back_inputs_info = - if !Config.return_back_funs then - (* Create the map *) - AllBacks - (RegionGroupId.Map.of_list - (List.map - (fun (rg : T.region_var_group) -> - ( rg.id, - let back_inputs = translate_back_inputs_for_gid rg.id in - let num_back_inputs = List.length back_inputs in - (* TODO: slightly overkill *) - let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in - compute_back_info back_state_ty num_back_inputs )) - regions_hierarchy)) + let inputs, output = + if !Config.return_back_funs then ( + assert (gid = None); + (* Compute the arrow types for all the backward functions *) + let back_tys = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in + let output = mk_output_ty effect_info output in + let inputs = dsg.fwd_inputs in + (inputs, output)) else - SingleBack - (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) - in - let info = { fwd_info; back_info; effect_info } in - assert (fun_sig_info_is_wf info); - let preds = translate_predicates sg.preds in - let sg = - { - generics; - llbc_generics = sg.generics; - preds; - inputs; - output; - doutputs; - info; - } - in - { sg; output_names } + match gid with + | None -> + let effect_info = dsg.fwd_info.effect_info in + let output = mk_output_ty effect_info dsg.fwd_output in + (dsg.fwd_inputs, output) + | Some gid -> + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + (inputs, output) + in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Generate the fresh variable *) -- cgit v1.2.3 From 62cb926e76ef0c9fb048b0e340bdae5b9dd76a84 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 14:06:16 +0100 Subject: Make progress on updating SymbolicToPure --- compiler/SymbolicToPure.ml | 169 ++++++++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 57 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 456ec0f6..d62cc829 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -127,7 +127,15 @@ type bs_ctx = { trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; - bid : T.RegionGroupId.id option; (** TODO: rename *) + bid : RegionGroupId.id option; + (** TODO: rename + + The id of the group region we are currently translating. + If we split the forward/backward functions, we set this id at the + very beginning of the translation. + If we don't split, we set it to `None`, then update it when we enter + an expression which is specific to a backward function. + *) sg : decomposed_fun_sig; (** Information about the function signature - useful in particular to translate [Panic] *) @@ -139,7 +147,7 @@ type bs_ctx = { var_counter : VarId.generator; state_var : VarId.id; (** The current state variable, in case the function is stateful *) - back_state_var : VarId.id; + back_state_vars : VarId.id RegionGroupId.Map.t; (** The additional input state variable received by a stateful backward function. When generating stateful functions, we generate code of the following form: @@ -163,16 +171,16 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list T.RegionGroupId.Map.t; + backward_inputs : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). *) - backward_outputs : var list T.RegionGroupId.Map.t; + backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state) *) - loop_backward_outputs : var list T.RegionGroupId.Map.t option; + loop_backward_outputs : var list RegionGroupId.Map.t option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). [None] if we are not inside a loop, [Some] otherwise (and whatever @@ -300,6 +308,13 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + match ctx.bid with + | None -> ctx.sg.fwd_info.effect_info + | Some bid -> + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + back_sg.effect_info + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1034,6 +1049,24 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty + = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail then mk_result_ty output else output + +(** Compute the arrow types for all the backward functions *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty_from_effect_info effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1050,27 +1083,13 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) in (* Two cases depending on whether we split the forward/backward functions or not *) - let mk_output_ty (effect_info : fun_effect_info) output = - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - if effect_info.can_fail then mk_result_ty output else output - in + let mk_output_ty = mk_output_ty_from_effect_info in + let inputs, output = if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = - List.map - (fun (back_sg : back_sg_info) -> - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - mk_arrows inputs output) - (RegionGroupId.Map.values dsg.back_sg) - in + let back_tys = compute_back_tys dsg in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1584,30 +1603,43 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in - mk_simpl_tuple_ty tys - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - mk_simpl_tuple_ty ctx.sg.doutputs - in + let effect_info = ctx_get_effect_info ctx in (* TODO: we should use a [Fail] function *) - if ctx.sg.info.effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id error_failure_id ret_ty + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in + ret_v + else mk_result_fail_texpression_with_error_id error_failure_id output_ty + in + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let back_vars = + T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) in - ret_v - else mk_result_fail_texpression_with_error_id error_failure_id output_ty + let tys = List.map (fun (v : var) -> v.ty) back_vars in + let output = mk_simpl_tuple_ty tys in + mk_output output + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + match ctx.bid with + | None -> + if !Config.return_back_funs then + let back_tys = compute_back_tys ctx.sg in + let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + mk_output output + else mk_output ctx.sg.fwd_output + | Some bid -> + let output = + mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output (** [opt_v]: the value to return, in case we translate a forward body *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) @@ -1641,7 +1673,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) * - error-monad: Return x * - state-error: Return (state, x) * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -1695,7 +1727,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) * effect - in particular, one manipulates a state iff the other does * the same. * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -2550,24 +2582,50 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) - (e : S.expression) (back_e : S.expression S.region_group_id_map) + (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* Update the current state with the additional state received by the backward - function, if needs be, and lookup the proper expression *) - let translate_end ctx = + (* TODO: *) + assert (not !Config.return_back_funs); + + let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) let ctx, e = match ctx.bid with - | None -> (ctx, e) + | None -> + (* We are translating the forward function - nothing to do *) + (ctx, fwd_e) | Some bid -> - let ctx = { ctx with state_var = ctx.back_state_var } in + (* There are two cases here: + - if we split the fwd/backward functions, we simply need to update + the state + - if we don't split, we also need to wrap the expression in a + lambda, which introduces the additional inputs of the backward + function + *) + let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in + let ctx = { ctx with state_var = back_state_var } in let e = T.RegionGroupId.Map.find bid back_e in (ctx, e) in translate_expression e ctx in + (* There are two cases, depending on whether we are splitting the forward/backward + functions or not. + + - if we split, then we simply need to translate the proper "end" expression, + that is the end of the forward function, or of the backward function we + are currently translating. + - if we don't split, then we need to translate the end of the forward + function (this is the value we will return) and generate the bodies + of the backward functions (which we will also return). + + Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression. + *) + let translate_end ctx = failwith "TODO" in + (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) match loop_input_values with @@ -2617,10 +2675,7 @@ and translate_forward_end (ectx : C.eval_ctx) in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = - let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in - fresh_var None output_ty ctx - in + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in let args, ctx, out_pats = let output_pat = mk_typed_pattern_from_var output_var None in @@ -2832,7 +2887,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Add the input state *) let input_state = - if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None + if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None in (* Translate the loop body *) -- cgit v1.2.3 From ea583d9f0f5e4a1a687b70f0e04e875969462157 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 17:20:30 +0100 Subject: Make good progress on updating SymbolicToPure --- compiler/PrintPure.ml | 8 ++ compiler/Pure.ml | 7 +- compiler/PureTypeCheck.ml | 6 ++ compiler/PureUtils.ml | 23 +++++ compiler/SymbolicToPure.ml | 224 +++++++++++++++++++++++++++++++++++++-------- 5 files changed, 226 insertions(+), 42 deletions(-) (limited to 'compiler') diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 2fe5843e..3a5ce513 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -592,6 +592,14 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) in "[ " ^ String.concat ", " fields ^ " ]" | _ -> raise (Failure "Unexpected")) + | Lambda _ -> + let pats, e = destruct_lambdas e in + let vars = + String.concat " " (List.map (typed_pattern_to_string env) pats) + in + let e = texpression_to_string env false indent indent_incr e in + let s = "λ " ^ vars ^ " => " ^ e in + if inside then "(" ^ s ^ ")" else s | Meta (meta, e) -> ( let meta_s = emeta_to_string env meta in let e = texpression_to_string env inside indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index fb0509f4..eb6b00c8 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -728,6 +728,7 @@ type expression = | Switch of texpression * switch_body | Loop of loop (** See the comments for {!loop} *) | StructUpdate of struct_update (** See the comments for {!struct_update} *) + | Lambda of typed_pattern * texpression (** [λ x => e] *) | Meta of (emeta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list @@ -912,9 +913,9 @@ type fun_sig_info = { [@@deriving show] type back_sg_info = { - inputs : ty list; (** The additional inputs of the backward function *) - input_names : string option list; - (** The optional names for the additional inputs *) + inputs : (string option * ty) list; + (** The additional inputs of the backward function *) + inputs_no_state : (string option * ty) list; outputs : ty list; (** The "decomposed" list of outputs. diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index a62a2361..3c1800a8 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -229,6 +229,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | _ -> raise (Failure "Unexpected")) + | Lambda (pat, e_next) -> + assert (e.ty = e_next.ty); + (* Check the pattern and register the introduced variables at the same time *) + let ctx = check_typed_pattern ctx pat in + (* Check the next expression *) + check_texpression ctx e_next | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index dfea255a..80b25641 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -221,6 +221,9 @@ let rec let_group_requires_parentheses (e : texpression) : bool = if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> let_group_requires_parentheses next_e + | Lambda (_, _) -> + (* Being conservative here *) + true | Loop _ -> (* Should have been eliminated *) raise (Failure "Unreachable") @@ -713,3 +716,23 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) let info = TypeDeclId.Map.find id ctx in info.is_tuple_struct | TAssumed _ -> false + +let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : + texpression = + let ty = TArrow (var.ty, e.ty) in + let pat = PatVar (var, mp) in + let pat = { value = pat; ty = var.ty } in + let e = Lambda (pat, e) in + { e; ty } + +let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) + (e : texpression) : texpression = + let vars = List.combine vars mps in + List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars + +let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = + match e.e with + | Lambda (pat, e) -> + let pats, e = destruct_lambdas e in + (pat :: pats, e) + | _ -> ([], e) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d62cc829..8e06db7c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -121,9 +121,9 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; + type_context : type_context; (* TODO: rename *) + fun_context : fun_context; (* TODO: rename *) + global_context : global_context; (* TODO: rename *) trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; @@ -148,7 +148,9 @@ type bs_ctx = { state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; - (** The additional input state variable received by a stateful backward function. + (** The additional input state variable received by a stateful backward function, + **in case we are splitting the forward/backward functions**. + When generating stateful functions, we generate code of the following form: @@ -161,7 +163,9 @@ type bs_ctx = { When translating a backward function, we need at some point to update [state_var] with [back_state_var], to account for the fact that the state may have been updated by the caller between the call to the - forward function and the call to the backward function. + forward function and the call to the backward function. We also need + to make sure we use the same variable in all the branches (because + this variable is quantified at the definition level). *) fuel0 : VarId.id; (** The original fuel taken as input by the function (if we use fuel) *) @@ -171,10 +175,20 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list RegionGroupId.Map.t; + backward_inputs_no_state : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). + + If we split the forward/backward functions: we initialize this map + when initializing the bs_ctx, because those variables are quantified + at the definition level. Otherwise, we initialize it upon diving + into the expressions which are specific to the backward functions. + *) + backward_inputs_with_state : var list RegionGroupId.Map.t; + (** All the additional input parameters for the backward functions. + + Same remarks as for {!backward_inputs_no_state}. *) backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding @@ -308,13 +322,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p -let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = - match ctx.bid with +let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) : + fun_effect_info = + match bid with | None -> ctx.sg.fwd_info.effect_info | Some bid -> let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in back_sg.effect_info +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + ctx_get_effect_info_for_bid ctx ctx.bid + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1009,19 +1027,18 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in - let inputs_no_state_names = - List.map (fun _ -> Some "ret") inputs_no_state + let inputs_no_state = + List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - let state_ty, state_name = - if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + let state = + if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in - let inputs = inputs_no_state @ state_ty in - let input_names = inputs_no_state_names @ state_name in + let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let info = { inputs; - input_names; + inputs_no_state; outputs; output_names; effect_info = back_effect_info; @@ -1061,7 +1078,7 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in mk_arrows inputs output) @@ -1105,14 +1122,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) | Some gid -> let back_sg = RegionGroupId.Map.find gid dsg.back_sg in let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty effect_info output in (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } -let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in let state_var = @@ -1122,7 +1139,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Update the context *) let ctx = { ctx with var_counter; state_var = id } in (* Return *) - (ctx, state_pat) + (ctx, state_var, state_pat) (** WARNING: do not call this function directly. Call [fresh_named_var_for_symbolic_value] instead. *) @@ -1776,7 +1793,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in @@ -2010,7 +2027,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2115,15 +2132,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) let-binding: {[ let id_back x nx = - let s = nx in // the name [s] is not important (only collision matters) - ... + let s = nx in // the name [s] is not important (only collision matters) + ... ]} This let-binding later gets inlined, during a micro-pass. *) (* First, retrieve the list of variables used for the inputs for the * backward function *) - let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in + let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in (* Retrieve the values consumed upon ending the loans inside this * abstraction: as there are no nested borrows, there should be none. *) let consumed = abs_to_consumed ctx ectx abs in @@ -2185,7 +2202,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) values consumed upon ending the abstraction (i.e., we don't use [abs_to_consumed]) *) let back_inputs_vars = - T.RegionGroupId.Map.find rg_id ctx.backward_inputs + T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in let back_inputs = List.map mk_texpression_from_var back_inputs_vars in (* If the function is stateful: @@ -2195,7 +2212,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2590,25 +2607,69 @@ and translate_forward_end (ectx : C.eval_ctx) let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) - let ctx, e = + let ctx, e, finish = match ctx.bid with | None -> (* We are translating the forward function - nothing to do *) - (ctx, fwd_e) + (ctx, fwd_e, fun e -> e) | Some bid -> (* There are two cases here: - if we split the fwd/backward functions, we simply need to update - the state + the state. - if we don't split, we also need to wrap the expression in a lambda, which introduces the additional inputs of the backward function *) - let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in - let ctx = { ctx with state_var = back_state_var } in + let ctx = + (* Introduce variables for the inputs and the state variable + and update the context. *) + if !Config.return_back_funs then + (* If the forward/backward functions are not split, we need + to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx = { ctx with bid = Some bid } in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if (ctx_get_effect_info ctx).stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } + else + (* Update the state variable *) + let back_state_var = + RegionGroupId.Map.find bid ctx.back_state_vars + in + { ctx with state_var = back_state_var } + in + let e = T.RegionGroupId.Map.find bid back_e in - (ctx, e) + let finish e = + (* Wrap in lambdas if necessary *) + if !Config.return_back_funs then + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e + else e + in + (ctx, e, finish) in - translate_expression e ctx + let e = translate_expression e ctx in + finish e in (* There are two cases, depending on whether we are splitting the forward/backward @@ -2624,7 +2685,87 @@ and translate_forward_end (ectx : C.eval_ctx) Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression. *) - let translate_end ctx = failwith "TODO" in + let translate_end ctx = + if !Config.return_back_funs then + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output_ty = + let ty = ctx.sg.fwd_output in + if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] + else ty + in + let ctx, fwd_var = fresh_var None output_ty ctx in + let ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + let fwd_e = translate_one_end ctx None in + + (* Introduce the backward functions *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_context.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in + let _, back_vars = fresh_vars back_vars ctx in + + (* Create the return expressions *) + let vars = fwd_var :: back_vars in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in + + (* Bind the expressions for the backward function and the expression + for the computation of the forward output *) + let e = + List.fold_right + (fun (var, back_e) e -> + mk_let false (mk_typed_pattern_from_var var None) back_e e) + (List.combine back_vars back_el) + ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e + else translate_one_end ctx ctx.bid + in (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) @@ -2687,7 +2828,7 @@ and translate_forward_end (ectx : C.eval_ctx) let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in ( List.concat [ fuel; args; [ state_var ] ], ctx, [ nstate_pat; output_pat ] ) @@ -3025,8 +3166,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in - (* Retrieve the signature *) - let signature = ctx.sg in + (* Translate the signature *) + let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies in @@ -3070,20 +3211,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match bid with | None -> [] | Some back_id -> + assert (not !Config.return_back_funs); let parents_ids = list_ordered_ancestor_region_groups regions_hierarchy back_id in let backward_ids = List.append parents_ids [ back_id ] in List.concat (List.map - (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) + (fun id -> + T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) backward_ids) in (* Introduce the backward input state (the state at call site of the * *backward* function), if necessary *) let back_state = if effect_info.stateful && Option.is_some bid then - [ mk_state_var ctx.back_state_var ] + let state_var = + RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars + in + [ mk_state_var state_var ] else [] in (* Group the inputs together *) -- cgit v1.2.3 From 5fa83883b4d573cfd252478f7937c8bde0ec01f6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 17:22:01 +0100 Subject: Minor fix --- compiler/SymbolicToPure.ml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8e06db7c..08f9e950 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2605,10 +2605,11 @@ and translate_forward_end (ectx : C.eval_ctx) assert (not !Config.return_back_funs); let translate_one_end ctx (bid : RegionGroupId.id option) = + let ctx = { ctx with bid } in (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) let ctx, e, finish = - match ctx.bid with + match bid with | None -> (* We are translating the forward function - nothing to do *) (ctx, fwd_e, fun e -> e) @@ -2628,7 +2629,6 @@ and translate_forward_end (ectx : C.eval_ctx) to introduce fresh variables for the additional inputs, because they are locally introduced in a lambda *) let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let ctx = { ctx with bid = Some bid } in let ctx, backward_inputs_no_state = fresh_vars back_sg.inputs_no_state ctx in -- cgit v1.2.3 From 955fdab55304979ba2d61432ea654241f20abaa4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 18:14:12 +0100 Subject: Make progress on propagating the changes --- compiler/Extract.ml | 8 ++-- compiler/PrintPure.ml | 16 ++----- compiler/Pure.ml | 3 +- compiler/PureMicroPasses.ml | 110 ++++++++++++++++++++++---------------------- compiler/PureTypeCheck.ml | 8 +--- compiler/PureUtils.ml | 24 +++------- 6 files changed, 72 insertions(+), 97 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 1ea26d79..7e2efd8a 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -285,9 +285,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args - | Abs _ -> - let xl, e = destruct_abs_list e in - extract_Abs ctx fmt inside xl e + | Lambda _ -> + let xl, e = destruct_lambdas e in + extract_Lambda ctx fmt inside xl e | Qualif _ -> (* We use the app case *) extract_App ctx fmt inside e [] @@ -574,7 +574,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* No argument: shouldn't happen *) raise (Failure "Unreachable") -and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) +and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (xl : typed_pattern list) (e : texpression) : unit = (* Open a box for the abs expression *) F.pp_open_hovbox fmt ctx.indent_incr; diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 3a5ce513..79506c04 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -543,9 +543,9 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) let app, args = destruct_apps e in (* Convert to string *) app_to_string env inside indent indent_incr app args - | Abs _ -> - let xl, e = destruct_abs_list e in - let e = abs_to_string env indent indent_incr xl e in + | Lambda _ -> + let xl, e = destruct_lambdas e in + let e = lambda_to_string env indent indent_incr xl e in if inside then "(" ^ e ^ ")" else e | Qualif _ -> (* Qualifier without arguments *) @@ -592,14 +592,6 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) in "[ " ^ String.concat ", " fields ^ " ]" | _ -> raise (Failure "Unexpected")) - | Lambda _ -> - let pats, e = destruct_lambdas e in - let vars = - String.concat " " (List.map (typed_pattern_to_string env) pats) - in - let e = texpression_to_string env false indent indent_incr e in - let s = "λ " ^ vars ^ " => " ^ e in - if inside then "(" ^ s ^ ")" else s | Meta (meta, e) -> ( let meta_s = emeta_to_string env meta in let e = texpression_to_string env inside indent indent_incr e in @@ -668,7 +660,7 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string) (* Add parentheses *) if all_args <> [] && inside then "(" ^ e ^ ")" else e -and abs_to_string (env : fmt_env) (indent : string) (indent_incr : string) +and lambda_to_string (env : fmt_env) (indent : string) (indent_incr : string) (xl : typed_pattern list) (e : texpression) : string = let xl = List.map (typed_pattern_to_string env) xl in let e = texpression_to_string env false indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index eb6b00c8..ddacf0c4 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -684,7 +684,7 @@ type expression = field accesses with calls to projectors over fields (when there are clashes of field names, some provers like F* get pretty bad...) *) - | Abs of typed_pattern * texpression (** Lambda abstraction: [fun x -> e] *) + | Lambda of typed_pattern * texpression (** Lambda abstraction: [λ x => e] *) | Qualif of qualif (** A top-level qualifier *) | Let of bool * typed_pattern * texpression * texpression (** Let binding. @@ -728,7 +728,6 @@ type expression = | Switch of texpression * switch_body | Loop of loop (** See the comments for {!loop} *) | StructUpdate of struct_update (** See the comments for {!struct_update} *) - | Lambda of typed_pattern * texpression (** [λ x => e] *) | Meta of (emeta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index d92b3de0..0102b13e 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -385,17 +385,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, arg = update_texpression arg ctx in let e = App (app, arg) in (ctx, e) - | Abs (x, e) -> update_abs x e ctx | Qualif _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Loop loop -> update_loop loop ctx | StructUpdate supd -> update_struct_update supd ctx + | Lambda (lb, e) -> update_lambda lb e ctx | Meta (meta, e) -> update_emeta meta e ctx in (ctx, { e; ty }) (* *) - and update_abs (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : + and update_lambda (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) let ctx = add_left_constraint x ctx in @@ -404,7 +404,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (* Update the abstracted value *) let x = update_typed_pattern ctx x in (* Put together *) - (ctx, Abs (x, e)) + (ctx, Lambda (x, e)) (* *) and update_let (monadic : bool) (lv : typed_pattern) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = @@ -890,12 +890,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) + | Lambda (_, e) -> self#visit_texpression env e | App _ -> ( fun () -> match opt_destruct_function_call e with | Some (func1, tys1, args1) -> check_call func1 tys1 args1 | None -> false) - | Abs (_, e) -> self#visit_texpression env e | Qualif _ -> (* Note that this case includes functions without arguments *) fun () -> false @@ -975,7 +975,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) | Var _ | CVar _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) - | StructUpdate _ | Abs _ -> + | StructUpdate _ | Lambda _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -1323,28 +1323,20 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = method! visit_Loop env loop = let fun_sig = def.signature in - let fun_sig_info = fun_sig.info in - let fun_effect_info = fun_sig_info.effect_info in + let fwd_info = fun_sig.fwd_info in + let fwd_effect_info = fwd_info.effect_info in (* TODO: *) assert (not !Config.return_back_funs); (* Generate the loop definition *) - let loop_effect_info = - { - stateful_group = fun_effect_info.stateful_group; - stateful = fun_effect_info.stateful; - can_fail = fun_effect_info.can_fail; - can_diverge = fun_effect_info.can_diverge; - is_rec = fun_effect_info.is_rec; - } - in + let loop_fwd_effect_info = fwd_effect_info in - let loop_sig_info = + let loop_fwd_sig_info : fun_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in let fwd_info : inputs_info = - let info = fun_sig_info.fwd_info in + let info = fwd_info.fwd_info in let fwd_state = info.num_inputs_with_fuel_with_state - info.num_inputs_with_fuel_no_state @@ -1358,48 +1350,48 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = } in - { fwd_info; effect_info = loop_effect_info } + { fwd_info; effect_info = loop_fwd_effect_info } in - assert (fun_sig_info_is_wf loop_sig_info); + assert (fun_sig_info_is_wf loop_fwd_sig_info); let inputs_tys = - (* TODO: *) - assert (not !Config.return_back_funs); - let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in - let info = fun_sig_info.fwd_info in - let state = + let info = fwd_info.fwd_info in + let fwd_state = Collections.List.subslice fun_sig.inputs info.num_inputs_with_fuel_no_state info.num_inputs_with_fuel_with_state in - let _, back_inputs = - Collections.List.split_at fun_sig.inputs - info.num_inputs_with_fuel_with_state + let back_inputs = + if !Config.return_back_funs then [] + else + snd + (Collections.List.split_at fun_sig.inputs + info.num_inputs_with_fuel_with_state) in - List.concat [ fuel; fwd_inputs; state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] in - let output, doutputs = + let output = match loop.back_output_tys with | None -> (* Forward function: the return type is the same as the parent function *) - (fun_sig.output, fun_sig.doutputs) + fun_sig.output | Some doutputs -> (* Backward function: custom return type *) let output = mk_simpl_tuple_ty doutputs in let output = - if loop_effect_info.stateful then + if loop_fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in let output = - if loop_effect_info.can_fail then mk_result_ty output + if loop_fwd_effect_info.can_fail then mk_result_ty output else output in - (output, doutputs) + output in let loop_sig = @@ -1409,8 +1401,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = preds = fun_sig.preds; inputs = inputs_tys; output; - doutputs; - info = loop_sig_info; + fwd_info = loop_fwd_sig_info; + back_effect_info = fun_sig.back_effect_info; } in @@ -1427,7 +1419,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = (* Introduce the forward input state *) let fwd_state_var, fwd_state_lvs = assert ( - loop_effect_info.stateful = Option.is_some loop.input_state); + loop_fwd_effect_info.stateful + = Option.is_some loop.input_state); match loop.input_state with | None -> ([], []) | Some input_state -> @@ -1436,11 +1429,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = ([ state_var ], [ state_lvs ]) in - (* Introduce the additional backward inputs *) - (* TODO: *) - assert (not !Config.return_back_funs); + (* Introduce the additional backward inputs, if necessary *) let fun_body = Option.get def.body in - let info = fun_sig_info.fwd_info in + let info = fwd_info.fwd_info in let _, back_inputs = Collections.List.split_at fun_body.inputs info.num_inputs_with_fuel_with_state @@ -2063,14 +2054,12 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = - (* TODO: *) - assert (not !Config.return_back_funs); (* There should be a body *) let body = Option.get decl.body in (* We only look at the forward inputs, without the state *) let inputs_prefix, _ = Collections.List.split_at body.inputs - decl.signature.info.fwd_info.num_inputs_with_fuel_no_state + decl.signature.fwd_info.fwd_info.num_inputs_with_fuel_no_state in let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in let inputs_prefix_length = List.length inputs_prefix in @@ -2089,9 +2078,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in (* Set the fuel as used *) - let sg_info = decl.signature.info in - (* TODO: *) - assert (not !Config.return_back_funs); + let sg_info = decl.signature.fwd_info in if sg_info.fwd_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); @@ -2177,13 +2164,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) : let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { generics; llbc_generics; preds; inputs; output; doutputs; info } - = + let { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } = decl.signature in - (* TODO: *) - assert (not !Config.return_back_funs); - let { fwd_info; effect_info } = info in + let { fwd_info; effect_info } = fwd_info in let { has_fuel; @@ -2208,10 +2200,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let info = { fwd_info; effect_info } in - assert (fun_sig_info_is_wf info); + let fwd_info = { fwd_info; effect_info } in + assert (fun_sig_info_is_wf fwd_info); let signature = - { generics; llbc_generics; preds; inputs; output; doutputs; info } + { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } in { decl with signature } diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 3c1800a8..d60d6a05 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -120,7 +120,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (output_ty = e.ty); check_texpression ctx app; check_texpression ctx arg - | Abs (pat, body) -> + | Lambda (pat, body) -> let pat_ty, body_ty = destruct_arrow e.ty in assert (pat.ty = pat_ty); assert (body.ty = body_ty); @@ -229,12 +229,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | _ -> raise (Failure "Unexpected")) - | Lambda (pat, e_next) -> - assert (e.ty = e_next.ty); - (* Check the pattern and register the introduced variables at the same time *) - let ctx = check_typed_pattern ctx pat in - (* Check the next expression *) - check_texpression ctx e_next | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 80b25641..6e86578c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -215,8 +215,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> - false + | Var _ | CVar _ | Const _ | App _ | Qualif _ | StructUpdate _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -374,18 +373,6 @@ let opt_destruct_tuple (ty : ty) : ty list option = Some generics.types | _ -> None -let mk_abs (x : typed_pattern) (e : texpression) : texpression = - let ty = TArrow (x.ty, e.ty) in - let e = Abs (x, e) in - { e; ty } - -let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression = - match e.e with - | Abs (x, e') -> - let xl, e'' = destruct_abs_list e' in - (x :: xl, e'') - | _ -> ([], e) - let destruct_arrow (ty : ty) : ty * ty = match ty with | TArrow (ty0, ty1) -> (ty0, ty1) @@ -717,13 +704,16 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) info.is_tuple_struct | TAssumed _ -> false +let mk_lambda (x : typed_pattern) (e : texpression) : texpression = + let ty = TArrow (x.ty, e.ty) in + let e = Lambda (x, e) in + { e; ty } + let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : texpression = - let ty = TArrow (var.ty, e.ty) in let pat = PatVar (var, mp) in let pat = { value = pat; ty = var.ty } in - let e = Lambda (pat, e) in - { e; ty } + mk_lambda pat e let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) (e : texpression) : texpression = -- cgit v1.2.3 From 884edaa3ee975626f184249d491f343fc02a66e2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 18:54:06 +0100 Subject: Make progress on updating the code --- compiler/PureMicroPasses.ml | 48 ++++++---- compiler/SymbolicToPure.ml | 79 +++-------------- compiler/Translate.ml | 207 ++++++++++++++++++++------------------------ 3 files changed, 134 insertions(+), 200 deletions(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 0102b13e..a7c2f154 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -776,9 +776,11 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) in { def with body = Some body } -(** Given a forward or backward function call, is there, for every execution +(** For the cases where we split the forward/backward functions. + + Given a forward or backward function call, is there, for every execution path, a child backward function called later with exactly the same input - list prefix? We use this to filter useless function calls: if there are + list prefix. We use this to filter useless function calls: if there are such child calls, we can remove this one (in case its outputs are not used). We do this check because we can't simply remove function calls whose @@ -1008,17 +1010,21 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* We need to check if there is a child call - see - * the comments for: - * [expression_contains_child_call_in_all_paths] *) - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid lp_id - rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () + (* If we split the forward/backward functions. + + We need to check if there is a child call - see + the comments for: + [expression_contains_child_call_in_all_paths] *) + if not !Config.return_back_funs then + let has_child_call = + expression_contains_child_call_in_all_paths ctx fid + lp_id rg_id tys args e + in + if has_child_call then (* Filter *) + (e.e, fun _ -> used) + else (* No child call: don't filter *) + dont_filter () + else dont_filter () | _ -> (* Not an LLBC function call or not allowed to filter: we can't filter *) dont_filter () @@ -1509,9 +1515,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = altogether. *) let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - if + (* The question of filtering the forward functions arises only if we split + the forward/backward functions *) + if !Config.return_back_funs then true + else if + (* Note that at this point, the output types are no longer seen as tuples: + * they should be lists of length 1. *) !Config.filter_useless_functions && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] @@ -1957,9 +1966,10 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Remove the backward functions with no outputs. - * Note that the calls to those functions should already have been removed, - * when translating from symbolic to pure. Here, we remove the definitions - * altogether, because they are now useless *) + + Note that the *calls* to those functions should already have been removed, + when translating from symbolic to pure. Here, we remove the definitions + altogether, because they are now useless *) let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in let opt_def = filter_if_backward_with_no_outputs def in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 08f9e950..204fc399 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -45,7 +45,6 @@ type fun_sig_named_outputs = { type fun_context = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } @@ -144,7 +143,11 @@ type bs_ctx = { a symbolic expansion or upon ending an abstraction, for instance) we introduce a new variable (with a let-binding). *) - var_counter : VarId.generator; + var_counter : VarId.generator ref; + (** Using a ref to make sure all the variables identifiers are unique. + TODO: this is not very clean, and the code was initially written without + a reference (and it's shape hasn't changed). We should use DeBruijn indices. + *) state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; @@ -1131,13 +1134,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let state_var = { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in let state_pat = mk_typed_pattern_from_var state_var None in (* Update the context *) - let ctx = { ctx with var_counter; state_var = id } in + ctx.var_counter := var_counter; + let ctx = { ctx with state_var = id } in (* Return *) (ctx, state_var, state_pat) @@ -1146,11 +1150,11 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let ty = ctx_translate_fwd_ty ctx ty in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -1184,10 +1188,10 @@ let fresh_named_vars_for_symbolic_values let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -3303,65 +3307,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list = List.map (translate_type_decl ctx) (TypeDeclId.Map.values ctx.type_ctx.type_decls) -(** Translates function signatures. - - Takes as input a list of function information containing: - - the function id - - a list of optional names for the inputs - - the function signature - - Returns a map from forward/backward functions identifiers to: - - translated function signatures - - optional names for the outputs values (we derive them for the backward - functions) - *) -let translate_fun_signatures (decls_ctx : C.decls_ctx) - (functions : (A.fun_id * string option list * A.fun_sig) list) : - fun_sig_named_outputs RegularFunIdNotLoopMap.t = - (* For every function, translate the signatures of: - - the forward function - - the backward functions - *) - let translate_one (fun_id : A.fun_id) (input_names : string option list) - (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list - = - log#ldebug - (lazy - ("Translating signature of function: " - ^ Print.Expressions.fun_id_to_string - (Print.Contexts.decls_ctx_to_fmt_env decls_ctx) - fun_id)); - (* Retrieve the regions hierarchy *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in - (* The forward function *) - let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in - let fwd_id = (fun_id, None) in - (* The backward functions *) - let back_sgs = - if !Config.return_back_funs then [] - else - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy - in - (* Return *) - (fwd_id, fwd_sg) :: back_sgs - in - let translated = - List.concat - (List.map (fun (id, names, sg) -> translate_one id names sg) functions) - in - List.fold_left - (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) - RegularFunIdNotLoopMap.empty translated - let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) : trait_decl = let { diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 06d4bd6d..8b221c93 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -6,7 +6,6 @@ open LlbcAst open Contexts module SA = SymbolicAst module Micro = PureMicroPasses -open PureUtils open TranslateCore (** The local logger *) @@ -43,7 +42,6 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : pure_fun_translation_no_loops = (* Debug *) @@ -58,13 +56,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Convert the symbolic ASTs to pure ASTs: *) (* Initialize the context *) - let forward_sig = - RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs - in let sv_to_var = SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in - let back_state_var, var_counter = Pure.VarId.fresh var_counter in let fuel0, var_counter = Pure.VarId.fresh var_counter in let fuel, var_counter = Pure.VarId.fresh var_counter in let calls = FunCallId.Map.empty in @@ -89,7 +83,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let fun_context = { SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; - fun_sigs; fun_infos = trans_ctx.fun_ctx.fun_infos; regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } @@ -126,17 +119,45 @@ let translate_function_to_pure (trans_ctx : trans_ctx) !m in + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) + in + + let sg = + SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id) + fdef.signature input_names + in + + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_context.regions_hierarchies + in + + let var_counter, back_state_vars = + if !Config.return_back_funs then (var_counter, []) + else + List.fold_left_map + (fun var_counter (region_vars : region_var_group) -> + let gid = region_vars.id in + let var, var_counter = Pure.VarId.fresh var_counter in + (var_counter, (gid, var))) + var_counter regions_hierarchy + in + let back_state_vars = RegionGroupId.Map.of_list back_state_vars in + let ctx = { SymbolicToPure.bid = None; - (* Dummy for now *) - sg = forward_sig.sg; - fwd_sg = forward_sig.sg; + sg; (* Will need to be updated for the backward functions *) sv_to_var; - var_counter; + var_counter = ref var_counter; state_var; - back_state_var; + back_state_vars; fuel0; fuel; type_context; @@ -146,9 +167,11 @@ let translate_function_to_pure (trans_ctx : trans_ctx) trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; forward_inputs = []; - (* Empty for now *) - backward_inputs = RegionGroupId.Map.empty; - (* Empty for now *) + (* Initialized just below *) + backward_inputs_no_state = RegionGroupId.Map.empty; + (* Initialized just below *) + backward_inputs_with_state = RegionGroupId.Map.empty; + (* Initialized just below *) backward_outputs = RegionGroupId.Map.empty; loop_backward_outputs = None; (* Empty for now *) @@ -180,6 +203,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | _ -> raise (Failure "Unreachable") in + (* Add the backward inputs *) + let ctx, backward_inputs_no_state, backward_inputs_with_state = + if !Config.return_back_funs then (ctx, [], []) + else + let ctx, inputs_no_with_state = + List.fold_left_map + (fun ctx (region_vars : region_var_group) -> + let gid = region_vars.id in + let back_sg = RegionGroupId.Map.find gid sg.back_sg in + let ctx, no_state = + SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx + in + let ctx, with_state = + SymbolicToPure.fresh_vars back_sg.inputs ctx + in + (ctx, ((gid, no_state), (gid, with_state)))) + ctx regions_hierarchy + in + let inputs_no_state, inputs_with_state = + List.split inputs_no_with_state + in + (ctx, inputs_no_state, inputs_with_state) + in + let backward_inputs_no_state = + RegionGroupId.Map.of_list backward_inputs_no_state + in + let backward_inputs_with_state = + RegionGroupId.Map.of_list backward_inputs_with_state + in + let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in + + (* Add the backward outputs *) + let ctx, backward_outputs = + List.fold_left_map + (fun ctx (region_vars : region_var_group) -> + let gid = region_vars.id in + let back_sg = RegionGroupId.Map.find gid sg.back_sg in + let outputs = List.combine back_sg.output_names back_sg.outputs in + let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in + (ctx, (gid, vars))) + ctx regions_hierarchy + in + let backward_outputs = RegionGroupId.Map.of_list backward_outputs in + let ctx = { ctx with backward_outputs } in + (* Translate the forward function *) let pure_forward = match symbolic_trans with @@ -187,7 +255,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) in - (* Translate the backward functions *) + (* Translate the backward functions, if we split the forward and backward functions *) let translate_backward (rg : region_var_group) : Pure.fun_decl = (* For the backward inputs/outputs initialization: we use the fact that * there are no nested borrows for now, and so that the region groups @@ -197,83 +265,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx) match symbolic_trans with | None -> - (* Initialize the context - note that the ret_ty is not really - * useful as we don't translate a body *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx None | Some (_, symbolic) -> - (* Finish initializing the context by adding the additional input - variables required by the backward function. - *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - (* We need to ignore the forward inputs, and the state input (if there is) *) - let backward_inputs = - let sg = backward_sg.sg in - (* TODO: *) - assert (not !Config.return_back_funs); - (* We need to ignore the forward state and the backward state *) - let num_forward_inputs = - sg.info.fwd_info.num_inputs_with_fuel_with_state - in - let num_back_inputs = - match sg.info.back_info with - | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state - | _ -> raise (Failure "Unexpected") - in - Collections.List.subslice sg.inputs num_forward_inputs - (num_forward_inputs + num_back_inputs) - in - (* As we forbid nested borrows, the additional inputs for the backward - * functions come from the borrows in the return value of the rust function: - * we thus use the name "ret" for those inputs *) - let backward_inputs = - List.map (fun ty -> (Some "ret", ty)) backward_inputs - in - let ctx, backward_inputs = - SymbolicToPure.fresh_vars backward_inputs ctx - in - (* The outputs for the backward functions, however, come from borrows - * present in the input values of the rust function: for those we reuse - * the names of the input values. *) - let backward_outputs = - List.combine backward_sg.output_names backward_sg.sg.doutputs - in - let ctx, backward_outputs = - SymbolicToPure.fresh_vars backward_outputs ctx - in - let backward_inputs = - RegionGroupId.Map.singleton back_id backward_inputs - in - let backward_outputs = - RegionGroupId.Map.singleton back_id backward_outputs - in - - (* Put everything in the context *) - let ctx = - { - ctx with - bid = Some back_id; - sg = backward_sg.sg; - backward_inputs; - backward_outputs; - } - in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx (Some symbolic) in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id) - fun_context.regions_hierarchies + let pure_backwards = + if !Config.return_back_funs then [] + else List.map translate_backward regions_hierarchy in - let pure_backwards = List.map translate_backward regions_hierarchy in (* Return *) (pure_forward, pure_backwards) @@ -300,36 +305,10 @@ let translate_crate_to_pure (crate : crate) : (List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls) in - (* Translate all the function *signatures* *) - let assumed_sigs = - List.map - (fun (info : Assumed.assumed_fun_info) -> - ( FAssumed info.fun_id, - List.map (fun _ -> None) info.fun_sig.inputs, - info.fun_sig )) - Assumed.assumed_fun_infos - in - let local_sigs = - List.map - (fun (fdef : fun_decl) -> - let input_names = - match fdef.body with - | None -> List.map (fun _ -> None) fdef.signature.inputs - | Some body -> - List.map - (fun (v : var) -> v.name) - (LlbcAstUtils.fun_body_get_input_vars body) - in - (FRegular fdef.def_id, input_names, fdef.signature)) - (FunDeclId.Map.values crate.fun_decls) - in - let sigs = List.append assumed_sigs local_sigs in - let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in - (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure trans_ctx fun_sigs type_decls_map) + (translate_function_to_pure trans_ctx type_decls_map) (FunDeclId.Map.values crate.fun_decls) in @@ -1036,7 +1015,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : List.map (fun { fwd; _ } -> let fwd_f = - if fwd.f.Pure.signature.info.effect_info.is_rec then + if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then [ (fwd.f.def_id, None) ] else [] in -- cgit v1.2.3 From 2fb4ca72b112f6181d74d1ca37ed6d54c65f43cd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 10:11:44 +0100 Subject: Do not register the names of the back funs if they are merged with the fwd funs --- compiler/Extract.ml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 7e2efd8a..3429cd11 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -43,8 +43,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) } | _ -> ctx in - let backs = List.map (fun f -> f.f) def.backs in - let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + let funs = + if !Config.return_back_funs then [ def.fwd.f ] + else + let backs = List.map (fun f -> f.f) def.backs in + if def.keep_fwd then def.fwd.f :: backs else backs + in List.fold_left (fun ctx (f : fun_decl) -> let open ExtractBuiltin in @@ -1988,7 +1992,8 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) (* We add one field per required forward/backward function *) let get_funs_for_id (id : fun_decl_id) : fun_decl list = let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in - List.map (fun f -> f.f) (trans.fwd :: trans.backs) + if !Config.return_back_funs then [ trans.fwd.f ] + else List.map (fun f -> f.f) (trans.fwd :: trans.backs) in match builtin_info with | None -> -- cgit v1.2.3 From a49754a5b11e4de8793dc7e13c2962d139eb03b1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 10:21:08 +0100 Subject: Rename some definitions --- compiler/AssociatedTypes.ml | 10 ++--- compiler/Contexts.ml | 42 +++++++++--------- compiler/Interpreter.ml | 8 ++-- compiler/InterpreterBorrows.ml | 12 +++--- compiler/InterpreterExpansion.ml | 2 +- compiler/InterpreterExpressions.ml | 5 +-- compiler/InterpreterLoopsFixedPoint.ml | 18 ++++---- compiler/InterpreterLoopsJoinCtxs.ml | 36 ++++++++-------- compiler/InterpreterLoopsMatchCtxs.ml | 4 +- compiler/InterpreterStatements.ml | 12 +++--- compiler/InterpreterUtils.ml | 26 ++++++------ compiler/Invariants.ml | 2 +- compiler/Print.ml | 10 ++--- compiler/SymbolicToPure.ml | 78 +++++++++++++++++----------------- compiler/Translate.ml | 14 +++--- 15 files changed, 139 insertions(+), 140 deletions(-) (limited to 'compiler') diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index e2f687e8..054c8169 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -493,11 +493,11 @@ let norm_ctx_normalize_trait_type_constraint (ctx : norm_ctx) let mk_norm_ctx (ctx : eval_ctx) : norm_ctx = { norm_trait_types = ctx.norm_trait_types; - type_decls = ctx.type_context.type_decls; - fun_decls = ctx.fun_context.fun_decls; - global_decls = ctx.global_context.global_decls; - trait_decls = ctx.trait_decls_context.trait_decls; - trait_impls = ctx.trait_impls_context.trait_impls; + type_decls = ctx.type_ctx.type_decls; + fun_decls = ctx.fun_ctx.fun_decls; + global_decls = ctx.global_ctx.global_decls; + trait_decls = ctx.trait_decls_ctx.trait_decls; + trait_impls = ctx.trait_impls_ctx.trait_impls; type_vars = ctx.type_vars; const_generic_vars = ctx.const_generic_vars; } diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index a30ed0f1..5d646a61 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -180,35 +180,35 @@ type config = { let mk_config (mode : interpreter_mode) : config = { mode } -type type_context = { +type type_ctx = { type_decls_groups : type_declaration_group TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; type_infos : TypesAnalysis.type_infos; } [@@deriving show] -type fun_context = { +type fun_ctx = { fun_decls : fun_decl FunDeclId.Map.t; fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t; regions_hierarchies : region_var_groups FunIdMap.t; } [@@deriving show] -type global_context = { global_decls : global_decl GlobalDeclId.Map.t } +type global_ctx = { global_decls : global_decl GlobalDeclId.Map.t } [@@deriving show] -type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t } +type trait_decls_ctx = { trait_decls : trait_decl TraitDeclId.Map.t } [@@deriving show] -type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t } +type trait_impls_ctx = { trait_impls : trait_impl TraitImplId.Map.t } [@@deriving show] type decls_ctx = { - type_ctx : type_context; - fun_ctx : fun_context; - global_ctx : global_context; - trait_decls_ctx : trait_decls_context; - trait_impls_ctx : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; } [@@deriving show] @@ -230,11 +230,11 @@ module TraitTypeRefMap = Collections.MakeMap (TraitTypeRefOrd) (** Evaluation context *) type eval_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; - trait_decls_context : trait_decls_context; - trait_impls_context : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; region_groups : RegionGroupId.id list; type_vars : type_var list; const_generic_vars : const_generic_var list; @@ -290,20 +290,20 @@ let ctx_lookup_var_binder (ctx : eval_ctx) (vid : VarId.id) : var_binder = fst (env_lookup_var ctx.env vid) let ctx_lookup_type_decl (ctx : eval_ctx) (tid : TypeDeclId.id) : type_decl = - TypeDeclId.Map.find tid ctx.type_context.type_decls + TypeDeclId.Map.find tid ctx.type_ctx.type_decls let ctx_lookup_fun_decl (ctx : eval_ctx) (fid : FunDeclId.id) : fun_decl = - FunDeclId.Map.find fid ctx.fun_context.fun_decls + FunDeclId.Map.find fid ctx.fun_ctx.fun_decls let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : global_decl = - GlobalDeclId.Map.find gid ctx.global_context.global_decls + GlobalDeclId.Map.find gid ctx.global_ctx.global_decls let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl = - TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls + TraitDeclId.Map.find id ctx.trait_decls_ctx.trait_decls let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl = - TraitImplId.Map.find id ctx.trait_impls_context.trait_impls + TraitImplId.Map.find id ctx.trait_impls_ctx.trait_impls (** Retrieve a variable's value in the current frame *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = @@ -528,7 +528,7 @@ let ctx_set_abs_can_end (ctx : eval_ctx) (abs_id : AbstractionId.id) fst (ctx_subst_abs ctx abs_id abs) let ctx_type_decl_is_rec (ctx : eval_ctx) (id : TypeDeclId.id) : bool = - let decl_group = TypeDeclId.Map.find id ctx.type_context.type_decls_groups in + let decl_group = TypeDeclId.Map.find id ctx.type_ctx.type_decls_groups in match decl_group with RecGroup _ -> true | NonRecGroup _ -> false (** Visitor to iterate over the values in the *current* frame *) diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 76432faa..22d176c9 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -195,7 +195,7 @@ let initialize_symbolic_context_for_fun (ctx : decls_ctx) (fdef : fun_decl) : List.map (fun (g : region_var_group) -> g.id) regions_hierarchy in let ctx = - initialize_eval_context ctx region_groups sg.generics.types + initialize_eval_ctx ctx region_groups sg.generics.types sg.generics.const_generics in (* Instantiate the signature. This updates the context because we compute @@ -277,7 +277,7 @@ let evaluate_function_symbolic_synthesize_backward_from_return (config : config) * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let regions_hierarchy = - FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies in let _, ret_inst_sg = symbolic_instantiate_fun_sig ctx fdef.signature regions_hierarchy fdef.kind @@ -466,7 +466,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : decls_ctx) let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in let regions_hierarchy = - FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies in (* Create the continuation to finish the evaluation *) @@ -615,7 +615,7 @@ module Test = struct assert (body.arg_count = 0); (* Create the evaluation context *) - let ctx = initialize_eval_context decls_ctx [] [] [] in + let ctx = initialize_eval_ctx decls_ctx [] [] [] in (* Insert the (uninitialized) local variables *) let ctx = ctx_push_uninitialized_vars ctx body.locals in diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index e56919fa..a2eb2545 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -1628,7 +1628,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) push { value; ty } | AIgnoredMutLoan (opt_bid, child_av) -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); assert (opt_bid = None); (* Simply explore the child *) list_avalues false push_fail child_av @@ -1639,7 +1639,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) { child = child_av; given_back = _; given_back_meta = _ } | AIgnoredSharedLoan child_av -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); (* Simply explore the child *) list_avalues false push_fail child_av) | ABorrow bc -> ( @@ -1659,14 +1659,14 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) push av | AIgnoredMutBorrow (opt_bid, child_av) -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); assert (opt_bid = None); (* Just explore the child *) list_avalues false push_fail child_av | AEndedIgnoredMutBorrow { child = child_av; given_back = _; given_back_meta = _ } -> (* We don't support nested borrows for now *) - assert (not (ty_has_borrows ctx.type_context.type_infos child_av.ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos child_av.ty)); (* Just explore the child *) list_avalues false push_fail child_av | AProjSharedBorrow asb -> @@ -1683,7 +1683,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) | ASymbolic _ -> (* For now, we fore all symbolic values containing borrows to be eagerly expanded *) - assert (not (ty_has_borrows ctx.type_context.type_infos ty)) + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty)) and list_values (v : typed_value) : typed_avalue list * typed_value = let ty = v.ty in match v.value with @@ -1732,7 +1732,7 @@ let destructure_abs (abs_kind : abs_kind) (can_end : bool) | VSymbolic _ -> (* For now, we fore all symbolic values containing borrows to be eagerly expanded *) - assert (not (ty_has_borrows ctx.type_context.type_infos ty)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty)); ([], v) in diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index bbf4d9d5..e489ddc3 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -627,7 +627,7 @@ let greedy_expand_symbolics_with_borrows (config : config) : cm_fun = inherit [_] iter_eval_ctx method! visit_VSymbolic _ sv = - if ty_has_borrows ctx.type_context.type_infos sv.sv_ty then + if ty_has_borrows ctx.type_ctx.type_infos sv.sv_ty then raise (FoundSymbolicValue sv) else () diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 1b5b79dd..8536b4ab 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -32,8 +32,7 @@ let expand_primitively_copyable_at_place (config : config) fun cf ctx -> let v = read_place access p ctx in match - find_first_primitively_copyable_sv_with_borrows - ctx.type_context.type_infos v + find_first_primitively_copyable_sv_with_borrows ctx.type_ctx.type_infos v with | None -> cf ctx | Some sv -> @@ -351,7 +350,7 @@ let eval_operand_no_reorganize (config : config) (op : operand) assert ( Option.is_none (find_first_primitively_copyable_sv_with_borrows - ctx.type_context.type_infos v)); + ctx.type_ctx.type_infos v)); (* Actually perform the copy *) let allow_adt_copy = false in let ctx, v = copy_value allow_adt_copy config ctx v in diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml index c4e180fa..4dabe974 100644 --- a/compiler/InterpreterLoopsFixedPoint.ml +++ b/compiler/InterpreterLoopsFixedPoint.ml @@ -300,7 +300,7 @@ let prepare_ashared_loans (loop_id : LoopId.id option) : cm_fun = let env = List.append fresh_absl env in let ctx = { ctx with env } in - let _, new_ctx_ids_map = compute_context_ids ctx in + let _, new_ctx_ids_map = compute_ctx_ids ctx in (* Synthesize *) match cf ctx with @@ -385,8 +385,8 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) match !fixed_ids with | Some _ -> ctx1 | None -> - let old_ids, _ = compute_context_ids ctx1 in - let new_ids, _ = compute_contexts_ids !ctxs in + let old_ids, _ = compute_ctx_ids ctx1 in + let new_ids, _ = compute_ctxs_ids !ctxs in let blids = BorrowId.Set.diff old_ids.blids new_ids.blids in let aids = AbstractionId.Set.diff old_ids.aids new_ids.aids in (* End those borrows and abstractions *) @@ -409,7 +409,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) ctxs := List.map (end_borrows_abs blids aids) !ctxs; (* Note that the fixed ids are given by the original context, from *before* we introduce fresh abstractions/reborrows for the shared values *) - fixed_ids := Some (fst (compute_context_ids ctx0)); + fixed_ids := Some (fst (compute_ctx_ids ctx0)); ctx1 in @@ -424,12 +424,12 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) intersection of ids between the original environment and the list of new environments *) let compute_fixed_ids (ctxl : eval_ctx list) : ids_sets = - let fixed_ids, _ = compute_context_ids ctx0 in + let fixed_ids, _ = compute_ctx_ids ctx0 in let { aids; blids; borrow_ids; loan_ids; dids; rids; sids } = fixed_ids in let sids = ref sids in List.iter (fun ctx -> - let fixed_ids, _ = compute_context_ids ctx in + let fixed_ids, _ = compute_ctx_ids ctx in sids := SymbolicValueId.Set.inter !sids fixed_ids.sids) ctxl; let sids = !sids in @@ -568,7 +568,7 @@ let compute_loop_entry_fixed_point (config : config) (loop_id : LoopId.id) InterpreterBorrows.end_abstraction_no_synth config abs_id ctx in (* Explore the context, and check which abstractions are not there anymore *) - let ids, _ = compute_context_ids ctx in + let ids, _ = compute_ctx_ids ctx in let ended_ids = AbstractionId.Set.diff !fp_aids ids.aids in add_ended_aids rg_id ended_ids) ctx.region_groups @@ -840,8 +840,8 @@ let compute_fixed_point_id_correspondance (fixed_ids : ids_sets) let compute_fp_ctx_symbolic_values (ctx : eval_ctx) (fp_ctx : eval_ctx) : SymbolicValueId.Set.t * symbolic_value list = - let old_ids, _ = compute_context_ids ctx in - let fp_ids, fp_ids_maps = compute_context_ids fp_ctx in + let old_ids, _ = compute_ctx_ids ctx in + let fp_ids, fp_ids_maps = compute_ctx_ids fp_ctx in let fresh_sids = SymbolicValueId.Set.diff fp_ids.sids old_ids.sids in (* Compute the set of symbolic values which appear in shared values inside diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index 8d485483..445e5abf 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -326,8 +326,8 @@ let mk_collapse_ctx_merge_duplicate_funs (loop_id : LoopId.id) (ctx : eval_ctx) let _ = let _, ty0, _ = ty_as_ref ty0 in let _, ty1, _ = ty_as_ref ty1 in - assert (not (ty_has_borrows ctx.type_context.type_infos ty0)); - assert (not (ty_has_borrows ctx.type_context.type_infos ty1)) + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty0)); + assert (not (ty_has_borrows ctx.type_ctx.type_infos ty1)) in (* Same remarks as for [merge_amut_borrows] *) @@ -543,11 +543,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) (* Construct the joined context - of course, the type, fun, etc. contexts * should be the same in the two contexts *) let { - type_context; - fun_context; - global_context; - trait_decls_context; - trait_impls_context; + type_ctx; + fun_ctx; + global_ctx; + trait_decls_ctx; + trait_impls_ctx; region_groups; type_vars; const_generic_vars; @@ -559,11 +559,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) ctx0 in let { - type_context = _; - fun_context = _; - global_context = _; - trait_decls_context = _; - trait_impls_context = _; + type_ctx = _; + fun_ctx = _; + global_ctx = _; + trait_decls_ctx = _; + trait_impls_ctx = _; region_groups = _; type_vars = _; const_generic_vars = _; @@ -577,11 +577,11 @@ let join_ctxs (loop_id : LoopId.id) (fixed_ids : ids_sets) (ctx0 : eval_ctx) let ended_regions = RegionId.Set.union ended_regions0 ended_regions1 in Ok { - type_context; - fun_context; - global_context; - trait_decls_context; - trait_impls_context; + type_ctx; + fun_ctx; + global_ctx; + trait_decls_ctx; + trait_impls_ctx; region_groups; type_vars; const_generic_vars; @@ -621,7 +621,7 @@ let destructure_new_abs (loop_id : LoopId.id) contexts we join don't have non-fixed abstractions with the same ids. *) let refresh_abs (old_abs : AbstractionId.Set.t) (ctx : eval_ctx) : eval_ctx = - let ids, _ = compute_context_ids ctx in + let ids, _ = compute_ctx_ids ctx in let abs_to_refresh = AbstractionId.Set.diff ids.aids old_abs in let aids_subst = List.map diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 90559c29..2a688fa7 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -658,7 +658,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct else ( (* The caller should have checked that the symbolic values don't contain borrows *) - assert (not (ty_has_borrows S.ctx.type_context.type_infos sv0.sv_ty)); + assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv0.sv_ty)); (* We simply introduce a fresh symbolic value *) mk_fresh_symbolic_value sv0.sv_ty) @@ -669,7 +669,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct - there are no borrows in the "regular" value If there are loans in the regular value, raise an exception. *) - assert (not (ty_has_borrows S.ctx.type_context.type_infos sv.sv_ty)); + assert (not (ty_has_borrows S.ctx.type_ctx.type_infos sv.sv_ty)); assert (not (value_has_borrows S.ctx v.value)); let value_is_left = not left in (match InterpreterBorrowsCore.get_first_loan_in_value v with diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 30b7b333..da617c64 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -747,7 +747,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) let tr_self = UnknownTrait __FUNCTION__ in let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular fid) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let inst_sg = instantiate_fun_sig ctx func.generics tr_self def.signature @@ -793,7 +793,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) let fid : fun_id = FRegular id in let regions_hierarchy = LlbcAstUtils.FunIdMap.find fid - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let inst_sg = instantiate_fun_sig ctx generics tr_self @@ -853,7 +853,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) method_def.signature.parent_params_info)); let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular method_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let tr_self = TraitRef trait_ref in let inst_sg = @@ -884,7 +884,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) (* Instantiate *) let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular method_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in let tr_self = TraitRef trait_ref in let inst_sg = @@ -1450,7 +1450,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) * this is a current limitation of our synthesis *) assert ( List.for_all - (fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty)) + (fun ty -> not (ty_has_borrows ctx.type_ctx.type_infos ty)) generics.types); (* There are two cases (and this is extremely annoying): @@ -1476,7 +1476,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) | _ -> let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FAssumed fid) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in (* There shouldn't be any reference to Self *) let tr_self = UnknownTrait __FUNCTION__ in diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index e04a6b90..a1a06ee5 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -265,7 +265,7 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : eval_ctx) inherit [_] iter_typed_value method! visit_symbolic_value _ s = - if ty_has_borrow_under_mut ctx.type_context.type_infos s.sv_ty then + if ty_has_borrow_under_mut ctx.type_ctx.type_infos s.sv_ty then raise Found else () end @@ -288,15 +288,15 @@ let rvalue_get_place (rv : rvalue) : place option = (** See {!ValuesUtils.symbolic_value_has_borrows} *) let symbolic_value_has_borrows (ctx : eval_ctx) (sv : symbolic_value) : bool = - ValuesUtils.symbolic_value_has_borrows ctx.type_context.type_infos sv + ValuesUtils.symbolic_value_has_borrows ctx.type_ctx.type_infos sv (** See {!ValuesUtils.value_has_borrows}. *) let value_has_borrows (ctx : eval_ctx) (v : value) : bool = - ValuesUtils.value_has_borrows ctx.type_context.type_infos v + ValuesUtils.value_has_borrows ctx.type_ctx.type_infos v (** See {!ValuesUtils.value_has_loans_or_borrows}. *) let value_has_loans_or_borrows (ctx : eval_ctx) (v : value) : bool = - ValuesUtils.value_has_loans_or_borrows ctx.type_context.type_infos v + ValuesUtils.value_has_loans_or_borrows ctx.type_ctx.type_infos v (** See {!ValuesUtils.value_has_loans}. *) let value_has_loans (v : value) : bool = ValuesUtils.value_has_loans v @@ -401,19 +401,19 @@ let compute_env_elem_ids (x : env_elem) : ids_sets * ids_to_values = compute_env_ids [ x ] (** Compute the sets of ids found in a list of contexts. *) -let compute_contexts_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values = +let compute_ctxs_ids (ctxl : eval_ctx list) : ids_sets * ids_to_values = let compute, get_ids, get_ids_to_values = compute_ids () in List.iter (compute#visit_eval_ctx ()) ctxl; (get_ids (), get_ids_to_values ()) (** Compute the sets of ids found in a context. *) -let compute_context_ids (ctx : eval_ctx) : ids_sets * ids_to_values = - compute_contexts_ids [ ctx ] +let compute_ctx_ids (ctx : eval_ctx) : ids_sets * ids_to_values = + compute_ctxs_ids [ ctx ] (** **WARNING**: this function doesn't compute the normalized types (for the trait type aliases). This should be computed afterwards. *) -let initialize_eval_context (ctx : decls_ctx) +let initialize_eval_ctx (ctx : decls_ctx) (region_groups : RegionGroupId.id list) (type_vars : type_var list) (const_generic_vars : const_generic_var list) : eval_ctx = reset_global_counters (); @@ -427,11 +427,11 @@ let initialize_eval_context (ctx : decls_ctx) const_generic_vars) in { - type_context = ctx.type_ctx; - fun_context = ctx.fun_ctx; - global_context = ctx.global_ctx; - trait_decls_context = ctx.trait_decls_ctx; - trait_impls_context = ctx.trait_impls_ctx; + type_ctx = ctx.type_ctx; + fun_ctx = ctx.fun_ctx; + global_ctx = ctx.global_ctx; + trait_decls_ctx = ctx.trait_decls_ctx; + trait_impls_ctx = ctx.trait_impls_ctx; region_groups; type_vars; const_generic_vars; diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index fa0d7436..b87cdff7 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -768,7 +768,7 @@ let check_symbolic_values (ctx : eval_ctx) : unit = assert (info.env_count = 0 || info.aproj_borrows = []); (* A symbolic value containing borrows can't be duplicated (i.e., copied): * it must be expanded first *) - if ty_has_borrows ctx.type_context.type_infos info.ty then + if ty_has_borrows ctx.type_ctx.type_infos info.ty then assert (info.env_count <= 1); (* A duplicated symbolic value is necessarily primitively copyable *) assert (info.env_count <= 1 || ty_is_primitively_copyable info.ty); diff --git a/compiler/Print.ml b/compiler/Print.ml index 0e2ec1fc..8999c77d 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -409,11 +409,11 @@ module Contexts = struct } let eval_ctx_to_fmt_env (ctx : eval_ctx) : fmt_env = - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in - let trait_decls = ctx.trait_decls_context.trait_decls in - let trait_impls = ctx.trait_impls_context.trait_impls in + let type_decls = ctx.type_ctx.type_decls in + let fun_decls = ctx.fun_ctx.fun_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in (* Below: it is always safe to omit fields - if an id can't be found at printing time, we print the id (in raw form) instead of the name it designates. *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 204fc399..d8213317 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -15,7 +15,7 @@ module PP = PrintPure (** The local logger *) let log = Logging.symbolic_to_pure_log -type type_context = { +type type_ctx = { llbc_type_decls : T.type_decl TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; (** We use this for type-checking (for sanity checks) when translating @@ -43,18 +43,18 @@ type fun_sig_named_outputs = { } [@@deriving show] -type fun_context = { +type fun_ctx = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } [@@deriving show] -type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } +type global_ctx = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } [@@deriving show] -type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] -type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show] +type trait_decls_ctx = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] +type trait_impls_ctx = A.trait_impl A.TraitImplId.Map.t [@@deriving show] (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended @@ -120,11 +120,11 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { - type_context : type_context; (* TODO: rename *) - fun_context : fun_context; (* TODO: rename *) - global_context : global_context; (* TODO: rename *) - trait_decls_ctx : trait_decls_context; - trait_impls_ctx : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; fun_decl : A.fun_decl; bid : RegionGroupId.id option; (** TODO: rename @@ -234,9 +234,9 @@ type bs_ctx = { (* TODO: move *) let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let { regions; types; const_generics; trait_clauses } : T.generic_params = @@ -258,9 +258,9 @@ let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = } let bs_ctx_to_pure_fmt_env (ctx : bs_ctx) : PrintPure.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let generics = ctx.sg.generics in @@ -346,11 +346,11 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = - TypeDeclId.Map.find id ctx.type_context.llbc_type_decls + TypeDeclId.Map.find id ctx.type_ctx.llbc_type_decls let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = - A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls + A.FunDeclId.Map.find id ctx.fun_ctx.llbc_fun_decls (* Some generic translation functions (we need to translate different "flavours" of types: forward types, backward types, etc.) *) @@ -617,13 +617,13 @@ and translate_fwd_trait_instance_id (type_infos : type_infos) (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : T.ty) : ty = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_ty type_infos ty (** Simply calls [translate_fwd_generic_args] *) let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : T.generic_args) : generic_args = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_generic_args type_infos generics (** Translate a type, when some regions may have ended. @@ -708,7 +708,7 @@ let rec translate_back_ty (type_infos : type_infos) (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : T.ty) : ty option = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_back_ty type_infos keep_region inside_mut ty let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = @@ -721,8 +721,8 @@ let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = in let env = VarId.Map.empty in { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; + PureTypeCheck.type_decls = ctx.type_ctx.type_decls; + global_decls = ctx.global_ctx.llbc_global_decls; env; const_generics; } @@ -742,7 +742,7 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) match id with | FunId fun_id -> FunId fun_id | TraitMethod (trait_ref, method_name, fun_decl_id) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in TraitMethod (trait_ref, method_name, fun_decl_id) @@ -894,7 +894,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) List.map (fun (g : T.region_var_group) -> g.id) regions_hierarchy in let ctx = - InterpreterUtils.initialize_eval_context decls_ctx region_groups + InterpreterUtils.initialize_eval_ctx decls_ctx region_groups sg.generics.types sg.generics.const_generics in (* Compute the normalization map for the *sty* types and add it to the context *) @@ -1786,7 +1786,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None None + get_fun_effect_info ctx.fun_ctx.fun_infos fid None None in (* Depending on the function effects: * - add the fuel @@ -2006,7 +2006,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) raise (Failure "Unreachable") in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) + get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) @@ -2194,8 +2194,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopCall -> let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id) - (Some vloop_id) (Some rg_id) + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) + (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2306,7 +2306,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in - let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in + let decl = A.GlobalDeclId.Map.find gid ctx.global_ctx.llbc_global_decls in let global_expr = { id = Global gid; generics = empty_generic_args } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in @@ -2482,7 +2482,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) - if we forbid using field projectors. *) let is_rec_def = - T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls + T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls in let use_let_with_cons = is_enum @@ -2495,7 +2495,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) like Coq don't, in which case we have to deconstruct the whole ADT at once (`let (a, b, c) = x in`) *) || TypesUtils.type_decl_from_type_id_is_tuple_struct - ctx.type_context.type_infos type_id + ctx.type_ctx.type_infos type_id && not (Config.backend_has_tuple_projectors ()) in if use_let_with_cons then @@ -2588,7 +2588,7 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) { e = StructUpdate su; ty = var.ty } | VaCgValue cg_id -> { e = CVar cg_id; ty = var.ty } | VaTraitConstValue (trait_ref, generics, const_name) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in let generics = translate_fwd_generic_args type_infos generics in let qualif_id = TraitConst (trait_ref, generics, const_name) in @@ -2722,7 +2722,7 @@ and translate_forward_end (ectx : C.eval_ctx) let sg = ctx.fun_decl.signature in let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in List.map (fun (gid, _) -> @@ -2816,7 +2816,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) @@ -2949,7 +2949,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> assert ( - not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty)); + not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); (None, ctx_translate_fwd_ty !ctx ty)) tys in @@ -3173,7 +3173,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate the signature *) let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies in (* Translate the body, if there is *) let body = @@ -3181,8 +3181,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos - (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id)) + None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 8b221c93..e153f4f4 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -72,7 +72,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | RecGroup _ -> Some tid) (TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) in - let type_context = + let type_ctx = { SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos; llbc_type_decls = trans_ctx.type_ctx.type_decls; @@ -80,14 +80,14 @@ let translate_function_to_pure (trans_ctx : trans_ctx) recursive_decls = recursive_type_decls; } in - let fun_context = + let fun_ctx = { SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; fun_infos = trans_ctx.fun_ctx.fun_infos; regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } in - let global_context = + let global_ctx = { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls } in @@ -134,7 +134,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_context.regions_hierarchies + LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies in let var_counter, back_state_vars = @@ -160,9 +160,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) back_state_vars; fuel0; fuel; - type_context; - fun_context; - global_context; + type_ctx; + fun_ctx; + global_ctx; trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls; trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; -- cgit v1.2.3 From 17973e99e4784ff5e31565622d183ad89e3d9cd7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 11:40:44 +0100 Subject: Add some comments --- compiler/SymbolicToPure.ml | 47 ++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d8213317..a79340b6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -151,8 +151,8 @@ type bs_ctx = { state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; - (** The additional input state variable received by a stateful backward function, - **in case we are splitting the forward/backward functions**. + (** The additional input state variable received by a stateful backward + function, **in case we are splitting the forward/backward functions**. When generating stateful functions, we generate code of the following form: @@ -195,7 +195,22 @@ type bs_ctx = { *) backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding - to the borrows they give back (don't include the backward state) + to the borrows they give back (don't include the backward state). + + The translation is done as follows: + - for a given backward function, we choose a set of variables [v_i] + - when we detect the ended input abstraction which corresponds + to the backward function of the LLBC function we are translating, + and which consumed the values [consumed_i] (that we need to give + back to the caller), we introduce: + {[ + let v_i = consumed_i in + ... + ]} + Then, upon reaching the [Return] node, we introduce: + {[ + (v_i) + ]} *) loop_backward_outputs : var list RegionGroupId.Map.t option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). @@ -1930,19 +1945,19 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) assert (rg_id = bid); (* The translation is done as follows: - * - for a given backward function, we choose a set of variables [v_i] - * - when we detect the ended input abstraction which corresponds - * to the backward function, and which consumed the values [consumed_i], - * we introduce: - * {[ - * let v_i = consumed_i in - * ... - * ]} - * Then, when we reach the [Return] node, we introduce: - * {[ - * (v_i) - * ]} - * *) + - for a given backward function, we choose a set of variables [v_i] + - when we detect the ended input abstraction which corresponds + to the backward function, and which consumed the values [consumed_i], + we introduce: + {[ + let v_i = consumed_i in + ... + ]} + Then, when we reach the [Return] node, we introduce: + {[ + (v_i) + ]} + *) (* First, get the given back variables. We don't use the same given back variables if we translate a loop or -- cgit v1.2.3 From 999f48d032107722aa6ca714da828ab2788ca412 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 12:06:07 +0100 Subject: Fix a minor mistake in SymbolicToPure --- compiler/PrintPure.ml | 12 +++++++----- compiler/PureMicroPasses.ml | 9 +++++++-- compiler/SymbolicToPure.ml | 5 +---- 3 files changed, 15 insertions(+), 11 deletions(-) (limited to 'compiler') diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 79506c04..1ce146a4 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -103,10 +103,13 @@ let adt_field_names (env : fmt_env) = Print.Types.adt_field_names (fmt_env_to_llbc_fmt_env env) let option_to_string = Print.option_to_string -let type_var_to_string = Print.Types.type_var_to_string -let const_generic_var_to_string = Print.Types.const_generic_var_to_string -let integer_type_to_string = Print.Values.integer_type_to_string let literal_type_to_string = Print.Values.literal_type_to_string +let type_var_to_string (v : type_var) = "(" ^ v.name ^ ": Type)" + +let const_generic_var_to_string (v : const_generic_var) = + "(" ^ v.name ^ " : " ^ literal_type_to_string v.ty ^ ")" + +let integer_type_to_string = Print.Values.integer_type_to_string let scalar_value_to_string = Print.Values.scalar_value_to_string let literal_to_string = Print.Values.literal_to_string @@ -203,13 +206,12 @@ and trait_instance_id_to_string (env : fmt_env) (inside : bool) | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")" let trait_clause_to_string (env : fmt_env) (clause : trait_clause) : string = - let clause_id = trait_clause_id_to_string env clause.clause_id in let trait_id = trait_decl_id_to_string env clause.trait_id in let generics = generic_args_to_strings env true clause.generics in let generics = if generics = [] then "" else " " ^ String.concat " " generics in - "[" ^ clause_id ^ "]: " ^ trait_id ^ generics + trait_id ^ generics let generic_params_to_strings (env : fmt_env) (generics : generic_params) : string list = diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index a7c2f154..34597d32 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -11,6 +11,10 @@ let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let fmt = trans_ctx_to_pure_fmt_env ctx in PrintPure.fun_decl_to_string fmt def +let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx in + PrintPure.fun_sig_to_string fmt sg + (** Small utility. We sometimes have to insert new fresh variables in a function body, in which @@ -1303,7 +1307,8 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = those function bodies into independent definitions while removing occurrences of the {!Pure.Loop} node. *) -let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = +let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : + fun_decl * fun_decl list = match def.body with | None -> (def, []) | Some body -> @@ -1982,7 +1987,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops def in + let def, loops = decompose_loops ctx def in (* Apply the remaining passes *) let f = apply_end_passes_to_def ctx def in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index a79340b6..7359f68a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -950,7 +950,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) let fwd_info = (* *) let has_fuel = fwd_fuel <> [] in - let num_inputs_no_fuel_no_state = List.length fwd_inputs in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in let num_inputs_with_fuel_no_state = (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) List.length fwd_fuel + num_inputs_no_fuel_no_state @@ -2620,9 +2620,6 @@ and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* TODO: *) - assert (not !Config.return_back_funs); - let translate_one_end ctx (bid : RegionGroupId.id option) = let ctx = { ctx with bid } in (* Update the current state with the additional state received by the backward -- cgit v1.2.3 From 116b569d1b08a57c3ad66071979a1c966fdad3a2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 12:18:06 +0100 Subject: Remove the backwards field from SymbolicToPure.call_info --- compiler/SymbolicToPure.ml | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 7359f68a..ea2082c7 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,12 +67,6 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; - (** A map from region group id (i.e., backward function id) to - pairs (abstraction, additional arguments received by the backward function) - - TODO: remove? it is also in the bs_ctx ("abstractions" field) - *) } [@@deriving show] @@ -224,7 +218,10 @@ type bs_ctx = { calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; - (** The ended abstractions we encountered so far, with their additional input arguments *) + (** The ended abstractions we encountered so far, with their additional + input arguments. We store it here and not in {!call_info} because + we need a map from abstraction id to abstraction (and not + from call id + region group id to abstraction). *) loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *) loops : loop_info LoopId.Map.t; (** The loops we encountered so far. @@ -765,9 +762,7 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) (ctx : bs_ctx) : bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = - { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } - in + let info = { forward; forward_inputs = args } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } @@ -777,11 +772,6 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) : bs_ctx * fun_or_op_id = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in - assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = - T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards - in - let info = { info with backwards } in let calls = V.FunCallId.Map.add call_id info ctx.calls in (* Insert the abstraction in the abstractions map *) let abstractions = ctx.abstractions in -- cgit v1.2.3 From 4f7bc41dcbc6187512111a81f968726452024d25 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 19 Dec 2023 12:54:40 +0100 Subject: Simplify SymbolicToPure.bs_ctx.{backward_outputs, loop_backward_outputs} --- compiler/SymbolicToPure.ml | 153 ++++++++++++++++++++------------------------- compiler/Translate.ml | 17 +---- 2 files changed, 70 insertions(+), 100 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ea2082c7..93e6cb4e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -109,6 +109,10 @@ type loop_info = { (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; (** The forward outputs are initialized at [None] *) + back_outputs : ty list RegionGroupId.Map.t; + (** The map from region group ids to the types of the values given back + by the corresponding loop abstractions. + *) } [@@deriving show] @@ -187,12 +191,11 @@ type bs_ctx = { Same remarks as for {!backward_inputs_no_state}. *) - backward_outputs : var list RegionGroupId.Map.t; + backward_outputs : var list option; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state). The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - when we detect the ended input abstraction which corresponds to the backward function of the LLBC function we are translating, and which consumed the values [consumed_i] (that we need to give @@ -201,14 +204,20 @@ type bs_ctx = { let v_i = consumed_i in ... ]} - Then, upon reaching the [Return] node, we introduce: + where the [v_i] are fresh, and are stored in the [backward_output]. + - Then, upon reaching the [Return] node, we introduce: {[ - (v_i) + return (v_i) ]} + + The option is [None] before we detect the ended input abstraction, + and [Some] afterwards. *) - loop_backward_outputs : var list RegionGroupId.Map.t option; + loop_backward_outputs : var list option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). + TODO: merge with [backward_outputs]? + [None] if we are not inside a loop, [Some] otherwise (and whatever the kind of function we are translating: it will be [Some] even though we are synthesizing a forward function). @@ -1607,7 +1616,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list) let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with - | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx + | S.Return (ectx, opt_v) -> + (* Remark: we can't get there if we are inside a loop *) + translate_return ectx opt_v ctx | ReturnWithLoop (loop_id, is_continue) -> translate_return_with_loop loop_id is_continue ctx | Panic -> translate_panic ctx @@ -1644,10 +1655,9 @@ and translate_panic (ctx : bs_ctx) : texpression = if ctx.inside_loop && Option.is_some ctx.bid then (* We are synthesizing the backward function of a loop body *) let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in let output = mk_simpl_tuple_ty tys in mk_output output else @@ -1667,7 +1677,11 @@ and translate_panic (ctx : bs_ctx) : texpression = in mk_output output -(** [opt_v]: the value to return, in case we translate a forward body *) +(** [opt_v]: the value to return, in case we translate a forward body. + + Remark: for now, we can't get there if we are inside a loop. + If inside a loop, we use {!translate_return_with_loop}. + *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1676,22 +1690,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) - or we are translating a backward function, in which case it should be [None] *) (* Compute the values that we should return *without the state and the result - * wrapper* *) + wrapper* *) let output = match ctx.bid with | None -> (* Forward function *) let v = Option.get opt_v in typed_value_to_texpression ctx ectx v - | Some bid -> + | Some _ -> (* Backward function *) (* Sanity check *) assert (opt_v = None); (* Group the variables in which we stored the values we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1728,19 +1740,16 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Forward *) mk_texpression_from_var (Option.get loop_info.forward_output_no_state_no_result) - | Some bid -> + | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) let backward_outputs = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function *) + Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values @@ -1923,45 +1932,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) ^ abs_to_string ctx abs ^ "\n")); (* When we end an input abstraction, this input abstraction gets back - * the borrows which it introduced in the context through the input - * values: by listing those values, we get the values which are given - * back by one of the backward functions we are synthesizing. *) - (* Note that we don't support nested borrows for now: if we find - * an ended synthesized input abstraction, it must be the one corresponding - * to the backward function wer are synthesizing, it can't be the one - * for a parent backward function. - *) + the borrows which it introduced in the context through the input + values: by listing those values, we get the values which are given + back by one of the backward functions we are synthesizing. + + Note that we don't support nested borrows for now: if we find + an ended synthesized input abstraction, it must be the one corresponding + to the backward function wer are synthesizing, it can't be the one + for a parent backward function. + *) let bid = Option.get ctx.bid in assert (rg_id = bid); - (* The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - - when we detect the ended input abstraction which corresponds - to the backward function, and which consumed the values [consumed_i], - we introduce: - {[ - let v_i = consumed_i in - ... - ]} - Then, when we reach the [Return] node, we introduce: - {[ - (v_i) - ]} - *) - (* First, get the given back variables. + (* First, introduce the given back variables. We don't use the same given back variables if we translate a loop or the standard body of a function. *) - let given_back_variables = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function body *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + let ctx, given_back_variables = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + let vars = List.map (fun ty -> (None, ty)) tys in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with loop_backward_outputs = Some vars }, vars) + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let vars = List.combine back_sg.output_names back_sg.outputs in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) @@ -2943,22 +2945,15 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in - let loop_backward_outputs = + let rg_to_given_back_tys = T.RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) - let vars = - List.map - (fun ty -> - assert ( - not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); - (None, ctx_translate_fwd_ty !ctx ty)) - tys - in - (* Introduce fresh variables *) - let ctx', vars = fresh_vars vars !ctx in - ctx := ctx'; - vars) + List.map + (fun ty -> + assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); + ctx_translate_fwd_ty !ctx ty) + tys) loop.rg_to_given_back_tys in let ctx = !ctx in @@ -2966,12 +2961,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let back_output_tys = match ctx.bid with | None -> None - | Some rg_id -> - let back_outputs = - T.RegionGroupId.Map.find rg_id loop_backward_outputs - in - let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in - Some back_output_tys + | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) in (* Add the loop information in the context *) @@ -3013,6 +3003,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = generics; forward_inputs = None; forward_output_no_state_no_result = None; + back_outputs = rg_to_given_back_tys; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in @@ -3020,13 +3011,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = - { - ctx with - loop_id = Some loop_id; - loop_backward_outputs = Some loop_backward_outputs; - } - in + let ctx_end = { ctx with loop_id = Some loop_id } in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index e153f4f4..0fa0202b 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -171,8 +171,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) backward_inputs_no_state = RegionGroupId.Map.empty; (* Initialized just below *) backward_inputs_with_state = RegionGroupId.Map.empty; - (* Initialized just below *) - backward_outputs = RegionGroupId.Map.empty; + backward_outputs = None; loop_backward_outputs = None; (* Empty for now *) calls; @@ -234,20 +233,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in - (* Add the backward outputs *) - let ctx, backward_outputs = - List.fold_left_map - (fun ctx (region_vars : region_var_group) -> - let gid = region_vars.id in - let back_sg = RegionGroupId.Map.find gid sg.back_sg in - let outputs = List.combine back_sg.output_names back_sg.outputs in - let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in - (ctx, (gid, vars))) - ctx regions_hierarchy - in - let backward_outputs = RegionGroupId.Map.of_list backward_outputs in - let ctx = { ctx with backward_outputs } in - (* Translate the forward function *) let pure_forward = match symbolic_trans with -- cgit v1.2.3 From 014c0668abf0834342b2b7076cf2f0634460e519 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 19 Dec 2023 13:24:53 +0100 Subject: Remove SymbolicToPure.bs_ctx.loop_backward_outputs --- compiler/SymbolicToPure.ml | 47 +++++++++++++++------------------------------- compiler/Translate.ml | 1 - 2 files changed, 15 insertions(+), 33 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 93e6cb4e..e2787271 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -213,17 +213,6 @@ type bs_ctx = { The option is [None] before we detect the ended input abstraction, and [Some] afterwards. *) - loop_backward_outputs : var list option; - (** Same as {!backward_outputs}, but for loops (if we entered a loop). - - TODO: merge with [backward_outputs]? - - [None] if we are not inside a loop, [Some] otherwise (and whatever - the kind of function we are translating: it will be [Some] even - though we are synthesizing a forward function). - - TODO: move to {!loop_info} - *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -1744,13 +1733,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - Option.get ctx.backward_outputs - in + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1950,20 +1933,20 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) the standard body of a function. *) let ctx, given_back_variables = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - let loop_id = Option.get ctx.loop_id in - let loop = LoopId.Map.find loop_id ctx.loops in - let tys = RegionGroupId.Map.find bid loop.back_outputs in - let vars = List.map (fun ty -> (None, ty)) tys in - let ctx, vars = fresh_vars vars ctx in - ({ ctx with loop_backward_outputs = Some vars }, vars) - else - (* Regular function body *) - let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let vars = List.combine back_sg.output_names back_sg.outputs in - let ctx, vars = fresh_vars vars ctx in - ({ ctx with backward_outputs = Some vars }, vars) + let vars = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + List.map (fun ty -> (None, ty)) tys + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + List.combine back_sg.output_names back_sg.outputs + in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 0fa0202b..631a5af9 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -172,7 +172,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialized just below *) backward_inputs_with_state = RegionGroupId.Map.empty; backward_outputs = None; - loop_backward_outputs = None; (* Empty for now *) calls; abstractions; -- cgit v1.2.3 From e90b23a0d42e2ea6805c88d6eaa4f9e5370a1dc1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 19 Dec 2023 13:28:17 +0100 Subject: Reset Config.return_back_funs to false --- compiler/Config.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index b8af6c6d..c8f3ed58 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -153,7 +153,7 @@ let loop_fixed_point_max_num_iters = 2 return (x :: ls))) ]} *) -let return_back_funs = ref true +let return_back_funs = ref false (** Forbids using field projectors for structures. -- cgit v1.2.3 From 8835d87df111d09122267fadc9a32f16b52d234a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 14:37:43 +0100 Subject: Make good progress on merging the fwd/back functions --- compiler/Config.ml | 2 +- compiler/Extract.ml | 4 +- compiler/InterpreterStatements.ml | 47 ++++--- compiler/PureUtils.ml | 19 ++- compiler/SymbolicAst.ml | 4 + compiler/SymbolicToPure.ml | 266 ++++++++++++++++++++++++++++++-------- compiler/SynthesizeSymbolic.ml | 16 ++- compiler/Translate.ml | 3 +- 8 files changed, 274 insertions(+), 87 deletions(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index c8f3ed58..b8af6c6d 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -153,7 +153,7 @@ let loop_fixed_point_max_num_iters = 2 return (x :: ls))) ]} *) -let return_back_funs = ref false +let return_back_funs = ref true (** Forbids using field projectors for structures. diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 3429cd11..46cf8c4a 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1332,7 +1332,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) in let fwd_back_comment = match def.back_id with - | None -> [ "forward function" ] + | None -> + if !Config.return_back_funs then [ "function definition" ] + else [ "forward function" ] | Some id -> (* Check if there is only one backward function, and no forward function *) if (not keep_fwd) && num_backs = 1 then diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index da617c64..94c65b5c 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -728,7 +728,12 @@ let create_push_abstractions_from_abs_region_groups to a trait clause but directly to the method provided in the trait declaration. *) let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) - : fun_id_or_trait_method_ref * generic_args * fun_decl * inst_fun_sig = + : + fun_id_or_trait_method_ref + * generic_args + * fun_decl + * region_var_groups + * inst_fun_sig = match call.func with | FnOpMove _ -> (* Closure case: TODO *) @@ -753,7 +758,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx func.generics tr_self def.signature regions_hierarchy in - (func.func, func.generics, def, inst_sg) + (func.func, func.generics, def, regions_hierarchy, inst_sg) | FunId (FAssumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") @@ -806,7 +811,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) we also need to update the generics. *) let func = FunId fid in - (func, generics, method_def, inst_sg) + (func, generics, method_def, regions_hierarchy, inst_sg) | None -> (* If not found, lookup the methods provided by the trait *declaration* (remember: for now, we forbid overriding provided methods) *) @@ -860,7 +865,11 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx all_generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg)) + ( func.func, + func.generics, + method_def, + regions_hierarchy, + inst_sg )) | _ -> (* We are using a local clause - we lookup the trait decl *) let trait_decl = @@ -891,7 +900,8 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, inst_sg))) + (func.func, func.generics, method_def, regions_hierarchy, inst_sg) + )) (** Evaluate a statement *) let rec eval_statement (config : config) (st : statement) : st_cm_fun = @@ -1277,14 +1287,14 @@ and eval_transparent_function_call_concrete (config : config) and eval_transparent_function_call_symbolic (config : config) (call : call) : st_cm_fun = fun cf ctx -> - let func, generics, def, inst_sg = + let func, generics, def, regions_hierarchy, inst_sg = eval_transparent_function_call_symbolic_inst call ctx in (* Sanity check *) assert (List.length call.args = List.length def.signature.inputs); (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config func inst_sg generics - call.args call.dest cf ctx + eval_function_call_symbolic_from_inst_sig config func def.signature + regions_hierarchy inst_sg generics call.args call.dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1298,7 +1308,8 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) : trait ref as input. *) and eval_function_call_symbolic_from_inst_sig (config : config) - (fid : fun_id_or_trait_method_ref) (inst_sg : inst_fun_sig) + (fid : fun_id_or_trait_method_ref) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig) (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun = fun cf ctx -> log#ldebug @@ -1378,8 +1389,8 @@ and eval_function_call_symbolic_from_inst_sig (config : config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids generics args - args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy + abs_ids generics args args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1468,7 +1479,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) (* In symbolic mode, the behaviour of a function call is completely defined * by the signature of the function: we thus simply generate correctly * instantiated signatures, and delegate the work to an auxiliary function *) - let inst_sig = + let sg, regions_hierarchy, inst_sig = match fid with | BoxFree -> (* Should have been treated above *) @@ -1480,14 +1491,16 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) in (* There shouldn't be any reference to Self *) let tr_self = UnknownTrait __FUNCTION__ in - instantiate_fun_sig ctx generics tr_self - (Assumed.get_assumed_fun_sig fid) - regions_hierarchy + let sg = Assumed.get_assumed_fun_sig fid in + let inst_sg = + instantiate_fun_sig ctx generics tr_self sg regions_hierarchy + in + (sg, regions_hierarchy, inst_sg) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) - inst_sig generics args dest cf ctx + eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg + regions_hierarchy inst_sig generics args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : config) (body : statement) : st_cm_fun = diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 6e86578c..6579e84c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -321,14 +321,23 @@ let destruct_apps (e : texpression) : texpression * texpression list = (** Make an [App (app, arg)] expression *) let mk_app (app : texpression) (arg : texpression) : texpression = + let raise_or_return msg = + if !Config.fail_hard then raise (Failure msg) + else + let e = App (app, arg) in + (* Dummy type - TODO: introduce an error type *) + let ty = app.ty in + { e; ty } + in match app.ty with | TArrow (ty0, ty1) -> (* Sanity check *) - assert (ty0 = arg.ty); - let e = App (app, arg) in - let ty = ty1 in - { e; ty } - | _ -> raise (Failure "Expected an arrow type") + if ty0 <> arg.ty then raise_or_return "App: wrong input type" + else + let e = App (app, arg) in + let ty = ty1 in + { e; ty } + | _ -> raise_or_return "Expected an arrow type" (** The reverse of {!destruct_apps} *) let mk_apps (app : texpression) (args : texpression list) : texpression = diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 53f99b7f..54d207d9 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -42,7 +42,11 @@ type call = { evaluated). We need it to compute the translated values for shared borrows (we need to perform lookups). *) + sg : fun_sig option; + (** The uninstantiated function signature, if this is not a unop/binop *) + regions_hierarchy : region_var_groups; abstractions : AbstractionId.id list; + (** The region abstractions introduced upon calling the function *) generics : generic_args; args : typed_value list; args_places : mplace option list; (** Meta information *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index e2787271..1ce6c698 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,6 +67,18 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) + back_funs : texpression RegionGroupId.Map.t option; + (** If we do not split between the forward/backward functions: the + variables we introduced for the backward functions. + + Example: + {[ + let x, back = Vec.index_mut n v in + ^^^^ + here + ... + ]} + *) } [@@deriving show] @@ -118,6 +130,8 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { + (* TODO: there are a lot of duplications with the various decls ctx *) + decls_ctx : C.decls_ctx; type_ctx : type_ctx; fun_ctx : fun_ctx; global_ctx : global_ctx; @@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) TraitMethod (trait_ref, method_name, fun_decl_id) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = + (args : texpression list) + (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx + = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = { forward; forward_inputs = args } in + let info = { forward; forward_inputs = args; back_funs } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = +(** [inherit_args]: the list of inputs inherited from the forward function and + the ancestors backward functions, if pertinent. + [back_args]: the *additional* list of inputs received by the backward function, + including the state. + + Returns the updated context and the expression corresponding to the function. + *) +let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) + (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) + (inherited_args : texpression list) (back_args : texpression list) + (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : + bs_ctx * texpression = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + (* Compute the expression corresponding to the function *) + let func = + if !Config.return_back_funs then + (* Lookup the variable introduced for the backward function *) + RegionGroupId.Map.find back_id (Option.get info.back_funs) + else + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + let args = List.append inherited_args back_args in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output_ty else output_ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp fun_id; generics } in + { e = Qualif func; ty = func_ty } in (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) + ({ ctx with calls; abstractions }, func) (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) - (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : - decomposed_fun_sig = +let translate_fun_sig_with_regions_hierarchy_to_decomposed + (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig) + (input_names : string option list) : decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in - (* Retrieve the list of parent backward functions *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None None - in + let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) + get_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = @@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list) + : decomposed_fun_sig = + (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies + in + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + (FunId (FRegular fun_id)) regions_hierarchy sg input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) +(** Return the pure signature of a backward function, in the case the + forward/backward functions are merged (i.e., the forward functions + return the backward functions). + + TODO: merge with {!translate_fun_sig_from_decomposed} + *) +let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id) : fun_sig = + assert !Config.return_back_funs; + + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) + in + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty = mk_output_ty_from_effect_info in + + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + (* Do not prepend the forward inputs *) + let inputs = List.map snd back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, out_state = + let ctx, fun_id, effect_info, args, back_funs, out_state = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in + (* If we do not split the forward/backward functions: generate the + variables for the backward functions returned by the forward + function. *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + fid call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_sgs = + List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + in + (* Introduce variables for the backward functions *) + let back_tys = + List.map + (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + back_sgs + in + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls + in + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) + in + name ^ "_back" + in + let ctx, back_vars = + fresh_vars + (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + ctx + in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in (* Register the function call *) - let ctx = bs_ctx_register_forward_call call_id call args ctx in - (ctx, func, effect_info, args, out_state) + let ctx = + bs_ctx_register_forward_call call_id call args back_funs_map ctx + in + (ctx, func, effect_info, args, back_funs, out_state) | S.Unop E.Not -> let effect_info = { @@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, None) + (ctx, Unop Not, effect_info, args, [], None) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, None) + (ctx, Unop (Neg int_ty), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, None) + (ctx, Binop (binop, int_ty0), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) in let dest_v = let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = mk_simpl_tuple_pattern (dest :: back_funs) in match out_state with | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] @@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + let inherited_inputs = + if !Config.return_back_funs then [] + else List.concat [ fwd_inputs; back_ancestors_inputs ] in + let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the * meta-place information from the input values given to the forward function @@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Retrieve the function id, and register the function call in the context - * if necessary *) + if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs + back_inputs generics output.ty ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) + let inputs = List.append inherited_inputs back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + ================= + We do a small optimization here if we split the forward/backward functions. + If the backward function doesn't have any output, we don't introduce any function + call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else mk_let effect_info.can_fail output call next_e diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index efcf001a..4ec7524b 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -2,6 +2,7 @@ open Types open TypesUtils open Expressions open Values +open LlbcAst open SymbolicAst let mk_mplace (p : place) (ctx : Contexts.eval_ctx) : mplace = @@ -92,6 +93,7 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value) synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) + (sg : fun_sig option) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) @@ -102,6 +104,8 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) { call_id; ctx; + sg; + regions_hierarchy; abstractions; generics; args; @@ -118,28 +122,30 @@ let synthesize_global_eval (gid : GlobalDeclId.id) (dest : symbolic_value) Option.map (fun e -> EvalGlobal (gid, dest, e)) e let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref) - (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) + (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig) + (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions generics args args_places dest dest_place e + ctx (Some sg) regions_hierarchy abstractions generics args args_places dest + dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop) (arg : typed_value) (arg_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ] - dest dest_place e + synthesize_function_call (Unop unop) ctx None [] [] generics [ arg ] + [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) (arg0 : typed_value) (arg0_place : mplace option) (arg1 : typed_value) (arg1_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ] + synthesize_function_call (Binop binop) ctx None [] [] generics [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 631a5af9..5584fb9a 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -129,7 +129,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let sg = - SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id) + SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id fdef.signature input_names in @@ -151,6 +151,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let ctx = { + decls_ctx = trans_ctx; SymbolicToPure.bid = None; sg; (* Will need to be updated for the backward functions *) -- cgit v1.2.3 From a630b8a703d8761746f7258b6db54080aa974f53 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 14:49:37 +0100 Subject: Fix a minor issue --- compiler/SymbolicToPure.ml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 1ce6c698..3d955061 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1123,12 +1123,15 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions *) +(** Compute the arrow types for all the backward functions. + + TODO: merge with below? + *) let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in + let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in mk_arrows inputs output) -- cgit v1.2.3 From 435fe4cf63869448e2b25486b564ede9efa9a34b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 15:17:28 +0100 Subject: Fix some issues in SymbolicToPure --- compiler/SymbolicToPure.ml | 51 +++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 25 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3d955061..ef0a0bde 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1137,39 +1137,30 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) -(** Return the pure signature of a backward function, in the case the - forward/backward functions are merged (i.e., the forward functions +(** Return the instantiated pure signature of a backward function, in the + case the forward/backward functions are merged (i.e., the forward functions return the backward functions). - - TODO: merge with {!translate_fun_sig_from_decomposed} *) -let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) - (gid : RegionGroupId.id) : fun_sig = +let translate_ret_back_inst_fun_sig_from_decomposed + (dsg : Pure.decomposed_fun_sig) (generics : generic_args) + (gid : RegionGroupId.id) : inst_fun_sig = assert !Config.return_back_funs; - - let generics = dsg.generics in - let llbc_generics = dsg.llbc_generics in - let preds = dsg.preds in - (* Compute the effects info *) - let fwd_info = dsg.fwd_info in - let back_effect_info = - RegionGroupId.Map.of_list - (List.map - (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> - (gid, info.effect_info)) - (RegionGroupId.Map.bindings dsg.back_sg)) - in - (* Two cases depending on whether we split the forward/backward functions - or not *) let mk_output_ty = mk_output_ty_from_effect_info in - + (* Lookup the signature information *) let back_sg = RegionGroupId.Map.find gid dsg.back_sg in let effect_info = back_sg.effect_info in (* Do not prepend the forward inputs *) let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty effect_info output in - { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + (* Substitute the types *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics dsg.generics generics tr_self in + let subst = ty_substitute subst in + let inputs = List.map subst inputs in + let output = subst output in + (* Return *) + { inputs; output } let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = @@ -1898,12 +1889,14 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : call.regions_hierarchy in let back_sgs = - List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + List.map + (translate_ret_back_inst_fun_sig_from_decomposed dsg generics) + gids in (* Introduce variables for the backward functions *) let back_tys = List.map - (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output) back_sgs in (* Compute a proper basename for the variables *) @@ -2216,6 +2209,14 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); let call = mk_apps func args in (* **Optimization**: ================= -- cgit v1.2.3 From d9f91cfcd538525f024c6019d7c8250dda8d76fd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 15:25:06 +0100 Subject: Remove some asserts which are now useless --- compiler/Extract.ml | 4 ---- compiler/PureMicroPasses.ml | 3 --- compiler/PureUtils.ml | 5 ++++- 3 files changed, 4 insertions(+), 8 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 46cf8c4a..8d35f039 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1475,8 +1475,6 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in - (* TODO: *) - assert (not !Config.return_back_funs); let num_fwd_inputs = def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in @@ -1523,8 +1521,6 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) if has_decreases_clause && !backend = Lean then ( let def_body = Option.get def.body in let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in - (* TODO: *) - assert (not !Config.return_back_funs); let num_fwd_inputs = def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 34597d32..63436e7d 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1337,9 +1337,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let fwd_info = fun_sig.fwd_info in let fwd_effect_info = fwd_info.effect_info in - (* TODO: *) - assert (not !Config.return_back_funs); - (* Generate the loop definition *) let loop_fwd_effect_info = fwd_effect_info in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 6579e84c..d4aaba16 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -332,7 +332,10 @@ let mk_app (app : texpression) (arg : texpression) : texpression = match app.ty with | TArrow (ty0, ty1) -> (* Sanity check *) - if ty0 <> arg.ty then raise_or_return "App: wrong input type" + if + (* TODO: we need to normalize the types *) + !Config.type_check_pure_code && ty0 <> arg.ty + then raise_or_return "App: wrong input type" else let e = App (app, arg) in let ty = ty1 in -- cgit v1.2.3 From cf3eea59ee61f2341daf7248664b8be878f128af Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 16:35:27 +0100 Subject: Update SymbolicToPure.ml for the loops --- compiler/SymbolicToPure.ml | 221 +++++++++++++++++++++++++-------------------- 1 file changed, 125 insertions(+), 96 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ef0a0bde..d3b0933c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -125,6 +125,11 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) + back_funs : texpression RegionGroupId.Map.t option; + (** Same as {!call_info.back_funs}. + Initialized with [None], gets updated to [Some] only if we merge + the fwd/back functions. + *) } [@@deriving show] @@ -1123,45 +1128,25 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. - - TODO: merge with below? - *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = +(** Compute the arrow types for all the backward functions. *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in + (* Compute *) let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in - mk_arrows inputs output) + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = make_subst_from_generics dsg.generics generics tr_self in + ty_substitute subst ty) (RegionGroupId.Map.values dsg.back_sg) -(** Return the instantiated pure signature of a backward function, in the - case the forward/backward functions are merged (i.e., the forward functions - return the backward functions). - *) -let translate_ret_back_inst_fun_sig_from_decomposed - (dsg : Pure.decomposed_fun_sig) (generics : generic_args) - (gid : RegionGroupId.id) : inst_fun_sig = - assert !Config.return_back_funs; - let mk_output_ty = mk_output_ty_from_effect_info in - (* Lookup the signature information *) - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - (* Do not prepend the forward inputs *) - let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (* Substitute the types *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics dsg.generics generics tr_self in - let subst = ty_substitute subst in - let inputs = List.map subst inputs in - let output = subst output in - (* Return *) - { inputs; output } - let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1184,7 +1169,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg in + let back_tys = compute_back_tys dsg None in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1274,6 +1259,40 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +(* Introduce variables for the backward functions *) +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = + (* We lookup the LLBC definition in an attempt to derive pretty names + for the backward functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_ctx.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in + fresh_vars back_vars ctx + let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1728,7 +1747,7 @@ and translate_panic (ctx : bs_ctx) : texpression = match ctx.bid with | None -> if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg in + let back_tys = compute_back_tys ctx.sg None in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1883,22 +1902,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_sgs = - List.map - (translate_ret_back_inst_fun_sig_from_decomposed dsg generics) - gids - in + let tr_self = UnknownTrait __FUNCTION__ in + let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in (* Introduce variables for the backward functions *) - let back_tys = - List.map - (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output) - back_sgs - in (* Compute a proper basename for the variables *) let back_fun_name = let name = @@ -1934,6 +1940,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs = List.map (fun v -> mk_typed_pattern_from_var v None) back_vars in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in let back_funs_map = RegionGroupId.Map.of_list (List.combine gids (List.map mk_texpression_from_var back_vars)) @@ -2338,6 +2349,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id | V.LoopCall -> + (* We need to introduce a call to the backward function corresponding + to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) @@ -2367,7 +2380,10 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in + let inputs = + if !Config.return_back_funs then List.concat [ back_inputs; back_state ] + else List.concat [ fwd_inputs; back_inputs; back_state ] + in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2391,28 +2407,43 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty in - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in + (* Create the expression for the function: + - it is either a call to a top-level function, if we split the + forward/backward functions + - or a call to the variable we introduced for the backward function, + if we merge the forward/backward functions *) + let func = + if !Config.return_back_funs then + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) + else + let func_ty = mk_arrows input_tys ret_ty in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in + { e = Qualif func; ty = func_ty } + in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None + ================= + We do a small optimization here in case we split the forward/backward + functions. + If the backward function doesn't have any output, we don't introduce + any function call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else @@ -2860,35 +2891,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) - let back_var_names = - let def_id = ctx.fun_decl.def_id in - let sg = ctx.fun_decl.signature in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.fun_ctx.regions_hierarchies - in - List.map - (fun (gid, _) -> - let rg = RegionGroupId.nth regions_hierarchy gid in - let region_names = - List.map - (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) - rg.regions - in - let name = - match region_names with - | [] -> "back" - | [ Some r ] -> "back" ^ r - | _ -> - (* Concatenate all the region names *) - "back" - ^ String.concat "" (List.filter_map (fun x -> x) region_names) - in - Some name) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in - let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in - let _, back_vars = fresh_vars back_vars ctx in + let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) let vars = fwd_var :: back_vars in @@ -2964,8 +2967,32 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + (* Introduce fresh variables for the backward functions of the loop. + + For now, the backward functions of the loop are the same as the + backward functions of the outer function. + *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in + + (* Introduce patterns *) let args, ctx, out_pats = + (* Create the pattern for the output value *) let output_pat = mk_typed_pattern_from_var output_var None in + (* Add the returned backward functions (they might be empty) *) + let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in (* Depending on the function effects: * - add the fuel @@ -2988,6 +3015,7 @@ and translate_forward_end (ectx : C.eval_ctx) loop_info with forward_inputs = Some args; forward_output_no_state_no_result = Some output_var; + back_funs = back_funs_map; } in let ctx = @@ -3143,6 +3171,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_inputs = None; forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; + back_funs = None; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in -- cgit v1.2.3 From d4b3d0e6adae5bb9a2f62872dbcedc29aaa9fa30 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 17:00:52 +0100 Subject: Filter the useless backward functions --- compiler/SymbolicToPure.ml | 220 +++++++++++++++++++++++++++++---------------- 1 file changed, 145 insertions(+), 75 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d3b0933c..f37ea201 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,7 +67,7 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** If we do not split between the forward/backward functions: the variables we introduced for the backward functions. @@ -78,6 +78,10 @@ type call_info = { here ... ]} + + The expression might be [None] in case the backward function + has to be filtered (because it does nothing - the backward + functions for shared borrows for instance). *) } [@@deriving show] @@ -125,7 +129,7 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** Same as {!call_info.back_funs}. Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. @@ -777,8 +781,8 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) - (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx - = + (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) : + bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); let info = { forward; forward_inputs = args; back_funs } in @@ -790,13 +794,15 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) [back_args]: the *additional* list of inputs received by the backward function, including the state. - Returns the updated context and the expression corresponding to the function. + Returns the updated context and the expression corresponding to the function + that we need to call. This function may be [None] if it has to be ignored + (because it does nothing). *) let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) (inherited_args : texpression list) (back_args : texpression list) (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression = + bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -827,7 +833,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) in let func_ty = mk_arrows input_tys ret_ty in let func = { id = FunOrOp fun_id; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1128,23 +1134,36 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. *) +(** Compute the arrow types for all the backward functions. + + If a backward function has no inputs/outputs we filter it. + *) let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty list = + (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - (* Compute *) + (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty_from_effect_info effect_info output in - let ty = mk_arrows inputs output in - (* Substitute - TODO: normalize *) - match subst with - | None -> ty - | Some (generics, tr_self) -> - let subst = make_subst_from_generics dsg.generics generics tr_self in - ty_substitute subst ty) + let outputs = back_sg.outputs in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then + None + else + let output = mk_simpl_tuple_ty outputs in + let output = mk_output_ty_from_effect_info effect_info output in + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + let ty = + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = + make_subst_from_generics dsg.generics generics tr_self + in + ty_substitute subst ty + in + Some ty) (RegionGroupId.Map.values dsg.back_sg) let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) @@ -1169,7 +1188,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg None in + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1259,8 +1278,19 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) : + bs_ctx * var option list = + List.fold_left_map + (fun ctx var -> + match var with + | None -> (ctx, None) + | Some (name, ty) -> + let ctx, var = fresh_var name ty ctx in + (ctx, Some var)) + ctx vars + (* Introduce variables for the backward functions *) -let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = (* We lookup the LLBC definition in an attempt to derive pretty names for the backward functions. *) let back_var_names = @@ -1291,7 +1321,13 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = (RegionGroupId.Map.bindings ctx.sg.back_sg) in let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in - fresh_vars back_vars ctx + let back_vars = + List.map + (fun (name, ty) -> + match ty with None -> None | Some ty -> Some (name, ty)) + back_vars + in + fresh_opt_vars back_vars ctx let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with @@ -1748,6 +1784,7 @@ and translate_panic (ctx : bs_ctx) : texpression = | None -> if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1933,21 +1970,33 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : name ^ "_back" in let ctx, back_vars = - fresh_vars - (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some ty -> Some (Some back_fun_name, ty)) + back_tys) ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = List.map (fun (g : T.region_var_group) -> g.id) call.regions_hierarchy in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + RegionGroupId.Map.of_list (List.combine gids back_vars) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) @@ -2220,15 +2269,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here if we split the forward/backward functions. @@ -2252,7 +2292,22 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) - else mk_let effect_info.can_fail output call next_e + else + (* The backward function might also have been filtered if we do not + split the forward/backward functions *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2348,7 +2403,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopSynthInput -> (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id - | V.LoopCall -> + | V.LoopCall -> ( (* We need to introduce a call to the backward function corresponding to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in @@ -2419,9 +2474,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let func_ty = mk_arrows input_tys ret_ty in let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in let func = { id = FunOrOp func; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here in case we split the forward/backward @@ -2447,38 +2501,44 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) assert (List.length inputs = List.length fwd_inputs); next_e) else - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values + (* In case we merge the fwd/back functions we filter the backward + functions elsewhere *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2894,7 +2954,7 @@ and translate_forward_end (ectx : C.eval_ctx) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: back_vars in + let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in @@ -2903,12 +2963,16 @@ and translate_forward_end (ectx : C.eval_ctx) (* Bind the expressions for the backward function and the expression for the computation of the forward output *) + let back_vars_els = + List.filter_map + (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) + (List.combine back_vars back_el) + in let e = List.fold_right (fun (var, back_e) e -> mk_let false (mk_typed_pattern_from_var var None) back_e e) - (List.combine back_vars back_el) - ret + back_vars_els ret in (* Bind the expression for the forward output *) let fwd_var = mk_typed_pattern_from_var fwd_var None in @@ -2976,12 +3040,18 @@ and translate_forward_end (ectx : C.eval_ctx) if !Config.return_back_funs then let ctx, back_vars = fresh_back_vars_for_current_fun ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = RegionGroupId.Map.keys ctx.sg.back_sg in let back_funs_map = RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) -- cgit v1.2.3 From ccfcadc3686e69c1b8a8c826ec14f3c0e1dfbd7b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 17:08:37 +0100 Subject: Update the formatting of comments --- compiler/Extract.ml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 8d35f039..57360536 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1323,18 +1323,16 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in - let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]: " in + let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in let comment = let loop_comment = match def.loop_id with | None -> "" - | Some id -> "loop " ^ LoopId.to_string id ^ ": " + | Some id -> " loop " ^ LoopId.to_string id ^ ":" in let fwd_back_comment = match def.back_id with - | None -> - if !Config.return_back_funs then [ "function definition" ] - else [ "forward function" ] + | None -> if !Config.return_back_funs then [] else [ "forward function" ] | Some id -> (* Check if there is only one backward function, and no forward function *) if (not keep_fwd) && num_backs = 1 then @@ -1346,9 +1344,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) else [ "backward function " ^ T.RegionGroupId.to_string id ] in match fwd_back_comment with - | [] -> raise (Failure "Unreachable") - | [ s ] -> [ comment_pre ^ loop_comment ^ s ] - | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl + | [] -> [ comment_pre ^ loop_comment ] + | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ] + | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl in extract_comment_with_span fmt comment def.meta.span -- cgit v1.2.3 From 2f681446b11739e650b1d6050b717da872be9022 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 19:23:29 +0100 Subject: Simplify the type of the merged fwd/back functions --- compiler/Config.ml | 26 ++++++++ compiler/Pure.ml | 6 ++ compiler/PureMicroPasses.ml | 7 +- compiler/PureUtils.ml | 1 + compiler/SymbolicToPure.ml | 159 ++++++++++++++++++++++++++++++++------------ 5 files changed, 153 insertions(+), 46 deletions(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index b8af6c6d..2bb1ca34 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -370,6 +370,32 @@ let filter_useless_monadic_calls = ref true *) let filter_useless_functions = ref true +(** Simplify the forward/backward functions, in case we merge them + (i.e., the forward functions return the backward functions). + + The simplification occurs as follows: + - if a forward function returns the unit type and has non-trivial backward + functions, then we remove the returned output. + - if a backward function doesn't have inputs, we evaluate it inside the + forward function and don't wrap it in a result. + + Example: + {[ + // LLBC: + fn incr(x: &mut u32) { *x += 1 } + + // Translation without simplification: + let incr (x : u32) : result (unit * result u32) = ... + ^^^^ ^^^^^^ + | remove this result + remove the unit + + // Translation with simplification: + let incr (x : u32) : result u32 = ... + ]} + *) +let simplify_merged_fwd_backs = ref true + (** Use short names for the record fields. Some backends can't disambiguate records when their field names have collisions. diff --git a/compiler/Pure.ml b/compiler/Pure.ml index ddacf0c4..05cdbd70 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -908,6 +908,11 @@ type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) effect_info : fun_effect_info; + ignore_output : bool; + (** In case we merge the forward/backward functions: should we ignore + the output (happens for forward functions if the output type is + [unit] and there are non-filtered backward functions)? + *) } [@@deriving show] @@ -939,6 +944,7 @@ type back_sg_info = { We derive those from the names of the inputs of the original LLBC function. *) effect_info : fun_effect_info; + filter : bool; (** Should we filter this backward function? *) } [@@deriving show] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 63436e7d..16bf1c08 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1336,6 +1336,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let fun_sig = def.signature in let fwd_info = fun_sig.fwd_info in let fwd_effect_info = fwd_info.effect_info in + let ignore_output = fwd_info.ignore_output in (* Generate the loop definition *) let loop_fwd_effect_info = fwd_effect_info in @@ -1358,7 +1359,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : } in - { fwd_info; effect_info = loop_fwd_effect_info } + { fwd_info; effect_info = loop_fwd_effect_info; ignore_output } in assert (fun_sig_info_is_wf loop_fwd_sig_info); @@ -2187,7 +2188,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } = decl.signature in - let { fwd_info; effect_info } = fwd_info in + let { fwd_info; effect_info; ignore_output } = fwd_info in let { has_fuel; @@ -2212,7 +2213,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let fwd_info = { fwd_info; effect_info } in + let fwd_info = { fwd_info; effect_info; ignore_output } in assert (fun_sig_info_is_wf fwd_info); let signature = { diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index d4aaba16..78d0b120 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -448,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty = let mk_bool_ty : ty = TLiteral TBool let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args) +let ty_is_unit ty : bool = ty = mk_unit_ty let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = TTuple; variant_id = None } in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f37ea201..70a4e18d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -979,30 +979,6 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Compute the backward output, without the effect information *) let fwd_output = translate_fwd_ty type_infos sg.output in - (* The additinoal information *) - let fwd_info = - (* *) - let has_fuel = fwd_fuel <> [] in - let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in - let num_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fwd_fuel + num_inputs_no_fuel_no_state - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_inputs_with_fuel_no_state + List.length fwd_state_ty; - } - in - let info = { fwd_info; effect_info = fwd_effect_info } in - assert (fun_sig_info_is_wf info); - info - in (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) @@ -1086,6 +1062,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in + let filter = + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + in let info = { inputs; @@ -1093,6 +1072,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed outputs; output_names; effect_info = back_effect_info; + filter; } in (gid, info) @@ -1102,6 +1082,39 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed (List.map compute_back_info_for_group regions_hierarchy) in + (* The additional information about the forward function *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let ignore_output = + if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + ty_is_unit fwd_output + && List.exists + (fun (info : back_sg_info) -> not info.filter) + (RegionGroupId.Map.values back_sg) + else false + in + let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in + assert (fun_sig_info_is_wf info); + info + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in @@ -1134,6 +1147,13 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output +let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) + (inputs : ty list) (ty : ty) : ty = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output else output + (** Compute the arrow types for all the backward functions. If a backward function has no inputs/outputs we filter it. @@ -1151,7 +1171,9 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) None else let output = mk_simpl_tuple_ty outputs in - let output = mk_output_ty_from_effect_info effect_info output in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in let ty = mk_arrows inputs output in (* Substitute - TODO: normalize *) let ty = @@ -1166,6 +1188,25 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) Some ty) (RegionGroupId.Map.values dsg.back_sg) +(** In case we merge the fwd/back functions: compute the output type of + a function, from a decomposed signature. *) +let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = + assert !Config.return_back_funs; + (* Compute the arrow types for all the backward functions *) + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = + (* We might need to ignore the output of the forward function + (if it is unit for instance) *) + let tys = + if dsg.fwd_info.ignore_output then back_tys + else dsg.fwd_output :: back_tys + in + mk_simpl_tuple_ty tys + in + mk_output_ty_from_effect_info effect_info output + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1180,19 +1221,12 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - (* Two cases depending on whether we split the forward/backward functions - or not *) let mk_output_ty = mk_output_ty_from_effect_info in - let inputs, output = + (* Two cases depending on whether we split the forward/backward functions or not *) if !Config.return_back_funs then ( assert (gid = None); - (* Compute the arrow types for all the backward functions *) - let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in - (* Group the forward output and the types of the backward functions *) - let effect_info = dsg.fwd_info.effect_info in - let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in - let output = mk_output_ty effect_info output in + let output = compute_output_ty_from_decomposed dsg in let inputs = dsg.fwd_inputs in (inputs, output)) else @@ -1785,7 +1819,11 @@ and translate_panic (ctx : bs_ctx) : texpression = if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in let back_tys = List.filter_map (fun x -> x) back_tys in - let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in mk_output output else mk_output ctx.sg.fwd_output | Some bid -> @@ -1798,6 +1836,9 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: for now, we can't get there if we are inside a loop. If inside a loop, we use {!translate_return_with_loop}. + + Remark: in case we merge the forward/backward functions, we introduce + those in [translate_forward_end]. *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = @@ -2648,6 +2689,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) If (true_e, false_e) ) in let ty = true_e.ty in + log#ldebug + (lazy + ("true_e.ty: " + ^ pure_ty_to_string ctx true_e.ty + ^ "\n\nfalse_e.ty: " + ^ pure_ty_to_string ctx false_e.ty)); assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> @@ -2941,37 +2988,63 @@ and translate_forward_end (ectx : C.eval_ctx) in let fwd_e = translate_one_end ctx None in - (* Introduce the backward functions *) + (* Introduce the backward functions. *) let back_el = List.map (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> translate_one_end ctx (Some gid)) (RegionGroupId.Map.bindings ctx.sg.back_sg) in + + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else fwd_var :: back_vars + in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in let ret = mk_result_return_texpression ret in - (* Bind the expressions for the backward function and the expression - for the computation of the forward output *) + (* Introduce all the let-bindings *) + + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) let back_vars_els = List.filter_map - (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) - (List.combine back_vars back_el) + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) in let e = List.fold_right - (fun (var, back_e) e -> - mk_let false (mk_typed_pattern_from_var var None) back_e e) + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) back_vars_els ret in (* Bind the expression for the forward output *) -- cgit v1.2.3 From 781638d204f90660caabe23946653437e9480374 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 19:45:36 +0100 Subject: Implement a micro-pass to simplify the let-bindings --- compiler/PureMicroPasses.ml | 78 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 16bf1c08..7babe95b 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -645,6 +645,79 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = obj#visit_texpression () body.body } in { def with body = Some body } +(** Simplify the let-bindings by performing the following rewritings: + + Move inner let-bindings outside. This is especially useful to simplify + the backward expressions, when we merge the forward/backward functions. + Note that the rule is also applied with monadic let-bindings. + {[ + let x := + let y := ... in + e + + ~~> + + let y := ... in + let x := e + ]} + + Simplify panics and returns: + {[ + let x ← fail + ... + ~~> + fail + + let x ← return y + ... + ~~> + let x := y + ... + ]} + *) +let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let env monadic lv rv next = + match rv.e with + | Let (rmonadic, rlv, rrv, rnext) -> + (* Case 1: move the inner let outside then re-visit *) + let rnext1 = Let (monadic, lv, rnext, next) in + let rnext1 = { ty = next.ty; e = rnext1 } in + self#visit_Let env rmonadic rlv rrv rnext1 + | App + ( { + e = + Qualif + { + id = + AdtCons + { + adt_id = TAssumed TResult; + variant_id = Some variant_id; + }; + generics = _; + }; + ty = _; + }, + x ) -> + (* return/fail case *) + if variant_id = result_return_id then + (* Return case *) + super#visit_Let env false lv x next + else if variant_id = result_fail_id then (* Fail case *) rv.e + else raise (Failure "Unexpected") + | _ -> super#visit_Let env monadic lv rv next + end + in + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } + (** Inline the useless variable (re-)assignments: A lot of intermediate variable assignments are introduced through the @@ -1829,6 +1902,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("intro_struct_updates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the let-bindings *) + let def = simplify_let_bindings ctx def in + log#ldebug + (lazy ("simplify_let_bindings:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Inline the useless variable reassignments *) let inline_named_vars = true in let inline_pure = true in -- cgit v1.2.3 From 0fb89f21a302210aa284e54a10129c46dbe8b4b5 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 19:51:10 +0100 Subject: Use indices starting at 1 to make variable names unique at code gen --- compiler/ExtractBase.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index eb2a2ec9..0af7a9b4 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -351,7 +351,7 @@ let basename_to_unique (names_set : StringSet.t) let s = append basename i in if StringSet.mem s names_set then gen (i + 1) else s in - if StringSet.mem basename names_set then gen 0 else basename + if StringSet.mem basename names_set then gen 1 else basename type fun_name_info = { keep_fwd : bool; num_backs : int } -- cgit v1.2.3 From 6ee1063d98d82f6a3c0cf017834ec81cf012f0a1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 20:00:35 +0100 Subject: Improve PureMicroPasses.filter_useless to simplify the matches --- compiler/PureMicroPasses.ml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7babe95b..156fba29 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1052,10 +1052,23 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with | Var _ | CVar _ | Const _ | App _ | Qualif _ - | Switch (_, _) | Meta (_, _) | StructUpdate _ | Lambda _ -> super#visit_expression env e + | Switch (scrut, switch) -> ( + match switch with + | If (_, _) -> super#visit_expression env e + | Match branches -> + (* Simplify the branches *) + let simplify_branch (br : match_branch) = + (* Compute the set of values used inside the branch *) + let branch, used = self#visit_texpression env br.branch in + (* Simplify the pattern *) + let pat, _ = filter_typed_pattern (used ()) br.pat in + { pat; branch } + in + super#visit_expression env + (Switch (scrut, Match (List.map simplify_branch branches)))) | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) let e, used = self#visit_texpression env e in -- cgit v1.2.3 From 266db04e97778911c93cfd1aac251de04bb25f53 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 22:17:11 +0100 Subject: Fix several issues --- compiler/Pure.ml | 17 ----- compiler/SymbolicToPure.ml | 186 ++++++++++++++++++++++++++++++++------------- compiler/Translate.ml | 29 +++---- 3 files changed, 151 insertions(+), 81 deletions(-) (limited to 'compiler') diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 05cdbd70..71531688 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -886,23 +886,6 @@ type inputs_info = { } [@@deriving show] -type ('a, 'b) back_info = - | SingleBack of 'a - (** Information about a single backward function, if pertinent. - - We use this variant if we split the forward and the backward functions. - *) - | AllBacks of 'b RegionGroupId.Map.t - (** Information about the various backward functions. - - We use this if we *do not* split the forward and the backward functions. - All the information is then carried by the forward function. - *) -[@@deriving show] - -type back_inputs_info = (inputs_info option, inputs_info) back_info -[@@deriving show] - (** Meta information about a function signature *) type fun_sig_info = { fwd_info : inputs_info; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 70a4e18d..37f621e4 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -146,6 +146,7 @@ type bs_ctx = { global_ctx : global_ctx; trait_decls_ctx : trait_decls_ctx; trait_impls_ctx : trait_impls_ctx; + fun_dsigs : decomposed_fun_sig FunDeclId.Map.t; fun_decl : A.fun_decl; bid : RegionGroupId.id option; (** TODO: rename @@ -890,7 +891,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] (** Small utility. *) -let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) +let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with @@ -917,6 +918,22 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } +(** TODO: not very clean. *) +let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : + fun_effect_info = + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + (** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and @@ -962,7 +979,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in + let fwd_effect_info = + compute_raw_fun_effect_info fun_infos fun_id None None + in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1051,12 +1070,23 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos fun_id None (Some gid) + compute_raw_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in + (* We consider the backward function as stateful and potentially failing + **only if it has inputs** (for the "potentially failing": if it has + not inputs, we directly evaluate it in the body of the forward function). + *) + let back_effect_info = + { + back_effect_info with + stateful = back_effect_info.stateful && inputs_no_state <> []; + can_fail = back_effect_info.can_fail && inputs_no_state <> []; + } + in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in @@ -1140,6 +1170,19 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx (FunId (FRegular fun_id)) regions_hierarchy sg input_names +let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) + (fdef : LlbcAst.fun_decl) : decomposed_fun_sig = + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : LlbcAst.var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) + in + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1158,8 +1201,9 @@ let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) If a backward function has no inputs/outputs we filter it. *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty option list = +let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : + (back_sg_info * ty) option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in @@ -1185,9 +1229,13 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) in ty_substitute subst ty in - Some ty) + Some (back_sg, ty)) (RegionGroupId.Map.values dsg.back_sg) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty option list = + List.map (Option.map snd) (compute_back_tys_with_info dsg subst) + (** In case we merge the fwd/back functions: compute the output type of a function, from a decomposed signature. *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = @@ -1363,6 +1411,7 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = in fresh_opt_vars back_vars ctx +(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *) let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1381,12 +1430,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value = | _ -> raise (Failure "Unreachable")) | _ -> v -(** Translate a symbolic value *) +(** Translate a symbolic value. + + Because we do not necessarily introduce variables for the symbolic values + of (translated) type unit, it is important that we do not lookup variables + in case the symbolic value has type unit. + *) let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) : texpression = (* Translate the type *) - let var = lookup_var_for_symbolic_value sv ctx in - mk_texpression_from_var var + let ty = ctx_translate_fwd_ty ctx sv.sv_ty in + (* If the type is unit, directly return unit *) + if ty_is_unit ty then mk_unit_rvalue + else + (* Otherwise lookup the variable *) + let var = lookup_var_for_symbolic_value sv ctx in + mk_texpression_from_var var (** Translate a typed value. @@ -1565,13 +1624,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option = match aproj with | V.AEndedProjLoans (msv, []) -> (* The symbolic value was left unchanged *) - let var = lookup_var_for_symbolic_value msv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx msv) | V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) -> assert (child_aproj = AIgnoredProjBorrows); (* The symbolic value was updated *) - let var = lookup_var_for_symbolic_value mnv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx mnv) | V.AEndedProjLoans (_, _) -> (* The symbolic value was updated, and the given back values come from sevearl * abstractions *) @@ -1940,10 +1997,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.combine args args_mplaces) in let dest_mplace = translate_opt_mplace call.dest_place in - let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, back_funs, out_state = + let ctx, fun_id, effect_info, args, dest_v = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1951,13 +2007,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fid None None - in + let effect_info = get_fun_effect_info ctx fid None None in (* Depending on the function effects: - * - add the fuel - * - add the state input argument - * - generate a fresh state variable for the returned state + - add the fuel + - add the state input argument + - generate a fresh state variable for the returned state *) let args, ctx, out_state = let fuel = mk_fuel_input_as_list ctx effect_info in @@ -1970,7 +2024,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* If we do not split the forward/backward functions: generate the variables for the backward functions returned by the forward function. *) - let ctx, back_funs_map, back_funs = + let ctx, ignore_fwd_output, back_funs_map, back_funs = if !Config.return_back_funs then (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in @@ -1981,7 +2035,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.map (fun _ -> None) sg.inputs) in let tr_self = UnknownTrait __FUNCTION__ in - let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in + let back_tys = + compute_back_tys_with_info dsg (Some (generics, tr_self)) + in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) let back_fun_name = @@ -2016,7 +2072,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (fun ty -> match ty with | None -> None - | Some ty -> Some (Some back_fun_name, ty)) + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then + Some back_fun_name + else None + in + Some (name, ty)) back_tys) ctx in @@ -2039,14 +2106,37 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + else (ctx, false, None, []) + in + (* Compute the pattern for the destination *) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = + (* Here there is something subtle: as we might ignore the output + of the forward function (because it translates to unit) we doNOT + necessarily introduce in the let-binding the variable to which we + map the symbolic value which was introduced for the output of the + function call. This would be problematic if later we need to + translate this symbolic value, but we implemented + {!symbolic_value_to_texpression} so that it doesn't perform any + lookups if the symbolic value has type unit. + *) + let vars = + if ignore_fwd_output then back_funs else dest :: back_funs + in + mk_simpl_tuple_pattern vars + in + let dest = + match out_state with + | None -> dest + | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args back_funs_map ctx in - (ctx, func, effect_info, args, back_funs, out_state) + (ctx, func, effect_info, args, dest) | S.Unop E.Not -> let effect_info = { @@ -2057,7 +2147,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop Not, effect_info, args, dest) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -2073,7 +2165,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Neg int_ty), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -2088,7 +2182,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -2108,16 +2204,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Binop (binop, int_ty0), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) in - let dest_v = - let dest = mk_typed_pattern_from_var dest dest_mplace in - let dest = mk_simpl_tuple_pattern (dest :: back_funs) in - match out_state with - | None -> dest - | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] - in let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -2242,9 +2333,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Those don't have backward functions *) raise (Failure "Unreachable") in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id) - in + let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in @@ -2449,8 +2538,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) - (Some rg_id) + get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2609,8 +2697,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) - let scrutinee_var = lookup_var_for_symbolic_value sv ctx in - let scrutinee = mk_texpression_from_var scrutinee_var in + let scrutinee = symbolic_value_to_texpression ctx sv in let scrutinee_mplace = translate_opt_mplace p in (* Translate the branches *) match exp with @@ -2999,7 +3086,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Compute whether the backward expressions should be evaluated straight away or not (i.e., if we should bind them with monadic let-bindings or not). We evaluate them straight away if they can fail and have no - inputs *) + inputs. *) let evaluate_backs = List.map (fun (sg : back_sg_info) -> @@ -3098,9 +3185,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid - in + let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in @@ -3479,8 +3564,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id)) - None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 5584fb9a..ccc46420 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -42,7 +42,8 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : + (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) + (fun_dsigs : Pure.decomposed_fun_sig FunDeclId.Map.t) (fdef : fun_decl) : pure_fun_translation_no_loops = (* Debug *) log#ldebug @@ -119,18 +120,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) !m in - let input_names = - match fdef.body with - | None -> List.map (fun _ -> None) fdef.signature.inputs - | Some body -> - List.map - (fun (v : var) -> v.name) - (LlbcAstUtils.fun_body_get_input_vars body) - in - let sg = - SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id - fdef.signature input_names + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef in let regions_hierarchy = @@ -154,6 +145,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) decls_ctx = trans_ctx; SymbolicToPure.bid = None; sg; + fun_dsigs; (* Will need to be updated for the backward functions *) sv_to_var; var_counter = ref var_counter; @@ -290,10 +282,21 @@ let translate_crate_to_pure (crate : crate) : (List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls) in + (* Compute the decomposed fun sigs for the whole crate *) + let fun_dsigs = + FunDeclId.Map.of_list + (List.map + (fun (fdef : LlbcAst.fun_decl) -> + ( fdef.def_id, + SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx + fdef )) + (FunDeclId.Map.values crate.fun_decls)) + in + (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure trans_ctx type_decls_map) + (translate_function_to_pure trans_ctx type_decls_map fun_dsigs) (FunDeclId.Map.values crate.fun_decls) in -- cgit v1.2.3 From eae740d644f5ccd1ad2a7e853a9cdf303c8df61e Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 22:45:47 +0100 Subject: Fix issues when extracting stateful functions --- compiler/PrintPure.ml | 51 +++++++++++++++++++++++----------------------- compiler/SymbolicToPure.ml | 30 +++++++++++++-------------- 2 files changed, 40 insertions(+), 41 deletions(-) (limited to 'compiler') diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 1ce146a4..315dd512 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -611,35 +611,36 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string) * expression *) let app, generics = match app.e with - | Qualif qualif -> + | Qualif qualif -> ( (* Qualifier case *) - (* Convert the qualifier identifier *) - let qualif_s = - match qualif.id with - | FunOrOp fun_id -> fun_or_op_id_to_string env fun_id - | Global global_id -> global_decl_id_to_string env global_id - | AdtCons adt_cons_id -> - let variant_s = - adt_variant_to_string env adt_cons_id.adt_id - adt_cons_id.variant_id - in - ConstStrings.constructor_prefix ^ variant_s - | Proj { adt_id; field_id } -> - let adt_s = adt_variant_to_string env adt_id None in - let field_s = adt_field_to_string env adt_id field_id in - (* Adopting an F*-like syntax *) - ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s - | TraitConst (trait_ref, generics, const_name) -> - let trait_ref = trait_ref_to_string env true trait_ref in - let generics_s = generic_args_to_string env generics in + match qualif.id with + | FunOrOp fun_id -> + let generics = generic_args_to_strings env true qualif.generics in + let qualif_s = fun_or_op_id_to_string env fun_id in + (qualif_s, generics) + | Global global_id -> + let generics = generic_args_to_strings env true qualif.generics in + (global_decl_id_to_string env global_id, generics) + | AdtCons adt_cons_id -> + let variant_s = + adt_variant_to_string env adt_cons_id.adt_id + adt_cons_id.variant_id + in + (ConstStrings.constructor_prefix ^ variant_s, []) + | Proj { adt_id; field_id } -> + let adt_s = adt_variant_to_string env adt_id None in + let field_s = adt_field_to_string env adt_id field_id in + (* Adopting an F*-like syntax *) + (ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s, []) + | TraitConst (trait_ref, generics, const_name) -> + let trait_ref = trait_ref_to_string env true trait_ref in + let generics_s = generic_args_to_string env generics in + let qualif = if generics <> empty_generic_args then "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name else trait_ref ^ "." ^ const_name - in - (* Convert the type instantiation *) - let generics = generic_args_to_strings env true qualif.generics in - (* *) - (qualif_s, generics) + in + (qualif, [])) | _ -> (* "Regular" expression case *) let inside = args <> [] || (args = [] && inside) in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 37f621e4..7eb75584 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2782,7 +2782,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) ^ pure_ty_to_string ctx true_e.ty ^ "\n\nfalse_e.ty: " ^ pure_ty_to_string ctx false_e.ty)); - assert (ty = false_e.ty); + if !Config.fail_hard then assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : @@ -3005,7 +3005,7 @@ and translate_forward_end (ectx : C.eval_ctx) fresh_vars back_sg.inputs_no_state ctx in let ctx, backward_inputs_with_state = - if (ctx_get_effect_info ctx).stateful then + if back_sg.effect_info.stateful then let ctx, var, _ = bs_ctx_fresh_state_var ctx in (ctx, backward_inputs_no_state @ [ var ]) else (ctx, backward_inputs_no_state) @@ -3061,18 +3061,7 @@ and translate_forward_end (ectx : C.eval_ctx) if !Config.return_back_funs then (* Compute the output of the forward function *) let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output_ty = - let ty = ctx.sg.fwd_output in - if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] - else ty - in - let ctx, fwd_var = fresh_var None output_ty ctx in - let ctx, state_var, state_pat = - if fwd_effect_info.stateful then - let ctx, var, pat = bs_ctx_fresh_state_var ctx in - (ctx, [ var ], [ pat ]) - else (ctx, [], []) - in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in let fwd_e = translate_one_end ctx None in (* Introduce the backward functions. *) @@ -3105,10 +3094,19 @@ and translate_forward_end (ectx : C.eval_ctx) let vars = let back_vars = List.filter_map (fun x -> x) back_vars in if ctx.sg.fwd_info.ignore_output then back_vars - else fwd_var :: back_vars + else pure_fwd_var :: back_vars in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in let ret = mk_result_return_texpression ret in @@ -3135,7 +3133,7 @@ and translate_forward_end (ectx : C.eval_ctx) back_vars_els ret in (* Bind the expression for the forward output *) - let fwd_var = mk_typed_pattern_from_var fwd_var None in + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in mk_let fwd_effect_info.can_fail pat fwd_e e else translate_one_end ctx ctx.bid -- cgit v1.2.3 From 6dc2b0f0906adc5d6f8f2f48404cf21d3595c957 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 23:02:51 +0100 Subject: Improve the pure micro passes --- compiler/PureMicroPasses.ml | 36 +++++++++++++++++++++++++++++++++--- compiler/PureUtils.ml | 13 +++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 156fba29..67495ab5 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -674,6 +674,16 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let x := y ... ]} + + Simplify tuples: + {[ + let (y0, y1) := (x0, x1) in + ... + ~~> + let y0 = x0 in + let y1 = x1 in + ... + ]} *) let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let obj = @@ -705,10 +715,30 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = x ) -> (* return/fail case *) if variant_id = result_return_id then - (* Return case *) - super#visit_Let env false lv x next - else if variant_id = result_fail_id then (* Fail case *) rv.e + (* Return case - note that the simplification we just perform + might have unlocked the tuple simplification below *) + self#visit_Let env false lv x next + else if variant_id = result_fail_id then + (* Fail case *) + self#visit_expression env rv.e else raise (Failure "Unexpected") + | App _ -> + (* This might be the tuple case *) + if not monadic then + match + (opt_dest_struct_pattern lv, opt_dest_tuple_texpression rv) + with + | Some pats, Some vals -> + (* Tuple case *) + let pat_vals = List.combine pats vals in + let e = + List.fold_right + (fun (pat, v) next -> mk_let false pat v next) + pat_vals next + in + super#visit_expression env e.e + | _ -> super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next | _ -> super#visit_Let env monadic lv rv next end in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 78d0b120..cc439e64 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -739,3 +739,16 @@ let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = let pats, e = destruct_lambdas e in (pat :: pats, e) | _ -> ([], e) + +let opt_dest_tuple_texpression (e : texpression) : texpression list option = + let app, args = destruct_apps e in + match app.e with + | Qualif { id = AdtCons { adt_id = TTuple; variant_id = None }; generics = _ } + -> + Some args + | _ -> None + +let opt_dest_struct_pattern (pat : typed_pattern) : typed_pattern list option = + match pat.value with + | PatAdt { variant_id = None; field_values } -> Some field_values + | _ -> None -- cgit v1.2.3 From 774eb319e514a0ba02473f9c82ee9d3355de8a3d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 11:09:10 +0100 Subject: Fix an issue when merging the fwd/back functions of trait methods --- compiler/InterpreterStatements.ml | 33 ++++++++++++++++++++++++--------- compiler/SymbolicAst.ml | 4 ++++ compiler/SymbolicToPure.ml | 26 +++++++++++++++++++++----- compiler/SynthesizeSymbolic.ml | 13 ++++++++----- 4 files changed, 57 insertions(+), 19 deletions(-) (limited to 'compiler') diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 94c65b5c..97c8bcd6 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -731,6 +731,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) : fun_id_or_trait_method_ref * generic_args + * (generic_args * trait_instance_id) option * fun_decl * region_var_groups * inst_fun_sig = @@ -758,7 +759,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx func.generics tr_self def.signature regions_hierarchy in - (func.func, func.generics, def, regions_hierarchy, inst_sg) + (func.func, func.generics, None, def, regions_hierarchy, inst_sg) | FunId (FAssumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") @@ -811,7 +812,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) we also need to update the generics. *) let func = FunId fid in - (func, generics, method_def, regions_hierarchy, inst_sg) + ( func, + generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ) | None -> (* If not found, lookup the methods provided by the trait *declaration* (remember: for now, we forbid overriding provided methods) *) @@ -867,6 +873,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) in ( func.func, func.generics, + Some (all_generics, tr_self), method_def, regions_hierarchy, inst_sg )) @@ -900,8 +907,12 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx) instantiate_fun_sig ctx generics tr_self method_def.signature regions_hierarchy in - (func.func, func.generics, method_def, regions_hierarchy, inst_sg) - )) + ( func.func, + func.generics, + Some (generics, tr_self), + method_def, + regions_hierarchy, + inst_sg ))) (** Evaluate a statement *) let rec eval_statement (config : config) (st : statement) : st_cm_fun = @@ -1287,14 +1298,15 @@ and eval_transparent_function_call_concrete (config : config) and eval_transparent_function_call_symbolic (config : config) (call : call) : st_cm_fun = fun cf ctx -> - let func, generics, def, regions_hierarchy, inst_sg = + let func, generics, trait_method_generics, def, regions_hierarchy, inst_sg = eval_transparent_function_call_symbolic_inst call ctx in (* Sanity check *) assert (List.length call.args = List.length def.signature.inputs); (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config func def.signature - regions_hierarchy inst_sg generics call.args call.dest cf ctx + regions_hierarchy inst_sg generics trait_method_generics call.args call.dest + cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1310,7 +1322,9 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) : and eval_function_call_symbolic_from_inst_sig (config : config) (fid : fun_id_or_trait_method_ref) (sg : fun_sig) (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig) - (generics : generic_args) (args : operand list) (dest : place) : st_cm_fun = + (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) + (args : operand list) (dest : place) : st_cm_fun = fun cf ctx -> log#ldebug (lazy @@ -1390,7 +1404,8 @@ and eval_function_call_symbolic_from_inst_sig (config : config) (* Synthesize the symbolic AST *) S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy - abs_ids generics args args_places ret_spc dest_place expr + abs_ids generics trait_method_generics args args_places ret_spc dest_place + expr in let cc = comp cc cf_call in @@ -1500,7 +1515,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id) (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg - regions_hierarchy inst_sig generics args dest cf ctx + regions_hierarchy inst_sig generics None args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : config) (body : statement) : st_cm_fun = diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 54d207d9..8e8cdec3 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -48,6 +48,10 @@ type call = { abstractions : AbstractionId.id list; (** The region abstractions introduced upon calling the function *) generics : generic_args; + trait_method_generics : (generic_args * trait_instance_id) option; + (** In case the call is to a trait method, we may need an additional type + parameter ([Self]) and the self trait clause to instantiate the + function signature. *) args : typed_value list; args_places : mplace option list; (** Meta information *) dest : symbolic_value; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 7eb75584..41922cb5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1985,7 +1985,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = log#ldebug (lazy - ("translate_function_call:\n" + ("translate_function_call:\n" ^ "\n- call.call_id:" + ^ S.show_call_id call.call_id + ^ "\n\n- call.generics:\n" ^ ctx_generic_args_to_string ctx call.generics)); (* Translate the function call *) let generics = ctx_translate_fwd_generic_args ctx call.generics in @@ -2025,7 +2027,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then + if !Config.return_back_funs then ( (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in let decls_ctx = ctx.decls_ctx in @@ -2034,9 +2036,23 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let tr_self = UnknownTrait __FUNCTION__ in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in let back_tys = - compute_back_tys_with_info dsg (Some (generics, tr_self)) + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) @@ -2106,7 +2122,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) else (ctx, false, None, []) in (* Compute the pattern for the destination *) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 4ec7524b..865185a8 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -95,6 +95,7 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value) let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) (sg : fun_sig option) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = @@ -108,6 +109,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) regions_hierarchy; abstractions; generics; + trait_method_generics; args; dest; args_places; @@ -125,19 +127,20 @@ let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref) (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig) (regions_hierarchy : region_var_groups) (abstractions : AbstractionId.id list) (generics : generic_args) + (trait_method_generics : (generic_args * trait_instance_id) option) (args : typed_value list) (args_places : mplace option list) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx (Some sg) regions_hierarchy abstractions generics args args_places dest - dest_place e + ctx (Some sg) regions_hierarchy abstractions generics trait_method_generics + args args_places dest dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop) (arg : typed_value) (arg_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Unop unop) ctx None [] [] generics [ arg ] + synthesize_function_call (Unop unop) ctx None [] [] generics None [ arg ] [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) @@ -145,8 +148,8 @@ let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop) (arg1_place : mplace option) (dest : symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = let generics = empty_generic_args in - synthesize_function_call (Binop binop) ctx None [] [] generics [ arg0; arg1 ] - [ arg0_place; arg1_place ] dest dest_place e + synthesize_function_call (Binop binop) ctx None [] [] generics None + [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs) (e : expression option) : expression option = -- cgit v1.2.3 From a504199331e1b406d24067837a725085fb8f09e9 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 18:08:26 +0100 Subject: Slightly update the formatting of the do blocks --- compiler/Extract.ml | 50 ++++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 24 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 57360536..04ad3b75 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -266,6 +266,21 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id av.field_values v.ty +(** Return true if we need to wrap a succession of let-bindings in a [do ...] + block (because some of them are monadic) *) +let lets_require_wrap_in_do (lets : (bool * typed_pattern * texpression) list) : + bool = + match !backend with + | Lean -> + (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) + List.exists (fun (m, _, _) -> m) lets + | HOL4 -> + (* HOL4 is similar to HOL4, but we add a sanity check *) + let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in + if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); + wrap_in_do + | FStar | Coq -> false + (** [inside]: controls the introduction of parentheses. See [extract_ty] TODO: replace the formatting boolean [inside] with something more general? @@ -634,15 +649,6 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | HOL4 -> destruct_lets_no_interleave e | FStar | Coq | Lean -> destruct_lets e in - (* Open a box for the whole expression. - - In the case of Lean, we use a vbox so that line breaks are inserted - at the end of every let-binding: let-bindings are indeed not ended - with an "in" keyword. - *) - if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; - (* Open parentheses *) - if inside && !backend <> Lean then F.pp_print_string fmt "("; (* Extract the let-bindings *) let extract_let (ctx : extraction_ctx) (monadic : bool) (lv : typed_pattern) (re : texpression) : extraction_ctx = @@ -715,22 +721,19 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Return *) ctx in + (* Open a box for the whole expression. + + In the case of Lean, we use a vbox so that line breaks are inserted + at the end of every let-binding: let-bindings are indeed not ended + with an "in" keyword. + *) + if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; + (* Open parentheses *) + if inside && !backend <> Lean then F.pp_print_string fmt "("; (* If Lean and HOL4, we rely on monadic blocks, so we insert a do and open a new box immediately *) - let wrap_in_do_od = - match !backend with - | Lean -> - (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) - List.exists (fun (m, _, _) -> m) lets - | HOL4 -> - (* HOL4 is similar to HOL4, but we add a sanity check *) - let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in - if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); - wrap_in_do - | FStar | Coq -> false - in + let wrap_in_do_od = lets_require_wrap_in_do lets in if wrap_in_do_od then ( - F.pp_open_vbox fmt (if !backend = Lean then ctx.indent_incr else 0); F.pp_print_string fmt "do"; F.pp_print_space fmt ()); let ctx = @@ -746,11 +749,10 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_close_box fmt (); (* do-box (Lean and HOL4 only) *) - if wrap_in_do_od then ( + if wrap_in_do_od then if !backend = HOL4 then ( F.pp_print_space fmt (); F.pp_print_string fmt "od"); - F.pp_close_box fmt ()); (* Close parentheses *) if inside && !backend <> Lean then F.pp_print_string fmt ")"; (* Close the box for the whole expression *) -- cgit v1.2.3 From 455ba366f9c8d07a1f1848ec0960b1f2d161e7cf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 18:47:50 +0100 Subject: Update the library for F* --- compiler/ExtractBuiltin.ml | 8 ++++++++ compiler/Translate.ml | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 24d16dca..ee8d4831 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -232,6 +232,14 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list let mk_fun (rust_name : string) (extract_name : string option) (filter : bool list option) (with_back : bool) (back_no_suffix : bool) : pattern * bool list option * builtin_fun_info list = + (* [back_no_suffix] is used to control whether the backward function should + have the suffix "_back" or not (if not, then the forward function has the + prefix "_fwd", and is filtered anyway). This is pertinent only if we split + the fwd/back functions. *) + let back_no_suffix = back_no_suffix && not !Config.return_back_funs in + (* Same for the [with_back] option: this is pertinent only if we split + the fwd/back functions *) + let with_back = with_back && not !Config.return_back_funs in let rust_name = try parse_pattern rust_name with Failure _ -> diff --git a/compiler/Translate.ml b/compiler/Translate.ml index ccc46420..55a94302 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -1171,7 +1171,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let exe_dir = Filename.dirname Sys.argv.(0) in let primitives_src_dest = match !Config.backend with - | FStar -> Some ("/backends/fstar/Primitives.fst", "Primitives.fst") + | FStar -> + let src = + if !Config.return_back_funs then + "/backends/fstar/merge/Primitives.fst" + else "/backends/fstar/split/Primitives.fst" + in + Some (src, "Primitives.fst") | Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v") | Lean -> None | HOL4 -> None -- cgit v1.2.3 From 3688596f27a1ba461f48e88446b8812ec73f1a2f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 19:09:16 +0100 Subject: Add an option to split the fwd/back functions and fix a minor issue --- compiler/Main.ml | 3 +++ compiler/SymbolicToPure.ml | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) (limited to 'compiler') diff --git a/compiler/Main.ml b/compiler/Main.ml index 835b9088..abc27b46 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -120,6 +120,9 @@ let () = " Generate a default lakefile.lean (Lean only)" ); ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC"); ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error"); + ( "-split-fwd-back", + Arg.Clear return_back_funs, + " Split the forward and backward functions." ); ] in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 41922cb5..4674b61c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1076,16 +1076,20 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - (* We consider the backward function as stateful and potentially failing + (* In case we merge the forward/backward functions: + we consider the backward function as stateful and potentially failing **only if it has inputs** (for the "potentially failing": if it has not inputs, we directly evaluate it in the body of the forward function). *) let back_effect_info = - { - back_effect_info with - stateful = back_effect_info.stateful && inputs_no_state <> []; - can_fail = back_effect_info.can_fail && inputs_no_state <> []; - } + if !Config.return_back_funs then + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } + else back_effect_info in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] @@ -1093,7 +1097,8 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let filter = - !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + !Config.simplify_merged_fwd_backs + && !Config.return_back_funs && inputs = [] && outputs = [] in let info = { -- cgit v1.2.3 From 29f358f4072ee4c6530b4c523a1754d4c0723893 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 19:17:09 +0100 Subject: Fix a minor extraction issue --- compiler/Extract.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 04ad3b75..3d9f0c22 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -610,7 +610,7 @@ and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) ctx xl in F.pp_print_space fmt (); - if !backend = Lean then F.pp_print_string fmt "=>" + if !backend = Lean || !backend = Coq then F.pp_print_string fmt "=>" else F.pp_print_string fmt "->"; F.pp_print_space fmt (); (* Print the body *) -- cgit v1.2.3 From 719263b7bb727bdb432f66709b8c1eadc47ba922 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 19:23:39 +0100 Subject: Annotate the bound vars in the lambdas for Coq --- compiler/Extract.ml | 56 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 20 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 3d9f0c22..30b76ceb 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -241,30 +241,45 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) Result.Ok types (** [inside]: see {!extract_ty}. + [with_type]: do we also generate a type annotation? This is necessary for + backends like Coq when we write lambdas (Coq is not powerful enough to + infer the type). As a pattern can introduce new variables, we return an extraction context updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = - match v.value with - | PatConstant cv -> - extract_literal fmt inside cv; - ctx - | PatVar (v, _) -> - let vname = ctx_compute_var_basename ctx v.basename v.ty in - let ctx, vname = ctx_add_var vname v.id ctx in - F.pp_print_string fmt vname; - ctx - | PatDummy -> - F.pp_print_string fmt "_"; - ctx - | PatAdt av -> - let extract_value ctx inside v = - extract_typed_pattern ctx fmt is_let inside v - in - extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id - av.field_values v.ty + (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) : + extraction_ctx = + if with_type then F.pp_print_string fmt "("; + let inside = inside && not with_type in + let ctx = + match v.value with + | PatConstant cv -> + extract_literal fmt inside cv; + ctx + | PatVar (v, _) -> + let vname = ctx_compute_var_basename ctx v.basename v.ty in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | PatDummy -> + F.pp_print_string fmt "_"; + ctx + | PatAdt av -> + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id + av.field_values v.ty + in + if with_type then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty false v.ty; + F.pp_print_string fmt ")"); + ctx (** Return true if we need to wrap a succession of let-bindings in a [do ...] block (because some of them are monadic) *) @@ -602,11 +617,12 @@ and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Print the lambda - note that there should always be at least one variable *) assert (xl <> []); F.pp_print_string fmt "fun"; + let with_type = !backend = Coq in let ctx = List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true true x) + extract_typed_pattern ctx fmt true true ~with_type x) ctx xl in F.pp_print_space fmt (); -- cgit v1.2.3 From b230ddacd44a1ca1804940bf89253bde8de7ffe1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 20:12:00 +0100 Subject: Fix a minor issue with the extraction of loops when merging the fwd/back functions --- compiler/SymbolicToPure.ml | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 4674b61c..cd367d83 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -123,7 +123,7 @@ type loop_info = { generics : generic_args; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) - forward_output_no_state_no_result : var option; + forward_output_no_state_no_result : texpression option; (** The forward outputs are initialized at [None] *) back_outputs : ty list RegionGroupId.Map.t; (** The map from region group ids to the types of the values given back @@ -1956,10 +1956,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) *) let output = match ctx.bid with - | None -> - (* Forward *) - mk_texpression_from_var - (Option.get loop_info.forward_output_no_state_no_result) + | None -> Option.get loop_info.forward_output_no_state_no_result | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. @@ -1984,7 +1981,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_return_texpression output + mk_emeta (Tag "return_with_loop") (mk_result_return_texpression output) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3207,7 +3204,20 @@ and translate_forward_end (ectx : C.eval_ctx) let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + let ctx, fwd_output, output_pat = + if ctx.sg.fwd_info.ignore_output then + (* Note that we still need the forward output (which is unit), + because even though the loop function will ignore the forward output, + the forward expression will still compute an output (which + will have type unit - otherwise we can't ignore it). *) + (ctx, mk_unit_rvalue, []) + else + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + ( ctx, + mk_texpression_from_var output_var, + [ mk_typed_pattern_from_var output_var None ] ) + in + (* Introduce fresh variables for the backward functions of the loop. For now, the backward functions of the loop are the same as the @@ -3236,10 +3246,8 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce patterns *) let args, ctx, out_pats = - (* Create the pattern for the output value *) - let output_pat = mk_typed_pattern_from_var output_var None in (* Add the returned backward functions (they might be empty) *) - let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in + let output_pat = mk_simpl_tuple_pattern (output_pat @ back_funs) in (* Depending on the function effects: * - add the fuel @@ -3261,7 +3269,7 @@ and translate_forward_end (ectx : C.eval_ctx) { loop_info with forward_inputs = Some args; - forward_output_no_state_no_result = Some output_var; + forward_output_no_state_no_result = Some fwd_output; back_funs = back_funs_map; } in -- cgit v1.2.3 From 70d506d148e5ae1a3e4115034161f449aff666ed Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 21:03:17 +0100 Subject: Fix the output type of the loops backward functions --- compiler/PrintPure.ml | 11 ++------ compiler/Pure.ml | 4 +-- compiler/PureMicroPasses.ml | 25 +++-------------- compiler/PureTypeCheck.ml | 6 ----- compiler/SymbolicToPure.ml | 65 ++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 65 insertions(+), 46 deletions(-) (limited to 'compiler') diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 315dd512..66475d02 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -711,21 +711,14 @@ and loop_to_string (env : fmt_env) (indent : string) (indent_incr : string) ^ String.concat "; " (List.map (var_to_string env) loop.inputs) ^ "]" in - let back_output_tys = - let tys = - match loop.back_output_tys with - | None -> "" - | Some tys -> String.concat "; " (List.map (ty_to_string env false) tys) - in - "back_output_tys: [" ^ tys ^ "]" - in + let output_ty = "output_ty: " ^ ty_to_string env false loop.output_ty in let fun_end = texpression_to_string env false indent2 indent_incr loop.fun_end in let loop_body = texpression_to_string env false indent2 indent_incr loop.loop_body in - "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n" + "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ output_ty ^ "\n" ^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n" ^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n" ^ indent ^ "}" diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 71531688..a879ba37 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -754,9 +754,7 @@ and loop = { inputs : var list; inputs_lvs : typed_pattern list; (** The inputs seen as patterns. See {!fun_body}. *) - back_output_tys : ty list option; - (** The types of the given back values, if we ar esynthesizing a backward - function *) + output_ty : ty; (** The output type of the loop *) loop_body : texpression; } diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 67495ab5..e7e9d5e1 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -459,7 +459,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } = loop @@ -478,7 +478,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in @@ -1498,26 +1498,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] in - let output = - match loop.back_output_tys with - | None -> - (* Forward function: the return type is the same as the - parent function *) - fun_sig.output - | Some doutputs -> - (* Backward function: custom return type *) - let output = mk_simpl_tuple_ty doutputs in - let output = - if loop_fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if loop_fwd_effect_info.can_fail then mk_result_ty output - else output - in - output - in + let output = loop.output_ty in let loop_sig = { diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index d60d6a05..a989fd3b 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -188,12 +188,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = List.iter check_branch branches) | Loop loop -> assert (loop.fun_end.ty = e.ty); - (* If we translate forward functions, the type of the loop is the same - as the type of the parent expression - in case of backward functions, - the loop doesn't necessarily give back the same values as the parent - function - *) - assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body | StructUpdate supd -> ( diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index cd367d83..bf92482a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in let rg_to_given_back_tys = - T.RegionGroupId.Map.map + RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) List.map @@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in let ctx = !ctx in - let back_output_tys = - match ctx.bid with - | None -> None - | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) + (* The output type of the loop function *) + let output_ty = + if !Config.return_back_funs then + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_tys = + List.filter_map + (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let effect_info = back_sg.effect_info in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty) + (List.combine back_sgs given_back_tys) + in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + else + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output in (* Add the loop information in the context *) @@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in -- cgit v1.2.3 From dd7552bec1be1695682801fca6ba6dfcfa990fbb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 21:03:59 +0100 Subject: Update the computation of the effect info for the loops --- compiler/SymbolicToPure.ml | 141 ++++++++++++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 46 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf92482a..f0d1ca62 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -134,6 +134,8 @@ type loop_info = { Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. *) + fwd_effect_info : fun_effect_info; + back_effect_infos : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] @@ -922,17 +924,31 @@ let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = - match fun_id with - | TraitMethod (_, _, fid) | FunId (FRegular fid) -> - let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in - let info = - match gid with - | None -> dsg.fwd_info.effect_info - | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info - in - { info with is_rec = info.is_rec || Option.is_some lid } - | FunId (FAssumed _) -> - compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + match lid with + | None -> ( + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid) + | Some lid -> ( + (* This is necessarily for the current function *) + match fun_id with + | FunId (FRegular fid) -> ( + assert (fid = ctx.fun_decl.def_id); + (* Lookup the loop *) + let lid = V.LoopId.Map.find lid ctx.loop_ids_map in + let loop_info = LoopId.Map.find lid ctx.loops in + match gid with + | None -> loop_info.fwd_effect_info + | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos) + | _ -> raise (Failure "Unreachable")) (** Translate a function signature to a decomposed function signature. @@ -1901,7 +1917,7 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: in case we merge the forward/backward functions, we introduce those in [translate_forward_end]. - *) +*) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -3381,31 +3397,47 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let ctx = !ctx in (* The output type of the loop function *) - let output_ty = + let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in + let back_effect_infos, output_ty = if !Config.return_back_funs then (* The loop backward functions consume the same additional inputs as the parent function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_tys = - List.filter_map - (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in let outputs = given_back in (* Filter if necessary *) - if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty) + let ty = + if + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) (List.combine back_sgs given_back_tys) in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in let output = if ctx.sg.fwd_info.ignore_output then back_tys else ctx.sg.fwd_output :: back_tys @@ -3416,27 +3448,42 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) else - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output + let back_info = + RegionGroupId.Map.of_list + (List.map + (fun ((id, back_sg) : _ * back_sg_info) -> + (id, { back_sg.effect_info with is_rec = true })) + (RegionGroupId.Map.bindings ctx.sg.back_sg)) + in + let output = + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = + T.RegionGroupId.Map.find rg_id rg_to_given_back_tys + in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3480,6 +3527,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; back_funs = None; + fwd_effect_info; + back_effect_infos; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in -- cgit v1.2.3 From 9a8e43df626400aacdfcb9d2cf2eec38d71d2d73 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 23:04:31 +0100 Subject: Fix minor issues --- compiler/ExtractBase.ml | 81 ++++++++++++++++++++++++++++++---------------- compiler/PureUtils.ml | 2 +- compiler/SymbolicToPure.ml | 57 ++++++++++++++++++++++++++------ 3 files changed, 103 insertions(+), 37 deletions(-) (limited to 'compiler') diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 0af7a9b4..db887539 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1051,33 +1051,60 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = let assumed_llbc_functions () : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = let rg0 = Some T.RegionGroupId.zero in - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexShared, None, "array_index_usize"); - (ArrayIndexMut, None, "array_index_usize"); - (ArrayIndexMut, rg0, "array_update_usize"); - (ArrayToSliceShared, None, "array_to_slice"); - (ArrayToSliceMut, None, "array_to_slice"); - (ArrayToSliceMut, rg0, "array_from_slice"); - (ArrayRepeat, None, "array_repeat"); - (SliceIndexShared, None, "slice_index_usize"); - (SliceIndexMut, None, "slice_index_usize"); - (SliceIndexMut, rg0, "slice_update_usize"); - ] - | Lean -> - [ - (ArrayIndexShared, None, "Array.index_usize"); - (ArrayIndexMut, None, "Array.index_usize"); - (ArrayIndexMut, rg0, "Array.update_usize"); - (ArrayToSliceShared, None, "Array.to_slice"); - (ArrayToSliceMut, None, "Array.to_slice"); - (ArrayToSliceMut, rg0, "Array.from_slice"); - (ArrayRepeat, None, "Array.repeat"); - (SliceIndexShared, None, "Slice.index_usize"); - (SliceIndexMut, None, "Slice.index_usize"); - (SliceIndexMut, rg0, "Slice.update_usize"); - ] + let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, None, "array_index_usize"); + (ArrayToSliceShared, None, "array_to_slice"); + (ArrayRepeat, None, "array_repeat"); + (SliceIndexShared, None, "slice_index_usize"); + ] + | Lean -> + [ + (ArrayIndexShared, None, "Array.index_usize"); + (ArrayToSliceShared, None, "Array.to_slice"); + (ArrayRepeat, None, "Array.repeat"); + (SliceIndexShared, None, "Slice.index_usize"); + ] + in + let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + if !Config.return_back_funs then + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_mut_usize"); + (ArrayToSliceMut, None, "array_to_slice_mut"); + (SliceIndexMut, None, "slice_index_mut_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_mut_usize"); + (ArrayToSliceMut, None, "Array.to_slice_mut"); + (SliceIndexMut, None, "Slice.index_mut_usize"); + ] + else + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_usize"); + (ArrayIndexMut, rg0, "array_update_usize"); + (ArrayToSliceMut, None, "array_to_slice"); + (ArrayToSliceMut, rg0, "array_from_slice"); + (SliceIndexMut, None, "slice_index_usize"); + (SliceIndexMut, rg0, "slice_update_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_usize"); + (ArrayIndexMut, rg0, "Array.update_usize"); + (ArrayToSliceMut, None, "Array.to_slice"); + (ArrayToSliceMut, rg0, "Array.from_slice"); + (SliceIndexMut, None, "Slice.index_usize"); + (SliceIndexMut, rg0, "Slice.update_usize"); + ] + in + regular @ mut_funs let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index cc439e64..80bf3c42 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -731,7 +731,7 @@ let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) (e : texpression) : texpression = let vars = List.combine vars mps in - List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars + List.fold_right (fun (v, mp) e -> mk_lambda_from_var v mp e) vars e let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = match e.e with diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f0d1ca62..3a50e495 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -734,11 +734,15 @@ let rec translate_back_ty (type_infos : type_infos) None | TTraitType (trait_ref, generics, type_name) -> assert (generics.regions = []); - (* Translate the trait ref and the generics as "forward" generics - - we do not want to filter any type *) - let trait_ref = translate_fwd_trait_ref type_infos trait_ref in - let generics = translate_fwd_generic_args type_infos generics in - Some (TTraitType (trait_ref, generics, type_name)) + assert ( + AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id); + if inside_mut then + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TTraitType (trait_ref, generics, type_name)) + else None | TArrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) @@ -1056,7 +1060,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed Upon ending the abstraction for 'a, we need to get back the borrow the function returned. *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + let inputs = + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let output = Print.Types.ty_to_string ctx sg.output in + let inputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) inputs + in + "translate_back_inputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n")); + inputs in let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = @@ -1080,7 +1098,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let outputs = List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs in - List.split outputs + let names, outputs = List.split outputs in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let inputs = + Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs + in + let outputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) outputs + in + "compute_back_outputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n")); + (names, outputs) in let compute_back_info_for_group (rg : T.region_var_group) : RegionGroupId.id * back_sg_info = @@ -1201,8 +1233,15 @@ let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) (fun (v : LlbcAst.var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature - input_names + let sg = + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + in + log#ldebug + (lazy + ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: " + ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n")); + sg let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = -- cgit v1.2.3 From b6ef8ee33802e75409c3bd2b82e7b5ad22f1d053 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 23 Dec 2023 00:41:25 +0100 Subject: Improve the micro passes to eliminate pattern `let f := fun x => g x` --- compiler/PureMicroPasses.ml | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index e7e9d5e1..fa025d93 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -684,6 +684,15 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let y1 = x1 in ... ]} + + Simplify arrows: + {[ + let f := fun x => g x in + ... + ~~> + let f := g in + ... + ]} *) let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let obj = @@ -739,6 +748,23 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = super#visit_expression env e.e | _ -> super#visit_Let env monadic lv rv next else super#visit_Let env monadic lv rv next + | Lambda _ -> + if not monadic then + (* Arrow case *) + let pats, e = destruct_lambdas rv in + let g, args = destruct_apps e in + if List.length pats = List.length args then + (* Check if the arguments are exactly the lambdas *) + let check_pat_arg ((pat, arg) : typed_pattern * texpression) = + match (pat.value, arg.e) with + | PatVar (v, _), Var vid -> v.id = vid + | _ -> false + in + if List.for_all check_pat_arg (List.combine pats args) then + self#visit_Let env monadic lv g next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next | _ -> super#visit_Let env monadic lv rv next end in @@ -1934,9 +1960,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Inline the useless variable reassignments *) let inline_named_vars = true in let inline_pure = true in - let def = - inline_useless_var_reassignments ctx inline_named_vars inline_pure def + let inline_useless_var_reassignments ctx = + inline_useless_var_reassignments ctx inline_named_vars inline_pure in + let def = inline_useless_var_reassignments ctx def in log#ldebug (lazy ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -1982,6 +2009,20 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the let-bindings - some simplifications may have been unlocked by + the pass above (for instance, the lambda simplification) *) + let def = simplify_let_bindings ctx def in + log#ldebug + (lazy + ("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Inline the useless vars again *) + let def = inline_useless_var_reassignments ctx def in + log#ldebug + (lazy + ("inline_useless_var_assignments (pass 2):\n\n" + ^ fun_decl_to_string ctx def ^ "\n")); + (* Decompose the monadic let-bindings - used by Coq *) let def = if !Config.decompose_monadic_let_bindings then ( -- cgit v1.2.3 From ff9fe8aa1e13a7297f7c4f2c2554235361db038f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 23 Dec 2023 00:58:35 +0100 Subject: Update the micro-passes --- compiler/PureMicroPasses.ml | 171 ++++++++++++++++++++++++-------------------- 1 file changed, 94 insertions(+), 77 deletions(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index fa025d93..ec64df21 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -800,8 +800,8 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = leave the let-bindings where they are, and eliminated them in a subsequent pass (if they are useless). *) -let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) - (inline_pure : bool) (def : fun_decl) : fun_decl = +let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool) + ~(inline_const : bool) ~(inline_pure : bool) (def : fun_decl) : fun_decl = let obj = object (self) inherit [_] map_expression as super @@ -826,15 +826,31 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) | _ -> false in (* And either: - * 2.1 the right-expression is a variable, a global or a const generic var *) + 2.1 the right-expression is a variable, a global or a const generic var *) let var_or_global = is_var re || is_cvar re || is_global re in (* Or: - * 2.2 the right-expression is a constant value, an ADT value, - * a projection or a primitive function call *and* the flag - * [inline_pure] is set *) + 2.2 the right-expression is a constant-value and we inline constant values, + *or* it is a qualif with no arguments (we consider this as a const) *) + let const_re = + inline_const + && + let is_const_adt = + let app, args = destruct_apps re in + if args = [] then + match app.e with + | Qualif _ -> true + | StructUpdate upd -> upd.updates = [] + | _ -> false + else false + in + is_const re || is_const_adt + in + (* Or: + 2.3 the right-expression is an ADT value, a projection or a + primitive function call *and* the flag [inline_pure] is set *) let pure_re = - is_const re - || + inline_pure + && let app, _ = destruct_apps re in match app.e with | Qualif qualif -> ( @@ -849,7 +865,7 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) | _ -> false in let filter = - filter_left && (var_or_global || (inline_pure && pure_re)) + filter_left && (var_or_global || const_re || pure_re) in (* Update the rhs (we may perform substitutions inside, and it is @@ -1958,12 +1974,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (lazy ("simplify_let_bindings:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Inline the useless variable reassignments *) - let inline_named_vars = true in - let inline_pure = true in - let inline_useless_var_reassignments ctx = - inline_useless_var_reassignments ctx inline_named_vars inline_pure + let def = + inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true + ~inline_pure:true def in - let def = inline_useless_var_reassignments ctx def in log#ldebug (lazy ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -2017,7 +2031,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = ("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Inline the useless vars again *) - let def = inline_useless_var_reassignments ctx def in + let def = + inline_useless_var_reassignments ctx ~inline_named:true ~inline_const:true + ~inline_pure:false def + in log#ldebug (lazy ("inline_useless_var_assignments (pass 2):\n\n" @@ -2073,68 +2090,6 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* We are done *) def -(** Apply all the micro-passes to a function. - - As loops are initially directly integrated into the function definition, - {!apply_passes_to_def} extracts those loops definitions from the body; - it thus returns the pair: (function def, loop defs). See {!decompose_loops} - for more information. - - Will return [None] if the function is a backward function with no outputs. - - [ctx]: used only for printing. - *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - fun_and_loops option = - (* Debug *) - log#ldebug - (lazy - ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string def.back_id - ^ ")")); - - log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* First, find names for the variables which are unnamed *) - let def = compute_pretty_names def in - log#ldebug - (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* TODO: we might want to leverage more the assignment meta-data, for - * aggregates for instance. *) - - (* TODO: reorder the branches of the matches/switches *) - - (* The meta-information is now useless: remove it. - * Rk.: some passes below use the fact that we removed the meta-data - * (otherwise we would have to "unmeta" expressions before matching) *) - let def = remove_meta def in - log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Remove the backward functions with no outputs. - - Note that the *calls* to those functions should already have been removed, - when translating from symbolic to pure. Here, we remove the definitions - altogether, because they are now useless *) - let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in - let opt_def = filter_if_backward_with_no_outputs def in - - match opt_def with - | None -> - log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); - None - | Some def -> - log#ldebug - (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); - - (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops ctx def in - - (* Apply the remaining passes *) - let f = apply_end_passes_to_def ctx def in - let loops = List.map (apply_end_passes_to_def ctx) loops in - Some { f; loops } - (** Small utility for {!filter_loop_inputs} *) let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = let ls0, ls1 = Collections.List.split_at ls (List.length keep) in @@ -2458,6 +2413,68 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* Return *) transl +(** Apply all the micro-passes to a function. + + As loops are initially directly integrated into the function definition, + {!apply_passes_to_def} extracts those loops definitions from the body; + it thus returns the pair: (function def, loop defs). See {!decompose_loops} + for more information. + + Will return [None] if the function is a backward function with no outputs. + + [ctx]: used only for printing. + *) +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : + fun_and_loops option = + (* Debug *) + log#ldebug + (lazy + ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" + ^ Print.option_to_string T.RegionGroupId.to_string def.back_id + ^ ")")); + + log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* First, find names for the variables which are unnamed *) + let def = compute_pretty_names def in + log#ldebug + (lazy ("compute_pretty_name:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* TODO: we might want to leverage more the assignment meta-data, for + * aggregates for instance. *) + + (* TODO: reorder the branches of the matches/switches *) + + (* The meta-information is now useless: remove it. + * Rk.: some passes below use the fact that we removed the meta-data + * (otherwise we would have to "unmeta" expressions before matching) *) + let def = remove_meta def in + log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Remove the backward functions with no outputs. + + Note that the *calls* to those functions should already have been removed, + when translating from symbolic to pure. Here, we remove the definitions + altogether, because they are now useless *) + let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in + let opt_def = filter_if_backward_with_no_outputs def in + + match opt_def with + | None -> + log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); + None + | Some def -> + log#ldebug + (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); + + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops ctx def in + + (* Apply the remaining passes *) + let f = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + Some { f; loops } + (** Apply the micro-passes to a list of forward/backward translations. This function also extracts the loop definitions from the function body -- cgit v1.2.3 From a52939b5119e2751570582533bf27828724c2e9f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 23 Dec 2023 01:18:37 +0100 Subject: Fix an issue when deconstructing tuples in Coq --- compiler/Extract.ml | 10 ++++++++-- compiler/Main.ml | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 30b76ceb..87dcb1fd 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -132,9 +132,15 @@ let extract_adt_g_value F.pp_print_string fmt "tt"; ctx) else - (* If there is exactly one value, we don't print the parentheses *) + (* If there is exactly one value, we don't print the parentheses. + Also, for Coq, we need the special syntax ['(...)] if we destruct + a tuple pattern in a let-binding and the tuple has > 2 values. + *) let lb, rb = - if List.length field_values = 1 then ("", "") else ("(", ")") + if List.length field_values = 1 then ("", "") + else if !backend = Coq && is_single_pat && List.length field_values > 2 + then ("'(", ")") + else ("(", ")") in F.pp_print_string fmt lb; let ctx = diff --git a/compiler/Main.ml b/compiler/Main.ml index abc27b46..0b8ec439 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -196,9 +196,6 @@ let () = let _ = match !backend with | FStar -> - (* Some patterns are not supported *) - decompose_monadic_let_bindings := false; - decompose_nested_let_patterns := false; (* F* can disambiguate the field names *) record_fields_short_names := true | Coq -> -- cgit v1.2.3