diff options
author | Son Ho | 2023-01-06 16:51:27 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 46381652adbece2d7ccfd57fae8b5ee2365fb374 (patch) | |
tree | 80e1d1e2cf5728c76736e213c9bedab5191b8376 /compiler/PureMicroPasses.ml | |
parent | 2935706e2670a6aad0a01f4ffa29803574a687ed (diff) |
Fix some issues with the values given back by loop backward translations
Diffstat (limited to '')
-rw-r--r-- | compiler/PureMicroPasses.ml | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index aed5b02d..25d760fe 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -440,13 +440,14 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; + back_output_tys; loop_body; } = loop in let ctx, fun_end = update_texpression fun_end ctx in let ctx, loop_body = update_texpression loop_body ctx in - let inputs = List.map (fun input -> update_var ctx input None) inputs in + let inputs = List.map (fun v -> update_var ctx v None) inputs in let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in let loop = { @@ -457,6 +458,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; + back_output_tys; loop_body; } in @@ -1126,12 +1128,33 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = List.concat [ fuel; fwd_inputs; state; back_inputs ] in + let output, doutputs = + match loop.back_output_tys with + | None -> + (* Forward function: the return type is the same as the + parent function *) + (fun_sig.output, fun_sig.doutputs) + | Some doutputs -> + (* Backward function: custom return type *) + let output = mk_simpl_tuple_ty doutputs in + let output = + if loop_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if loop_effect_info.can_fail then mk_result_ty output + else output + in + (output, doutputs) + in + let loop_sig = { type_params = fun_sig.type_params; inputs = inputs_tys; - output = fun_sig.output; - doutputs = fun_sig.doutputs; + output; + doutputs; info = loop_sig_info; } in |