summaryrefslogtreecommitdiff
path: root/src/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/PureMicroPasses.ml474
1 files changed, 267 insertions, 207 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index b110f829..5227b2ad 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -311,14 +311,22 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(ctx, e.e)
in
- let input_names =
- List.filter_map
- (fun (v : var) ->
- match v.basename with None -> None | Some name -> Some (v.id, name))
- def.inputs
+ let body =
+ match def.body with
+ | None -> None
+ | Some body ->
+ let input_names =
+ List.filter_map
+ (fun (v : var) ->
+ match v.basename with
+ | None -> None
+ | Some name -> Some (v.id, name))
+ body.inputs
+ in
+ let ctx = VarId.Map.of_list input_names in
+ let _, body_exp = update_texpression body.body ctx in
+ Some { body with body = body_exp }
in
- let ctx = VarId.Map.of_list input_names in
- let _, body = update_texpression def.body ctx in
{ def with body }
(** Remove the meta-information *)
@@ -330,8 +338,11 @@ let remove_meta (def : fun_decl) : fun_decl =
method! visit_Meta env _ e = super#visit_expression env e.e
end
in
- let body = obj#visit_texpression () def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
(** Inline the useless variable (re-)assignments:
@@ -452,8 +463,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
(** Visit the places used as rvalues, to substitute them if possible *)
end
in
- let body = obj#visit_texpression VarId.Map.empty def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body =
+ { body with body = obj#visit_texpression VarId.Map.empty body.body }
+ in
+ { def with body = Some body }
(** Given a forward or backward function call, is there, for every execution
path, a child backward function called later with exactly the same input
@@ -683,22 +699,29 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
dont_filter ()
end
in
- (* Visit the body *)
- let body, used_vars = expr_visitor#visit_texpression () def.body in
- (* Visit the parameters - TODO: update: we can filter only if the definition
- * is not recursive (otherwise it might mess up with the decrease clauses:
- * the decrease clauses uses all the inputs given to the function, if some
- * inputs are replaced by '_' we can't give it to the function used in the
- * decreases clause).
- * For now we deactivate the filtering. *)
- let used_vars = used_vars () in
- let inputs_lvs =
- if false then
- List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs
- else def.inputs_lvs
- in
- (* Return *)
- { def with body; inputs_lvs }
+ (* We filter only inside of transparent (i.e., non-opaque) definitions *)
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Visit the body *)
+ let body_exp, used_vars = expr_visitor#visit_texpression () body.body in
+ (* Visit the parameters - TODO: update: we can filter only if the definition
+ * is not recursive (otherwise it might mess up with the decrease clauses:
+ * the decrease clauses uses all the inputs given to the function, if some
+ * inputs are replaced by '_' we can't give it to the function used in the
+ * decreases clause).
+ * For now we deactivate the filtering. *)
+ let used_vars = used_vars () in
+ let inputs_lvs =
+ if false then
+ List.map
+ (fun lv -> fst (filter_typed_lvalue used_vars lv))
+ body.inputs_lvs
+ else body.inputs_lvs
+ in
+ (* Return *)
+ let body = { body with body = body_exp; inputs_lvs } in
+ { def with body = Some body }
(** Return `None` if the function is a backward function with no outputs (so
that we eliminate the definition which is useless).
@@ -763,21 +786,31 @@ let to_monadic (config : config) (def : fun_decl) : fun_decl =
super#visit_call env call
end
in
- let body = obj#visit_texpression () def.body in
- let def = { def with body } in
+ let def =
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
+ in
(* Update the signature: first the input types *)
let def =
- if def.inputs = [] && config.add_unit_args then (
- assert (def.signature.inputs = []);
+ if def.signature.inputs = [] && config.add_unit_args then
let signature = { def.signature with inputs = [ unit_ty ] } in
- let var_cnt = get_expression_min_var_counter def.body.e in
- let id, _ = VarId.fresh var_cnt in
- let var = { id; basename = None; ty = unit_ty } in
- let inputs = [ var ] in
- let input_lv = mk_typed_lvalue_from_var var None in
- let inputs_lvs = [ input_lv ] in
- { def with signature; inputs; inputs_lvs })
+ let body =
+ match def.body with
+ | None -> None
+ | Some body ->
+ let var_cnt = get_expression_min_var_counter body.body.e in
+ let id, _ = VarId.fresh var_cnt in
+ let var = { id; basename = None; ty = unit_ty } in
+ let inputs = [ var ] in
+ let input_lv = mk_typed_lvalue_from_var var None in
+ let inputs_lvs = [ input_lv ] in
+ Some { body with inputs; inputs_lvs }
+ in
+ { def with signature; body }
else def
in
(* Then the output type *)
@@ -830,11 +863,15 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
- (* Update the input parameters *)
- let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in
- (* Return *)
- { def with body; inputs_lvs }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body_exp = obj#visit_texpression () body.body in
+ (* Update the input parameters *)
+ let inputs_lvs = List.map (obj#visit_typed_lvalue ()) body.inputs_lvs in
+ (* Return *)
+ let body = Some { body with body = body_exp; inputs_lvs } in
+ { def with body }
(** Eliminate the box functions like `Box::new`, `Box::deref`, etc. Most of them
are translated to identity, and `Box::free` is translated to `()`.
@@ -887,178 +924,201 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = Some { body with body = obj#visit_texpression () body.body } in
+ { def with body }
(** Decompose the monadic let-bindings.
See the explanations in [config].
*)
-let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl
- =
- (* Set up the var id generator *)
- let cnt = get_expression_min_var_counter def.body.e in
- let _, fresh_id = VarId.mk_stateful_generator cnt in
- (* It is a very simple map *)
- let obj =
- object (self)
- inherit [_] map_expression as super
-
- method! visit_Let env monadic lv re next_e =
- if not monadic then super#visit_Let env monadic lv re next_e
- else
- (* If monadic, we need to check if the left-value is a variable:
- * - if yes, don't decompose
- * - if not, make the decomposition in two steps
- *)
- match lv.value with
- | LvVar _ ->
- (* Variable: nothing to do *)
- super#visit_Let env monadic lv re next_e
- | _ ->
- (* Not a variable: decompose *)
- (* Introduce a temporary variable to receive the value of the
- * monadic binding *)
- let vid = fresh_id () in
- let tmp : var = { id = vid; basename = None; ty = lv.ty } in
- let ltmp = mk_typed_lvalue_from_var tmp None in
- let rtmp = mk_typed_rvalue_from_var tmp in
- let rtmp = mk_value_expression rtmp None in
- (* Visit the next expression *)
- let next_e = self#visit_texpression env next_e in
- (* Create the let-bindings *)
- (mk_let true ltmp re (mk_let false lv rtmp next_e)).e
- end
- in
- (* Update the body *)
- let body = obj#visit_texpression () def.body in
- (* Return *)
- { def with body }
+let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
+ fun_decl =
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Set up the var id generator *)
+ let cnt = get_expression_min_var_counter body.body.e in
+ let _, fresh_id = VarId.mk_stateful_generator cnt in
+ (* It is a very simple map *)
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_Let env monadic lv re next_e =
+ if not monadic then super#visit_Let env monadic lv re next_e
+ else
+ (* If monadic, we need to check if the left-value is a variable:
+ * - if yes, don't decompose
+ * - if not, make the decomposition in two steps
+ *)
+ match lv.value with
+ | LvVar _ ->
+ (* Variable: nothing to do *)
+ super#visit_Let env monadic lv re next_e
+ | _ ->
+ (* Not a variable: decompose *)
+ (* Introduce a temporary variable to receive the value of the
+ * monadic binding *)
+ let vid = fresh_id () in
+ let tmp : var = { id = vid; basename = None; ty = lv.ty } in
+ let ltmp = mk_typed_lvalue_from_var tmp None in
+ let rtmp = mk_typed_rvalue_from_var tmp in
+ let rtmp = mk_value_expression rtmp None in
+ (* Visit the next expression *)
+ let next_e = self#visit_texpression env next_e in
+ (* Create the let-bindings *)
+ (mk_let true ltmp re (mk_let false lv rtmp next_e)).e
+ end
+ in
+ (* Update the body *)
+ let body = Some { body with body = obj#visit_texpression () body.body } in
+ (* Return *)
+ { def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
(def : fun_decl) : fun_decl =
- (* We may need to introduce fresh variables for the state *)
- let var_cnt = get_expression_min_var_counter def.body.e in
- let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
- let fresh_state_var () =
- let id = fresh_var_id () in
- { id; basename = Some "st"; ty = mk_state_ty }
- in
- (* It is a very simple map *)
- let obj =
- object (self)
- inherit [_] map_expression as super
-
- method! visit_Let state_var monadic lv re e =
- if not monadic then super#visit_Let state_var monadic lv re e
- else
- (* We don't do the same thing if we use a state-error monad or simply
- * an error monad.
- * Note that some functions always live in the error monad (arithmetic
- * operations, for instance).
- *)
- let re_call =
- match re.e with
- | Call call -> call
- | _ -> raise (Failure "Unreachable: expected a function call")
- in
- (* TODO: this information should be computed in SymbolicToPure and
- * store in an enum ("monadic" should be an enum, not a bool).
- * Also: everything will be cleaner once we update the AST to make
- * it more idiomatic lambda calculus... *)
- let re_call_can_use_state =
- match re_call.func with
- | Regular (A.Local _, _) -> true
- | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
- in
- if config.use_state_monad && re_call_can_use_state then
- let re_call =
- let call = re_call in
- let state_value = mk_typed_rvalue_from_var state_var in
- let args = call.args @ [ mk_value_expression state_value None ] in
- Call { call with args }
- in
- let re = { re with e = re_call } in
- (* Create the match *)
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
- in
- (* The `Success` branch introduces a fresh state variable *)
- let state_var = fresh_state_var () in
- let state_value = mk_typed_lvalue_from_var state_var None in
- let success_pat =
- mk_result_return_lvalue
- (mk_simpl_tuple_lvalue [ state_value; lv ])
- in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression state_var e
- else
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
- in
- let success_pat = mk_result_return_lvalue lv in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression state_var e
-
- method! visit_Value state_var rv mp =
- if config.use_state_monad then
- match rv.ty with
- | Adt (Assumed Result, _) -> (
- match rv.value with
- | RvAdt av ->
- (* We only need to replace the content of `Return ...` *)
- (* TODO: type checking is completely broken at this point... *)
- let variant_id = Option.get av.variant_id in
- if variant_id = result_return_id then
- let res_v = Collections.List.to_cons_nil av.field_values in
- let state_value = mk_typed_rvalue_from_var state_var in
- let res = mk_simpl_tuple_rvalue [ state_value; res_v ] in
- let res = mk_result_return_rvalue res in
- (mk_value_expression res None).e
- else super#visit_Value state_var rv mp
- | _ -> raise (Failure "Unrechable"))
- | _ -> super#visit_Value state_var rv mp
- else super#visit_Value state_var rv mp
- (** We also need to update values, in case this value is `Return ...`.
-
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* We may need to introduce fresh variables for the state *)
+ let var_cnt = get_expression_min_var_counter body.body.e in
+ let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
+ let fresh_state_var () =
+ let id = fresh_var_id () in
+ { id; basename = Some "st"; ty = mk_state_ty }
+ in
+ (* It is a very simple map *)
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_Let state_var monadic lv re e =
+ if not monadic then super#visit_Let state_var monadic lv re e
+ else
+ (* We don't do the same thing if we use a state-error monad or simply
+ * an error monad.
+ * Note that some functions always live in the error monad (arithmetic
+ * operations, for instance).
+ *)
+ let re_call =
+ match re.e with
+ | Call call -> call
+ | _ -> raise (Failure "Unreachable: expected a function call")
+ in
+ (* TODO: this information should be computed in SymbolicToPure and
+ * store in an enum ("monadic" should be an enum, not a bool).
+ * Also: everything will be cleaner once we update the AST to make
+ * it more idiomatic lambda calculus... *)
+ let re_call_can_use_state =
+ match re_call.func with
+ | Regular (A.Local _, _) -> true
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
+ in
+ if config.use_state_monad && re_call_can_use_state then
+ let re_call =
+ let call = re_call in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let args =
+ call.args @ [ mk_value_expression state_value None ]
+ in
+ Call { call with args }
+ in
+ let re = { re with e = re_call } in
+ (* Create the match *)
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ {
+ pat = fail_pat;
+ branch = mk_value_expression fail_value None;
+ }
+ in
+ (* The `Success` branch introduces a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_lvalue_from_var state_var None in
+ let success_pat =
+ mk_result_return_lvalue
+ (mk_simpl_tuple_lvalue [ state_value; lv ])
+ in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+ else
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ {
+ pat = fail_pat;
+ branch = mk_value_expression fail_value None;
+ }
+ in
+ let success_pat = mk_result_return_lvalue lv in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+
+ method! visit_Value state_var rv mp =
+ if config.use_state_monad then
+ match rv.ty with
+ | Adt (Assumed Result, _) -> (
+ match rv.value with
+ | RvAdt av ->
+ (* We only need to replace the content of `Return ...` *)
+ (* TODO: type checking is completely broken at this point... *)
+ let variant_id = Option.get av.variant_id in
+ if variant_id = result_return_id then
+ let res_v =
+ Collections.List.to_cons_nil av.field_values
+ in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let res =
+ mk_simpl_tuple_rvalue [ state_value; res_v ]
+ in
+ let res = mk_result_return_rvalue res in
+ (mk_value_expression res None).e
+ else super#visit_Value state_var rv mp
+ | _ -> raise (Failure "Unrechable"))
+ | _ -> super#visit_Value state_var rv mp
+ else super#visit_Value state_var rv mp
+ (** We also need to update values, in case this value is `Return ...`.
+
TODO: this is super ugly... We need to use the monadic functions
- `fail` and `return` instead.
- *)
- end
- in
- (* Update the body *)
- let input_state_var = fresh_state_var () in
- let body = obj#visit_texpression input_state_var def.body in
- let def = { def with body } in
- (* We need to update the type if we revealed the state monad *)
- let def =
- if config.use_state_monad then
- (* Update the signature *)
- let sg = def.signature in
- let sg_inputs = sg.inputs @ [ mk_state_ty ] in
- let sg_outputs = Collections.List.to_cons_nil sg.outputs in
- let _, sg_outputs = dest_arrow_ty sg_outputs in
- let sg_outputs = [ sg_outputs ] in
- let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
- (* Update the inputs list *)
- let inputs = def.inputs @ [ input_state_var ] in
- let input_lv = mk_typed_lvalue_from_var input_state_var None in
- let inputs_lvs = def.inputs_lvs @ [ input_lv ] in
- (* Update the definition *)
- { def with signature = sg; inputs; inputs_lvs }
- else def
- in
- (* Return *)
- { def with body }
+ `fail` and `return` instead.
+ *)
+ end
+ in
+ (* Update the body *)
+ let input_state_var = fresh_state_var () in
+ let body =
+ { body with body = obj#visit_texpression input_state_var body.body }
+ in
+ (* We need to update the type if we revealed the state monad *)
+ let body, signature =
+ if config.use_state_monad then
+ (* Update the signature *)
+ let sg = def.signature in
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Update the inputs list *)
+ let inputs = body.inputs @ [ input_state_var ] in
+ let input_lv = mk_typed_lvalue_from_var input_state_var None in
+ let inputs_lvs = body.inputs_lvs @ [ input_lv ] in
+ (* Update the body *)
+ let body = { body with inputs; inputs_lvs } in
+ (body, sg)
+ else (body, def.signature)
+ in
+ (* Return *)
+ { def with body = Some body; signature }
(** Apply all the micro-passes to a function.
@@ -1149,8 +1209,8 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
let def = decompose_monadic_let_bindings ctx def in
log#ldebug
(lazy
- ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
+ ("decompose_monadic_let_bindings:\n\n"
+ ^ fun_decl_to_string ctx def ^ "\n"));
def)
else (
log#ldebug