diff options
author | Son Ho | 2024-04-11 20:31:16 +0200 |
---|---|---|
committer | Son Ho | 2024-04-11 20:31:16 +0200 |
commit | b6421bc01df278f625a8c95b4ea36ad2e4355718 (patch) | |
tree | 6246ef2b038560e3deae41e4fa700f14390cd14f /compiler/SymbolicToPure.ml | |
parent | 44065f447dc3a2f4b1441b97b9687d1c1b85afbf (diff) | |
parent | 2f8aa9b47acb5c98aed91c29b04f71099452e781 (diff) |
Merge branch 'son/clean' into checked-ops
Diffstat (limited to '')
-rw-r--r-- | compiler/SymbolicToPure.ml | 295 |
1 files changed, 192 insertions, 103 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 0c30f44c..15b52237 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -268,6 +268,22 @@ type bs_ctx = { Note that when a function contains a loop, we group the function symbolic AST and the loop symbolic AST in a single function. *) + mk_return : (bs_ctx -> texpression option -> texpression) option; + (** Small helper: translate a [return] expression, given a value to "return". + The translation of [return] depends on the context, and in particular depends on + whether we are inside a subexpression like a loop or not. + + Note that the function consumes an optional expression, which is: + - [Some] for a forward computation + - [None] for a backward computation + + We initialize this at [None]. + *) + mk_panic : texpression option; + (** Small helper: translate a [fail] expression. + + We initialize this at [None]. + *) } [@@deriving show] @@ -1499,9 +1515,27 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) let back_vars = List.map (fun (name, ty) -> - match ty with None -> None | Some ty -> Some (name, ty)) + match ty with + | None -> None + | Some ty -> + (* If the type is not an arrow type, don't use the name "back" + (it is a backward function with no inputs, that is to say a + value) *) + let name = if is_arrow_ty ty then name else None in + Some (name, ty)) back_vars in + (* If there is one backward function or less, we use the name "back" + (there is no point in using the lifetime name, and it makes the + code generation more stable) *) + let num_back_vars = List.length (List.filter_map (fun x -> x) back_vars) in + let back_vars = + if num_back_vars = 1 then + List.map + (Option.map (fun (name, ty) -> (Option.map (fun _ -> "back") name, ty))) + back_vars + else back_vars + in fresh_opt_vars back_vars ctx (** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *) @@ -1963,6 +1997,9 @@ let eval_ctx_to_symbolic_assignments_info (ctx : bs_ctx) (* Return the computed information *) !info +let translate_error (meta : Meta.meta option) (msg : string) : texpression = + { e = EError (meta, msg); ty = Error } + let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return (ectx, opt_v) -> @@ -1989,55 +2026,9 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = *) translate_forward_end ectx loop_input_values e back_e ctx | Loop loop -> translate_loop loop ctx + | Error (meta, msg) -> translate_error meta msg -and translate_panic (ctx : bs_ctx) : texpression = - (* Here we use the function return type - note that it is ok because - * we don't match on panics which happen inside the function body - - * but it won't be true anymore once we translate individual blocks *) - (* If we use a state monad, we need to add a lambda for the state variable *) - (* Note that only forward functions return a state *) - let effect_info = ctx_get_effect_info ctx in - (* TODO: we should use a [Fail] function *) - let mk_output output_ty = - if effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - ret_ty - in - ret_v - else - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - output_ty - in - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let loop_id = Option.get ctx.loop_id in - let loop = LoopId.Map.find loop_id ctx.loops in - let tys = RegionGroupId.Map.find bid loop.back_outputs in - let output = mk_simpl_tuple_ty tys in - mk_output output - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - match ctx.bid with - | None -> - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - | Some bid -> - let output = - mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs - in - mk_output output +and translate_panic (ctx : bs_ctx) : texpression = Option.get ctx.mk_panic (** [opt_v]: the value to return, in case we translate a forward body. @@ -2049,42 +2040,8 @@ and translate_panic (ctx : bs_ctx) : texpression = *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = - (* There are two cases: - - either we reach the return of a forward function or a forward loop body, - in which case the optional value should be [Some] (it is the returned value) - - or we are translating a backward function, in which case it should be [None] - *) - (* Compute the values that we should return *without the state and the result - wrapper* *) - let output = - match ctx.bid with - | None -> - (* Forward function *) - let v = Option.get opt_v in - typed_value_to_texpression ctx ectx v - | Some _ -> - (* Backward function *) - (* Sanity check *) - sanity_check __FILE__ __LINE__ (opt_v = None) ctx.meta; - (* Group the variables in which we stored the values we need to give back. - See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = Option.get ctx.backward_outputs in - let field_values = List.map mk_texpression_from_var backward_outputs in - mk_simpl_tuple_texpression ctx.meta field_values - in - (* We may need to return a state - * - error-monad: Return x - * - state-error: Return (state, x) - * *) - let effect_info = ctx_get_effect_info ctx in - let output = - if effect_info.stateful then - let state_rvalue = mk_state_texpression ctx.state_var in - mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] - else output - in - (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_return_texpression ctx.meta output + let opt_v = Option.map (typed_value_to_texpression ctx ectx) opt_v in + (Option.get ctx.mk_return) ctx opt_v and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (ctx : bs_ctx) : texpression = @@ -2132,8 +2089,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_emeta (Tag "return_with_loop") - (mk_result_return_texpression ctx.meta output) + mk_emeta (Tag "return_with_loop") (mk_result_ok_texpression ctx.meta output) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2240,15 +2196,14 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (fun ty -> match ty with | None -> None - | Some (back_sg, ty) -> - (* We insert a name for the variable only if the function - can fail: if it can fail, it means the call returns a backward - function. Otherwise, it directly returns the value given - back by the backward function, which means we shouldn't - give it a name like "back..." (it doesn't make sense) *) + | Some (_back_sg, ty) -> + (* We insert a name for the variable only if the type + is an arrow type. If it is not, it means the backward + function is degenerate (it takes no inputs) so it is + not a function anymore but a value: it doesn't make + sense to use a name like "back...". *) let name = - if back_sg.effect_info.can_fail then Some back_fun_name - else None + if is_arrow_ty ty then Some back_fun_name else None in Some (name, ty)) back_tys) @@ -3102,6 +3057,49 @@ and translate_forward_end (ectx : C.eval_ctx) (ctx, backward_inputs_no_state @ [ var ]) else (ctx, backward_inputs_no_state) in + (* Update the functions mk_return and mk_panic *) + let effect_info = back_sg.effect_info in + let mk_return ctx v = + assert (v = None); + let output = + (* Group the variables in which we stored the values we need to give back. + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in + let field_values = + List.map mk_texpression_from_var backward_outputs + in + mk_simpl_tuple_texpression ctx.meta field_values + in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id output_ty + in + let output = + mk_simpl_tuple_ty + (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output + in { ctx with backward_inputs_no_state = @@ -3110,6 +3108,8 @@ and translate_forward_end (ectx : C.eval_ctx) backward_inputs_with_state = RegionGroupId.Map.add bid backward_inputs_with_state ctx.backward_inputs_with_state; + mk_return = Some mk_return; + mk_panic = Some mk_panic; } in @@ -3209,7 +3209,7 @@ and translate_forward_end (ectx : C.eval_ctx) let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression ctx.meta (state_var @ [ ret ]) in - let ret = mk_result_return_texpression ctx.meta ret in + let ret = mk_result_ok_texpression ctx.meta ret in (* Introduce all the let-bindings *) @@ -3454,7 +3454,6 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Compute the backward outputs *) - let ctx = ref ctx in let rg_to_given_back_tys = RegionGroupId.Map.map (fun tys -> @@ -3462,13 +3461,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> cassert __FILE__ __LINE__ - (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)) - !ctx.meta "The types shouldn't contain borrows"; - ctx_translate_fwd_ty !ctx ty) + (not (TypesUtils.ty_has_borrows ctx.type_ctx.type_infos ty)) + ctx.meta "The types shouldn't contain borrows"; + ctx_translate_fwd_ty ctx ty) tys) loop.rg_to_given_back_tys in - let ctx = !ctx in (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in @@ -3573,6 +3571,44 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = { types; const_generics; trait_refs } in + (* Update the helpers to translate the fail and return expressions *) + let mk_panic = + (* Note that we reuse the effect information from the parent function *) + let effect_info = ctx_get_effect_info ctx in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output_ty = mk_simpl_tuple_ty tys in + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let mk_return ctx v = + match v with + | None -> raise (Failure "Unexpected") + | Some output -> + let effect_info = ctx_get_effect_info ctx in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let loop_info = { loop_id; @@ -3588,7 +3624,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in - { ctx with loops } + { ctx with loops; mk_return = Some mk_return; mk_panic = Some mk_panic } in (* Update the context to translate the function end *) @@ -3755,6 +3791,50 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let effect_info = get_fun_effect_info ctx (FunId (FRegular def_id)) None None in + let mk_return ctx v = + match v with + | None -> + raise + (Failure + "Unexpected: reached a return expression without value in a \ + function forward expression") + | Some output -> + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output + in + let ctx = + { ctx with mk_return = Some mk_return; mk_panic = Some mk_panic } + in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) let body = @@ -3840,7 +3920,16 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = def let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list = - List.map (translate_type_decl ctx) + List.filter_map + (fun a -> + try Some (translate_type_decl ctx a) + with CFailure (meta, _) -> + let env = PrintPure.decls_ctx_to_fmt_env ctx in + let name = PrintPure.name_to_string env a.name in + save_error __FILE__ __LINE__ meta + ("Could not translate type decl '" ^ name + ^ "' because of previous error"); + None) (TypeDeclId.Map.values ctx.type_ctx.type_decls) let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) |