summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-21 22:45:47 +0100
committerSon Ho2023-12-21 22:45:47 +0100
commiteae740d644f5ccd1ad2a7e853a9cdf303c8df61e (patch)
treec059ea5c7cd3b657cf853f9816d7c038b33151fd /compiler
parent266db04e97778911c93cfd1aac251de04bb25f53 (diff)
Fix issues when extracting stateful functions
Diffstat (limited to 'compiler')
-rw-r--r--compiler/PrintPure.ml51
-rw-r--r--compiler/SymbolicToPure.ml30
2 files changed, 40 insertions, 41 deletions
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 1ce146a4..315dd512 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -611,35 +611,36 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string)
* expression *)
let app, generics =
match app.e with
- | Qualif qualif ->
+ | Qualif qualif -> (
(* Qualifier case *)
- (* Convert the qualifier identifier *)
- let qualif_s =
- match qualif.id with
- | FunOrOp fun_id -> fun_or_op_id_to_string env fun_id
- | Global global_id -> global_decl_id_to_string env global_id
- | AdtCons adt_cons_id ->
- let variant_s =
- adt_variant_to_string env adt_cons_id.adt_id
- adt_cons_id.variant_id
- in
- ConstStrings.constructor_prefix ^ variant_s
- | Proj { adt_id; field_id } ->
- let adt_s = adt_variant_to_string env adt_id None in
- let field_s = adt_field_to_string env adt_id field_id in
- (* Adopting an F*-like syntax *)
- ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s
- | TraitConst (trait_ref, generics, const_name) ->
- let trait_ref = trait_ref_to_string env true trait_ref in
- let generics_s = generic_args_to_string env generics in
+ match qualif.id with
+ | FunOrOp fun_id ->
+ let generics = generic_args_to_strings env true qualif.generics in
+ let qualif_s = fun_or_op_id_to_string env fun_id in
+ (qualif_s, generics)
+ | Global global_id ->
+ let generics = generic_args_to_strings env true qualif.generics in
+ (global_decl_id_to_string env global_id, generics)
+ | AdtCons adt_cons_id ->
+ let variant_s =
+ adt_variant_to_string env adt_cons_id.adt_id
+ adt_cons_id.variant_id
+ in
+ (ConstStrings.constructor_prefix ^ variant_s, [])
+ | Proj { adt_id; field_id } ->
+ let adt_s = adt_variant_to_string env adt_id None in
+ let field_s = adt_field_to_string env adt_id field_id in
+ (* Adopting an F*-like syntax *)
+ (ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s, [])
+ | TraitConst (trait_ref, generics, const_name) ->
+ let trait_ref = trait_ref_to_string env true trait_ref in
+ let generics_s = generic_args_to_string env generics in
+ let qualif =
if generics <> empty_generic_args then
"(" ^ trait_ref ^ generics_s ^ ")." ^ const_name
else trait_ref ^ "." ^ const_name
- in
- (* Convert the type instantiation *)
- let generics = generic_args_to_strings env true qualif.generics in
- (* *)
- (qualif_s, generics)
+ in
+ (qualif, []))
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 37f621e4..7eb75584 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2782,7 +2782,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
^ pure_ty_to_string ctx true_e.ty
^ "\n\nfalse_e.ty: "
^ pure_ty_to_string ctx false_e.ty));
- assert (ty = false_e.ty);
+ if !Config.fail_hard then assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
@@ -3005,7 +3005,7 @@ and translate_forward_end (ectx : C.eval_ctx)
fresh_vars back_sg.inputs_no_state ctx
in
let ctx, backward_inputs_with_state =
- if (ctx_get_effect_info ctx).stateful then
+ if back_sg.effect_info.stateful then
let ctx, var, _ = bs_ctx_fresh_state_var ctx in
(ctx, backward_inputs_no_state @ [ var ])
else (ctx, backward_inputs_no_state)
@@ -3061,18 +3061,7 @@ and translate_forward_end (ectx : C.eval_ctx)
if !Config.return_back_funs then
(* Compute the output of the forward function *)
let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output_ty =
- let ty = ctx.sg.fwd_output in
- if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ]
- else ty
- in
- let ctx, fwd_var = fresh_var None output_ty ctx in
- let ctx, state_var, state_pat =
- if fwd_effect_info.stateful then
- let ctx, var, pat = bs_ctx_fresh_state_var ctx in
- (ctx, [ var ], [ pat ])
- else (ctx, [], [])
- in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
let fwd_e = translate_one_end ctx None in
(* Introduce the backward functions. *)
@@ -3105,10 +3094,19 @@ and translate_forward_end (ectx : C.eval_ctx)
let vars =
let back_vars = List.filter_map (fun x -> x) back_vars in
if ctx.sg.fwd_info.ignore_output then back_vars
- else fwd_var :: back_vars
+ else pure_fwd_var :: back_vars
in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
+
let state_var = List.map mk_texpression_from_var state_var in
let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
let ret = mk_result_return_texpression ret in
@@ -3135,7 +3133,7 @@ and translate_forward_end (ectx : C.eval_ctx)
back_vars_els ret
in
(* Bind the expression for the forward output *)
- let fwd_var = mk_typed_pattern_from_var fwd_var None in
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
mk_let fwd_effect_info.can_fail pat fwd_e e
else translate_one_end ctx ctx.bid