From 6f22190cba92a44b6c74bfcce8f5ed142a68e195 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 31 Aug 2023 12:47:43 +0200 Subject: Start adding support for traits --- compiler/PureUtils.ml | 132 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 44 deletions(-) (limited to 'compiler/PureUtils.ml') diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index f099ef9c..1357793b 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) (projection : mprojection) : mplace = { var_id; name; projection } +let empty_generic_params : generic_params = + { types = []; const_generics = []; trait_clauses = [] } + +let empty_generic_args : generic_args = + { types = []; const_generics = []; trait_refs = [] } + +let mk_generic_args_from_types (types : ty list) : generic_args = + { types; const_generics = []; trait_refs = [] } + +type subst = { + ty_subst : TypeVarId.id -> ty; + cg_subst : ConstGenericVarId.id -> const_generic; + tr_subst : TraitClauseId.id -> trait_instance_id; + tr_self : trait_instance_id; +} + (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = +let ty_substitute (subst : subst) (ty : ty) : ty = let obj = object inherit [_] map_ty - method! visit_TypeVar _ var_id = tsubst var_id - method! visit_ConstGenericVar _ var_id = cgsubst var_id + method! visit_TypeVar _ var_id = subst.ty_subst var_id + method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end in obj#visit_ty () ty @@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list) (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = Substitute.make_const_generic_subst_from_vars vars cgs +let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) : + TraitClauseId.id -> trait_instance_id = + let clauses = List.map (fun x -> x.clause_id) clauses in + let refs = List.map (fun x -> TraitRef x) refs in + let ls = List.combine clauses refs in + let mp = + List.fold_left + (fun mp (k, v) -> TraitClauseId.Map.add k v mp) + TraitClauseId.Map.empty ls + in + fun id -> TraitClauseId.Map.find id mp + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl) - def: " ^ show_type_decl def ^ "\n- opt_variant_id: " ^ opt_variant_id)) +let make_subst_from_generics (params : generic_params) (args : generic_args) + (tr_self : trait_instance_id) : subst = + let ty_subst = make_type_subst params.types args.types in + let cg_subst = + make_const_generic_subst params.const_generics args.const_generics + in + let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in + { ty_subst; cg_subst; tr_subst; tr_self } + (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) - (cgs : const_generic list) : ty list = - let ty_subst = make_type_subst def.type_params types in - let cg_subst = make_const_generic_subst def.const_generic_params cgs in + (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list = + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.generics generics tr_self in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields + List.map (fun f -> ty_substitute subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : - inst_fun_sig = - let subst = ty_substitute tsubst cgsubst in +let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = + let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -194,9 +230,9 @@ let is_global (e : texpression) : bool = let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = +let ty_as_adt (ty : ty) : type_id * generic_args = match ty with - | Adt (id, tys, cgs) -> (id, tys, cgs) + | Adt (id, generics) -> (id, generics) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -294,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list = (** Destruct an expression into a function call, if possible *) let opt_destruct_function_call (e : texpression) : - (fun_or_op_id * ty list * texpression list) option = + (fun_or_op_id * generic_args * texpression list) option = match opt_destruct_qualif_app e with | None -> None | Some (qualif, args) -> ( match qualif.id with - | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args) + | FunOrOp fun_id -> Some (fun_id, qualif.generics, args) | _ -> None) let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys, cgs) -> - assert (cgs = []); - Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some (Collections.List.to_cons_nil generics.types) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = match ty with - | Adt (Tuple, tys, cgs) -> - assert (cgs = []); - Some tys + | Adt (Tuple, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some generics.types | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = @@ -387,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) + match tys with + | [ ty ] -> ty + | _ -> Adt (Tuple, mk_generic_args_from_types tys) let mk_bool_ty : ty = Literal Bool -let mk_unit_ty : ty = Adt (Tuple, [], []) +let mk_unit_ty : ty = Adt (Tuple, empty_generic_args) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -434,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -445,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types tys } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -467,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type = let ty_as_literal (t : ty) : T.literal_type = match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, [], []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) -let mk_error_ty : ty = Adt (Assumed Error, [], []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) +let mk_state_ty : ty = Adt (Assumed State, empty_generic_args) + +let mk_result_ty (ty : ty) : ty = + Adt (Assumed Result, mk_generic_args_from_types [ ty ]) + +let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args) +let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ], cgs) -> - assert (cgs = []); + | Adt + (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] }) + -> ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -505,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -518,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -530,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -565,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in + let adt_id, generics = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values -- cgit v1.2.3