From fb6fdfd0c57de1ce16fb6bc373d5593c9446b0bb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 4 May 2022 14:13:20 +0200 Subject: Make progress updating the code --- src/ExtractToFStar.ml | 5 +- src/PureMicroPasses.ml | 191 +++++++------------------------------------------ src/SymbolicToPure.ml | 11 ++- src/Translate.ml | 56 ++++++++++----- src/main.ml | 7 +- 5 files changed, 76 insertions(+), 194 deletions(-) (limited to 'src') diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 068448e9..4baa1fd6 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -1368,8 +1368,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) if has_decreases_clause then ( F.pp_print_string fmt "Tot"; F.pp_print_space fmt ()); - extract_ty ctx fmt has_decreases_clause - (Collections.List.to_cons_nil def.signature.outputs); + extract_ty ctx fmt has_decreases_clause def.signature.output; (* Close the box for the return type *) F.pp_close_box fmt (); (* Print the decrease clause - rk.: a function with a decreases clause @@ -1476,7 +1475,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) if sg.type_params = [] && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) - && sg.outputs = [ mk_result_ty mk_unit_ty ] + && sg.output = mk_result_ty mk_unit_ty then ( (* Add a break before *) F.pp_print_break fmt 0 0; diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index f76dd2f4..0c371420 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -63,7 +63,6 @@ type config = { borrows as inputs, it can't return mutable borrows; we actually dynamically check for that). *) - use_state_monad : bool; (** TODO: remove *) } (** A configuration to control the application of the passes *) @@ -920,15 +919,9 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) *) let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) : fun_decl option = - let return_ty = - if config.use_state_monad then - mk_arrow mk_state_ty - (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; mk_unit_ty ])) - else mk_result_ty (mk_simpl_tuple_ty [ mk_unit_ty ]) - in if config.filter_useless_functions && Option.is_some def.back_id - && def.signature.outputs = [ return_ty ] + && def.signature.output = mk_result_ty mk_unit_ty then None else Some def @@ -954,7 +947,7 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = * they should be lists of length 1. *) if config.filter_useless_functions - && fwd.signature.outputs = [ mk_result_ty mk_unit_ty ] + && fwd.signature.output = mk_result_ty mk_unit_ty && backs <> [] then false else true @@ -1108,85 +1101,27 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) - (def : fun_decl) : fun_decl = +let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> - (* We may need to introduce fresh variables for the state *) - let fresh_var_id = - let var_cnt = get_body_min_var_counter body in - let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in - fresh_var_id - in - let fresh_state_var () = - let id = fresh_var_id () in - { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } - in (* It is a very simple map *) let obj = object (_self) inherit [_] map_expression as super - method! visit_Switch env scrut switch_body = - (* We transform the switches the following way (if their branches - * are stateful): - * ``` - * match x with - * | Pati -> branchi - * - * ~~> - * - * fun st -> - * match x with - * | Pati -> branchi st - * ``` - * - * The reason is that after unfolding the monadic lets, we often - * have this: `(match x with | ...) st`, and we want to "push" the - * `st` variable inside. - *) - let sb_ty = get_switch_body_ty switch_body in - if Option.is_some (opt_destruct_state_monad_result sb_ty) then - (* Generate a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_texpression_from_var state_var in - let state_lvar = mk_typed_pattern_from_var state_var None in - (* Apply in all the branches and reconstruct the switch *) - let mk_app e = mk_app e state_value in - let switch_body = map_switch_body_branches mk_app switch_body in - let e = mk_switch scrut switch_body in - let e = mk_abs state_lvar e in - (* Introduce the lambda and continue - * Rk.: we will revisit the switch, but won't loop because its - * type has now changed (the `state -> ...` disappeared) *) - super#visit_Abs env state_lvar e - else super#visit_Switch env scrut switch_body - method! visit_Let env monadic lv re e = - (* For now, we do the following transformation: + (* We simply do the following transformation: * ``` - * x <-- re; e + * pat <-- re; e * * ~~> * - * (fun st -> - * match re st with - * | Return (st', x) -> e st' - * | Fail err -> Fail err) + * match re with + * | Fail err -> Fail err + * | Return pat -> e * ``` - * - * We rely on the simplification pass which comes later to normalize - * away expressions like `(fun x -> e) y`. - * - * TODO: fix the use of state-error monads (with the bakward functions, - * we apply some updates twice... - * It would be better if symbolic to pure generated code of the - * following shape: - * `(st1, x) <-- e st0` - * Then, this micro-pass would only expand the monadic let-bindings - * (we wouldn't need to introduce state variables). - * *) + *) (* TODO: we should use a monad "kind" instead of a boolean *) if not monadic then super#visit_Let env monadic lv re e else @@ -1197,95 +1132,24 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) *) (* TODO: this information should be computed in SymbolicToPure and * store in an enum ("monadic" should be an enum, not a bool). *) - let re_uses_state = - Option.is_some (opt_destruct_state_monad_result re.ty) - in - if re_uses_state then ( - let e0 = e in - (* Create a fresh state variable *) - let state_var = fresh_state_var () in - (* The type of `e` is: `state -> e_no_arrow_ty` *) - let _, e_no_arrow_ty = destruct_arrow e.ty in - let e_no_monad_ty = destruct_result e_no_arrow_ty in - let _, re_no_arrow_ty = destruct_arrow re.ty in - let re_no_monad_ty = destruct_result re_no_arrow_ty in - (* Add the state argument on the right-expression *) - let re = - let state_value = mk_texpression_from_var state_var in - mk_app re state_value - in - (* Create the match *) - let fail_pat = mk_result_fail_pattern re_no_monad_ty in - let fail_value = mk_result_fail_texpression e_no_monad_ty in - let fail_branch = { pat = fail_pat; branch = fail_value } in - (* The `Success` branch introduces a fresh state variable *) - let pat_state_var = fresh_state_var () in - let pat_state_pattern = - mk_typed_pattern_from_var pat_state_var None - in - let success_pat = - mk_result_return_pattern - (mk_simpl_tuple_pattern [ pat_state_pattern; lv ]) - in - let pat_state_rvalue = mk_texpression_from_var pat_state_var in - (* TODO: write a utility to create matches (and perform - * type-checking, etc.) *) - let success_branch = - { pat = success_pat; branch = mk_app e pat_state_rvalue } - in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - let e = { e; ty = e_no_arrow_ty } in - (* Add the lambda to introduce the state variable *) - let e = mk_abs (mk_typed_pattern_from_var state_var None) e in - (* Sanity check *) - assert (e0.ty = e.ty); - assert (fail_branch.branch.ty = success_branch.branch.ty); - (* Continue *) - super#visit_expression env e.e) - else - let re_ty = Option.get (opt_destruct_result re.ty) in - assert (lv.ty = re_ty); - let fail_pat = mk_result_fail_pattern lv.ty in - let fail_value = mk_result_fail_texpression e.ty in - let fail_branch = { pat = fail_pat; branch = fail_value } in - let success_pat = mk_result_return_pattern lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - (* Continue *) - super#visit_expression env e + let re_ty = Option.get (opt_destruct_result re.ty) in + assert (lv.ty = re_ty); + let fail_pat = mk_result_fail_pattern lv.ty in + let fail_value = mk_result_fail_texpression e.ty in + let fail_branch = { pat = fail_pat; branch = fail_value } in + let success_pat = mk_result_return_pattern lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + (* Continue *) + super#visit_expression env e end in - (* Update the body: add *) - let body, signature = - let state_var = fresh_state_var () in - (* First, unfold the expressions inside the body *) - let body_e = obj#visit_texpression () body.body in - (* Then, add a "state" input variable if necessary: *) - if config.use_state_monad then - (* - in the body *) - let state_rvalue = mk_texpression_from_var state_var in - let body_e = mk_app body_e state_rvalue in - (* - in the signature *) - let sg = def.signature in - (* Input types *) - let sg_inputs = sg.inputs @ [ mk_state_ty ] in - (* Output types *) - let sg_outputs = Collections.List.to_cons_nil sg.outputs in - let _, sg_outputs = dest_arrow_ty sg_outputs in - let sg_outputs = [ sg_outputs ] in - let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in - (* Input list *) - let inputs = body.inputs @ [ state_var ] in - let input_lv = mk_typed_pattern_from_var state_var None in - let inputs_lvs = body.inputs_lvs @ [ input_lv ] in - let body = { body = body_e; inputs; inputs_lvs } in - (body, sg) - else ({ body with body = body_e }, def.signature) - in + (* Update the body *) + let body_e = obj#visit_texpression () body.body in + let body = { body with body = body_e } in (* Return *) - { def with body = Some body; signature } + { def with body = Some body } (** Apply all the micro-passes to a function. @@ -1359,12 +1223,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Decompose the monadic let-bindings - F* specific - * TODO: remove? With the state-error monad, it is becoming completely - * ad-hoc. *) + * TODO: remove? *) let def = if config.decompose_monadic_let_bindings then ( - (* TODO: we haven't updated the code to handle the state-error monad *) - assert (not config.use_state_monad); let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy @@ -1381,7 +1242,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : (* Unfold the monadic let-bindings *) let def = if config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings config ctx def in + let def = unfold_monadic_let_bindings ctx def in log#ldebug (lazy ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 466e5562..fa482b8e 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -101,7 +101,12 @@ type bs_ctx = { fun_context : fun_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) - ret_ty : ty; (** The return type - we use it to translate `Panic` *) + output_ty : ty; + (** The output type - we use it to translate `Panic`. + + This should be the directly translated output type (i.e., no state, + no result). + *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -1082,7 +1087,7 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression = (* If we use a state monad, we need to add a lambda for the state variable *) if config.use_state_monad then (* Create the `Fail` value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.output_ty ] in let ret_v = mk_result_fail_texpression ret_ty in (* Add the lambda *) let _, state_var = @@ -1090,7 +1095,7 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression = in let state_pattern = mk_typed_pattern_from_var state_var None in mk_abs state_pattern ret_v - else mk_result_fail_texpression ctx.ret_ty + else mk_result_fail_texpression ctx.output_ty and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = diff --git a/src/Translate.ml b/src/Translate.ml index d9b42f6b..d69f1379 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -15,6 +15,12 @@ let log = TranslateCore.log type config = { eval_config : Contexts.partial_config; mp_config : Micro.config; + use_state_monad : bool; + (** If `true`, use a state-error monad. + If `false`, only use an error monad. + + Using a state-error monad is necessary when modelling I/O, for instance. + *) split_files : bool; (** Controls whether we split the generated definitions between different files for the types, clauses and functions, or if we group them in @@ -92,7 +98,7 @@ let translate_function_to_symbolics (config : C.partial_config) TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (config : C.partial_config) - (mp_config : Micro.config) (trans_ctx : trans_ctx) + (mp_config : Micro.config) (use_state_monad : bool) (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) : pure_fun_translation = @@ -116,13 +122,14 @@ let translate_function_to_pure (config : C.partial_config) (* Initialize the context *) let forward_sig = RegularFunIdMap.find (A.Regular def_id, None) fun_sigs in - let forward_ret_ty = - match forward_sig.sg.outputs with + let forward_output_ty = + match forward_sig.sg.doutputs with | [ ty ] -> ty | _ -> failwith "Unreachable" in let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in + let state_var, var_counter = Pure.VarId.fresh var_counter in let calls = V.FunCallId.Map.empty in let abstractions = V.AbstractionId.Map.empty in let type_context = @@ -139,10 +146,11 @@ let translate_function_to_pure (config : C.partial_config) { SymbolicToPure.bid = None; (* Dummy for now *) - ret_ty = forward_ret_ty; + output_ty = forward_output_ty; (* Will need to be updated for the backward functions *) sv_to_var; var_counter; + state_var; type_context; fun_context; fun_decl = fdef; @@ -179,7 +187,7 @@ let translate_function_to_pure (config : C.partial_config) { SymbolicToPure.filter_useless_back_calls = mp_config.filter_useless_monadic_calls; - use_state_monad = mp_config.use_state_monad; + use_state_monad; } in @@ -208,8 +216,10 @@ let translate_function_to_pure (config : C.partial_config) let backward_sg = RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs in - let backward_ret_ty = mk_simpl_tuple_ty backward_sg.sg.outputs in - let ctx = { ctx with bid = Some back_id; ret_ty = backward_ret_ty } in + let backward_output_ty = mk_simpl_tuple_ty backward_sg.sg.doutputs in + let ctx = + { ctx with bid = Some back_id; output_ty = backward_output_ty } + in (* Translate *) SymbolicToPure.translate_fun_decl sp_config ctx None @@ -238,7 +248,7 @@ let translate_function_to_pure (config : C.partial_config) * present in the input values of the rust function: for those we reuse * the names of the input values. *) let backward_outputs = - List.combine backward_sg.output_names backward_sg.sg.outputs + List.combine backward_sg.output_names backward_sg.sg.doutputs in let ctx, backward_outputs = SymbolicToPure.fresh_vars backward_outputs ctx @@ -246,7 +256,7 @@ let translate_function_to_pure (config : C.partial_config) let backward_output_tys = List.map (fun (v : Pure.var) -> v.ty) backward_outputs in - let backward_ret_ty = mk_simpl_tuple_ty backward_output_tys in + let backward_output_ty = mk_simpl_tuple_ty backward_output_tys in let backward_inputs = T.RegionGroupId.Map.singleton back_id backward_inputs in @@ -259,7 +269,7 @@ let translate_function_to_pure (config : C.partial_config) { ctx with bid = Some back_id; - ret_ty = backward_ret_ty; + output_ty = backward_output_ty; backward_inputs; backward_outputs; } @@ -276,7 +286,7 @@ let translate_function_to_pure (config : C.partial_config) (pure_forward, pure_backwards) let translate_module_to_pure (config : C.partial_config) - (mp_config : Micro.config) (m : M.llbc_module) : + (mp_config : Micro.config) (use_state_monad : bool) (m : M.llbc_module) : trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list = (* Debug *) log#ldebug (lazy "translate_module_to_pure"); @@ -316,15 +326,23 @@ let translate_module_to_pure (config : C.partial_config) m.functions in let sigs = List.append assumed_sigs local_sigs in + let sp_config = + { + SymbolicToPure.filter_useless_back_calls = + mp_config.filter_useless_monadic_calls; + use_state_monad; + } + in let fun_sigs = - SymbolicToPure.translate_fun_signatures type_context.type_infos sigs + SymbolicToPure.translate_fun_signatures sp_config type_context.type_infos + sigs in (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure config mp_config trans_ctx fun_sigs - type_decls_map) + (translate_function_to_pure config mp_config use_state_monad trans_ctx + fun_sigs type_decls_map) m.functions in @@ -349,6 +367,7 @@ type gen_ctx = { type gen_config = { mp_config : Micro.config; + use_state_monad : bool; extract_types : bool; extract_decreases_clauses : bool; extract_template_decreases_clauses : bool; @@ -475,7 +494,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (* Is there an input parameter "visible" for the state used in * the state error monad (if we use a state error monad)? *) let has_state_param = - config.mp_config.use_state_monad + config.use_state_monad && config.mp_config.unfold_monadic_let_bindings in (* Check if the definition needs to be filtered or not *) @@ -599,7 +618,8 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (m : M.llbc_module) : unit = (* Translate the module to the pure AST *) let trans_ctx, trans_types, trans_funs = - translate_module_to_pure config.eval_config config.mp_config m + translate_module_to_pure config.eval_config config.mp_config + config.use_state_monad m in (* Initialize the extraction context - for now we extract only to F* *) @@ -711,13 +731,14 @@ let translate_module (filename : string) (dest_dir : string) (config : config) } in - let use_state = config.mp_config.use_state_monad in + let use_state = config.use_state_monad in (* Extract one or several files, depending on the configuration *) if config.split_files then ( let base_gen_config = { mp_config = config.mp_config; + use_state_monad = use_state; extract_types = false; extract_decreases_clauses = config.extract_decreases_clauses; extract_template_decreases_clauses = false; @@ -807,6 +828,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) let gen_config = { mp_config = config.mp_config; + use_state_monad = use_state; extract_types = true; extract_decreases_clauses = config.extract_decreases_clauses; extract_template_decreases_clauses = diff --git a/src/main.ml b/src/main.ml index 9d038f12..e635d910 100644 --- a/src/main.ml +++ b/src/main.ml @@ -177,14 +177,8 @@ let () = unfold_monadic_let_bindings = !unfold_monads; filter_useless_monadic_calls = !filter_useless_calls; filter_useless_functions = !filter_useless_functions; - use_state_monad = not !no_state; } in - (* Small issue: the monadic `bind` only works for the error-monad, not - * the state-error monad (there are definitions to write and piping to do) *) - assert ( - (not micro_passes_config.use_state_monad) - || micro_passes_config.unfold_monadic_let_bindings); let trans_config = { Translate.eval_config; @@ -193,6 +187,7 @@ let () = test_unit_functions; extract_decreases_clauses = not !no_decreases_clauses; extract_template_decreases_clauses = !template_decreases_clauses; + use_state_monad = not !no_state; } in Translate.translate_module filename dest_dir trans_config m -- cgit v1.2.3