summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r--compiler/Extract.ml101
1 files changed, 70 insertions, 31 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 3ea3a862..0e9a53df 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -172,8 +172,8 @@ let keywords () =
"macro";
"match";
"namespace";
+ "opaque";
"open";
- "return";
"run_cmd";
"set_option";
"simp";
@@ -569,6 +569,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fname ^ lp_suffix ^ suffix
in
+ let opaque_pre () =
+ match !Config.backend with FStar | Coq -> "" | Lean -> "opaque_defs."
+ in
+
let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
: string =
(* If there is a basename, we use it *)
@@ -699,6 +703,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fun_name;
termination_measure_name;
decreases_proof_name;
+ opaque_pre;
var_basename;
type_var_basename;
append_index;
@@ -726,6 +731,7 @@ let unit_name () = match !backend with Lean -> "Unit" | Coq | FStar -> "unit"
*)
let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(ty : ty) : unit =
+ let extract_rec = extract_ty ctx fmt in
match ty with
| Adt (type_id, tys) -> (
match type_id with
@@ -743,15 +749,18 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
in
F.pp_print_string fmt product;
F.pp_print_space fmt ())
- (extract_ty ctx fmt true) tys;
+ (extract_rec true) tys;
F.pp_print_string fmt ")")
| AdtId _ | Assumed _ ->
let print_paren = inside && tys <> [] in
if print_paren then F.pp_print_string fmt "(";
- F.pp_print_string fmt (ctx_get_type type_id ctx);
+ (* TODO: for now, only the opaque *functions* are extracted in the
+ opaque module. The opaque *types* are assumed. *)
+ let with_opaque_pre = false in
+ F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
if tys <> [] then F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_ty ctx fmt true) tys;
+ Collections.List.iter_link (F.pp_print_space fmt) (extract_rec true)
+ tys;
if print_paren then F.pp_print_string fmt ")")
| TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
| Bool -> F.pp_print_string fmt ctx.fmt.bool_name
@@ -760,11 +769,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| Str -> F.pp_print_string fmt ctx.fmt.str_name
| Arrow (arg_ty, ret_ty) ->
if inside then F.pp_print_string fmt "(";
- extract_ty ctx fmt false arg_ty;
+ extract_rec false arg_ty;
F.pp_print_space fmt ();
F.pp_print_string fmt "->";
F.pp_print_space fmt ();
- extract_ty ctx fmt false ret_ty;
+ extract_rec false ret_ty;
if inside then F.pp_print_string fmt ")"
| Array _ | Slice _ -> raise Unimplemented
@@ -969,7 +978,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
(* If Coq: print the constructor name *)
(* TODO: remove superfluous test not is_rec below *)
if !backend = Coq && not is_rec then (
- F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx);
+ let with_opaque_pre = false in
+ F.pp_print_string fmt
+ (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx);
F.pp_print_string fmt " ");
if !backend <> Lean then F.pp_print_string fmt "{";
F.pp_print_break fmt 1 ctx.indent_incr;
@@ -1000,8 +1011,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
a group of mutually recursive types: we extract it as an inductive type
*)
assert (is_rec && !backend = Coq);
- let cons_name = ctx_get_struct (AdtId def.def_id) ctx in
- let def_name = ctx_get_local_type def.def_id ctx in
+ let with_opaque_pre = false in
+ let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in
+ let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
extract_type_decl_variant ctx fmt def_name type_params cons_name fields)
in
()
@@ -1043,10 +1055,12 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter)
The boolean [is_opaque_coq] is used to detect this case.
*)
- let is_opaque_coq = !backend = Coq && type_kind = None in
+ let is_opaque = type_kind = None in
+ let is_opaque_coq = !backend = Coq && is_opaque in
let use_forall = is_opaque_coq && def.type_params <> [] in
(* Retrieve the definition name *)
- let def_name = ctx_get_local_type def.def_id ctx in
+ let with_opaque_pre = false in
+ let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Add the type params - note that we need those bindings only for the
* body translation (they are not top-level) *)
let ctx_body, type_params = ctx_add_type_params def.type_params ctx in
@@ -1173,7 +1187,8 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
| Struct fields ->
let adt_id = AdtId decl.def_id in
(* Generate the instruction for the record constructor *)
- let cons_name = ctx_get_struct adt_id ctx in
+ let with_opaque_pre = false in
+ let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in
extract_arguments_info cons_name fields;
(* Generate the instruction for the record projectors, if there are *)
let is_rec = decl_is_from_rec_group kind in
@@ -1215,8 +1230,11 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let ctx, type_params = ctx_add_type_params decl.type_params ctx in
let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in
- let def_name = ctx_get_local_type decl.def_id ctx in
- let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in
+ let with_opaque_pre = false in
+ let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in
+ let cons_name =
+ ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx
+ in
let extract_field_proj (field_id : FieldId.id) (_ : field) : unit =
F.pp_print_space fmt ();
(* Outer box for the projector definition *)
@@ -1500,12 +1518,16 @@ let extract_adt_g_value
* [{ field0=...; ...; fieldn=...; }] in case of structures.
*)
let cons =
+ (* The ADT shouldn't be opaque *)
+ let with_opaque_pre = false in
match variant_id with
| Some vid ->
if !backend = Lean then
- ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx
+ ctx_get_type with_opaque_pre adt_id ctx
+ ^ "."
+ ^ ctx_get_variant adt_id vid ctx
else ctx_get_variant adt_id vid ctx
- | None -> ctx_get_struct adt_id ctx
+ | None -> ctx_get_struct with_opaque_pre adt_id ctx
in
if inside && field_values <> [] then F.pp_print_string fmt "(";
F.pp_print_string fmt cons;
@@ -1523,7 +1545,8 @@ let extract_adt_g_value
(* Extract globals in the same way as variables *)
let extract_global (ctx : extraction_ctx) (fmt : F.formatter)
(id : A.GlobalDeclId.id) : unit =
- F.pp_print_string fmt (ctx_get_global id ctx)
+ let with_opaque_pre = ctx.use_opaque_pre in
+ F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx)
(** [inside]: see {!extract_ty}.
@@ -1643,11 +1666,8 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the function call *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the function name *)
- let fun_name =
- Option.value
- ~default:(ctx_get_function fun_id ctx)
- (ctx_maybe_get (DeclaredId fun_id) ctx)
- in
+ let with_opaque_pre = ctx.use_opaque_pre in
+ let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
F.pp_print_string fmt fun_name;
(* Print the type parameters *)
List.iter
@@ -1703,14 +1723,16 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
* applied structure constructors.
*)
let cons =
+ (* The ADT shouldn't be opaque *)
+ let with_opaque_pre = false in
match adt_cons.variant_id with
| Some vid ->
if !backend = Lean then
- ctx_get_type adt_cons.adt_id ctx
+ ctx_get_type with_opaque_pre adt_cons.adt_id ctx
^ "."
^ ctx_get_variant adt_cons.adt_id vid ctx
else ctx_get_variant adt_cons.adt_id vid ctx
- | None -> ctx_get_struct adt_cons.adt_id ctx
+ | None -> ctx_get_struct with_opaque_pre adt_cons.adt_id ctx
in
let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in
if is_lean_struct then (
@@ -2309,8 +2331,10 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit =
assert (not def.is_global_decl_body);
(* Retrieve the function name *)
+ let with_opaque_pre = false in
let def_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id
+ ctx
in
(* Add a break before *)
F.pp_print_break fmt 0 0;
@@ -2612,9 +2636,12 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
extract_comment fmt ("[" ^ Print.global_name_to_string global.name ^ "]");
F.pp_print_space fmt ();
- let decl_name = ctx_get_global global.def_id ctx in
+ let with_opaque_pre = false in
+ let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in
let body_name =
- ctx_get_function (FromLlbc (Regular global.body_id, None, None)) ctx
+ ctx_get_function with_opaque_pre
+ (FromLlbc (Regular global.body_id, None, None))
+ ctx
in
let decl_ty, body_ty =
@@ -2685,8 +2712,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "assert_norm";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
+ (* Note that if the function is opaque, the unit test will fail
+ because the normalizer will get stuck *)
+ let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ ctx_get_local_function with_opaque_pre def.def_id def.loop_id
+ def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
@@ -2701,8 +2732,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "Check";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
+ (* Note that if the function is opaque, the unit test will fail
+ because the normalizer will get stuck *)
+ let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ ctx_get_local_function with_opaque_pre def.def_id def.loop_id
+ def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
@@ -2714,8 +2749,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "#assert";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
+ (* Note that if the function is opaque, the unit test will fail
+ because the normalizer will get stuck *)
+ let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ ctx_get_local_function with_opaque_pre def.def_id def.loop_id
+ def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (