diff options
-rw-r--r-- | compiler/Config.ml | 63 | ||||
-rw-r--r-- | compiler/Extract.ml | 8 | ||||
-rw-r--r-- | compiler/InterpreterExpressions.mli | 2 | ||||
-rw-r--r-- | compiler/Pure.ml | 54 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 91 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 23 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 73 | ||||
-rw-r--r-- | compiler/Translate.ml | 8 |
8 files changed, 246 insertions, 76 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index b09544ba..b8af6c6d 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -92,6 +92,69 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) +(** If true, do not define separate forward/backward functions, but make the + forward functions return the backward function. + + Example: + {[ + (* Rust *) + pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T { + match l { + List::Nil => { + panic!() + } + List::Cons(x, tl) => { + if i == 0 { + x + } else { + list_nth(tl, i - 1) + } + } + } + } + + (* Translation, if return_back_funs = false *) + def list_nth (T : Type) (l : List T) (i : U32) : Result T := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret x + else do + let i0 ← i - 1#u32 + list_nth T tl i0 + | List.Nil => Result.fail .panic + + def list_nth_back + (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (List.Cons ret tl) + else + do + let i0 ← i - 1#u32 + let tl0 ← list_nth_back T tl i0 ret + Result.ret (List.Cons x tl0) + | List.Nil => Result.fail .panic + + (* Translation, if return_back_funs = true *) + def list_nth (T: Type) (ls : List T) (i : U32) : + Result (T × (T → Result (List T))) := + match ls with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (x, (λ ret => return (ret :: ls))) + else do + let i0 ← i - 1#u32 + let (x, back) ← list_nth ls i0 + Return.ret (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + ]} + *) +let return_back_funs = ref true + (** Forbids using field projectors for structures. If we don't use field projectors, whenever we symbolically expand a structure diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 20cdb20b..93fcf416 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1469,8 +1469,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in + (* TODO: *) + assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1515,8 +1517,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) if has_decreases_clause && !backend = Lean then ( let def_body = Option.get def.body in let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in + (* TODO: *) + assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli index f8d979f4..b975371c 100644 --- a/compiler/InterpreterExpressions.mli +++ b/compiler/InterpreterExpressions.mli @@ -52,7 +52,7 @@ val eval_operands : Transmits the computed rvalue to the received continuation. - Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant + Note that this function fails on {!Aeneas.Expressions.rvalue.Discriminant}: discriminant reads should have been eliminated from the AST. *) val eval_rvalue_not_global : diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 8d39cc69..c3716001 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref = (** A function id for a non-assumed function *) type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option + fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -860,8 +860,8 @@ type fun_effect_info = { the set [{ forward function } U { backward functions }]. We need this because of the option {!val:Config.backward_no_state_update}: - if it is [true], then in case of a backward function {!stateful} is [false], - but we might need to know whether the corresponding forward function + if it is [true], then in case of a backward function {!stateful} might be + [false], but we might need to know whether the corresponding forward function is stateful or not. *) stateful : bool; (** [true] if the function is stateful (updates a state) *) @@ -873,21 +873,41 @@ type fun_effect_info = { } [@@deriving show] -(** Meta information about a function signature *) -type fun_sig_info = { +type inputs_info = { has_fuel : bool; - (* TODO: add [num_fwd_inputs_no_fuel_no_state] *) - num_fwd_inputs_with_fuel_no_state : int; - (** The number of input types for forward computation, with the fuel (if used) + num_inputs_no_fuel_no_state : int; + (** The number of input types ignoring the fuel (if used) and ignoring the state (if used) *) - num_fwd_inputs_with_fuel_with_state : int; - (** The number of input types for forward computation, with fuel and state (if used) *) - num_back_inputs_no_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - ignoring the state (if there is one) *) - num_back_inputs_with_state : int option; - (** The number of additional inputs for the backward computation (if pertinent), - with the state (if there is one) *) + num_inputs_with_fuel_no_state : int; + (** The number of input types, with the fuel (if used) + and ignoring the state (if used) *) + num_inputs_with_fuel_with_state : int; + (** The number of input types, with fuel and state (if used) *) +} +[@@deriving show] + +type 'a back_info = + | SingleBack of 'a option + (** Information about a single backward function, if pertinent. + + We use this variant if we split the forward and the backward functions. + *) + | AllBacks of 'a RegionGroupId.Map.t + (** Information about the various backward functions. + + We use this if we *do not* split the forward and the backward functions. + All the information is then carried by the forward function. + *) +[@@deriving show] + +type back_inputs_info = inputs_info back_info [@@deriving show] + +(** Meta information about a function signature *) +type fun_sig_info = { + fwd_info : inputs_info; + (** Information about the inputs of the forward function *) + back_info : back_inputs_info; + (** Information about the inputs of the backward functions. *) effect_info : fun_effect_info; } [@@deriving show] @@ -1020,7 +1040,7 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : T.RegionGroupId.id option; + back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 959ec1c8..7f122f15 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1326,6 +1326,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let fun_sig_info = fun_sig.info in let fun_effect_info = fun_sig_info.effect_info in + (* TODO: *) + assert (not !Config.return_back_funs); + (* Generate the loop definition *) let loop_effect_info = { @@ -1340,36 +1343,44 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in - let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in - let fwd_state = - fun_sig_info.num_fwd_inputs_with_fuel_with_state - - fun_sig_info.num_fwd_inputs_with_fuel_no_state - in - let num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_no_state + fwd_state + let fwd_info : inputs_info = + let info = fun_sig_info.fwd_info in + let fwd_state = + info.num_inputs_with_fuel_with_state + - info.num_inputs_with_fuel_no_state + in + { + has_fuel = !Config.use_fuel; + num_inputs_no_fuel_no_state = num_inputs; + num_inputs_with_fuel_no_state = num_inputs + fuel; + num_inputs_with_fuel_with_state = + num_inputs + fuel + fwd_state; + } in + { - has_fuel = !Config.use_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; - num_back_inputs_with_state = - fun_sig_info.num_back_inputs_with_state; + fwd_info; + back_info = fun_sig_info.back_info; effect_info = loop_effect_info; } in + assert (fun_sig_info_is_wf loop_sig_info); let inputs_tys = + (* TODO: *) + assert (not !Config.return_back_funs); + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let info = fun_sig_info.fwd_info in let state = Collections.List.subslice fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_no_state - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_no_state + info.num_inputs_with_fuel_with_state in let _, back_inputs = Collections.List.split_at fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in List.concat [ fuel; fwd_inputs; state; back_inputs ] in @@ -1430,14 +1441,17 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = in (* Introduce the additional backward inputs *) + (* TODO: *) + assert (not !Config.return_back_funs); let fun_body = Option.get def.body in + let info = fun_sig_info.fwd_info in let _, back_inputs = Collections.List.split_at fun_body.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let _, back_inputs_lvs = Collections.List.split_at fun_body.inputs_lvs - fun_sig_info.num_fwd_inputs_with_fuel_with_state + info.num_inputs_with_fuel_with_state in let inputs = @@ -2053,12 +2067,14 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = + (* TODO: *) + assert (not !Config.return_back_funs); (* There should be a body *) let body = Option.get decl.body in (* We only look at the forward inputs, without the state *) let inputs_prefix, _ = Collections.List.split_at body.inputs - decl.signature.info.num_fwd_inputs_with_fuel_no_state + decl.signature.info.fwd_info.num_inputs_with_fuel_no_state in let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in let inputs_prefix_length = List.length inputs_prefix in @@ -2078,7 +2094,10 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* Set the fuel as used *) let sg_info = decl.signature.info in - if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); + (* TODO: *) + assert (not !Config.return_back_funs); + if sg_info.fwd_info.has_fuel then + set_used (fst (Collections.List.nth inputs 0)); let visitor = object (self : 'self) @@ -2166,31 +2185,35 @@ let filter_loop_inputs (transl : pure_fun_translation list) : = decl.signature in + (* TODO: *) + assert (not !Config.return_back_funs); + let { fwd_info; back_info; effect_info } = info in + let { has_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; } = - info + fwd_info in let inputs = filter_prefix used_info inputs in - let info = + let fwd_info = { has_fuel; - num_fwd_inputs_with_fuel_no_state = - num_fwd_inputs_with_fuel_no_state - num_filtered; - num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_with_state - num_filtered; - num_back_inputs_no_state; - num_back_inputs_with_state; - effect_info; + num_inputs_no_fuel_no_state = + num_inputs_no_fuel_no_state - num_filtered; + num_inputs_with_fuel_no_state = + num_inputs_with_fuel_no_state - num_filtered; + num_inputs_with_fuel_with_state = + num_inputs_with_fuel_with_state - num_filtered; } in + + let info = { fwd_info; back_info; effect_info } in + assert (fun_sig_info_is_wf info); let signature = { generics; llbc_generics; preds; inputs; output; doutputs; info } in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 39dcd52d..3c038149 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -57,6 +57,29 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType) +let inputs_info_is_wf (info : inputs_info) : bool = + let { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state; + } = + info + in + let fuel = if has_fuel then 1 else 0 in + num_inputs_no_fuel_no_state >= 0 + && num_inputs_with_fuel_no_state = num_inputs_no_fuel_no_state + fuel + && num_inputs_with_fuel_with_state >= num_inputs_with_fuel_no_state + +let fun_sig_info_is_wf (info : fun_sig_info) : bool = + inputs_info_is_wf info.fwd_info + && + match info.back_info with + | SingleBack None -> true + | SingleBack (Some info) -> inputs_info_is_wf info + | AllBacks infos -> + List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos) + let dest_arrow_ty (ty : ty) : ty * ty = match ty with | TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 84f09280..1fd4896e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -855,10 +855,14 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names in the extracted code. + + We use [bid] ("backward function id") only if we split the forward + and the backward functions. *) let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + assert (Option.is_none bid || not !Config.return_back_funs); let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) @@ -939,6 +943,18 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in + let translate_back_inputs_for_gid gid : ty list = + (* For now, we don't allow nested borrows, so the additional inputs to the + backward function can only come from borrows that were returned like + in (for the backward function we introduce for 'a): + {[ + fn f<'a>(...) -> &'a mut u32; + ]} + Upon ending the abstraction for 'a, we need to get back the borrow + the function returned. + *) + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = @@ -1034,32 +1050,47 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Generic parameters *) let generics = translate_generic_params sg.generics in (* Return *) + (* TODO: *) + assert (not !Config.return_back_funs); let has_fuel = fuel <> [] in - let num_fwd_inputs_no_state = List.length fwd_inputs in + let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in let num_fwd_inputs_with_fuel_no_state = (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_state + List.length fuel + num_fwd_inputs_no_fuel_no_state in let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) in - let info = + let fwd_info : inputs_info = { has_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state = + num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, and 0 otherwise *) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - num_back_inputs_no_state; - num_back_inputs_with_state = - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - Option.map - (fun n -> n + List.length back_state_ty) - num_back_inputs_no_state; - effect_info; } in + let back_info : back_inputs_info = + if !Config.return_back_funs then + SingleBack + (Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state) + else (* Create the map *) + failwith "TODO" + in + let info = { fwd_info; back_info; effect_info } in + assert (fun_sig_info_is_wf info); let preds = translate_predicates sg.preds in let sg = { @@ -3151,14 +3182,16 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy + if !Config.return_back_funs then [] + else + List.map + (fun (rg : T.region_var_group) -> + let tsg = + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) + in + let id = (fun_id, Some rg.id) in + (id, tsg)) + regions_hierarchy in (* Return *) (fwd_id, fwd_sg) :: back_sgs diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 221d4e73..54e24066 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -216,11 +216,15 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = let sg = backward_sg.sg in + (* TODO: *) + assert (not !Config.return_back_funs); (* We need to ignore the forward state and the backward state *) let num_forward_inputs = - sg.info.num_fwd_inputs_with_fuel_with_state + sg.info.fwd_info.num_inputs_with_fuel_with_state + in + let num_back_inputs = + (Option.get sg.info.back_info).num_inputs_no_fuel_no_state in - let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in Collections.List.subslice sg.inputs num_forward_inputs (num_forward_inputs + num_back_inputs) in |