diff options
-rw-r--r-- | src/PureMicroPasses.ml | 286 | ||||
-rw-r--r-- | src/PureUtils.ml | 26 |
2 files changed, 143 insertions, 169 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 3c25e7b6..b4f4462b 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -95,7 +95,7 @@ let get_body_min_var_counter (body : fun_body) : VarId.generator = method! visit_var _ v _ = v.id (** For the lvalues *) - method! visit_place _ p _ = p.var + method! visit_Local _ vid _ = vid (** For the rvalues *) end in @@ -345,16 +345,12 @@ let compute_pretty_names (def : fun_decl) : fun_decl = | _ -> ctx in (* Specific case of constraint on rvalues *) - let add_right_constraint (mp : mplace) (rv : typed_rvalue) (ctx : pn_ctx) : + let add_right_constraint (mp : mplace) (rv : texpression) (ctx : pn_ctx) : pn_ctx = (* Register the place *) let ctx = register_mplace mp ctx in (* Add the constraint *) - match rv.value with RvPlace p -> add_constraint mp p.var ctx | _ -> ctx - in - let add_opt_right_constraint (mp : mplace option) (rv : typed_rvalue) - (ctx : pn_ctx) : pn_ctx = - match mp with None -> ctx | Some mp -> add_right_constraint mp rv ctx + match (unmeta rv).e with Local vid -> add_constraint mp vid ctx | _ -> ctx in (* Specific case of constraint on left values *) let add_left_constraint (lv : typed_lvalue) (ctx : pn_ctx) : pn_ctx = @@ -385,9 +381,9 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (* We propagate constraints across variable reassignments: `^0 = x`, * if the destination doesn't have naming information *) match lv.value with - | LvVar (Var (({ id = _; basename = None; ty = _ } as lvar), lmp)) -> ( + | LvVar (Var (({ id = _; basename = None; ty = _ } as lvar), lmp)) -> if - (* Check that there is not already a name for teh variable *) + (* Check that there is not already a name for the variable *) VarId.Map.mem lvar.id ctx.pure_vars then ctx else @@ -402,31 +398,28 @@ let compute_pretty_names (def : fun_decl) : fun_decl = | None -> ctx | Some lmp -> add_llbc_var_constraint lmp.var_id name ctx in - match re.e with - | Value (rv, rmp) -> - (* We try to use the right-place information *) - let ctx = - match rmp with - | Some { var_id; name; projection = [] } -> ( - if Option.is_some name then add (Option.get name) ctx - else - match V.VarId.Map.find_opt var_id ctx.llbc_vars with - | None -> ctx - | Some name -> add name ctx) - | _ -> ctx - in - (* We try to use the rvalue information *) - let ctx = - match rv with - | { value = RvPlace { var = rvar_id; projection = [] }; ty = _ } - -> ( - match VarId.Map.find_opt rvar_id ctx.pure_vars with - | None -> ctx - | Some name -> add name ctx) - | _ -> ctx - in - ctx - | _ -> ctx) + (* We try to use the right-place information *) + let rmp, re = opt_unmeta_mplace re in + let ctx = + match rmp with + | Some { var_id; name; projection = [] } -> ( + if Option.is_some name then add (Option.get name) ctx + else + match V.VarId.Map.find_opt var_id ctx.llbc_vars with + | None -> ctx + | Some name -> add name ctx) + | _ -> ctx + in + (* We try to use the rvalue information, if it is a variable *) + let ctx = + match (unmeta re).e with + | Local rvar_id -> ( + match VarId.Map.find_opt rvar_id ctx.pure_vars with + | None -> ctx + | Some name -> add name ctx) + | _ -> ctx + in + ctx | _ -> ctx in @@ -436,25 +429,21 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ty = e.ty in let ctx, e = match e.e with - | Value (v, mp) -> update_value v mp ctx + | Local _ -> (* Nothing to do *) (ctx, e.e) + | Const _ -> (* Nothing to do *) (ctx, e.e) | App (app, arg) -> let ctx, app = update_texpression app ctx in let ctx, arg = update_texpression arg ctx in let e = App (app, arg) in (ctx, e) | Abs (x, e) -> update_abs x e ctx - | Func _ -> (* nothing to do *) (ctx, e.e) + | Qualif _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Meta (meta, e) -> update_meta meta e ctx in (ctx, { e; ty }) (* *) - and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) : - pn_ctx * expression = - let ctx = add_opt_right_constraint mp v ctx in - (ctx, Value (v, mp)) - (* *) and update_abs (x : typed_lvalue) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) @@ -507,24 +496,29 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (* *) and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = - match meta with - | Assignment (mp, rvalue, rmp) -> - let ctx = add_right_constraint mp rvalue ctx in - let ctx = - match (mp.projection, rmp) with - | [], Some { var_id; name; projection = [] } -> ( - let name = + let ctx = + match meta with + | Assignment (mp, rvalue, rmp) -> + let ctx = add_right_constraint mp rvalue ctx in + let ctx = + match (mp.projection, rmp) with + | [], Some { var_id; name; projection = [] } -> ( + let name = + match name with + | Some name -> Some name + | None -> V.VarId.Map.find_opt var_id ctx.llbc_vars + in match name with - | Some name -> Some name - | None -> V.VarId.Map.find_opt var_id ctx.llbc_vars - in - match name with - | None -> ctx - | Some name -> add_llbc_var_constraint mp.var_id name ctx) - | _ -> ctx - in - let ctx, e = update_texpression e ctx in - (ctx, e.e) + | None -> ctx + | Some name -> add_llbc_var_constraint mp.var_id name ctx) + | _ -> ctx + in + ctx + | MPlace mp -> add_right_constraint mp e ctx + in + let ctx, e = update_texpression e ctx in + let e = mk_meta meta e in + (ctx, e.e) in let body = @@ -552,17 +546,10 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (** Remove the meta-information *) let remove_meta (def : fun_decl) : fun_decl = - let obj = - object - inherit [_] map_expression as super - - method! visit_Meta env _ e = super#visit_expression env e.e - end - in match def.body with | None -> def | Some body -> - let body = { body with body = obj#visit_texpression () body.body } in + let body = { body with body = PureUtils.remove_meta body.body } in { def with body = Some body } (** Inline the useless variable (re-)assignments: @@ -597,7 +584,7 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) object (self) inherit [_] map_expression as super - method! visit_Let env monadic lv re e = + method! visit_Let (env : texpression VarId.Map.t) monadic lv re e = (* In order to filter, we need to check first that: * - the let-binding is not monadic * - the left-value is a variable @@ -621,16 +608,16 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) *) let filter = if inline_pure then - match re.e with - | Value _ -> true - | App _ -> ( - (* Application: decompose, and check that function call *) - match opt_destruct_function_call re with - | Some (func, _) -> ( - match func.func with - | Regular _ -> false - | Unop _ | Binop _ -> true) - | _ -> false) + let app, _ = destruct_apps re in + match app.e with + | Const _ | Local _ -> true (* constant or variable *) + | Qualif qualif -> ( + match qualif.id with + | AdtCons _ | Proj _ -> true (* ADT constructor *) + | Func (Unop _ | Binop _) -> + true (* primitive function call *) + | Func (Regular _) -> + false (* non-primitive function call *)) | _ -> filter else false in @@ -648,44 +635,17 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) (** Visit the let-bindings to filter the useless ones (and update the substitution map while doing so *) - method! visit_Value env v mp = - (* Check if we need to substitute *) - match v.value with - | RvPlace p -> ( - match VarId.Map.find_opt p.var env with - | None -> (* No substitution *) super#visit_Value env v mp - | Some ne -> - (* Substitute - note that we need to reexplore, because - * there may be stacked substitutions, if we have: - * var0 --> var1 - * var1 --> var2. - * - * Also: we can always substitute if we substitute with - * a variable. If we substitute with a value we need to - * check that the path is empty. - * TODO: actually do a projection *) - if is_var ne then - let var = as_var ne in - let p = { p with var } in - let nv = { v with value = RvPlace p } in - self#visit_Value env nv mp - else if p.projection = [] then self#visit_expression env ne.e - else super#visit_Value env v mp) - | _ -> (* No substitution *) super#visit_Value env v mp - (** Visit the values, to substitute them if possible *) - - method! visit_RvPlace env p = - if p.projection = [] then - match VarId.Map.find_opt p.var env with - | None -> (* No substitution *) super#visit_RvPlace env p - | Some ne -> ( - (* Substitute if the new expression is a value *) - match ne.e with - | Value (nv, _) -> self#visit_rvalue env nv.value - | _ -> (* Not a value *) super#visit_RvPlace env p) - else (* TODO: project *) - super#visit_RvPlace env p - (** Visit the places used as rvalues, to substitute them if possible *) + method! visit_Local (env : texpression VarId.Map.t) (vid : VarId.id) = + match VarId.Map.find_opt vid env with + | None -> (* No substitution *) super#visit_Local env vid + | Some ne -> + (* Substitute - note that we need to reexplore, because + * there may be stacked substitutions, if we have: + * var0 --> var1 + * var1 --> var2. + *) + self#visit_expression env ne.e + (** Substitute the variables *) end in match def.body with @@ -721,11 +681,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call `f x`. *) -let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) - (args0 : texpression list) (e : texpression) : bool = - let check_call (func1 : func) (args1 : texpression list) : bool = - (* Check the func_ids, to see if call1's function is a child of call0's function *) - match (func0.func, func1.func) with +let expression_contains_child_call_in_all_paths (ctx : trans_ctx) + (fun_id0 : fun_id) (tys0 : ty list) (args0 : texpression list) + (e : texpression) : bool = + let check_call (fun_id1 : fun_id) (tys1 : ty list) (args1 : texpression list) + : bool = + (* Check the fun_ids, to see if call1's function is a child of call0's function *) + match (fun_id0, fun_id1) with | Regular (id0, rg_id0), Regular (id1, rg_id1) -> (* Both are "regular" calls: check if they come from the same rust function *) if id0 = id1 then @@ -765,15 +727,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) in let args = List.combine args0 call1_args in (* Note that the input values are expressions, *which may contain - * meta-values* (which we need to ignore). We only consider the - * case where both expressions are actually values. *) + * meta-values* (which we need to ignore). *) let input_eq (v0, v1) = - match (v0.e, v1.e) with - | Value (v0, _), Value (v1, _) -> v0 = v1 - | _ -> false + PureUtils.remove_meta v0 = PureUtils.remove_meta v1 in (* Compare the input types and the prefix of the input arguments *) - func0.type_params = func1.type_params && List.for_all input_eq args + tys0 = tys1 && List.for_all input_eq args else (* Not a child *) false else (* Not the same function *) @@ -791,21 +750,23 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) method! visit_texpression env e = match e.e with - | Value (_, _) -> fun _ -> false + | Local _ | Const _ -> fun _ -> false | Let (_, _, re, e) -> ( match opt_destruct_function_call re with | None -> fun () -> self#visit_texpression env e () - | Some (func1, args1) -> - let call_is_child = check_call func1 args1 in + | Some (func1, tys1, args1) -> + let call_is_child = check_call func1 tys1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) | App _ -> ( fun () -> match opt_destruct_function_call e with - | Some (func1, args1) -> check_call func1 args1 + | Some (func1, tys1, args1) -> check_call func1 tys1 args1 | None -> false) | Abs (_, e) -> self#visit_texpression env e - | Func _ -> fun () -> false + | Qualif _ -> + (* Note that this case includes functions without arguments *) + fun () -> false | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body @@ -878,12 +839,15 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method plus s0 s1 _ = VarId.Set.union (s0 ()) (s1 ()) - method! visit_place _ p = (p, fun _ -> VarId.Set.singleton p.var) - (** Whenever we visit a place, we need to register the used variable *) + method! visit_Local _ vid = (Local vid, fun _ -> VarId.Set.singleton vid) + (** Whenever we visit a variable, we need to register the used variable *) method! visit_expression env e = match e with - | Value (_, _) | App _ | Func _ | Switch (_, _) | Meta (_, _) | Abs _ -> + | Local _ | Const _ | App _ | Qualif _ + | Switch (_, _) + | Meta (_, _) + | Abs _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -907,13 +871,13 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * We can filter if the right-expression is a function call, * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (func, args) -> + | true, Some (func, tys, args) -> (* We need to check if there is a child call - see * the comments for: * [expression_contains_child_call_in_all_paths] *) let has_child_call = - expression_contains_child_call_in_all_paths ctx func args - e + expression_contains_child_call_in_all_paths ctx func tys + args e in if has_child_call then (* Filter *) (e.e, fun _ -> used) @@ -1012,9 +976,11 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = | Var (v, mp) -> if v.ty = unit_ty then Dummy else Var (v, mp) (** Replace in lvalues *) - method! visit_typed_rvalue env rv = - if rv.ty = unit_ty then unit_rvalue else super#visit_typed_rvalue env rv - (** Replace in rvalues *) + method! visit_texpression env e = + if e.ty = unit_ty then unit_rvalue else super#visit_texpression env e + (** Replace in "regular" expressions - note that we could limit ourselves + to variables, but this is more powerful + *) end in (* Update the body *) @@ -1046,8 +1012,8 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = method! visit_texpression env e = match opt_destruct_function_call e with - | Some (func, args) -> ( - match func.func with + | Some (fun_id, _tys, args) -> ( + match fun_id with | Regular (A.Assumed aid, rg_id) -> ( (* Below, when dealing with the arguments: we consider the very * general case, where functions could be boxed (meaning we @@ -1065,7 +1031,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | A.BoxDeref, Some _ -> (* `Box::deref` backward is `()` (doesn't give back anything) *) assert (args = []); - mk_value_expression unit_rvalue None + unit_rvalue | A.BoxDerefMut, None -> (* `Box::deref_mut` forward is the identity *) let arg, args = Collections.List.pop args in @@ -1082,7 +1048,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = mk_apps arg args | A.BoxFree, _ -> assert (args = []); - mk_value_expression unit_rvalue None + unit_rvalue | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen | A.VecIndex | A.VecIndexMut ), _ ) -> @@ -1133,8 +1099,7 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : let vid = fresh_id () in let tmp : var = { id = vid; basename = None; ty = lv.ty } in let ltmp = mk_typed_lvalue_from_var tmp None in - let rtmp = mk_typed_rvalue_from_var tmp in - let rtmp = mk_value_expression rtmp None in + let rtmp = mk_texpression_from_var tmp in (* Visit the next expression *) let next_e = self#visit_texpression env next_e in (* Create the let-bindings *) @@ -1189,8 +1154,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) if Option.is_some (opt_destruct_state_monad_result sb_ty) then (* Generate a fresh state variable *) let state_var = fresh_state_var () in - let state_value = mk_typed_rvalue_from_var state_var in - let state_value = mk_value_expression state_value None in + let state_value = mk_texpression_from_var state_var in let state_lvar = mk_typed_lvalue_from_var state_var None in (* Apply in all the branches and reconstruct the switch *) let mk_app e = mk_app e state_value in @@ -1251,19 +1215,13 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let re_no_monad_ty = destruct_result re_no_arrow_ty in (* Add the state argument on the right-expression *) let re = - let state_value = mk_typed_rvalue_from_var state_var in - let state_value = mk_value_expression state_value None in + let state_value = mk_texpression_from_var state_var in mk_app re state_value in (* Create the match *) let fail_pat = mk_result_fail_lvalue re_no_monad_ty in let fail_value = mk_result_fail_rvalue e_no_monad_ty in - let fail_branch = - { - pat = fail_pat; - branch = mk_value_expression fail_value None; - } - in + let fail_branch = { pat = fail_pat; branch = fail_value } in (* The `Success` branch introduces a fresh state variable *) let pat_state_var = fresh_state_var () in let pat_state_lvalue = @@ -1273,10 +1231,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) mk_result_return_lvalue (mk_simpl_tuple_lvalue [ pat_state_lvalue; lv ]) in - let pat_state_rvalue = mk_typed_rvalue_from_var pat_state_var in - let pat_state_rvalue = - mk_value_expression pat_state_rvalue None - in + let pat_state_rvalue = mk_texpression_from_var pat_state_var in (* TODO: write a utility to create matches (and perform * type-checking, etc.) *) let success_branch = @@ -1297,12 +1252,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) assert (lv.ty = re_ty); let fail_pat = mk_result_fail_lvalue lv.ty in let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { - pat = fail_pat; - branch = mk_value_expression fail_value None; - } - in + let fail_branch = { pat = fail_pat; branch = fail_value } in let success_pat = mk_result_return_lvalue lv in let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in @@ -1319,8 +1269,8 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) (* Then, add a "state" input variable if necessary: *) if config.use_state_monad then (* - in the body *) - let state_rvalue = mk_typed_rvalue_from_var state_var in - let body_e = mk_app body_e (mk_value_expression state_rvalue None) in + let state_rvalue = mk_texpression_from_var state_var in + let body_e = mk_app body_e state_rvalue in (* - in the signature *) let sg = def.signature in (* Input types *) diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 21f39c0e..b6676db4 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -196,6 +196,17 @@ let as_var (e : texpression) : VarId.id = let rec unmeta (e : texpression) : texpression = match e.e with Meta (_, e) -> unmeta e | _ -> e +(** Remove *all* the meta information *) +let remove_meta (e : texpression) : texpression = + let obj = + object + inherit [_] map_expression as super + + method! visit_Meta env _ e = super#visit_expression env e.e + end + in + obj#visit_texpression () e + let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) (** Construct a type as a list of arrows: ty1 -> ... tyn *) @@ -238,10 +249,20 @@ let opt_destruct_qualif_app (e : texpression) : let app, args = destruct_apps e in match app.e with Qualif qualif -> Some (qualif, args) | _ -> None -(** Destruct an expression into a function identifier and a list of arguments *) +(** Destruct an expression into a qualif identifier and a list of arguments *) let destruct_qualif_app (e : texpression) : qualif * texpression list = Option.get (opt_destruct_qualif_app e) +(** Destruct an expression into a function call, if possible *) +let opt_destruct_function_call (e : texpression) : + (fun_id * ty list * texpression list) option = + match opt_destruct_qualif_app e with + | None -> None + | Some (qualif, args) -> ( + match qualif.id with + | Func fun_id -> Some (fun_id, qualif.type_params, args) + | _ -> None) + let opt_destruct_result (ty : ty) : ty option = match ty with | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) @@ -453,6 +474,9 @@ let opt_destruct_state_monad_result (ty : ty) : ty option = else None | _ -> None +let opt_unmeta_mplace (e : texpression) : mplace option * texpression = + match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e) + (** Utility function, used for type checking - TODO: move *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) : |