diff options
author | Son Ho | 2023-12-22 21:03:17 +0100 |
---|---|---|
committer | Son Ho | 2023-12-22 21:03:17 +0100 |
commit | 70d506d148e5ae1a3e4115034161f449aff666ed (patch) | |
tree | 43faecd146f5d792d398512097b3afdb503ae11c | |
parent | b230ddacd44a1ca1804940bf89253bde8de7ffe1 (diff) |
Fix the output type of the loops backward functions
-rw-r--r-- | compiler/PrintPure.ml | 11 | ||||
-rw-r--r-- | compiler/Pure.ml | 4 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 25 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 6 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 65 |
5 files changed, 65 insertions, 46 deletions
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 315dd512..66475d02 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -711,21 +711,14 @@ and loop_to_string (env : fmt_env) (indent : string) (indent_incr : string) ^ String.concat "; " (List.map (var_to_string env) loop.inputs) ^ "]" in - let back_output_tys = - let tys = - match loop.back_output_tys with - | None -> "" - | Some tys -> String.concat "; " (List.map (ty_to_string env false) tys) - in - "back_output_tys: [" ^ tys ^ "]" - in + let output_ty = "output_ty: " ^ ty_to_string env false loop.output_ty in let fun_end = texpression_to_string env false indent2 indent_incr loop.fun_end in let loop_body = texpression_to_string env false indent2 indent_incr loop.loop_body in - "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n" + "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ output_ty ^ "\n" ^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n" ^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n" ^ indent ^ "}" diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 71531688..a879ba37 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -754,9 +754,7 @@ and loop = { inputs : var list; inputs_lvs : typed_pattern list; (** The inputs seen as patterns. See {!fun_body}. *) - back_output_tys : ty list option; - (** The types of the given back values, if we ar esynthesizing a backward - function *) + output_ty : ty; (** The output type of the loop *) loop_body : texpression; } diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 67495ab5..e7e9d5e1 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -459,7 +459,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } = loop @@ -478,7 +478,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in @@ -1498,26 +1498,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] in - let output = - match loop.back_output_tys with - | None -> - (* Forward function: the return type is the same as the - parent function *) - fun_sig.output - | Some doutputs -> - (* Backward function: custom return type *) - let output = mk_simpl_tuple_ty doutputs in - let output = - if loop_fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if loop_fwd_effect_info.can_fail then mk_result_ty output - else output - in - output - in + let output = loop.output_ty in let loop_sig = { diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index d60d6a05..a989fd3b 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -188,12 +188,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = List.iter check_branch branches) | Loop loop -> assert (loop.fun_end.ty = e.ty); - (* If we translate forward functions, the type of the loop is the same - as the type of the parent expression - in case of backward functions, - the loop doesn't necessarily give back the same values as the parent - function - *) - assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body | StructUpdate supd -> ( diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index cd367d83..bf92482a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in let rg_to_given_back_tys = - T.RegionGroupId.Map.map + RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) List.map @@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in let ctx = !ctx in - let back_output_tys = - match ctx.bid with - | None -> None - | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) + (* The output type of the loop function *) + let output_ty = + if !Config.return_back_funs then + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_tys = + List.filter_map + (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let effect_info = back_sg.effect_info in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty) + (List.combine back_sgs given_back_tys) + in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + else + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output in (* Add the loop information in the context *) @@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in |