summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml295
1 files changed, 192 insertions, 103 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 0c30f44c..15b52237 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -268,6 +268,22 @@ type bs_ctx = {
Note that when a function contains a loop, we group the function symbolic AST and the loop symbolic
AST in a single function.
*)
+ mk_return : (bs_ctx -> texpression option -> texpression) option;
+ (** Small helper: translate a [return] expression, given a value to "return".
+ The translation of [return] depends on the context, and in particular depends on
+ whether we are inside a subexpression like a loop or not.
+
+ Note that the function consumes an optional expression, which is:
+ - [Some] for a forward computation
+ - [None] for a backward computation
+
+ We initialize this at [None].
+ *)
+ mk_panic : texpression option;
+ (** Small helper: translate a [fail] expression.
+
+ We initialize this at [None].
+ *)
}
[@@deriving show]
@@ -1499,9 +1515,27 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx)
let back_vars =
List.map
(fun (name, ty) ->
- match ty with None -> None | Some ty -> Some (name, ty))
+ match ty with
+ | None -> None
+ | Some ty ->
+ (* If the type is not an arrow type, don't use the name "back"
+ (it is a backward function with no inputs, that is to say a
+ value) *)
+ let name = if is_arrow_ty ty then name else None in
+ Some (name, ty))
back_vars
in
+ (* If there is one backward function or less, we use the name "back"
+ (there is no point in using the lifetime name, and it makes the
+ code generation more stable) *)
+ let num_back_vars = List.length (List.filter_map (fun x -> x) back_vars) in
+ let back_vars =
+ if num_back_vars = 1 then
+ List.map
+ (Option.map (fun (name, ty) -> (Option.map (fun _ -> "back") name, ty)))
+ back_vars
+ else back_vars
+ in
fresh_opt_vars back_vars ctx
(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *)
@@ -1963,6 +1997,9 @@ let eval_ctx_to_symbolic_assignments_info (ctx : bs_ctx)
(* Return the computed information *)
!info
+let translate_error (meta : Meta.meta option) (msg : string) : texpression =
+ { e = EError (meta, msg); ty = Error }
+
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return (ectx, opt_v) ->
@@ -1989,55 +2026,9 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
*)
translate_forward_end ectx loop_input_values e back_e ctx
| Loop loop -> translate_loop loop ctx
+ | Error (meta, msg) -> translate_error meta msg
-and translate_panic (ctx : bs_ctx) : texpression =
- (* Here we use the function return type - note that it is ok because
- * we don't match on panics which happen inside the function body -
- * but it won't be true anymore once we translate individual blocks *)
- (* If we use a state monad, we need to add a lambda for the state variable *)
- (* Note that only forward functions return a state *)
- let effect_info = ctx_get_effect_info ctx in
- (* TODO: we should use a [Fail] function *)
- let mk_output output_ty =
- if effect_info.stateful then
- (* Create the [Fail] value *)
- let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v =
- mk_result_fail_texpression_with_error_id ctx.meta error_failure_id
- ret_ty
- in
- ret_v
- else
- mk_result_fail_texpression_with_error_id ctx.meta error_failure_id
- output_ty
- in
- if ctx.inside_loop && Option.is_some ctx.bid then
- (* We are synthesizing the backward function of a loop body *)
- let bid = Option.get ctx.bid in
- let loop_id = Option.get ctx.loop_id in
- let loop = LoopId.Map.find loop_id ctx.loops in
- let tys = RegionGroupId.Map.find bid loop.back_outputs in
- let output = mk_simpl_tuple_ty tys in
- mk_output output
- else
- (* Regular function, or forward function (the forward translation for
- a loop has the same return type as the parent function)
- *)
- match ctx.bid with
- | None ->
- let back_tys = compute_back_tys ctx.sg None in
- let back_tys = List.filter_map (fun x -> x) back_tys in
- let tys =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty tys in
- mk_output output
- | Some bid ->
- let output =
- mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
- in
- mk_output output
+and translate_panic (ctx : bs_ctx) : texpression = Option.get ctx.mk_panic
(** [opt_v]: the value to return, in case we translate a forward body.
@@ -2049,42 +2040,8 @@ and translate_panic (ctx : bs_ctx) : texpression =
*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
- (* There are two cases:
- - either we reach the return of a forward function or a forward loop body,
- in which case the optional value should be [Some] (it is the returned value)
- - or we are translating a backward function, in which case it should be [None]
- *)
- (* Compute the values that we should return *without the state and the result
- wrapper* *)
- let output =
- match ctx.bid with
- | None ->
- (* Forward function *)
- let v = Option.get opt_v in
- typed_value_to_texpression ctx ectx v
- | Some _ ->
- (* Backward function *)
- (* Sanity check *)
- sanity_check __FILE__ __LINE__ (opt_v = None) ctx.meta;
- (* Group the variables in which we stored the values we need to give back.
- See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs = Option.get ctx.backward_outputs in
- let field_values = List.map mk_texpression_from_var backward_outputs in
- mk_simpl_tuple_texpression ctx.meta field_values
- in
- (* We may need to return a state
- * - error-monad: Return x
- * - state-error: Return (state, x)
- * *)
- let effect_info = ctx_get_effect_info ctx in
- let output =
- if effect_info.stateful then
- let state_rvalue = mk_state_texpression ctx.state_var in
- mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ]
- else output
- in
- (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- mk_result_return_texpression ctx.meta output
+ let opt_v = Option.map (typed_value_to_texpression ctx ectx) opt_v in
+ (Option.get ctx.mk_return) ctx opt_v
and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
(ctx : bs_ctx) : texpression =
@@ -2132,8 +2089,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
else output
in
(* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- mk_emeta (Tag "return_with_loop")
- (mk_result_return_texpression ctx.meta output)
+ mk_emeta (Tag "return_with_loop") (mk_result_ok_texpression ctx.meta output)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
@@ -2240,15 +2196,14 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(fun ty ->
match ty with
| None -> None
- | Some (back_sg, ty) ->
- (* We insert a name for the variable only if the function
- can fail: if it can fail, it means the call returns a backward
- function. Otherwise, it directly returns the value given
- back by the backward function, which means we shouldn't
- give it a name like "back..." (it doesn't make sense) *)
+ | Some (_back_sg, ty) ->
+ (* We insert a name for the variable only if the type
+ is an arrow type. If it is not, it means the backward
+ function is degenerate (it takes no inputs) so it is
+ not a function anymore but a value: it doesn't make
+ sense to use a name like "back...". *)
let name =
- if back_sg.effect_info.can_fail then Some back_fun_name
- else None
+ if is_arrow_ty ty then Some back_fun_name else None
in
Some (name, ty))
back_tys)
@@ -3102,6 +3057,49 @@ and translate_forward_end (ectx : C.eval_ctx)
(ctx, backward_inputs_no_state @ [ var ])
else (ctx, backward_inputs_no_state)
in
+ (* Update the functions mk_return and mk_panic *)
+ let effect_info = back_sg.effect_info in
+ let mk_return ctx v =
+ assert (v = None);
+ let output =
+ (* Group the variables in which we stored the values we need to give back.
+ See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs = Option.get ctx.backward_outputs in
+ let field_values =
+ List.map mk_texpression_from_var backward_outputs
+ in
+ mk_simpl_tuple_texpression ctx.meta field_values
+ in
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
+ mk_result_ok_texpression ctx.meta output
+ in
+ let mk_panic =
+ (* TODO: we should use a [Fail] function *)
+ let mk_output output_ty =
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id ctx.meta
+ error_failure_id ret_ty
+ in
+ ret_v
+ else
+ mk_result_fail_texpression_with_error_id ctx.meta
+ error_failure_id output_ty
+ in
+ let output =
+ mk_simpl_tuple_ty
+ (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
+ in
+ mk_output output
+ in
{
ctx with
backward_inputs_no_state =
@@ -3110,6 +3108,8 @@ and translate_forward_end (ectx : C.eval_ctx)
backward_inputs_with_state =
RegionGroupId.Map.add bid backward_inputs_with_state
ctx.backward_inputs_with_state;
+ mk_return = Some mk_return;
+ mk_panic = Some mk_panic;
}
in
@@ -3209,7 +3209,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let state_var = List.map mk_texpression_from_var state_var in
let ret = mk_simpl_tuple_texpression ctx.meta (state_var @ [ ret ]) in
- let ret = mk_result_return_texpression ctx.meta ret in
+ let ret = mk_result_ok_texpression ctx.meta ret in
(* Introduce all the let-bindings *)
@@ -3454,7 +3454,6 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Compute the backward outputs *)
- let ctx = ref ctx in
let rg_to_given_back_tys =
RegionGroupId.Map.map
(fun tys ->
@@ -3462,13 +3461,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
List.map
(fun ty ->
cassert __FILE__ __LINE__
- (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty))
- !ctx.meta "The types shouldn't contain borrows";
- ctx_translate_fwd_ty !ctx ty)
+ (not (TypesUtils.ty_has_borrows ctx.type_ctx.type_infos ty))
+ ctx.meta "The types shouldn't contain borrows";
+ ctx_translate_fwd_ty ctx ty)
tys)
loop.rg_to_given_back_tys
in
- let ctx = !ctx in
(* The output type of the loop function *)
let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
@@ -3573,6 +3571,44 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
{ types; const_generics; trait_refs }
in
+ (* Update the helpers to translate the fail and return expressions *)
+ let mk_panic =
+ (* Note that we reuse the effect information from the parent function *)
+ let effect_info = ctx_get_effect_info ctx in
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output_ty = mk_simpl_tuple_ty tys in
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id ctx.meta error_failure_id
+ ret_ty
+ in
+ ret_v
+ else
+ mk_result_fail_texpression_with_error_id ctx.meta error_failure_id
+ output_ty
+ in
+ let mk_return ctx v =
+ match v with
+ | None -> raise (Failure "Unexpected")
+ | Some output ->
+ let effect_info = ctx_get_effect_info ctx in
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
+ mk_result_ok_texpression ctx.meta output
+ in
+
let loop_info =
{
loop_id;
@@ -3588,7 +3624,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in
- { ctx with loops }
+ { ctx with loops; mk_return = Some mk_return; mk_panic = Some mk_panic }
in
(* Update the context to translate the function end *)
@@ -3755,6 +3791,50 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let effect_info =
get_fun_effect_info ctx (FunId (FRegular def_id)) None None
in
+ let mk_return ctx v =
+ match v with
+ | None ->
+ raise
+ (Failure
+ "Unexpected: reached a return expression without value in a \
+ function forward expression")
+ | Some output ->
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression ctx.meta [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
+ mk_result_ok_texpression ctx.meta output
+ in
+ let mk_panic =
+ (* TODO: we should use a [Fail] function *)
+ let mk_output output_ty =
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id ctx.meta
+ error_failure_id ret_ty
+ in
+ ret_v
+ else
+ mk_result_fail_texpression_with_error_id ctx.meta error_failure_id
+ output_ty
+ in
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ in
+ let ctx =
+ { ctx with mk_return = Some mk_return; mk_panic = Some mk_panic }
+ in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
let body =
@@ -3840,7 +3920,16 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
def
let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list =
- List.map (translate_type_decl ctx)
+ List.filter_map
+ (fun a ->
+ try Some (translate_type_decl ctx a)
+ with CFailure (meta, _) ->
+ let env = PrintPure.decls_ctx_to_fmt_env ctx in
+ let name = PrintPure.name_to_string env a.name in
+ save_error __FILE__ __LINE__ meta
+ ("Could not translate type decl '" ^ name
+ ^ "' because of previous error");
+ None)
(TypeDeclId.Map.values ctx.type_ctx.type_decls)
let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl)