From fd17736cbdb312578b2ea6de9a58febf83bd96c8 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 3 Sep 2023 20:41:42 +0200 Subject: Extract the trait decl methods --- compiler/Extract.ml | 66 +++++++++++++++++++++++++++++++++++-------------- compiler/ExtractBase.ml | 22 ++++++----------- 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 2a678a27..138619c4 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -3213,6 +3213,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx) in List.iter extract_param def.signature.inputs +let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx) + (fmt : F.formatter) (def : fun_decl) : unit = + extract_fun_input_parameters_types ctx fmt def; + extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output + let assert_backend_supports_decreases_clauses () = match !backend with | FStar | Lean -> () @@ -3931,19 +3936,10 @@ let extract_trait_decl_method_register_names (ctx : extraction_ctx) (* We add one field per required forward/backward function *) let trans = A.FunDeclId.Map.find id ctx.trans_funs in - let register_fun ctx f = ctx_add_trait_method trait_decl name f ctx in - let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the names of the forward functions *) - let ctx = - if trans.keep_fwd then register_funs ctx (trans.fwd.f :: trans.fwd.loops) - else ctx - in - (* Register the names of the backward functions *) - List.fold_left - (fun ctx back -> - let ctx = register_fun ctx back.f in - register_funs ctx back.loops) - ctx trans.backs + let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in + (* Register the names *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + List.fold_left register_fun ctx funs (** Similar to {!extract_type_decl_register_names} *) let extract_trait_decl_register_names (ctx : extraction_ctx) @@ -4016,18 +4012,50 @@ let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ";"; F.pp_close_box fmt () +(** Small helper - TODO: move *) +let generic_params_drop_prefix (g1 : generic_params) (g2 : generic_params) : + generic_params = + let open Collections.List in + let types = drop (length g1.types) g2.types in + let const_generics = drop (length g1.const_generics) g2.const_generics in + let trait_clauses = drop (length g1.trait_clauses) g2.trait_clauses in + { types; const_generics; trait_clauses } + (** Small helper. Extract the items for a method in a trait decl. *) let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) - (decl : trait_decl) (name : string) (id : fun_decl_id) : unit = - let item_name = ctx_get_trait_const decl.def_id name ctx in + (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit = (* Lookup the definition *) - (* let def = - FunDeclId.Map.find ctx. - in *) - raise (Failure "TODO") + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + let extract_method (f : fun_and_loops) = + let f = f.f in + let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in + let ty () = + (* Extract the generics *) + (* We need to add the generics specific to the method, by removing those + which actually apply to the trait decl *) + let generics = + generic_params_drop_prefix decl.generics f.signature.generics + in + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + let use_forall = generics <> empty_generic_params in + let use_implicits = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall + use_implicits None None generics type_params cg_params trait_clauses; + if use_forall then F.pp_print_string fmt ","; + (* Extract the inputs and output *) + F.pp_print_space fmt (); + extract_fun_inputs_output_parameters_types ctx fmt f + in + extract_trait_decl_item ctx fmt fun_name ty + in + List.iter extract_method funs (** Extract a trait declaration *) let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index e4d1fb7b..435aa10c 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -417,8 +417,7 @@ type id = | TraitDeclId of TraitDeclId.id | TraitImplId of TraitImplId.id | LocalTraitClauseId of TraitClauseId.id - | TraitMethodId of - TraitDeclId.id * string * LoopId.id option * T.RegionGroupId.id option + | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option (** Something peculiar with trait methods: because we have to take into account forward/backward functions, we may need to generate fields items per method. @@ -820,23 +819,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | TraitItemId (id, name) -> "trait_item_id: decl_id:" ^ TraitDeclId.to_string id ^ ", type name: " ^ name - | TraitMethodId (trait_decl_id, fun_name, lp_id, rg_id) -> + | TraitMethodId (trait_decl_id, fun_name, rg_id) -> let trait_name = Print.fun_name_to_string (A.TraitDeclId.Map.find trait_decl_id trait_decls).name in - let lp_kind = - match lp_id with - | None -> "" - | Some lp_id -> "loop " ^ LoopId.to_string lp_id ^ ", " - in - let fwd_back_kind = match rg_id with | None -> "forward" | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id in - "trait " ^ trait_name ^ " method name (" ^ lp_kind ^ fwd_back_kind ^ "): " + "trait " ^ trait_name ^ " method name (" ^ fwd_back_kind ^ "): " ^ fun_name | TraitSelfClauseId -> "trait_self_clause" @@ -960,8 +953,9 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string) ctx_get_trait_item id item_name ctx let ctx_get_trait_method (id : trait_decl_id) (item_name : string) - (ctx : extraction_ctx) : string = - ctx_get_trait_item id item_name ctx + (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre (TraitMethodId (id, item_name, rg_id)) ctx let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) (ctx : extraction_ctx) : string = @@ -1299,9 +1293,7 @@ let ctx_add_trait_method (d : trait_decl) (item_name : string) (f : fun_decl) let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in let name = ctx_compute_fun_name trans f ctx in let is_opaque = false in - ctx_add is_opaque - (TraitMethodId (d.def_id, item_name, f.loop_id, f.back_id)) - name ctx + ctx_add is_opaque (TraitMethodId (d.def_id, item_name, f.back_id)) name ctx let ctx_add_trait_parent_clause (d : trait_decl) (clause : trait_clause) (ctx : extraction_ctx) : extraction_ctx = -- cgit v1.2.3