diff options
author | Son HO | 2023-07-31 16:15:58 +0200 |
---|---|---|
committer | GitHub | 2023-07-31 16:15:58 +0200 |
commit | 887d0ef1efc8912c6273b5ebcf979384e9d7fa97 (patch) | |
tree | 92d6021eb549f7cc25501856edd58859786b7e90 /compiler | |
parent | 53adf30fe440eb8b6f58ba89f4a4c0acc7877498 (diff) | |
parent | 9b3a58e423333fc9a4a5a264c3beb0a3d951e86b (diff) |
Merge pull request #31 from AeneasVerif/son_lean_backend
Improve the Lean backend
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 23 | ||||
-rw-r--r-- | compiler/Driver.ml | 10 | ||||
-rw-r--r-- | compiler/Extract.ml | 353 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 120 | ||||
-rw-r--r-- | compiler/Pure.ml | 10 | ||||
-rw-r--r-- | compiler/Translate.ml | 295 |
6 files changed, 569 insertions, 242 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index ce9b0e0c..0475899c 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -162,6 +162,11 @@ let backward_no_state_update = ref false *) let split_files = ref true +(** For Lean, controls whether we generate a lakefile or not. + + *) +let lean_gen_lakefile = ref false + (** If true, treat the unit functions (function taking no inputs and returning no outputs) as unit tests: evaluate them with the interpreter and check that they don't panic. @@ -292,3 +297,21 @@ let filter_useless_monadic_calls = ref true dynamically check for that). *) let filter_useless_functions = ref true + +(** Obsolete. TODO: remove. + + For Lean we used to parameterize the entire development by a section variable + called opaque_defs, of type OpaqueDefs. + *) +let wrap_opaque_in_sig = ref false + +(** Use short names for the record fields. + + Some backends can't disambiguate records when their field names have collisions. + When this happens, we use long names, by which we concatenate the record + names with the field names, and check whether there are name collisions. + + For backends which can disambiguate records (typically by using the typing + information), we use short names (i.e., the original field names). + *) +let record_fields_short_names = ref false diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 2ff9e295..166ef11b 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -107,6 +107,9 @@ let () = Arg.Clear check_invariants, " Deactivate the invariant sanity checks performed at every evaluation \ step. Dramatically increases speed." ); + ( "-lean-default-lakefile", + Arg.Clear lean_gen_lakefile, + " Generate a default lakefile.lean (Lean only)" ); ] in @@ -130,6 +133,9 @@ let () = (not !use_fuel) || (not !extract_decreases_clauses) && not !extract_template_decreases_clauses); + if !lean_gen_lakefile && not (!backend = Lean) then + log#error + "The -lean-default-lakefile option is valid only for the Lean backend"; (* Check that the user specified a backend *) let _ = @@ -157,7 +163,9 @@ let () = (* We don't support fuel for the Lean backend *) if !use_fuel then ( log#error "The Lean backend doesn't support the -use-fuel option"; - fail ()) + fail ()); + (* Lean can disambiguate the field names *) + record_fields_short_names := true | HOL4 -> (* We don't support fuel for the HOL4 backend *) if !use_fuel then ( diff --git a/compiler/Extract.ml b/compiler/Extract.ml index d624d9ca..b16f9639 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -300,23 +300,40 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Option, option_none_id, "NONE"); ] -let assumed_llbc_functions : +let assumed_llbc_functions () : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = let rg0 = Some T.RegionGroupId.zero in - [ - (Replace, None, "mem_replace_fwd"); - (Replace, rg0, "mem_replace_back"); - (VecNew, None, "vec_new"); - (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "vec_push_back"); - (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "vec_insert_back"); - (VecLen, None, "vec_len"); - (VecIndex, None, "vec_index_fwd"); - (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); - (VecIndexMut, None, "vec_index_mut_fwd"); - (VecIndexMut, rg0, "vec_index_mut_back"); - ] + match !backend with + | FStar | Coq | HOL4 -> + [ + (Replace, None, "mem_replace_fwd"); + (Replace, rg0, "mem_replace_back"); + (VecNew, None, "vec_new"); + (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); + (VecPush, rg0, "vec_push_back"); + (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); + (VecInsert, rg0, "vec_insert_back"); + (VecLen, None, "vec_len"); + (VecIndex, None, "vec_index_fwd"); + (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); + (VecIndexMut, None, "vec_index_mut_fwd"); + (VecIndexMut, rg0, "vec_index_mut_back"); + ] + | Lean -> + [ + (Replace, None, "mem.replace"); + (Replace, rg0, "mem.replace_back"); + (VecNew, None, "Vec.new"); + (VecPush, None, "Vec.push_fwd") (* Shouldn't be used *); + (VecPush, rg0, "Vec.push"); + (VecInsert, None, "Vec.insert_fwd") (* Shouldn't be used *); + (VecInsert, rg0, "Vec.insert"); + (VecLen, None, "Vec.len"); + (VecIndex, None, "Vec.index"); + (VecIndex, rg0, "Vec.index_back") (* shouldn't be used *); + (VecIndexMut, None, "Vec.index_mut"); + (VecIndexMut, rg0, "Vec.index_mut_back"); + ] let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with @@ -344,7 +361,7 @@ let names_map_init () : names_map_init = assumed_adts = assumed_adts (); assumed_structs; assumed_variants = assumed_variants (); - assumed_llbc_functions; + assumed_llbc_functions = assumed_llbc_functions (); assumed_pure_functions = assumed_pure_functions (); } @@ -505,10 +522,10 @@ let fun_decl_kind_to_qualif (kind : decl_kind) : string option = | Lean -> ( match kind with | SingleNonRec -> Some "def" - | SingleRec -> Some "def" - | MutRecFirst -> Some "mutual def" - | MutRecInner -> Some "def" - | MutRecLast -> Some "def" + | SingleRec -> Some "divergent def" + | MutRecFirst -> Some "mutual divergent def" + | MutRecInner -> Some "divergent def" + | MutRecLast -> Some "divergent def" | Assumed -> Some "axiom" | Declared -> Some "axiom") | HOL4 -> None @@ -601,34 +618,48 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | FStar | Lean | HOL4 -> name | Coq -> capitalize_first_letter name in - let type_name name = type_name_to_snake_case name ^ "_t" in + let type_name name = + match !backend with + | FStar | Coq | HOL4 -> type_name_to_snake_case name ^ "_t" + | Lean -> String.concat "." (get_type_name name) + in let field_name (def_name : name) (field_id : FieldId.id) (field_name : string option) : string = - let def_name = type_name_to_snake_case def_name ^ "_" in - match field_name with - | Some field_name -> def_name ^ field_name - | None -> def_name ^ FieldId.to_string field_id + let field_name = + match field_name with + | Some field_name -> field_name + | None -> FieldId.to_string field_id + in + if !Config.record_fields_short_names then field_name + else + let def_name = type_name_to_snake_case def_name ^ "_" in + def_name ^ field_name in let variant_name (def_name : name) (variant : string) : string = - let variant = to_camel_case variant in - if variant_concatenate_type_name then - type_name_to_camel_case def_name ^ variant - else variant + match !backend with + | FStar | Coq | HOL4 -> + let variant = to_camel_case variant in + if variant_concatenate_type_name then + type_name_to_camel_case def_name ^ variant + else variant + | Lean -> variant in let struct_constructor (basename : name) : string = let tname = type_name basename in let prefix = - match !backend with FStar -> "Mk" | Lean | Coq | HOL4 -> "mk" + match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> "" + in + let suffix = + match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" in - prefix ^ tname + prefix ^ tname ^ suffix in - let get_fun_name = get_name in - let fun_name_to_snake_case (fname : fun_name) : string = - let fname = get_fun_name fname in - (* Converting to snake case should be a no-op, but it doesn't cost much *) - let fname = List.map to_snake_case fname in - (* Concatenate the elements *) - String.concat "_" fname + let get_fun_name fname = + let fname = get_name fname in + (* TODO: don't convert to snake case for Coq, HOL4, F* *) + match !backend with + | FStar | Coq | HOL4 -> String.concat "_" (List.map to_snake_case fname) + | Lean -> String.concat "." fname in let global_name (name : global_name) : string = (* Converting to snake case also lowercases the letters (in Rust, global @@ -639,7 +670,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) : string = - let fname = fun_name_to_snake_case fname in + let fname = get_fun_name fname in (* Compute the suffix *) let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in (* Concatenate *) @@ -648,7 +679,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = fun_name_to_snake_case fname in + let fname = get_fun_name fname in let lp_suffix = default_fun_loop_suffix num_loops loop_id in (* Compute the suffix *) let suffix = @@ -663,7 +694,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let decreases_proof_name (_fid : A.FunDeclId.id) (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = fun_name_to_snake_case fname in + let fname = get_fun_name fname in let lp_suffix = default_fun_loop_suffix num_loops loop_id in (* Compute the suffix *) let suffix = @@ -678,7 +709,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let opaque_pre () = match !Config.backend with | FStar | Coq | HOL4 -> "" - | Lean -> "opaque_defs." + | Lean -> if !Config.wrap_opaque_in_sig then "opaque_defs." else "" in let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) @@ -789,7 +820,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) F.pp_print_string fmt ")"; F.pp_print_string fmt ")") else Z.pp_print fmt sv.value; - F.pp_print_string fmt " (by intlit))") + F.pp_print_string fmt ")") | Bool b -> let b = match !backend with @@ -1007,6 +1038,11 @@ let end_type_decl_group (fmt : F.formatter) (is_rec : bool) let unit_name () = match !backend with Lean -> "Unit" | Coq | FStar | HOL4 -> "unit" +(** Small helper *) +let extract_arrow (fmt : F.formatter) () : unit = + if !Config.backend = Lean then F.pp_print_string fmt "→" + else F.pp_print_string fmt "->" + (** [inside] constrols whether we should add parentheses or not around type applications (if [true] we add parentheses). @@ -1100,7 +1136,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) if inside then F.pp_print_string fmt "("; extract_rec false arg_ty; F.pp_print_space fmt (); - F.pp_print_string fmt "->"; + extract_arrow fmt (); F.pp_print_space fmt (); extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" @@ -1188,7 +1224,7 @@ let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) (* Print the arrow [->] *) if !backend <> HOL4 then ( F.pp_print_space fmt (); - F.pp_print_string fmt "->"); + extract_arrow fmt ()); (* Close the field box *) F.pp_close_box fmt (); (* Return *) @@ -1326,7 +1362,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt (unit_name ())) else if !backend = Lean && fields = [] then () (* If the definition is recursive, we may need to extract it as an inductive - (instead of a record) *) + (instead of a record). We start with the "normal" case: we extract it + as a record. *) else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then ( if !backend <> Lean then F.pp_print_space fmt (); (* If Coq: print the constructor name *) @@ -1379,7 +1416,14 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && (!backend = Coq || !backend = Lean)); let with_opaque_pre = false in - let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in + (* Small trick: in Lean we use namespaces, meaning we don't need to prefix + the constructor name with the name of the type at definition site, + i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq + we generate `inductive Foo := | mk ... *) + let cons_name = + if !backend = Lean then "mk" + else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx + in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt type_decl_group def_name type_params cons_name fields) @@ -1387,16 +1431,26 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) () (** Extract a nestable, muti-line comment *) -let extract_comment (fmt : F.formatter) (s : string) : unit = - match !backend with - | Coq | FStar | HOL4 -> - F.pp_print_string fmt "(** "; - F.pp_print_string fmt s; - F.pp_print_string fmt " *)" - | Lean -> - F.pp_print_string fmt "/- "; +let extract_comment (fmt : F.formatter) (sl : string list) : unit = + (* Delimiters, space after we break a line *) + let ld, space, rd = + match !backend with + | Coq | FStar | HOL4 -> ("(** ", 4, " *)") + | Lean -> ("/- ", 3, " -/") + in + F.pp_open_vbox fmt space; + F.pp_print_string fmt ld; + (match sl with + | [] -> () + | s :: sl -> F.pp_print_string fmt s; - F.pp_print_string fmt " -/" + List.iter + (fun s -> + F.pp_print_space fmt (); + F.pp_print_string fmt s) + sl); + F.pp_print_string fmt rd; + F.pp_close_box fmt () (** Extract a type declaration. @@ -1436,7 +1490,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) - extract_comment fmt ("[" ^ Print.name_to_string def.name ^ "]"); + extract_comment fmt [ "[" ^ Print.name_to_string def.name ^ "]" ]; F.pp_print_break fmt 0 0; (* Open a box for the definition, so that whenever possible it gets printed on * one line. Note however that in the case of Lean line breaks are important @@ -1833,7 +1887,7 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment *) - extract_comment fmt "The state type used in the state-error monad"; + extract_comment fmt [ "The state type used in the state-error monad" ]; F.pp_print_break fmt 0 0; (* Open a box for the definition, so that whenever possible it gets printed on * one line *) @@ -1950,14 +2004,17 @@ let extract_global_decl_register_names (ctx : extraction_ctx) Note that patterns can introduce new variables: we thus return an extraction context updated with new bindings. + [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding + or a lambda). + TODO: we don't need something very generic anymore (some definitions used to be polymorphic). *) let extract_adt_g_value (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) - (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) - (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : - extraction_ctx = + (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool) + (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) + (ty : ty) : extraction_ctx = match ty with | Adt (Tuple, type_args) -> (* Tuple *) @@ -1982,36 +2039,57 @@ let extract_adt_g_value ctx) | Adt (adt_id, _) -> (* "Regular" ADT *) - (* We print something of the form: [Cons field0 ... fieldn]. - * We could update the code to print something of the form: - * [{ field0=...; ...; fieldn=...; }] in case of structures. - *) - let cons = - (* The ADT shouldn't be opaque *) - let with_opaque_pre = false in - match variant_id with - | Some vid -> ( - (* In the case of Lean, we might have to add the type name as a prefix *) - match (!backend, adt_id) with - | Lean, Assumed _ -> - ctx_get_type with_opaque_pre adt_id ctx - ^ "." - ^ ctx_get_variant adt_id vid ctx - | _ -> ctx_get_variant adt_id vid ctx) - | None -> ctx_get_struct with_opaque_pre adt_id ctx - in - let use_parentheses = inside && field_values <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - let ctx = - Collections.List.fold_left - (fun ctx v -> - F.pp_print_space fmt (); - extract_value ctx true v) - ctx field_values - in - if use_parentheses then F.pp_print_string fmt ")"; - ctx + + (* If we are generating a pattern for a let-binding and we target Lean, + the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. + + Otherwise, it is: `let Cons x0 ... xn = ...` + *) + if is_single_pat && !Config.backend = Lean then ( + F.pp_print_string fmt "⟨"; + F.pp_print_space fmt (); + let ctx = + Collections.List.fold_left_link + (fun _ -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx true v) + ctx field_values + in + F.pp_print_space fmt (); + F.pp_print_string fmt "⟩"; + ctx) + else + (* We print something of the form: [Cons field0 ... fieldn]. + * We could update the code to print something of the form: + * [{ field0=...; ...; fieldn=...; }] in case of structures. + *) + let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in + match variant_id with + | Some vid -> ( + (* In the case of Lean, we might have to add the type name as a prefix *) + match (!backend, adt_id) with + | Lean, Assumed _ -> + ctx_get_type with_opaque_pre adt_id ctx + ^ "." + ^ ctx_get_variant adt_id vid ctx + | _ -> ctx_get_variant adt_id vid ctx) + | None -> ctx_get_struct with_opaque_pre adt_id ctx + in + let use_parentheses = inside && field_values <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + let ctx = + Collections.List.fold_left + (fun ctx v -> + F.pp_print_space fmt (); + extract_value ctx true v) + ctx field_values + in + if use_parentheses then F.pp_print_string fmt ")"; + ctx | _ -> raise (Failure "Inconsistent typed value") (* Extract globals in the same way as variables *) @@ -2026,7 +2104,7 @@ let extract_global (ctx : extraction_ctx) (fmt : F.formatter) updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_pattern) : extraction_ctx = + (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = match v.value with | PatConstant cv -> ctx.fmt.extract_primitive_value fmt inside cv; @@ -2042,8 +2120,10 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "_"; ctx | PatAdt av -> - let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in - extract_adt_g_value extract_value fmt ctx inside av.variant_id + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id av.field_values v.ty (** [inside]: controls the introduction of parentheses. See [extract_ty] @@ -2173,12 +2253,13 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : unit = let e_ty = Adt (adt_cons.adt_id, type_args) in + let is_single_pat = false in let _ = extract_adt_g_value (fun ctx inside e -> extract_texpression ctx fmt inside e; ctx) - fmt ctx inside adt_cons.variant_id args e_ty + fmt ctx is_single_pat inside adt_cons.variant_id args e_ty in () @@ -2226,11 +2307,12 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true x) + extract_typed_pattern ctx fmt true true x) ctx xl in F.pp_print_space fmt (); - F.pp_print_string fmt "->"; + if !backend = Lean then F.pp_print_string fmt "=>" + else F.pp_print_string fmt "->"; F.pp_print_space fmt (); (* Print the body *) extract_texpression ctx fmt false e; @@ -2295,7 +2377,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) * TODO: cleanup * *) if monadic && (!backend = Coq || !backend = HOL4) then ( - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let arrow = match !backend with @@ -2321,15 +2403,13 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) else ( F.pp_print_string fmt "let"; F.pp_print_space fmt ()); - let ctx = extract_typed_pattern ctx fmt true lv in + let ctx = extract_typed_pattern ctx fmt true true lv in F.pp_print_space fmt (); let eq = match !backend with | FStar -> "=" | Coq -> ":=" - | Lean -> - (* TODO: switch to ⟵ once issues are fixed *) - if monadic then "←" else ":=" + | Lean -> if monadic then "←" else ":=" | HOL4 -> if monadic then "<-" else "=" in F.pp_print_string fmt eq; @@ -2409,7 +2489,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) (* Open a box for the [if e] *) F.pp_open_hovbox fmt ctx.indent_incr; F.pp_print_string fmt "if"; - if !backend = Lean then F.pp_print_string fmt " h:"; + if !backend = Lean && ctx.use_dep_ite then F.pp_print_string fmt " h:"; F.pp_print_space fmt (); let scrut_inside = PureUtils.texpression_requires_parentheses scrut in extract_texpression ctx fmt scrut_inside scrut; @@ -2470,7 +2550,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) match !backend with | FStar -> "begin match" | Coq -> "match" - | Lean -> "match h:" + | Lean -> if ctx.use_dep_ite then "match h:" else "match" | HOL4 -> (* We're being extra safe in the case of HOL4 *) "(case" @@ -2497,7 +2577,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) (* Print the pattern *) F.pp_print_string fmt "|"; F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false br.pat in + let ctx = extract_typed_pattern ctx fmt false false br.pat in F.pp_print_space fmt (); let arrow = match !backend with FStar -> "->" | Coq | Lean | HOL4 -> "=>" @@ -2689,7 +2769,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (* Open a box for the input parameter *) F.pp_open_hovbox fmt 0; F.pp_print_string fmt "("; - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); @@ -2714,7 +2794,7 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx) let inside = false in extract_ty ctx fmt TypeDeclId.Set.empty inside ty; F.pp_print_space fmt (); - F.pp_print_string fmt "->"; + extract_arrow fmt (); F.pp_print_space fmt () in List.iter extract_param def.signature.inputs @@ -2752,7 +2832,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) extract_comment fmt - ("[" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause"); + [ "[" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause" ]; F.pp_print_space fmt (); (* Open a box for the definition, so that whenever possible it gets printed on * one line *) @@ -2814,7 +2894,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) extract_comment fmt - ("[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure"); + [ "[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure" ]; F.pp_print_space fmt (); (* Open a box for the definition, so that whenever possible it gets printed on * one line *) @@ -2868,7 +2948,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) (* syntax <def_name> term ... term : tactic *) F.pp_print_break fmt 0 0; extract_comment fmt - ("[" ^ Print.fun_name_to_string def.basename ^ "]: decreases_by tactic"); + [ "[" ^ Print.fun_name_to_string def.basename ^ "]: decreases_by tactic" ]; F.pp_print_space fmt (); F.pp_open_hvbox fmt 0; F.pp_print_string fmt "syntax \""; @@ -2897,6 +2977,40 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) F.pp_close_box fmt (); F.pp_print_break fmt 0 0 +let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_decl) : unit = + let { keep_fwd; num_backs } = + PureUtils.RegularFunIdMap.find + (A.Regular def.def_id, def.loop_id, def.back_id) + ctx.fun_name_info + in + let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in + let comment = + let loop_comment = + match def.loop_id with + | None -> "" + | Some id -> "loop " ^ LoopId.to_string id ^ ": " + in + let fwd_back_comment = + match def.back_id with + | None -> [ "forward function" ] + | Some id -> + (* Check if there is only one backward function, and no forward function *) + if (not keep_fwd) && num_backs = 1 then + [ + "merged forward/backward function"; + "(there is a single backward function, and the forward function \ + returns ())"; + ] + else [ "backward function " ^ T.RegionGroupId.to_string id ] + in + match fwd_back_comment with + | [] -> raise (Failure "Unreachable") + | [ s ] -> [ comment_pre ^ loop_comment ^ s ] + | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl + in + extract_comment fmt comment + (** Extract a function declaration. This function is for all function declarations and all backends **at the exception** @@ -2916,8 +3030,8 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; - (* Print a comment to link the extracted type to its original rust definition *) - extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]"); + (* Print a comment to link the extracted definition to its original rust definition *) + extract_fun_comment ctx fmt def; F.pp_print_space fmt (); (* Open two boxes for the definition, so that whenever possible it gets printed on * one line and indents are correct *) @@ -2939,8 +3053,11 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let use_forall = is_opaque_coq && def.signature.type_params <> [] in (* Print the qualifier ("assume", etc.). - For Lean: we generate a record of assumed functions *) - (if not (!backend = Lean && (kind = Assumed || kind = Declared)) then + if `wrap_opaque_in_sig`: we generate a record of assumed funcions. + TODO: this is obsolete. + *) + (if not (!Config.wrap_opaque_in_sig && (kind = Assumed || kind = Declared)) + then let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in match qualif with | Some qualif -> @@ -3034,7 +3151,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) List.fold_left (fun ctx (lv : typed_pattern) -> F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false lv in + let ctx = extract_typed_pattern ctx fmt true false lv in ctx) ctx inputs_lvs in @@ -3168,6 +3285,8 @@ let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_break fmt 0 0; (* Open a box for the whole definition *) F.pp_open_hvbox fmt ctx.indent_incr; + (* Print a comment to link the extracted definition to its original rust definition *) + extract_fun_comment ctx fmt def; (* Generate: `val _ = new_constant ("...",` *) F.pp_print_string fmt ("val _ = new_constant (\"" ^ def_name ^ "\","); F.pp_print_space fmt (); @@ -3343,7 +3462,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break then the name of the corresponding LLBC declaration *) F.pp_print_break fmt 0 0; - extract_comment fmt ("[" ^ Print.global_name_to_string global.name ^ "]"); + extract_comment fmt [ "[" ^ Print.global_name_to_string global.name ^ "]" ]; F.pp_print_space fmt (); let with_opaque_pre = false in @@ -3417,7 +3536,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_break fmt 0 0; (* Print a comment *) extract_comment fmt - ("Unit test for [" ^ Print.fun_name_to_string def.basename ^ "]"); + [ "Unit test for [" ^ Print.fun_name_to_string def.basename ^ "]" ]; F.pp_print_space fmt (); (* Open a box for the test *) F.pp_open_hovbox fmt ctx.indent_incr; diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 0a5d7df2..655bb033 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -240,7 +240,9 @@ type formatter = { - loop identifier, if this is for a loop *) opaque_pre : unit -> string; - (** The prefix to use for opaque definitions. + (** TODO: obsolete, remove. + + The prefix to use for opaque definitions. We need this because for some backends like Lean and Coq, we group opaque definitions in module signatures, meaning that using those @@ -414,7 +416,7 @@ module IdSet = Collections.MakeSet (IdOrderedType) We use it for lookups (during the translation) and to check for name clashes. - [id_to_string] is for debugging. + [id_to_name] is for debugging. *) type names_map = { id_to_name : string IdMap.t; @@ -425,7 +427,9 @@ type names_map = { *) names_set : StringSet.t; opaque_ids : IdSet.t; - (** The set of opaque definitions. + (** TODO: this is obsolete. Remove. + + The set of opaque definitions. See {!formatter.opaque_pre} for detailed explanations about why we need to know which definitions are opaque to compute names. @@ -486,6 +490,20 @@ let names_map_add_function (id_to_string : id -> string) (is_opaque : bool) (fid : fun_id) (name : string) (nm : names_map) : names_map = names_map_add id_to_string is_opaque (FunId fid) name nm +(** The unsafe names map stores mappings from identifiers to names which might + collide. For some backends and some names, it might be acceptable to have + collisions. For instance, in Lean, different records can have fields with + the same name because Lean uses the typing information to resolve the + ambiguities. + + This map complements the {!names_map}, which checks for collisions. + *) +type unsafe_names_map = { id_to_name : string IdMap.t } + +let unsafe_names_map_add (id : id) (name : string) (nm : unsafe_names_map) : + unsafe_names_map = + { id_to_name = IdMap.add id name nm.id_to_name } + (** Make a (variable) basename unique (by adding an index). We do this in an inefficient manner (by testing all indices starting from @@ -518,6 +536,8 @@ let basename_to_unique (names_set : StringSet.t) in if StringSet.mem basename names_set then gen 0 else basename +type fun_name_info = { keep_fwd : bool; num_backs : int } + (** Extraction context. Note that the extraction context contains information coming from the @@ -528,6 +548,11 @@ let basename_to_unique (names_set : StringSet.t) type extraction_ctx = { trans_ctx : trans_ctx; names_map : names_map; + (** The map for id to names, where we forbid name collisions + (ex.: we always forbid function name collisions). *) + unsafe_names_map : unsafe_names_map; + (** The map for id to names, where we allow name collisions + (ex.: we might allow record field name collisions). *) fmt : formatter; indent_incr : int; (** The indent increment we insert whenever we need to indent more *) @@ -539,6 +564,25 @@ type extraction_ctx = { use it. Also see {!names_map.opaque_ids}. *) + use_dep_ite : bool; + (** For Lean: do we use dependent-if then else expressions? + + Example: + {[ + if h: b then ... else ... + -- ^^ + -- makes the if then else dependent + ]} + *) + fun_name_info : fun_name_info PureUtils.RegularFunIdMap.t; + (** Information used to filter and name functions - we use it + to print comments in the generated code, to help link + the generated code to the original code (information such + as: "this function is the backward function of ...", or + "this function is the merged forward/backward function of ..." + in case a Rust function only has one backward translation + and we filter the forward function because it returns unit. + *) } (** Debugging function, used when communicating name collisions to the user, @@ -667,23 +711,42 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id +(** We might not check for collisions for some specific ids (ex.: field names) *) +let allow_collisions (id : id) : bool = + match id with + | FieldId (_, _) -> !Config.record_fields_short_names + | _ -> false + let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = - (* The id_to_string function to print nice debugging messages if there are - * collisions *) - let id_to_string (id : id) : string = id_to_string id ctx in - let names_map = names_map_add id_to_string is_opaque id name ctx.names_map in - { ctx with names_map } + (* We do not use the same name map if we allow/disallow collisions *) + if allow_collisions id then ( + assert (not is_opaque); + { + ctx with + unsafe_names_map = unsafe_names_map_add id name ctx.unsafe_names_map; + }) + else + (* The id_to_string function to print nice debugging messages if there are + * collisions *) + let id_to_string (id : id) : string = id_to_string id ctx in + let names_map = + names_map_add id_to_string is_opaque id name ctx.names_map + in + { ctx with names_map } (** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *) let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string = - match IdMap.find_opt id ctx.names_map.id_to_name with - | Some s -> - let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in - if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s - | None -> - log#serror ("Could not find: " ^ id_to_string id ctx); - raise Not_found + (* We do not use the same name map if we allow/disallow collisions *) + if allow_collisions id then IdMap.find id ctx.unsafe_names_map.id_to_name + else + match IdMap.find_opt id ctx.names_map.id_to_name with + | Some s -> + let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in + if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s + | None -> + log#serror ("Could not find: " ^ id_to_string id ctx); + raise Not_found let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = @@ -918,9 +981,15 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info (keep_fwd, num_backs) in - ctx_add is_opaque - (FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id))) - def_name ctx + let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in + let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in + (* Add the name info *) + { + ctx with + fun_name_info = + PureUtils.RegularFunIdMap.add fun_id { keep_fwd; num_backs } + ctx.fun_name_info; + } type names_map_init = { keywords : string list; @@ -1041,13 +1110,24 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) definitions (in particular between type and function definitions). *) let rg_suff = + (* TODO: make all the backends match what is done for Lean *) match rg with - | None -> "_fwd" + | None -> ( + match !Config.backend with + | FStar | Coq | HOL4 -> "_fwd" + | Lean -> + (* In order to avoid name conflicts: + * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used) + * - otherwise, no suffix (because the backward functions will have a suffix) + *) + if num_backs = 1 && not keep_fwd then "_fwd" else "") | Some rg -> assert (num_region_groups > 0 && num_backs > 0); if num_backs = 1 then (* Exactly one backward function *) - if not keep_fwd then "_fwd_back" else "_back" + match !Config.backend with + | FStar | Coq | HOL4 -> if not keep_fwd then "_fwd_back" else "_back" + | Lean -> if not keep_fwd then "" else "_back" else if (* Several region groups/backward functions: - if all the regions in the group have names, we use those names diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 4a00dfb2..b251a005 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -18,16 +18,19 @@ module GlobalDeclId = A.GlobalDeclId (they monotonically increase across functions) while in {!module:Pure} we want the indices to start at 0 for every function. *) -module LoopId = IdGen () +module LoopId = +IdGen () type loop_id = LoopId.id [@@deriving show, ord] (** We give an identifier to every phase of the synthesis (forward, backward for group of regions 0, etc.) *) -module SynthPhaseId = IdGen () +module SynthPhaseId = +IdGen () (** Pay attention to the fact that we also define a {!E.VarId} module in Values *) -module VarId = IdGen () +module VarId = +IdGen () type integer_type = T.integer_type [@@deriving show, ord] @@ -723,6 +726,7 @@ type fun_sig_info = { *) type fun_sig = { type_params : type_var list; + (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) inputs : ty list; (** The input types. diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 75fc7fe9..c5f7df92 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -750,18 +750,17 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if config.extract_state_type && config.extract_fun_decls then export_state_type (); - (* For Lean, we parameterize the entire development by a section variable - called opaque_defs, of type OpaqueDefs. The code below emits the type - definition for OpaqueDefs, which is a structure, in which each field is one - of the functions marked as Opaque. We emit the `structure ...` bit here, - then rely on `extract_fun_decl` to be aware of this, and skip the keyword - (e.g. "axiom" or "val") so as to generate valid syntax for records. - - We also generate such a structure only if there actually are opaque - definitions. - *) + (* Obsolete: (TODO: remove) For Lean we parameterize the entire development by a section + variable called opaque_defs, of type OpaqueDefs. The code below emits the type + definition for OpaqueDefs, which is a structure, in which each field is one of the + functions marked as Opaque. We emit the `structure ...` bit here, then rely on + `extract_fun_decl` to be aware of this, and skip the keyword (e.g. "axiom" or "val") + so as to generate valid syntax for records. + + We also generate such a structure only if there actually are opaque definitions. *) let wrap_in_sig = - config.extract_opaque && config.extract_fun_decls && !Config.backend = Lean + config.extract_opaque && config.extract_fun_decls + && !Config.wrap_opaque_in_sig && let _, opaque_funs = module_has_opaque_decls ctx in opaque_funs @@ -783,11 +782,22 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if wrap_in_sig then Format.pp_close_box fmt () -let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) - (rust_module_name : string) (module_name : string) (custom_msg : string) - (custom_imports : string list) (custom_includes : string list) : unit = +type extract_file_info = { + filename : string; + namespace : string; + in_namespace : bool; + crate_name : string; + rust_module_name : string; + module_name : string; + custom_msg : string; + custom_imports : string list; + custom_includes : string list; +} + +let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) + : unit = (* Open the file and create the formatter *) - let out = open_out filename in + let out = open_out fi.filename in let fmt = Format.formatter_of_out_channel out in (* Print the headers. @@ -801,19 +811,22 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) (match !Config.backend with | Lean -> Printf.fprintf out "-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS\n"; - Printf.fprintf out "-- [%s]%s\n" rust_module_name custom_msg + Printf.fprintf out "-- [%s]%s\n" fi.rust_module_name fi.custom_msg | Coq | FStar | HOL4 -> Printf.fprintf out "(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *)\n"; - Printf.fprintf out "(** [%s]%s *)\n" rust_module_name custom_msg); + Printf.fprintf out "(** [%s]%s *)\n" fi.rust_module_name fi.custom_msg); + (* Generate the imports *) (match !Config.backend with | FStar -> - Printf.fprintf out "module %s\n" module_name; + Printf.fprintf out "module %s\n" fi.module_name; Printf.fprintf out "open Primitives\n"; (* Add the custom imports *) - List.iter (fun m -> Printf.fprintf out "open %s\n" m) custom_imports; + List.iter (fun m -> Printf.fprintf out "open %s\n" m) fi.custom_imports; (* Add the custom includes *) - List.iter (fun m -> Printf.fprintf out "include %s\n" m) custom_includes; + List.iter + (fun m -> Printf.fprintf out "include %s\n" m) + fi.custom_includes; (* Z3 options - note that we use fuel 1 because it its useful for the decrease clauses *) Printf.fprintf out "\n#set-options \"--z3rlimit 50 --fuel 1 --ifuel 1\"\n" | Coq -> @@ -825,24 +838,29 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) (* Add the custom imports *) List.iter (fun m -> Printf.fprintf out "Require Import %s.\n" m) - custom_imports; + fi.custom_imports; (* Add the custom includes *) List.iter (fun m -> Printf.fprintf out "Require Export %s.\n" m; Printf.fprintf out "Import %s.\n" m) - custom_includes; - Printf.fprintf out "Module %s.\n" module_name + fi.custom_includes; + Printf.fprintf out "Module %s.\n" fi.module_name | Lean -> - Printf.fprintf out "import Base.Primitives\n"; + Printf.fprintf out "import Base\n"; (* Add the custom imports *) - List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_imports; + List.iter (fun m -> Printf.fprintf out "import %s\n" m) fi.custom_imports; (* Add the custom includes *) - List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_includes + List.iter (fun m -> Printf.fprintf out "import %s\n" m) fi.custom_includes; + (* Always open the Primitives namespace *) + Printf.fprintf out "open Primitives\n"; + (* If we are inside the namespace: declare it, otherwise: open it *) + if fi.in_namespace then Printf.fprintf out "namespace %s\n" fi.namespace + else Printf.fprintf out "open %s\n" fi.namespace | HOL4 -> Printf.fprintf out "open primitivesLib divDefLib\n"; (* Add the custom imports and includes *) - let imports = custom_imports @ custom_includes in + let imports = fi.custom_imports @ fi.custom_includes in (* The imports are a list of module names: we need to add a "Theory" suffix *) let imports = List.map (fun s -> s ^ "Theory") imports in if imports <> [] then @@ -850,7 +868,7 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) Printf.fprintf out "open %s\n\n" imports else Printf.fprintf out "\n"; (* Initialize the theory *) - Printf.fprintf out "val _ = new_theory \"%s\"\n\n" module_name); + Printf.fprintf out "val _ = new_theory \"%s\"\n\n" fi.module_name); (* From now onwards, we use the formatter *) (* Set the margin *) Format.pp_set_margin fmt 80; @@ -867,12 +885,13 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) (* Close the module *) (match !Config.backend with - | FStar | Lean -> () + | FStar -> () + | Lean -> if fi.in_namespace then Printf.fprintf out "end %s\n" fi.namespace | HOL4 -> Printf.fprintf out "val _ = export_theory ()\n" - | Coq -> Printf.fprintf out "End %s .\n" module_name); + | Coq -> Printf.fprintf out "End %s .\n" fi.module_name); (* Some logging *) - log#linfo (lazy ("Generated: " ^ filename)); + log#linfo (lazy ("Generated: " ^ fi.filename)); (* Flush and close the file *) close_out out @@ -891,18 +910,24 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : prefixed with the type name to prevent collisions *) match !Config.backend with Coq | FStar | HOL4 -> true | Lean -> false in + (* Initialize the names map (we insert the names of the "primitives" + declarations, and insert the names of the local declarations later) *) let mk_formatter_and_names_map = Extract.mk_formatter_and_names_map in let fmt, names_map = mk_formatter_and_names_map trans_ctx crate.name variant_concatenate_type_name in + (* Put everything in the context *) let ctx = { ExtractBase.trans_ctx; names_map; + unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty }; fmt; indent_incr = 2; use_opaque_pre = !Config.split_files; + use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; + fun_name_info = PureUtils.RegularFunIdMap.empty; } in @@ -968,7 +993,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Open the output file *) (* First compute the filename by replacing the extension and converting the * case (rust module names are snake case) *) - let module_name, extract_filebasename = + let namespace, crate_name, extract_filebasename = match Filename.chop_suffix_opt ~suffix:".llbc" filename with | None -> (* Note that we already checked the suffix upon opening the file *) @@ -977,14 +1002,20 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Retrieve the file basename *) let basename = Filename.basename filename in (* Convert the case *) - let module_name = StringUtils.to_camel_case basename in - let module_name = + let crate_name = StringUtils.to_camel_case basename in + let crate_name = if !Config.backend = HOL4 then - StringUtils.lowercase_first_letter module_name - else module_name + StringUtils.lowercase_first_letter crate_name + else crate_name + in + (* We use the raw crate name for the namespaces *) + let namespace = + match !Config.backend with + | FStar | Coq | HOL4 -> crate.name + | Lean -> crate.name in (* Concatenate *) - (module_name, Filename.concat dest_dir module_name) + (namespace, crate_name, Filename.concat dest_dir crate_name) in (* Put the translated definitions in maps *) @@ -1019,11 +1050,10 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : create more directories *) if !Config.backend = Lean then ( let ( ^^ ) = Filename.concat in - mkdir_if (dest_dir ^^ "Base"); - if !Config.split_files then mkdir_if (dest_dir ^^ module_name); + if !Config.split_files then mkdir_if (dest_dir ^^ crate_name); if needs_clauses_module then ( assert !Config.split_files; - mkdir_if (dest_dir ^^ module_name ^^ "Clauses"))); + mkdir_if (dest_dir ^^ crate_name ^^ "Clauses"))); (* Copy the "Primitives" file, if necessary *) let _ = @@ -1033,7 +1063,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : match !Config.backend with | FStar -> Some ("/backends/fstar/Primitives.fst", "Primitives.fst") | Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v") - | Lean -> Some ("/backends/lean/Primitives.lean", "Base/Primitives.lean") + | Lean -> None | HOL4 -> None in match primitives_src_dest with @@ -1117,6 +1147,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Extract the types *) (* If there are opaque types, we extract in an interface *) + (* TODO: for Lean and Coq: generate a template file *) let types_filename_ext = match !Config.backend with | FStar -> if has_opaque_types then ".fsti" else ".fst" @@ -1127,7 +1158,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : let types_filename = extract_filebasename ^ file_delimiter ^ "Types" ^ types_filename_ext in - let types_module = module_name ^ module_delimiter ^ "Types" in + let types_module = crate_name ^ module_delimiter ^ "Types" in let types_config = { base_gen_config with @@ -1137,8 +1168,20 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : interface = has_opaque_types; } in - extract_file types_config gen_ctx types_filename crate.A.name types_module - ": type definitions" [] []; + let file_info = + { + filename = types_filename; + namespace; + in_namespace = true; + crate_name; + rust_module_name = crate.A.name; + module_name = types_module; + custom_msg = ": type definitions"; + custom_imports = []; + custom_includes = []; + } + in + extract_file types_config gen_ctx file_info; (* Extract the template clauses *) (if needs_clauses_module && !Config.extract_template_decreases_clauses then @@ -1147,33 +1190,49 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : ^ "Template" ^ ext in let template_clauses_module = - module_name ^ module_delimiter ^ "Clauses" ^ module_delimiter + crate_name ^ module_delimiter ^ "Clauses" ^ module_delimiter ^ "Template" in let template_clauses_config = { base_gen_config with extract_template_decreases_clauses = true } in - extract_file template_clauses_config gen_ctx template_clauses_filename - crate.A.name template_clauses_module - ": templates for the decreases clauses" [ types_module ] []); + let file_info = + { + filename = template_clauses_filename; + namespace; + in_namespace = true; + crate_name; + rust_module_name = crate.A.name; + module_name = template_clauses_module; + custom_msg = ": templates for the decreases clauses"; + custom_imports = [ types_module ]; + custom_includes = []; + } + in + extract_file template_clauses_config gen_ctx file_info); (* Extract the opaque functions, if needed *) let opaque_funs_module = if has_opaque_funs then ( + (* In the case of Lean we generate a template file *) + let module_suffix, opaque_imported_suffix, custom_msg = + match !Config.backend with + | FStar | Coq | HOL4 -> + ("Opaque", "Opaque", ": external function declarations") + | Lean -> + ( "FunsExternal_Template", + "FunsExternal", + ": external functions.\n\ + -- This is a template file: rename it to \ + \"FunsExternal.lean\" and fill the holes." ) + in let opaque_filename = - extract_filebasename ^ file_delimiter ^ "Opaque" ^ opaque_ext + extract_filebasename ^ file_delimiter ^ module_suffix ^ opaque_ext in - let opaque_module = module_name ^ module_delimiter ^ "Opaque" in + let opaque_module = crate_name ^ module_delimiter ^ module_suffix in let opaque_imported_module = - (* In the case of Lean, we declare an interface (a record) containing - the opaque definitions, and we leave it to the user to provide an - instance of this module. - - TODO: do the same for Coq. - TODO: do the same for the type definitions. - *) if !Config.backend = Lean then - module_name ^ module_delimiter ^ "ExternalFuns" + crate_name ^ module_delimiter ^ opaque_imported_suffix else opaque_module in let opaque_config = @@ -1191,15 +1250,28 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false }; } in - extract_file opaque_config gen_ctx opaque_filename crate.A.name - opaque_module ": opaque function definitions" [] [ types_module ]; + let file_info = + { + filename = opaque_filename; + namespace; + in_namespace = false; + crate_name; + rust_module_name = crate.A.name; + module_name = opaque_module; + custom_msg; + custom_imports = []; + custom_includes = [ types_module ]; + } + in + extract_file opaque_config gen_ctx file_info; + (* Return the additional dependencies *) [ opaque_imported_module ]) else [] in (* Extract the functions *) let fun_filename = extract_filebasename ^ file_delimiter ^ "Funs" ^ ext in - let fun_module = module_name ^ module_delimiter ^ "Funs" in + let fun_module = crate_name ^ module_delimiter ^ "Funs" in let fun_config = { base_gen_config with @@ -1213,12 +1285,24 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : let clauses_submodule = if !Config.backend = Lean then module_delimiter ^ "Clauses" else "" in - [ module_name ^ clauses_submodule ^ module_delimiter ^ "Clauses" ] + [ crate_name ^ clauses_submodule ^ module_delimiter ^ "Clauses" ] else [] in - extract_file fun_config gen_ctx fun_filename crate.A.name fun_module - ": function definitions" [] - ([ types_module ] @ opaque_funs_module @ clauses_module)) + let file_info = + { + filename = fun_filename; + namespace; + in_namespace = true; + crate_name; + rust_module_name = crate.A.name; + module_name = fun_module; + custom_msg = ": function definitions"; + custom_imports = []; + custom_includes = + [ types_module ] @ opaque_funs_module @ clauses_module; + } + in + extract_file fun_config gen_ctx file_info) else let gen_config = { @@ -1235,10 +1319,21 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : test_trans_unit_functions = !Config.test_trans_unit_functions; } in + let file_info = + { + filename = extract_filebasename ^ ext; + namespace; + in_namespace = true; + crate_name; + rust_module_name = crate.A.name; + module_name = crate_name; + custom_msg = ""; + custom_imports = []; + custom_includes = []; + } + in (* Add the extension for F* *) - let extract_filename = extract_filebasename ^ ext in - extract_file gen_config gen_ctx extract_filename crate.A.name module_name - "" [] []); + extract_file gen_config gen_ctx file_info); (* Generate the build file *) match !Config.backend with @@ -1256,47 +1351,45 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : * different files. *) if !Config.split_files then ( - let filename = Filename.concat dest_dir (module_name ^ ".lean") in + let filename = Filename.concat dest_dir (crate_name ^ ".lean") in let out = open_out filename in (* Write *) - Printf.fprintf out "import %s.Funs\n" module_name; + Printf.fprintf out "import %s.Funs\n" crate_name; (* Flush and close the file, log *) close_out out; log#linfo (lazy ("Generated: " ^ filename))); (* - * Generate the lakefile.lean file + * Generate the lakefile.lean file, if the user asks for it *) + if !Config.lean_gen_lakefile then ( + (* Open the file *) + let filename = Filename.concat dest_dir "lakefile.lean" in + let out = open_out filename in - (* Open the file *) - let filename = Filename.concat dest_dir "lakefile.lean" in - let out = open_out filename in - - (* Generate the content *) - Printf.fprintf out "import Lake\nopen Lake DSL\n\n"; - Printf.fprintf out "require mathlib from git\n"; - Printf.fprintf out - " \"https://github.com/leanprover-community/mathlib4.git\"\n\n"; - - let package_name = StringUtils.to_snake_case module_name in - Printf.fprintf out "package «%s» {}\n\n" package_name; - - Printf.fprintf out "lean_lib «Base» {}\n\n"; - - Printf.fprintf out "@[default_target]\nlean_lib «%s» {}\n" module_name; - - (* No default target for now. - Format would be: - {[ - @[default_target] - lean_exe «package_name» { - root := `Main - } - ]} - *) + (* Generate the content *) + Printf.fprintf out "import Lake\nopen Lake DSL\n\n"; + Printf.fprintf out "require mathlib from git\n"; + Printf.fprintf out + " \"https://github.com/leanprover-community/mathlib4.git\"\n\n"; + + let package_name = StringUtils.to_snake_case crate_name in + Printf.fprintf out "package «%s» {}\n\n" package_name; + + Printf.fprintf out "@[default_target]\nlean_lib «%s» {}\n" crate_name; + + (* No default target for now. + Format would be: + {[ + @[default_target] + lean_exe «package_name» { + root := `Main + } + ]} + *) - (* Flush and close the file *) - close_out out; + (* Flush and close the file *) + close_out out; - (* Logging *) - log#linfo (lazy ("Generated: " ^ filename)) + (* Logging *) + log#linfo (lazy ("Generated: " ^ filename))) |