diff options
author | Jonathan Protzenko | 2023-01-23 18:17:42 -0800 |
---|---|---|
committer | Son HO | 2023-06-04 21:44:33 +0200 |
commit | cbcaa965c4ee5597bb8f4f8bee7fba87729e7154 (patch) | |
tree | fb6d55d9cd9248e26fc6f16f8521e6c7e3fb2ef0 | |
parent | df2e79f88b04ae7bf43586eb83ea0461fb547b3b (diff) |
Initial Lean backend, WIP
-rw-r--r-- | Makefile | 8 | ||||
-rw-r--r-- | backends/lean/primitives.lean | 155 | ||||
-rw-r--r-- | compiler/Config.ml | 6 | ||||
-rw-r--r-- | compiler/Driver.ml | 2 | ||||
-rw-r--r-- | compiler/Extract.ml | 288 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 2 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 2 | ||||
-rw-r--r-- | compiler/Translate.ml | 56 | ||||
-rw-r--r-- | compiler/dune | 2 | ||||
-rw-r--r-- | tests/Makefile | 7 |
10 files changed, 406 insertions, 122 deletions
@@ -186,7 +186,7 @@ gen-llbcp-%: .PHONY: trans-% trans-%: CHARON_TEST_DIR = $(CHARON_TESTS_REGULAR_DIR) trans-%: FILE = $* -trans-%: gen-llbc-% tfstar-% tcoq-% +trans-%: gen-llbc-% tfstar-% tcoq-% tlean-% echo "# Test $* done" # "p" stands for "Polonius" @@ -222,6 +222,12 @@ tcoqp-%: BACKEND_SUBDIR := coq tcoqp-%: $(AENEAS_CMD) +.PHONY: tlean-% +tlean-%: OPTIONS += -backend lean -test-trans-units +tlean-%: BACKEND_SUBDIR := lean +tlean-%: + $(AENEAS_CMD) + # Nix .PHONY: nix nix: nix-aeneas-tests nix-aeneas-verify-fstar nix-aeneas-verify-coq diff --git a/backends/lean/primitives.lean b/backends/lean/primitives.lean new file mode 100644 index 00000000..b68df5f0 --- /dev/null +++ b/backends/lean/primitives.lean @@ -0,0 +1,155 @@ +------------- +-- PRELUDE -- +------------- + +-- Results & monadic combinators + +inductive error where + | assertionFailure: error + | integerOverflow: error + | arrayOutOfBounds: error + | maximumSizeExceeded: error + | panic: error +deriving Repr + +open error + +inductive result (α : Type u) where + | ret (v: α): result α + | fail (e: error): result α +deriving Repr + +open result + +-- TODO: is there automated syntax for these discriminators? +def is_ret {α: Type} (r: result α): Bool := + match r with + | result.ret _ => true + | result.fail _ => false + +def eval_global {α: Type} (x: result α) (h: is_ret x): α := + match x with + | result.fail _ => by contradiction + | result.ret x => x + +def bind (x: result α) (f: α -> result β) : result β := + match x with + | ret v => f v + | fail v => fail v + +-- Allows using result in do-blocks +instance : Bind result where + bind := bind + +-- Allows using return x in do-blocks +instance : Pure result where + pure := fun x => ret x + +def massert (b:Bool) : result Unit := + if b then return () else fail assertionFailure + +-- Machine integers + +-- NOTE: we reuse the USize type from prelude.lean, because at least we know +-- it's defined in an idiomatic style that is going to make proofs easy (and +-- indeed, several proofs here are much shortened compared to Aymeric's earlier +-- attempt.) This is not stricto sensu the *correct* thing to do, because one +-- can query at run-time the value of USize, which we do *not* want to do (we +-- don't know what target we'll run on!), but when the day comes, we'll just +-- define our own USize. +-- ANOTHER NOTE: there is USize.sub but it has wraparound semantics, which is +-- not something we want to define (I think), so we use our own monadic sub (but +-- is it in line with the Rust behavior?) + +-- TODO: I am somewhat under the impression that subtraction is defined as a +-- total function over nats...? the hypothesis in the if condition is not used +-- in the then-branch which confuses me quite a bit + +-- TODO: add a refinement for the result (just like vec_push_back below) that +-- explains that the toNat of the result (in the case of success) is the sub of +-- the toNat of the arguments (i.e. intrinsic specification) +-- ... do we want intrinsic specifications for the builtins? that might require +-- some careful type annotations in the monadic notation for clients, but may +-- give us more "for free" + +-- Note from Chris Bailey: "If there's more than one salient property of your +-- definition then the subtyping strategy might get messy, and the property part +-- of a subtype is less discoverable by the simplifier or tactics like +-- library_search." Try to settle this with a Lean expert on what is the most +-- productive way to go about this? + +-- Further thoughts: look at what has been done here: +-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean +-- and +-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean +-- which both contain a fair amount of reasoning already! +def USize.checked_sub (n: USize) (m: Nat): result USize := + -- NOTE: the test USize.toNat n - m >= 0 seems to always succeed? + if USize.toNat n >= m then + let n' := USize.toNat n + let r := USize.ofNatCore (n' - m) (by + have h: n' - m <= n' := by + apply Nat.sub_le_of_le_add + case h => rewrite [ Nat.add_comm ]; apply Nat.le_add_left + apply Nat.lt_of_le_of_lt h + apply n.val.isLt + ) + return r + else + fail integerOverflow + +-- TODO: settle the style for usize_sub before we write these +def USize.checked_mul (n: USize) (m: USize): result USize := sorry +def USize.checked_div (n: USize) (m: USize): result USize := sorry + +#eval USize.checked_sub 10 20 +#eval USize.checked_sub 20 10 +-- NOTE: compare with concrete behavior here, which I do not think we want +#eval USize.sub 0 1 +#eval UInt8.add 255 255 + +-- Vectors + +def vec (α : Type u) := { l : List α // List.length l < USize.size } + +def vec_new : result (vec α) := return ⟨ [], by { + match USize.size, usize_size_eq with + | _, Or.inl rfl => simp + | _, Or.inr rfl => simp + } ⟩ + +def vec_len (v : vec α) : USize := + let ⟨ v, l ⟩ := v + USize.ofNatCore (List.length v) l + +#eval do + return (vec_len (<- @vec_new Nat)) + +def vec_push_fwd (_ : vec α) (_ : α) : Unit := () + +-- TODO: more precise error condition here for the fail case +-- TODO: I originally wrote `List.length v.val < USize.size - 1`; how can one +-- make the proof work in that case? Probably need to import tactics from +-- mathlib to deal with inequalities... would love to see an example. +def vec_push_back (v : vec α) (x : α) : { res: result (vec α) // + match res with | fail _ => True | ret v' => List.length v'.val = List.length v.val + 1} + := + if h : List.length v.val + 1 < USize.size then + ⟨ return ⟨List.concat v.val x, + by + rw [List.length_concat] + assumption + ⟩, by simp ⟩ + else + ⟨ fail maximumSizeExceeded, by simp ⟩ + +#eval do + -- NOTE: the // notation is syntactic sugar for Subtype, a refinement with + -- fields val and property. However, Lean's elaborator can automatically + -- select the `val` field if the context provides a type annotation. We + -- annotate `x`, which relieves us of having to write `.val` on the right-hand + -- side of the monadic let. + let x: vec Nat ← vec_push_back (<- vec_new) 1 + -- TODO: strengthen post-condition above and do a demo to show that we can + -- safely eliminate the `fail` case + return (vec_len x) diff --git a/compiler/Config.ml b/compiler/Config.ml index cc80e452..1c3d14ff 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -3,13 +3,13 @@ (** {1 Backend choice} *) (** The choice of backend *) -type backend = FStar | Coq +type backend = FStar | Coq | Lean -let backend_names = [ "fstar"; "coq" ] +let backend_names = [ "fstar"; "coq"; "lean" ] (** Utility to compute the backend from an input parameter *) let backend_of_string (b : string) : backend option = - match b with "fstar" -> Some FStar | "coq" -> Some Coq | _ -> None + match b with "fstar" -> Some FStar | "coq" -> Some Coq | "lean" -> Some Lean | _ -> None let opt_backend : backend option ref = ref None diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 73a2c974..4350c9ae 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -147,6 +147,8 @@ let () = (* Some patterns are not supported *) decompose_monadic_let_bindings := true; decompose_nested_let_patterns := true + | Lean -> + () in (* Retrieve and check the filename *) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 518e8979..6bda6376 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -14,25 +14,36 @@ module F = Format (** Small helper to compute the name of an int type *) let int_name (int_ty : integer_type) = + let isize, usize, i_format, u_format = + match !backend with + | FStar | Coq -> + "isize", "usize", format_of_string "i%d", format_of_string "u%d" + | Lean -> + "ISize", "USize", format_of_string "Int%d", format_of_string "UInt%d" + in match int_ty with - | Isize -> "isize" - | I8 -> "i8" - | I16 -> "i16" - | I32 -> "i32" - | I64 -> "i64" - | I128 -> "i128" - | Usize -> "usize" - | U8 -> "u8" - | U16 -> "u16" - | U32 -> "u32" - | U64 -> "u64" - | U128 -> "u128" + | Isize -> isize + | I8 -> Printf.sprintf i_format 8 + | I16 -> Printf.sprintf i_format 16 + | I32 -> Printf.sprintf i_format 32 + | I64 -> Printf.sprintf i_format 64 + | I128 -> Printf.sprintf i_format 128 + | Usize -> usize + | U8 -> Printf.sprintf u_format 8 + | U16 -> Printf.sprintf u_format 16 + | U32 -> Printf.sprintf u_format 32 + | U64 -> Printf.sprintf u_format 64 + | U128 -> Printf.sprintf u_format 128 (** Small helper to compute the name of a unary operation *) let unop_name (unop : unop) : string = match unop with - | Not -> ( match !backend with FStar -> "not" | Coq -> "negb") - | Neg int_ty -> int_name int_ty ^ "_neg" + | Not -> ( match !backend with FStar | Lean -> "not" | Coq -> "negb") + | Neg (int_ty: integer_type) -> + begin match !backend with + | Lean -> int_name int_ty ^ ".checked_neg" + | _ -> int_name int_ty ^ "_neg" + end | Cast _ -> raise (Failure "Unsupported") (** Small helper to compute the name of a binary operation (note that many @@ -49,7 +60,9 @@ let named_binop_name (binop : E.binop) (int_ty : integer_type) : string = | Mul -> "mul" | _ -> raise (Failure "Unreachable") in - int_name int_ty ^ "_" ^ binop + match !backend with + | Lean -> int_name int_ty ^ ".checked_" ^ binop + | FStar | Coq -> int_name int_ty ^ "_" ^ binop (** A list of keywords/identifiers used by the backend and with which we want to check collision. *) @@ -60,10 +73,9 @@ let keywords () = in let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in let named_binops = - List.concat - (List.map - (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types) - named_binops) + List.concat_map + (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types) + named_binops in let misc = match !backend with @@ -120,6 +132,8 @@ let keywords () = "tt"; "char_of_byte"; ] + | Lean -> + [] (* TODO *) in List.concat [ named_unops; named_binops; misc ] @@ -159,6 +173,16 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] + | Lean -> + [ + (Result, result_return_id, "ret"); + (Result, result_fail_id, "fail_"); (* TODO: why the _ *) + (Error, error_failure_id, "panic"); + (* No Fuel::Zero on purpose *) + (* No Fuel::Succ on purpose *) + (Option, option_some_id, "@some"); + (Option, option_none_id, "@none"); + ] let assumed_llbc_functions : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = @@ -191,6 +215,15 @@ let assumed_pure_functions : (pure_assumed_fun_id * string) list = | Coq -> (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] + | Lean -> + [ + (Return, "return"); + (Fail, "fail_"); + (Assert, "massert"); + (* TODO: figure out how to deal with this *) + (FuelDecrease, "decrease"); + (FuelEqZero, "is_zero"); + ] let names_map_init () : names_map_init = { @@ -241,12 +274,12 @@ let extract_binop (extract_expr : bool -> texpression -> unit) | Eq -> "=" | Lt -> "<" | Le -> "<=" - | Ne -> "<>" + | Ne -> if !backend = Lean then "!=" else "<>" | Ge -> ">=" | Gt -> ">" | _ -> raise (Failure "Unreachable") in - let binop = match !backend with FStar -> binop | Coq -> "s" ^ binop in + let binop = match !backend with FStar | Lean -> binop | Coq -> "s" ^ binop in extract_expr false arg0; F.pp_print_space fmt (); F.pp_print_string fmt binop; @@ -263,7 +296,7 @@ let extract_binop (extract_expr : bool -> texpression -> unit) if inside then F.pp_print_string fmt ")" let type_decl_kind_to_qualif (kind : decl_kind) - (type_kind : type_decl_kind option) : string = + (type_kind : type_decl_kind option) (is_rec: bool): string = match !backend with | FStar -> ( match kind with @@ -286,6 +319,16 @@ let type_decl_kind_to_qualif (kind : decl_kind) "with" | (Assumed | Declared), None -> "Axiom" | _ -> raise (Failure "Unexpected")) + | Lean -> ( + match kind with + | SingleNonRec -> + if type_kind = Some Struct && not is_rec then "structure" else "inductive" + | SingleRec -> "inductive" + | MutRecFirst -> "mutual inductive" + | MutRecInner -> "inductive" + | MutRecLast -> "inductive" (* TODO: need to print end afterwards *) + | Assumed -> "axiom" + | Declared -> "axiom") let fun_decl_kind_to_qualif (kind : decl_kind) = match !backend with @@ -307,12 +350,22 @@ let fun_decl_kind_to_qualif (kind : decl_kind) = | MutRecLast -> "with" | Assumed -> "Axiom" | Declared -> "Axiom") + | Lean -> ( + match kind with + | SingleNonRec -> "def" + | SingleRec -> "def" + | MutRecFirst -> "mutual def" + | MutRecInner -> "def" + | MutRecLast -> "def" (* TODO: need to print end afterwards *) + | Assumed -> "axiom" + | Declared -> "axiom") + (** [ctx]: we use the context to lookup type definitions, to retrieve type names. This is used to compute variable names, when they have no basenames: in this case we use the first letter of the type name. - + [variant_concatenate_type_name]: if true, add the type name as a prefix to the variant names. Ex.: @@ -323,21 +376,21 @@ let fun_decl_kind_to_qualif (kind : decl_kind) = Nil, } ]} - + F*, if option activated: {[ type list = | ListCons : u32 -> list -> list | ListNil : list ]} - + F*, if option not activated: {[ type list = | Cons : u32 -> list -> list | Nil : list ]} - + Rk.: this should be true by default, because in Rust all the variant names are actively uniquely identifier by the type name [List::Cons(...)], while in other languages it is not necessarily the case, and thus clashes can mess @@ -382,7 +435,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let name = get_type_name name in let name = List.map to_snake_case name in let name = String.concat "_" name in - match !backend with FStar -> name | Coq -> capitalize_first_letter name + match !backend with FStar | Lean -> name | Coq -> capitalize_first_letter name in let type_name name = type_name_to_snake_case name ^ "_t" in let field_name (def_name : name) (field_id : FieldId.id) @@ -400,7 +453,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in let struct_constructor (basename : name) : string = let tname = type_name basename in - let prefix = match !backend with FStar -> "Mk" | Coq -> "mk" in + let prefix = match !backend with FStar -> "Mk" | Lean | Coq -> "mk" in prefix ^ tname in let get_fun_name = get_name in @@ -481,7 +534,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* TODO: use "t" also for F* *) match !backend with | FStar -> "x" (* lacking inspiration here... *) - | Coq -> "t" (* lacking inspiration here... *)) + | Coq | Lean -> "t" (* lacking inspiration here... *)) | Bool -> "b" | Char -> "c" | Integer _ -> "i" @@ -495,7 +548,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | FStar -> (* This is *not* a no-op: this removes the capital letter *) to_snake_case basename - | Coq -> basename + | Coq | Lean -> basename in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i @@ -514,13 +567,19 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) F.pp_print_string fmt (Z.to_string sv.PV.value) else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty); - if inside then F.pp_print_string fmt ")") + if inside then F.pp_print_string fmt ")" + | Lean -> + F.pp_print_string fmt "("; + F.pp_print_string fmt (int_name sv.int_ty); + F.pp_print_string fmt ".ofNatCore "; + Z.pp_print fmt sv.value; + F.pp_print_string fmt (" (by simp))")) | Bool b -> let b = if b then "true" else "false" in F.pp_print_string fmt b | Char c -> ( match !backend with - | FStar -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") + | FStar | Lean -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") | Coq -> if inside then F.pp_print_string fmt "("; F.pp_print_string fmt "char_of_byte"; @@ -534,6 +593,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in F.pp_print_string fmt c; if inside then F.pp_print_string fmt ")") + | String s -> (* We need to replace all the line breaks *) let s = @@ -543,11 +603,14 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in F.pp_print_string fmt ("\"" ^ s ^ "\"") in + let bool_name = if !backend = Lean then "Bool" else "bool" in + let char_name = if !backend = Lean then "Char" else "char" in + let str_name = if !backend = Lean then "String" else "string" in { - bool_name = "bool"; - char_name = "char"; + bool_name; + char_name; int_name; - str_name = "string"; + str_name; type_decl_kind_to_qualif; fun_decl_kind_to_qualif; field_name; @@ -577,6 +640,8 @@ let print_decl_end_delimiter (fmt : F.formatter) (kind : decl_kind) = F.pp_print_cut fmt (); F.pp_print_string fmt ".") +let unit_name = match !backend with | Lean -> "Unit" | Coq | FStar -> "unit" + (** [inside] constrols whether we should add parentheses or not around type applications (if [true] we add parentheses). *) @@ -588,13 +653,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Tuple -> (* This is a bit annoying, but in F*/Coq [()] is not the unit type: * we have to write [unit]... *) - if tys = [] then F.pp_print_string fmt "unit" + if tys = [] then F.pp_print_string fmt unit_name else ( F.pp_print_string fmt "("; Collections.List.iter_link (fun () -> F.pp_print_space fmt (); - let product = match !backend with FStar -> "&" | Coq -> "*" in + let product = match !backend with FStar -> "&" | Coq -> "*" | Lean -> "×" in F.pp_print_string fmt product; F.pp_print_space fmt ()) (extract_ty ctx fmt true) tys; @@ -625,7 +690,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Compute the names for all the top-level identifiers used in a type definition (type name, variant names, field names, etc. but not type parameters). - + We need to do this preemptively, beforce extracting any definition, because of recursive definitions. *) @@ -694,7 +759,7 @@ let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt (field_name ^ " :"); F.pp_print_space fmt (); ctx) - | Coq -> ctx + | Coq | Lean -> ctx in (* Print the field type *) extract_ty ctx fmt false f.field_ty; @@ -817,14 +882,16 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) let _ = if !backend = FStar && fields = [] then ( F.pp_print_space fmt (); - F.pp_print_string fmt "unit") + F.pp_print_string fmt unit_name) else if (not is_rec) || !backend = FStar then ( F.pp_print_space fmt (); (* If Coq: print the constructor name *) + (* TODO: remove superfluous test not is_rec below *) if !backend = Coq && not is_rec then ( F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx); F.pp_print_string fmt " "); - F.pp_print_string fmt "{"; + if !backend <> Lean then + F.pp_print_string fmt "{"; F.pp_print_break fmt 1 ctx.indent_incr; (* The body itself *) F.pp_open_hvbox fmt 0; @@ -837,7 +904,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ":"; F.pp_print_space fmt (); extract_ty ctx fmt false f.field_ty; - F.pp_print_string fmt ";"; + if !backend <> Lean then + F.pp_print_string fmt ";"; F.pp_close_box fmt () in let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in @@ -847,7 +915,8 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (* Close *) F.pp_close_box fmt (); F.pp_print_space fmt (); - F.pp_print_string fmt "}") + if !backend <> Lean then + F.pp_print_string fmt "}") else ( (* We extract for Coq, and we have a recursive record, or a record in a group of mutually recursive types: we extract it as an inductive type @@ -859,6 +928,18 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) in () +(** Extract a nestable, muti-line comment *) +let extract_comment (fmt: F.formatter) (s: string): unit = + match !backend with + | Coq | FStar -> + F.pp_print_string fmt "(** "; + F.pp_print_string fmt s; + F.pp_print_string fmt " *)"; + | Lean -> + F.pp_print_string fmt "/- "; + F.pp_print_string fmt s; + F.pp_print_string fmt " -/" + (** Extract a type declaration. Note that all the names used for extraction should already have been @@ -894,7 +975,7 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt ("(** [" ^ Print.name_to_string def.name ^ "] *)"); + extract_comment fmt ("[" ^ Print.name_to_string def.name ^ "]"); F.pp_print_space fmt (); (* Open a box for the definition, so that whenever possible it gets printed on * one line *) @@ -902,10 +983,11 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) - let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in + let is_rec = decl_is_from_rec_group kind in + let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind is_rec in F.pp_print_string fmt (qualif ^ " " ^ def_name); (* Print the type parameters *) - let type_keyword = match !backend with FStar -> "Type0" | Coq -> "Type" in + let type_keyword = match !backend with FStar -> "Type0" | Coq | Lean -> "Type" in if def.type_params <> [] then ( if use_forall then ( F.pp_print_space fmt (); @@ -926,7 +1008,11 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Print the "=" if we extract the body*) if extract_body then ( F.pp_print_space fmt (); - let eq = match !backend with FStar -> "=" | Coq -> ":=" in + let eq = match !backend with + | FStar -> "=" + | Coq -> ":=" + | Lean -> if type_kind = Some Struct && not is_rec then "where" else ":=" + in F.pp_print_string fmt eq) else ( (* Otherwise print ": Type0" *) @@ -1186,7 +1272,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = match !backend with - | FStar -> () + | FStar | Lean -> () | Coq -> extract_type_decl_coq_arguments ctx fmt kind decl; extract_type_decl_record_field_projectors ctx fmt kind decl @@ -1197,13 +1283,28 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment *) - F.pp_print_string fmt "(** The state type used in the state-error monad *)"; + extract_comment fmt "The state type used in the state-error monad"; F.pp_print_space fmt (); (* Open a box for the definition, so that whenever possible it gets printed on * one line *) F.pp_open_hvbox fmt 0; (* Retrieve the name *) let state_name = ctx_get_assumed_type State ctx in + (* The syntax for Lean and Coq is almost identical. *) + let print_axiom () = + if !backend = Coq then + F.pp_print_string fmt "Axiom" + else + F.pp_print_string fmt "axiom"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type"; + if !backend = Coq then + F.pp_print_string fmt "." + in (* The kind should be [Assumed] or [Declared] *) (match kind with | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> @@ -1220,14 +1321,8 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "Type0" - | Coq -> - F.pp_print_string fmt "Axiom"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type.") + | Coq | Lean -> + print_axiom ()) | Declared -> ( match !backend with | FStar -> @@ -1238,14 +1333,8 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "Type0" - | Coq -> - F.pp_print_string fmt "Axiom"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type.")); + | Coq | Lean -> + print_axiom ())); (* Close the box for the definition *) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) @@ -1289,7 +1378,7 @@ let extract_global_decl_register_names (ctx : extraction_ctx) Note that patterns can introduce new variables: we thus return an extraction context updated with new bindings. - + TODO: we don't need something very generic anymore (some definitions used to be polymorphic). *) @@ -1374,7 +1463,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) av.field_values v.ty (** [inside]: controls the introduction of parentheses. See [extract_ty] - + TODO: replace the formatting boolean [inside] with something more general? Also, it seems we don't really use it... Cases to consider: @@ -1555,7 +1644,9 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* If in Coq, the field projection has to be parenthesized *) (match !backend with | FStar -> F.pp_print_string fmt field_name - | Coq -> F.pp_print_string fmt ("(" ^ field_name ^ ")")); + | Coq -> F.pp_print_string fmt ("(" ^ field_name ^ ")") + | Lean -> F.pp_print_string fmt field_name + ); (* Close the box *) F.pp_close_box fmt () | arg :: args -> @@ -1619,7 +1710,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) if monadic && !backend = Coq then ( let ctx = extract_typed_pattern ctx fmt true lv in F.pp_print_space fmt (); - let arrow = match !backend with FStar -> "<--" | Coq -> "<-" in + let arrow = match !backend with FStar -> "<--" | Coq -> "<-" | Lean -> failwith "impossible" in F.pp_print_string fmt arrow; F.pp_print_space fmt (); extract_texpression ctx fmt false re; @@ -1630,7 +1721,7 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_print_space fmt (); let ctx = extract_typed_pattern ctx fmt true lv in F.pp_print_space fmt (); - let eq = match !backend with FStar -> "=" | Coq -> ":=" in + let eq = match !backend with FStar -> "=" | Coq -> ":=" | Lean -> "<-" in F.pp_print_string fmt eq; F.pp_print_space fmt (); extract_texpression ctx fmt false re; @@ -1698,9 +1789,10 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_print_space fmt (); F.pp_print_string fmt "begin"; F.pp_print_space fmt () - | Coq -> + | Coq | Lean -> F.pp_print_string fmt " ("; - F.pp_print_cut fmt ()); + F.pp_print_cut fmt () + ); (* Print the branch expression *) extract_texpression ctx fmt false e_branch; (* Close the parenthesized expression *) @@ -1709,7 +1801,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | FStar -> F.pp_print_space fmt (); F.pp_print_string fmt "end" - | Coq -> F.pp_print_string fmt ")"); + | Coq | Lean -> F.pp_print_string fmt ")"); (* Close the box for the branch *) if not parenth then F.pp_close_box fmt (); (* Close the box for the then/else+branch *) @@ -1723,7 +1815,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the [match ... with] *) let match_begin = - match !backend with FStar -> "begin match" | Coq -> "match" + match !backend with FStar -> "begin match" | Coq | Lean -> "match" in F.pp_print_string fmt match_begin; F.pp_print_space fmt (); @@ -1744,7 +1836,7 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_print_space fmt (); let ctx = extract_typed_pattern ctx fmt false br.pat in F.pp_print_space fmt (); - let arrow = match !backend with FStar -> "->" | Coq -> "=>" in + let arrow = match !backend with FStar -> "->" | Coq | Lean -> "=>" in F.pp_print_string fmt arrow; F.pp_print_space fmt (); (* Open a box for the branch *) @@ -1798,7 +1890,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) def.signature.type_params; F.pp_print_string fmt ":"; F.pp_print_space fmt (); - let type_keyword = match !backend with FStar -> "Type0" | Coq -> "Type" in + let type_keyword = match !backend with FStar -> "Type0" | Coq | Lean -> "Type" in F.pp_print_string fmt (type_keyword ^ ")"); (* Close the box for the type parameters *) F.pp_close_box fmt ()); @@ -1829,7 +1921,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (** A small utility to print the types of the input parameters in the form: [u32 -> list u32 -> ...] (we don't print the return type of the function) - + This is used for opaque function declarations, in particular. *) let extract_fun_input_parameters_types (ctx : extraction_ctx) @@ -1853,7 +1945,7 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx) {[ let f_decrease (t : Type0) (x : t) : nat = admit() ]} - + Where the translated functions for [f] look like this: {[ let f_fwd (t : Type0) (x : t) : Tot ... (decreases (f_decrease t x)) = ... @@ -1867,8 +1959,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt - ("(** [" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause *)"); + extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause"); F.pp_print_space fmt (); (* Open a box for the definition, so that whenever possible it gets printed on * one line *) @@ -1910,7 +2001,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) Note that all the names used for extraction should already have been registered. - + We take the definition of the forward translation as parameter (which is equal to the definition to extract, if we extract a forward function) because it is useful for the decrease clause. @@ -1925,8 +2016,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt - ("(** [" ^ Print.fun_name_to_string def.basename ^ "] *)"); + extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]"); F.pp_print_space fmt (); (* Open two boxes for the definition, so that whenever possible it gets printed on * one line and indents are correct *) @@ -2043,7 +2133,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Print the "=" *) if not is_opaque then ( F.pp_print_space fmt (); - let eq = match !backend with FStar -> "=" | Coq -> ":=" in + let eq = match !backend with FStar -> "=" | Coq | Lean -> ":=" in F.pp_print_string fmt eq); (* Close the box for "(PARAMS) : EFFECT =" *) F.pp_close_box fmt (); @@ -2102,7 +2192,7 @@ let extract_global_decl_body (ctx : extraction_ctx) (fmt : F.formatter) if not is_opaque then ( (* Print " =" *) F.pp_print_space fmt (); - let eq = match !backend with FStar -> "=" | Coq -> ":=" in + let eq = match !backend with FStar -> "=" | Coq | Lean -> ":=" in F.pp_print_string fmt eq); (* Close ": TYPE =" box (depth=2) *) F.pp_close_box fmt (); @@ -2149,8 +2239,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break then the name of the corresponding LLBC declaration *) F.pp_print_break fmt 0 0; - F.pp_print_string fmt - ("(** [" ^ Print.global_name_to_string global.name ^ "] *)"); + extract_comment fmt ("[" ^ Print.global_name_to_string global.name ^ "]"); F.pp_print_space fmt (); let decl_name = ctx_get_global global.def_id ctx in @@ -2177,6 +2266,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let body = match !backend with | FStar -> "eval_global " ^ body_name + | Lean -> "eval_global " ^ body_name ^ " (by simp)" | Coq -> body_name ^ "%global" in F.pp_print_string fmt body)); @@ -2185,7 +2275,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (** Extract a unit test, if the function is a unit function (takes no parameters, returns unit). - + A unit test simply checks that the function normalizes to [Return ()]. F*: @@ -2212,8 +2302,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment *) - F.pp_print_string fmt - ("(** Unit test for [" ^ Print.fun_name_to_string def.basename ^ "] *)"); + extract_comment fmt ("Unit test for [" ^ Print.fun_name_to_string def.basename ^ "]"); F.pp_print_space fmt (); (* Open a box for the test *) F.pp_open_hovbox fmt ctx.indent_incr; @@ -2249,7 +2338,24 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); F.pp_print_string fmt "()"); F.pp_print_space fmt (); - F.pp_print_string fmt ")%return."); + F.pp_print_string fmt ")%return." + | Lean -> + F.pp_print_string fmt "#assert"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let fun_name = + ctx_get_local_function def.def_id def.loop_id def.back_id ctx + in + F.pp_print_string fmt fun_name; + if sg.inputs <> [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "()"); + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + let success = ctx_get_variant (Assumed Result) result_return_id ctx in + F.pp_print_string fmt (success ^ " ())") + ); (* Close the box for the test *) F.pp_close_box fmt (); (* Add a break after *) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 3ad55d37..c8094128 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -118,7 +118,7 @@ type formatter = { char_name : string; int_name : integer_type -> string; str_name : string; - type_decl_kind_to_qualif : decl_kind -> type_decl_kind option -> string; + type_decl_kind_to_qualif : decl_kind -> type_decl_kind option -> bool -> string; (** Compute the qualified for a type definition/declaration. For instance: "type", "and", etc. diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index c6ef4297..f357f33b 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2615,6 +2615,8 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let match_ty = body.ty in let match_e = Switch (fuel0, Match [ fail_branch; success_branch ]) in { e = match_e; ty = match_ty } + | Lean -> + failwith "Not handling fuel in Lean" let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index c42f3a27..6b3d00f3 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -672,8 +672,14 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) * internal count is consistent with the state of the file. *) (* Create the header *) - Printf.fprintf out "(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *)\n"; - Printf.fprintf out "(** [%s]%s *)\n" rust_module_name custom_msg; + begin match !Config.backend with + | Lean -> + Printf.fprintf out "-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS\n"; + Printf.fprintf out "-- [%s]%s\n" rust_module_name custom_msg; + | Coq | FStar -> + Printf.fprintf out "(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *)\n"; + Printf.fprintf out "(** [%s]%s *)\n" rust_module_name custom_msg + end; (match !Config.backend with | FStar -> Printf.fprintf out "module %s\n" module_name; @@ -700,7 +706,14 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) Printf.fprintf out "Require Export %s.\n" m; Printf.fprintf out "Import %s.\n" m) custom_includes; - Printf.fprintf out "Module %s.\n" module_name); + Printf.fprintf out "Module %s.\n" module_name + | Lean -> + Printf.fprintf out "import Primitives\nopen result\n\n"; + (* Add the custom imports *) + List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_imports; + (* Add the custom includes *) + List.iter (fun m -> Printf.fprintf out "import %s\n" m) custom_includes + ); (* From now onwards, we use the formatter *) (* Set the margin *) Format.pp_set_margin fmt 80; @@ -717,7 +730,7 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) (* Close the module *) (match !Config.backend with - | FStar -> () + | FStar | Lean -> () | Coq -> Printf.fprintf out "End %s .\n" module_name); (* Some logging *) @@ -846,6 +859,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : match !Config.backend with | Config.FStar -> ("/backends/fstar/Primitives.fst", "Primitives.fst") | Config.Coq -> ("/backends/coq/Primitives.v", "Primitives.v") + | Config.Lean -> ("/backends/lean/Primitives.lean", "Primitives.lean") in let src = open_in (exe_dir ^ primitives_src) in let tgt_filename = Filename.concat dest_dir primitives_destname in @@ -875,8 +889,9 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : in let module_delimiter = - match !Config.backend with FStar -> "." | Coq -> "_" + match !Config.backend with FStar | Lean -> "." | Coq -> "_" in + let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" | Lean -> ".lean" in (* Extract one or several files, depending on the configuration *) if !Config.split_files then ( @@ -904,7 +919,8 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : let types_filename_ext = match !Config.backend with | FStar -> if has_opaque_types then ".fsti" else ".fst" - | Coq -> if has_opaque_types then ".v" else ".v" + | Coq -> ".v" + | Lean -> ".lean" in let types_file_suffix = module_delimiter ^ "Types" in let types_filename = @@ -928,24 +944,22 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : !Config.extract_decreases_clauses && not (PureUtils.FunLoopIdSet.is_empty rec_functions) in - (if needs_clauses_module && !Config.extract_template_decreases_clauses then - let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in - let clauses_file_suffix = - module_delimiter ^ "Clauses" ^ module_delimiter ^ "Template" - in - let clauses_filename = extract_filebasename ^ clauses_file_suffix ^ ext in - let clauses_module = module_name ^ clauses_file_suffix in - let clauses_config = - { base_gen_config with extract_template_decreases_clauses = true } - in - extract_file clauses_config gen_ctx clauses_filename crate.A.name - clauses_module ": templates for the decreases clauses" [ types_module ] - []); + if needs_clauses_module && !Config.extract_template_decreases_clauses then ( + let clauses_file_suffix = + module_delimiter ^ "Clauses" ^ module_delimiter ^ "Template" + in + let clauses_filename = extract_filebasename ^ clauses_file_suffix ^ ext in + let clauses_module = module_name ^ clauses_file_suffix in + let clauses_config = + { base_gen_config with extract_template_decreases_clauses = true } + in + extract_file clauses_config gen_ctx clauses_filename crate.A.name + clauses_module ": templates for the decreases clauses" [ types_module ] + []); (* Extract the opaque functions, if needed *) let opaque_funs_module = if has_opaque_funs then ( - let ext = match !Config.backend with FStar -> ".fsti" | Coq -> ".v" in let opaque_file_suffix = module_delimiter ^ "Opaque" in let opaque_filename = extract_filebasename ^ opaque_file_suffix ^ ext in let opaque_module = module_name ^ opaque_file_suffix in @@ -965,7 +979,6 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : in (* Extract the functions *) - let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in let fun_file_suffix = module_delimiter ^ "Funs" in let fun_filename = extract_filebasename ^ fun_file_suffix ^ ext in let fun_module = module_name ^ fun_file_suffix in @@ -1000,7 +1013,6 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : } in (* Add the extension for F* *) - let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in let extract_filename = extract_filebasename ^ ext in extract_file gen_config gen_ctx extract_filename crate.A.name module_name "" [] [] diff --git a/compiler/dune b/compiler/dune index ae9cef04..b74b65fa 100644 --- a/compiler/dune +++ b/compiler/dune @@ -81,7 +81,7 @@ -g ;-dsource -warn-error - -5-8-9-11-14-33-20-21-26-27-39)) + -5@8-9-11-14-33-20-21-26-27-39)) (release (flags :standard diff --git a/tests/Makefile b/tests/Makefile index dfb20cc4..a6a85d2d 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -1,3 +1,4 @@ -all: - cd fstar && $(MAKE) all - cd coq && $(MAKE) all +all: test-fstar test-coq test-lean + +test-%: + cd $* && $(MAKE) all |