diff options
author | Son Ho | 2022-05-04 15:39:05 +0200 |
---|---|---|
committer | Son Ho | 2022-05-04 15:39:05 +0200 |
commit | 15d90db02086f8ecae9a93ebf39c3c0ae8caa50f (patch) | |
tree | 3eb303b96c9233fd7745f06a0d4e2d211373ca03 | |
parent | fb6fdfd0c57de1ce16fb6bc373d5593c9446b0bb (diff) |
Fix some issues when using states
Diffstat (limited to '')
-rw-r--r-- | src/SymbolicToPure.ml | 149 | ||||
-rw-r--r-- | src/Translate.ml | 4 | ||||
-rw-r--r-- | src/main.ml | 3 |
3 files changed, 84 insertions, 72 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index fa482b8e..66f4d608 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -478,6 +478,30 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids +type fun_effect_info = { + can_fail : bool; + input_state : bool; + output_state : bool; +} +(** TODO: factorize with fun_sig_info? + TODO: use an enumeration + *) + +(** Small utility. *) +let get_fun_effect_info (config : config) (fun_id : A.fun_id) + (gid : T.RegionGroupId.id option) : fun_effect_info = + match fun_id with + | A.Regular _ -> + let input_state = config.use_state_monad in + let output_state = input_state && gid = None in + { can_fail = true; input_state; output_state } + | A.Assumed aid -> + { + can_fail = Assumed.assumed_is_monadic aid; + input_state = false; + output_state = false; + } + (** Translate a function signature. Note that the function also takes a list of names for the inputs, and @@ -485,9 +509,10 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : name (outputs for backward functions come from borrows in the inputs of the forward function). *) -let translate_fun_sig (config : config) (types_infos : TA.type_infos) - (sg : A.fun_sig) (input_names : string option list) - (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = +let translate_fun_sig (config : config) (fun_id : A.fun_id) + (types_infos : TA.type_infos) (sg : A.fun_sig) + (input_names : string option list) (bid : T.RegionGroupId.id option) : + fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) let gid, parents = match bid with @@ -537,15 +562,9 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos) in (* Does the function take a state as input, does it return a state and can * it fail? *) - (* For now, all the translated functions can fail *) - let can_fail = true in - (* For now, all translated functions have an input state if we setup - * the translation to use states *) - let input_state = config.use_state_monad in - (* Only the forward functions return a state *) - let output_state = input_state && bid = None in + let effect_info = get_fun_effect_info config fun_id bid in (* *) - let state_ty = if input_state then [ mk_state_ty ] else [] in + let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in (* Concatenate the inputs, in the following order: * - forward inputs * - state input @@ -589,10 +608,11 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos) let output = mk_simpl_tuple_ty doutputs in (* Add the output state *) let output = - if output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] else output + if effect_info.output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output in (* Wrap in a result type *) - if can_fail then mk_result_ty output else output + if effect_info.can_fail then mk_result_ty output else output in (* Type parameters *) let type_params = sg.type_params in @@ -602,9 +622,9 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos) num_fwd_inputs = List.length fwd_inputs; num_back_inputs = (if bid = None then None else Some (List.length back_inputs)); - input_state; - output_state; - can_fail; + input_state = effect_info.input_state; + output_state = effect_info.output_state; + can_fail = effect_info.can_fail; } in let sg = { type_params; inputs; output; doutputs; info } in @@ -1046,30 +1066,6 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : let abs_ancestors = list_ancestor_abstractions ctx abs in (call_info.forward, abs_ancestors) -type fun_effect_info = { - can_fail : bool; - input_state : bool; - output_state : bool; -} -(** TODO: factorize with fun_sig_info? - TODO: use an enumeration - *) - -(** Small utility. *) -let get_fun_effect_info (config : config) (fun_id : A.fun_id) - (gid : T.RegionGroupId.id option) : fun_effect_info = - match fun_id with - | A.Regular _ -> - let input_state = config.use_state_monad in - let output_state = input_state && gid = None in - { can_fail = true; input_state; output_state } - | A.Assumed aid -> - { - can_fail = Assumed.assumed_is_monadic aid; - input_state = false; - output_state = false; - } - let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = match e with @@ -1085,7 +1081,8 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression = * we don't match on panics which happen inside the function body - * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) - if config.use_state_monad then + (* Note that only forward functions return a state *) + if config.use_state_monad && ctx.bid <> None then (* Create the `Fail` value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.output_ty ] in let ret_v = mk_result_fail_texpression ret_ty in @@ -1109,22 +1106,21 @@ and translate_return (config : config) (opt_v : V.typed_value option) (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_texpression ctx v in - (* We don't synthesize the same expression depending on the monad we use: + (* We may need to return a state * - error-monad: Return x - * - state-error monad: fun state -> Return (state, x) + * - state-error: Return (state, x) * *) - (* TODO: we should use a `return` function, it would be cleaner *) if config.use_state_monad then - let _, state_var = - fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx + let state_var = + { + id = ctx.state_var; + basename = Some ConstStrings.state_basename; + ty = mk_state_ty; + } in let state_rvalue = mk_texpression_from_var state_var in - let ret_v = - mk_result_return_texpression - (mk_simpl_tuple_texpression [ state_rvalue; v ]) - in - let state_var = mk_typed_pattern_from_var state_var None in - mk_abs state_var ret_v + mk_result_return_texpression + (mk_simpl_tuple_texpression [ state_rvalue; v ]) else mk_result_return_texpression v | Some bid -> (* Backward function *) @@ -1137,22 +1133,11 @@ and translate_return (config : config) (opt_v : V.typed_value option) T.RegionGroupId.Map.find bid ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in - (* See the comment about the monads, for the forward function case *) + (* Backward functions never return a state *) (* TODO: we should use a `fail` function, it would be cleaner *) - if config.use_state_monad then - let _, state_var = fresh_var (Some "st") mk_state_ty ctx in - let state_rvalue = mk_texpression_from_var state_var in - let ret_value = mk_simpl_tuple_texpression field_values in - let ret_value = - mk_result_return_texpression - (mk_simpl_tuple_texpression [ state_rvalue; ret_value ]) - in - let state_var = mk_typed_pattern_from_var state_var None in - mk_abs state_var ret_value - else - let ret_value = mk_simpl_tuple_texpression field_values in - let ret_value = mk_result_return_texpression ret_value in - ret_value + let ret_value = mk_simpl_tuple_texpression field_values in + let ret_value = mk_result_return_texpression ret_value in + ret_value and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1311,7 +1296,8 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (* Retrieve the values consumed when we called the forward function and * ended the parent backward functions: those give us part of the input * values (rmk: for now, as we disallow nested lifetimes, there can't be - * parent backward functions) *) + * parent backward functions). + * Note that the forward inputs include the input state (if there is one). *) let fwd_inputs = call_info.forward_inputs in let back_ancestors_inputs = List.concat (List.map (fun (_abs, args) -> args) backwards) @@ -1345,6 +1331,16 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) let inst_sg = get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map show_texpression inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map show_ty inst_sg.inputs))); List.iter (fun (x, ty) -> assert ((x : texpression).ty = ty)) (List.combine inputs inst_sg.inputs); @@ -1715,6 +1711,16 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) List.map (fun v -> mk_typed_pattern_from_var v None) inputs in (* Sanity check *) + log#ldebug + (lazy + ("SymbolicToPure.translate_fun_decl:" ^ "\n- forward_inputs: " + ^ String.concat ", " (List.map show_var ctx.forward_inputs) + ^ "\n- input_state: " + ^ String.concat ", " (List.map show_var input_state) + ^ "\n- backward_inputs: " + ^ String.concat ", " (List.map show_var backward_inputs) + ^ "\n- signature.inputs: " + ^ String.concat ", " (List.map show_ty signature.inputs))); assert ( List.for_all (fun (var, ty) -> (var : var).ty = ty) @@ -1756,14 +1762,17 @@ let translate_fun_signatures (config : config) (types_infos : TA.type_infos) let translate_one (fun_id : A.fun_id) (input_names : string option list) (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = (* The forward function *) - let fwd_sg = translate_fun_sig config types_infos sg input_names None in + let fwd_sg = + translate_fun_sig config fun_id types_infos sg input_names None + in let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig config types_infos sg input_names (Some rg.id) + translate_fun_sig config fun_id types_infos sg input_names + (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) diff --git a/src/Translate.ml b/src/Translate.ml index d69f1379..92261dba 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -232,8 +232,10 @@ let translate_function_to_pure (config : C.partial_config) let backward_sg = RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs in + (* We need to ignore the forward inputs, and the state input (if there is) *) let _, backward_inputs = - Collections.List.split_at backward_sg.sg.inputs num_forward_inputs + Collections.List.split_at backward_sg.sg.inputs + (num_forward_inputs + if use_state_monad then 1 else 0) in (* As we forbid nested borrows, the additional inputs for the backward * functions come from the borrows in the return value of the rust function: diff --git a/src/main.ml b/src/main.ml index e635d910..8afbc0cd 100644 --- a/src/main.ml +++ b/src/main.ml @@ -122,7 +122,8 @@ let () = (* Set up the logging - for now we use default values - TODO: use the * command-line arguments *) - Easy_logging.Handlers.set_level main_logger_handler EL.Info; + (* By setting a level for the main_logger_handler, we filter everything *) + Easy_logging.Handlers.set_level main_logger_handler EL.Debug; main_log#set_level EL.Info; llbc_of_json_logger#set_level EL.Info; pre_passes_log#set_level EL.Info; |