summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml30
1 files changed, 14 insertions, 16 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 37f621e4..7eb75584 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2782,7 +2782,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
^ pure_ty_to_string ctx true_e.ty
^ "\n\nfalse_e.ty: "
^ pure_ty_to_string ctx false_e.ty));
- assert (ty = false_e.ty);
+ if !Config.fail_hard then assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
@@ -3005,7 +3005,7 @@ and translate_forward_end (ectx : C.eval_ctx)
fresh_vars back_sg.inputs_no_state ctx
in
let ctx, backward_inputs_with_state =
- if (ctx_get_effect_info ctx).stateful then
+ if back_sg.effect_info.stateful then
let ctx, var, _ = bs_ctx_fresh_state_var ctx in
(ctx, backward_inputs_no_state @ [ var ])
else (ctx, backward_inputs_no_state)
@@ -3061,18 +3061,7 @@ and translate_forward_end (ectx : C.eval_ctx)
if !Config.return_back_funs then
(* Compute the output of the forward function *)
let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output_ty =
- let ty = ctx.sg.fwd_output in
- if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ]
- else ty
- in
- let ctx, fwd_var = fresh_var None output_ty ctx in
- let ctx, state_var, state_pat =
- if fwd_effect_info.stateful then
- let ctx, var, pat = bs_ctx_fresh_state_var ctx in
- (ctx, [ var ], [ pat ])
- else (ctx, [], [])
- in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
let fwd_e = translate_one_end ctx None in
(* Introduce the backward functions. *)
@@ -3105,10 +3094,19 @@ and translate_forward_end (ectx : C.eval_ctx)
let vars =
let back_vars = List.filter_map (fun x -> x) back_vars in
if ctx.sg.fwd_info.ignore_output then back_vars
- else fwd_var :: back_vars
+ else pure_fwd_var :: back_vars
in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
+
let state_var = List.map mk_texpression_from_var state_var in
let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
let ret = mk_result_return_texpression ret in
@@ -3135,7 +3133,7 @@ and translate_forward_end (ectx : C.eval_ctx)
back_vars_els ret
in
(* Bind the expression for the forward output *)
- let fwd_var = mk_typed_pattern_from_var fwd_var None in
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
mk_let fwd_effect_info.can_fail pat fwd_e e
else translate_one_end ctx ctx.bid