summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-22 21:03:17 +0100
committerSon Ho2023-12-22 21:03:17 +0100
commit70d506d148e5ae1a3e4115034161f449aff666ed (patch)
tree43faecd146f5d792d398512097b3afdb503ae11c /compiler
parentb230ddacd44a1ca1804940bf89253bde8de7ffe1 (diff)
Fix the output type of the loops backward functions
Diffstat (limited to 'compiler')
-rw-r--r--compiler/PrintPure.ml11
-rw-r--r--compiler/Pure.ml4
-rw-r--r--compiler/PureMicroPasses.ml25
-rw-r--r--compiler/PureTypeCheck.ml6
-rw-r--r--compiler/SymbolicToPure.ml65
5 files changed, 65 insertions, 46 deletions
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 315dd512..66475d02 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -711,21 +711,14 @@ and loop_to_string (env : fmt_env) (indent : string) (indent_incr : string)
^ String.concat "; " (List.map (var_to_string env) loop.inputs)
^ "]"
in
- let back_output_tys =
- let tys =
- match loop.back_output_tys with
- | None -> ""
- | Some tys -> String.concat "; " (List.map (ty_to_string env false) tys)
- in
- "back_output_tys: [" ^ tys ^ "]"
- in
+ let output_ty = "output_ty: " ^ ty_to_string env false loop.output_ty in
let fun_end =
texpression_to_string env false indent2 indent_incr loop.fun_end
in
let loop_body =
texpression_to_string env false indent2 indent_incr loop.loop_body
in
- "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n"
+ "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ output_ty ^ "\n"
^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n"
^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n"
^ indent ^ "}"
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 71531688..a879ba37 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -754,9 +754,7 @@ and loop = {
inputs : var list;
inputs_lvs : typed_pattern list;
(** The inputs seen as patterns. See {!fun_body}. *)
- back_output_tys : ty list option;
- (** The types of the given back values, if we ar esynthesizing a backward
- function *)
+ output_ty : ty; (** The output type of the loop *)
loop_body : texpression;
}
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 67495ab5..e7e9d5e1 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -459,7 +459,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
} =
loop
@@ -478,7 +478,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in
@@ -1498,26 +1498,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
in
- let output =
- match loop.back_output_tys with
- | None ->
- (* Forward function: the return type is the same as the
- parent function *)
- fun_sig.output
- | Some doutputs ->
- (* Backward function: custom return type *)
- let output = mk_simpl_tuple_ty doutputs in
- let output =
- if loop_fwd_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if loop_fwd_effect_info.can_fail then mk_result_ty output
- else output
- in
- output
- in
+ let output = loop.output_ty in
let loop_sig =
{
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index d60d6a05..a989fd3b 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -188,12 +188,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
List.iter check_branch branches)
| Loop loop ->
assert (loop.fun_end.ty = e.ty);
- (* If we translate forward functions, the type of the loop is the same
- as the type of the parent expression - in case of backward functions,
- the loop doesn't necessarily give back the same values as the parent
- function
- *)
- assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty);
check_texpression ctx loop.fun_end;
check_texpression ctx loop.loop_body
| StructUpdate supd -> (
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index cd367d83..bf92482a 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Compute the backward outputs *)
let ctx = ref ctx in
let rg_to_given_back_tys =
- T.RegionGroupId.Map.map
+ RegionGroupId.Map.map
(fun (_, tys) ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
List.map
@@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
let ctx = !ctx in
- let back_output_tys =
- match ctx.bid with
- | None -> None
- | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys)
+ (* The output type of the loop function *)
+ let output_ty =
+ if !Config.return_back_funs then
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_tys =
+ List.filter_map
+ (fun ((back_sg, given_back) : back_sg_info * ty list) ->
+ let effect_info = back_sg.effect_info in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty)
+ (List.combine back_sgs given_back_tys)
+ in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ else
+ match ctx.bid with
+ | None ->
+ (* Forward function: same type as the parent function *)
+ (translate_fun_sig_from_decomposed ctx.sg None).output
+ | Some rg_id ->
+ (* Backward function: custom return type *)
+ let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in
+ let output = mk_simpl_tuple_ty doutputs in
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if fwd_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if fwd_effect_info.can_fail then mk_result_ty output else output
+ in
+ output
in
(* Add the loop information in the context *)
@@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in