summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r--compiler/Extract.ml66
1 files changed, 47 insertions, 19 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 2a678a27..138619c4 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -3213,6 +3213,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx)
in
List.iter extract_param def.signature.inputs
+let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx)
+ (fmt : F.formatter) (def : fun_decl) : unit =
+ extract_fun_input_parameters_types ctx fmt def;
+ extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output
+
let assert_backend_supports_decreases_clauses () =
match !backend with
| FStar | Lean -> ()
@@ -3931,19 +3936,10 @@ let extract_trait_decl_method_register_names (ctx : extraction_ctx)
(* We add one field per required forward/backward function *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
- let register_fun ctx f = ctx_add_trait_method trait_decl name f ctx in
- let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the names of the forward functions *)
- let ctx =
- if trans.keep_fwd then register_funs ctx (trans.fwd.f :: trans.fwd.loops)
- else ctx
- in
- (* Register the names of the backward functions *)
- List.fold_left
- (fun ctx back ->
- let ctx = register_fun ctx back.f in
- register_funs ctx back.loops)
- ctx trans.backs
+ let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in
+ (* Register the names *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ List.fold_left register_fun ctx funs
(** Similar to {!extract_type_decl_register_names} *)
let extract_trait_decl_register_names (ctx : extraction_ctx)
@@ -4016,18 +4012,50 @@ let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ";";
F.pp_close_box fmt ()
+(** Small helper - TODO: move *)
+let generic_params_drop_prefix (g1 : generic_params) (g2 : generic_params) :
+ generic_params =
+ let open Collections.List in
+ let types = drop (length g1.types) g2.types in
+ let const_generics = drop (length g1.const_generics) g2.const_generics in
+ let trait_clauses = drop (length g1.trait_clauses) g2.trait_clauses in
+ { types; const_generics; trait_clauses }
+
(** Small helper.
Extract the items for a method in a trait decl.
*)
let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
- (decl : trait_decl) (name : string) (id : fun_decl_id) : unit =
- let item_name = ctx_get_trait_const decl.def_id name ctx in
+ (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit =
(* Lookup the definition *)
- (* let def =
- FunDeclId.Map.find ctx.
- in *)
- raise (Failure "TODO")
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ generic_params_drop_prefix decl.generics f.signature.generics
+ in
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ let use_forall = generics <> empty_generic_params in
+ let use_implicits = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall
+ use_implicits None None generics type_params cg_params trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
+ in
+ extract_trait_decl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
(** Extract a trait declaration *)
let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)