summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
authorSon Ho2023-01-06 16:51:27 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit46381652adbece2d7ccfd57fae8b5ee2365fb374 (patch)
tree80e1d1e2cf5728c76736e213c9bedab5191b8376 /compiler/PureMicroPasses.ml
parent2935706e2670a6aad0a01f4ffa29803574a687ed (diff)
Fix some issues with the values given back by loop backward translations
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml29
1 files changed, 26 insertions, 3 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index aed5b02d..25d760fe 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -440,13 +440,14 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
} =
loop
in
let ctx, fun_end = update_texpression fun_end ctx in
let ctx, loop_body = update_texpression loop_body ctx in
- let inputs = List.map (fun input -> update_var ctx input None) inputs in
+ let inputs = List.map (fun v -> update_var ctx v None) inputs in
let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in
let loop =
{
@@ -457,6 +458,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
}
in
@@ -1126,12 +1128,33 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
List.concat [ fuel; fwd_inputs; state; back_inputs ]
in
+ let output, doutputs =
+ match loop.back_output_tys with
+ | None ->
+ (* Forward function: the return type is the same as the
+ parent function *)
+ (fun_sig.output, fun_sig.doutputs)
+ | Some doutputs ->
+ (* Backward function: custom return type *)
+ let output = mk_simpl_tuple_ty doutputs in
+ let output =
+ if loop_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if loop_effect_info.can_fail then mk_result_ty output
+ else output
+ in
+ (output, doutputs)
+ in
+
let loop_sig =
{
type_params = fun_sig.type_params;
inputs = inputs_tys;
- output = fun_sig.output;
- doutputs = fun_sig.doutputs;
+ output;
+ doutputs;
info = loop_sig_info;
}
in