diff options
Diffstat (limited to 'src/PureMicroPasses.ml')
-rw-r--r-- | src/PureMicroPasses.ml | 109 |
1 files changed, 60 insertions, 49 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 82e5ecd2..2a7293c8 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -195,14 +195,18 @@ let compute_pretty_names (def : fun_def) : fun_def = in (* *) - let rec update_expression (e : expression) (ctx : pn_ctx) : - pn_ctx * expression = - match e with - | Value (v, mp) -> update_value v mp ctx - | Call call -> update_call call ctx - | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx - | Switch (scrut, body) -> update_switch_body scrut body ctx - | Meta (meta, e) -> update_meta meta e ctx + let rec update_texpression (e : texpression) (ctx : pn_ctx) : + pn_ctx * texpression = + let ty = e.ty in + let ctx, e = + match e.e with + | Value (v, mp) -> update_value v mp ctx + | Call call -> update_call call ctx + | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx + | Switch (scrut, body) -> update_switch_body scrut body ctx + | Meta (meta, e) -> update_meta meta e ctx + in + (ctx, { e; ty }) (* *) and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) : pn_ctx * expression = @@ -212,40 +216,40 @@ let compute_pretty_names (def : fun_def) : fun_def = and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression = let ctx, args = List.fold_left_map - (fun ctx arg -> update_expression arg ctx) + (fun ctx arg -> update_texpression arg ctx) ctx call.args in let call = { call with args } in (ctx, Call call) (* *) - and update_let (monadic : bool) (lv : typed_lvalue) (re : expression) - (e : expression) (ctx : pn_ctx) : pn_ctx * expression = + and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression) + (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = let ctx = add_left_constraint lv ctx in - let ctx, re = update_expression re ctx in - let ctx, e = update_expression e ctx in + let ctx, re = update_texpression re ctx in + let ctx, e = update_texpression e ctx in let lv = update_typed_lvalue ctx lv in (ctx, Let (monadic, lv, re, e)) (* *) - and update_switch_body (scrut : expression) (body : switch_body) + and update_switch_body (scrut : texpression) (body : switch_body) (ctx : pn_ctx) : pn_ctx * expression = - let ctx, scrut = update_expression scrut ctx in + let ctx, scrut = update_texpression scrut ctx in let ctx, body = match body with | If (e_true, e_false) -> - let ctx1, e_true = update_expression e_true ctx in - let ctx2, e_false = update_expression e_false ctx in + let ctx1, e_true = update_texpression e_true ctx in + let ctx2, e_false = update_texpression e_false ctx in let ctx = merge_ctxs ctx1 ctx2 in (ctx, If (e_true, e_false)) | SwitchInt (int_ty, branches, otherwise) -> let ctx_branches_ls = List.map (fun (v, br) -> - let ctx, br = update_expression br ctx in + let ctx, br = update_texpression br ctx in (ctx, (v, br))) branches in - let ctx, otherwise = update_expression otherwise ctx in + let ctx, otherwise = update_texpression otherwise ctx in let ctxs, branches = List.split ctx_branches_ls in let ctxs = merge_ctxs_ls ctxs in let ctx = merge_ctxs ctx ctxs in @@ -255,7 +259,7 @@ let compute_pretty_names (def : fun_def) : fun_def = List.map (fun br -> let ctx = add_left_constraint br.pat ctx in - let ctx, branch = update_expression br.branch ctx in + let ctx, branch = update_texpression br.branch ctx in let pat = update_typed_lvalue ctx br.pat in (ctx, { pat; branch })) branches @@ -266,12 +270,13 @@ let compute_pretty_names (def : fun_def) : fun_def = in (ctx, Switch (scrut, body)) (* *) - and update_meta (meta : meta) (e : expression) (ctx : pn_ctx) : + and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = match meta with | Assignment (mp, rvalue) -> let ctx = add_right_constraint mp rvalue ctx in - update_expression e ctx + let ctx, e = update_texpression e ctx in + (ctx, e.e) in let input_names = @@ -281,7 +286,7 @@ let compute_pretty_names (def : fun_def) : fun_def = def.inputs in let ctx = VarId.Map.of_list input_names in - let _, body = update_expression def.body ctx in + let _, body = update_texpression def.body ctx in { def with body } (** Remove the meta-information *) @@ -290,10 +295,10 @@ let remove_meta (def : fun_def) : fun_def = object inherit [_] map_expression as super - method! visit_Meta env _ e = super#visit_expression env e + method! visit_Meta env _ e = super#visit_expression env e.e end in - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in { def with body } (** Inline the useless variable reassignments (a lot of variable assignments @@ -326,7 +331,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : * - the let-binding is not monadic * - the left-value is a variable * - the assigned expression is a value *) - match (monadic, lv.value, re) with + match (monadic, lv.value, re.e) with | false, LvVar (Var (lv_var, _)), Value (rv, _) -> ( (* Check that: * - the left variable is unnamed or that [inline_named] is true @@ -337,7 +342,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : (* Update the environment and explore the next expression * (dropping the currrent let) *) let env = add_subst lv_var.id var env in - super#visit_expression env e + super#visit_expression env e.e | _ -> super#visit_Let env monadic lv re e) | _ -> super#visit_Let env monadic lv re e (** Visit the let-bindings to filter the useless ones (and update @@ -353,7 +358,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : (** Visit the places used as rvalues, to substitute them if necessary *) end in - let body = obj#visit_expression VarId.Map.empty def.body in + let body = obj#visit_texpression VarId.Map.empty def.body in { def with body } (** Given a forward or backward function call, is there, for every execution @@ -382,7 +387,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : In this situation, we can remove the call `f x`. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) - (e : expression) : bool = + (e : texpression) : bool = let check_call call1 : bool = (* Check the func_ids, to see if call1's function is a child of call0's function *) match (call0.func, call1.func) with @@ -428,7 +433,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) * meta-values* (which we need to ignore). We only consider the * case where both expressions are actually values. *) let input_eq (v0, v1) = - match (v0, v1) with + match (v0.e, v1.e) with | Value (v0, _), Value (v1, _) -> v0 = v1 | _ -> false in @@ -451,37 +456,43 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) method! visit_expression env e = match e with | Value (_, _) -> fun _ -> false - | Let (_, _, Call call1, e) -> + | Let (_, _, { e = Call call1; ty = _ }, e) -> let call_is_child = check_call call1 in if call_is_child then fun () -> true - else self#visit_expression env e + else self#visit_texpression env e | Let (_, _, re, e) -> fun () -> - self#visit_expression env re () && self#visit_expression env e () + self#visit_texpression env re () + && self#visit_texpression env e () | Call call1 -> fun () -> check_call call1 - | Meta (_, e) -> self#visit_expression env e + | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body (** We need to reimplement the way we compose the booleans *) + method! visit_texpression env e = + (* We take care not to visit the type *) + self#visit_expression env e.e + method! visit_switch_body env body = match body with | If (e1, e2) -> fun () -> - self#visit_expression env e1 () && self#visit_expression env e2 () + self#visit_texpression env e1 () + && self#visit_texpression env e2 () | SwitchInt (_, branches, otherwise) -> fun () -> List.for_all - (fun (_, br) -> self#visit_expression env br ()) + (fun (_, br) -> self#visit_texpression env br ()) branches - && self#visit_expression env otherwise () + && self#visit_texpression env otherwise () | Match branches -> fun () -> List.for_all - (fun br -> self#visit_expression env br.branch ()) + (fun br -> self#visit_texpression env br.branch ()) branches end in - visitor#visit_expression () e () + visitor#visit_texpression () e () (** Filter the unused assignments (removes the unused variables, filters the function calls) *) @@ -546,13 +557,13 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) - let e, used = self#visit_expression env e in + let e, used = self#visit_texpression env e in let used = used () in (* Filter the left values *) let lv, all_dummies = filter_typed_lvalue used lv in (* Small utility - called if we can't filter the let-binding *) let dont_filter () = - let re, used_re = self#visit_expression env re in + let re, used_re = self#visit_texpression env re in let used = VarId.Set.union used (used_re ()) in (Let (monadic, lv, re, e), fun _ -> used) in @@ -560,12 +571,12 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) if all_dummies then if not monadic then (* Not a monadic let-binding: simple case *) - (e, fun _ -> used) + (e.e, fun _ -> used) else (* Monadic let-binding: trickier. * We can filter if the right-expression is a function call, * under some conditions. *) - match (filter_monadic_calls, re) with + match (filter_monadic_calls, re.e) with | true, Call call -> (* We need to check if there is a child call - see * the comments for: @@ -574,7 +585,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) expression_contains_child_call_in_all_paths ctx call e in if has_child_call then (* Filter *) - (e, fun _ -> used) + (e.e, fun _ -> used) else (* No child call: don't filter *) dont_filter () | _ -> @@ -585,7 +596,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) end in (* Visit the body *) - let body, used_vars = expr_visitor#visit_expression () def.body in + let body, used_vars = expr_visitor#visit_texpression () def.body in (* Visit the parameters *) let used_vars = used_vars () in let inputs_lvs = @@ -604,12 +615,12 @@ let to_monadic (def : fun_def) : fun_def = method! visit_call env call = (* If no arguments, introduce unit *) if call.args = [] then - let args = [ Value (unit_rvalue, None) ] in + let args = [ mk_value_expression unit_rvalue None ] in { call with args } (* Otherwise: nothing to do *) else super#visit_call env call end in - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in let def = { def with body } in (* Update the signature: first the input types *) @@ -617,7 +628,7 @@ let to_monadic (def : fun_def) : fun_def = if def.inputs = [] then ( assert (def.signature.inputs = []); let signature = { def.signature with inputs = [ unit_ty ] } in - let var_cnt = get_expression_min_var_counter def.body in + let var_cnt = get_expression_min_var_counter def.body.e in let id, _ = VarId.fresh var_cnt in let var = { id; basename = None; ty = unit_ty } in let inputs = [ var ] in @@ -659,7 +670,7 @@ let unit_vars_to_unit (def : fun_def) : fun_def = end in (* Update the body *) - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in (* Update the input parameters *) let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in (* Return *) |