summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml54
-rw-r--r--compiler/PureUtils.ml18
2 files changed, 64 insertions, 8 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 1820b86a..0ac0851e 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1075,6 +1075,41 @@ let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
]}
*)
let simplify_let_then_return _ctx def =
+ (* Match a pattern and an expression: evaluates to [true] if the expression
+ is actually exactly the pattern *)
+ let rec match_pattern_and_expr (pat : typed_pattern) (e : texpression) : bool
+ =
+ match (pat.value, e.e) with
+ | PatConstant plit, Const lit -> plit = lit
+ | PatVar (pv, _), Var vid -> pv.id = vid
+ | PatDummy, _ ->
+ (* It is ok only if we ignore the unit value *)
+ pat.ty = mk_unit_ty && e = mk_unit_rvalue
+ | PatAdt padt, _ -> (
+ let qualif, args = destruct_apps e in
+ match qualif.e with
+ | Qualif { id = AdtCons cons_id; generics = _ } ->
+ if
+ pat.ty = e.ty
+ && padt.variant_id = cons_id.variant_id
+ && List.length padt.field_values = List.length args
+ then
+ List.for_all
+ (fun (p, e) -> match_pattern_and_expr p e)
+ (List.combine padt.field_values args)
+ else false
+ | _ -> false)
+ | _ -> false
+ in
+ let match_pattern_and_ret_expr (monadic : bool) (pat : typed_pattern)
+ (e : texpression) : bool =
+ if monadic then
+ match opt_destruct_ret e with
+ | Some e -> match_pattern_and_expr pat e
+ | None -> false
+ else match_pattern_and_expr pat e
+ in
+
let expr_visitor =
object (self)
inherit [_] map_expression
@@ -1089,14 +1124,9 @@ let simplify_let_then_return _ctx def =
| Switch _ | Loop _ | Let _ ->
(* Small shortcut to avoid doing the check on every let-binding *)
not_simpl_e
- | _ -> (
- match typed_pattern_to_texpression lv with
- | None -> not_simpl_e
- | Some lv_v ->
- let lv_v =
- if monadic then mk_result_return_texpression lv_v else lv_v
- in
- if lv_v = next_e then rv.e else not_simpl_e)
+ | _ ->
+ if match_pattern_and_ret_expr monadic lv next_e then rv.e
+ else not_simpl_e
end
in
@@ -1824,6 +1854,14 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
("inline_useless_var_assignments (pass 2):\n\n"
^ fun_decl_to_string ctx def ^ "\n"));
+ (* Simplify the let-then return again (the lambda simplification may have
+ unlocked more simplifications here) *)
+ let def = simplify_let_then_return ctx def in
+ log#ldebug
+ (lazy
+ ("simplify_let_then_return (pass 2):\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+
(* Decompose the monadic let-bindings - used by Coq *)
let def =
if !Config.decompose_monadic_let_bindings then (
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 80bf3c42..81e3fbe1 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -752,3 +752,21 @@ let opt_dest_struct_pattern (pat : typed_pattern) : typed_pattern list option =
match pat.value with
| PatAdt { variant_id = None; field_values } -> Some field_values
| _ -> None
+
+(** Destruct a [ret ...] expression *)
+let opt_destruct_ret (e : texpression) : texpression option =
+ match e.e with
+ | App
+ ( {
+ e =
+ Qualif
+ {
+ id = AdtCons { adt_id = TAssumed TResult; variant_id };
+ generics = _;
+ };
+ ty = _;
+ },
+ arg )
+ when variant_id = Some result_return_id ->
+ Some arg
+ | _ -> None