From 40a08afb20dd9bb36069407e3db37a03a8be0981 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 7 Mar 2023 12:00:40 +0100 Subject: Handle the "opaque_defs." prefix in a cleaner manner --- compiler/Extract.ml | 101 +++++++++++++++------- compiler/ExtractBase.ml | 222 ++++++++++++++++++++++++++++++++++-------------- compiler/Translate.ml | 25 +++++- 3 files changed, 253 insertions(+), 95 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 3ea3a862..0e9a53df 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -172,8 +172,8 @@ let keywords () = "macro"; "match"; "namespace"; + "opaque"; "open"; - "return"; "run_cmd"; "set_option"; "simp"; @@ -569,6 +569,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ lp_suffix ^ suffix in + let opaque_pre () = + match !Config.backend with FStar | Coq -> "" | Lean -> "opaque_defs." + in + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = (* If there is a basename, we use it *) @@ -699,6 +703,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fun_name; termination_measure_name; decreases_proof_name; + opaque_pre; var_basename; type_var_basename; append_index; @@ -726,6 +731,7 @@ let unit_name () = match !backend with Lean -> "Unit" | Coq | FStar -> "unit" *) let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (ty : ty) : unit = + let extract_rec = extract_ty ctx fmt in match ty with | Adt (type_id, tys) -> ( match type_id with @@ -743,15 +749,18 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) in F.pp_print_string fmt product; F.pp_print_space fmt ()) - (extract_ty ctx fmt true) tys; + (extract_rec true) tys; F.pp_print_string fmt ")") | AdtId _ | Assumed _ -> let print_paren = inside && tys <> [] in if print_paren then F.pp_print_string fmt "("; - F.pp_print_string fmt (ctx_get_type type_id ctx); + (* TODO: for now, only the opaque *functions* are extracted in the + opaque module. The opaque *types* are assumed. *) + let with_opaque_pre = false in + F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); if tys <> [] then F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_ty ctx fmt true) tys; + Collections.List.iter_link (F.pp_print_space fmt) (extract_rec true) + tys; if print_paren then F.pp_print_string fmt ")") | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) | Bool -> F.pp_print_string fmt ctx.fmt.bool_name @@ -760,11 +769,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Str -> F.pp_print_string fmt ctx.fmt.str_name | Arrow (arg_ty, ret_ty) -> if inside then F.pp_print_string fmt "("; - extract_ty ctx fmt false arg_ty; + extract_rec false arg_ty; F.pp_print_space fmt (); F.pp_print_string fmt "->"; F.pp_print_space fmt (); - extract_ty ctx fmt false ret_ty; + extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" | Array _ | Slice _ -> raise Unimplemented @@ -969,7 +978,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (* If Coq: print the constructor name *) (* TODO: remove superfluous test not is_rec below *) if !backend = Coq && not is_rec then ( - F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx); + let with_opaque_pre = false in + F.pp_print_string fmt + (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx); F.pp_print_string fmt " "); if !backend <> Lean then F.pp_print_string fmt "{"; F.pp_print_break fmt 1 ctx.indent_incr; @@ -1000,8 +1011,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && !backend = Coq); - let cons_name = ctx_get_struct (AdtId def.def_id) ctx in - let def_name = ctx_get_local_type def.def_id ctx in + let with_opaque_pre = false in + let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in + let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt def_name type_params cons_name fields) in () @@ -1043,10 +1055,12 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) The boolean [is_opaque_coq] is used to detect this case. *) - let is_opaque_coq = !backend = Coq && type_kind = None in + let is_opaque = type_kind = None in + let is_opaque_coq = !backend = Coq && is_opaque in let use_forall = is_opaque_coq && def.type_params <> [] in (* Retrieve the definition name *) - let def_name = ctx_get_local_type def.def_id ctx in + let with_opaque_pre = false in + let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Add the type params - note that we need those bindings only for the * body translation (they are not top-level) *) let ctx_body, type_params = ctx_add_type_params def.type_params ctx in @@ -1173,7 +1187,8 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) | Struct fields -> let adt_id = AdtId decl.def_id in (* Generate the instruction for the record constructor *) - let cons_name = ctx_get_struct adt_id ctx in + let with_opaque_pre = false in + let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in extract_arguments_info cons_name fields; (* Generate the instruction for the record projectors, if there are *) let is_rec = decl_is_from_rec_group kind in @@ -1215,8 +1230,11 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let ctx, type_params = ctx_add_type_params decl.type_params ctx in let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in - let def_name = ctx_get_local_type decl.def_id ctx in - let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in + let with_opaque_pre = false in + let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in + let cons_name = + ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx + in let extract_field_proj (field_id : FieldId.id) (_ : field) : unit = F.pp_print_space fmt (); (* Outer box for the projector definition *) @@ -1500,12 +1518,16 @@ let extract_adt_g_value * [{ field0=...; ...; fieldn=...; }] in case of structures. *) let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in match variant_id with | Some vid -> if !backend = Lean then - ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx + ctx_get_type with_opaque_pre adt_id ctx + ^ "." + ^ ctx_get_variant adt_id vid ctx else ctx_get_variant adt_id vid ctx - | None -> ctx_get_struct adt_id ctx + | None -> ctx_get_struct with_opaque_pre adt_id ctx in if inside && field_values <> [] then F.pp_print_string fmt "("; F.pp_print_string fmt cons; @@ -1523,7 +1545,8 @@ let extract_adt_g_value (* Extract globals in the same way as variables *) let extract_global (ctx : extraction_ctx) (fmt : F.formatter) (id : A.GlobalDeclId.id) : unit = - F.pp_print_string fmt (ctx_get_global id ctx) + let with_opaque_pre = ctx.use_opaque_pre in + F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx) (** [inside]: see {!extract_ty}. @@ -1643,11 +1666,8 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the function call *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the function name *) - let fun_name = - Option.value - ~default:(ctx_get_function fun_id ctx) - (ctx_maybe_get (DeclaredId fun_id) ctx) - in + let with_opaque_pre = ctx.use_opaque_pre in + let fun_name = ctx_get_function with_opaque_pre fun_id ctx in F.pp_print_string fmt fun_name; (* Print the type parameters *) List.iter @@ -1703,14 +1723,16 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) * applied structure constructors. *) let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in match adt_cons.variant_id with | Some vid -> if !backend = Lean then - ctx_get_type adt_cons.adt_id ctx + ctx_get_type with_opaque_pre adt_cons.adt_id ctx ^ "." ^ ctx_get_variant adt_cons.adt_id vid ctx else ctx_get_variant adt_cons.adt_id vid ctx - | None -> ctx_get_struct adt_cons.adt_id ctx + | None -> ctx_get_struct with_opaque_pre adt_cons.adt_id ctx in let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in if is_lean_struct then ( @@ -2309,8 +2331,10 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) + let with_opaque_pre = false in let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id + ctx in (* Add a break before *) F.pp_print_break fmt 0 0; @@ -2612,9 +2636,12 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) extract_comment fmt ("[" ^ Print.global_name_to_string global.name ^ "]"); F.pp_print_space fmt (); - let decl_name = ctx_get_global global.def_id ctx in + let with_opaque_pre = false in + let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in let body_name = - ctx_get_function (FromLlbc (Regular global.body_id, None, None)) ctx + ctx_get_function with_opaque_pre + (FromLlbc (Regular global.body_id, None, None)) + ctx in let decl_ty, body_ty = @@ -2685,8 +2712,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -2701,8 +2732,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -2714,8 +2749,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 55963289..86bb0cff 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -221,6 +221,35 @@ type formatter = { the same purpose as in {!field:fun_name}. - loop identifier, if this is for a loop *) + opaque_pre : unit -> string; + (** The prefix to use for opaque definitions. + + We need this because for some backends like Lean and Coq, we group + opaque definitions in module signatures, meaning that using those + definitions requires to prefix them with a module parameter name (such + as "opaque_defs."). + + For instance, if we have an opaque function [f : int -> int], which + is used by the non-opaque function [g], we would generate (in Coq): + {[ + (* The module signature declaring the opaque definitions *) + module type OpaqueDefs = { + f_fwd : int -> int + ... (* Other definitions *) + } + + (* The definitions generated for the non-opaque definitions *) + module Funs (opaque: OpaqueDefs) = { + let g ... = + ... + opaque_defs.f_fwd + ... + } + ]} + + Upon using [f] in [g], we don't directly use the the name "f_fwd", + but prefix it with the "opaque_defs." identifier. + *) var_basename : StringSet.t -> string option -> ty -> string; (** Generates a variable basename. @@ -297,7 +326,6 @@ type formatter = { type id = | GlobalId of A.GlobalDeclId.id | FunId of fun_id - | DeclaredId of fun_id | TerminationMeasureId of (A.fun_id * LoopId.id option) (** The definition which provides the decreases/termination measure. We insert calls to this clause to prove/reason about termination: @@ -362,6 +390,7 @@ module IdOrderedType = struct end module IdMap = Collections.MakeMap (IdOrderedType) +module IdSet = Collections.MakeSet (IdOrderedType) (** The names map stores the mappings from names to identifiers and vice-versa. @@ -377,10 +406,22 @@ type names_map = { precisely which identifiers are mapped to the same name... *) names_set : StringSet.t; + opaque_ids : IdSet.t; + (** The set of opaque definitions. + + See {!formatter.opaque_pre} for detailed explanations about why + we need to know which definitions are opaque to compute names. + + Also note that the opaque ids don't contain the ids of the assumed + definitions. In practice, assumed definitions are opaque_defs. However, they + are not grouped in the opaque module, meaning we never need to + prefix them (with, say, "opaque_defs."): we thus consider them as non-opaque + with regards to the names map. + *) } -let names_map_add (id_to_string : id -> string) (id : id) (name : string) - (nm : names_map) : names_map = +let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id) + (name : string) (nm : names_map) : names_map = (* Check if there is a clash *) (match StringMap.find_opt name nm.name_to_id with | None -> () (* Ok *) @@ -400,24 +441,32 @@ let names_map_add (id_to_string : id -> string) (id : id) (name : string) let id_to_name = IdMap.add id name nm.id_to_name in let name_to_id = StringMap.add name id nm.name_to_id in let names_set = StringSet.add name nm.names_set in - { id_to_name; name_to_id; names_set } + let opaque_ids = + if is_opaque then IdSet.add id nm.opaque_ids else nm.opaque_ids + in + { id_to_name; name_to_id; names_set; opaque_ids } let names_map_add_assumed_type (id_to_string : id -> string) (id : assumed_ty) (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (TypeId (Assumed id)) name nm + let is_opaque = false in + names_map_add id_to_string is_opaque (TypeId (Assumed id)) name nm let names_map_add_assumed_struct (id_to_string : id -> string) (id : assumed_ty) (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (StructId (Assumed id)) name nm + let is_opaque = false in + names_map_add id_to_string is_opaque (StructId (Assumed id)) name nm let names_map_add_assumed_variant (id_to_string : id -> string) (id : assumed_ty) (variant_id : VariantId.id) (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (VariantId (Assumed id, variant_id)) name nm + let is_opaque = false in + names_map_add id_to_string is_opaque + (VariantId (Assumed id, variant_id)) + name nm -let names_map_add_function (id_to_string : id -> string) (fid : fun_id) - (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (FunId fid) name nm +let names_map_add_function (id_to_string : id -> string) (is_opaque : bool) + (fid : fun_id) (name : string) (nm : names_map) : names_map = + names_map_add id_to_string is_opaque (FunId fid) name nm (** Make a (variable) basename unique (by adding an index). @@ -464,6 +513,14 @@ type extraction_ctx = { fmt : formatter; indent_incr : int; (** The indent increment we insert whenever we need to indent more *) + use_opaque_pre : bool; + (** Do we use the "opaque_defs." prefix for the opaque definitions? + + Opaque function definitions might refer opaque types: if we are in the + opaque module, we musn't use the "opaque_defs." prefix, otherwise we + use it. + Also see {!names_map.opaque_ids}. + *) } (** Debugging function, used when communicating name collisions to the user, @@ -487,7 +544,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | GlobalId gid -> let name = (A.GlobalDeclId.Map.find gid global_decls).name in "global name: " ^ Print.global_name_to_string name - | DeclaredId fid | FunId fid -> ( + | FunId fid -> ( match fid with | FromLlbc (fid, lp_id, rg_id) -> let fun_name = @@ -592,76 +649,96 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id -let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = +let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx) + : extraction_ctx = (* The id_to_string function to print nice debugging messages if there are * collisions *) let id_to_string (id : id) : string = id_to_string id ctx in - let names_map = names_map_add id_to_string id name ctx.names_map in + let names_map = names_map_add id_to_string is_opaque id name ctx.names_map in { ctx with names_map } -let ctx_maybe_get (id : id) (ctx : extraction_ctx) : string option = - IdMap.find_opt id ctx.names_map.id_to_name - -let ctx_get (id : id) (ctx : extraction_ctx) : string = - match ctx_maybe_get id ctx with - | Some s -> s +(** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *) +let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string = + match IdMap.find_opt id ctx.names_map.id_to_name with + | Some s -> + let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in + if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s | None -> log#serror ("Could not find: " ^ id_to_string id ctx); raise Not_found -let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = - ctx_get (GlobalId id) ctx +let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id) + (ctx : extraction_ctx) : string = + ctx_get with_opaque_pre (GlobalId id) ctx -let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = - ctx_get (FunId id) ctx +let ctx_get_function (with_opaque_pre : bool) (id : fun_id) + (ctx : extraction_ctx) : string = + ctx_get with_opaque_pre (FunId id) ctx -let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option) - (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get_function (FromLlbc (Regular id, lp, rg)) ctx +let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id) + (lp : LoopId.id option) (rg : RegionGroupId.id option) + (ctx : extraction_ctx) : string = + ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx -let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = +let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx) + : string = assert (id <> Tuple); - ctx_get (TypeId id) ctx + ctx_get with_opaque_pre (TypeId id) ctx -let ctx_get_local_type (id : TypeDeclId.id) (ctx : extraction_ctx) : string = - ctx_get_type (AdtId id) ctx +let ctx_get_local_type (with_opaque_pre : bool) (id : TypeDeclId.id) + (ctx : extraction_ctx) : string = + ctx_get_type with_opaque_pre (AdtId id) ctx let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = - ctx_get_type (Assumed id) ctx + (* In practice, the assumed types are opaque. However, assumed types + are never grouped in the opaque module, meaning we never need to + prefix them: we thus consider them as non-opaque with regards to the + names map. + *) + let is_opaque = false in + ctx_get_type is_opaque (Assumed id) ctx let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = - ctx_get (VarId id) ctx + let is_opaque = false in + ctx_get is_opaque (VarId id) ctx let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = - ctx_get (TypeVarId id) ctx + let is_opaque = false in + ctx_get is_opaque (TypeVarId id) ctx let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = - ctx_get (FieldId (type_id, field_id)) ctx + let is_opaque = false in + ctx_get is_opaque (FieldId (type_id, field_id)) ctx -let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string = - ctx_get (StructId def_id) ctx +let ctx_get_struct (with_opaque_pre : bool) (def_id : type_id) + (ctx : extraction_ctx) : string = + ctx_get with_opaque_pre (StructId def_id) ctx let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = - ctx_get (VariantId (def_id, variant_id)) ctx + let is_opaque = false in + ctx_get is_opaque (VariantId (def_id, variant_id)) ctx let ctx_get_decreases_proof (def_id : A.FunDeclId.id) (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = - ctx_get (DecreasesProofId (Regular def_id, loop_id)) ctx + let is_opaque = false in + ctx_get is_opaque (DecreasesProofId (Regular def_id, loop_id)) ctx let ctx_get_termination_measure (def_id : A.FunDeclId.id) (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = - ctx_get (TerminationMeasureId (Regular def_id, loop_id)) ctx + let is_opaque = false in + ctx_get is_opaque (TerminationMeasureId (Regular def_id, loop_id)) ctx (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in let name = basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name in - let ctx = ctx_add (TypeVarId id) name ctx in + let ctx = ctx_add is_opaque (TypeVarId id) name ctx in (ctx, name) (** See {!ctx_add_type_var} *) @@ -674,10 +751,11 @@ let ctx_add_type_vars (vars : (string * TypeVarId.id) list) (** Generate a unique variable name and add it to the context *) let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in let name = basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename in - let ctx = ctx_add (VarId id) name ctx in + let ctx = ctx_add is_opaque (VarId id) name ctx in (ctx, name) (** See {!ctx_add_var} *) @@ -697,20 +775,24 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) : let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : extraction_ctx * string = + assert (match def.kind with Struct _ -> true | _ -> false); + let is_opaque = false in let cons_name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add (StructId (AdtId def.def_id)) cons_name ctx in + let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) cons_name ctx in (ctx, cons_name) let ctx_add_type_decl (def : type_decl) (ctx : extraction_ctx) : extraction_ctx = + let is_opaque = def.kind = Opaque in let def_name = ctx.fmt.type_name def.name in - let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in + let ctx = ctx_add is_opaque (TypeId (AdtId def.def_id)) def_name ctx in ctx let ctx_add_field (def : type_decl) (field_id : FieldId.id) (field : field) (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in let name = ctx.fmt.field_name def.name field_id field.field_name in - let ctx = ctx_add (FieldId (AdtId def.def_id, field_id)) name ctx in + let ctx = ctx_add is_opaque (FieldId (AdtId def.def_id, field_id)) name ctx in (ctx, name) let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list) @@ -721,8 +803,11 @@ let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list) let ctx_add_variant (def : type_decl) (variant_id : VariantId.id) (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in let name = ctx.fmt.variant_name def.name variant.variant_name in - let ctx = ctx_add (VariantId (AdtId def.def_id, variant_id)) name ctx in + let ctx = + ctx_add is_opaque (VariantId (AdtId def.def_id, variant_id)) name ctx + in (ctx, name) let ctx_add_variants (def : type_decl) @@ -734,33 +819,43 @@ let ctx_add_variants (def : type_decl) let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : extraction_ctx * string = + assert (match def.kind with Struct _ -> true | _ -> false); + let is_opaque = false in let name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add (StructId (AdtId def.def_id)) name ctx in + let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) name ctx in (ctx, name) let ctx_add_decreases_proof (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = + let is_opaque = false in let name = ctx.fmt.decreases_proof_name def.def_id def.basename def.num_loops def.loop_id in - ctx_add (DecreasesProofId (Regular def.def_id, def.loop_id)) name ctx + ctx_add is_opaque + (DecreasesProofId (Regular def.def_id, def.loop_id)) + name ctx let ctx_add_termination_measure (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = + let is_opaque = false in let name = ctx.fmt.termination_measure_name def.def_id def.basename def.num_loops def.loop_id in - ctx_add (TerminationMeasureId (Regular def.def_id, def.loop_id)) name ctx + ctx_add is_opaque + (TerminationMeasureId (Regular def.def_id, def.loop_id)) + name ctx let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = + (* TODO: update once the body id can be an option *) + let is_opaque = false in let name = ctx.fmt.global_name def.name in let decl = GlobalId def.def_id in let body = FunId (FromLlbc (Regular def.body_id, None, None)) in - let ctx = ctx_add decl (name ^ "_c") ctx in - let ctx = ctx_add body (name ^ "_body") ctx in + let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in + let ctx = ctx_add is_opaque body (name ^ "_body") ctx in ctx let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) @@ -792,20 +887,15 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) in Some { id = rg_id; region_names } in - let name = + let is_opaque = def.body = None in + (* Add the function name *) + let def_name = ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info (keep_fwd, num_backs) in - let ctx = - if def.body = None && !Config.backend = Lean then - ctx_add - (DeclaredId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id))) - ("opaque_defs." ^ name) ctx - else ctx - in - ctx_add + ctx_add is_opaque (FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id))) - name ctx + def_name ctx type names_map_init = { keywords : string list; @@ -831,11 +921,12 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = let name_to_id = StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords) in + let opaque_ids = IdSet.empty in (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. * Also note that we don't need this mapping for keywords: we insert keywords only * to check collisions. *) let id_to_name = IdMap.empty in - let nm = { id_to_name; name_to_id; names_set } in + let nm = { id_to_name; name_to_id; names_set; opaque_ids } in (* For debugging - we are creating bindings for assumed types and functions, so * it is ok if we simply use the "show" function (those aren't simply identified * by numbers) *) @@ -871,8 +962,15 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in let nm = + (* In practice, the assumed function are opaque. However, assumed functions + are never grouped in the opaque module, meaning we never need to + prefix them: we thus consider them as non-opaque with regards to the + names map. + *) + let is_opaque = false in List.fold_left - (fun nm (fid, name) -> names_map_add_function id_to_string fid name nm) + (fun nm (fid, name) -> + names_map_add_function id_to_string is_opaque fid name nm) nm assumed_functions in (* Return *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 139f8891..7ee86b28 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -350,7 +350,14 @@ type gen_config = { extract_transparent : bool; (** If [true], extract the transparent declarations, otherwise ignore. *) extract_opaque : bool; - (** If [true], extract the opaque declarations, otherwise ignore. *) + (** If [true], extract the opaque declarations, otherwise ignore. + + For now, this controls only the opaque *functions*, not the opaque + globals or types. + TODO: update this. This is not trivial if we want to extract the opaque + types in an opaque module, because some non-opaque types may refer + to opaque types and vice-versa. + *) extract_state_type : bool; (** If [true], generate a definition/declaration for the state type *) extract_globals : bool; @@ -787,7 +794,15 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : mk_formatter_and_names_map trans_ctx crate.name variant_concatenate_type_name in - let ctx = { ExtractBase.trans_ctx; names_map; fmt; indent_incr = 2 } in + let ctx = + { + ExtractBase.trans_ctx; + names_map; + fmt; + indent_incr = 2; + use_opaque_pre = !Config.split_files; + } + in (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) @@ -1035,6 +1050,12 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : interface = true; } in + let gen_ctx = + { + gen_ctx with + extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false }; + } + in extract_file opaque_config gen_ctx opaque_filename crate.A.name opaque_module ": opaque function definitions" [] [ types_module ]; [ opaque_module ]) -- cgit v1.2.3