summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml10
-rw-r--r--compiler/ExtractBase.ml4
-rw-r--r--compiler/Translate.ml21
3 files changed, 30 insertions, 5 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index d707dc81..e1b2b23f 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -476,8 +476,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
in
let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option)
(num_rgs : int) (rg : region_group_info option) (filter_info : bool * int)
+ (missing_body: bool)
: string =
let fname = fun_name_to_snake_case fname in
+ let fname = if !backend = Lean && missing_body then "opaque_defs." ^ fname else fname in
(* Compute the suffix *)
let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
(* Concatenate *)
@@ -1371,7 +1373,6 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
ctx
in
let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
- (* Register the function names *)
let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
(* Register the forward functions' names *)
@@ -2227,7 +2228,12 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
let use_forall = is_opaque_coq && def.signature.type_params <> [] in
(* *)
let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in
- F.pp_print_string fmt (qualif ^ " " ^ def_name);
+ (* For Lean: we generate a record of assumed functions *)
+ if not (!backend = Lean && (kind = Assumed || kind = Declared)) then begin
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ()
+ end;
+ F.pp_print_string fmt def_name;
F.pp_print_space fmt ();
if use_forall then (
F.pp_print_string fmt ":";
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 77170b5b..98a29daf 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -168,6 +168,7 @@ type formatter = {
int ->
region_group_info option ->
bool * int ->
+ bool ->
string;
(** Compute the name of a regular (non-assumed) function.
@@ -187,6 +188,7 @@ type formatter = {
The number of extracted backward functions if not necessarily
equal to the number of region groups, because we may have
filtered some of them.
+ - whether there is a body or not (indicates assumed function)
TODO: use the fun id for the assumed functions.
*)
decreases_clause_name :
@@ -774,7 +776,7 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
in
let name =
ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
- (keep_fwd, num_backs)
+ (keep_fwd, num_backs) (def.body = None)
in
ctx_add
(FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id)))
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 0a1c8f9a..df7a750d 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -661,13 +661,24 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
*)
if config.extract_state_type && config.extract_fun_decls then
export_state_type ();
+ if config.extract_opaque && config.extract_fun_decls && !Config.backend = Lean then begin
+ Format.pp_print_break fmt 0 0;
+ Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr;
+ Format.pp_print_string fmt "structure OpaqueDefs where";
+ Format.pp_print_break fmt 0 0
+ end;
List.iter export_decl_group ctx.crate.declarations;
+ if config.extract_opaque && !Config.backend = Lean then begin
+ Format.pp_close_box fmt ()
+ end;
if config.extract_state_type && not config.extract_fun_decls then
export_state_type ()
let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string)
(rust_module_name : string) (module_name : string) (custom_msg : string)
- (custom_imports : string list) (custom_includes : string list) : unit =
+ ?(custom_variables: string list = [])
+ (custom_imports : string list) (custom_includes : string list)
+ : unit =
(* Open the file and create the formatter *)
let out = open_out filename in
let fmt = Format.formatter_of_out_channel out in
@@ -720,7 +731,11 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string)
(* Add the custom imports *)
List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_imports;
(* Add the custom includes *)
- List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_includes
+ List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_includes;
+ if custom_variables <> [] then begin
+ Printf.fprintf out "\n";
+ List.iter (fun m -> Printf.fprintf out "%s\n" m) custom_variables
+ end
);
(* From now onwards, we use the formatter *)
(* Set the margin *)
@@ -1016,8 +1031,10 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
[ module_name ^ module_delimiter ^ "Clauses" ^ module_delimiter ^ "Template"]
else []
in
+ let custom_variables = if has_opaque_funs then [ "section variable (opaque_defs: OpaqueDefs)" ] else [] in
extract_file fun_config gen_ctx fun_filename crate.A.name fun_module
": function definitions" []
+ ~custom_variables
([ types_module ] @ opaque_funs_module @ clauses_module))
else
let gen_config =