summaryrefslogtreecommitdiff
path: root/src/PureUtils.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/PureUtils.ml')
-rw-r--r--src/PureUtils.ml34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index b87a6346..873931be 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -478,3 +478,37 @@ let destruct_arrow (ty : ty) : ty * ty =
| _ -> raise (Failure "Unreachable")
let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1)
+
+let get_switch_body_ty (sb : switch_body) : ty =
+ match sb with
+ | If (e_then, _) -> e_then.ty
+ | Match branches ->
+ (* There should be at least one branch *)
+ (List.hd branches).branch.ty
+
+let map_switch_body_branches (f : texpression -> texpression) (sb : switch_body)
+ : switch_body =
+ match sb with
+ | If (e_then, e_else) -> If (f e_then, f e_else)
+ | Match branches ->
+ Match
+ (List.map
+ (fun (b : match_branch) -> { b with branch = f b.branch })
+ branches)
+
+let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) :
+ unit =
+ match sb with
+ | If (e_then, e_else) ->
+ f e_then;
+ f e_else
+ | Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches
+
+let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
+ (* TODO: check the type of the scrutinee *)
+ let ty = get_switch_body_ty sb in
+ (* Sanity check: all the branches have the same type *)
+ iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb;
+ (* Put together *)
+ let e = Switch (scrut, sb) in
+ { e; ty }