diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/SymbolicToPure.ml | 240 | ||||
-rw-r--r-- | compiler/Translate.ml | 2 |
2 files changed, 152 insertions, 90 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 38ee5df1..482ebf3a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -268,6 +268,22 @@ type bs_ctx = { Note that when a function contains a loop, we group the function symbolic AST and the loop symbolic AST in a single function. *) + mk_return : (bs_ctx -> texpression option -> texpression) option; + (** Small helper: translate a [return] expression, given a value to "return". + The translation of [return] depends on the context, and in particular depends on + whether we are inside a subexpression like a loop or not. + + Note that the function consumes an optional expression, which is: + - [Some] for a forward computation + - [None] for a backward computation + + We initialize this at [None]. + *) + mk_panic : texpression option; + (** Small helper: translate a [fail] expression. + + We initialize this at [None]. + *) } [@@deriving show] @@ -2008,54 +2024,7 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = translate_forward_end ectx loop_input_values e back_e ctx | Loop loop -> translate_loop loop ctx -and translate_panic (ctx : bs_ctx) : texpression = - (* Here we use the function return type - note that it is ok because - * we don't match on panics which happen inside the function body - - * but it won't be true anymore once we translate individual blocks *) - (* If we use a state monad, we need to add a lambda for the state variable *) - (* Note that only forward functions return a state *) - let effect_info = ctx_get_effect_info ctx in - (* TODO: we should use a [Fail] function *) - let mk_output output_ty = - if effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - ret_ty - in - ret_v - else - mk_result_fail_texpression_with_error_id ctx.meta error_failure_id - output_ty - in - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let loop_id = Option.get ctx.loop_id in - let loop = LoopId.Map.find loop_id ctx.loops in - let tys = RegionGroupId.Map.find bid loop.back_outputs in - let output = mk_simpl_tuple_ty tys in - mk_output output - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - match ctx.bid with - | None -> - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - | Some bid -> - let output = - mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs - in - mk_output output +and translate_panic (ctx : bs_ctx) : texpression = Option.get ctx.mk_panic (** [opt_v]: the value to return, in case we translate a forward body. @@ -2067,42 +2036,8 @@ and translate_panic (ctx : bs_ctx) : texpression = *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = - (* There are two cases: - - either we reach the return of a forward function or a forward loop body, - in which case the optional value should be [Some] (it is the returned value) - - or we are translating a backward function, in which case it should be [None] - *) - (* Compute the values that we should return *without the state and the result - wrapper* *) - let output = - match ctx.bid with - | None -> - (* Forward function *) - let v = Option.get opt_v in - typed_value_to_texpression ctx ectx v - | Some _ -> - (* Backward function *) - (* Sanity check *) - sanity_check __FILE__ __LINE__ (opt_v = None) ctx.meta; - (* Group the variables in which we stored the values we need to give back. - See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = Option.get ctx.backward_outputs in - let field_values = List.map mk_texpression_from_var backward_outputs in - mk_simpl_tuple_texpression ctx.meta field_values - in - (* We may need to return a state - * - error-monad: Return x - * - state-error: Return (state, x) - * *) - let effect_info = ctx_get_effect_info ctx in - let output = - if effect_info.stateful then - let state_rvalue = mk_state_texpression ctx.state_var in - mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] - else output - in - (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_ok_texpression ctx.meta output + let opt_v = Option.map (typed_value_to_texpression ctx ectx) opt_v in + (Option.get ctx.mk_return) ctx opt_v and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (ctx : bs_ctx) : texpression = @@ -3118,6 +3053,49 @@ and translate_forward_end (ectx : C.eval_ctx) (ctx, backward_inputs_no_state @ [ var ]) else (ctx, backward_inputs_no_state) in + (* Update the functions mk_return and mk_panic *) + let effect_info = back_sg.effect_info in + let mk_return ctx v = + assert (v = None); + let output = + (* Group the variables in which we stored the values we need to give back. + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in + let field_values = + List.map mk_texpression_from_var backward_outputs + in + mk_simpl_tuple_texpression ctx.meta field_values + in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id output_ty + in + let output = + mk_simpl_tuple_ty + (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output + in { ctx with backward_inputs_no_state = @@ -3126,6 +3104,8 @@ and translate_forward_end (ectx : C.eval_ctx) backward_inputs_with_state = RegionGroupId.Map.add bid backward_inputs_with_state ctx.backward_inputs_with_state; + mk_return = Some mk_return; + mk_panic = Some mk_panic; } in @@ -3470,7 +3450,6 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Compute the backward outputs *) - let ctx = ref ctx in let rg_to_given_back_tys = RegionGroupId.Map.map (fun tys -> @@ -3478,13 +3457,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> cassert __FILE__ __LINE__ - (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)) - !ctx.meta "The types shouldn't contain borrows"; - ctx_translate_fwd_ty !ctx ty) + (not (TypesUtils.ty_has_borrows ctx.type_ctx.type_infos ty)) + ctx.meta "The types shouldn't contain borrows"; + ctx_translate_fwd_ty ctx ty) tys) loop.rg_to_given_back_tys in - let ctx = !ctx in (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in @@ -3589,6 +3567,44 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = { types; const_generics; trait_refs } in + (* Update the helpers to translate the fail and return expressions *) + let mk_panic = + (* Note that we reuse the effect information from the parent function *) + let effect_info = ctx_get_effect_info ctx in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output_ty = mk_simpl_tuple_ty tys in + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let mk_return ctx v = + match v with + | None -> raise (Failure "Unexpected") + | Some output -> + let effect_info = ctx_get_effect_info ctx in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let loop_info = { loop_id; @@ -3604,7 +3620,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in - { ctx with loops } + { ctx with loops; mk_return = Some mk_return; mk_panic = Some mk_panic } in (* Update the context to translate the function end *) @@ -3771,6 +3787,50 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let effect_info = get_fun_effect_info ctx (FunId (FRegular def_id)) None None in + let mk_return ctx v = + match v with + | None -> + raise + (Failure + "Unexpected: reached a return expression without value in a \ + function forward expression") + | Some output -> + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + mk_result_ok_texpression ctx.meta output + in + let mk_panic = + (* TODO: we should use a [Fail] function *) + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id ctx.meta + error_failure_id ret_ty + in + ret_v + else + mk_result_fail_texpression_with_error_id ctx.meta error_failure_id + output_ty + in + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output + in + let ctx = + { ctx with mk_return = Some mk_return; mk_panic = Some mk_panic } + in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) let body = diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 348183c5..22288fe2 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -158,6 +158,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) inside_loop = false; loop_ids_map; loops = Pure.LoopId.Map.empty; + mk_return = None; + mk_panic = None; } in |