summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-03 20:41:42 +0200
committerSon Ho2023-09-03 20:41:42 +0200
commitfd17736cbdb312578b2ea6de9a58febf83bd96c8 (patch)
tree474e611743797ed50a510e1e6c3b36d189ffd4d7 /compiler
parentfcd1fbe048b55a89bd8ed34afa8ed2295798d3ec (diff)
Extract the trait decl methods
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml66
-rw-r--r--compiler/ExtractBase.ml22
2 files changed, 54 insertions, 34 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 2a678a27..138619c4 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -3213,6 +3213,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx)
in
List.iter extract_param def.signature.inputs
+let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx)
+ (fmt : F.formatter) (def : fun_decl) : unit =
+ extract_fun_input_parameters_types ctx fmt def;
+ extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output
+
let assert_backend_supports_decreases_clauses () =
match !backend with
| FStar | Lean -> ()
@@ -3931,19 +3936,10 @@ let extract_trait_decl_method_register_names (ctx : extraction_ctx)
(* We add one field per required forward/backward function *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
- let register_fun ctx f = ctx_add_trait_method trait_decl name f ctx in
- let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the names of the forward functions *)
- let ctx =
- if trans.keep_fwd then register_funs ctx (trans.fwd.f :: trans.fwd.loops)
- else ctx
- in
- (* Register the names of the backward functions *)
- List.fold_left
- (fun ctx back ->
- let ctx = register_fun ctx back.f in
- register_funs ctx back.loops)
- ctx trans.backs
+ let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in
+ (* Register the names *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ List.fold_left register_fun ctx funs
(** Similar to {!extract_type_decl_register_names} *)
let extract_trait_decl_register_names (ctx : extraction_ctx)
@@ -4016,18 +4012,50 @@ let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ";";
F.pp_close_box fmt ()
+(** Small helper - TODO: move *)
+let generic_params_drop_prefix (g1 : generic_params) (g2 : generic_params) :
+ generic_params =
+ let open Collections.List in
+ let types = drop (length g1.types) g2.types in
+ let const_generics = drop (length g1.const_generics) g2.const_generics in
+ let trait_clauses = drop (length g1.trait_clauses) g2.trait_clauses in
+ { types; const_generics; trait_clauses }
+
(** Small helper.
Extract the items for a method in a trait decl.
*)
let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
- (decl : trait_decl) (name : string) (id : fun_decl_id) : unit =
- let item_name = ctx_get_trait_const decl.def_id name ctx in
+ (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit =
(* Lookup the definition *)
- (* let def =
- FunDeclId.Map.find ctx.
- in *)
- raise (Failure "TODO")
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ generic_params_drop_prefix decl.generics f.signature.generics
+ in
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ let use_forall = generics <> empty_generic_params in
+ let use_implicits = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall
+ use_implicits None None generics type_params cg_params trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
+ in
+ extract_trait_decl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
(** Extract a trait declaration *)
let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index e4d1fb7b..435aa10c 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -417,8 +417,7 @@ type id =
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
- | TraitMethodId of
- TraitDeclId.id * string * LoopId.id option * T.RegionGroupId.id option
+ | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option
(** Something peculiar with trait methods: because we have to take into
account forward/backward functions, we may need to generate fields
items per method.
@@ -820,23 +819,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| TraitItemId (id, name) ->
"trait_item_id: decl_id:" ^ TraitDeclId.to_string id ^ ", type name: "
^ name
- | TraitMethodId (trait_decl_id, fun_name, lp_id, rg_id) ->
+ | TraitMethodId (trait_decl_id, fun_name, rg_id) ->
let trait_name =
Print.fun_name_to_string
(A.TraitDeclId.Map.find trait_decl_id trait_decls).name
in
- let lp_kind =
- match lp_id with
- | None -> ""
- | Some lp_id -> "loop " ^ LoopId.to_string lp_id ^ ", "
- in
-
let fwd_back_kind =
match rg_id with
| None -> "forward"
| Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
in
- "trait " ^ trait_name ^ " method name (" ^ lp_kind ^ fwd_back_kind ^ "): "
+ "trait " ^ trait_name ^ " method name (" ^ fwd_back_kind ^ "): "
^ fun_name
| TraitSelfClauseId -> "trait_self_clause"
@@ -960,8 +953,9 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string)
ctx_get_trait_item id item_name ctx
let ctx_get_trait_method (id : trait_decl_id) (item_name : string)
- (ctx : extraction_ctx) : string =
- ctx_get_trait_item id item_name ctx
+ (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitMethodId (id, item_name, rg_id)) ctx
let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
(ctx : extraction_ctx) : string =
@@ -1299,9 +1293,7 @@ let ctx_add_trait_method (d : trait_decl) (item_name : string) (f : fun_decl)
let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
let name = ctx_compute_fun_name trans f ctx in
let is_opaque = false in
- ctx_add is_opaque
- (TraitMethodId (d.def_id, item_name, f.loop_id, f.back_id))
- name ctx
+ ctx_add is_opaque (TraitMethodId (d.def_id, item_name, f.back_id)) name ctx
let ctx_add_trait_parent_clause (d : trait_decl) (clause : trait_clause)
(ctx : extraction_ctx) : extraction_ctx =