diff options
author | Son Ho | 2023-12-05 17:34:13 +0100 |
---|---|---|
committer | Son Ho | 2023-12-05 17:34:13 +0100 |
commit | 726db4911add81a853aafcec3936b457aaeff5b4 (patch) | |
tree | 2663915767c3558203990ed14f8d5604b7fd21d1 /compiler/PureUtils.ml | |
parent | 92887b89e35607e99bae2f19e4c5b2f162683d02 (diff) | |
parent | 4795e5f823bc89504855d8eb946b111d9314f4d5 (diff) |
Merge branch 'main' into son_fixes2
Diffstat (limited to '')
-rw-r--r-- | compiler/PureUtils.ml | 250 |
1 files changed, 180 insertions, 70 deletions
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 1c8d8921..a5143f3c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -15,11 +15,11 @@ end module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) (** We use this type as a key for lookups *) -type regular_fun_id_not_loop = A.fun_id * T.RegionGroupId.id option +type regular_fun_id_not_loop = LlbcAst.fun_id * RegionGroupId.id option [@@deriving show, ord] (** We use this type as a key for lookups *) -type fun_loop_id = A.FunDeclId.id * LoopId.id option [@@deriving show, ord] +type fun_loop_id = FunDeclId.id * LoopId.id option [@@deriving show, ord] module RegularFunIdNotLoopOrderedType = struct type t = regular_fun_id_not_loop @@ -59,19 +59,19 @@ module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType) let dest_arrow_ty (ty : ty) : ty * ty = match ty with - | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) + | TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") let compute_literal_type (cv : literal) : literal_type = match cv with - | PV.Scalar sv -> Integer sv.PV.int_ty - | Bool _ -> Bool - | Char _ -> Char + | VScalar sv -> TInteger sv.int_ty + | VBool _ -> TBool + | VChar _ -> TChar let var_get_id (v : var) : VarId.id = v.id let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = - let ty = Literal (compute_literal_type cv) in + let ty = TLiteral (compute_literal_type cv) in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) @@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) (projection : mprojection) : mplace = { var_id; name; projection } +let empty_generic_params : generic_params = + { types = []; const_generics = []; trait_clauses = [] } + +let empty_generic_args : generic_args = + { types = []; const_generics = []; trait_refs = [] } + +let mk_generic_args_from_types (types : ty list) : generic_args = + { types; const_generics = []; trait_refs = [] } + +type subst = { + ty_subst : TypeVarId.id -> ty; + cg_subst : ConstGenericVarId.id -> const_generic; + tr_subst : TraitClauseId.id -> trait_instance_id; + tr_self : trait_instance_id; +} + (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = +let ty_substitute (subst : subst) (ty : ty) : ty = let obj = object inherit [_] map_ty - method! visit_TypeVar _ var_id = tsubst var_id - method! visit_ConstGenericVar _ var_id = cgsubst var_id + method! visit_TVar _ var_id = subst.ty_subst var_id + method! visit_CgVar _ var_id = subst.cg_subst var_id + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end in obj#visit_ty () ty @@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list) (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = Substitute.make_const_generic_subst_from_vars vars cgs +let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) : + TraitClauseId.id -> trait_instance_id = + let clauses = List.map (fun x -> x.clause_id) clauses in + let refs = List.map (fun x -> TraitRef x) refs in + let ls = List.combine clauses refs in + let mp = + List.fold_left + (fun mp (k, v) -> TraitClauseId.Map.add k v mp) + TraitClauseId.Map.empty ls + in + fun id -> TraitClauseId.Map.find id mp + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl) - def: " ^ show_type_decl def ^ "\n- opt_variant_id: " ^ opt_variant_id)) +let make_subst_from_generics (params : generic_params) (args : generic_args) + (tr_self : trait_instance_id) : subst = + let ty_subst = make_type_subst params.types args.types in + let cg_subst = + make_const_generic_subst params.const_generics args.const_generics + in + let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in + { ty_subst; cg_subst; tr_subst; tr_self } + (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) - (cgs : const_generic list) : ty list = - let ty_subst = make_type_subst def.type_params types in - let cg_subst = make_const_generic_subst def.const_generic_params cgs in + (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list = + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.generics generics tr_self in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields + List.map (fun f -> ty_substitute subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : - inst_fun_sig = - let subst = ty_substitute tsubst cgsubst in +let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = + let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -159,12 +195,13 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) We only look for outer monadic let-bindings. This is used when printing the branches of [if ... then ... else ...]. - Rem.: this function will *fail* if there are {!constructor:Aeneas.Pure.expression.Loop} + Rem.: this function will *fail* if there are {!Pure.Loop} nodes (you should call it on an expression where those nodes have been eliminated). *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false + | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> + false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -184,15 +221,18 @@ let is_var (e : texpression) : bool = let as_var (e : texpression) : VarId.id = match e.e with Var v -> v | _ -> raise (Failure "Unreachable") +let is_cvar (e : texpression) : bool = + match e.e with CVar _ -> true | _ -> false + let is_global (e : texpression) : bool = match e.e with Qualif { id = Global _; _ } -> true | _ -> false let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = +let ty_as_adt (ty : ty) : type_id * generic_args = match ty with - | Adt (id, tys, cgs) -> (id, tys, cgs) + | TAdt (id, generics) -> (id, generics) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -209,12 +249,12 @@ let remove_meta (e : texpression) : texpression = in obj#visit_texpression () e -let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) +let mk_arrow (ty0 : ty) (ty1 : ty) : ty = TArrow (ty0, ty1) (** Construct a type as a list of arrows: ty1 -> ... tyn *) let mk_arrows (inputs : ty list) (output : ty) = let rec aux (tys : ty list) : ty = - match tys with [] -> output | ty :: tys' -> Arrow (ty, aux tys') + match tys with [] -> output | ty :: tys' -> TArrow (ty, aux tys') in aux inputs @@ -265,7 +305,7 @@ let destruct_apps (e : texpression) : texpression * texpression list = (** Make an [App (app, arg)] expression *) let mk_app (app : texpression) (arg : texpression) : texpression = match app.ty with - | Arrow (ty0, ty1) -> + | TArrow (ty0, ty1) -> (* Sanity check *) assert (ty0 = arg.ty); let e = App (app, arg) in @@ -290,32 +330,34 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list = (** Destruct an expression into a function call, if possible *) let opt_destruct_function_call (e : texpression) : - (fun_or_op_id * ty list * texpression list) option = + (fun_or_op_id * generic_args * texpression list) option = match opt_destruct_qualif_app e with | None -> None | Some (qualif, args) -> ( match qualif.id with - | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args) + | FunOrOp fun_id -> Some (fun_id, qualif.generics, args) | _ -> None) let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys, cgs) -> - assert (cgs = []); - Some (Collections.List.to_cons_nil tys) + | TAdt (TAssumed TResult, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some (Collections.List.to_cons_nil generics.types) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = match ty with - | Adt (Tuple, tys, cgs) -> - assert (cgs = []); - Some tys + | TAdt (TTuple, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some generics.types | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = - let ty = Arrow (x.ty, e.ty) in + let ty = TArrow (x.ty, e.ty) in let e = Abs (x, e) in { e; ty } @@ -328,12 +370,12 @@ let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression = let destruct_arrow (ty : ty) : ty * ty = match ty with - | Arrow (ty0, ty1) -> (ty0, ty1) + | TArrow (ty0, ty1) -> (ty0, ty1) | _ -> raise (Failure "Not an arrow type") let rec destruct_arrows (ty : ty) : ty list * ty = match ty with - | Arrow (ty0, ty1) -> + | TArrow (ty0, ty1) -> let tys, out_ty = destruct_arrows ty1 in (ty0 :: tys, out_ty) | _ -> ([], ty) @@ -366,7 +408,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : let mk_switch (scrut : texpression) (sb : switch_body) : texpression = (* Sanity check: the scrutinee has the proper type *) (match sb with - | If (_, _) -> assert (scrut.ty = Literal Bool) + | If (_, _) -> assert (scrut.ty = TLiteral TBool) | Match branches -> List.iter (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty)) @@ -383,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) + match tys with + | [ ty ] -> ty + | _ -> TAdt (TTuple, mk_generic_args_from_types tys) -let mk_bool_ty : ty = Literal Bool -let mk_unit_ty : ty = Adt (Tuple, [], []) +let mk_bool_ty : ty = TLiteral TBool +let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args) let mk_unit_rvalue : texpression = - let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let id = AdtCons { adt_id = TTuple; variant_id = None } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -409,13 +453,13 @@ let mk_dummy_pattern (ty : ty) : typed_pattern = let value = PatDummy in { value; ty } -let mk_meta (m : meta) (e : texpression) : texpression = +let mk_emeta (m : emeta) (e : texpression) : texpression = let ty = e.ty in let e = Meta (m, e) in { e; ty } let mk_mplace_texpression (mp : mplace) (e : texpression) : texpression = - mk_meta (MPlace mp) e + mk_emeta (MPlace mp) e let mk_opt_mplace_texpression (mp : mplace option) (e : texpression) : texpression = @@ -430,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = TAdt (TTuple, mk_generic_args_from_types tys) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -441,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = TAdt (TTuple, mk_generic_args_from_types tys) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) - let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys; const_generic_args = [] } in + let id = AdtCons { adt_id = TTuple; variant_id = None } in + let qualif = { id; generics = mk_generic_args_from_types tys } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -457,38 +501,42 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option) let ty_as_integer (t : ty) : T.integer_type = match t with - | Literal (Integer int_ty) -> int_ty + | TLiteral (TInteger int_ty) -> int_ty | _ -> raise (Failure "Unreachable") let ty_as_literal (t : ty) : T.literal_type = - match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") + match t with TLiteral ty -> ty | _ -> raise (Failure "Unreachable") + +let mk_state_ty : ty = TAdt (TAssumed TState, empty_generic_args) + +let mk_result_ty (ty : ty) : ty = + TAdt (TAssumed TResult, mk_generic_args_from_types [ ty ]) -let mk_state_ty : ty = Adt (Assumed State, [], []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) -let mk_error_ty : ty = Adt (Assumed Error, [], []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) +let mk_error_ty : ty = TAdt (TAssumed TError, empty_generic_args) +let mk_fuel_ty : ty = TAdt (TAssumed TFuel, empty_generic_args) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in - let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let id = AdtCons { adt_id = TAssumed TError; variant_id = Some error } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ], cgs) -> - assert (cgs = []); + | TAdt + ( TAssumed TResult, + { types = [ ty ]; const_generics = []; trait_refs = [] } ) -> ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = TAdt (TAssumed TResult, mk_generic_args_from_types type_args) in let id = - AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } + AdtCons { adt_id = TAssumed TResult; variant_id = Some result_fail_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -501,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = TAdt (TAssumed TResult, mk_generic_args_from_types type_args) in let id = - AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } + AdtCons { adt_id = TAssumed TResult; variant_id = Some result_return_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -514,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ], []) in + let ty = TAdt (TAssumed TResult, mk_generic_args_from_types [ ty ]) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -526,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ], []) in + let ty = TAdt (TAssumed TResult, mk_generic_args_from_types [ v.ty ]) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -561,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in + let adt_id, generics = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values @@ -577,3 +625,65 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option Some (mk_apps cons fields_values).e in match e_opt with None -> None | Some e -> Some { e; ty = pat.ty } + +type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id } + +let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) : + trait_decl_method_decl_id = + (* First look in the required methods *) + let method_id = + List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods + in + match method_id with + | Some (_, id) -> { is_provided = false; id } + | None -> + (* Must be a provided method *) + let _, id = + List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods + in + { is_provided = true; id = Option.get id } + +let trait_decl_is_empty (trait_decl : trait_decl) : bool = + let { + def_id = _; + is_local = _; + name = _; + llbc_name = _; + meta = _; + generics = _; + llbc_generics = _; + preds = _; + parent_clauses; + llbc_parent_clauses = _; + consts; + types; + required_methods; + provided_methods; + } = + trait_decl + in + parent_clauses = [] && consts = [] && types = [] && required_methods = [] + && provided_methods = [] + +let trait_impl_is_empty (trait_impl : trait_impl) : bool = + let { + def_id = _; + is_local = _; + name = _; + llbc_name = _; + meta = _; + impl_trait = _; + llbc_impl_trait = _; + generics = _; + llbc_generics = _; + preds = _; + parent_trait_refs; + consts; + types; + required_methods; + provided_methods; + } = + trait_impl + in + parent_trait_refs = [] && consts = [] && types = [] && required_methods = [] + && provided_methods = [] |