diff options
author | Son HO | 2024-04-11 20:32:15 +0200 |
---|---|---|
committer | GitHub | 2024-04-11 20:32:15 +0200 |
commit | 77d74452489f85f558efe07d72d0200c80b16444 (patch) | |
tree | 810c6504b8e5b2fcde58841e25079d5e8c8e92ae /compiler | |
parent | 4fb9c9f655a9ffc3b4a1a717988311c057c9c599 (diff) | |
parent | 2f8aa9b47acb5c98aed91c29b04f71099452e781 (diff) |
Merge pull request #123 from AeneasVerif/son/clean
Cleanup the code in preparation of the nested loops
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Extract.ml | 6 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 8 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 4 | ||||
-rw-r--r-- | compiler/Pure.ml | 2 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 21 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 2 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 14 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 245 | ||||
-rw-r--r-- | compiler/Translate.ml | 2 |
9 files changed, 181 insertions, 123 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 985fb470..6eeef772 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -2867,7 +2867,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "="; F.pp_print_space fmt (); let success = - ctx_get_variant def.meta (TAssumed TResult) result_return_id ctx + ctx_get_variant def.meta (TAssumed TResult) result_ok_id ctx in F.pp_print_string fmt (success ^ " ())") | Coq -> @@ -2898,11 +2898,11 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "=="; F.pp_print_space fmt (); let success = - ctx_get_variant def.meta (TAssumed TResult) result_return_id ctx + ctx_get_variant def.meta (TAssumed TResult) result_ok_id ctx in F.pp_print_string fmt (success ^ " ())") | HOL4 -> - F.pp_print_string fmt "val _ = assert_return ("; + F.pp_print_string fmt "val _ = assert_ok ("; F.pp_print_string fmt "“"; let fun_name = ctx_get_local_function def.meta def.def_id def.loop_id ctx diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 47b613c2..656d2f27 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1020,7 +1020,7 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = match !backend with | FStar -> [ - (TResult, result_return_id, "Return"); + (TResult, result_ok_id, "Ok"); (TResult, result_fail_id, "Fail"); (TError, error_failure_id, "Failure"); (TError, error_out_of_fuel_id, "OutOfFuel"); @@ -1029,7 +1029,7 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = ] | Coq -> [ - (TResult, result_return_id, "Return"); + (TResult, result_ok_id, "Ok"); (TResult, result_fail_id, "Fail_"); (TError, error_failure_id, "Failure"); (TError, error_out_of_fuel_id, "OutOfFuel"); @@ -1038,7 +1038,7 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = ] | Lean -> [ - (TResult, result_return_id, "Result.ret"); + (TResult, result_ok_id, "Result.ok"); (TResult, result_fail_id, "Result.fail"); (* For panic: we omit the prefix "Error." because the type is always clear from the context. Also, "Error" is often used by user-defined @@ -1049,7 +1049,7 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = ] | HOL4 -> [ - (TResult, result_return_id, "Return"); + (TResult, result_ok_id, "Ok"); (TResult, result_fail_id, "Fail"); (TError, error_failure_id, "Failure"); (* No Fuel::Zero on purpose *) diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 97ea6048..db9c583d 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -312,7 +312,7 @@ let adt_variant_to_string ?(meta = None) (env : fmt_env) (adt_id : type_id) craise_opt_meta __FILE__ __LINE__ meta "Unreachable" | TResult -> let variant_id = Option.get variant_id in - if variant_id = result_return_id then "@Result::Return" + if variant_id = result_ok_id then "@Result::Return" else if variant_id = result_fail_id then "@Result::Fail" else craise_opt_meta __FILE__ __LINE__ meta @@ -395,7 +395,7 @@ let adt_g_value_to_string ?(meta : Meta.meta option = None) (env : fmt_env) craise_opt_meta __FILE__ __LINE__ meta "Unreachable" | TResult -> let variant_id = Option.get variant_id in - if variant_id = result_return_id then + if variant_id = result_ok_id then match field_values with | [ v ] -> "@Result::Return " ^ v | _ -> diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 7366783c..451767f8 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -92,7 +92,7 @@ type assumed_ty = (* TODO: we should never directly manipulate [Return] and [Fail], but rather * the monadic functions [return] and [fail] (makes treatment of error and * state-error monads more uniform) *) -let result_return_id = VariantId.of_int 0 +let result_ok_id = VariantId.of_int 0 let result_fail_id = VariantId.of_int 1 let option_some_id = T.option_some_id let option_none_id = T.option_none_id diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index ebc5c65f..004ecfef 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -752,7 +752,7 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = }, x ) -> (* return/fail case *) - if variant_id = result_return_id then + if variant_id = result_ok_id then (* Return case - note that the simplification we just perform might have unlocked the tuple simplification below *) self#visit_Let env false lv x next @@ -1084,19 +1084,19 @@ let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = body_exp; inputs_lvs } in { def with body = Some body } -(** Simplify the lets immediately followed by a return. +(** Simplify the lets immediately followed by an ok. Ex.: {[ x <-- f y; - Return x + Ok x ~~> f y ]} *) -let simplify_let_then_return _ctx (def : fun_decl) = +let simplify_let_then_ok _ctx (def : fun_decl) = (* Match a pattern and an expression: evaluates to [true] if the expression is actually exactly the pattern *) let rec match_pattern_and_expr (pat : typed_pattern) (e : texpression) : bool @@ -1152,7 +1152,7 @@ let simplify_let_then_return _ctx (def : fun_decl) = | Some e -> if match_pattern_and_expr lv e then (* We need to wrap the right-value in a ret *) - (mk_result_return_texpression def.meta rv).e + (mk_result_ok_texpression def.meta rv).e else not_simpl_e | None -> if match_pattern_and_expr lv next_e then rv.e else not_simpl_e @@ -1791,7 +1791,7 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let err_v = mk_texpression_from_var err_var in let fail_value = mk_result_fail_texpression def.meta err_v e.ty in let fail_branch = { pat = fail_pat; branch = fail_value } in - let success_pat = mk_result_return_pattern lv in + let success_pat = mk_result_ok_pattern lv in let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in @@ -1854,9 +1854,9 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = f y ]} *) - let def = simplify_let_then_return ctx def in + let def = simplify_let_then_ok ctx def in log#ldebug - (lazy ("simplify_let_then_return:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (lazy ("simplify_let_then_ok:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Simplify the aggregated ADTs. @@ -1892,11 +1892,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Simplify the let-then return again (the lambda simplification may have unlocked more simplifications here) *) - let def = simplify_let_then_return ctx def in + let def = simplify_let_then_ok ctx def in log#ldebug (lazy - ("simplify_let_then_return (pass 2):\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); + ("simplify_let_then_ok (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Decompose the monadic let-bindings - used by Coq *) let def = diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 27044c27..c1da4019 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -32,7 +32,7 @@ let get_adt_field_types (meta : Meta.meta) | TResult -> let ty = Collections.List.to_cons_nil generics.types in let variant_id = Option.get variant_id in - if variant_id = result_return_id then [ ty ] + if variant_id = result_ok_id then [ ty ] else if variant_id = result_fail_id then [ mk_error_ty ] else craise __FILE__ __LINE__ meta diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 1bef43b4..fdd14eba 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -586,12 +586,12 @@ let mk_result_fail_texpression_with_error_id (meta : Meta.meta) let error = mk_error error in mk_result_fail_texpression meta error ty -let mk_result_return_texpression (meta : Meta.meta) (v : texpression) : - texpression = +let mk_result_ok_texpression (meta : Meta.meta) (v : texpression) : texpression + = let type_args = [ v.ty ] in let ty = TAdt (TAssumed TResult, mk_generic_args_from_types type_args) in let id = - AdtCons { adt_id = TAssumed TResult; variant_id = Some result_return_id } + AdtCons { adt_id = TAssumed TResult; variant_id = Some result_ok_id } in let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in @@ -613,11 +613,9 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = let error_pat : pattern = PatDummy in mk_result_fail_pattern error_pat ty -let mk_result_return_pattern (v : typed_pattern) : typed_pattern = +let mk_result_ok_pattern (v : typed_pattern) : typed_pattern = let ty = TAdt (TAssumed TResult, mk_generic_args_from_types [ v.ty ]) in - let value = - PatAdt { variant_id = Some result_return_id; field_values = [ v ] } - in + let value = PatAdt { variant_id = Some result_ok_id; field_values = [ v ] } in { value; ty } let opt_unmeta_mplace (e : texpression) : mplace option * texpression = @@ -791,6 +789,6 @@ let opt_destruct_ret (e : texpression) : texpression option = ty = _; }, arg ) - when variant_id = Some result_return_id -> + when variant_id = Some result_ok_id -> Some arg | _ -> None diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 93f9ef75..15b52237 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -268,6 +268,22 @@ type bs_ctx = { Note that when a function contains a loop, we group the function symbolic AST and the loop symbolic AST in a single function. *) + mk_return : (bs_ctx -> texpression option -> texpression) option; + (** Small helper: translate a [return] expression, given a value to "return". + The translation of [return] depends on the context, and in particular depends on + whether we are inside a subexpression like a loop or not. + + Note that the function consumes an optional expression, which is: + - [Some] for a forward computation + - [None] for a backward computation + + We initialize this at [None]. + *) + mk_panic : texpression option; + (** Small helper: translate a [fail] expression. + + We initialize this at [None]. + *) } [@@deriving show] @@ -2012,54 +2028,7 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = | Loop loop -> translate_loop loop ctx | Error (meta, msg) -> translate_error meta msg -and translate_panic (ctx : bs_ctx) : texpression = - (* Here we use the function return type - note that it is ok because - * we don't match on panics which happen inside the function body - - * but it won't be true anymore once we translate individual blocks *) - (* If we use a state monad, we need to add a lambda for the state variable *) - (* Note that only forward functions return a state *) - let effect_info = ctx_get_effect_info ctx in - (* TODO: we should use a [Fail] function *) - let mk_output output_ty = - if effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - ret_ty - in - ret_v - else - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - output_ty - in - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let loop_id = Option.get ctx.loop_id in - let loop = LoopId.Map.find loop_id ctx.loops in - let tys = RegionGroupId.Map.find bid loop.back_outputs in - let output = mk_simpl_tuple_ty tys in - mk_output output - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - match ctx.bid with - | None -> - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - | Some bid -> - let output = - mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs - in - mk_output output +and translate_panic (ctx : bs_ctx) : texpression = Option.get ctx.mk_panic (** [opt_v]: the value to return, in case we translate a forward body. @@ -2071,42 +2040,8 @@ and translate_panic (ctx : bs_ctx) : texpression = *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = - (* There are two cases: - - either we reach the return of a forward function or a forward loop body, - in which case the optional value should be [Some] (it is the returned value) - - or we are translating a backward function, in which case it should be [None] - *) - (* Compute the values that we should return *without the state and the result - wrapper* *) - let output = - match ctx.bid with - | None -> - (* Forward function *) - let v = Option.get opt_v in - typed_value_to_texpression ctx ectx v - | Some _ -> - (* Backward function *) - (* Sanity check *) - sanity_check __FILE__ __LINE__ (opt_v = None) ctx.meta; - (* Group the variables in which we stored the values we need to give back. - See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = Option.get ctx.backward_outputs in - let field_values = List.map mk_texpression_from_var backward_outputs in - mk_simpl_tuple_texpression ctx.meta field_values - in - (* We may need to return a state - * - error-monad: Return x - * - state-error: Return (state, x) - * *) - let effect_info = ctx_get_effect_info ctx in - let output = - if effect_info.stateful then - let state_rvalue = mk_state_texpression ctx.state_var in - mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] - else output - in - (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_return_texpression ctx.meta output + let opt_v = Option.map (typed_value_to_texpression ctx ectx) opt_v in + (Option.get ctx.mk_return) ctx opt_v and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (ctx : bs_ctx) : texpression = @@ -2154,8 +2089,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_emeta (Tag "return_with_loop") - (mk_result_return_texpression ctx.meta output) + mk_emeta (Tag "return_with_loop") (mk_result_ok_texpression ctx.meta output) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3123,6 +3057,49 @@ and translate_forward_end (ectx : C.eval_ctx) (ctx, backward_inputs_no_state @ [ var ]) else (ctx, backward_inputs_no_state) in + (* Update the functions mk_return and mk_panic *) + let effect_info = back_sg.effect_info in + let mk_return ctx v = + assert (v = None); + let output = + (* Group the variables in which we stored the values we need to give back. + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in + let field_values = + List.map mk_texpression_from_var backward_outputs + in + mk_simpl_tuple_texpression ctx.meta field_values + in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id output_ty + in + let output = + mk_simpl_tuple_ty + (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output + in { ctx with backward_inputs_no_state = @@ -3131,6 +3108,8 @@ and translate_forward_end (ectx : C.eval_ctx) backward_inputs_with_state = RegionGroupId.Map.add bid backward_inputs_with_state ctx.backward_inputs_with_state; + mk_return = Some mk_return; + mk_panic = Some mk_panic; } in @@ -3230,7 +3209,7 @@ and translate_forward_end (ectx : C.eval_ctx) let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression ctx.meta (state_var @ [ ret ]) in - let ret = mk_result_return_texpression ctx.meta ret in + let ret = mk_result_ok_texpression ctx.meta ret in (* Introduce all the let-bindings *) @@ -3475,7 +3454,6 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Compute the backward outputs *) - let ctx = ref ctx in let rg_to_given_back_tys = RegionGroupId.Map.map (fun tys -> @@ -3483,13 +3461,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> cassert __FILE__ __LINE__ - (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)) - !ctx.meta "The types shouldn't contain borrows"; - ctx_translate_fwd_ty !ctx ty) + (not (TypesUtils.ty_has_borrows ctx.type_ctx.type_infos ty)) + ctx.meta "The types shouldn't contain borrows"; + ctx_translate_fwd_ty ctx ty) tys) loop.rg_to_given_back_tys in - let ctx = !ctx in (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in @@ -3594,6 +3571,44 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = { types; const_generics; trait_refs } in + (* Update the helpers to translate the fail and return expressions *) + let mk_panic = + (* Note that we reuse the effect information from the parent function *) + let effect_info = ctx_get_effect_info ctx in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output_ty = mk_simpl_tuple_ty tys in + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let mk_return ctx v = + match v with + | None -> raise (Failure "Unexpected") + | Some output -> + let effect_info = ctx_get_effect_info ctx in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let loop_info = { loop_id; @@ -3609,7 +3624,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in - { ctx with loops } + { ctx with loops; mk_return = Some mk_return; mk_panic = Some mk_panic } in (* Update the context to translate the function end *) @@ -3776,6 +3791,50 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let effect_info = get_fun_effect_info ctx (FunId (FRegular def_id)) None None in + let mk_return ctx v = + match v with + | None -> + raise + (Failure + "Unexpected: reached a return expression without value in a \ + function forward expression") + | Some output -> + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output + in + let ctx = + { ctx with mk_return = Some mk_return; mk_panic = Some mk_panic } + in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) let body = diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 870f8a22..9460c5f4 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -158,6 +158,8 @@ let translate_function_to_pure_aux (trans_ctx : trans_ctx) inside_loop = false; loop_ids_map; loops = Pure.LoopId.Map.empty; + mk_return = None; + mk_panic = None; } in |