summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon HO2024-05-24 17:08:42 +0200
committerGitHub2024-05-24 17:08:42 +0200
commitfbfa0e13ab56ee847e891fa7d798d2eb226b6794 (patch)
tree58f0a2de653f57a986bb5e5f26453a1fbdf0ef17
parent0baa0519cf477fe1fa447417585960fc811bcae9 (diff)
parent169af47945f013e61b14d67e7ebdc9c03636c5a2 (diff)
Merge pull request #194 from AeneasVerif/afromher/recursive_projectors
Support field projectors for recursive structs in Lean backend
Diffstat (limited to '')
-rw-r--r--compiler/ExtractTypes.ml71
-rw-r--r--compiler/SymbolicToPure.ml5
-rw-r--r--tests/lean/BetreeMain/Funs.lean81
-rw-r--r--tests/lean/BetreeMain/Types.lean16
-rw-r--r--tests/test_runner/aeneas_test_runner.opam2
5 files changed, 106 insertions, 69 deletions
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index a2d4758b..cc0c351d 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -1666,14 +1666,15 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
(** Auxiliary function.
- Generate field projectors in Coq.
+ Generate field projectors for Lean and Coq.
- Sometimes we extract records as inductives in Coq: when this happens we
- have to define the field projectors afterwards.
+ Recursive structs are defined as inductives in Lean and Coq.
+ Field projectors allow to retrieve the facilities provided by
+ Lean structures.
*)
let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
(fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit =
- sanity_check __FILE__ __LINE__ (!backend = Coq) decl.span;
+ sanity_check __FILE__ __LINE__ (!backend = Coq || !backend = Lean) decl.span;
match decl.kind with
| Opaque | Enum _ -> ()
| Struct fields ->
@@ -1685,29 +1686,60 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
ctx_add_generic_params decl.span decl.llbc_name decl.llbc_generics
decl.generics ctx
in
+ (* Record_var will be the ADT argument to the projector *)
let ctx, record_var = ctx_add_var decl.span "x" (VarId.of_int 0) ctx in
+ (* Field_var will be the variable in the constructor that is returned by the projector *)
let ctx, field_var = ctx_add_var decl.span "x" (VarId.of_int 1) ctx in
+ (* Name of the ADT *)
let def_name = ctx_get_local_type decl.span decl.def_id ctx in
+ (* Name of the ADT constructor. As we are in the struct case, we only have
+ one constructor *)
let cons_name = ctx_get_struct decl.span (TAdtId decl.def_id) ctx in
+
let extract_field_proj (field_id : FieldId.id) (_ : field) : unit =
F.pp_print_space fmt ();
(* Outer box for the projector definition *)
F.pp_open_hvbox fmt 0;
(* Inner box for the projector definition *)
F.pp_open_hvbox fmt ctx.indent_incr;
- (* Open a box for the [Definition PROJ ... :=] *)
+
+ (* For Lean: add some attributes *)
+ if !backend = Lean then (
+ (* Box for the attributes *)
+ F.pp_open_vbox fmt 0;
+ (* Annotate the projectors with both simp and reducible.
+ The first one allows to automatically unfold when calling simp in proofs.
+ The second ensures that projectors will interact well with the unifier *)
+ F.pp_print_string fmt "@[simp, reducible]";
+ F.pp_print_break fmt 0 0;
+ (* Close box for the attributes *)
+ F.pp_close_box fmt ());
+
+ (* Box for the [def ADT.proj ... :=] *)
F.pp_open_hovbox fmt ctx.indent_incr;
- F.pp_print_string fmt "Definition";
+ (match !backend with
+ | Lean -> F.pp_print_string fmt "def"
+ | Coq -> F.pp_print_string fmt "Definition"
+ | _ -> internal_error __FILE__ __LINE__ decl.span);
F.pp_print_space fmt ();
+
+ (* Print the function name. In Lean, the syntax ADT.proj will
+ allow us to call x.proj for any x of type ADT. In Coq,
+ we will have to introduce a notation for the projector. *)
let field_name =
ctx_get_field decl.span (TAdtId decl.def_id) field_id ctx
in
+ if !backend = Lean then (
+ F.pp_print_string fmt def_name;
+ F.pp_print_string fmt ".");
F.pp_print_string fmt field_name;
+
(* Print the generics *)
let as_implicits = true in
extract_generic_params decl.span ctx fmt TypeDeclId.Set.empty
~as_implicits decl.generics type_params cg_params trait_clauses;
- (* Print the record parameter *)
+
+ (* Print the record parameter as "(x : ADT)" *)
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
F.pp_print_string fmt record_var;
@@ -1721,14 +1753,17 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_string fmt p)
type_params;
F.pp_print_string fmt ")";
- (* *)
+
F.pp_print_space fmt ();
F.pp_print_string fmt ":=";
- (* Close the box for the [Definition PROJ ... :=] *)
+
+ (* Close the box for the [def ADT.proj ... :=] *)
F.pp_close_box fmt ();
F.pp_print_space fmt ();
+
(* Open a box for the whole match *)
F.pp_open_hvbox fmt 0;
+
(* Open a box for the [match ... with] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt "match";
@@ -1758,9 +1793,12 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_string fmt field_var;
(* Close the box for the branch *)
F.pp_close_box fmt ();
+
(* Print the [end] *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt "end";
+ if !backend = Coq then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "end");
+
(* Close the box for the whole match *)
F.pp_close_box fmt ();
(* Close the inner box projector *)
@@ -1769,12 +1807,13 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
if !backend = Coq then (
F.pp_print_cut fmt ();
F.pp_print_string fmt ".");
- (* Close the outer box projector *)
+ (* Close the outer box for projector definition *)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0
in
+ (* Only for Coq: we need to define a notation for the projector *)
let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit =
F.pp_print_space fmt ();
(* Outer box for the projector definition *)
@@ -1815,7 +1854,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let extract_field_proj_and_notation (field_id : FieldId.id)
(field : field) : unit =
extract_field_proj field_id field;
- extract_proj_notation field_id field
+ if !backend = Coq then extract_proj_notation field_id field
in
FieldId.iteri extract_field_proj_and_notation fields
@@ -1828,14 +1867,14 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (decl : type_decl) : unit =
match !backend with
- | FStar | Lean | HOL4 -> ()
- | Coq ->
+ | FStar | HOL4 -> ()
+ | Lean | Coq ->
if
not
(TypesUtils.type_decl_from_decl_id_is_tuple_struct
ctx.trans_ctx.type_ctx.type_infos decl.def_id)
then (
- extract_type_decl_coq_arguments ctx fmt kind decl;
+ if !backend = Coq then extract_type_decl_coq_arguments ctx fmt kind decl;
extract_type_decl_record_field_projectors ctx fmt kind decl)
(** Extract the state type declaration. *)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 8dfe0abe..d6d2e018 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2903,14 +2903,9 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
- if the ADT is an enumeration (which must have exactly one branch)
- if we forbid using field projectors.
*)
- let is_rec_def =
- T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls
- in
let use_let_with_cons =
is_enum
|| !Config.dont_use_field_projectors
- (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *)
- || (!Config.backend = Lean && is_rec_def)
(* Also, there is a special case when the ADT is to be extracted as
a tuple, because it is a structure with unnamed fields. Some backends
like Lean have projectors for tuples (like so: `x.3`), but others
diff --git a/tests/lean/BetreeMain/Funs.lean b/tests/lean/BetreeMain/Funs.lean
index 7cc52159..f6fda6db 100644
--- a/tests/lean/BetreeMain/Funs.lean
+++ b/tests/lean/BetreeMain/Funs.lean
@@ -250,16 +250,15 @@ mutual divergent def betree.Internal.lookup_in_children
(self : betree.Internal) (key : U64) (st : State) :
Result (State × ((Option U64) × betree.Internal))
:=
- let ⟨ i, i1, n, n1 ⟩ := self
- if key < i1
+ if key < self.pivot
then
do
- let (st1, (o, n2)) ← betree.Node.lookup n key st
- Result.ok (st1, (o, betree.Internal.mk i i1 n2 n1))
+ let (st1, (o, n)) ← betree.Node.lookup self.left key st
+ Result.ok (st1, (o, betree.Internal.mk self.id self.pivot n self.right))
else
do
- let (st1, (o, n2)) ← betree.Node.lookup n1 key st
- Result.ok (st1, (o, betree.Internal.mk i i1 n n2))
+ let (st1, (o, n)) ← betree.Node.lookup self.right key st
+ Result.ok (st1, (o, betree.Internal.mk self.id self.pivot self.left n))
/- [betree_main::betree::{betree_main::betree::Node#5}::lookup]:
Source: 'src/betree.rs', lines 709:4-709:58 -/
@@ -270,8 +269,7 @@ divergent def betree.Node.lookup
match self with
| betree.Node.Internal node =>
do
- let ⟨ i, i1, n, n1 ⟩ := node
- let (st1, msgs) ← betree.load_internal_node i st
+ let (st1, msgs) ← betree.load_internal_node node.id st
let (pending, lookup_first_message_for_key_back) ←
betree.Node.lookup_first_message_for_key key msgs
match pending with
@@ -281,8 +279,7 @@ divergent def betree.Node.lookup
then
do
let (st2, (o, node1)) ←
- betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1) key
- st1
+ betree.Internal.lookup_in_children node key st1
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k, msg) l)
Result.ok (st2, (o, betree.Node.Internal node1))
@@ -293,33 +290,26 @@ divergent def betree.Node.lookup
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k,
betree.Message.Insert v) l)
- Result.ok (st1, (some v, betree.Node.Internal (betree.Internal.mk i
- i1 n n1)))
+ Result.ok (st1, (some v, betree.Node.Internal node))
| betree.Message.Delete =>
do
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k,
betree.Message.Delete) l)
- Result.ok (st1, (none, betree.Node.Internal (betree.Internal.mk i i1
- n n1)))
+ Result.ok (st1, (none, betree.Node.Internal node))
| betree.Message.Upsert ufs =>
do
let (st2, (v, node1)) ←
- betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1)
- key st1
+ betree.Internal.lookup_in_children node key st1
let (v1, pending1) ←
betree.Node.apply_upserts (betree.List.Cons (k,
betree.Message.Upsert ufs) l) v key
- let ⟨ i2, i3, n2, n3 ⟩ := node1
let msgs1 ← lookup_first_message_for_key_back pending1
- let (st3, _) ← betree.store_internal_node i2 msgs1 st2
- Result.ok (st3, (some v1, betree.Node.Internal (betree.Internal.mk i2
- i3 n2 n3)))
+ let (st3, _) ← betree.store_internal_node node1.id msgs1 st2
+ Result.ok (st3, (some v1, betree.Node.Internal node1))
| betree.List.Nil =>
do
- let (st2, (o, node1)) ←
- betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1) key
- st1
+ let (st2, (o, node1)) ← betree.Internal.lookup_in_children node key st1
let _ ← lookup_first_message_for_key_back betree.List.Nil
Result.ok (st2, (o, betree.Node.Internal node1))
| betree.Node.Leaf node =>
@@ -541,34 +531,36 @@ mutual divergent def betree.Internal.flush
× betree.NodeIdCounter)))
:=
do
- let ⟨ i, i1, n, n1 ⟩ := self
- let p ← betree.ListPairU64T.partition_at_pivot betree.Message content i1
+ let p ←
+ betree.ListPairU64T.partition_at_pivot betree.Message content self.pivot
let (msgs_left, msgs_right) := p
let len_left ← betree.List.len (U64 × betree.Message) msgs_left
if len_left >= params.min_flush_size
then
do
let (st1, p1) ←
- betree.Node.apply_messages n params node_id_cnt msgs_left st
- let (n2, node_id_cnt1) := p1
+ betree.Node.apply_messages self.left params node_id_cnt msgs_left st
+ let (n, node_id_cnt1) := p1
let len_right ← betree.List.len (U64 × betree.Message) msgs_right
if len_right >= params.min_flush_size
then
do
let (st2, p2) ←
- betree.Node.apply_messages n1 params node_id_cnt1 msgs_right st1
- let (n3, node_id_cnt2) := p2
- Result.ok (st2, (betree.List.Nil, (betree.Internal.mk i i1 n2 n3,
- node_id_cnt2)))
+ betree.Node.apply_messages self.right params node_id_cnt1 msgs_right
+ st1
+ let (n1, node_id_cnt2) := p2
+ Result.ok (st2, (betree.List.Nil, (betree.Internal.mk self.id self.pivot
+ n n1, node_id_cnt2)))
else
- Result.ok (st1, (msgs_right, (betree.Internal.mk i i1 n2 n1,
- node_id_cnt1)))
+ Result.ok (st1, (msgs_right, (betree.Internal.mk self.id self.pivot n
+ self.right, node_id_cnt1)))
else
do
let (st1, p1) ←
- betree.Node.apply_messages n1 params node_id_cnt msgs_right st
- let (n2, node_id_cnt1) := p1
- Result.ok (st1, (msgs_left, (betree.Internal.mk i i1 n n2, node_id_cnt1)))
+ betree.Node.apply_messages self.right params node_id_cnt msgs_right st
+ let (n, node_id_cnt1) := p1
+ Result.ok (st1, (msgs_left, (betree.Internal.mk self.id self.pivot
+ self.left n, node_id_cnt1)))
/- [betree_main::betree::{betree_main::betree::Node#5}::apply_messages]:
Source: 'src/betree.rs', lines 588:4-593:5 -/
@@ -581,26 +573,21 @@ divergent def betree.Node.apply_messages
match self with
| betree.Node.Internal node =>
do
- let ⟨ i, i1, n, n1 ⟩ := node
- let (st1, content) ← betree.load_internal_node i st
+ let (st1, content) ← betree.load_internal_node node.id st
let content1 ← betree.Node.apply_messages_to_internal content msgs
let num_msgs ← betree.List.len (U64 × betree.Message) content1
if num_msgs >= params.min_flush_size
then
do
let (st2, (content2, p)) ←
- betree.Internal.flush (betree.Internal.mk i i1 n n1) params node_id_cnt
- content1 st1
+ betree.Internal.flush node params node_id_cnt content1 st1
let (node1, node_id_cnt1) := p
- let ⟨ i2, i3, n2, n3 ⟩ := node1
- let (st3, _) ← betree.store_internal_node i2 content2 st2
- Result.ok (st3, (betree.Node.Internal (betree.Internal.mk i2 i3 n2 n3),
- node_id_cnt1))
+ let (st3, _) ← betree.store_internal_node node1.id content2 st2
+ Result.ok (st3, (betree.Node.Internal node1, node_id_cnt1))
else
do
- let (st2, _) ← betree.store_internal_node i content1 st1
- Result.ok (st2, (betree.Node.Internal (betree.Internal.mk i i1 n n1),
- node_id_cnt))
+ let (st2, _) ← betree.store_internal_node node.id content1 st1
+ Result.ok (st2, (betree.Node.Internal node, node_id_cnt))
| betree.Node.Leaf node =>
do
let (st1, content) ← betree.load_leaf_node node.id st
diff --git a/tests/lean/BetreeMain/Types.lean b/tests/lean/BetreeMain/Types.lean
index 877508f6..e79da43f 100644
--- a/tests/lean/BetreeMain/Types.lean
+++ b/tests/lean/BetreeMain/Types.lean
@@ -46,6 +46,22 @@ inductive betree.Node :=
end
+@[simp, reducible]
+def betree.Internal.id (x : betree.Internal) :=
+ match x with | betree.Internal.mk x1 _ _ _ => x1
+
+@[simp, reducible]
+def betree.Internal.pivot (x : betree.Internal) :=
+ match x with | betree.Internal.mk _ x1 _ _ => x1
+
+@[simp, reducible]
+def betree.Internal.left (x : betree.Internal) :=
+ match x with | betree.Internal.mk _ _ x1 _ => x1
+
+@[simp, reducible]
+def betree.Internal.right (x : betree.Internal) :=
+ match x with | betree.Internal.mk _ _ _ x1 => x1
+
/- [betree_main::betree::Params]
Source: 'src/betree.rs', lines 187:0-187:13 -/
structure betree.Params where
diff --git a/tests/test_runner/aeneas_test_runner.opam b/tests/test_runner/aeneas_test_runner.opam
index b57cc9f6..1539c521 100644
--- a/tests/test_runner/aeneas_test_runner.opam
+++ b/tests/test_runner/aeneas_test_runner.opam
@@ -7,7 +7,7 @@ homepage: "https://github.com/AeneasVerif/aeneas"
bug-reports: "https://github.com/AeneasVerif/aeneas/issues"
depends: [
"ocaml"
- "dune" {>= "3.12"}
+ "dune" {>= "3.7"}
"odoc" {with-doc}
]
build: [