summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-01 18:22:15 +0200
committerSon Ho2023-09-01 18:22:15 +0200
commit0023f814ce638cd9b04ead9ec2e0c194d5efaefd (patch)
treef23fc1cf37bbb9b126119a7ec8978c2c667358e9 /compiler
parentaed317881d3083bba6a2cf154289486ef47ddf85 (diff)
Make good progress on Extract.ml
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml414
-rw-r--r--compiler/ExtractBase.ml47
2 files changed, 237 insertions, 224 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 4238a152..3c4feca5 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -836,6 +836,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
to_snake_case basename
| Coq | Lean -> basename
in
+ let trait_clause_basename (_varset : StringSet.t) (_clause : trait_clause) :
+ string =
+ (* TODO: actually use the clause to derive the name *)
+ "cl"
+ in
let append_index (basename : string) (i : int) : string =
basename ^ string_of_int i
in
@@ -931,6 +936,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
var_basename;
type_var_basename;
const_generic_var_basename;
+ trait_clause_basename;
append_index;
extract_literal;
extract_unop;
@@ -1246,16 +1252,18 @@ and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit =
let { types; const_generics; trait_refs } = generics in
- if types <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_ty ctx fmt no_params_tys true)
- types);
- if const_generics <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_const_generic ctx fmt true)
- const_generics);
+ if !backend <> HOL4 then (
+ if types <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_ty ctx fmt no_params_tys true)
+ types);
+ if const_generics <> [] then (
+ assert (!backend <> HOL4);
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_const_generic ctx fmt true)
+ const_generics));
if trait_refs <> [] then (
F.pp_print_space fmt ();
Collections.List.iter_link (F.pp_print_space fmt)
@@ -1588,6 +1596,93 @@ let extract_comment (fmt : F.formatter) (sl : string list) : unit =
F.pp_print_string fmt rd;
F.pp_close_box fmt ()
+let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit =
+ let with_opaque_pre = false in
+ let trait_name = ctx_get_trait_decl with_opaque_pre clause.trait_id ctx in
+ F.pp_print_string fmt trait_name;
+ extract_generic_args ctx fmt no_params_tys clause.generics
+
+(** Insert a space, if necessary *)
+let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
+ if !space then space := false else F.pp_print_space fmt ()
+
+let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (use_forall : bool) (as_implicits : bool)
+ (space : bool ref option) (generics : generic_params)
+ (type_params : string list) (cg_params : string list)
+ (trait_clauses : string list) : unit =
+ let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
+ (* HOL4 doesn't support const generics *)
+ assert (cg_params = [] || !backend <> HOL4);
+ let left_bracket () =
+ if as_implicits then F.pp_print_string fmt "{"
+ else F.pp_print_string fmt "("
+ in
+ let right_bracket () =
+ if as_implicits then F.pp_print_string fmt "}"
+ else F.pp_print_string fmt ")"
+ in
+ let insert_req_space () =
+ match space with
+ | None -> F.pp_print_space fmt ()
+ | Some space -> insert_req_space fmt space
+ in
+ (* Print the type/const generic parameters *)
+ if all_params <> [] then (
+ if use_forall then (
+ insert_req_space ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "forall");
+ (* Note that in HOL4 we don't print the type parameters. *)
+ if !backend <> HOL4 then (
+ (* Print the type parameters *)
+ if type_params <> [] then (
+ insert_req_space ();
+ (* ( *)
+ left_bracket ();
+ List.iter
+ (fun s ->
+ F.pp_print_string fmt s;
+ F.pp_print_space fmt ())
+ type_params;
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ());
+ (* ) *)
+ right_bracket ());
+ (* Print the const generic parameters *)
+ List.iter
+ (fun (var : const_generic_var) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket ();
+ let n = ctx_get_const_generic_var var.index ctx in
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_literal_type ctx fmt var.ty;
+ (* ) *)
+ right_bracket ())
+ generics.const_generics);
+ (* Print the trait clauses *)
+ List.iter
+ (fun (clause : trait_clause) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket ();
+ let n = ctx_get_trait_clause_var clause.clause_id ctx in
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt no_params_tys clause;
+ (* ) *)
+ right_bracket ())
+ generics.trait_clauses)
+
(** Extract a type declaration.
This function is for all type declarations and all backends **at the exception**
@@ -1624,7 +1719,6 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let ctx_body, type_params, cg_params, trait_clauses =
ctx_add_generic_params def.generics ctx
in
- let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
F.pp_print_break fmt 0 0;
@@ -1644,40 +1738,13 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(match qualif with
| Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name)
| None -> F.pp_print_string fmt def_name);
- (* HOL4 doesn't support const generics *)
- assert (cg_params = [] || !backend <> HOL4);
- (* Print the type/const generic parameters *)
- if all_params <> [] && !backend <> HOL4 then (
- if use_forall then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "forall");
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- List.iter
- (fun s ->
- F.pp_print_string fmt s;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt (type_keyword () ^ ")"));
- (* Print the const generic parameters *)
- List.iter
- (fun (var : const_generic_var) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- let n = ctx_get_const_generic_var var.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt var.ty;
- F.pp_print_string fmt ")")
- def.const_generic_params);
+ (* HOL4 doesn't support const generics, and type definitions in HOL4 don't
+ support trait clauses *)
+ assert ((cg_params = [] && trait_clauses = []) || !backend <> HOL4);
+ (* Print the generic parameters *)
+ let as_implicits = false in
+ extract_generic_params ctx_body fmt type_decl_group use_forall as_implicits
+ None def.generics type_params cg_params trait_clauses;
(* Print the "=" if we extract the body*)
if extract_body then (
F.pp_print_space fmt ();
@@ -1737,9 +1804,12 @@ let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Generic parameters are unsupported *)
- assert (def.const_generic_params = []);
+ assert (def.generics.const_generics = []);
+ (* Trait clauses on type definitions are unsupported *)
+ assert (def.generics.trait_clauses = []);
+ (* Types *)
(* Count the number of parameters *)
- let num_params = List.length def.type_params in
+ let num_params = List.length def.generics.types in
(* Generate the declaration *)
F.pp_print_space fmt ();
F.pp_print_string fmt
@@ -1760,8 +1830,7 @@ let extract_type_decl_hol4_empty_record (ctx : extraction_ctx)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Sanity check *)
- assert (def.type_params = []);
- assert (def.const_generic_params = []);
+ assert (def.generics = empty_generic_params);
(* Generate the declaration *)
F.pp_print_space fmt ();
F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”");
@@ -1801,13 +1870,12 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (decl : type_decl) : unit =
assert (!backend = Coq);
(* Generating the [Arguments] instructions is useful only if there are type parameters *)
- if decl.type_params = [] && decl.const_generic_params = [] then ()
+ if decl.generics.types = [] && decl.generics.const_generics = [] then ()
else
(* Add the type params - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let _ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
+ let _ctx_body, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params decl.generics ctx
in
(* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *)
let extract_arguments_info (cons_name : string) (fields : 'a list) : unit =
@@ -1815,27 +1883,22 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_break fmt 0 0;
(* Open a box *)
F.pp_open_hovbox fmt ctx.indent_incr;
- (* Small utility *)
- let print_vars () =
- List.iter
- (fun (var : string) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt ("{" ^ var ^ "}"))
- (List.append type_params cg_params)
- in
- let print_fields () =
- List.iter
- (fun _ ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "_")
- fields
- in
F.pp_print_break fmt 0 0;
F.pp_print_string fmt "Arguments";
F.pp_print_space fmt ();
F.pp_print_string fmt cons_name;
- print_vars ();
- print_fields ();
+ (* Print the type/const params and the trait clauses (`{T}`) *)
+ List.iter
+ (fun (var : string) ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ("{" ^ var ^ "}"))
+ (List.concat [ type_params; cg_params; trait_clauses ]);
+ (* Print the fields (`_`) *)
+ List.iter
+ (fun _ ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "_")
+ fields;
F.pp_print_string fmt ".";
(* Close the box *)
@@ -1888,9 +1951,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let is_rec = decl_is_from_rec_group kind in
if is_rec then
(* Add the type params *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params decl.generics ctx
in
let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in
@@ -1911,33 +1973,13 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_space fmt ();
let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
F.pp_print_string fmt field_name;
- F.pp_print_space fmt ();
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_string fmt "{";
- List.iter
- (fun p ->
- F.pp_print_string fmt p;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type}";
- F.pp_print_space fmt ());
- (* Print the const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (v : const_generic_var) ->
- F.pp_print_string fmt "{";
- let n = ctx_get_const_generic_var v.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt v.ty;
- F.pp_print_string fmt "}";
- F.pp_print_space fmt ())
- decl.const_generic_params;
+ (* Print the generics *)
+ let use_forall = false in
+ let as_implicits = true in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall
+ as_implicits None decl.generics type_params cg_params trait_clauses;
(* Print the record parameter *)
+ F.pp_print_space fmt ();
F.pp_print_string fmt "(";
F.pp_print_string fmt record_var;
F.pp_print_space fmt ();
@@ -2183,11 +2225,11 @@ let extract_adt_g_value
(inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
(ty : ty) : extraction_ctx =
match ty with
- | Adt (Tuple, type_args, cg_args) ->
+ | Adt (Tuple, generics) ->
(* Tuple *)
(* For now, we only support fully applied tuple constructors *)
- assert (List.length type_args = List.length field_values);
- assert (cg_args = []);
+ assert (List.length generics.types = List.length field_values);
+ assert (generics.const_generics = [] && generics.trait_refs = []);
(* This is very annoying: in Coq, we can't write [()] for the value of
type [unit], we have to write [tt]. *)
if !backend = Coq && field_values = [] then (
@@ -2205,7 +2247,7 @@ let extract_adt_g_value
in
F.pp_print_string fmt ")";
ctx)
- | Adt (adt_id, _, _) ->
+ | Adt (adt_id, _) ->
(* "Regular" ADT *)
(* If we are generating a pattern for a let-binding and we target Lean,
@@ -2343,14 +2385,12 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Top-level qualifier *)
match qualif.id with
| FunOrOp fun_id ->
- extract_function_call ctx fmt inside fun_id qualif.type_args
- qualif.const_generic_args args
+ extract_function_call ctx fmt inside fun_id qualif.generics args
| Global global_id -> extract_global ctx fmt global_id
| AdtCons adt_cons_id ->
- extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args
- qualif.const_generic_args args
+ extract_adt_cons ctx fmt inside adt_cons_id qualif.generics args
| Proj proj ->
- extract_field_projector ctx fmt inside app proj qualif.type_args args)
+ extract_field_projector ctx fmt inside app proj qualif.generics args)
| _ ->
(* "Regular" expression *)
(* Open parentheses *)
@@ -2373,8 +2413,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: function call *)
and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (fid : fun_or_op_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
+ (inside : bool) (fid : fun_or_op_id) (generics : generic_args)
+ (args : texpression list) : unit =
match (fid, args) with
| Unop unop, [ arg ] ->
(* A unop can have *at most* one argument (the result can't be a function!).
@@ -2396,19 +2436,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
F.pp_print_string fmt fun_name;
(* Sanity check: HOL4 doesn't support const generics *)
- assert (cg_args = [] || !backend <> HOL4);
- (* Print the type parameters, if the backend is not HOL4 *)
- if !backend <> HOL4 then (
- List.iter
- (fun ty ->
- F.pp_print_space fmt ();
- extract_ty ctx fmt TypeDeclId.Set.empty true ty)
- type_args;
- List.iter
- (fun cg ->
- F.pp_print_space fmt ();
- extract_const_generic ctx fmt true cg)
- cg_args);
+ assert (generics.const_generics = [] || !backend <> HOL4);
+ (* Print the generics *)
+ extract_generic_args ctx fmt TypeDeclId.Set.empty generics;
(* Print the arguments *)
List.iter
(fun ve ->
@@ -2430,9 +2460,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
(** Subcase of the app case: ADT constructor *)
and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
- (adt_cons : adt_cons_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
- let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in
+ (adt_cons : adt_cons_id) (generics : generic_args) (args : texpression list)
+ : unit =
+ let e_ty = Adt (adt_cons.adt_id, generics) in
let is_single_pat = false in
let _ =
extract_adt_g_value
@@ -2446,7 +2476,7 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: ADT field projector. *)
and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(inside : bool) (original_app : texpression) (proj : projection)
- (_proj_type_params : ty list) (args : texpression list) : unit =
+ (_generics : generic_args) (args : texpression list) : unit =
(* We isolate the first argument (if there is), in order to pretty print the
* projection ([x.field] instead of [MkAdt?.field x] *)
match args with
@@ -2905,11 +2935,11 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
let cs = ctx_get_struct false (Assumed Array) ctx in
F.pp_print_string fmt cs;
(* Print the parameters *)
- let _, tys, cgs = ty_as_adt e_ty in
- let ty = Collections.List.to_cons_nil tys in
+ let _, generics = ty_as_adt e_ty in
+ let ty = Collections.List.to_cons_nil generics.types in
F.pp_print_space fmt ();
extract_ty ctx fmt TypeDeclId.Set.empty true ty;
- let cg = Collections.List.to_cons_nil cgs in
+ let cg = Collections.List.to_cons_nil generics.const_generics in
F.pp_print_space fmt ();
extract_const_generic ctx fmt true cg;
F.pp_print_space fmt ();
@@ -2936,10 +2966,6 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ()
| _ -> raise (Failure "Unreachable")
-(** Insert a space, if necessary *)
-let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
- if !space then space := false else F.pp_print_space fmt ()
-
(** A small utility to print the parameters of a function signature.
We return two contexts:
@@ -2947,6 +2973,8 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
- the context augmented with bindings for the type parameters *and*
bindings for the input values
+ We also return names for the type parameters, const generics, etc.
+
TODO: do we really need the first one? We should probably always use
the second one.
It comes from the fact that when we print the input values for the
@@ -2954,57 +2982,22 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
patterns, not the variables). We should figure a cleaner way.
*)
let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
- (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx =
+ (fmt : F.formatter) (def : fun_decl) :
+ extraction_ctx * extraction_ctx * string list =
(* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.signature.generics ctx
in
- (* Print the parameters - rem.: we should have filtered the functions
- * with no input parameters *)
- (* The type parameters.
-
- Note that in HOL4 we don't print the type parameters.
- *)
- if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then (
- (* Open a box for the type and const generic parameters *)
- F.pp_open_hovbox fmt 0;
- (* The type parameters *)
- if type_params <> [] then (
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ())
- def.signature.type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- let type_keyword =
- match !backend with
- | FStar -> "Type0"
- | Coq | Lean -> "Type"
- | HOL4 -> raise (Failure "Unreachable")
- in
- F.pp_print_string fmt (type_keyword ^ ")"));
- (* The const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt p.ty;
- F.pp_print_string fmt ")")
- def.signature.const_generic_params;
- (* Close the box for the type parameters *)
- F.pp_close_box fmt ());
+ (* Print the generics *)
+ (* Open a box for the generics *)
+ F.pp_open_hovbox fmt 0;
+ let use_forall = false in
+ let as_implicits = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits
+ (Some space) def.signature.generics type_params cg_params trait_clauses;
+ (* Close the box for the generics *)
+ F.pp_close_box fmt ();
(* The input parameters - note that doing this adds bindings to the context *)
let ctx_body =
match def.body with
@@ -3027,7 +3020,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
ctx)
ctx body.inputs_lvs
in
- (ctx, ctx_body)
+ (ctx, ctx_body, List.concat [ type_params; cg_params; trait_clauses ])
(** A small utility to print the types of the input parameters in the form:
[u32 -> list u32 -> ...]
@@ -3096,7 +3089,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, _ = extract_fun_parameters space ctx fmt def in
+ let _, _, _ = extract_fun_parameters space ctx fmt def in
insert_req_space fmt space;
F.pp_print_string fmt ":";
(* Print the signature *)
@@ -3158,7 +3151,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, ctx_body = extract_fun_parameters space ctx fmt def in
+ let _, ctx_body, _ = extract_fun_parameters space ctx fmt def in
(* Print the ":=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt ":=";
@@ -3298,9 +3291,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let is_opaque_coq = !backend = Coq && is_opaque in
let use_forall =
- is_opaque_coq
- && (def.signature.type_params <> []
- || def.signature.const_generic_params <> [])
+ is_opaque_coq && def.signature.generics <> empty_generic_params
in
(* Print the qualifier ("assume", etc.).
@@ -3326,7 +3317,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for "(PARAMS) :" *)
F.pp_open_hovbox fmt 0;
let space = ref true in
- let ctx, ctx_body = extract_fun_parameters space ctx fmt def in
+ let ctx, ctx_body, all_params = extract_fun_parameters space ctx fmt def in
(* Print the return type - note that we have to be careful when
* printing the input values for the decrease clause, because
* it introduces bindings in the context... We thus "forget"
@@ -3374,20 +3365,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* The name of the decrease clause *)
let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in
F.pp_print_string fmt decr_name;
- (* Print the type/const generic parameters - TODO: we do this many
+ (* Print the generic parameters - TODO: we do this many
times, we should have a helper to factor it out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the input values: we have to be careful here to print
* only the input values which are in common with the *forward*
* function (the additional input values "given back" to the
@@ -3474,19 +3458,12 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open the box for [DECREASES] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt terminates_name;
- (* Print the type/const generic params - TODO: factor out *)
+ (* Print the generic params - TODO: factor out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the variables *)
List.iter
(fun v ->
@@ -3544,13 +3521,10 @@ let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id
ctx
in
- assert (def.signature.const_generic_params = []);
+ assert (def.signature.generics.const_generics = []);
(* Add the type/const gen parameters - note that we need those bindings
only for the generation of the type (they are not top-level) *)
- let ctx, _, _ =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
- in
+ let ctx, _, _, _ = ctx_add_generic_params def.signature.generics ctx in
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0;
(* Open a box for the whole definition *)
@@ -3726,10 +3700,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(global : A.global_decl) (body : fun_decl) (interface : bool) : unit =
assert body.is_global_decl_body;
assert (Option.is_none body.back_id);
- assert (List.length body.signature.inputs = 0);
+ assert (body.signature.inputs = []);
assert (List.length body.signature.doutputs = 1);
- assert (List.length body.signature.type_params = 0);
- assert (List.length body.signature.const_generic_params = 0);
+ assert (body.signature.generics = empty_generic_params);
(* Add a break then the name of the corresponding LLBC declaration *)
F.pp_print_break fmt 0 0;
@@ -3799,8 +3772,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
(* Check if this is a unit function *)
let sg = def.signature in
if
- sg.type_params = []
- && sg.const_generic_params = []
+ sg.generics = empty_generic_params
&& (sg.inputs = [ mk_unit_ty ] || sg.inputs = [])
&& sg.output = mk_result_ty mk_unit_ty
then (
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 96ecfd42..2c704d3e 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -111,10 +111,10 @@ let decl_is_first_from_group (kind : decl_kind) : bool =
let decl_is_not_last_from_group (kind : decl_kind) : bool =
not (decl_is_last_from_group kind)
-(* TODO: this should a module we give to a functor! *)
-
type type_decl_kind = Enum | Struct
+(* TODO: this should be a module we give to a functor! *)
+
(** A formatter's role is twofold:
1. Come up with name suggestions.
For instance, provided some information about a function (its basename,
@@ -125,6 +125,9 @@ type type_decl_kind = Enum | Struct
snake case, adding prefixes/suffixes, etc.
2. Format some specific terms, like constants.
+
+ TODO: unclear that this is useful now that all the backends are so much
+ entangled in Extract.ml
*)
type formatter = {
bool_name : string;
@@ -288,6 +291,13 @@ type formatter = {
(** Generates a type variable basename. *)
const_generic_var_basename : StringSet.t -> string -> string;
(** Generates a const generic variable basename. *)
+ trait_clause_basename : StringSet.t -> trait_clause -> string;
+ (** Return a base name for a trait clause. We might add a suffix to prevent
+ collisions.
+
+ In the traduction we explicitely manipulate the trait clause instances,
+ that is we introduce one input variable for each trait clause.
+ *)
append_index : string -> int -> string;
(** Appends an index to a name - we use this to generate unique
names: when doing so, the role of the formatter is just to concatenate
@@ -396,6 +406,9 @@ type id =
| TypeVarId of TypeVarId.id
| ConstGenericVarId of ConstGenericVarId.id
| VarId of VarId.id
+ | TraitDeclId of TraitDeclId.id
+ | TraitImplId of TraitImplId.id
+ | TraitClauseId of TraitClauseId.id
| UnknownId
(** Used for stored various strings like keywords, definitions which
should always be in context, etc. and which can't be linked to one
@@ -718,6 +731,9 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| ConstGenericVarId id ->
"const_generic_var_id: " ^ ConstGenericVarId.to_string id
| VarId id -> "var_id: " ^ VarId.to_string id
+ | TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
+ | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
+ | TraitClauseId id -> "trait_clause_id: " ^ TraitClauseId.to_string id
(** We might not check for collisions for some specific ids (ex.: field names) *)
let allow_collisions (id : id) : bool =
@@ -787,6 +803,14 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get_type is_opaque (Assumed id) ctx
+let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id)
+ (ctx : extraction_ctx) : string =
+ ctx_get with_opaque_pre (TraitDeclId id) ctx
+
+let ctx_get_trait_impl (with_opaque_pre : bool) (id : trait_impl_id)
+ (ctx : extraction_ctx) : string =
+ ctx_get with_opaque_pre (TraitImplId id) ctx
+
let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get is_opaque (VarId id) ctx
@@ -800,6 +824,11 @@ let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx)
let is_opaque = false in
ctx_get is_opaque (ConstGenericVarId id) ctx
+let ctx_get_trait_clause_var (id : TraitClauseId.id) (ctx : extraction_ctx) :
+ string =
+ let is_opaque = false in
+ ctx_get is_opaque (TraitClauseId id) ctx
+
let ctx_get_field (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
let is_opaque = false in
@@ -865,6 +894,16 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
let ctx = ctx_add is_opaque (VarId id) name ctx in
(ctx, name)
+(** Generate a unique trait clause name and add it to the context *)
+let ctx_add_trait_clause (basename : string) (id : TraitClauseId.id)
+ (ctx : extraction_ctx) : extraction_ctx * string =
+ let is_opaque = false in
+ let name =
+ basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ in
+ let ctx = ctx_add is_opaque (TraitClauseId id) name ctx in
+ (ctx, name)
+
(** See {!ctx_add_var} *)
let ctx_add_vars (vars : var list) (ctx : extraction_ctx) :
extraction_ctx * string list =
@@ -890,7 +929,9 @@ let ctx_add_const_generic_params (vars : const_generic_var list)
let ctx_add_trait_clauses (clauses : trait_clause list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
- (fun ctx (c : trait_clause) -> ctx_add_trait_clause c ctx)
+ (fun ctx (c : trait_clause) ->
+ let basename = ctx.fmt.trait_clause_basename ctx.names_map.names_set c in
+ ctx_add_trait_clause basename c.clause_id ctx)
ctx clauses
(** Returns the lists of names for: