From fc21cf96f80ccb7e6455c057987bb0ff4597c0bb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 13 Nov 2022 23:00:38 +0100 Subject: Make good progress on the Coq backend --- .gitignore | 11 + Makefile | 96 +- backends/coq/Primitives.v | 478 ++++++++ compiler/Config.ml | 67 +- compiler/Driver.ml | 36 +- compiler/Extract.ml | 2001 ++++++++++++++++++++++++++++++++++ compiler/ExtractBase.ml | 811 ++++++++++++++ compiler/ExtractToBackend.ml | 1639 ---------------------------- compiler/ExtractToCoq.ml | 8 + compiler/ExtractToFStar.ml | 8 + compiler/Logging.ml | 4 +- compiler/PureMicroPasses.ml | 1 + compiler/PureToExtract.ml | 734 ------------- compiler/PureUtils.ml | 4 +- compiler/SymbolicAst.ml | 4 + compiler/SymbolicToPure.ml | 28 +- compiler/Translate.ml | 224 ++-- compiler/Values.ml | 1 + compiler/dune | 6 +- tests/Makefile | 3 + tests/coq/Makefile | 3 + tests/coq/misc/Constants.v | 138 +++ tests/coq/misc/External__Funs.v | 100 ++ tests/coq/misc/External__Opaque.v | 36 + tests/coq/misc/External__Types.v | 15 + tests/coq/misc/Makefile | 22 + tests/coq/misc/NoNestedBorrows.v | 510 +++++++++ tests/coq/misc/Paper.v | 114 ++ tests/coq/misc/Primitives.v | 478 ++++++++ tests/coq/misc/_CoqProject | 12 + tests/fstar/misc/NoNestedBorrows.fst | 53 - 31 files changed, 5094 insertions(+), 2551 deletions(-) create mode 100644 backends/coq/Primitives.v create mode 100644 compiler/Extract.ml create mode 100644 compiler/ExtractBase.ml delete mode 100644 compiler/ExtractToBackend.ml create mode 100644 compiler/ExtractToCoq.ml create mode 100644 compiler/ExtractToFStar.ml delete mode 100644 compiler/PureToExtract.ml create mode 100644 tests/Makefile create mode 100644 tests/coq/Makefile create mode 100644 tests/coq/misc/Constants.v create mode 100644 tests/coq/misc/External__Funs.v create mode 100644 tests/coq/misc/External__Opaque.v create mode 100644 tests/coq/misc/External__Types.v create mode 100644 tests/coq/misc/Makefile create mode 100644 tests/coq/misc/NoNestedBorrows.v create mode 100644 tests/coq/misc/Paper.v create mode 100644 tests/coq/misc/Primitives.v create mode 100644 tests/coq/misc/_CoqProject diff --git a/.gitignore b/.gitignore index c0bbbf28..d0a4be47 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,17 @@ tests/fstar/hashmap/obj/ tests/fstar/hashmap_on_disk/obj/ tests/fstar/misc/obj/ +# Coq +*.vo +*.vok +*.vos +*.glob +*.aux +*.lia.cache +*.d +*Makefile.coq +*CoqMakefile.conf + # Misc /fstar-tests *~ diff --git a/Makefile b/Makefile index 6f35eaac..c4ec50ef 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ AENEAS_EXE ?= bin/aeneas.exe # - unfold all the monadic let bindings to matches (required by F*) # - insert calls to the normalizer in the translated code to test the # generated unit functions -OPTIONS += -unfold-monads -test-trans-units +OPTIONS += # # The rules use (and update) the following variables @@ -43,6 +43,8 @@ OPTIONS += -unfold-monads -test-trans-units CHARON_TEST_DIR = # The options with which to call Charon CHARON_OPTIONS = +# The backend sub-directory in which to generate the files +BACKEND_SUBDIR := # The directory in which to extract the result of the translation SUBDIR := @@ -86,8 +88,8 @@ clean: tests: trans-no_nested_borrows trans-paper \ trans-hashmap trans-hashmap_main \ trans-external trans-constants \ - trans-polonius-polonius_list trans-polonius-betree_main \ - test-trans-polonius-betree_main + transp-polonius_list transp-betree_main \ + test-transp-betree_main # Verify the F* files generated by the translation .PHONY: verify @@ -106,70 +108,102 @@ CHARON_CMD = cd $(CHARON_TEST_DIR) && NOT_ALL_TESTS=1 $(MAKE) test-$* endif # The command to run Aeneas on the proper llbc file -AENEAS_CMD = $(AENEAS_EXE) $(CHARON_TEST_DIR)/llbc/$(FILE).llbc -dest tests/fstar/$(SUBDIR) $(OPTIONS) +AENEAS_CMD = $(AENEAS_EXE) $(CHARON_TEST_DIR)/llbc/$(FILE).llbc -dest tests/$(BACKEND_SUBDIR)/$(SUBDIR) $(OPTIONS) # Add specific options to some tests trans-no_nested_borrows trans-paper: \ - OPTIONS += -test-units -test-trans-units -no-split-files -no-state -no-decreases-clauses + OPTIONS += -test-units -test-trans-units -no-split-files -no-state trans-no_nested_borrows trans-paper: SUBDIR:=misc +tfstar-no_nested_borrows tfstar-paper: -trans-hashmap: OPTIONS += -template-clauses -no-state +trans-hashmap: OPTIONS += -no-state trans-hashmap: SUBDIR:=hashmap +tfstar-hashmap: OPTIONS += -decreases-clauses -template-clauses -trans-hashmap_main: OPTIONS += -template-clauses +trans-hashmap_main: OPTIONS += trans-hashmap_main: SUBDIR:=hashmap_on_disk +tfstar-hashmap_main: OPTIONS += -decreases-clauses -template-clauses -trans-polonius-polonius_list: OPTIONS += -test-units -test-trans-units -no-split-files -no-state -no-decreases-clauses -trans-polonius-polonius_list: SUBDIR:=misc +transp-polonius_list: OPTIONS += -test-units -test-trans-units -no-split-files -no-state +transp-polonius_list: SUBDIR:=misc +tfstarp-polonius_list: OPTIONS += -trans-constants: OPTIONS += -test-units -test-trans-units -no-split-files -no-state -no-decreases-clauses +trans-constants: OPTIONS += -test-units -test-trans-units -no-split-files -no-state trans-constants: SUBDIR:=misc +tfstar-constants: OPTIONS += trans-external: OPTIONS += trans-external: SUBDIR:=misc +tfstar-external: OPTIONS += -BETREE_OPTIONS = -template-clauses -trans-polonius-betree_main: OPTIONS += $(BETREE_OPTIONS) -backward-no-state-update -trans-polonius-betree_main: SUBDIR:=betree +BETREE_FSTAR_OPTIONS = -decreases-clauses -template-clauses +transp-betree_main: OPTIONS += -backward-no-state-update +transp-betree_main: SUBDIR:=betree +tfstarp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS) # Additional test on the betree: translate it without `-backward-no-state-update`. # This generates very ugly code, but is good to test the translation. -test-trans-polonius-betree_main: trans-polonius-betree_main -test-trans-polonius-betree_main: OPTIONS += $(BETREE_OPTIONS) -test-trans-polonius-betree_main: SUBDIR:=betree_back_stateful -test-trans-polonius-betree_main: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) -test-trans-polonius-betree_main: FILE = betree_main -test-trans-polonius-betree_main: +.PHONY: test-transp-betree_main +test-transp-betree_main: transp-betree_main +test-transp-betree_main: OPTIONS += -backend fstar -unfold-monads -test-trans-units +test-transp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS) +test-transp-betree_main: BACKEND_SUBDIR := "fstar" +test-transp-betree_main: SUBDIR:=betree_back_stateful +test-transp-betree_main: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) +test-transp-betree_main: FILE = betree_main +test-transp-betree_main: $(AENEAS_CMD) # Generic rules to extract the LLBC from a rust file # We use the rules in Charon's Makefile to generate the .llbc files: the options # vary with the test files. -.PHONY: gen-llbc-polonius-% -gen-llbc-polonius-%: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) -gen-llbc-polonius-%: - $(CHARON_CMD) - .PHONY: gen-llbc-% gen-llbc-%: CHARON_TEST_DIR = $(CHARON_TESTS_REGULAR_DIR) gen-llbc-%: $(CHARON_CMD) -# Generic rule to test the translation of an LLBC file. +# "p" stands for "Polonius" +.PHONY: gen-llbcp-% +gen-llbcp-%: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) +gen-llbcp-%: + $(CHARON_CMD) + +# Generic rules to test the translation of an LLBC file. # Note that the files requiring the Polonius borrow-checker are generated # in the tests-polonius subdirectory. .PHONY: trans-% -trans-%: -trans-polonius-%: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) trans-%: CHARON_TEST_DIR = $(CHARON_TESTS_REGULAR_DIR) +trans-%: FILE = $* +trans-%: gen-llbc-% tfstar-% tcoq-% + echo "# Test $* done" + +# "p" stands for "Polonius" +.PHONY: transp-% +transp-%: CHARON_TEST_DIR = $(CHARON_TESTS_POLONIUS_DIR) +transp-%: FILE = $* +transp-%: gen-llbcp-% tfstarp-% + echo "# Test $* done" + +.PHONY: tfstar-% +tfstar-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstar-%: BACKEND_SUBDIR := fstar +tfstar-%: + $(AENEAS_CMD) -trans-polonius-%: FILE = $* -trans-polonius-%: gen-llbc-polonius-% +# "p" stands for "Polonius" +.PHONY: tfstarp-% +tfstarp-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstarp-%: BACKEND_SUBDIR := fstar +tfstarp-%: $(AENEAS_CMD) -trans-%: FILE = $* -trans-%: gen-llbc-% +# TODO: -test-trans-units +# It doesn't work on vec_push_fwd, I don't understand why. +.PHONY: tcoq-% +tcoq-%: OPTIONS += -backend coq -decompose-monads +tcoq-%: BACKEND_SUBDIR := coq +tcoq-%: $(AENEAS_CMD) # Nix diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v new file mode 100644 index 00000000..c27b8aed --- /dev/null +++ b/backends/coq/Primitives.v @@ -0,0 +1,478 @@ +Require Import Lia. +Require Coq.Strings.Ascii. +Require Coq.Strings.String. +Require Import Coq.Program.Equality. +Require Import Coq.ZArith.ZArith. +Require Import Coq.ZArith.Znat. +Require Import List. +Import ListNotations. + +Module Primitives. + + (* TODO: use more *) +Declare Scope Primitives_scope. + +(*** Result *) + +Inductive result A := + | Return : A -> result A + | Fail_ : result A. + +Arguments Return {_} a. +Arguments Fail_ {_}. + +Definition bind {A B} (m: result A) (f: A -> result B) : result B := + match m with + | Fail_ => Fail_ + | Return x => f x + end. + +Definition return_ {A: Type} (x: A) : result A := Return x . +Definition fail_ {A: Type} : result A := Fail_ . + +Notation "x <- c1 ; c2" := (bind c1 (fun x => c2)) + (at level 61, c1 at next level, right associativity). + +(** Monadic assert *) +Definition massert (b: bool) : result unit := + if b then Return tt else Fail_. + +(** Normalize and unwrap a successful result (used for globals) *) +Definition eval_result_refl {A} {x} (a: result A) (p: a = Return x) : A := + match a as r return (r = Return x -> A) with + | Return a' => fun _ => a' + | Fail_ => fun p' => + False_rect _ (eq_ind Fail_ + (fun e : result A => + match e with + | Return _ => False + | Fail_ => True + end) + I (Return x) p') + end p. + +Notation "x %global" := (eval_result_refl x eq_refl) (at level 40). +Notation "x %return" := (eval_result_refl x eq_refl) (at level 40). + +(* Sanity check *) +Check (if true then Return (1 + 2) else Fail_)%global = 3. + +(*** Misc *) + + +Definition string := Coq.Strings.String.string. +Definition char := Coq.Strings.Ascii.ascii. +Definition char_of_byte := Coq.Strings.Ascii.ascii_of_byte. + +Definition mem_replace_fwd (a : Type) (x : a) (y : a) : a := x . +Definition mem_replace_back (a : Type) (x : a) (y : a) : a := y . + +(*** Scalars *) + +Definition i8_min : Z := -128%Z. +Definition i8_max : Z := 127%Z. +Definition i16_min : Z := -32768%Z. +Definition i16_max : Z := 32767%Z. +Definition i32_min : Z := -2147483648%Z. +Definition i32_max : Z := 2147483647%Z. +Definition i64_min : Z := -9223372036854775808%Z. +Definition i64_max : Z := 9223372036854775807%Z. +Definition i128_min : Z := -170141183460469231731687303715884105728%Z. +Definition i128_max : Z := 170141183460469231731687303715884105727%Z. +Definition u8_min : Z := 0%Z. +Definition u8_max : Z := 255%Z. +Definition u16_min : Z := 0%Z. +Definition u16_max : Z := 65535%Z. +Definition u32_min : Z := 0%Z. +Definition u32_max : Z := 4294967295%Z. +Definition u64_min : Z := 0%Z. +Definition u64_max : Z := 18446744073709551615%Z. +Definition u128_min : Z := 0%Z. +Definition u128_max : Z := 340282366920938463463374607431768211455%Z. + +(** The bounds of [isize] and [usize] vary with the architecture. *) +Axiom isize_min : Z. +Axiom isize_max : Z. +Definition usize_min : Z := 0%Z. +Axiom usize_max : Z. + +Open Scope Z_scope. + +(** We provide those lemmas to reason about the bounds of [isize] and [usize] *) +Axiom isize_min_bound : isize_min <= i32_min. +Axiom isize_max_bound : i32_max <= isize_max. +Axiom usize_max_bound : u32_max <= usize_max. + +Inductive scalar_ty := + | Isize + | I8 + | I16 + | I32 + | I64 + | I128 + | Usize + | U8 + | U16 + | U32 + | U64 + | U128 +. + +Definition scalar_min (ty: scalar_ty) : Z := + match ty with + | Isize => isize_min + | I8 => i8_min + | I16 => i16_min + | I32 => i32_min + | I64 => i64_min + | I128 => i128_min + | Usize => usize_min + | U8 => u8_min + | U16 => u16_min + | U32 => u32_min + | U64 => u64_min + | U128 => u128_min +end. + +Definition scalar_max (ty: scalar_ty) : Z := + match ty with + | Isize => isize_max + | I8 => i8_max + | I16 => i16_max + | I32 => i32_max + | I64 => i64_max + | I128 => i128_max + | Usize => usize_max + | U8 => u8_max + | U16 => u16_max + | U32 => u32_max + | U64 => u64_max + | U128 => u128_max +end. + +(** We use the following conservative bounds to make sure we can compute bound + checks in most situations *) +Definition scalar_min_cons (ty: scalar_ty) : Z := + match ty with + | Isize => i32_min + | Usize => u32_min + | _ => scalar_min ty +end. + +Definition scalar_max_cons (ty: scalar_ty) : Z := + match ty with + | Isize => i32_max + | Usize => u32_max + | _ => scalar_max ty +end. + +Lemma scalar_min_cons_valid : forall ty, scalar_min ty <= scalar_min_cons ty . +Proof. + destruct ty; unfold scalar_min_cons, scalar_min; try lia. + - pose isize_min_bound; lia. + - apply Z.le_refl. +Qed. + +Lemma scalar_max_cons_valid : forall ty, scalar_max ty >= scalar_max_cons ty . +Proof. + destruct ty; unfold scalar_max_cons, scalar_max; try lia. + - pose isize_max_bound; lia. + - pose usize_max_bound. lia. +Qed. + +Definition scalar (ty: scalar_ty) : Type := + { x: Z | scalar_min ty <= x <= scalar_max ty }. + +Definition to_Z {ty} (x: scalar ty) : Z := proj1_sig x. + +(** Bounds checks: we start by using the conservative bounds, to make sure we + can compute in most situations, then we use the real bounds (for [isize] + and [usize]). *) +Definition scalar_ge_min (ty: scalar_ty) (x: Z) : bool := + Z.leb (scalar_min_cons ty) x || Z.leb (scalar_min ty) x. + +Definition scalar_le_max (ty: scalar_ty) (x: Z) : bool := + Z.leb x (scalar_max_cons ty) || Z.leb x (scalar_max ty). + +Lemma scalar_ge_min_valid (ty: scalar_ty) (x: Z) : + scalar_ge_min ty x = true -> scalar_min ty <= x . +Proof. + unfold scalar_ge_min. + pose (scalar_min_cons_valid ty). + lia. +Qed. + +Lemma scalar_le_max_valid (ty: scalar_ty) (x: Z) : + scalar_le_max ty x = true -> x <= scalar_max ty . +Proof. + unfold scalar_le_max. + pose (scalar_max_cons_valid ty). + lia. +Qed. + +Definition scalar_in_bounds (ty: scalar_ty) (x: Z) : bool := + scalar_ge_min ty x && scalar_le_max ty x . + +Lemma scalar_in_bounds_valid (ty: scalar_ty) (x: Z) : + scalar_in_bounds ty x = true -> scalar_min ty <= x <= scalar_max ty . +Proof. + unfold scalar_in_bounds. + intros H. + destruct (scalar_ge_min ty x) eqn:Hmin. + - destruct (scalar_le_max ty x) eqn:Hmax. + + pose (scalar_ge_min_valid ty x Hmin). + pose (scalar_le_max_valid ty x Hmax). + lia. + + inversion H. + - inversion H. +Qed. + +Import Sumbool. + +Definition mk_scalar (ty: scalar_ty) (x: Z) : result (scalar ty) := + match sumbool_of_bool (scalar_in_bounds ty x) with + | left H => Return (exist _ x (scalar_in_bounds_valid _ _ H)) + | right _ => Fail_ + end. + +Definition scalar_add {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x + to_Z y). + +Definition scalar_sub {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x - to_Z y). + +Definition scalar_mul {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x * to_Z y). + +Definition scalar_div {ty} (x y: scalar ty) : result (scalar ty) := + if to_Z y =? 0 then Fail_ else + mk_scalar ty (to_Z x / to_Z y). + +Definition scalar_rem {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (Z.rem (to_Z x) (to_Z y)). + +Definition scalar_neg {ty} (x: scalar ty) : result (scalar ty) := mk_scalar ty (-(to_Z x)). + +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +(* TODO: check the semantics of casts in Rust *) +Definition scalar_cast (src_ty tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) := + mk_scalar tgt_ty (to_Z x). + +(** Comparisons *) +Print Z.leb . + +Definition scalar_leb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.leb (to_Z x) (to_Z y) . + +Definition scalar_ltb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.ltb (to_Z x) (to_Z y) . + +Definition scalar_geb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.geb (to_Z x) (to_Z y) . + +Definition scalar_gtb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.gtb (to_Z x) (to_Z y) . + +Definition scalar_eqb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.eqb (to_Z x) (to_Z y) . + +Definition scalar_neqb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + negb (Z.eqb (to_Z x) (to_Z y)) . + + +(** The scalar types *) +Definition isize := scalar Isize. +Definition i8 := scalar I8. +Definition i16 := scalar I16. +Definition i32 := scalar I32. +Definition i64 := scalar I64. +Definition i128 := scalar I128. +Definition usize := scalar Usize. +Definition u8 := scalar U8. +Definition u16 := scalar U16. +Definition u32 := scalar U32. +Definition u64 := scalar U64. +Definition u128 := scalar U128. + +(** Negaion *) +Definition isize_neg := @scalar_neg Isize. +Definition i8_neg := @scalar_neg I8. +Definition i16_neg := @scalar_neg I16. +Definition i32_neg := @scalar_neg I32. +Definition i64_neg := @scalar_neg I64. +Definition i128_neg := @scalar_neg I128. + +(** Division *) +Definition isize_div := @scalar_div Isize. +Definition i8_div := @scalar_div I8. +Definition i16_div := @scalar_div I16. +Definition i32_div := @scalar_div I32. +Definition i64_div := @scalar_div I64. +Definition i128_div := @scalar_div I128. +Definition usize_div := @scalar_div Usize. +Definition u8_div := @scalar_div U8. +Definition u16_div := @scalar_div U16. +Definition u32_div := @scalar_div U32. +Definition u64_div := @scalar_div U64. +Definition u128_div := @scalar_div U128. + +(** Remainder *) +Definition isize_rem := @scalar_rem Isize. +Definition i8_rem := @scalar_rem I8. +Definition i16_rem := @scalar_rem I16. +Definition i32_rem := @scalar_rem I32. +Definition i64_rem := @scalar_rem I64. +Definition i128_rem := @scalar_rem I128. +Definition usize_rem := @scalar_rem Usize. +Definition u8_rem := @scalar_rem U8. +Definition u16_rem := @scalar_rem U16. +Definition u32_rem := @scalar_rem U32. +Definition u64_rem := @scalar_rem U64. +Definition u128_rem := @scalar_rem U128. + +(** Addition *) +Definition isize_add := @scalar_add Isize. +Definition i8_add := @scalar_add I8. +Definition i16_add := @scalar_add I16. +Definition i32_add := @scalar_add I32. +Definition i64_add := @scalar_add I64. +Definition i128_add := @scalar_add I128. +Definition usize_add := @scalar_add Usize. +Definition u8_add := @scalar_add U8. +Definition u16_add := @scalar_add U16. +Definition u32_add := @scalar_add U32. +Definition u64_add := @scalar_add U64. +Definition u128_add := @scalar_add U128. + +(** Substraction *) +Definition isize_sub := @scalar_sub Isize. +Definition i8_sub := @scalar_sub I8. +Definition i16_sub := @scalar_sub I16. +Definition i32_sub := @scalar_sub I32. +Definition i64_sub := @scalar_sub I64. +Definition i128_sub := @scalar_sub I128. +Definition usize_sub := @scalar_sub Usize. +Definition u8_sub := @scalar_sub U8. +Definition u16_sub := @scalar_sub U16. +Definition u32_sub := @scalar_sub U32. +Definition u64_sub := @scalar_sub U64. +Definition u128_sub := @scalar_sub U128. + +(** Multiplication *) +Definition isize_mul := @scalar_mul Isize. +Definition i8_mul := @scalar_mul I8. +Definition i16_mul := @scalar_mul I16. +Definition i32_mul := @scalar_mul I32. +Definition i64_mul := @scalar_mul I64. +Definition i128_mul := @scalar_mul I128. +Definition usize_mul := @scalar_mul Usize. +Definition u8_mul := @scalar_mul U8. +Definition u16_mul := @scalar_mul U16. +Definition u32_mul := @scalar_mul U32. +Definition u64_mul := @scalar_mul U64. +Definition u128_mul := @scalar_mul U128. + +(** Small utility *) +Definition usize_to_nat (x: usize) : nat := Z.to_nat (to_Z x). + +(** Notations *) +Notation "x %isize" := ((mk_scalar Isize x)%return) (at level 9). +Notation "x %i8" := ((mk_scalar I8 x)%return) (at level 9). +Notation "x %i16" := ((mk_scalar I16 x)%return) (at level 9). +Notation "x %i32" := ((mk_scalar I32 x)%return) (at level 9). +Notation "x %i64" := ((mk_scalar I64 x)%return) (at level 9). +Notation "x %i128" := ((mk_scalar I128 x)%return) (at level 9). +Notation "x %usize" := ((mk_scalar Usize x)%return) (at level 9). +Notation "x %u8" := ((mk_scalar U8 x)%return) (at level 9). +Notation "x %u16" := ((mk_scalar U16 x)%return) (at level 9). +Notation "x %u32" := ((mk_scalar U32 x)%return) (at level 9). +Notation "x %u64" := ((mk_scalar U64 x)%return) (at level 9). +Notation "x %u128" := ((mk_scalar U128 x)%return) (at level 9). + +Notation "x s= y" := (scalar_eqb x y) (at level 80) : Primitives_scope. +Notation "x s<> y" := (scalar_neqb x y) (at level 80) : Primitives_scope. +Notation "x s<= y" := (scalar_leb x y) (at level 80) : Primitives_scope. +Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. +Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. +Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. + +(*** Vectors *) + +Definition vec T := { l: list T | Z.of_nat (length l) <= usize_max }. + +Definition vec_to_list {T: Type} (v: vec T) : list T := proj1_sig v. + +Definition vec_length {T: Type} (v: vec T) : Z := Z.of_nat (length (vec_to_list v)). + +Lemma le_0_usize_max : 0 <= usize_max. +Proof. + pose (H := usize_max_bound). + unfold u32_max in H. + lia. +Qed. + +Definition vec_new (T: Type) : vec T := (exist _ [] le_0_usize_max). + +Lemma vec_len_in_usize {T} (v: vec T) : usize_min <= vec_length v <= usize_max. +Proof. + unfold vec_length, usize_min. + split. + - lia. + - apply (proj2_sig v). +Qed. + +Definition vec_len (T: Type) (v: vec T) : usize := + exist _ (vec_length v) (vec_len_in_usize v). + +Fixpoint list_update {A} (l: list A) (n: nat) (a: A) + : list A := + match l with + | [] => [] + | x :: t => match n with + | 0%nat => a :: t + | S m => x :: (list_update t m a) +end end. + +Definition vec_bind {A B} (v: vec A) (f: list A -> result (list B)) : result (vec B) := + l <- f (vec_to_list v) ; + match sumbool_of_bool (scalar_le_max Usize (Z.of_nat (length l))) with + | left H => Return (exist _ l (scalar_le_max_valid _ _ H)) + | right _ => Fail_ + end. + +(* The **forward** function shouldn't be used *) +Definition vec_push_fwd (T: Type) (v: vec T) (x: T) : unit := tt. + +Definition vec_push_back (T: Type) (v: vec T) (x: T) : result (vec T) := + vec_bind v (fun l => Return (l ++ [x])). + +(* The **forward** function shouldn't be used *) +Definition vec_insert_fwd (T: Type) (v: vec T) (i: usize) (x: T) : result unit := + if to_Z i + if to_Z i Return n + | None => Fail_ + end. + +Definition vec_index_back (T: Type) (v: vec T) (i: usize) (x: T) : result unit := + if to_Z i Return n + | None => Fail_ + end. + +Definition vec_index_mut_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) := + vec_bind v (fun l => + if to_Z i Some FStar | "coq" -> Some Coq | _ -> None + +let opt_backend : backend option ref = ref None + +let set_backend (b : string) : unit = + match backend_of_string b with + | Some b -> opt_backend := Some b + | None -> + (* We shouldn't get there: the string should have been checked as + belonging to the proper set *) + raise (Failure "Unexpected") + +(** The backend to which to extract. + + We initialize it with a default value, but it always gets overwritten: + we check that the user provides a backend argument. + *) +let backend = ref FStar + (** {1 Interpreter} *) (** Check that invariants are maintained whenever we execute a statement *) @@ -44,6 +72,39 @@ let allow_bottom_below_borrow = true *) let return_unit_end_abs_with_no_loans = true +(** Forbids using field projectors for structures. + + If we don't use field projectors, whenever we symbolically expand a structure + value (note that accessing a structure field in the symbolic execution triggers + its expansion), then instead of generating code like this: + {[ + let x1 = s.f1 in + let x2 = s.f2 in + ... + ]} + + we generate code like this: + {[ + let Mkstruct x1 x2 ... = s in + ... + ]} + + We use this for instance for Coq, because in Coq we can't define groups + of mutually recursive records and inductives. In such cases, we extract + the structures as inductives, which means that field projectors are not + always available. + + TODO: we could define a notation to take care of this. + TODO: today dont_use_field_projectors is not useful actually. + *) +let dont_use_field_projectors = ref false + +(** Deconstructing ADTs which have only one variant with let-bindings is not always + supported: this parameter controls whether we use let-bindings in such situations or not. + *) + +let always_deconstruct_adts_with_matches = ref false + (** {1 Translation} *) (** Controls whether we need to use a state to model the external world @@ -103,7 +164,7 @@ let test_trans_unit_functions = ref false The body of such clauses must be defined by the user. *) -let extract_decreases_clauses = ref true +let extract_decreases_clauses = ref false (** In order to help the user, we can generate "template" decrease clauses (i.e., definitions with proper signatures but dummy bodies) in a @@ -113,10 +174,10 @@ let extract_template_decreases_clauses = ref false (** {1 Micro passes} *) -(** Some provers like F* don't support the decomposition of return values +(** Some provers like F* and Coq don't support the decomposition of return values in monadic let-bindings: {[ - // NOT supported in F* + (* NOT supported in F*/Coq *) let (x, y) <-- f (); ... ]} diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 3059ec2f..5089cb8e 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -33,6 +33,9 @@ let () = let spec = [ + ( "-backend", + Arg.Symbol (backend_names, set_backend), + " Specify the backend to which to extract" ); ("-dest", Arg.Set_string dest_dir, " Specify the output directory"); ( "-decompose-monads", Arg.Set decompose_monadic_let_bindings, @@ -61,9 +64,9 @@ let () = Arg.Set test_trans_unit_functions, " Test the translated unit functions with the target theorem\n\ \ prover's normalizer" ); - ( "-no-decreases-clauses", - Arg.Clear extract_decreases_clauses, - " Do not add decrease clauses to the recursive definitions" ); + ( "-decreases-clauses", + Arg.Set extract_decreases_clauses, + " Use decreases clauses for the recursive definitions" ); ( "-no-state", Arg.Clear use_state, " Do not use state-error monads, simply use error monads" ); @@ -73,8 +76,8 @@ let () = ( "-template-clauses", Arg.Set extract_template_decreases_clauses, " Generate templates for the required decreases clauses, in a\n\ - \ dedicated file. Incompatible with \ - -no-decreases-clauses" ); + \ dedicated file. Reauires -decreases-clauses" + ); ( "-no-split-files", Arg.Clear split_files, " Don't split the definitions between different files for types,\n\ @@ -98,6 +101,29 @@ let () = print_string usage; exit 1 in + + (* Check that the user specified a backend *) + let _ = + match !opt_backend with + | Some b -> backend := b + | None -> + print_string "Backend not specified (use the `-backend` argument)\n"; + fail () + in + + (* In the case of Coq, we forbid using field projectors (see the comments for + [dont_use_field_projectors]). + Also, we always decompose ADT values with matches (decomposing with + let-bindings is not supported). + *) + let _ = + match !backend with + | FStar -> () + | Coq -> + dont_use_field_projectors := true; + always_deconstruct_adts_with_matches := true + in + (* Retrieve and check the filename *) let filename = match !filenames with diff --git a/compiler/Extract.ml b/compiler/Extract.ml new file mode 100644 index 00000000..f9c4d10a --- /dev/null +++ b/compiler/Extract.ml @@ -0,0 +1,2001 @@ +(** The generic extraction *) +(* Turn the whole module into a functor: it is very annoying to carry the + the formatter everywhere... +*) + +open Utils +open Pure +open PureUtils +open TranslateCore +open ExtractBase +open StringUtils +open Config +module F = Format + +(** Small helper to compute the name of an int type *) +let int_name (int_ty : integer_type) = + match int_ty with + | Isize -> "isize" + | I8 -> "i8" + | I16 -> "i16" + | I32 -> "i32" + | I64 -> "i64" + | I128 -> "i128" + | Usize -> "usize" + | U8 -> "u8" + | U16 -> "u16" + | U32 -> "u32" + | U64 -> "u64" + | U128 -> "u128" + +(** Small helper to compute the name of a unary operation *) +let unop_name (unop : unop) : string = + match unop with + | Not -> ( match !backend with FStar -> "not" | Coq -> "negb") + | Neg int_ty -> int_name int_ty ^ "_neg" + | Cast _ -> raise (Failure "Unsupported") + +(** Small helper to compute the name of a binary operation (note that many + binary operations like "less than" are extracted to primitive operations, + like [<]. + *) +let named_binop_name (binop : E.binop) (int_ty : integer_type) : string = + let binop = + match binop with + | Div -> "div" + | Rem -> "rem" + | Add -> "add" + | Sub -> "sub" + | Mul -> "mul" + | _ -> raise (Failure "Unreachable") + in + int_name int_ty ^ "_" ^ binop + +(** A list of keywords/identifiers used by the backend and with which we + want to check collision. *) +let keywords () = + let named_unops = + unop_name Not + :: List.map (fun it -> unop_name (Neg it)) T.all_signed_int_types + in + let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in + let named_binops = + List.concat + (List.map + (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types) + named_binops) + in + let misc = + match !backend with + | FStar -> + [ + "let"; + "rec"; + "in"; + "fun"; + "fn"; + "val"; + "int"; + "nat"; + "list"; + "FStar"; + "FStar.Mul"; + "type"; + "match"; + "with"; + "assert"; + "assert_norm"; + "assume"; + "Type0"; + "Type"; + "unit"; + "not"; + "scalar_cast"; + ] + | Coq -> + [ + "Record"; + "Inductive"; + "Fixpoint"; + "Definition"; + "Arguments"; + "Notation"; + "Check"; + "Search"; + "SearchPattern"; + "Axiom"; + "Type"; + "Set"; + "let"; + "rec"; + "in"; + "unit"; + "fun"; + "type"; + "int"; + "nat"; + "match"; + "with"; + "assert"; + "not"; + (* [tt] is unit *) + "tt"; + "char_of_byte"; + ] + in + List.concat [ named_unops; named_binops; misc ] + +let assumed_adts : (assumed_ty * string) list = + [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ] + +let assumed_structs : (assumed_ty * string) list = [] + +let assumed_variants () : (assumed_ty * VariantId.id * string) list = + match !backend with + | FStar -> + [ + (Result, result_return_id, "Return"); + (Result, result_fail_id, "Fail"); + (Option, option_some_id, "Some"); + (Option, option_none_id, "None"); + ] + | Coq -> + [ + (Result, result_return_id, "Return"); + (Result, result_fail_id, "Fail_"); + (Option, option_some_id, "Some"); + (Option, option_none_id, "None"); + ] + +let assumed_llbc_functions : + (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + let rg0 = Some T.RegionGroupId.zero in + [ + (Replace, None, "mem_replace_fwd"); + (Replace, rg0, "mem_replace_back"); + (VecNew, None, "vec_new"); + (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); + (VecPush, rg0, "vec_push_back"); + (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); + (VecInsert, rg0, "vec_insert_back"); + (VecLen, None, "vec_len"); + (VecIndex, None, "vec_index_fwd"); + (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); + (VecIndexMut, None, "vec_index_mut_fwd"); + (VecIndexMut, rg0, "vec_index_mut_back"); + ] + +let assumed_pure_functions : (pure_assumed_fun_id * string) list = + match !backend with + | FStar -> [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] + | Coq -> [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] + +let names_map_init () : names_map_init = + { + keywords = keywords (); + assumed_adts; + assumed_structs; + assumed_variants = assumed_variants (); + assumed_llbc_functions; + assumed_pure_functions; + } + +let extract_unop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit + = + match unop with + | Not | Neg _ -> + let unop = unop_name unop in + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt unop; + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")" + | Cast (src, tgt) -> + (* The source type is an implicit parameter *) + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt "scalar_cast"; + F.pp_print_space fmt (); + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string src)); + F.pp_print_space fmt (); + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string tgt)); + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")" + +let extract_binop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (binop : E.binop) + (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit = + if inside then F.pp_print_string fmt "("; + (* Some binary operations have a special treatment *) + (match binop with + | Eq | Lt | Le | Ne | Ge | Gt -> + let binop = + match binop with + | Eq -> "=" + | Lt -> "<" + | Le -> "<=" + | Ne -> "<>" + | Ge -> ">=" + | Gt -> ">" + | _ -> raise (Failure "Unreachable") + in + let binop = match !backend with FStar -> binop | Coq -> "s" ^ binop in + extract_expr false arg0; + F.pp_print_space fmt (); + F.pp_print_string fmt binop; + F.pp_print_space fmt (); + extract_expr false arg1 + | Div | Rem | Add | Sub | Mul -> + let binop = named_binop_name binop int_ty in + F.pp_print_string fmt binop; + F.pp_print_space fmt (); + extract_expr false arg0; + F.pp_print_space fmt (); + extract_expr false arg1 + | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented); + if inside then F.pp_print_string fmt ")" + +let type_decl_kind_to_qualif (kind : decl_kind) + (type_kind : type_decl_kind option) : string = + match !backend with + | FStar -> ( + match kind with + | SingleNonRec -> "type" + | SingleRec -> "type" + | MutRecFirst -> "type" + | MutRecInner -> "and" + | MutRecLast -> "and" + | Assumed -> "assume type" + | Declared -> "val") + | Coq -> ( + match (kind, type_kind) with + | SingleNonRec, Some Enum -> "Inductive" + | SingleNonRec, Some Struct -> "Record" + | (SingleRec | MutRecFirst), Some _ -> "Inductive" + | (MutRecInner | MutRecLast), Some _ -> + (* Coq doesn't support groups of mutually recursive definitions which mix + * records and inducties: we convert everything to records if this happens + *) + "with" + | (Assumed | Declared), None -> "Axiom" + | _ -> raise (Failure "Unexpected")) + +let fun_decl_kind_to_qualif (kind : decl_kind) = + match !backend with + | FStar -> ( + match kind with + | SingleNonRec -> "let" + | SingleRec -> "let rec" + | MutRecFirst -> "let rec" + | MutRecInner -> "and" + | MutRecLast -> "and" + | Assumed -> "assume val" + | Declared -> "val") + | Coq -> ( + match kind with + | SingleNonRec -> "Definition" + | SingleRec -> "Fixpoint" + | MutRecFirst -> "Fixpoint" + | MutRecInner -> "with" + | MutRecLast -> "with" + | Assumed -> "Axiom" + | Declared -> "Axiom") + +(** + [ctx]: we use the context to lookup type definitions, to retrieve type names. + This is used to compute variable names, when they have no basenames: in this + case we use the first letter of the type name. + + [variant_concatenate_type_name]: if true, add the type name as a prefix + to the variant names. + Ex.: + In Rust: + {[ + enum List = { + Cons(u32, Box),x + Nil, + } + ]} + + F*, if option activated: + {[ + type list = + | ListCons : u32 -> list -> list + | ListNil : list + ]} + + F*, if option not activated: + {[ + type list = + | Cons : u32 -> list -> list + | Nil : list + ]} + + Rk.: this should be true by default, because in Rust all the variant names + are actively uniquely identifier by the type name [List::Cons(...)], while + in other languages it is not necessarily the case, and thus clashes can mess + up type checking. Note that some languages actually forbids the name clashes + (it is the case of F* ). + *) +let mk_formatter (ctx : trans_ctx) (crate_name : string) + (variant_concatenate_type_name : bool) : formatter = + let int_name = int_name in + + (* Prepare a name. + * The first id elem is always the crate: if it is the local crate, + * we remove it. + * We also remove all the disambiguators, then convert everything to strings. + * **Rmk:** because we remove the disambiguators, there may be name collisions + * (which is ok, because we check for name collisions and fail if there is any). + *) + let get_name (name : name) : string list = + (* Rmk.: initially we only filtered the disambiguators equal to 0 *) + let name = Names.filter_disambiguators name in + match name with + | Ident crate :: name -> + let name = if crate = crate_name then name else Ident crate :: name in + let name = + List.map + (function + | Names.Ident s -> s + | Disambiguator d -> Names.Disambiguator.to_string d) + name + in + name + | _ -> + raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name)) + in + let get_type_name = get_name in + let type_name_to_camel_case name = + let name = get_type_name name in + let name = List.map to_camel_case name in + String.concat "" name + in + let type_name_to_snake_case name = + let name = get_type_name name in + let name = List.map to_snake_case name in + let name = String.concat "_" name in + match !backend with FStar -> name | Coq -> capitalize_first_letter name + in + let type_name name = type_name_to_snake_case name ^ "_t" in + let field_name (def_name : name) (field_id : FieldId.id) + (field_name : string option) : string = + let def_name = type_name_to_snake_case def_name ^ "_" in + match field_name with + | Some field_name -> def_name ^ field_name + | None -> def_name ^ FieldId.to_string field_id + in + let variant_name (def_name : name) (variant : string) : string = + let variant = to_camel_case variant in + if variant_concatenate_type_name then + type_name_to_camel_case def_name ^ variant + else variant + in + let struct_constructor (basename : name) : string = + let tname = type_name basename in + let prefix = match !backend with FStar -> "Mk" | Coq -> "mk" in + prefix ^ tname + in + let get_fun_name = get_name in + let fun_name_to_snake_case (fname : fun_name) : string = + let fname = get_fun_name fname in + (* Converting to snake case should be a no-op, but it doesn't cost much *) + let fname = List.map to_snake_case fname in + (* Concatenate the elements *) + String.concat "_" fname + in + let global_name (name : global_name) : string = + (* Converting to snake case also lowercases the letters (in Rust, global + * names are written in capital letters). *) + let parts = List.map to_snake_case (get_name name) in + String.concat "_" parts + in + let fun_name (fname : fun_name) (num_rgs : int) + (rg : region_group_info option) (filter_info : bool * int) : string = + let fname = fun_name_to_snake_case fname in + (* Compute the suffix *) + let suffix = default_fun_suffix num_rgs rg filter_info in + (* Concatenate *) + fname ^ suffix + in + + let decreases_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) : string + = + let fname = fun_name_to_snake_case fname in + (* Compute the suffix *) + let suffix = "_decreases" in + (* Concatenate *) + fname ^ suffix + in + + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) + : string = + (* If there is a basename, we use it *) + match basename with + | Some basename -> + (* This should be a no-op *) + to_snake_case basename + | None -> ( + (* No basename: we use the first letter of the type *) + match ty with + | Adt (type_id, tys) -> ( + match type_id with + | Tuple -> + (* The "pair" case is frequent enough to have its special treatment *) + if List.length tys = 2 then "p" else "t" + | Assumed Result -> "r" + | Assumed Option -> "opt" + | Assumed Vec -> "v" + | Assumed State -> "st" + | AdtId adt_id -> + let def = + TypeDeclId.Map.find adt_id ctx.type_context.type_decls + in + (* We do the following: + * - compute the type name, and retrieve the last ident + * - convert this to snake case + * - take the first letter of every "letter group" + * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm" + *) + (* Thename shouldn't be empty, and its last element should + * be an ident *) + let cl = List.nth def.name (List.length def.name - 1) in + let cl = to_snake_case (Names.as_ident cl) in + let cl = String.split_on_char '_' cl in + let cl = List.filter (fun s -> String.length s > 0) cl in + assert (List.length cl > 0); + let cl = List.map (fun s -> s.[0]) cl in + StringUtils.string_of_chars cl) + | TypeVar _ -> ( + (* TODO: use "t" also for F* *) + match !backend with + | FStar -> "x" (* lacking inspiration here... *) + | Coq -> "t" (* lacking inspiration here... *)) + | Bool -> "b" + | Char -> "c" + | Integer _ -> "i" + | Str -> "s" + | Arrow _ -> "f" + | Array _ | Slice _ -> raise Unimplemented) + in + let type_var_basename (_varset : StringSet.t) (basename : string) : string = + (* Rust type variables are snake-case and start with a capital letter *) + match !backend with + | FStar -> + (* This is *not* a no-op: this removes the capital letter *) + to_snake_case basename + | Coq -> basename + in + let append_index (basename : string) (i : int) : string = + basename ^ string_of_int i + in + + let extract_primitive_value (fmt : F.formatter) (inside : bool) + (cv : primitive_value) : unit = + match cv with + | Scalar sv -> ( + match !backend with + | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value) + | Coq -> + if inside then F.pp_print_string fmt "("; + (* We need to add parentheses if the value is negative *) + if sv.PV.value >= Z.of_int 0 then + F.pp_print_string fmt (Z.to_string sv.PV.value) + else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); + F.pp_print_space fmt (); + F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty); + if inside then F.pp_print_string fmt ")") + | Bool b -> + let b = if b then "true" else "false" in + F.pp_print_string fmt b + | Char c -> ( + match !backend with + | FStar -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") + | Coq -> + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt "char_of_byte"; + F.pp_print_space fmt (); + (* Convert the the char to ascii *) + let c = + let i = Char.code c in + let x0 = i / 16 in + let x1 = i mod 16 in + "Coq.Init.Byte.x" ^ string_of_int x0 ^ string_of_int x1 + in + F.pp_print_string fmt c; + if inside then F.pp_print_string fmt ")") + | String s -> + (* We need to replace all the line breaks *) + let s = + StringUtils.map + (fun c -> if c = '\n' then "\n" else String.make 1 c) + s + in + F.pp_print_string fmt ("\"" ^ s ^ "\"") + in + { + bool_name = "bool"; + char_name = "char"; + int_name; + str_name = "string"; + type_decl_kind_to_qualif; + fun_decl_kind_to_qualif; + field_name; + variant_name; + struct_constructor; + type_name; + global_name; + fun_name; + decreases_clause_name; + var_basename; + type_var_basename; + append_index; + extract_primitive_value; + extract_unop; + extract_binop; + } + +let mk_formatter_and_names_map (ctx : trans_ctx) (crate_name : string) + (variant_concatenate_type_name : bool) : formatter * names_map = + let fmt = mk_formatter ctx crate_name variant_concatenate_type_name in + let names_map = initialize_names_map fmt (names_map_init ()) in + (fmt, names_map) + +(** In Coq, a group of definitions must be ended with a "." *) +let print_decl_end_delimiter (fmt : F.formatter) (kind : decl_kind) = + if !backend = Coq && decl_is_last_from_group kind then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ".") + +(** [inside] constrols whether we should add parentheses or not around type + applications (if [true] we add parentheses). + *) +let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (ty : ty) : unit = + match ty with + | Adt (type_id, tys) -> ( + match type_id with + | Tuple -> + (* This is a bit annoying, but in F*/Coq [()] is not the unit type: + * we have to write [unit]... *) + if tys = [] then F.pp_print_string fmt "unit" + else ( + F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_space fmt (); + let product = match !backend with FStar -> "&" | Coq -> "*" in + F.pp_print_string fmt product; + F.pp_print_space fmt ()) + (extract_ty ctx fmt true) tys; + F.pp_print_string fmt ")") + | AdtId _ | Assumed _ -> + let print_paren = inside && tys <> [] in + if print_paren then F.pp_print_string fmt "("; + F.pp_print_string fmt (ctx_get_type type_id ctx); + if tys <> [] then F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_ty ctx fmt true) tys; + if print_paren then F.pp_print_string fmt ")") + | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) + | Bool -> F.pp_print_string fmt ctx.fmt.bool_name + | Char -> F.pp_print_string fmt ctx.fmt.char_name + | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) + | Str -> F.pp_print_string fmt ctx.fmt.str_name + | Arrow (arg_ty, ret_ty) -> + if inside then F.pp_print_string fmt "("; + extract_ty ctx fmt false arg_ty; + F.pp_print_space fmt (); + F.pp_print_string fmt "->"; + F.pp_print_space fmt (); + extract_ty ctx fmt false ret_ty; + if inside then F.pp_print_string fmt ")" + | Array _ | Slice _ -> raise Unimplemented + +(** Compute the names for all the top-level identifiers used in a type + definition (type name, variant names, field names, etc. but not type + parameters). + + We need to do this preemptively, beforce extracting any definition, + because of recursive definitions. + *) +let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : + extraction_ctx = + (* Compute and register the type def name *) + let ctx = ctx_add_type_decl def ctx in + (* Compute and register: + * - the variant names, if this is an enumeration + * - the field names, if this is a structure + *) + let ctx = + match def.kind with + | Struct fields -> + (* Add the fields *) + let ctx = + fst + (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) + in + (* Add the constructor name *) + fst (ctx_add_struct def ctx) + | Enum variants -> + fst + (ctx_add_variants def + (VariantId.mapi (fun id v -> (id, v)) variants) + ctx) + | Opaque -> + (* Nothing to do *) + ctx + in + (* Return *) + ctx + +(** Print the variants *) +let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) + (type_name : string) (type_params : string list) (cons_name : string) + (fields : field list) : unit = + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + (* variant box *) + (* [| Cons :] + * Note that we really don't want any break above so we print everything + * at once. *) + F.pp_print_string fmt ("| " ^ cons_name ^ " :"); + F.pp_print_space fmt (); + let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) : + extraction_ctx = + (* Open the field box *) + F.pp_open_box fmt ctx.indent_incr; + (* Print the field names + * [ x :] + * Note that when printing fields, we register the field names as + * *variables*: they don't need to be unique at the top level. *) + let ctx = + match f.field_name with + | None -> ctx + | Some field_name -> + let var_id = VarId.of_int (FieldId.to_int fid) in + let field_name = + ctx.fmt.var_basename ctx.names_map.names_set (Some field_name) + f.field_ty + in + let ctx, field_name = ctx_add_var field_name var_id ctx in + F.pp_print_string fmt (field_name ^ " :"); + F.pp_print_space fmt (); + ctx + in + (* Print the field type *) + extract_ty ctx fmt false f.field_ty; + (* Print the arrow [->]*) + F.pp_print_space fmt (); + F.pp_print_string fmt "->"; + (* Close the field box *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + (* Return *) + ctx + in + (* Print the fields *) + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + let _ = + List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields + in + (* Print the final type *) + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt type_name; + List.iter + (fun type_param -> + F.pp_print_space fmt (); + F.pp_print_string fmt type_param) + type_params; + F.pp_close_box fmt (); + (* Close the variant box *) + F.pp_close_box fmt () + +(* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *) +let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) + (def : type_decl) (def_name : string) (type_params : string list) + (variants : variant list) : unit = + (* We want to generate a definition which looks like this (taking F* as example): + {[ + type list a = | Cons : a -> list a -> list a | Nil : list a + ]} + + If there isn't enough space on one line: + {[ + type s = + | Cons : a -> list a -> list a + | Nil : list a + ]} + + And if we need to write the type of a variant on several lines: + {[ + type s = + | Cons : + a -> + list a -> + list a + | Nil : list a + ]} + + Finally, it is possible to give names to the variant fields in Rust. + In this situation, we generate a definition like this: + {[ + type s = + | Cons : hd:a -> tl:list a -> list a + | Nil : list a + ]} + + Note that we already printed: [type s =] + *) + let print_variant variant_id v = + let cons_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in + let fields = v.fields in + extract_type_decl_variant ctx fmt def_name type_params cons_name fields + in + (* Print the variants *) + let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in + List.iter (fun (vid, v) -> print_variant vid v) variants + +let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (def : type_decl) (type_params : string list) + (fields : field list) : unit = + (* We want to generate a definition which looks like this (taking F* as example): + {[ + type t = { x : int; y : bool; } + ]} + + If there isn't enough space on one line: + {[ + type t = + { + x : int; y : bool; + } + ]} + + And if there is even less space: + {[ + type t = + { + x : int; + y : bool; + } + ]} + + Also, in case there are no fields, we need to define the type as [unit] + ([type t = {}] doesn't work in F* ). + + Coq: + ==== + We need to define the constructor name upon defining the struct (record, in Coq). + The syntex is: + {[ + Record Foo = mkFoo { x : int; y : bool; }. + }] + + Also, Coq doesn't support groups of mutually recursive inductives and records. + This is fine, because we can then define records as inductives, and leverage + the fact that when record fields are accessed, the records are symbolically + expanded which introduces let bindings of the form: [let RecordCons ... = x in ...]. + As a consequence, we never use the record projectors (unless we reconstruct + them in the micro passes of course). + *) + (* Note that we already printed: [type t =] *) + let is_rec = decl_is_from_rec_group kind in + (* If Coq: print the constructor name *) + if !backend = Coq && not is_rec then ( + F.pp_print_space fmt (); + F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx)); + let _ = + if !backend = FStar && fields = [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "unit") + else if (not is_rec) || !backend = FStar then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "{"; + F.pp_print_break fmt 1 ctx.indent_incr; + (* The body itself *) + F.pp_open_hvbox fmt 0; + (* Print the fields *) + let print_field (field_id : FieldId.id) (f : field) : unit = + let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in + F.pp_open_box fmt ctx.indent_incr; + F.pp_print_string fmt field_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false f.field_ty; + F.pp_print_string fmt ";"; + F.pp_close_box fmt () + in + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + Collections.List.iter_link (F.pp_print_space fmt) + (fun (fid, f) -> print_field fid f) + fields; + (* Close *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + F.pp_print_string fmt "}") + else ( + (* We extract for Coq, and we have a recursive record, or a record in + a group of mutually recursive types: we extract it as an inductive type + *) + assert (is_rec && !backend = Coq); + let cons_name = ctx_get_struct (AdtId def.def_id) ctx in + let def_name = ctx_get_local_type def.def_id ctx in + extract_type_decl_variant ctx fmt def_name type_params cons_name fields) + in + () + +(** Extract a type declaration. + + Note that all the names used for extraction should already have been + registered. + *) +let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (def : type_decl) : unit = + let extract_body = + match kind with + | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true + | Assumed | Declared -> false + in + let type_kind = + if extract_body then + match def.kind with + | Struct _ -> Some Struct + | Enum _ -> Some Enum + | Opaque -> None + else None + in + (* If in Coq and the declaration is opaque, it must have the shape: + [Axiom Ident : forall (T0 ... Tn : Type), ... -> ... -> ...]. + + The boolean [is_opaque_coq] is used to detect this case. + *) + let is_opaque_coq = !backend = Coq && type_kind = None in + let use_forall = is_opaque_coq && def.type_params <> [] in + (* Retrieve the definition name *) + let def_name = ctx_get_local_type def.def_id ctx in + (* Add the type params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx_body, type_params = ctx_add_type_params def.type_params ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + F.pp_print_string fmt ("(** [" ^ Print.name_to_string def.name ^ "] *)"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "type TYPE_NAME" *) + let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in + F.pp_print_string fmt (qualif ^ " " ^ def_name); + (* Print the type parameters *) + let type_keyword = match !backend with FStar -> "Type0" | Coq -> "Type" in + if def.type_params <> [] then ( + if use_forall then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "forall"); + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx_body in + F.pp_print_string fmt pname; + F.pp_print_space fmt ()) + def.type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ^ ")")); + (* Print the "=" if we extract the body*) + if extract_body then ( + F.pp_print_space fmt (); + let eq = match !backend with FStar -> "=" | Coq -> ":=" in + F.pp_print_string fmt eq) + else ( + (* Otherwise print ": Type0" *) + if use_forall then F.pp_print_string fmt "," + else ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"); + F.pp_print_space fmt (); + F.pp_print_string fmt type_keyword); + (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *) + F.pp_close_box fmt (); + (if extract_body then + match def.kind with + | Struct fields -> + extract_type_decl_struct_body ctx_body fmt kind def type_params fields + | Enum variants -> + extract_type_decl_enum_body ctx_body fmt def def_name type_params + variants + | Opaque -> raise (Failure "Unreachable")); + (* If Coq: end the definition with a "." *) + print_decl_end_delimiter fmt kind; + (* Close the box for the definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Extract extra information for a type (e.g., [Arguments] information in Coq). + + Note that all the names used for extraction should already have been + registered. + *) +let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (decl : type_decl) : unit = + match !backend with + | FStar -> () + | Coq -> ( + (* Add the type params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let _ctx_body, type_params = ctx_add_type_params decl.type_params ctx in + (* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *) + let extract_arguments_info (cons_name : string) (fields : 'a list) : unit + = + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Open a box *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Small utility *) + let print_type_vars () = + List.iter + (fun (var : string) -> + F.pp_print_string fmt ("{" ^ var ^ "}"); + F.pp_print_space fmt ()) + type_params + in + let print_fields () = + List.iter + (fun _ -> + F.pp_print_string fmt "_"; + F.pp_print_space fmt ()) + fields + in + F.pp_print_break fmt 0 0; + F.pp_print_string fmt "Arguments"; + F.pp_print_space fmt (); + F.pp_print_string fmt cons_name; + F.pp_print_space fmt (); + print_type_vars (); + print_fields (); + F.pp_print_space fmt (); + F.pp_print_string fmt "."; + + (* Close the box *) + F.pp_close_box fmt () + in + + (* Generate the [Arguments] instruction *) + match decl.kind with + | Opaque -> () + | Struct fields -> + let adt_id = AdtId decl.def_id in + (* Generate the instruction for the record constructor *) + let cons_name = ctx_get_struct adt_id ctx in + extract_arguments_info cons_name fields; + (* Generate the instruction for the record projectors, if there are *) + let is_rec = decl_is_from_rec_group kind in + if not is_rec then + FieldId.iteri + (fun fid _ -> + let cons_name = ctx_get_field adt_id fid ctx in + extract_arguments_info cons_name []) + fields; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + | Enum variants -> + (* Generate the instructions *) + VariantId.iteri + (fun vid (v : variant) -> + let cons_name = ctx_get_variant (AdtId decl.def_id) vid ctx in + extract_arguments_info cons_name v.fields) + variants; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0) + +(** Extract the state type declaration. *) +let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) + (kind : decl_kind) : unit = + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment *) + F.pp_print_string fmt "(** The state type used in the state-error monad *)"; + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Retrieve the name *) + let state_name = ctx_get_assumed_type State ctx in + (* The kind should be [Assumed] or [Declared] *) + (match kind with + | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> + raise (Failure "Unexpected") + | Assumed -> ( + match !backend with + | FStar -> + F.pp_print_string fmt "assume"; + F.pp_print_space fmt (); + F.pp_print_string fmt "type"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type0" + | Coq -> + F.pp_print_string fmt "Axiom"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type.") + | Declared -> ( + match !backend with + | FStar -> + F.pp_print_string fmt "val"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type0" + | Coq -> + F.pp_print_string fmt "Axiom"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type.")); + (* Close the box for the definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Compute the names for all the pure functions generated from a rust function + (forward function and backward functions). + *) +let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) + (has_decreases_clause : bool) (def : pure_fun_translation) : extraction_ctx + = + let fwd, back_ls = def in + (* Register the decrease clause, if necessary *) + let ctx = + if has_decreases_clause then ctx_add_decrases_clause fwd ctx else ctx + in + (* Register the forward function name *) + let ctx = ctx_add_fun_decl (keep_fwd, def) fwd ctx in + (* Register the backward functions' names *) + let ctx = + List.fold_left + (fun ctx back -> ctx_add_fun_decl (keep_fwd, def) back ctx) + ctx back_ls + in + (* Return *) + ctx + +(** Simply add the global name to the context. *) +let extract_global_decl_register_names (ctx : extraction_ctx) + (def : A.global_decl) : extraction_ctx = + ctx_add_global_decl_and_body def ctx + +(** The following function factorizes the extraction of ADT values. + + Note that patterns can introduce new variables: we thus return an extraction + context updated with new bindings. + + TODO: we don't need something very generic anymore (some definitions used + to be polymorphic). + *) +let extract_adt_g_value + (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) + (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) + (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : + extraction_ctx = + match ty with + | Adt (Tuple, _) -> + (* Tuple *) + (* This is very annoying: in Coq, we can't write [()] for the value of + type [unit], we have to write [tt]. *) + if !backend = Coq && field_values = [] then ( + F.pp_print_string fmt "tt"; + ctx) + else ( + F.pp_print_string fmt "("; + let ctx = + Collections.List.fold_left_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx false v) + ctx field_values + in + F.pp_print_string fmt ")"; + ctx) + | Adt (adt_id, _) -> + (* "Regular" ADT *) + (* We print something of the form: [Cons field0 ... fieldn]. + * We could update the code to print something of the form: + * [{ field0=...; ...; fieldn=...; }] in case of structures. + *) + let cons = + match variant_id with + | Some vid -> ctx_get_variant adt_id vid ctx + | None -> ctx_get_struct adt_id ctx + in + if inside && field_values <> [] then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + let ctx = + Collections.List.fold_left + (fun ctx v -> + F.pp_print_space fmt (); + extract_value ctx true v) + ctx field_values + in + if inside && field_values <> [] then F.pp_print_string fmt ")"; + ctx + | _ -> raise (Failure "Inconsistent typed value") + +(* Extract globals in the same way as variables *) +let extract_global (ctx : extraction_ctx) (fmt : F.formatter) + (id : A.GlobalDeclId.id) : unit = + F.pp_print_string fmt (ctx_get_global id ctx) + +(** [inside]: see {!extract_ty}. + + As a pattern can introduce new variables, we return an extraction context + updated with new bindings. + *) +let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (v : typed_pattern) : extraction_ctx = + match v.value with + | PatConstant cv -> + ctx.fmt.extract_primitive_value fmt inside cv; + ctx + | PatVar (v, _) -> + let vname = + ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty + in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | PatDummy -> + F.pp_print_string fmt "_"; + ctx + | PatAdt av -> + let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in + extract_adt_g_value extract_value fmt ctx inside av.variant_id + av.field_values v.ty + +(** [inside]: controls the introduction of parentheses. See [extract_ty] + + TODO: replace the formatting boolean [inside] with something more general? + Also, it seems we don't really use it... + Cases to consider: + - right-expression in a let: [let x = re in _] (never parentheses?) + - next expression in a let: [let x = _ in next_e] (never parentheses?) + - application argument: [f (exp)] + - match/if scrutinee: [if exp then _ else _]/[match exp | _ -> _] + *) +let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (e : texpression) : unit = + match e.e with + | Var var_id -> + let var_name = ctx_get_var var_id ctx in + F.pp_print_string fmt var_name + | Const cv -> ctx.fmt.extract_primitive_value fmt inside cv + | App _ -> + let app, args = destruct_apps e in + extract_App ctx fmt inside app args + | Abs _ -> + let xl, e = destruct_abs_list e in + extract_Abs ctx fmt inside xl e + | Qualif _ -> + (* We use the app case *) + extract_App ctx fmt inside e [] + | Let (monadic, lv, re, next_e) -> + extract_Let ctx fmt inside monadic lv re next_e + | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body + | Meta (_, e) -> extract_texpression ctx fmt inside e + +(* Extract an application *or* a top-level qualif (function extraction has + * to handle top-level qualifiers, so it seemed more natural to merge the + * two cases) *) +and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (app : texpression) (args : texpression list) : unit = + (* We don't do the same thing if the app is a top-level identifier (function, + * ADT constructor...) or a "regular" expression *) + match app.e with + | Qualif qualif -> ( + (* Top-level qualifier *) + match qualif.id with + | FunOrOp fun_id -> + extract_function_call ctx fmt inside fun_id qualif.type_args args + | Global global_id -> extract_global ctx fmt global_id + | AdtCons adt_cons_id -> + extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args + | Proj proj -> + extract_field_projector ctx fmt inside app proj qualif.type_args args) + | _ -> + (* "Regular" expression *) + (* Open parentheses *) + if inside then F.pp_print_string fmt "("; + (* Open a box for the application *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the app expression *) + let app_inside = (inside && args = []) || args <> [] in + extract_texpression ctx fmt app_inside app; + (* Print the arguments *) + List.iter + (fun ve -> + F.pp_print_space fmt (); + extract_texpression ctx fmt true ve) + args; + (* Close the box for the application *) + F.pp_close_box fmt (); + (* Close parentheses *) + if inside then F.pp_print_string fmt ")" + +(** Subcase of the app case: function call *) +and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (fid : fun_or_op_id) (type_args : ty list) + (args : texpression list) : unit = + match (fid, args) with + | Unop unop, [ arg ] -> + (* A unop can have *at most* one argument (the result can't be a function!). + * Note that the way we generate the translation, we shouldn't get the + * case where we have no argument (all functions are fully instantiated, + * and no AST transformation introduces partial calls). *) + ctx.fmt.extract_unop (extract_texpression ctx fmt) fmt inside unop arg + | Binop (binop, int_ty), [ arg0; arg1 ] -> + (* Number of arguments: similar to unop *) + ctx.fmt.extract_binop + (extract_texpression ctx fmt) + fmt inside binop int_ty arg0 arg1 + | Fun fun_id, _ -> + if inside then F.pp_print_string fmt "("; + (* Open a box for the function call *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the function name *) + let fun_name = ctx_get_function fun_id ctx in + F.pp_print_string fmt fun_name; + (* Print the type parameters *) + List.iter + (fun ty -> + F.pp_print_space fmt (); + extract_ty ctx fmt true ty) + type_args; + (* Print the arguments *) + List.iter + (fun ve -> + F.pp_print_space fmt (); + extract_texpression ctx fmt true ve) + args; + (* Close the box for the function call *) + F.pp_close_box fmt (); + (* Return *) + if inside then F.pp_print_string fmt ")" + | (Unop _ | Binop _), _ -> + raise + (Failure + ("Unreachable:\n" ^ "Function: " ^ show_fun_or_op_id fid + ^ ",\nNumber of arguments: " + ^ string_of_int (List.length args) + ^ ",\nArguments: " + ^ String.concat " " (List.map show_texpression args))) + +(** Subcase of the app case: ADT constructor *) +and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : + unit = + match adt_cons.adt_id with + | Tuple -> + (* Tuple *) + (* For now, we only support fully applied tuple constructors *) + (* This is very annoying: in Coq, we can't write [()] for the value of + type [unit], we have to write [tt]. *) + assert (List.length type_args = List.length args); + if !backend = Coq && args = [] then F.pp_print_string fmt "tt" + else ( + F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun v -> extract_texpression ctx fmt false v) + args; + F.pp_print_string fmt ")") + | _ -> + (* "Regular" ADT *) + (* We print something of the form: [Cons field0 ... fieldn]. + * We could update the code to print something of the form: + * [{ field0=...; ...; fieldn=...; }] in case of fully + * applied structure constructors. + *) + let cons = + match adt_cons.variant_id with + | Some vid -> ctx_get_variant adt_cons.adt_id vid ctx + | None -> ctx_get_struct adt_cons.adt_id ctx + in + let use_parentheses = inside && args <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + Collections.List.iter + (fun v -> + F.pp_print_space fmt (); + extract_texpression ctx fmt true v) + args; + if use_parentheses then F.pp_print_string fmt ")" + +(** Subcase of the app case: ADT field projector. *) +and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (original_app : texpression) (proj : projection) + (_proj_type_params : ty list) (args : texpression list) : unit = + (* We isolate the first argument (if there is), in order to pretty print the + * projection ([x.field] instead of [MkAdt?.field x] *) + match args with + | [ arg ] -> + (* Exactly one argument: pretty-print *) + let field_name = ctx_get_field proj.adt_id proj.field_id ctx in + (* Open a box *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Extract the expression *) + extract_texpression ctx fmt true arg; + (* We allow to break where the "." appears *) + F.pp_print_break fmt 0 0; + F.pp_print_string fmt "."; + (* If in Coq, the field projection has to be parenthesized *) + (match !backend with + | FStar -> F.pp_print_string fmt field_name + | Coq -> F.pp_print_string fmt ("(" ^ field_name ^ ")")); + (* Close the box *) + F.pp_close_box fmt () + | arg :: args -> + (* Call extract_App again, but in such a way that the first argument is + * isolated *) + extract_App ctx fmt inside (mk_app original_app arg) args + | [] -> + (* No argument: shouldn't happen *) + raise (Failure "Unreachable") + +and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (xl : typed_pattern list) (e : texpression) : unit = + (* Open a box for the abs expression *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Open parentheses *) + if inside then F.pp_print_string fmt "("; + (* Print the lambda - note that there should always be at least one variable *) + assert (xl <> []); + F.pp_print_string fmt "fun"; + let ctx = + List.fold_left + (fun ctx x -> + F.pp_print_space fmt (); + extract_typed_pattern ctx fmt true x) + ctx xl + in + F.pp_print_space fmt (); + F.pp_print_string fmt "->"; + F.pp_print_space fmt (); + (* Print the body *) + extract_texpression ctx fmt false e; + (* Close parentheses *) + if inside then F.pp_print_string fmt ")"; + (* Close the box for the abs expression *) + F.pp_close_box fmt () + +and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (monadic : bool) (lv : typed_pattern) (re : texpression) + (next_e : texpression) : unit = + (* Open a box for the whole expression *) + F.pp_open_hvbox fmt 0; + (* Open parentheses *) + if inside then F.pp_print_string fmt "("; + (* Open a box for the let-binding *) + F.pp_open_hovbox fmt ctx.indent_incr; + let ctx = + if monadic then ( + (* Note that in F*, the left value of a monadic let-binding can only be + * a variable *) + let ctx = extract_typed_pattern ctx fmt true lv in + F.pp_print_space fmt (); + let arrow = match !backend with FStar -> "<--" | Coq -> "<-" in + F.pp_print_string fmt arrow; + F.pp_print_space fmt (); + extract_texpression ctx fmt false re; + F.pp_print_string fmt ";"; + ctx) + else ( + F.pp_print_string fmt "let"; + F.pp_print_space fmt (); + let ctx = extract_typed_pattern ctx fmt true lv in + F.pp_print_space fmt (); + let eq = match !backend with FStar -> "=" | Coq -> ":=" in + F.pp_print_string fmt eq; + F.pp_print_space fmt (); + extract_texpression ctx fmt false re; + F.pp_print_space fmt (); + F.pp_print_string fmt "in"; + ctx) + in + (* Close the box for the let-binding *) + F.pp_close_box fmt (); + (* Print the next expression *) + F.pp_print_space fmt (); + extract_texpression ctx fmt false next_e; + (* Close parentheses *) + if inside then F.pp_print_string fmt ")"; + (* Close the box for the whole expression *) + F.pp_close_box fmt () + +and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) + (scrut : texpression) (body : switch_body) : unit = + (* Open a box for the whole expression *) + F.pp_open_hvbox fmt 0; + (* Open parentheses *) + if inside then F.pp_print_string fmt "("; + (* Extract the switch *) + (match body with + | If (e_then, e_else) -> + (* Open a box for the [if] *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "if"; + F.pp_print_space fmt (); + let scrut_inside = PureUtils.let_group_requires_parentheses scrut in + extract_texpression ctx fmt scrut_inside scrut; + (* Close the box for the [if] *) + F.pp_close_box fmt (); + (* Extract the branches *) + let extract_branch (is_then : bool) (e_branch : texpression) : unit = + F.pp_print_space fmt (); + (* Open a box for the then/else+branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + let then_or_else = if is_then then "then" else "else" in + F.pp_print_string fmt then_or_else; + F.pp_print_space fmt (); + (* Open a box for the branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the [begin] if necessary *) + let parenth = PureUtils.let_group_requires_parentheses e_branch in + let left_delim, right_delim = + match !backend with FStar -> ("begin", "end") | Coq -> ("(", ")") + in + if parenth then ( + F.pp_print_string fmt left_delim; + F.pp_print_space fmt ()); + (* Print the branch expression *) + extract_texpression ctx fmt false e_branch; + (* Close the [begin ... end ] *) + if parenth then ( + F.pp_print_space fmt (); + F.pp_print_string fmt right_delim); + (* Close the box for the branch *) + F.pp_close_box fmt (); + (* Close the box for the then/else+branch *) + F.pp_close_box fmt () + in + + extract_branch true e_then; + extract_branch false e_else + | Match branches -> + (* Open a box for the [match ... with] *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the [match ... with] *) + let match_begin = + match !backend with FStar -> "begin match" | Coq -> "match" + in + F.pp_print_string fmt match_begin; + F.pp_print_space fmt (); + let scrut_inside = PureUtils.let_group_requires_parentheses scrut in + extract_texpression ctx fmt scrut_inside scrut; + F.pp_print_space fmt (); + F.pp_print_string fmt "with"; + (* Close the box for the [match ... with] *) + F.pp_close_box fmt (); + + (* Extract the branches *) + let extract_branch (br : match_branch) : unit = + F.pp_print_space fmt (); + (* Open a box for the pattern+branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "|"; + (* Print the pattern *) + F.pp_print_space fmt (); + let ctx = extract_typed_pattern ctx fmt false br.pat in + F.pp_print_space fmt (); + let arrow = match !backend with FStar -> "->" | Coq -> "=>" in + F.pp_print_string fmt arrow; + F.pp_print_space fmt (); + (* Open a box for the branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the branch itself *) + extract_texpression ctx fmt false br.branch; + (* Close the box for the branch *) + F.pp_close_box fmt (); + (* Close the box for the pattern+branch *) + F.pp_close_box fmt () + in + + List.iter extract_branch branches; + + (* End the match *) + F.pp_print_space fmt (); + F.pp_print_string fmt "end"); + (* Close parentheses *) + if inside then F.pp_print_string fmt ")"; + (* Close the box for the whole expression *) + F.pp_close_box fmt () + +(** A small utility to print the parameters of a function signature. + + We return two contexts: + - the context augmented with bindings for the type parameters + - the previous context augmented with bindings for the input values + *) +let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_decl) : extraction_ctx * extraction_ctx = + (* Add the type parameters - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, _ = ctx_add_type_params def.signature.type_params ctx in + (* Print the parameters - rk.: we should have filtered the functions + * with no input parameters *) + (* The type parameters *) + if def.signature.type_params <> [] then ( + (* Open a box for the type parameters *) + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "("; + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx in + F.pp_print_string fmt pname; + F.pp_print_space fmt ()) + def.signature.type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + let type_keyword = match !backend with FStar -> "Type0" | Coq -> "Type" in + F.pp_print_string fmt (type_keyword ^ ")"); + (* Close the box for the type parameters *) + F.pp_close_box fmt (); + F.pp_print_space fmt ()); + (* The input parameters - note that doing this adds bindings to the context *) + let ctx_body = + match def.body with + | None -> ctx + | Some body -> + List.fold_left + (fun ctx (lv : typed_pattern) -> + (* Open a box for the input parameter *) + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "("; + let ctx = extract_typed_pattern ctx fmt false lv in + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false lv.ty; + F.pp_print_string fmt ")"; + (* Close the box for the input parameters *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + ctx) + ctx body.inputs_lvs + in + (ctx, ctx_body) + +(** A small utility to print the types of the input parameters in the form: + [u32 -> list u32 -> ...] + (we don't print the return type of the function) + + This is used for opaque function declarations, in particular. + *) +let extract_fun_input_parameters_types (ctx : extraction_ctx) + (fmt : F.formatter) (def : fun_decl) : unit = + let extract_param (ty : ty) : unit = + let inside = false in + extract_ty ctx fmt inside ty; + F.pp_print_space fmt (); + F.pp_print_string fmt "->"; + F.pp_print_space fmt () + in + List.iter extract_param def.signature.inputs + +(** Extract a decrease clause function template body. + + Only for F*. + + In order to help the user, we can generate a template for the functions + required by the decreases clauses for. We simply generate definitions of + the following form in a separate file: + {[ + let f_decrease (t : Type0) (x : t) : nat = admit() + ]} + + Where the translated functions for [f] look like this: + {[ + let f_fwd (t : Type0) (x : t) : Tot ... (decreases (f_decrease t x)) = ... + ]} + *) +let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_decl) : unit = + assert (!backend = FStar); + (* Retrieve the function name *) + let def_name = ctx_get_decreases_clause def.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + F.pp_print_string fmt + ("(** [" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause *)"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Add the [unfold] keyword *) + F.pp_print_string fmt "unfold"; + F.pp_print_space fmt (); + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "let FUN_NAME" *) + F.pp_print_string fmt ("let " ^ def_name); + F.pp_print_space fmt (); + (* Extract the parameters *) + let _, _ = extract_fun_parameters ctx fmt def in + F.pp_print_string fmt ":"; + (* Print the signature *) + F.pp_print_space fmt (); + F.pp_print_string fmt "nat"; + (* Print the "=" *) + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + (* Print the "admit ()" *) + F.pp_print_string fmt "admit ()"; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_close_box fmt (); + (* Close the box for the whole definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Extract a function declaration. + + Note that all the names used for extraction should already have been + registered. + + We take the definition of the forward translation as parameter (which is + equal to the definition to extract, if we extract a forward function) because + it is useful for the decrease clause. + *) +let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = + assert (not def.is_global_decl_body); + (* Retrieve the function name *) + let def_name = ctx_get_local_function def.def_id def.back_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + F.pp_print_string fmt + ("(** [" ^ Print.fun_name_to_string def.basename ^ "] *)"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "let FUN_NAME" *) + let is_opaque = Option.is_none def.body in + (* If in Coq and the declaration is opaque, it must have the shape: + [Axiom Ident : forall (T0 ... Tn : Type), ... -> ... -> ...]. + + The boolean [is_opaque_coq] is used to detect this case. + *) + let is_opaque_coq = !backend = Coq && is_opaque in + let use_forall = is_opaque_coq && def.signature.type_params <> [] in + (* *) + let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in + F.pp_print_string fmt (qualif ^ " " ^ def_name); + F.pp_print_space fmt (); + if use_forall then ( + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "forall"); + (* Open a box for "(PARAMS) : EFFECT =" *) + F.pp_open_hvbox fmt 0; + (* Open a box for "(PARAMS)" *) + F.pp_open_hovbox fmt 0; + let ctx, ctx_body = extract_fun_parameters ctx fmt def in + (* Close the box for "(PARAMS)" *) + F.pp_close_box fmt (); + (* Print the return type - note that we have to be careful when + * printing the input values for the decrease clause, because + * it introduces bindings in the context... We thus "forget" + * the bindings we introduced above. + * TODO: figure out a cleaner way *) + let _ = + if use_forall then F.pp_print_string fmt "," else F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + (* Open a box for the EFFECT *) + F.pp_open_hvbox fmt 0; + (* Open a box for the return type *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the return type *) + (* For opaque definitions, as we don't have named parameters under the hand, + * we don't print parameters in the form [(x : a) (y : b) ...] above, + * but wait until here to print the types: [a -> b -> ...]. *) + if is_opaque then extract_fun_input_parameters_types ctx fmt def; + (* [Tot] *) + if has_decreases_clause then ( + assert (!backend = FStar); + F.pp_print_string fmt "Tot"; + F.pp_print_space fmt ()); + extract_ty ctx fmt has_decreases_clause def.signature.output; + (* Close the box for the return type *) + F.pp_close_box fmt (); + (* Print the decrease clause - rk.: a function with a decreases clause + * is necessarily a transparent function *) + if has_decreases_clause then ( + assert (!backend = FStar); + F.pp_print_space fmt (); + (* Open a box for the decrease clause *) + F.pp_open_hovbox fmt 0; + (* *) + F.pp_print_string fmt "(decreases"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + (* The name of the decrease clause *) + let decr_name = ctx_get_decreases_clause def.def_id ctx in + F.pp_print_string fmt decr_name; + (* Print the type parameters *) + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx in + F.pp_print_space fmt (); + F.pp_print_string fmt pname) + def.signature.type_params; + (* Print the input values: we have to be careful here to print + * only the input values which are in common with the *forward* + * function (the additional input values "given back" to the + * backward functions have no influence on termination: we thus + * share the decrease clauses between the forward and the backward + * functions - we also ignore the additional state received by the + * backward function, if there is one). + *) + let inputs_lvs = + let all_inputs = (Option.get def.body).inputs_lvs in + let num_fwd_inputs = def.signature.info.num_fwd_inputs_with_state in + Collections.List.prefix num_fwd_inputs all_inputs + in + let _ = + List.fold_left + (fun ctx (lv : typed_pattern) -> + F.pp_print_space fmt (); + let ctx = extract_typed_pattern ctx fmt false lv in + ctx) + ctx inputs_lvs + in + F.pp_print_string fmt "))"; + (* Close the box for the decrease clause *) + F.pp_close_box fmt ()); + (* Close the box for the EFFECT *) + F.pp_close_box fmt () + in + (* Print the "=" *) + if not is_opaque then ( + F.pp_print_space fmt (); + let eq = match !backend with FStar -> "=" | Coq -> ":=" in + F.pp_print_string fmt eq); + (* Close the box for "(PARAMS) : EFFECT =" *) + F.pp_close_box fmt (); + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_close_box fmt (); + if not is_opaque then ( + F.pp_print_space fmt (); + (* Open a box for the body *) + F.pp_open_hvbox fmt 0; + (* Extract the body *) + let _ = extract_texpression ctx_body fmt false (Option.get def.body).body in + (* Coq: add a "." *) + print_decl_end_delimiter fmt kind; + (* Close the box for the body *) + F.pp_close_box fmt ()); + (* Coq: add a "." *) + if is_opaque_coq then print_decl_end_delimiter fmt kind; + (* Close the box for the definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Extract a global declaration body of the shape "QUALIF NAME : TYPE = BODY" + with a custom body extractor + *) +let extract_global_decl_body (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (name : string) (ty : ty) + (extract_body : (F.formatter -> unit) Option.t) : unit = + let is_opaque = Option.is_none extract_body in + + (* Open the definition box (depth=0) *) + F.pp_open_hvbox fmt ctx.indent_incr; + + (* Open "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "QUALIF NAME " *) + F.pp_print_string fmt (ctx.fmt.fun_decl_kind_to_qualif kind); + F.pp_print_space fmt (); + F.pp_print_string fmt name; + F.pp_print_space fmt (); + + (* Open ": TYPE =" box (depth=2) *) + F.pp_open_hvbox fmt 0; + (* Print ": " *) + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + + (* Open "TYPE" box (depth=3) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "TYPE" *) + extract_ty ctx fmt false ty; + (* Close "TYPE" box (depth=3) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + (* Print " =" *) + F.pp_print_space fmt (); + let eq = match !backend with FStar -> "=" | Coq -> ":=" in + F.pp_print_string fmt eq); + (* Close ": TYPE =" box (depth=2) *) + F.pp_close_box fmt (); + (* Close "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + F.pp_print_space fmt (); + (* Open "BODY" box (depth=1) *) + F.pp_open_hvbox fmt 0; + (* Print "BODY" *) + (Option.get extract_body) fmt; + (* Close "BODY" box (depth=1) *) + F.pp_close_box fmt ()); + + (* Coq: add a "." *) + print_decl_end_delimiter fmt Declared; + + (* Close the definition box (depth=0) *) + F.pp_close_box fmt () + +(** Extract a global declaration. + + We generate the body which computes the global value separately from the + value declaration itself. + + For example in Rust, + [static X: u32 = 3;] + + will be translated to the following F*: + [let x_body : result u32 = Return 3] + [let x_c : u32 = eval_global x_body] + *) +let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) + (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = + assert body.is_global_decl_body; + assert (Option.is_none body.back_id); + assert (List.length body.signature.inputs = 0); + assert (List.length body.signature.doutputs = 1); + assert (List.length body.signature.type_params = 0); + + (* Add a break then the name of the corresponding LLBC declaration *) + F.pp_print_break fmt 0 0; + F.pp_print_string fmt + ("(** [" ^ Print.global_name_to_string global.name ^ "] *)"); + F.pp_print_space fmt (); + + let decl_name = ctx_get_global global.def_id ctx in + let body_name = + ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx + in + + let decl_ty, body_ty = + let ty = body.signature.output in + if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + else (ty, mk_result_ty ty) + in + match body.body with + | None -> + let kind = if interface then Declared else Assumed in + extract_global_decl_body ctx fmt kind decl_name decl_ty None + | Some body -> + extract_global_decl_body ctx fmt SingleNonRec body_name body_ty + (Some (fun fmt -> extract_texpression ctx fmt false body.body)); + F.pp_print_break fmt 0 0; + extract_global_decl_body ctx fmt SingleNonRec decl_name decl_ty + (Some + (fun fmt -> + let body = + match !backend with + | FStar -> "eval_global " ^ body_name + | Coq -> body_name ^ "%global" + in + F.pp_print_string fmt body)); + (* Add a break to insert lines between declarations *) + F.pp_print_break fmt 0 0 + +(** Extract a unit test, if the function is a unit function (takes no + parameters, returns unit). + + A unit test simply checks that the function normalizes to [Return ()]. + + F*: + {[ + let _ = assert_norm (FUNCTION = Return ()) + ]} + + Coq: + {[ + Check (FUNCTION)%return). + ]} + *) +let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_decl) : unit = + (* We only insert unit tests for forward functions *) + assert (def.back_id = None); + (* Check if this is a unit function *) + let sg = def.signature in + if + sg.type_params = [] + && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) + && sg.output = mk_result_ty mk_unit_ty + then ( + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment *) + F.pp_print_string fmt + ("(** Unit test for [" ^ Print.fun_name_to_string def.basename ^ "] *)"); + F.pp_print_space fmt (); + (* Open a box for the test *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the test *) + (match !backend with + | FStar -> + F.pp_print_string fmt "let _ ="; + F.pp_print_space fmt (); + F.pp_print_string fmt "assert_norm"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let fun_name = ctx_get_local_function def.def_id def.back_id ctx in + F.pp_print_string fmt fun_name; + if sg.inputs <> [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "()"); + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + let success = ctx_get_variant (Assumed Result) result_return_id ctx in + F.pp_print_string fmt (success ^ " ())") + | Coq -> + F.pp_print_string fmt "Check"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let fun_name = ctx_get_local_function def.def_id def.back_id ctx in + F.pp_print_string fmt fun_name; + if sg.inputs <> [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "()"); + F.pp_print_space fmt (); + F.pp_print_string fmt ")%return."); + (* Close the box for the test *) + F.pp_close_box fmt (); + (* Add a break after *) + F.pp_print_break fmt 0 0) + else (* Do nothing *) + () diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml new file mode 100644 index 00000000..33939e6a --- /dev/null +++ b/compiler/ExtractBase.ml @@ -0,0 +1,811 @@ +(** Define base utilities for the extraction *) + +open Pure +open TranslateCore +module C = Contexts +module RegionVarId = T.RegionVarId +module F = Format + +(** The local logger *) +let log = L.pure_to_extract_log + +type region_group_info = { + id : RegionGroupId.id; + (** The id of the region group. + Note that a simple way of generating unique names for backward + functions is to use the region group ids. + *) + region_names : string option list; + (** The names of the region variables included in this group. + Note that names are not always available... + *) +} + +module StringSet = Collections.MakeSet (Collections.OrderedString) +module StringMap = Collections.MakeMap (Collections.OrderedString) + +type name = Names.name +type type_name = Names.type_name +type global_name = Names.global_name +type fun_name = Names.fun_name + +(** Characterizes a declaration. + + Is in particular useful to derive the proper keywords to introduce the + declarations/definitions. + *) +type decl_kind = + | SingleNonRec + (** A single, non-recursive definition. + + F*: [let x = ...] + Coq: [Definition x := ...] + *) + | SingleRec + (** A single, recursive definition. + + F*: [let rec x = ...] + Coq: [Fixpoint x := ...] + *) + | MutRecFirst + (** The first definition of a group of mutually-recursive definitions. + + F*: [type x0 = ... and x1 = ...] + Coq: [Fixpoing x0 := ... with x1 := ...] + *) + | MutRecInner + (** An inner definition in a group of mutually-recursive definitions. *) + | MutRecLast + (** The last definition in a group of mutually-recursive definitions. + + We need this because in some theorem provers like Coq, we need to + delimit group of mutually recursive definitions (in particular, we + need to insert an end delimiter). + *) + | Assumed + (** An assumed definition. + + F*: [assume val x] + Coq: [Axiom x : Type.] + *) + | Declared + (** Declare a type in an interface or a module signature. + + Rem.: for now, in Coq, we don't declare module signatures: we + thus assume the corresponding declarations. + + F*: [val x : Type0] + Coq: [Axiom x : Type.] + *) + +(** Return [true] if the declaration is the last from its group of declarations. + + We need this because in some provers (e.g., Coq), we need to delimit the + end of a (group of) definition(s) (in Coq: with a "."). + *) +let decl_is_last_from_group (kind : decl_kind) : bool = + match kind with + | SingleNonRec | SingleRec | MutRecLast | Assumed | Declared -> true + | MutRecFirst | MutRecInner -> false + +let decl_is_from_rec_group (kind : decl_kind) : bool = + match kind with + | SingleNonRec | Assumed | Declared -> false + | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true + +let decl_is_from_mut_rec_group (kind : decl_kind) : bool = + match kind with + | SingleNonRec | SingleRec | Assumed | Declared -> false + | MutRecFirst | MutRecInner | MutRecLast -> true + +(* TODO: this should a module we give to a functor! *) + +type type_decl_kind = Enum | Struct + +(** A formatter's role is twofold: + 1. Come up with name suggestions. + For instance, provided some information about a function (its basename, + information about the region group, etc.) it should come up with an + appropriate name for the forward/backward function. + + It can of course apply many transformations, like changing to camel case/ + snake case, adding prefixes/suffixes, etc. + + 2. Format some specific terms, like constants. + *) +type formatter = { + bool_name : string; + char_name : string; + int_name : integer_type -> string; + str_name : string; + type_decl_kind_to_qualif : decl_kind -> type_decl_kind option -> string; + (** Compute the qualified for a type definition/declaration. + + For instance: "type", "and", etc. + *) + fun_decl_kind_to_qualif : decl_kind -> string; + (** Compute the qualified for a function definition/declaration. + + For instance: "let", "let rec", "and", etc. + *) + field_name : name -> FieldId.id -> string option -> string; + (** Inputs: + - type name + - field id + - field name + + Note that fields don't always have names, but we still need to + generate some names if we want to extract the structures to records... + We might want to extract such structures to tuples, later, but field + access then causes trouble because not all provers accept syntax like + [x.3] where [x] is a tuple. + *) + variant_name : name -> string -> string; + (** Inputs: + - type name + - variant name + *) + struct_constructor : name -> string; + (** Structure constructors are used when constructing structure values. + + For instance, in F*: + {[ + type pair = { x : nat; y : nat } + let p : pair = Mkpair 0 1 + ]} + + Inputs: + - type name + *) + type_name : type_name -> string; + (** Provided a basename, compute a type name. *) + global_name : global_name -> string; + (** Provided a basename, compute a global name. *) + fun_name : + fun_name -> int -> region_group_info option -> bool * int -> string; + (** Compute the name of a regular (non-assumed) function. + + Inputs: + - function id + - function basename (TODO: shouldn't appear for assumed functions?...) + - number of region groups + - region group information in case of a backward function + ([None] if forward function) + - pair: + - do we generate the forward function (it may have been filtered)? + - the number of extracted backward functions (not necessarily equal + to the number of region groups, because we may have filtered + some of them) + TODO: use the fun id for the assumed functions. + *) + decreases_clause_name : A.FunDeclId.id -> fun_name -> string; + (** Generates the name of the definition used to prove/reason about + termination. The generated code uses this clause where needed, + but its body must be defined by the user. + + Inputs: + - function id: this is especially useful to identify whether the + function is an assumed function or a local function + - function basename + *) + var_basename : StringSet.t -> string option -> ty -> string; + (** Generates a variable basename. + + Inputs: + - the set of names used in the context so far + - the basename we got from the symbolic execution, if we have one + - the type of the variable (can be useful for heuristics, in order + not to always use "x" for instance, whenever naming anonymous + variables) + + Note that once the formatter generated a basename, we add an index + if necessary to prevent name clashes: the burden of name clashes checks + is thus on the caller's side. + *) + type_var_basename : StringSet.t -> string -> string; + (** Generates a type variable basename. *) + append_index : string -> int -> string; + (** Appends an index to a name - we use this to generate unique + names: when doing so, the role of the formatter is just to concatenate + indices to names, the responsability of finding a proper index is + delegated to helper functions. + *) + extract_primitive_value : F.formatter -> bool -> primitive_value -> unit; + (** Format a constant value. + + Inputs: + - formatter + - [inside]: if [true], the value should be wrapped in parentheses + if it is made of an application (ex.: [U32 3]) + - the constant value + *) + extract_unop : + (bool -> texpression -> unit) -> + F.formatter -> + bool -> + unop -> + texpression -> + unit; + (** Format a unary operation + + Inputs: + - a formatter for expressions (called on the argument of the unop) + - extraction context (see below) + - formatter + - expression formatter + - [inside] + - unop + - argument + *) + extract_binop : + (bool -> texpression -> unit) -> + F.formatter -> + bool -> + E.binop -> + integer_type -> + texpression -> + texpression -> + unit; + (** Format a binary operation + + Inputs: + - a formatter for expressions (called on the arguments of the binop) + - extraction context (see below) + - formatter + - expression formatter + - [inside] + - binop + - argument 0 + - argument 1 + *) +} + +(** We use identifiers to look for name clashes *) +type id = + | GlobalId of A.GlobalDeclId.id + | FunId of fun_id + | DecreasesClauseId of A.fun_id + (** The definition which provides the decreases/termination clause. + We insert calls to this clause to prove/reason about termination: + the body of those clauses must be defined by the user, in the + proper files. + *) + | TypeId of type_id + | StructId of type_id + (** We use this when we manipulate the names of the structure + constructors. + + For instance, in F*: + {[ + type pair = { x: nat; y : nat } + let p : pair = Mkpair 0 1 + ]} + *) + | VariantId of type_id * VariantId.id + (** If often happens that variant names must be unique (it is the case in + F* ) which is why we register them here. + *) + | FieldId of type_id * FieldId.id + (** If often happens that in the case of structures, the field names + must be unique (it is the case in F* ) which is why we register + them here. + *) + | TypeVarId of TypeVarId.id + | VarId of VarId.id + | UnknownId + (** Used for stored various strings like keywords, definitions which + should always be in context, etc. and which can't be linked to one + of the above. + *) +[@@deriving show, ord] + +module IdOrderedType = struct + type t = id + + let compare = compare_id + let to_string = show_id + let pp_t = pp_id + let show_t = show_id +end + +module IdMap = Collections.MakeMap (IdOrderedType) + +(** The names map stores the mappings from names to identifiers and vice-versa. + + We use it for lookups (during the translation) and to check for name clashes. + + [id_to_string] is for debugging. + *) +type names_map = { + id_to_name : string IdMap.t; + name_to_id : id StringMap.t; + (** The name to id map is used to look for name clashes, and generate nice + debugging messages: if there is a name clash, it is useful to know + precisely which identifiers are mapped to the same name... + *) + names_set : StringSet.t; +} + +let names_map_add (id_to_string : id -> string) (id : id) (name : string) + (nm : names_map) : names_map = + (* Check if there is a clash *) + (match StringMap.find_opt name nm.name_to_id with + | None -> () (* Ok *) + | Some clash -> + (* There is a clash: print a nice debugging message for the user *) + let id1 = "\n- " ^ id_to_string clash in + let id2 = "\n- " ^ id_to_string id in + let err = + "Name clash detected: the following identifiers are bound to the same \ + name \"" ^ name ^ "\":" ^ id1 ^ id2 + in + log#serror err; + raise (Failure err)); + (* Sanity check *) + assert (not (StringSet.mem name nm.names_set)); + (* Insert *) + let id_to_name = IdMap.add id name nm.id_to_name in + let name_to_id = StringMap.add name id nm.name_to_id in + let names_set = StringSet.add name nm.names_set in + { id_to_name; name_to_id; names_set } + +let names_map_add_assumed_type (id_to_string : id -> string) (id : assumed_ty) + (name : string) (nm : names_map) : names_map = + names_map_add id_to_string (TypeId (Assumed id)) name nm + +let names_map_add_assumed_struct (id_to_string : id -> string) (id : assumed_ty) + (name : string) (nm : names_map) : names_map = + names_map_add id_to_string (StructId (Assumed id)) name nm + +let names_map_add_assumed_variant (id_to_string : id -> string) + (id : assumed_ty) (variant_id : VariantId.id) (name : string) + (nm : names_map) : names_map = + names_map_add id_to_string (VariantId (Assumed id, variant_id)) name nm + +let names_map_add_function (id_to_string : id -> string) (fid : fun_id) + (name : string) (nm : names_map) : names_map = + names_map_add id_to_string (FunId fid) name nm + +(** Make a (variable) basename unique (by adding an index). + + We do this in an inefficient manner (by testing all indices starting from + 0) but it shouldn't be a bottleneck. + + Also note that at some point, we thought about trying to reuse names of + variables which are not used anymore, like here: + {[ + let x = ... in + ... + let x0 = ... in // We could use the name "x" if [x] is not used below + ... + ]} + + However it is a good idea to keep things as they are for F*: as F* is + designed for extrinsic proofs, a proof about a function follows this + function's structure. The consequence is that we often end up + copy-pasting function bodies. As in the proofs (in assertions and + when calling lemmas) we often need to talk about the "past" (i.e., + previous values), it is very useful to generate code where all variable + names are assigned at most once. + + [append]: function to append an index to a string + *) +let basename_to_unique (names_set : StringSet.t) + (append : string -> int -> string) (basename : string) : string = + let rec gen (i : int) : string = + let s = append basename i in + if StringSet.mem s names_set then gen (i + 1) else s + in + if StringSet.mem basename names_set then gen 0 else basename + +(** Extraction context. + + Note that the extraction context contains information coming from the + LLBC AST (not only the pure AST). This is useful for naming, for instance: + we use the region information to generate the names of the backward + functions, etc. + *) +type extraction_ctx = { + trans_ctx : trans_ctx; + names_map : names_map; + fmt : formatter; + indent_incr : int; + (** The indent increment we insert whenever we need to indent more *) +} + +(** Debugging function *) +let id_to_string (id : id) (ctx : extraction_ctx) : string = + let global_decls = ctx.trans_ctx.global_context.global_decls in + let fun_decls = ctx.trans_ctx.fun_context.fun_decls in + let type_decls = ctx.trans_ctx.type_context.type_decls in + (* TODO: factorize the pretty-printing with what is in PrintPure *) + let get_type_name (id : type_id) : string = + match id with + | AdtId id -> + let def = TypeDeclId.Map.find id type_decls in + Print.name_to_string def.name + | Assumed aty -> show_assumed_ty aty + | Tuple -> raise (Failure "Unreachable") + in + match id with + | GlobalId gid -> + let name = (A.GlobalDeclId.Map.find gid global_decls).name in + "global name: " ^ Print.global_name_to_string name + | FunId fid -> ( + match fid with + | FromLlbc (fid, rg_id) -> + let fun_name = + match fid with + | Regular fid -> + Print.fun_name_to_string + (A.FunDeclId.Map.find fid fun_decls).name + | Assumed aid -> A.show_assumed_fun_id aid + in + let fun_kind = + match rg_id with + | None -> "forward" + | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id + in + "fun name (" ^ fun_kind ^ "): " ^ fun_name + | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid) + | DecreasesClauseId fid -> + let fun_name = + match fid with + | Regular fid -> + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name + | Assumed aid -> A.show_assumed_fun_id aid + in + "decreases clause for function: " ^ fun_name + | TypeId id -> "type name: " ^ get_type_name id + | StructId id -> "struct constructor of: " ^ get_type_name id + | VariantId (id, variant_id) -> + let variant_name = + match id with + | Tuple -> raise (Failure "Unreachable") + | Assumed State -> raise (Failure "Unreachable") + | Assumed Result -> + if variant_id = result_return_id then "@result::Return" + else if variant_id = result_fail_id then "@result::Fail" + else raise (Failure "Unreachable") + | Assumed Option -> + if variant_id = option_some_id then "@option::Some" + else if variant_id = option_none_id then "@option::None" + else raise (Failure "Unreachable") + | Assumed Vec -> raise (Failure "Unreachable") + | AdtId id -> ( + let def = TypeDeclId.Map.find id type_decls in + match def.kind with + | Struct _ | Opaque -> raise (Failure "Unreachable") + | Enum variants -> + let variant = VariantId.nth variants variant_id in + Print.name_to_string def.name ^ "::" ^ variant.variant_name) + in + "variant name: " ^ variant_name + | FieldId (id, field_id) -> + let field_name = + match id with + | Tuple -> raise (Failure "Unreachable") + | Assumed (State | Result | Option) -> raise (Failure "Unreachable") + | Assumed Vec -> + (* We can't directly have access to the fields of a vector *) + raise (Failure "Unreachable") + | AdtId id -> ( + let def = TypeDeclId.Map.find id type_decls in + match def.kind with + | Enum _ | Opaque -> raise (Failure "Unreachable") + | Struct fields -> + let field = FieldId.nth fields field_id in + let field_name = + match field.field_name with + | None -> FieldId.to_string field_id + | Some name -> name + in + Print.name_to_string def.name ^ "." ^ field_name) + in + "field name: " ^ field_name + | UnknownId -> "keyword" + | TypeVarId _ | VarId _ -> + (* We should never get there: we add indices to make sure variable + * names are unique *) + raise (Failure "Unreachable") + +let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = + (* The id_to_string function to print nice debugging messages if there are + * collisions *) + let id_to_string (id : id) : string = id_to_string id ctx in + let names_map = names_map_add id_to_string id name ctx.names_map in + { ctx with names_map } + +let ctx_get (id : id) (ctx : extraction_ctx) : string = + match IdMap.find_opt id ctx.names_map.id_to_name with + | Some s -> s + | None -> + log#serror ("Could not find: " ^ id_to_string id ctx); + raise Not_found + +let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = + ctx_get (GlobalId id) ctx + +let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = + ctx_get (FunId id) ctx + +let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option) + (ctx : extraction_ctx) : string = + ctx_get_function (FromLlbc (Regular id, rg)) ctx + +let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = + assert (id <> Tuple); + ctx_get (TypeId id) ctx + +let ctx_get_local_type (id : TypeDeclId.id) (ctx : extraction_ctx) : string = + ctx_get_type (AdtId id) ctx + +let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = + ctx_get_type (Assumed id) ctx + +let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = + ctx_get (VarId id) ctx + +let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = + ctx_get (TypeVarId id) ctx + +let ctx_get_field (type_id : type_id) (field_id : FieldId.id) + (ctx : extraction_ctx) : string = + ctx_get (FieldId (type_id, field_id)) ctx + +let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string = + ctx_get (StructId def_id) ctx + +let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) + (ctx : extraction_ctx) : string = + ctx_get (VariantId (def_id, variant_id)) ctx + +let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : + string = + ctx_get (DecreasesClauseId (Regular def_id)) ctx + +(** Generate a unique type variable name and add it to the context *) +let ctx_add_type_var (basename : string) (id : TypeVarId.id) + (ctx : extraction_ctx) : extraction_ctx * string = + let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name + in + let ctx = ctx_add (TypeVarId id) name ctx in + (ctx, name) + +(** See {!ctx_add_type_var} *) +let ctx_add_type_vars (vars : (string * TypeVarId.id) list) + (ctx : extraction_ctx) : extraction_ctx * string list = + List.fold_left_map + (fun ctx (name, id) -> ctx_add_type_var name id ctx) + ctx vars + +(** Generate a unique variable name and add it to the context *) +let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : + extraction_ctx * string = + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + in + let ctx = ctx_add (VarId id) name ctx in + (ctx, name) + +(** See {!ctx_add_var} *) +let ctx_add_vars (vars : var list) (ctx : extraction_ctx) : + extraction_ctx * string list = + List.fold_left_map + (fun ctx (v : var) -> + let name = ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty in + ctx_add_var name v.id ctx) + ctx vars + +let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) : + extraction_ctx * string list = + List.fold_left_map + (fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx) + ctx vars + +let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : + extraction_ctx * string = + let cons_name = ctx.fmt.struct_constructor def.name in + let ctx = ctx_add (StructId (AdtId def.def_id)) cons_name ctx in + (ctx, cons_name) + +let ctx_add_type_decl (def : type_decl) (ctx : extraction_ctx) : extraction_ctx + = + let def_name = ctx.fmt.type_name def.name in + let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in + ctx + +let ctx_add_field (def : type_decl) (field_id : FieldId.id) (field : field) + (ctx : extraction_ctx) : extraction_ctx * string = + let name = ctx.fmt.field_name def.name field_id field.field_name in + let ctx = ctx_add (FieldId (AdtId def.def_id, field_id)) name ctx in + (ctx, name) + +let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list) + (ctx : extraction_ctx) : extraction_ctx * string list = + List.fold_left_map + (fun ctx (vid, v) -> ctx_add_field def vid v ctx) + ctx fields + +let ctx_add_variant (def : type_decl) (variant_id : VariantId.id) + (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = + let name = ctx.fmt.variant_name def.name variant.variant_name in + let ctx = ctx_add (VariantId (AdtId def.def_id, variant_id)) name ctx in + (ctx, name) + +let ctx_add_variants (def : type_decl) + (variants : (VariantId.id * variant) list) (ctx : extraction_ctx) : + extraction_ctx * string list = + List.fold_left_map + (fun ctx (vid, v) -> ctx_add_variant def vid v ctx) + ctx variants + +let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : + extraction_ctx * string = + let name = ctx.fmt.struct_constructor def.name in + let ctx = ctx_add (StructId (AdtId def.def_id)) name ctx in + (ctx, name) + +let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : + extraction_ctx = + let name = ctx.fmt.decreases_clause_name def.def_id def.basename in + ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx + +let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : + extraction_ctx = + let name = ctx.fmt.global_name def.name in + let decl = GlobalId def.def_id in + let body = FunId (FromLlbc (Regular def.body_id, None)) in + let ctx = ctx_add decl (name ^ "_c") ctx in + let ctx = ctx_add body (name ^ "_body") ctx in + ctx + +let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) + (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = + (* Sanity check: the function should not be a global body - those are handled + * separately *) + assert (not def.is_global_decl_body); + (* Lookup the LLBC def to compute the region group information *) + let def_id = def.def_id in + let llbc_def = + A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls + in + let sg = llbc_def.signature in + let num_rgs = List.length sg.regions_hierarchy in + let keep_fwd, (_, backs) = trans_group in + let num_backs = List.length backs in + let rg_info = + match def.back_id with + | None -> None + | Some rg_id -> + let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in + let regions = + List.map + (fun rid -> T.RegionVarId.nth sg.region_params rid) + rg.regions + in + let region_names = + List.map (fun (r : T.region_var) -> r.name) regions + in + Some { id = rg_id; region_names } + in + let name = + ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs) + in + ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx + +type names_map_init = { + keywords : string list; + assumed_adts : (assumed_ty * string) list; + assumed_structs : (assumed_ty * string) list; + assumed_variants : (assumed_ty * VariantId.id * string) list; + assumed_llbc_functions : + (A.assumed_fun_id * RegionGroupId.id option * string) list; + assumed_pure_functions : (pure_assumed_fun_id * string) list; +} + +(** Initialize a names map with a proper set of keywords/names coming from the + target language/prover. *) +let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = + let int_names = List.map fmt.int_name T.all_int_types in + let keywords = + List.concat + [ + [ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords; + ] + in + let names_set = StringSet.of_list keywords in + let name_to_id = + StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords) + in + (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. + * Also note that we don't need this mapping for keywords: we insert keywords only + * to check collisions. *) + let id_to_name = IdMap.empty in + let nm = { id_to_name; name_to_id; names_set } in + (* For debugging - we are creating bindings for assumed types and functions, so + * it is ok if we simply use the "show" function (those aren't simply identified + * by numbers) *) + let id_to_string = show_id in + (* Then we add: + * - the assumed types + * - the assumed struct constructors + * - the assumed variants + * - the assumed functions + *) + let nm = + List.fold_left + (fun nm (type_id, name) -> + names_map_add_assumed_type id_to_string type_id name nm) + nm init.assumed_adts + in + let nm = + List.fold_left + (fun nm (type_id, name) -> + names_map_add_assumed_struct id_to_string type_id name nm) + nm init.assumed_structs + in + let nm = + List.fold_left + (fun nm (type_id, variant_id, name) -> + names_map_add_assumed_variant id_to_string type_id variant_id name nm) + nm init.assumed_variants + in + let assumed_functions = + List.map + (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name)) + init.assumed_llbc_functions + @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions + in + let nm = + List.fold_left + (fun nm (fid, name) -> names_map_add_function id_to_string fid name nm) + nm assumed_functions + in + (* Return *) + nm + +let compute_type_decl_name (fmt : formatter) (def : type_decl) : string = + fmt.type_name def.name + +(** A helper function: generates a function suffix from a region group + information. + TODO: move all those helpers. +*) +let default_fun_suffix (num_region_groups : int) (rg : region_group_info option) + ((keep_fwd, num_backs) : bool * int) : string = + (* There are several cases: + - [rg] is [Some]: this is a forward function: + - we add "_fwd" + - [rg] is [None]: this is a backward function: + - this function has one extracted backward function: + - if the forward function has been filtered, we add "_fwd_back": + the forward function is useless, so the unique backward function + takes its place, in a way + - otherwise we add "_back" + - this function has several backward functions: we add "_back" and an + additional suffix to identify the precise backward function + Note that we always add a suffix (in case there are no region groups, + we could not add the "_fwd" suffix) to prevent name clashes between + definitions (in particular between type and function definitions). + *) + match rg with + | None -> "_fwd" + | Some rg -> + assert (num_region_groups > 0 && num_backs > 0); + if num_backs = 1 then + (* Exactly one backward function *) + if not keep_fwd then "_fwd_back" else "_back" + else if + (* Several region groups/backward functions: + - if all the regions in the group have names, we use those names + - otherwise we use an index + *) + List.for_all Option.is_some rg.region_names + then + (* Concatenate the region names *) + "_back" ^ String.concat "" (List.map Option.get rg.region_names) + else (* Use the region index *) + "_back" ^ RegionGroupId.to_string rg.id diff --git a/compiler/ExtractToBackend.ml b/compiler/ExtractToBackend.ml deleted file mode 100644 index fc04ce90..00000000 --- a/compiler/ExtractToBackend.ml +++ /dev/null @@ -1,1639 +0,0 @@ -(** Extract to F* *) - -open Utils -open Pure -open PureUtils -open TranslateCore -open PureToExtract -open StringUtils -module F = Format - -(** A qualifier for a type definition. - - Controls whether we should use [type ...] or [and ...] (for mutually - recursive datatypes). - *) -type type_decl_qualif = - | Type (** [type t = ...] *) - | And (** [type t0 = ... and t1 = ...] *) - | AssumeType (** [assume type t] *) - | TypeVal (** In an fsti: [val t : Type0] *) - -(** A qualifier for function definitions. - - Controls whether we should use [let ...], [let rec ...] or [and ...], - or only generate a declaration with [val] or [assume val] - *) -type fun_decl_qualif = Let | LetRec | And | Val | AssumeVal - -let fun_decl_qualif_keyword (qualif : fun_decl_qualif) : string = - match qualif with - | Let -> "let" - | LetRec -> "let rec" - | And -> "and" - | Val -> "val" - | AssumeVal -> "assume val" - -(** Small helper to compute the name of an int type *) -let fstar_int_name (int_ty : integer_type) = - match int_ty with - | Isize -> "isize" - | I8 -> "i8" - | I16 -> "i16" - | I32 -> "i32" - | I64 -> "i64" - | I128 -> "i128" - | Usize -> "usize" - | U8 -> "u8" - | U16 -> "u16" - | U32 -> "u32" - | U64 -> "u64" - | U128 -> "u128" - -(** Small helper to compute the name of a unary operation *) -let fstar_unop_name (unop : unop) : string = - match unop with - | Not -> "not" - | Neg int_ty -> fstar_int_name int_ty ^ "_neg" - | Cast _ -> raise (Failure "Unsupported") - -(** Small helper to compute the name of a binary operation (note that many - binary operations like "less than" are extracted to primitive operations, - like [<]. - *) -let fstar_named_binop_name (binop : E.binop) (int_ty : integer_type) : string = - let binop = - match binop with - | Div -> "div" - | Rem -> "rem" - | Add -> "add" - | Sub -> "sub" - | Mul -> "mul" - | _ -> raise (Failure "Unreachable") - in - fstar_int_name int_ty ^ "_" ^ binop - -(** A list of keywords/identifiers used in F* and with which we want to check - collision. *) -let fstar_keywords = - let named_unops = - fstar_unop_name Not - :: List.map (fun it -> fstar_unop_name (Neg it)) T.all_signed_int_types - in - let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in - let named_binops = - List.concat - (List.map - (fun bn -> - List.map (fun it -> fstar_named_binop_name bn it) T.all_int_types) - named_binops) - in - let misc = - [ - "let"; - "rec"; - "in"; - "fn"; - "val"; - "int"; - "nat"; - "list"; - "FStar"; - "FStar.Mul"; - "type"; - "match"; - "with"; - "assert"; - "assert_norm"; - "assume"; - "Type0"; - "Type"; - "unit"; - "not"; - "scalar_cast"; - ] - in - List.concat [ named_unops; named_binops; misc ] - -let fstar_assumed_adts : (assumed_ty * string) list = - [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ] - -let fstar_assumed_structs : (assumed_ty * string) list = [] - -let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list = - [ - (Result, result_return_id, "Return"); - (Result, result_fail_id, "Fail"); - (Option, option_some_id, "Some"); - (Option, option_none_id, "None"); - ] - -let fstar_assumed_llbc_functions : - (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - let rg0 = Some T.RegionGroupId.zero in - [ - (Replace, None, "mem_replace_fwd"); - (Replace, rg0, "mem_replace_back"); - (VecNew, None, "vec_new"); - (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "vec_push_back"); - (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "vec_insert_back"); - (VecLen, None, "vec_len"); - (VecIndex, None, "vec_index_fwd"); - (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); - (VecIndexMut, None, "vec_index_mut_fwd"); - (VecIndexMut, rg0, "vec_index_mut_back"); - ] - -let fstar_assumed_pure_functions : (pure_assumed_fun_id * string) list = - [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] - -let fstar_names_map_init : names_map_init = - { - keywords = fstar_keywords; - assumed_adts = fstar_assumed_adts; - assumed_structs = fstar_assumed_structs; - assumed_variants = fstar_assumed_variants; - assumed_llbc_functions = fstar_assumed_llbc_functions; - assumed_pure_functions = fstar_assumed_pure_functions; - } - -let fstar_extract_unop (extract_expr : bool -> texpression -> unit) - (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit - = - match unop with - | Not | Neg _ -> - let unop = fstar_unop_name unop in - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt unop; - F.pp_print_space fmt (); - extract_expr true arg; - if inside then F.pp_print_string fmt ")" - | Cast (src, tgt) -> - (* The source type is an implicit parameter *) - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt "scalar_cast"; - F.pp_print_space fmt (); - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string src)); - F.pp_print_space fmt (); - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string tgt)); - F.pp_print_space fmt (); - extract_expr true arg; - if inside then F.pp_print_string fmt ")" - -let fstar_extract_binop (extract_expr : bool -> texpression -> unit) - (fmt : F.formatter) (inside : bool) (binop : E.binop) - (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit = - if inside then F.pp_print_string fmt "("; - (* Some binary operations have a special treatment *) - (match binop with - | Eq | Lt | Le | Ne | Ge | Gt -> - let binop = - match binop with - | Eq -> "=" - | Lt -> "<" - | Le -> "<=" - | Ne -> "<>" - | Ge -> ">=" - | Gt -> ">" - | _ -> raise (Failure "Unreachable") - in - extract_expr false arg0; - F.pp_print_space fmt (); - F.pp_print_string fmt binop; - F.pp_print_space fmt (); - extract_expr false arg1 - | Div | Rem | Add | Sub | Mul -> - let binop = fstar_named_binop_name binop int_ty in - F.pp_print_string fmt binop; - F.pp_print_space fmt (); - extract_expr false arg0; - F.pp_print_space fmt (); - extract_expr false arg1 - | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented); - if inside then F.pp_print_string fmt ")" - -(** - [ctx]: we use the context to lookup type definitions, to retrieve type names. - This is used to compute variable names, when they have no basenames: in this - case we use the first letter of the type name. - - [variant_concatenate_type_name]: if true, add the type name as a prefix - to the variant names. - Ex.: - In Rust: - {[ - enum List = { - Cons(u32, Box),x - Nil, - } - ]} - - F*, if option activated: - {[ - type list = - | ListCons : u32 -> list -> list - | ListNil : list - ]} - - F*, if option not activated: - {[ - type list = - | Cons : u32 -> list -> list - | Nil : list - ]} - - Rk.: this should be true by default, because in Rust all the variant names - are actively uniquely identifier by the type name [List::Cons(...)], while - in other languages it is not necessarily the case, and thus clashes can mess - up type checking. Note that some languages actually forbids the name clashes - (it is the case of F* ). - *) -let mk_formatter (ctx : trans_ctx) (crate_name : string) - (variant_concatenate_type_name : bool) : formatter = - let int_name = fstar_int_name in - - (* Prepare a name. - * The first id elem is always the crate: if it is the local crate, - * we remove it. - * We also remove all the disambiguators, then convert everything to strings. - * **Rmk:** because we remove the disambiguators, there may be name collisions - * (which is ok, because we check for name collisions and fail if there is any). - *) - let get_name (name : name) : string list = - (* Rmk.: initially we only filtered the disambiguators equal to 0 *) - let name = Names.filter_disambiguators name in - match name with - | Ident crate :: name -> - let name = if crate = crate_name then name else Ident crate :: name in - let name = - List.map - (function - | Names.Ident s -> s - | Disambiguator d -> Names.Disambiguator.to_string d) - name - in - name - | _ -> - raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name)) - in - let get_type_name = get_name in - let type_name_to_camel_case name = - let name = get_type_name name in - let name = List.map to_camel_case name in - String.concat "" name - in - let type_name_to_snake_case name = - let name = get_type_name name in - let name = List.map to_snake_case name in - String.concat "_" name - in - let type_name name = type_name_to_snake_case name ^ "_t" in - let field_name (def_name : name) (field_id : FieldId.id) - (field_name : string option) : string = - let def_name = type_name_to_snake_case def_name ^ "_" in - match field_name with - | Some field_name -> def_name ^ field_name - | None -> def_name ^ FieldId.to_string field_id - in - let variant_name (def_name : name) (variant : string) : string = - let variant = to_camel_case variant in - if variant_concatenate_type_name then - type_name_to_camel_case def_name ^ variant - else variant - in - let struct_constructor (basename : name) : string = - let tname = type_name basename in - "Mk" ^ tname - in - let get_fun_name = get_name in - let fun_name_to_snake_case (fname : fun_name) : string = - let fname = get_fun_name fname in - (* Converting to snake case should be a no-op, but it doesn't cost much *) - let fname = List.map to_snake_case fname in - (* Concatenate the elements *) - String.concat "_" fname - in - let global_name (name : global_name) : string = - (* Converting to snake case also lowercases the letters (in Rust, global - * names are written in capital letters). *) - let parts = List.map to_snake_case (get_name name) in - String.concat "_" parts - in - let fun_name (fname : fun_name) (num_rgs : int) - (rg : region_group_info option) (filter_info : bool * int) : string = - let fname = fun_name_to_snake_case fname in - (* Compute the suffix *) - let suffix = default_fun_suffix num_rgs rg filter_info in - (* Concatenate *) - fname ^ suffix - in - - let decreases_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) : string - = - let fname = fun_name_to_snake_case fname in - (* Compute the suffix *) - let suffix = "_decreases" in - (* Concatenate *) - fname ^ suffix - in - - let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) - : string = - (* If there is a basename, we use it *) - match basename with - | Some basename -> - (* This should be a no-op *) - to_snake_case basename - | None -> ( - (* No basename: we use the first letter of the type *) - match ty with - | Adt (type_id, tys) -> ( - match type_id with - | Tuple -> - (* The "pair" case is frequent enough to have its special treatment *) - if List.length tys = 2 then "p" else "t" - | Assumed Result -> "r" - | Assumed Option -> "opt" - | Assumed Vec -> "v" - | Assumed State -> "st" - | AdtId adt_id -> - let def = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in - (* We do the following: - * - compute the type name, and retrieve the last ident - * - convert this to snake case - * - take the first letter of every "letter group" - * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm" - *) - (* Thename shouldn't be empty, and its last element should - * be an ident *) - let cl = List.nth def.name (List.length def.name - 1) in - let cl = to_snake_case (Names.as_ident cl) in - let cl = String.split_on_char '_' cl in - let cl = List.filter (fun s -> String.length s > 0) cl in - assert (List.length cl > 0); - let cl = List.map (fun s -> s.[0]) cl in - StringUtils.string_of_chars cl) - | TypeVar _ -> "x" (* lacking imagination here... *) - | Bool -> "b" - | Char -> "c" - | Integer _ -> "i" - | Str -> "s" - | Arrow _ -> "f" - | Array _ | Slice _ -> raise Unimplemented) - in - let type_var_basename (_varset : StringSet.t) (basename : string) : string = - (* This is *not* a no-op: type variables in Rust often start with - * a capital letter *) - to_snake_case basename - in - let append_index (basename : string) (i : int) : string = - basename ^ string_of_int i - in - - let extract_primitive_value (fmt : F.formatter) (_inside : bool) - (cv : primitive_value) : unit = - match cv with - | Scalar sv -> F.pp_print_string fmt (Z.to_string sv.PV.value) - | Bool b -> - let b = if b then "true" else "false" in - F.pp_print_string fmt b - | Char c -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") - | String s -> - (* We need to replace all the line breaks *) - let s = - StringUtils.map - (fun c -> if c = '\n' then "\n" else String.make 1 c) - s - in - F.pp_print_string fmt ("\"" ^ s ^ "\"") - in - { - bool_name = "bool"; - char_name = "char"; - int_name; - str_name = "string"; - field_name; - variant_name; - struct_constructor; - type_name; - global_name; - fun_name; - decreases_clause_name; - var_basename; - type_var_basename; - append_index; - extract_primitive_value; - extract_unop = fstar_extract_unop; - extract_binop = fstar_extract_binop; - } - -(** [inside] constrols whether we should add parentheses or not around type - applications (if [true] we add parentheses). - *) -let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (ty : ty) : unit = - match ty with - | Adt (type_id, tys) -> ( - match type_id with - | Tuple -> - (* This is a bit annoying, but in F* [()] is not the unit type: - * we have to write [unit]... *) - if tys = [] then F.pp_print_string fmt "unit" - else ( - F.pp_print_string fmt "("; - Collections.List.iter_link - (fun () -> - F.pp_print_space fmt (); - F.pp_print_string fmt "&"; - F.pp_print_space fmt ()) - (extract_ty ctx fmt true) tys; - F.pp_print_string fmt ")") - | AdtId _ | Assumed _ -> - let print_paren = inside && tys <> [] in - if print_paren then F.pp_print_string fmt "("; - F.pp_print_string fmt (ctx_get_type type_id ctx); - if tys <> [] then F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_ty ctx fmt true) tys; - if print_paren then F.pp_print_string fmt ")") - | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) - | Bool -> F.pp_print_string fmt ctx.fmt.bool_name - | Char -> F.pp_print_string fmt ctx.fmt.char_name - | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) - | Str -> F.pp_print_string fmt ctx.fmt.str_name - | Arrow (arg_ty, ret_ty) -> - if inside then F.pp_print_string fmt "("; - extract_ty ctx fmt false arg_ty; - F.pp_print_space fmt (); - F.pp_print_string fmt "->"; - F.pp_print_space fmt (); - extract_ty ctx fmt false ret_ty; - if inside then F.pp_print_string fmt ")" - | Array _ | Slice _ -> raise Unimplemented - -(** Compute the names for all the top-level identifiers used in a type - definition (type name, variant names, field names, etc. but not type - parameters). - - We need to do this preemptively, beforce extracting any definition, - because of recursive definitions. - *) -let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : - extraction_ctx = - (* Compute and register the type def name *) - let ctx = ctx_add_type_decl def ctx in - (* Compute and register: - * - the variant names, if this is an enumeration - * - the field names, if this is a structure - *) - let ctx = - match def.kind with - | Struct fields -> - (* Add the fields *) - let ctx = - fst - (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) - in - (* Add the constructor name *) - fst (ctx_add_struct def ctx) - | Enum variants -> - fst - (ctx_add_variants def - (VariantId.mapi (fun id v -> (id, v)) variants) - ctx) - | Opaque -> - (* Nothing to do *) - ctx - in - (* Return *) - ctx - -let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) - (def : type_decl) (fields : field list) : unit = - (* We want to generate a definition which looks like this: - {[ - type t = { x : int; y : bool; } - ]} - - If there isn't enough space on one line: - {[ - type t = - { - x : int; y : bool; - } - ]} - - And if there is even less space: - {[ - type t = - { - x : int; - y : bool; - } - ]} - - Also, in case there are no fields, we need to define the type as [unit] - ([type t = {}] doesn't work in F* ). - *) - (* Note that we already printed: [type t =] *) - if fields = [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "unit") - else ( - F.pp_print_space fmt (); - F.pp_print_string fmt "{"; - F.pp_print_break fmt 1 ctx.indent_incr; - (* The body itself *) - F.pp_open_hvbox fmt 0; - (* Print the fields *) - let print_field (field_id : FieldId.id) (f : field) : unit = - let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in - F.pp_open_box fmt ctx.indent_incr; - F.pp_print_string fmt field_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx fmt false f.field_ty; - F.pp_print_string fmt ";"; - F.pp_close_box fmt () - in - let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - Collections.List.iter_link (F.pp_print_space fmt) - (fun (fid, f) -> print_field fid f) - fields; - (* Close *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - F.pp_print_string fmt "}") - -let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) - (def : type_decl) (def_name : string) (type_params : string list) - (variants : variant list) : unit = - (* We want to generate a definition which looks like this: - {[ - type list a = | Cons : a -> list a -> list a | Nil : list a - ]} - - If there isn't enough space on one line: - {[ - type s = - | Cons : a -> list a -> list a - | Nil : list a - ]} - - And if we need to write the type of a variant on several lines: - {[ - type s = - | Cons : - a -> - list a -> - list a - | Nil : list a - ]} - - Finally, it is possible to give names to the variant fields in Rust. - In this situation, we generate a definition like this: - {[ - type s = - | Cons : hd:a -> tl:list a -> list a - | Nil : list a - ]} - - Note that we already printed: [type s =] - *) - (* Print the variants *) - let print_variant (variant_id : VariantId.id) (variant : variant) : unit = - let variant_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in - F.pp_print_space fmt (); - F.pp_open_hvbox fmt ctx.indent_incr; - (* variant box *) - (* [| Cons :] - * Note that we really don't want any break above so we print everything - * at once. *) - F.pp_print_string fmt ("| " ^ variant_name ^ " :"); - F.pp_print_space fmt (); - let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) : - extraction_ctx = - (* Open the field box *) - F.pp_open_box fmt ctx.indent_incr; - (* Print the field names - * [ x :] - * Note that when printing fields, we register the field names as - * *variables*: they don't need to be unique at the top level. *) - let ctx = - match f.field_name with - | None -> ctx - | Some field_name -> - let var_id = VarId.of_int (FieldId.to_int fid) in - let field_name = - ctx.fmt.var_basename ctx.names_map.names_set (Some field_name) - f.field_ty - in - let ctx, field_name = ctx_add_var field_name var_id ctx in - F.pp_print_string fmt (field_name ^ " :"); - F.pp_print_space fmt (); - ctx - in - (* Print the field type *) - extract_ty ctx fmt false f.field_ty; - (* Print the arrow [->]*) - F.pp_print_space fmt (); - F.pp_print_string fmt "->"; - (* Close the field box *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - (* Return *) - ctx - in - (* Print the fields *) - let fields = FieldId.mapi (fun fid f -> (fid, f)) variant.fields in - let _ = - List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields - in - (* Print the final type *) - F.pp_open_hovbox fmt 0; - F.pp_print_string fmt def_name; - List.iter - (fun type_param -> - F.pp_print_space fmt (); - F.pp_print_string fmt type_param) - type_params; - F.pp_close_box fmt (); - (* Close the variant box *) - F.pp_close_box fmt () - in - (* Print the variants *) - let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in - List.iter (fun (vid, v) -> print_variant vid v) variants - -(** Extract a type declaration. - - Note that all the names used for extraction should already have been - registered. - *) -let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) - (qualif : type_decl_qualif) (def : type_decl) : unit = - (* Retrieve the definition name *) - let def_name = ctx_get_local_type def.def_id ctx in - (* Add the type params - note that we need those bindings only for the - * body translation (they are not top-level) *) - let ctx_body, type_params = ctx_add_type_params def.type_params ctx in - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt ("(** [" ^ Print.name_to_string def.name ^ "] *)"); - F.pp_print_space fmt (); - (* Open a box for the definition, so that whenever possible it gets printed on - * one line *) - F.pp_open_hvbox fmt 0; - (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* > "type TYPE_NAME" *) - let extract_body, qualif = - match qualif with - | Type -> (true, "type") - | And -> (true, "and") - | AssumeType -> (false, "assume type") - | TypeVal -> (false, "val") - in - F.pp_print_string fmt (qualif ^ " " ^ def_name); - (* Print the type parameters *) - if def.type_params <> [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx_body in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0)"); - (* Print the "=" if we extract the body*) - if extract_body then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "=") - else ( - (* Otherwise print ": Type0" *) - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0"); - (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *) - F.pp_close_box fmt (); - (if extract_body then - match def.kind with - | Struct fields -> extract_type_decl_struct_body ctx_body fmt def fields - | Enum variants -> - extract_type_decl_enum_body ctx_body fmt def def_name type_params - variants - | Opaque -> raise (Failure "Unreachable")); - (* Close the box for the definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - -(** Extract the state type declaration. *) -let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) - (qualif : type_decl_qualif) : unit = - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment *) - F.pp_print_string fmt "(** The state type used in the state-error monad *)"; - F.pp_print_space fmt (); - (* Open a box for the definition, so that whenever possible it gets printed on - * one line *) - F.pp_open_hvbox fmt 0; - (* Retrieve the name *) - let state_name = ctx_get_assumed_type State ctx in - (* The qualif should be [AssumeType] or [TypeVal] *) - (match qualif with - | Type | And -> raise (Failure "Unexpected") - | AssumeType -> - F.pp_print_string fmt "assume"; - F.pp_print_space fmt (); - F.pp_print_string fmt "type"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0" - | TypeVal -> - F.pp_print_string fmt "val"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0"); - (* Close the box for the definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - -(** Compute the names for all the pure functions generated from a rust function - (forward function and backward functions). - *) -let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) - (has_decreases_clause : bool) (def : pure_fun_translation) : extraction_ctx - = - let fwd, back_ls = def in - (* Register the decrease clause, if necessary *) - let ctx = - if has_decreases_clause then ctx_add_decrases_clause fwd ctx else ctx - in - (* Register the forward function name *) - let ctx = ctx_add_fun_decl (keep_fwd, def) fwd ctx in - (* Register the backward functions' names *) - let ctx = - List.fold_left - (fun ctx back -> ctx_add_fun_decl (keep_fwd, def) back ctx) - ctx back_ls - in - (* Return *) - ctx - -(** Simply add the global name to the context. *) -let extract_global_decl_register_names (ctx : extraction_ctx) - (def : A.global_decl) : extraction_ctx = - ctx_add_global_decl_and_body def ctx - -(** The following function factorizes the extraction of ADT values. - - Note that patterns can introduce new variables: we thus return an extraction - context updated with new bindings. - - TODO: we don't need something very generic anymore - *) -let extract_adt_g_value - (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) - (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) - (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : - extraction_ctx = - match ty with - | Adt (Tuple, _) -> - (* Tuple *) - F.pp_print_string fmt "("; - let ctx = - Collections.List.fold_left_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (fun ctx v -> extract_value ctx false v) - ctx field_values - in - F.pp_print_string fmt ")"; - ctx - | Adt (adt_id, _) -> - (* "Regular" ADT *) - (* We print something of the form: [Cons field0 ... fieldn]. - * We could update the code to print something of the form: - * [{ field0=...; ...; fieldn=...; }] in case of structures. - *) - let cons = - match variant_id with - | Some vid -> ctx_get_variant adt_id vid ctx - | None -> ctx_get_struct adt_id ctx - in - if inside && field_values <> [] then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - let ctx = - Collections.List.fold_left - (fun ctx v -> - F.pp_print_space fmt (); - extract_value ctx true v) - ctx field_values - in - if inside && field_values <> [] then F.pp_print_string fmt ")"; - ctx - | _ -> raise (Failure "Inconsistent typed value") - -(* Extract globals in the same way as variables *) -let extract_global (ctx : extraction_ctx) (fmt : F.formatter) - (id : A.GlobalDeclId.id) : unit = - F.pp_print_string fmt (ctx_get_global id ctx) - -(** [inside]: see {!extract_ty}. - - As a pattern can introduce new variables, we return an extraction context - updated with new bindings. - *) -let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_pattern) : extraction_ctx = - match v.value with - | PatConstant cv -> - ctx.fmt.extract_primitive_value fmt inside cv; - ctx - | PatVar (v, _) -> - let vname = - ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty - in - let ctx, vname = ctx_add_var vname v.id ctx in - F.pp_print_string fmt vname; - ctx - | PatDummy -> - F.pp_print_string fmt "_"; - ctx - | PatAdt av -> - let extract_value ctx inside v = extract_typed_pattern ctx fmt inside v in - extract_adt_g_value extract_value fmt ctx inside av.variant_id - av.field_values v.ty - -(** [inside]: controls the introduction of parentheses. See [extract_ty] - - TODO: replace the formatting boolean [inside] with something more general? - Also, it seems we don't really use it... - Cases to consider: - - right-expression in a let: [let x = re in _] (never parentheses?) - - next expression in a let: [let x = _ in next_e] (never parentheses?) - - application argument: [f (exp)] - - match/if scrutinee: [if exp then _ else _]/[match exp | _ -> _] - *) -let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (e : texpression) : unit = - match e.e with - | Var var_id -> - let var_name = ctx_get_var var_id ctx in - F.pp_print_string fmt var_name - | Const cv -> ctx.fmt.extract_primitive_value fmt inside cv - | App _ -> - let app, args = destruct_apps e in - extract_App ctx fmt inside app args - | Abs _ -> - let xl, e = destruct_abs_list e in - extract_Abs ctx fmt inside xl e - | Qualif _ -> - (* We use the app case *) - extract_App ctx fmt inside e [] - | Let (monadic, lv, re, next_e) -> - extract_Let ctx fmt inside monadic lv re next_e - | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body - | Meta (_, e) -> extract_texpression ctx fmt inside e - -(* Extract an application *or* a top-level qualif (function extraction has - * to handle top-level qualifiers, so it seemed more natural to merge the - * two cases) *) -and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (app : texpression) (args : texpression list) : unit = - (* We don't do the same thing if the app is a top-level identifier (function, - * ADT constructor...) or a "regular" expression *) - match app.e with - | Qualif qualif -> ( - (* Top-level qualifier *) - match qualif.id with - | FunOrOp fun_id -> - extract_function_call ctx fmt inside fun_id qualif.type_args args - | Global global_id -> extract_global ctx fmt global_id - | AdtCons adt_cons_id -> - extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args - | Proj proj -> - extract_field_projector ctx fmt inside app proj qualif.type_args args) - | _ -> - (* "Regular" expression *) - (* Open parentheses *) - if inside then F.pp_print_string fmt "("; - (* Open a box for the application *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the app expression *) - let app_inside = (inside && args = []) || args <> [] in - extract_texpression ctx fmt app_inside app; - (* Print the arguments *) - List.iter - (fun ve -> - F.pp_print_space fmt (); - extract_texpression ctx fmt true ve) - args; - (* Close the box for the application *) - F.pp_close_box fmt (); - (* Close parentheses *) - if inside then F.pp_print_string fmt ")" - -(** Subcase of the app case: function call *) -and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (fid : fun_or_op_id) (type_args : ty list) - (args : texpression list) : unit = - match (fid, args) with - | Unop unop, [ arg ] -> - (* A unop can have *at most* one argument (the result can't be a function!). - * Note that the way we generate the translation, we shouldn't get the - * case where we have no argument (all functions are fully instantiated, - * and no AST transformation introduces partial calls). *) - ctx.fmt.extract_unop (extract_texpression ctx fmt) fmt inside unop arg - | Binop (binop, int_ty), [ arg0; arg1 ] -> - (* Number of arguments: similar to unop *) - ctx.fmt.extract_binop - (extract_texpression ctx fmt) - fmt inside binop int_ty arg0 arg1 - | Fun fun_id, _ -> - if inside then F.pp_print_string fmt "("; - (* Open a box for the function call *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the function name *) - let fun_name = ctx_get_function fun_id ctx in - F.pp_print_string fmt fun_name; - (* Print the type parameters *) - List.iter - (fun ty -> - F.pp_print_space fmt (); - extract_ty ctx fmt true ty) - type_args; - (* Print the arguments *) - List.iter - (fun ve -> - F.pp_print_space fmt (); - extract_texpression ctx fmt true ve) - args; - (* Close the box for the function call *) - F.pp_close_box fmt (); - (* Return *) - if inside then F.pp_print_string fmt ")" - | (Unop _ | Binop _), _ -> - raise - (Failure - ("Unreachable:\n" ^ "Function: " ^ show_fun_or_op_id fid - ^ ",\nNumber of arguments: " - ^ string_of_int (List.length args) - ^ ",\nArguments: " - ^ String.concat " " (List.map show_texpression args))) - -(** Subcase of the app case: ADT constructor *) -and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : - unit = - match adt_cons.adt_id with - | Tuple -> - (* Tuple *) - (* For now, we only support fully applied tuple constructors *) - assert (List.length type_args = List.length args); - F.pp_print_string fmt "("; - Collections.List.iter_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (fun v -> extract_texpression ctx fmt false v) - args; - F.pp_print_string fmt ")" - | _ -> - (* "Regular" ADT *) - (* We print something of the form: [Cons field0 ... fieldn]. - * We could update the code to print something of the form: - * [{ field0=...; ...; fieldn=...; }] in case of fully - * applied structure constructors. - *) - let cons = - match adt_cons.variant_id with - | Some vid -> ctx_get_variant adt_cons.adt_id vid ctx - | None -> ctx_get_struct adt_cons.adt_id ctx - in - let use_parentheses = inside && args <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - Collections.List.iter - (fun v -> - F.pp_print_space fmt (); - extract_texpression ctx fmt true v) - args; - if use_parentheses then F.pp_print_string fmt ")" - -(** Subcase of the app case: ADT field projector. *) -and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (original_app : texpression) (proj : projection) - (_proj_type_params : ty list) (args : texpression list) : unit = - (* We isolate the first argument (if there is), in order to pretty print the - * projection ([x.field] instead of [MkAdt?.field x] *) - match args with - | [ arg ] -> - (* Exactly one argument: pretty-print *) - let field_name = ctx_get_field proj.adt_id proj.field_id ctx in - (* Open a box *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Extract the expression *) - extract_texpression ctx fmt true arg; - (* We allow to break where the "." appears *) - F.pp_print_break fmt 0 0; - F.pp_print_string fmt "."; - F.pp_print_string fmt field_name; - (* Close the box *) - F.pp_close_box fmt () - | arg :: args -> - (* Call extract_App again, but in such a way that the first argument is - * isolated *) - extract_App ctx fmt inside (mk_app original_app arg) args - | [] -> - (* No argument: shouldn't happen *) - raise (Failure "Unreachable") - -and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (xl : typed_pattern list) (e : texpression) : unit = - (* Open a box for the abs expression *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Open parentheses *) - if inside then F.pp_print_string fmt "("; - (* Print the lambda - note that there should always be at least one variable *) - assert (xl <> []); - F.pp_print_string fmt "fun"; - let ctx = - List.fold_left - (fun ctx x -> - F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true x) - ctx xl - in - F.pp_print_space fmt (); - F.pp_print_string fmt "->"; - F.pp_print_space fmt (); - (* Print the body *) - extract_texpression ctx fmt false e; - (* Close parentheses *) - if inside then F.pp_print_string fmt ")"; - (* Close the box for the abs expression *) - F.pp_close_box fmt () - -and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (monadic : bool) (lv : typed_pattern) (re : texpression) - (next_e : texpression) : unit = - (* Open a box for the whole expression *) - F.pp_open_hvbox fmt 0; - (* Open parentheses *) - if inside then F.pp_print_string fmt "("; - (* Open a box for the let-binding *) - F.pp_open_hovbox fmt ctx.indent_incr; - let ctx = - if monadic then ( - (* Note that in F*, the left value of a monadic let-binding can only be - * a variable *) - let ctx = extract_typed_pattern ctx fmt true lv in - F.pp_print_space fmt (); - F.pp_print_string fmt "<--"; - F.pp_print_space fmt (); - extract_texpression ctx fmt false re; - F.pp_print_string fmt ";"; - ctx) - else ( - F.pp_print_string fmt "let"; - F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt true lv in - F.pp_print_space fmt (); - F.pp_print_string fmt "="; - F.pp_print_space fmt (); - extract_texpression ctx fmt false re; - F.pp_print_space fmt (); - F.pp_print_string fmt "in"; - ctx) - in - (* Close the box for the let-binding *) - F.pp_close_box fmt (); - (* Print the next expression *) - F.pp_print_space fmt (); - extract_texpression ctx fmt false next_e; - (* Close parentheses *) - if inside then F.pp_print_string fmt ")"; - (* Close the box for the whole expression *) - F.pp_close_box fmt () - -and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (scrut : texpression) (body : switch_body) : unit = - (* Open a box for the whole expression *) - F.pp_open_hvbox fmt 0; - (* Open parentheses *) - if inside then F.pp_print_string fmt "("; - (* Extract the switch *) - (match body with - | If (e_then, e_else) -> - (* Open a box for the [if] *) - F.pp_open_hovbox fmt ctx.indent_incr; - F.pp_print_string fmt "if"; - F.pp_print_space fmt (); - let scrut_inside = PureUtils.let_group_requires_parentheses scrut in - extract_texpression ctx fmt scrut_inside scrut; - (* Close the box for the [if] *) - F.pp_close_box fmt (); - (* Extract the branches *) - let extract_branch (is_then : bool) (e_branch : texpression) : unit = - F.pp_print_space fmt (); - (* Open a box for the then/else+branch *) - F.pp_open_hovbox fmt ctx.indent_incr; - let then_or_else = if is_then then "then" else "else" in - F.pp_print_string fmt then_or_else; - F.pp_print_space fmt (); - (* Open a box for the branch *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the [begin] if necessary *) - let parenth = PureUtils.let_group_requires_parentheses e_branch in - if parenth then ( - F.pp_print_string fmt "begin"; - F.pp_print_space fmt ()); - (* Print the branch expression *) - extract_texpression ctx fmt false e_branch; - (* Close the [begin ... end ] *) - if parenth then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "end"); - (* Close the box for the branch *) - F.pp_close_box fmt (); - (* Close the box for the then/else+branch *) - F.pp_close_box fmt () - in - - extract_branch true e_then; - extract_branch false e_else - | Match branches -> - (* Open a box for the [match ... with] *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the [match ... with] *) - F.pp_print_string fmt "begin match"; - F.pp_print_space fmt (); - let scrut_inside = PureUtils.let_group_requires_parentheses scrut in - extract_texpression ctx fmt scrut_inside scrut; - F.pp_print_space fmt (); - F.pp_print_string fmt "with"; - (* Close the box for the [match ... with] *) - F.pp_close_box fmt (); - - (* Extract the branches *) - let extract_branch (br : match_branch) : unit = - F.pp_print_space fmt (); - (* Open a box for the pattern+branch *) - F.pp_open_hovbox fmt ctx.indent_incr; - F.pp_print_string fmt "|"; - (* Print the pattern *) - F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false br.pat in - F.pp_print_space fmt (); - F.pp_print_string fmt "->"; - F.pp_print_space fmt (); - (* Open a box for the branch *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the branch itself *) - extract_texpression ctx fmt false br.branch; - (* Close the box for the branch *) - F.pp_close_box fmt (); - (* Close the box for the pattern+branch *) - F.pp_close_box fmt () - in - - List.iter extract_branch branches; - - (* End the match *) - F.pp_print_space fmt (); - F.pp_print_string fmt "end"); - (* Close parentheses *) - if inside then F.pp_print_string fmt ")"; - (* Close the box for the whole expression *) - F.pp_close_box fmt () - -(** A small utility to print the parameters of a function signature. - - We return two contexts: - - the context augmented with bindings for the type parameters - - the previous context augmented with bindings for the input values - *) -let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter) - (def : fun_decl) : extraction_ctx * extraction_ctx = - (* Add the type parameters - note that we need those bindings only for the - * body translation (they are not top-level) *) - let ctx, _ = ctx_add_type_params def.signature.type_params ctx in - (* Print the parameters - rk.: we should have filtered the functions - * with no input parameters *) - (* The type parameters *) - if def.signature.type_params <> [] then ( - (* Open a box for the type parameters *) - F.pp_open_hovbox fmt 0; - F.pp_print_string fmt "("; - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.signature.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0)"; - (* Close the box for the type parameters *) - F.pp_close_box fmt (); - F.pp_print_space fmt ()); - (* The input parameters - note that doing this adds bindings to the context *) - let ctx_body = - match def.body with - | None -> ctx - | Some body -> - List.fold_left - (fun ctx (lv : typed_pattern) -> - (* Open a box for the input parameter *) - F.pp_open_hovbox fmt 0; - F.pp_print_string fmt "("; - let ctx = extract_typed_pattern ctx fmt false lv in - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx fmt false lv.ty; - F.pp_print_string fmt ")"; - (* Close the box for the input parameters *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - ctx) - ctx body.inputs_lvs - in - (ctx, ctx_body) - -(** A small utility to print the types of the input parameters in the form: - [u32 -> list u32 -> ...] - (we don't print the return type of the function) - - This is used for opaque function declarations, in particular. - *) -let extract_fun_input_parameters_types (ctx : extraction_ctx) - (fmt : F.formatter) (def : fun_decl) : unit = - let extract_param (ty : ty) : unit = - let inside = false in - extract_ty ctx fmt inside ty; - F.pp_print_space fmt (); - F.pp_print_string fmt "->"; - F.pp_print_space fmt () - in - List.iter extract_param def.signature.inputs - -(** Extract a decrease clause function template body. - - In order to help the user, we can generate a template for the functions - required by the decreases clauses. We simply generate definitions of - the following form in a separate file: - {[ - let f_decrease (t : Type0) (x : t) : nat = admit() - ]} - - Where the translated functions for [f] look like this: - {[ - let f_fwd (t : Type0) (x : t) : Tot ... (decreases (f_decrease t x)) = ... - ]} - *) -let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) - (def : fun_decl) : unit = - (* Retrieve the function name *) - let def_name = ctx_get_decreases_clause def.def_id ctx in - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt - ("(** [" ^ Print.fun_name_to_string def.basename ^ "]: decreases clause *)"); - F.pp_print_space fmt (); - (* Open a box for the definition, so that whenever possible it gets printed on - * one line *) - F.pp_open_hvbox fmt 0; - (* Add the [unfold] keyword *) - F.pp_print_string fmt "unfold"; - F.pp_print_space fmt (); - (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) - F.pp_open_hvbox fmt ctx.indent_incr; - (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* > "let FUN_NAME" *) - F.pp_print_string fmt ("let " ^ def_name); - F.pp_print_space fmt (); - (* Extract the parameters *) - let _, _ = extract_fun_parameters ctx fmt def in - F.pp_print_string fmt ":"; - (* Print the signature *) - F.pp_print_space fmt (); - F.pp_print_string fmt "nat"; - (* Print the "=" *) - F.pp_print_space fmt (); - F.pp_print_string fmt "="; - (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - (* Print the "admit ()" *) - F.pp_print_string fmt "admit ()"; - (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) - F.pp_close_box fmt (); - (* Close the box for the whole definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - -(** Extract a function declaration. - - Note that all the names used for extraction should already have been - registered. - - We take the definition of the forward translation as parameter (which is - equal to the definition to extract, if we extract a forward function) because - it is useful for the decrease clause. - *) -let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) - (qualif : fun_decl_qualif) (has_decreases_clause : bool) (def : fun_decl) : - unit = - assert (not def.is_global_decl_body); - (* Retrieve the function name *) - let def_name = ctx_get_local_function def.def_id def.back_id ctx in - (* (* Add the type parameters - note that we need those bindings only for the - * body translation (they are not top-level) *) - let ctx, _ = ctx_add_type_params def.signature.type_params ctx in *) - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment to link the extracted type to its original rust definition *) - F.pp_print_string fmt - ("(** [" ^ Print.fun_name_to_string def.basename ^ "] *)"); - F.pp_print_space fmt (); - (* Open a box for the definition, so that whenever possible it gets printed on - * one line *) - F.pp_open_hvbox fmt ctx.indent_incr; - (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* > "let FUN_NAME" *) - let is_opaque = Option.is_none def.body in - let qualif = fun_decl_qualif_keyword qualif in - F.pp_print_string fmt (qualif ^ " " ^ def_name); - F.pp_print_space fmt (); - (* Open a box for "(PARAMS) : EFFECT =" *) - F.pp_open_hvbox fmt 0; - (* Open a box for "(PARAMS)" *) - F.pp_open_hovbox fmt 0; - let ctx, ctx_body = extract_fun_parameters ctx fmt def in - (* Close the box for "(PARAMS)" *) - F.pp_close_box fmt (); - (* Print the return type - note that we have to be careful when - * printing the input values for the decrease clause, because - * it introduces bindings in the context... We thus "forget" - * the bindings we introduced above. - * TODO: figure out a cleaner way *) - let _ = - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - (* Open a box for the EFFECT *) - F.pp_open_hvbox fmt 0; - (* Open a box for the return type *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the return type *) - (* For opaque definitions, as we don't have named parameters under the hand, - * we don't print parameters in the form [(x : a) (y : b) ...] above, - * but wait until here to print the types: [a -> b -> ...]. *) - if is_opaque then extract_fun_input_parameters_types ctx fmt def; - (* [Tot] *) - if has_decreases_clause then ( - F.pp_print_string fmt "Tot"; - F.pp_print_space fmt ()); - extract_ty ctx fmt has_decreases_clause def.signature.output; - (* Close the box for the return type *) - F.pp_close_box fmt (); - (* Print the decrease clause - rk.: a function with a decreases clause - * is necessarily a transparent function *) - if has_decreases_clause then ( - F.pp_print_space fmt (); - (* Open a box for the decrease clause *) - F.pp_open_hovbox fmt 0; - (* *) - F.pp_print_string fmt "(decreases"; - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - (* The name of the decrease clause *) - let decr_name = ctx_get_decreases_clause def.def_id ctx in - F.pp_print_string fmt decr_name; - (* Print the type parameters *) - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in - F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.type_params; - (* Print the input values: we have to be careful here to print - * only the input values which are in common with the *forward* - * function (the additional input values "given back" to the - * backward functions have no influence on termination: we thus - * share the decrease clauses between the forward and the backward - * functions - we also ignore the additional state received by the - * backward function, if there is one). - *) - let inputs_lvs = - let all_inputs = (Option.get def.body).inputs_lvs in - let num_fwd_inputs = def.signature.info.num_fwd_inputs_with_state in - Collections.List.prefix num_fwd_inputs all_inputs - in - let _ = - List.fold_left - (fun ctx (lv : typed_pattern) -> - F.pp_print_space fmt (); - let ctx = extract_typed_pattern ctx fmt false lv in - ctx) - ctx inputs_lvs - in - F.pp_print_string fmt "))"; - (* Close the box for the decrease clause *) - F.pp_close_box fmt ()); - (* Close the box for the EFFECT *) - F.pp_close_box fmt () - in - (* Print the "=" *) - if not is_opaque then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "="); - (* Close the box for "(PARAMS) : EFFECT =" *) - F.pp_close_box fmt (); - (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) - F.pp_close_box fmt (); - if not is_opaque then ( - F.pp_print_space fmt (); - (* Open a box for the body *) - F.pp_open_hvbox fmt 0; - (* Extract the body *) - let _ = extract_texpression ctx_body fmt false (Option.get def.body).body in - (* Close the box for the body *) - F.pp_close_box fmt ()); - (* Close the box for the definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - -(** Extract a global declaration body of the shape "QUALIF NAME : TYPE = BODY" with a custom body extractor *) -let extract_global_decl_body (ctx : extraction_ctx) (fmt : F.formatter) - (qualif : fun_decl_qualif) (name : string) (ty : ty) - (extract_body : (F.formatter -> unit) Option.t) : unit = - let is_opaque = Option.is_none extract_body in - - (* Open the definition box (depth=0) *) - F.pp_open_hvbox fmt ctx.indent_incr; - - (* Open "QUALIF NAME : TYPE =" box (depth=1) *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print "QUALIF NAME " *) - F.pp_print_string fmt (fun_decl_qualif_keyword qualif ^ " " ^ name); - F.pp_print_space fmt (); - - (* Open ": TYPE =" box (depth=2) *) - F.pp_open_hvbox fmt 0; - (* Print ": " *) - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - - (* Open "TYPE" box (depth=3) *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print "TYPE" *) - extract_ty ctx fmt false ty; - (* Close "TYPE" box (depth=3) *) - F.pp_close_box fmt (); - - if not is_opaque then ( - (* Print " =" *) - F.pp_print_space fmt (); - F.pp_print_string fmt "="); - (* Close ": TYPE =" box (depth=2) *) - F.pp_close_box fmt (); - (* Close "QUALIF NAME : TYPE =" box (depth=1) *) - F.pp_close_box fmt (); - - if not is_opaque then ( - F.pp_print_space fmt (); - (* Open "BODY" box (depth=1) *) - F.pp_open_hvbox fmt 0; - (* Print "BODY" *) - (Option.get extract_body) fmt; - (* Close "BODY" box (depth=1) *) - F.pp_close_box fmt ()); - (* Close the definition box (depth=0) *) - F.pp_close_box fmt () - -(** Extract a global declaration. - We generate the body which computes the global value separately from the value declaration itself. - - For example in Rust, - [static X: u32 = 3;] - - will be translated to: - [let x_body : result u32 = Return 3] - [let x_c : u32 = eval_global x_body] - *) -let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) - (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = - assert body.is_global_decl_body; - assert (Option.is_none body.back_id); - assert (List.length body.signature.inputs = 0); - assert (List.length body.signature.doutputs = 1); - assert (List.length body.signature.type_params = 0); - - (* Add a break then the name of the corresponding LLBC declaration *) - F.pp_print_break fmt 0 0; - F.pp_print_string fmt - ("(** [" ^ Print.global_name_to_string global.name ^ "] *)"); - F.pp_print_space fmt (); - - let decl_name = ctx_get_global global.def_id ctx in - let body_name = - ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx - in - - let decl_ty, body_ty = - let ty = body.signature.output in - if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) - else (ty, mk_result_ty ty) - in - match body.body with - | None -> - let qualif = if interface then Val else AssumeVal in - extract_global_decl_body ctx fmt qualif decl_name decl_ty None - | Some body -> - extract_global_decl_body ctx fmt Let body_name body_ty - (Some (fun fmt -> extract_texpression ctx fmt false body.body)); - F.pp_print_break fmt 0 0; - extract_global_decl_body ctx fmt Let decl_name decl_ty - (Some (fun fmt -> F.pp_print_string fmt ("eval_global " ^ body_name))); - F.pp_print_break fmt 0 0 - -(** Extract a unit test, if the function is a unit function (takes no - parameters, returns unit). - - A unit test simply checks that the function normalizes to [Return ()]: - {[ - let _ = assert_norm (FUNCTION () = Return ()) - ]} - *) -let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) - (def : fun_decl) : unit = - (* We only insert unit tests for forward functions *) - assert (def.back_id = None); - (* Check if this is a unit function *) - let sg = def.signature in - if - sg.type_params = [] - && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) - && sg.output = mk_result_ty mk_unit_ty - then ( - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment *) - F.pp_print_string fmt - ("(** Unit test for [" ^ Print.fun_name_to_string def.basename ^ "] *)"); - F.pp_print_space fmt (); - (* Open a box for the test *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the test *) - F.pp_print_string fmt "let _ ="; - F.pp_print_space fmt (); - F.pp_print_string fmt "assert_norm"; - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - let fun_name = ctx_get_local_function def.def_id def.back_id ctx in - F.pp_print_string fmt fun_name; - if sg.inputs <> [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "()"); - F.pp_print_space fmt (); - F.pp_print_string fmt "="; - F.pp_print_space fmt (); - let success = ctx_get_variant (Assumed Result) result_return_id ctx in - F.pp_print_string fmt (success ^ " ())"); - (* Close the box for the test *) - F.pp_close_box fmt (); - (* Add a break after *) - F.pp_print_break fmt 0 0) - else (* Do nothing *) - () diff --git a/compiler/ExtractToCoq.ml b/compiler/ExtractToCoq.ml new file mode 100644 index 00000000..3681adc3 --- /dev/null +++ b/compiler/ExtractToCoq.ml @@ -0,0 +1,8 @@ +(** Utilities for the extraction to Coq *) + +open Utils +open Pure +open TranslateCore +open ExtractBase +open StringUtils +module F = Format diff --git a/compiler/ExtractToFStar.ml b/compiler/ExtractToFStar.ml new file mode 100644 index 00000000..21a6fc8f --- /dev/null +++ b/compiler/ExtractToFStar.ml @@ -0,0 +1,8 @@ +(** Utilities for the extraction to F* *) + +open Utils +open Pure +open TranslateCore +open ExtractBase +open StringUtils +module F = Format diff --git a/compiler/Logging.ml b/compiler/Logging.ml index 5f22506b..1d57fa5b 100644 --- a/compiler/Logging.ml +++ b/compiler/Logging.ml @@ -18,8 +18,8 @@ let symbolic_to_pure_log = L.get_logger "MainLogger.SymbolicToPure" (** Logger for PureMicroPasses *) let pure_micro_passes_log = L.get_logger "MainLogger.PureMicroPasses" -(** Logger for PureToExtract *) -let pure_to_extract_log = L.get_logger "MainLogger.PureToExtract" +(** Logger for ExtractBase *) +let pure_to_extract_log = L.get_logger "MainLogger.ExtractBase" (** Logger for Interpreter *) let interpreter_log = L.get_logger "MainLogger.Interpreter" diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 30fc4989..7b261516 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1250,6 +1250,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Simplify the aggregated ADTs. + Ex.: {[ type struct = { f0 : nat; f1 : nat } diff --git a/compiler/PureToExtract.ml b/compiler/PureToExtract.ml deleted file mode 100644 index 25ad6713..00000000 --- a/compiler/PureToExtract.ml +++ /dev/null @@ -1,734 +0,0 @@ -(** This module is used to extract the pure ASTs to various theorem provers. - It defines utilities and helpers to make the work as easy as possible: - we try to factorize as much as possible the different extractions to the - backends we target. - *) - -open Pure -open TranslateCore -module C = Contexts -module RegionVarId = T.RegionVarId -module F = Format - -(** The local logger *) -let log = L.pure_to_extract_log - -type region_group_info = { - id : RegionGroupId.id; - (** The id of the region group. - Note that a simple way of generating unique names for backward - functions is to use the region group ids. - *) - region_names : string option list; - (** The names of the region variables included in this group. - Note that names are not always available... - *) -} - -module StringSet = Collections.MakeSet (Collections.OrderedString) -module StringMap = Collections.MakeMap (Collections.OrderedString) - -type name = Names.name -type type_name = Names.type_name -type global_name = Names.global_name -type fun_name = Names.fun_name - -(* TODO: this should a module we give to a functor! *) - -(** A formatter's role is twofold: - 1. Come up with name suggestions. - For instance, provided some information about a function (its basename, - information about the region group, etc.) it should come up with an - appropriate name for the forward/backward function. - - It can of course apply many transformations, like changing to camel case/ - snake case, adding prefixes/suffixes, etc. - - 2. Format some specific terms, like constants. - *) -type formatter = { - bool_name : string; - char_name : string; - int_name : integer_type -> string; - str_name : string; - field_name : name -> FieldId.id -> string option -> string; - (** Inputs: - - type name - - field id - - field name - - Note that fields don't always have names, but we still need to - generate some names if we want to extract the structures to records... - We might want to extract such structures to tuples, later, but field - access then causes trouble because not all provers accept syntax like - [x.3] where [x] is a tuple. - *) - variant_name : name -> string -> string; - (** Inputs: - - type name - - variant name - *) - struct_constructor : name -> string; - (** Structure constructors are used when constructing structure values. - - For instance, in F*: - {[ - type pair = { x : nat; y : nat } - let p : pair = Mkpair 0 1 - ]} - - Inputs: - - type name - *) - type_name : type_name -> string; - (** Provided a basename, compute a type name. *) - global_name : global_name -> string; - (** Provided a basename, compute a global name. *) - fun_name : - fun_name -> int -> region_group_info option -> bool * int -> string; - (** Compute the name of a regular (non-assumed) function. - - Inputs: - - function id - - function basename (TODO: shouldn't appear for assumed functions?...) - - number of region groups - - region group information in case of a backward function - ([None] if forward function) - - pair: - - do we generate the forward function (it may have been filtered)? - - the number of extracted backward functions (not necessarily equal - to the number of region groups, because we may have filtered - some of them) - TODO: use the fun id for the assumed functions. - *) - decreases_clause_name : A.FunDeclId.id -> fun_name -> string; - (** Generates the name of the definition used to prove/reason about - termination. The generated code uses this clause where needed, - but its body must be defined by the user. - - Inputs: - - function id: this is especially useful to identify whether the - function is an assumed function or a local function - - function basename - *) - var_basename : StringSet.t -> string option -> ty -> string; - (** Generates a variable basename. - - Inputs: - - the set of names used in the context so far - - the basename we got from the symbolic execution, if we have one - - the type of the variable (can be useful for heuristics, in order - not to always use "x" for instance, whenever naming anonymous - variables) - - Note that once the formatter generated a basename, we add an index - if necessary to prevent name clashes: the burden of name clashes checks - is thus on the caller's side. - *) - type_var_basename : StringSet.t -> string -> string; - (** Generates a type variable basename. *) - append_index : string -> int -> string; - (** Appends an index to a name - we use this to generate unique - names: when doing so, the role of the formatter is just to concatenate - indices to names, the responsability of finding a proper index is - delegated to helper functions. - *) - extract_primitive_value : F.formatter -> bool -> primitive_value -> unit; - (** Format a constant value. - - Inputs: - - formatter - - [inside]: if [true], the value should be wrapped in parentheses - if it is made of an application (ex.: [U32 3]) - - the constant value - *) - extract_unop : - (bool -> texpression -> unit) -> - F.formatter -> - bool -> - unop -> - texpression -> - unit; - (** Format a unary operation - - Inputs: - - a formatter for expressions (called on the argument of the unop) - - extraction context (see below) - - formatter - - expression formatter - - [inside] - - unop - - argument - *) - extract_binop : - (bool -> texpression -> unit) -> - F.formatter -> - bool -> - E.binop -> - integer_type -> - texpression -> - texpression -> - unit; - (** Format a binary operation - - Inputs: - - a formatter for expressions (called on the arguments of the binop) - - extraction context (see below) - - formatter - - expression formatter - - [inside] - - binop - - argument 0 - - argument 1 - *) -} - -(** We use identifiers to look for name clashes *) -type id = - | GlobalId of A.GlobalDeclId.id - | FunId of fun_id - | DecreasesClauseId of A.fun_id - (** The definition which provides the decreases/termination clause. - We insert calls to this clause to prove/reason about termination: - the body of those clauses must be defined by the user, in the - proper files. - *) - | TypeId of type_id - | StructId of type_id - (** We use this when we manipulate the names of the structure - constructors. - - For instance, in F*: - {[ - type pair = { x: nat; y : nat } - let p : pair = Mkpair 0 1 - ]} - *) - | VariantId of type_id * VariantId.id - (** If often happens that variant names must be unique (it is the case in - F* ) which is why we register them here. - *) - | FieldId of type_id * FieldId.id - (** If often happens that in the case of structures, the field names - must be unique (it is the case in F* ) which is why we register - them here. - *) - | TypeVarId of TypeVarId.id - | VarId of VarId.id - | UnknownId - (** Used for stored various strings like keywords, definitions which - should always be in context, etc. and which can't be linked to one - of the above. - *) -[@@deriving show, ord] - -module IdOrderedType = struct - type t = id - - let compare = compare_id - let to_string = show_id - let pp_t = pp_id - let show_t = show_id -end - -module IdMap = Collections.MakeMap (IdOrderedType) - -(** The names map stores the mappings from names to identifiers and vice-versa. - - We use it for lookups (during the translation) and to check for name clashes. - - [id_to_string] is for debugging. - *) -type names_map = { - id_to_name : string IdMap.t; - name_to_id : id StringMap.t; - (** The name to id map is used to look for name clashes, and generate nice - debugging messages: if there is a name clash, it is useful to know - precisely which identifiers are mapped to the same name... - *) - names_set : StringSet.t; -} - -let names_map_add (id_to_string : id -> string) (id : id) (name : string) - (nm : names_map) : names_map = - (* Check if there is a clash *) - (match StringMap.find_opt name nm.name_to_id with - | None -> () (* Ok *) - | Some clash -> - (* There is a clash: print a nice debugging message for the user *) - let id1 = "\n- " ^ id_to_string clash in - let id2 = "\n- " ^ id_to_string id in - let err = - "Name clash detected: the following identifiers are bound to the same \ - name \"" ^ name ^ "\":" ^ id1 ^ id2 - in - log#serror err; - raise (Failure err)); - (* Sanity check *) - assert (not (StringSet.mem name nm.names_set)); - (* Insert *) - let id_to_name = IdMap.add id name nm.id_to_name in - let name_to_id = StringMap.add name id nm.name_to_id in - let names_set = StringSet.add name nm.names_set in - { id_to_name; name_to_id; names_set } - -let names_map_add_assumed_type (id_to_string : id -> string) (id : assumed_ty) - (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (TypeId (Assumed id)) name nm - -let names_map_add_assumed_struct (id_to_string : id -> string) (id : assumed_ty) - (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (StructId (Assumed id)) name nm - -let names_map_add_assumed_variant (id_to_string : id -> string) - (id : assumed_ty) (variant_id : VariantId.id) (name : string) - (nm : names_map) : names_map = - names_map_add id_to_string (VariantId (Assumed id, variant_id)) name nm - -let names_map_add_function (id_to_string : id -> string) (fid : fun_id) - (name : string) (nm : names_map) : names_map = - names_map_add id_to_string (FunId fid) name nm - -(** Make a (variable) basename unique (by adding an index). - - We do this in an inefficient manner (by testing all indices starting from - 0) but it shouldn't be a bottleneck. - - Also note that at some point, we thought about trying to reuse names of - variables which are not used anymore, like here: - {[ - let x = ... in - ... - let x0 = ... in // We could use the name "x" if [x] is not used below - ... - ]} - - However it is a good idea to keep things as they are for F*: as F* is - designed for extrinsic proofs, a proof about a function follows this - function's structure. The consequence is that we often end up - copy-pasting function bodies. As in the proofs (in assertions and - when calling lemmas) we often need to talk about the "past" (i.e., - previous values), it is very useful to generate code where all variable - names are assigned at most once. - - [append]: function to append an index to a string - *) -let basename_to_unique (names_set : StringSet.t) - (append : string -> int -> string) (basename : string) : string = - let rec gen (i : int) : string = - let s = append basename i in - if StringSet.mem s names_set then gen (i + 1) else s - in - if StringSet.mem basename names_set then gen 0 else basename - -(** Extraction context. - - Note that the extraction context contains information coming from the - LLBC AST (not only the pure AST). This is useful for naming, for instance: - we use the region information to generate the names of the backward - functions, etc. - *) -type extraction_ctx = { - trans_ctx : trans_ctx; - names_map : names_map; - fmt : formatter; - indent_incr : int; - (** The indent increment we insert whenever we need to indent more *) -} - -(** Debugging function *) -let id_to_string (id : id) (ctx : extraction_ctx) : string = - let global_decls = ctx.trans_ctx.global_context.global_decls in - let fun_decls = ctx.trans_ctx.fun_context.fun_decls in - let type_decls = ctx.trans_ctx.type_context.type_decls in - (* TODO: factorize the pretty-printing with what is in PrintPure *) - let get_type_name (id : type_id) : string = - match id with - | AdtId id -> - let def = TypeDeclId.Map.find id type_decls in - Print.name_to_string def.name - | Assumed aty -> show_assumed_ty aty - | Tuple -> raise (Failure "Unreachable") - in - match id with - | GlobalId gid -> - let name = (A.GlobalDeclId.Map.find gid global_decls).name in - "global name: " ^ Print.global_name_to_string name - | FunId fid -> ( - match fid with - | FromLlbc (fid, rg_id) -> - let fun_name = - match fid with - | Regular fid -> - Print.fun_name_to_string - (A.FunDeclId.Map.find fid fun_decls).name - | Assumed aid -> A.show_assumed_fun_id aid - in - let fun_kind = - match rg_id with - | None -> "forward" - | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id - in - "fun name (" ^ fun_kind ^ "): " ^ fun_name - | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid) - | DecreasesClauseId fid -> - let fun_name = - match fid with - | Regular fid -> - Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name - | Assumed aid -> A.show_assumed_fun_id aid - in - "decreases clause for function: " ^ fun_name - | TypeId id -> "type name: " ^ get_type_name id - | StructId id -> "struct constructor of: " ^ get_type_name id - | VariantId (id, variant_id) -> - let variant_name = - match id with - | Tuple -> raise (Failure "Unreachable") - | Assumed State -> raise (Failure "Unreachable") - | Assumed Result -> - if variant_id = result_return_id then "@result::Return" - else if variant_id = result_fail_id then "@result::Fail" - else raise (Failure "Unreachable") - | Assumed Option -> - if variant_id = option_some_id then "@option::Some" - else if variant_id = option_none_id then "@option::None" - else raise (Failure "Unreachable") - | Assumed Vec -> raise (Failure "Unreachable") - | AdtId id -> ( - let def = TypeDeclId.Map.find id type_decls in - match def.kind with - | Struct _ | Opaque -> raise (Failure "Unreachable") - | Enum variants -> - let variant = VariantId.nth variants variant_id in - Print.name_to_string def.name ^ "::" ^ variant.variant_name) - in - "variant name: " ^ variant_name - | FieldId (id, field_id) -> - let field_name = - match id with - | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Option) -> raise (Failure "Unreachable") - | Assumed Vec -> - (* We can't directly have access to the fields of a vector *) - raise (Failure "Unreachable") - | AdtId id -> ( - let def = TypeDeclId.Map.find id type_decls in - match def.kind with - | Enum _ | Opaque -> raise (Failure "Unreachable") - | Struct fields -> - let field = FieldId.nth fields field_id in - let field_name = - match field.field_name with - | None -> FieldId.to_string field_id - | Some name -> name - in - Print.name_to_string def.name ^ "." ^ field_name) - in - "field name: " ^ field_name - | UnknownId -> "keyword" - | TypeVarId _ | VarId _ -> - (* We should never get there: we add indices to make sure variable - * names are unique *) - raise (Failure "Unreachable") - -let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = - (* The id_to_string function to print nice debugging messages if there are - * collisions *) - let id_to_string (id : id) : string = id_to_string id ctx in - let names_map = names_map_add id_to_string id name ctx.names_map in - { ctx with names_map } - -let ctx_get (id : id) (ctx : extraction_ctx) : string = - match IdMap.find_opt id ctx.names_map.id_to_name with - | Some s -> s - | None -> - log#serror ("Could not find: " ^ id_to_string id ctx); - raise Not_found - -let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = - ctx_get (GlobalId id) ctx - -let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = - ctx_get (FunId id) ctx - -let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option) - (ctx : extraction_ctx) : string = - ctx_get_function (FromLlbc (Regular id, rg)) ctx - -let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = - assert (id <> Tuple); - ctx_get (TypeId id) ctx - -let ctx_get_local_type (id : TypeDeclId.id) (ctx : extraction_ctx) : string = - ctx_get_type (AdtId id) ctx - -let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = - ctx_get_type (Assumed id) ctx - -let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = - ctx_get (VarId id) ctx - -let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = - ctx_get (TypeVarId id) ctx - -let ctx_get_field (type_id : type_id) (field_id : FieldId.id) - (ctx : extraction_ctx) : string = - ctx_get (FieldId (type_id, field_id)) ctx - -let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string = - ctx_get (StructId def_id) ctx - -let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) - (ctx : extraction_ctx) : string = - ctx_get (VariantId (def_id, variant_id)) ctx - -let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : - string = - ctx_get (DecreasesClauseId (Regular def_id)) ctx - -(** Generate a unique type variable name and add it to the context *) -let ctx_add_type_var (basename : string) (id : TypeVarId.id) - (ctx : extraction_ctx) : extraction_ctx * string = - let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in - let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name - in - let ctx = ctx_add (TypeVarId id) name ctx in - (ctx, name) - -(** See {!ctx_add_type_var} *) -let ctx_add_type_vars (vars : (string * TypeVarId.id) list) - (ctx : extraction_ctx) : extraction_ctx * string list = - List.fold_left_map - (fun ctx (name, id) -> ctx_add_type_var name id ctx) - ctx vars - -(** Generate a unique variable name and add it to the context *) -let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : - extraction_ctx * string = - let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename - in - let ctx = ctx_add (VarId id) name ctx in - (ctx, name) - -(** See {!ctx_add_var} *) -let ctx_add_vars (vars : var list) (ctx : extraction_ctx) : - extraction_ctx * string list = - List.fold_left_map - (fun ctx (v : var) -> - let name = ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty in - ctx_add_var name v.id ctx) - ctx vars - -let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) : - extraction_ctx * string list = - List.fold_left_map - (fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx) - ctx vars - -let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : - extraction_ctx * string = - let cons_name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add (StructId (AdtId def.def_id)) cons_name ctx in - (ctx, cons_name) - -let ctx_add_type_decl (def : type_decl) (ctx : extraction_ctx) : extraction_ctx - = - let def_name = ctx.fmt.type_name def.name in - let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in - ctx - -let ctx_add_field (def : type_decl) (field_id : FieldId.id) (field : field) - (ctx : extraction_ctx) : extraction_ctx * string = - let name = ctx.fmt.field_name def.name field_id field.field_name in - let ctx = ctx_add (FieldId (AdtId def.def_id, field_id)) name ctx in - (ctx, name) - -let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list) - (ctx : extraction_ctx) : extraction_ctx * string list = - List.fold_left_map - (fun ctx (vid, v) -> ctx_add_field def vid v ctx) - ctx fields - -let ctx_add_variant (def : type_decl) (variant_id : VariantId.id) - (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = - let name = ctx.fmt.variant_name def.name variant.variant_name in - let ctx = ctx_add (VariantId (AdtId def.def_id, variant_id)) name ctx in - (ctx, name) - -let ctx_add_variants (def : type_decl) - (variants : (VariantId.id * variant) list) (ctx : extraction_ctx) : - extraction_ctx * string list = - List.fold_left_map - (fun ctx (vid, v) -> ctx_add_variant def vid v ctx) - ctx variants - -let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : - extraction_ctx * string = - let name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add (StructId (AdtId def.def_id)) name ctx in - (ctx, name) - -let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : - extraction_ctx = - let name = ctx.fmt.decreases_clause_name def.def_id def.basename in - ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx - -let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : - extraction_ctx = - let name = ctx.fmt.global_name def.name in - let decl = GlobalId def.def_id in - let body = FunId (FromLlbc (Regular def.body_id, None)) in - let ctx = ctx_add decl (name ^ "_c") ctx in - let ctx = ctx_add body (name ^ "_body") ctx in - ctx - -let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) - (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - (* Sanity check: the function should not be a global body - those are handled - * separately *) - assert (not def.is_global_decl_body); - (* Lookup the LLBC def to compute the region group information *) - let def_id = def.def_id in - let llbc_def = - A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls - in - let sg = llbc_def.signature in - let num_rgs = List.length sg.regions_hierarchy in - let keep_fwd, (_, backs) = trans_group in - let num_backs = List.length backs in - let rg_info = - match def.back_id with - | None -> None - | Some rg_id -> - let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in - let regions = - List.map - (fun rid -> T.RegionVarId.nth sg.region_params rid) - rg.regions - in - let region_names = - List.map (fun (r : T.region_var) -> r.name) regions - in - Some { id = rg_id; region_names } - in - let name = - ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs) - in - ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx - -type names_map_init = { - keywords : string list; - assumed_adts : (assumed_ty * string) list; - assumed_structs : (assumed_ty * string) list; - assumed_variants : (assumed_ty * VariantId.id * string) list; - assumed_llbc_functions : - (A.assumed_fun_id * RegionGroupId.id option * string) list; - assumed_pure_functions : (pure_assumed_fun_id * string) list; -} - -(** Initialize a names map with a proper set of keywords/names coming from the - target language/prover. *) -let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = - let int_names = List.map fmt.int_name T.all_int_types in - let keywords = - List.concat - [ - [ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords; - ] - in - let names_set = StringSet.of_list keywords in - let name_to_id = - StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords) - in - (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. - * Also note that we don't need this mapping for keywords: we insert keywords only - * to check collisions. *) - let id_to_name = IdMap.empty in - let nm = { id_to_name; name_to_id; names_set } in - (* For debugging - we are creating bindings for assumed types and functions, so - * it is ok if we simply use the "show" function (those aren't simply identified - * by numbers) *) - let id_to_string = show_id in - (* Then we add: - * - the assumed types - * - the assumed struct constructors - * - the assumed variants - * - the assumed functions - *) - let nm = - List.fold_left - (fun nm (type_id, name) -> - names_map_add_assumed_type id_to_string type_id name nm) - nm init.assumed_adts - in - let nm = - List.fold_left - (fun nm (type_id, name) -> - names_map_add_assumed_struct id_to_string type_id name nm) - nm init.assumed_structs - in - let nm = - List.fold_left - (fun nm (type_id, variant_id, name) -> - names_map_add_assumed_variant id_to_string type_id variant_id name nm) - nm init.assumed_variants - in - let assumed_functions = - List.map - (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name)) - init.assumed_llbc_functions - @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions - in - let nm = - List.fold_left - (fun nm (fid, name) -> names_map_add_function id_to_string fid name nm) - nm assumed_functions - in - (* Return *) - nm - -let compute_type_decl_name (fmt : formatter) (def : type_decl) : string = - fmt.type_name def.name - -(** A helper function: generates a function suffix from a region group - information. - TODO: move all those helpers. -*) -let default_fun_suffix (num_region_groups : int) (rg : region_group_info option) - ((keep_fwd, num_backs) : bool * int) : string = - (* There are several cases: - - [rg] is [Some]: this is a forward function: - - we add "_fwd" - - [rg] is [None]: this is a backward function: - - this function has one extracted backward function: - - if the forward function has been filtered, we add "_fwd_back": - the forward function is useless, so the unique backward function - takes its place, in a way - - otherwise we add "_back" - - this function has several backward functions: we add "_back" and an - additional suffix to identify the precise backward function - Note that we always add a suffix (in case there are no region groups, - we could not add the "_fwd" suffix) to prevent name clashes between - definitions (in particular between type and function definitions). - *) - match rg with - | None -> "_fwd" - | Some rg -> - assert (num_region_groups > 0 && num_backs > 0); - if num_backs = 1 then - (* Exactly one backward function *) - if not keep_fwd then "_fwd_back" else "_back" - else if - (* Several region groups/backward functions: - - if all the regions in the group have names, we use those names - - otherwise we use an index - *) - List.for_all Option.is_some rg.region_names - then - (* Concatenate the region names *) - "_back" ^ String.concat "" (List.map Option.get rg.region_names) - else (* Use the region index *) - "_back" ^ RegionGroupId.to_string rg.id diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index ff55f322..5a024d9e 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -398,9 +398,9 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = let cons = { e = Qualif qualif; ty } in mk_apps cons vl -let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id) +let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option) (vl : typed_pattern list) : typed_pattern = - let value = PatAdt { variant_id = Some variant_id; field_values = vl } in + let value = PatAdt { variant_id; field_values = vl } in { value; ty = adt_ty } let ty_as_integer (t : ty) : T.integer_type = diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 235e33e4..9d95db7f 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -26,6 +26,7 @@ type mplace = { projection : E.projection; (** We store the projection because we can, but it is actually not that useful *) } +[@@deriving show] type call_id = | Fun of A.fun_id * V.FunCallId.id @@ -43,6 +44,7 @@ type call = { dest : V.symbolic_value; dest_place : mplace option; (** Meta information *) } +[@@deriving show] (** Meta information, not necessary for synthesis but useful to guide it to generate a pretty output. @@ -51,6 +53,7 @@ type call = { type meta = | Assignment of mplace * V.typed_value * mplace option (** We generated an assignment (destination, assigned value, src) *) +[@@deriving show] (** **Rk.:** here, {!expression} is not at all equivalent to the expressions used in LLBC: they are a first step towards lambda-calculus expressions. @@ -108,3 +111,4 @@ and expansion = T.integer_type * (V.scalar_value * expression) list * expression (** An integer expansion (i.e, a switch over an integer). The last expression is for the "otherwise" branch. *) +[@@deriving show] diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index a327c785..62be5efd 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1578,8 +1578,17 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* We don't do the same thing if there is a branching or not *) match branches with | [] -> raise (Failure "Unreachable") - | [ (variant_id, svl, branch) ] -> ( - (* There is exactly one branch: no branching *) + | [ (variant_id, svl, branch) ] + when not + (TypesUtils.ty_is_custom_adt sv.V.sv_ty + && !Config.always_deconstruct_adts_with_matches) -> ( + (* There is exactly one branch: no branching. + + We can decompose the ADT value with a let-binding, unless + the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): + we *ignore* this branch (and go to the next one) if the ADT is a custom + adt, and [always_deconstruct_adts_with_matches] is true. + *) let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let branch = translate_expression branch ctx in @@ -1588,9 +1597,16 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* Detect if this is an enumeration or not *) let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in let is_enum = type_decl_is_enum tdef in - if is_enum then - (* This is an enumeration: introduce an [ExpandEnum] let-binding *) - let variant_id = Option.get variant_id in + (* We deconstruct the ADT with a let-binding in two situations: + - if the ADT is an enumeration (which must have exactly one branch) + - if we forbid using field projectors. + + We forbid using field projectors in some situations, for example + if the backend is Coq. See '!Config.dont_use_field_projectors}. + *) + let use_let = is_enum || !Config.dont_use_field_projectors in + if use_let then + (* Introduce a let binding which expands the ADT *) let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in @@ -1665,8 +1681,6 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let translate_branch (variant_id : T.VariantId.id option) (svl : V.symbolic_value list) (branch : S.expression) : match_branch = - (* There *must* be a variant id - otherwise there can't be several branches *) - let variant_id = Option.get variant_id in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let vars = List.map (fun x -> mk_typed_pattern_from_var x None) vars diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 79d1c913..b2a28710 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -293,7 +293,7 @@ let translate_module_to_pure (crate : A.crate) : (** Extraction context *) type gen_ctx = { crate : A.crate; - extract_ctx : PureToExtract.extraction_ctx; + extract_ctx : ExtractBase.extraction_ctx; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; functions_with_decreases_clause : A.FunDeclId.Set.t; @@ -342,30 +342,41 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = *) let extract_definitions (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) : unit = - (* Export the definition groups to the file, in the proper order *) - let export_type (qualif : ExtractToBackend.type_decl_qualif) - (id : Pure.TypeDeclId.id) : unit = - (* Retrive the declaration *) + (* Export the definition groups to the file, in the proper order. + - [extract_decl]: extract the type declaration (if not filtered) + - [extract_extra_info]: extra the extra type information (e.g., + the [Arguments] information in Coq). + *) + let export_type (kind : ExtractBase.decl_kind) (id : Pure.TypeDeclId.id) + (extract_decl : bool) (extract_extra_info : bool) : unit = + (* Retrieve the declaration *) let def = Pure.TypeDeclId.Map.find id ctx.trans_types in - (* Update the qualifier, if the type is opaque *) - let is_opaque, qualif = + (* Update the kind, if the type is opaque *) + let is_opaque, kind = match def.kind with - | Enum _ | Struct _ -> (false, qualif) + | Enum _ | Struct _ -> (false, kind) | Opaque -> - let qualif = - if config.interface then ExtractToBackend.TypeVal - else ExtractToBackend.AssumeType + let kind = + if config.interface then ExtractBase.Declared + else ExtractBase.Assumed in - (true, qualif) + (true, kind) in (* Extract, if the config instructs to do so (depending on whether the type * is opaque or not) *) if (is_opaque && config.extract_opaque) || ((not is_opaque) && config.extract_transparent) - then ExtractToBackend.extract_type_decl ctx.extract_ctx fmt qualif def + then ( + if extract_decl then + Extract.extract_type_decl ctx.extract_ctx fmt kind def; + if extract_extra_info then + Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def) in + let export_type_decl kind id = export_type kind id true false in + let export_type_extra_info kind id = export_type kind id false true in + (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause @@ -393,8 +404,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (fun (_, (fwd, _)) -> let has_decr_clause = has_decreases_clause fwd in if has_decr_clause then - ExtractToBackend.extract_template_decreases_clause ctx.extract_ctx fmt - fwd) + Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd) pure_ls; (* Extract the function definitions *) (if config.extract_fun_decls then @@ -408,17 +418,25 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) else true else false in + let fls_length = List.length fls in List.iteri (fun i (fwd_def, def) -> let is_opaque = Option.is_none fwd_def.Pure.body in - let qualif = + let kind = if is_opaque then - if config.interface then ExtractToBackend.Val - else ExtractToBackend.AssumeVal - else if not is_rec then ExtractToBackend.Let + if config.interface then ExtractBase.Declared + else ExtractBase.Assumed + else if not is_rec then ExtractBase.SingleNonRec else if is_mut_rec then - if i = 0 then ExtractToBackend.LetRec else ExtractToBackend.And - else ExtractToBackend.LetRec + (* If the functions are mutually recursive, we need to distinguish: + * - the first of the group + * - the last of the group + * - the inner functions + *) + if i = 0 then ExtractBase.MutRecFirst + else if i = fls_length - 1 then ExtractBase.MutRecLast + else ExtractBase.MutRecInner + else ExtractBase.SingleRec in let has_decr_clause = has_decreases_clause def && config.extract_decreases_clauses @@ -428,15 +446,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) ((not is_opaque) && config.extract_transparent) || (is_opaque && config.extract_opaque) then - ExtractToBackend.extract_fun_decl ctx.extract_ctx fmt qualif - has_decr_clause def) + Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def) fls); (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter (fun (keep_fwd, (fwd, _)) -> if keep_fwd then - ExtractToBackend.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) + Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) pure_ls in @@ -454,32 +471,49 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) ((not is_opaque) && config.extract_transparent) || (is_opaque && config.extract_opaque) then - ExtractToBackend.extract_global_decl ctx.extract_ctx fmt global body + Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface in let export_state_type () : unit = - let qualif = - if config.interface then ExtractToBackend.TypeVal - else ExtractToBackend.AssumeType + let kind = + if config.interface then ExtractBase.Declared else ExtractBase.Assumed in - ExtractToBackend.extract_state_type fmt ctx.extract_ctx qualif + Extract.extract_state_type fmt ctx.extract_ctx kind in let export_decl (decl : A.declaration_group) : unit = match decl with | Type (NonRec id) -> - if config.extract_types then export_type ExtractToBackend.Type id + if config.extract_types then ( + let kind = ExtractBase.SingleNonRec in + (* Export the type declaration *) + export_type_decl kind id; + (* Export the extra information (ex.: [Arguments] instructions in Coq) *) + export_type_extra_info kind id) | Type (Rec ids) -> (* Rk.: we shouldn't have (mutually) recursive opaque types *) - if config.extract_types then + let num_decls = List.length ids in + let is_mut_rec = num_decls > 1 in + if config.extract_types then ( + let kind_from_index i = + if not is_mut_rec then ExtractBase.SingleRec + else if i = 0 then ExtractBase.MutRecFirst + else if i = num_decls - 1 then ExtractBase.MutRecLast + else ExtractBase.MutRecInner + in + (* Extract the type declarations *) + List.iteri + (fun i id -> + let kind = kind_from_index i in + export_type_decl kind id) + ids; + (* Export the extra information (ex.: [Arguments] instructions in Coq) *) List.iteri (fun i id -> - let qualif = - if i = 0 then ExtractToBackend.Type else ExtractToBackend.And - in - export_type qualif id) - ids + let kind = kind_from_index i in + export_type_extra_info kind id) + ids) | Fun (NonRec id) -> (* Lookup *) let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in @@ -527,15 +561,33 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) (* Create the header *) Printf.fprintf out "(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *)\n"; Printf.fprintf out "(** [%s]%s *)\n" rust_module_name custom_msg; - Printf.fprintf out "module %s\n" module_name; - Printf.fprintf out "open Primitives\n"; - (* Add the custom imports *) - List.iter (fun m -> Printf.fprintf out "open %s\n" m) custom_imports; - (* Add the custom includes *) - List.iter (fun m -> Printf.fprintf out "include %s\n" m) custom_includes; - (* Z3 options - note that we use fuel 1 because it its useful for the decrease clauses *) - Printf.fprintf out "\n#set-options \"--z3rlimit 50 --fuel 1 --ifuel 1\"\n"; - + (match !Config.backend with + | FStar -> + Printf.fprintf out "module %s\n" module_name; + Printf.fprintf out "open Primitives\n"; + (* Add the custom imports *) + List.iter (fun m -> Printf.fprintf out "open %s\n" m) custom_imports; + (* Add the custom includes *) + List.iter (fun m -> Printf.fprintf out "include %s\n" m) custom_includes; + (* Z3 options - note that we use fuel 1 because it its useful for the decrease clauses *) + Printf.fprintf out "\n#set-options \"--z3rlimit 50 --fuel 1 --ifuel 1\"\n" + | Coq -> + Printf.fprintf out "Require Import Primitives.\n"; + Printf.fprintf out "Import Primitives.\n"; + Printf.fprintf out "Require Import Coq.ZArith.ZArith.\n"; + Printf.fprintf out "Local Open Scope Primitives_scope.\n"; + + (* Add the custom imports *) + List.iter + (fun m -> Printf.fprintf out "Require Import %s .\n" m) + custom_imports; + (* Add the custom includes *) + List.iter + (fun m -> + Printf.fprintf out "Require Export %s .\n" m; + Printf.fprintf out "Import %s .\n" m) + custom_includes; + Printf.fprintf out "Module %s .\n" module_name); (* From now onwards, we use the formatter *) (* Set the margin *) Format.pp_set_margin fmt 80; @@ -550,6 +602,11 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string) Format.pp_close_box fmt (); Format.pp_print_newline fmt (); + (* Close the module *) + (match !Config.backend with + | FStar -> () + | Coq -> Printf.fprintf out "End %s .\n" module_name); + (* Some logging *) log#linfo (lazy ("Generated: " ^ filename)); @@ -568,17 +625,12 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : * We initialize the names map by registering the keywords used in the * language, as well as some primitive names ("u32", etc.) *) let variant_concatenate_type_name = true in - let fstar_fmt = - ExtractToBackend.mk_formatter trans_ctx crate.name + let mk_formatter_and_names_map = Extract.mk_formatter_and_names_map in + let fmt, names_map = + mk_formatter_and_names_map trans_ctx crate.name variant_concatenate_type_name in - let names_map = - PureToExtract.initialize_names_map fstar_fmt - ExtractToBackend.fstar_names_map_init - in - let ctx = - { PureToExtract.trans_ctx; names_map; fmt = fstar_fmt; indent_incr = 2 } - in + let ctx = { ExtractBase.trans_ctx; names_map; fmt; indent_incr = 2 } in (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) @@ -596,7 +648,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : * sure there are no name clashes. *) let ctx = List.fold_left - (fun ctx def -> ExtractToBackend.extract_type_decl_register_names ctx def) + (fun ctx def -> Extract.extract_type_decl_register_names ctx def) ctx trans_types in @@ -612,14 +664,13 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : let is_global = (fst def).Pure.is_global_decl_body in if is_global then ctx else - ExtractToBackend.extract_fun_decl_register_names ctx keep_fwd - gen_decr_clause def) + Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause + def) ctx trans_funs in let ctx = - List.fold_left ExtractToBackend.extract_global_decl_register_names ctx - crate.globals + List.fold_left Extract.extract_global_decl_register_names ctx crate.globals in (* Open the output file *) @@ -658,12 +709,17 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* Create a directory with *default* permissions *) Core_unix.mkdir_p dest_dir); - (* Copy "Primitives.fst" *) + (* Copy the "Primitives" file *) let _ = (* Retrieve the executable's directory *) let exe_dir = Filename.dirname Sys.argv.(0) in - let src = open_in (exe_dir ^ "/backends/fstar/Primitives.fst") in - let tgt_filename = Filename.concat dest_dir "Primitives.fst" in + let primitives_src, primitives_destname = + match !Config.backend with + | Config.FStar -> ("/backends/fstar/Primitives.fst", "Primitives.fst") + | Config.Coq -> ("/backends/coq/Primitives.v", "Primitives.v") + in + let src = open_in (exe_dir ^ primitives_src) in + let tgt_filename = Filename.concat dest_dir primitives_destname in let tgt = open_out tgt_filename in (* Very annoying: I couldn't find a "cp" function in the OCaml libraries... *) try @@ -689,6 +745,10 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : } in + let module_delimiter = + match !Config.backend with FStar -> "." | Coq -> "__" + in + (* Extract one or several files, depending on the configuration *) if !Config.split_files then ( let base_gen_config = @@ -712,9 +772,16 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* Extract the types *) (* If there are opaque types, we extract in an interface *) - let types_filename_ext = if has_opaque_types then ".fsti" else ".fst" in - let types_filename = extract_filebasename ^ ".Types" ^ types_filename_ext in - let types_module = module_name ^ ".Types" in + let types_filename_ext = + match !Config.backend with + | FStar -> if has_opaque_types then ".fsti" else ".fst" + | Coq -> if has_opaque_types then ".v" else ".v" + in + let types_file_suffix = module_delimiter ^ "Types" in + let types_filename = + extract_filebasename ^ types_file_suffix ^ types_filename_ext + in + let types_module = module_name ^ types_file_suffix in let types_config = { base_gen_config with @@ -733,8 +800,12 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : && not (A.FunDeclId.Set.is_empty rec_functions) in (if needs_clauses_module && !Config.extract_template_decreases_clauses then - let clauses_filename = extract_filebasename ^ ".Clauses.Template.fst" in - let clauses_module = module_name ^ ".Clauses.Template" in + let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in + let clauses_file_suffix = + module_delimiter ^ "Clauses" ^ module_delimiter ^ "Template" + in + let clauses_filename = extract_filebasename ^ clauses_file_suffix ^ ext in + let clauses_module = module_name ^ clauses_file_suffix in let clauses_config = { base_gen_config with extract_template_decreases_clauses = true } in @@ -745,8 +816,10 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* Extract the opaque functions, if needed *) let opaque_funs_module = if has_opaque_funs then ( - let opaque_filename = extract_filebasename ^ ".Opaque.fsti" in - let opaque_module = module_name ^ ".Opaque" in + let ext = match !Config.backend with FStar -> ".fsti" | Coq -> ".v" in + let opaque_file_suffix = module_delimiter ^ "Opaque" in + let opaque_filename = extract_filebasename ^ opaque_file_suffix ^ ext in + let opaque_module = module_name ^ opaque_file_suffix in let opaque_config = { base_gen_config with @@ -763,8 +836,10 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : in (* Extract the functions *) - let fun_filename = extract_filebasename ^ ".Funs.fst" in - let fun_module = module_name ^ ".Funs" in + let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in + let fun_file_suffix = module_delimiter ^ "Funs" in + let fun_filename = extract_filebasename ^ fun_file_suffix ^ ext in + let fun_module = module_name ^ fun_file_suffix in let fun_config = { base_gen_config with @@ -773,7 +848,9 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : } in let clauses_module = - if needs_clauses_module then [ module_name ^ ".Clauses" ] else [] + if needs_clauses_module then + [ module_name ^ module_delimiter ^ "Clauses" ] + else [] in extract_file fun_config gen_ctx fun_filename crate.A.name fun_module ": function definitions" [] @@ -794,6 +871,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : } in (* Add the extension for F* *) - let extract_filename = extract_filebasename ^ ".fst" in + let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in + let extract_filename = extract_filebasename ^ ext in extract_file gen_config gen_ctx extract_filename crate.A.name module_name "" [] [] diff --git a/compiler/Values.ml b/compiler/Values.ml index 44b22f9f..046f0482 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -817,3 +817,4 @@ type symbolic_expansion = | SeAdt of (VariantId.id option * symbolic_value list) | SeMutRef of BorrowId.id * symbolic_value | SeSharedRef of BorrowId.Set.t * symbolic_value +[@@deriving show] diff --git a/compiler/dune b/compiler/dune index 10aa9b10..b530340b 100644 --- a/compiler/dune +++ b/compiler/dune @@ -20,7 +20,10 @@ Cps Expressions ExpressionsUtils - ExtractToBackend + Extract + ExtractBase + ExtractToCoq + ExtractToFStar FunsAnalysis Identifiers InterpreterBorrowsCore @@ -44,7 +47,6 @@ PrintPure PureMicroPasses Pure - PureToExtract PureTypeCheck PureUtils Scalars diff --git a/tests/Makefile b/tests/Makefile new file mode 100644 index 00000000..dfb20cc4 --- /dev/null +++ b/tests/Makefile @@ -0,0 +1,3 @@ +all: + cd fstar && $(MAKE) all + cd coq && $(MAKE) all diff --git a/tests/coq/Makefile b/tests/coq/Makefile new file mode 100644 index 00000000..5fead9c9 --- /dev/null +++ b/tests/coq/Makefile @@ -0,0 +1,3 @@ +# TODO: make this more general +all: + cd misc && $(MAKE) all diff --git a/tests/coq/misc/Constants.v b/tests/coq/misc/Constants.v new file mode 100644 index 00000000..677aae8c --- /dev/null +++ b/tests/coq/misc/Constants.v @@ -0,0 +1,138 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [constants] *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Module Constants . + +(** [constants::X0] *) +Definition x0_body : result u32 := Return (0 %u32) . +Definition x0_c : u32 := x0_body%global . + +(** [core::num::u32::{9}::MAX] *) +Definition core_num_u32_max_body : result u32 := Return (4294967295 %u32) . +Definition core_num_u32_max_c : u32 := core_num_u32_max_body%global . + +(** [constants::X1] *) +Definition x1_body : result u32 := Return core_num_u32_max_c . +Definition x1_c : u32 := x1_body%global . + +(** [constants::X2] *) +Definition x2_body : result u32 := Return (3 %u32) . +Definition x2_c : u32 := x2_body%global . + +(** [constants::incr] *) +Definition incr_fwd (n : u32) : result u32 := i <- u32_add n 1 %u32; Return i . + +(** [constants::X3] *) +Definition x3_body : result u32 := i <- incr_fwd (32 %u32); Return i . +Definition x3_c : u32 := x3_body%global . + +(** [constants::mk_pair0] *) +Definition mk_pair0_fwd (x : u32) (y : u32) : result (u32 * u32) := + Return (x, y) . + +(** [constants::Pair] *) +Record Pair_t (T1 T2 : Type) := mkPair_t { Pair_x : T1; Pair_y : T2; } . + +Arguments mkPair_t {T1} {T2} _ _ . +Arguments Pair_x {T1} {T2} . +Arguments Pair_y {T1} {T2} . + +(** [constants::mk_pair1] *) +Definition mk_pair1_fwd (x : u32) (y : u32) : result (Pair_t u32 u32) := + Return (mkPair_t x y) . + +(** [constants::P0] *) +Definition p0_body : result (u32 * u32) := + p <- mk_pair0_fwd (0 %u32) (1 %u32); Return p + . +Definition p0_c : (u32 * u32) := p0_body%global . + +(** [constants::P1] *) +Definition p1_body : result (Pair_t u32 u32) := + p <- mk_pair1_fwd (0 %u32) (1 %u32); Return p + . +Definition p1_c : Pair_t u32 u32 := p1_body%global . + +(** [constants::P2] *) +Definition p2_body : result (u32 * u32) := Return (0 %u32, 1 %u32) . +Definition p2_c : (u32 * u32) := p2_body%global . + +(** [constants::P3] *) +Definition p3_body : result (Pair_t u32 u32) := + Return (mkPair_t (0 %u32) (1 %u32)) + . +Definition p3_c : Pair_t u32 u32 := p3_body%global . + +(** [constants::Wrap] *) +Record Wrap_t (T : Type) := mkWrap_t { Wrap_val : T; } . + +Arguments mkWrap_t {T} _ . +Arguments Wrap_val {T} . + +(** [constants::Wrap::{0}::new] *) +Definition wrap_new_fwd (T : Type) (val : T) : result (Wrap_t T) := + Return (mkWrap_t val) . + +(** [constants::Y] *) +Definition y_body : result (Wrap_t i32) := + w <- wrap_new_fwd i32 (2 %i32); Return w + . +Definition y_c : Wrap_t i32 := y_body%global . + +(** [constants::unwrap_y] *) +Definition unwrap_y_fwd : result i32 := + match y_c with | mkWrap_t i => Return i end . + +(** [constants::YVAL] *) +Definition yval_body : result i32 := i <- unwrap_y_fwd; Return i . +Definition yval_c : i32 := yval_body%global . + +(** [constants::get_z1::Z1] *) +Definition get_z1_z1_body : result i32 := Return (3 %i32) . +Definition get_z1_z1_c : i32 := get_z1_z1_body%global . + +(** [constants::get_z1] *) +Definition get_z1_fwd : result i32 := Return get_z1_z1_c . + +(** [constants::add] *) +Definition add_fwd (a : i32) (b : i32) : result i32 := + i <- i32_add a b; Return i . + +(** [constants::Q1] *) +Definition q1_body : result i32 := Return (5 %i32) . +Definition q1_c : i32 := q1_body%global . + +(** [constants::Q2] *) +Definition q2_body : result i32 := Return q1_c . +Definition q2_c : i32 := q2_body%global . + +(** [constants::Q3] *) +Definition q3_body : result i32 := i <- add_fwd q2_c (3 %i32); Return i . +Definition q3_c : i32 := q3_body%global . + +(** [constants::get_z2] *) +Definition get_z2_fwd : result i32 := + i <- get_z1_fwd; i0 <- add_fwd i q3_c; i1 <- add_fwd q1_c i0; Return i1 . + +(** [constants::S1] *) +Definition s1_body : result u32 := Return (6 %u32) . +Definition s1_c : u32 := s1_body%global . + +(** [constants::S2] *) +Definition s2_body : result u32 := i <- incr_fwd s1_c; Return i . +Definition s2_c : u32 := s2_body%global . + +(** [constants::S3] *) +Definition s3_body : result (Pair_t u32 u32) := Return p3_c . +Definition s3_c : Pair_t u32 u32 := s3_body%global . + +(** [constants::S4] *) +Definition s4_body : result (Pair_t u32 u32) := + p <- mk_pair1_fwd (7 %u32) (8 %u32); Return p + . +Definition s4_c : Pair_t u32 u32 := s4_body%global . + +End Constants . diff --git a/tests/coq/misc/External__Funs.v b/tests/coq/misc/External__Funs.v new file mode 100644 index 00000000..77b738b0 --- /dev/null +++ b/tests/coq/misc/External__Funs.v @@ -0,0 +1,100 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [external]: function definitions *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Require Export External__Types . +Import External__Types . +Require Export External__Opaque . +Import External__Opaque . +Module External__Funs . + +(** [external::swap] *) +Definition swap_fwd + (T : Type) (x : T) (y : T) (st : state) : result (state * unit) := + p <- core_mem_swap_fwd T x y st; + let (st0, _) := p in + p0 <- core_mem_swap_back0 T x y st st0; + let (st1, _) := p0 in + p1 <- core_mem_swap_back1 T x y st st1; + let (st2, _) := p1 in Return (st2, tt) + . + +(** [external::swap] *) +Definition swap_back + (T : Type) (x : T) (y : T) (st : state) (st0 : state) : + result (state * (T * T)) + := + p <- core_mem_swap_fwd T x y st; + let (st1, _) := p in + p0 <- core_mem_swap_back0 T x y st st1; + let (st2, x0) := p0 in + p1 <- core_mem_swap_back1 T x y st st2; + let (_, y0) := p1 in Return (st0, (x0, y0)) + . + +(** [external::test_new_non_zero_u32] *) +Definition test_new_non_zero_u32_fwd + (x : u32) (st : state) : result (state * Core_num_nonzero_non_zero_u32_t) := + p <- core_num_nonzero_non_zero_u32_new_fwd x st; + let (st0, opt) := p in + p0 <- core_option_option_unwrap_fwd Core_num_nonzero_non_zero_u32_t opt st0; + let (st1, nzu) := p0 in Return (st1, nzu) + . + +(** [external::test_vec] *) +Definition test_vec_fwd : result unit := + let v := vec_new u32 in + v0 <- vec_push_back u32 v (0 %u32); let _ := v0 in Return tt + . + +(** [external::custom_swap] *) +Definition custom_swap_fwd + (T : Type) (x : T) (y : T) (st : state) : result (state * T) := + p <- core_mem_swap_fwd T x y st; + let (st0, _) := p in + p0 <- core_mem_swap_back0 T x y st st0; + let (st1, x0) := p0 in + p1 <- core_mem_swap_back1 T x y st st1; + let (st2, _) := p1 in Return (st2, x0) + . + +(** [external::custom_swap] *) +Definition custom_swap_back + (T : Type) (x : T) (y : T) (st : state) (ret : T) (st0 : state) : + result (state * (T * T)) + := + p <- core_mem_swap_fwd T x y st; + let (st1, _) := p in + p0 <- core_mem_swap_back0 T x y st st1; + let (st2, _) := p0 in + p1 <- core_mem_swap_back1 T x y st st2; + let (_, y0) := p1 in Return (st0, (ret, y0)) + . + +(** [external::test_custom_swap] *) +Definition test_custom_swap_fwd + (x : u32) (y : u32) (st : state) : result (state * unit) := + p <- custom_swap_fwd u32 x y st; let (st0, _) := p in Return (st0, tt) . + +(** [external::test_custom_swap] *) +Definition test_custom_swap_back + (x : u32) (y : u32) (st : state) (st0 : state) : + result (state * (u32 * u32)) + := + p <- custom_swap_back u32 x y st (1 %u32) st0; + let (st1, tmp) := p in + let (x0, y0) := tmp in Return (st1, (x0, y0)) + . + +(** [external::test_swap_non_zero] *) +Definition test_swap_non_zero_fwd + (x : u32) (st : state) : result (state * u32) := + p <- swap_fwd u32 x (0 %u32) st; + let (st0, _) := p in + p0 <- swap_back u32 x (0 %u32) st st0; + let (st1, (x0, _)) := p0 in if x0 s= 0 %u32 then Fail_ else Return (st1, x0) + . + +End External__Funs . diff --git a/tests/coq/misc/External__Opaque.v b/tests/coq/misc/External__Opaque.v new file mode 100644 index 00000000..19111a37 --- /dev/null +++ b/tests/coq/misc/External__Opaque.v @@ -0,0 +1,36 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [external]: opaque function definitions *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Require Export External__Types . +Import External__Types . +Module External__Opaque . + +(** [core::mem::swap] *) +Axiom core_mem_swap_fwd : + forall(T : Type) , T -> T -> state -> result (state * unit) + . + +(** [core::mem::swap] *) +Axiom core_mem_swap_back0 : + forall(T : Type) , T -> T -> state -> state -> result (state * T) + . + +(** [core::mem::swap] *) +Axiom core_mem_swap_back1 : + forall(T : Type) , T -> T -> state -> state -> result (state * T) + . + +(** [core::num::nonzero::NonZeroU32::{14}::new] *) +Axiom core_num_nonzero_non_zero_u32_new_fwd + : u32 -> state -> result (state * (option Core_num_nonzero_non_zero_u32_t)) + . + +(** [core::option::Option::{0}::unwrap] *) +Axiom core_option_option_unwrap_fwd : + forall(T : Type) , option T -> state -> result (state * T) + . + +End External__Opaque . diff --git a/tests/coq/misc/External__Types.v b/tests/coq/misc/External__Types.v new file mode 100644 index 00000000..1513ec4a --- /dev/null +++ b/tests/coq/misc/External__Types.v @@ -0,0 +1,15 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [external]: type definitions *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Module External__Types . + +(** [core::num::nonzero::NonZeroU32] *) +Axiom Core_num_nonzero_non_zero_u32_t : Type . + +(** The state type used in the state-error monad *) +Axiom state : Type. + +End External__Types . diff --git a/tests/coq/misc/Makefile b/tests/coq/misc/Makefile new file mode 100644 index 00000000..ff1ccd39 --- /dev/null +++ b/tests/coq/misc/Makefile @@ -0,0 +1,22 @@ +# Makefile originally taken from coq-club + +%: Makefile.coq phony + +make -f Makefile.coq $@ + +all: Makefile.coq + +make -f Makefile.coq all + +clean: Makefile.coq + +make -f Makefile.coq clean + rm -f Makefile.coq + +Makefile.coq: _CoqProject Makefile + coq_makefile -f _CoqProject | sed 's/$$(COQCHK) $$(COQCHKFLAGS) $$(COQLIBS)/$$(COQCHK) $$(COQCHKFLAGS) $$(subst -Q,-R,$$(COQLIBS))/' > Makefile.coq + +_CoqProject: ; + +Makefile: ; + +phony: ; + +.PHONY: all clean phony diff --git a/tests/coq/misc/NoNestedBorrows.v b/tests/coq/misc/NoNestedBorrows.v new file mode 100644 index 00000000..6dc41204 --- /dev/null +++ b/tests/coq/misc/NoNestedBorrows.v @@ -0,0 +1,510 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [no_nested_borrows] *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Module NoNestedBorrows . + +(** [no_nested_borrows::Pair] *) +Record Pair_t (T1 T2 : Type) := mkPair_t { Pair_x : T1; Pair_y : T2; } . + +Arguments mkPair_t {T1} {T2} _ _ . +Arguments Pair_x {T1} {T2} . +Arguments Pair_y {T1} {T2} . + +(** [no_nested_borrows::List] *) +Inductive List_t (T : Type) := +| ListCons : T -> List_t T -> List_t T +| ListNil : List_t T +. + +Arguments ListCons {T} _ _ . +Arguments ListNil {T} . + +(** [no_nested_borrows::One] *) +Inductive One_t (T1 : Type) := | OneOne : T1 -> One_t T1 . + +Arguments OneOne {T1} _ . + +(** [no_nested_borrows::EmptyEnum] *) +Inductive Empty_enum_t := | EmptyEnumEmpty : Empty_enum_t . + +Arguments EmptyEnumEmpty . + +(** [no_nested_borrows::Enum] *) +Inductive Enum_t := | EnumVariant1 : Enum_t | EnumVariant2 : Enum_t . + +Arguments EnumVariant1 . +Arguments EnumVariant2 . + +(** [no_nested_borrows::EmptyStruct] *) +Record Empty_struct_t := mkEmpty_struct_t { } . + +Arguments mkEmpty_struct_t . + +(** [no_nested_borrows::Sum] *) +Inductive Sum_t (T1 T2 : Type) := +| SumLeft : T1 -> Sum_t T1 T2 +| SumRight : T2 -> Sum_t T1 T2 +. + +Arguments SumLeft {T1} {T2} _ . +Arguments SumRight {T1} {T2} _ . + +(** [no_nested_borrows::neg_test] *) +Definition neg_test_fwd (x : i32) : result i32 := i <- i32_neg x; Return i . + +(** [no_nested_borrows::add_test] *) +Definition add_test_fwd (x : u32) (y : u32) : result u32 := + i <- u32_add x y; Return i . + +(** [no_nested_borrows::subs_test] *) +Definition subs_test_fwd (x : u32) (y : u32) : result u32 := + i <- u32_sub x y; Return i . + +(** [no_nested_borrows::div_test] *) +Definition div_test_fwd (x : u32) (y : u32) : result u32 := + i <- u32_div x y; Return i . + +(** [no_nested_borrows::div_test1] *) +Definition div_test1_fwd (x : u32) : result u32 := + i <- u32_div x 2 %u32; Return i . + +(** [no_nested_borrows::rem_test] *) +Definition rem_test_fwd (x : u32) (y : u32) : result u32 := + i <- u32_rem x y; Return i . + +(** [no_nested_borrows::cast_test] *) +Definition cast_test_fwd (x : u32) : result i32 := + i <- scalar_cast U32 I32 x; Return i . + +(** [no_nested_borrows::test2] *) +Definition test2_fwd : result unit := + i <- u32_add 23 %u32 44 %u32; let _ := i in Return tt . + +(** Unit test for [no_nested_borrows::test2] *) +Check (test2_fwd )%return. + +(** [no_nested_borrows::get_max] *) +Definition get_max_fwd (x : u32) (y : u32) : result u32 := + if x s>= y then Return x else Return y . + +(** [no_nested_borrows::test3] *) +Definition test3_fwd : result unit := + x <- get_max_fwd (4 %u32) (3 %u32); + y <- get_max_fwd (10 %u32) (11 %u32); + z <- u32_add x y; if negb (z s= 15 %u32) then Fail_ else Return tt + . + +(** Unit test for [no_nested_borrows::test3] *) +Check (test3_fwd )%return. + +(** [no_nested_borrows::test_neg1] *) +Definition test_neg1_fwd : result unit := + y <- i32_neg (3 %i32); if negb (y s= (-3) %i32) then Fail_ else Return tt . + +(** Unit test for [no_nested_borrows::test_neg1] *) +Check (test_neg1_fwd )%return. + +(** [no_nested_borrows::refs_test1] *) +Definition refs_test1_fwd : result unit := + if negb (1 %i32 s= 1 %i32) then Fail_ else Return tt . + +(** Unit test for [no_nested_borrows::refs_test1] *) +Check (refs_test1_fwd )%return. + +(** [no_nested_borrows::refs_test2] *) +Definition refs_test2_fwd : result unit := + if negb (2 %i32 s= 2 %i32) + then Fail_ + else + if negb (0 %i32 s= 0 %i32) + then Fail_ + else + if negb (2 %i32 s= 2 %i32) + then Fail_ + else if negb (2 %i32 s= 2 %i32) then Fail_ else Return tt + . + +(** Unit test for [no_nested_borrows::refs_test2] *) +Check (refs_test2_fwd )%return. + +(** [no_nested_borrows::test_list1] *) +Definition test_list1_fwd : result unit := Return tt . + +(** Unit test for [no_nested_borrows::test_list1] *) +Check (test_list1_fwd )%return. + +(** [no_nested_borrows::test_box1] *) +Definition test_box1_fwd : result unit := + let b := 1 %i32 in + let x := b in if negb (x s= 1 %i32) then Fail_ else Return tt + . + +(** Unit test for [no_nested_borrows::test_box1] *) +Check (test_box1_fwd )%return. + +(** [no_nested_borrows::copy_int] *) +Definition copy_int_fwd (x : i32) : result i32 := Return x . + +(** [no_nested_borrows::test_unreachable] *) +Definition test_unreachable_fwd (b : bool) : result unit := + if b then Fail_ else Return tt . + +(** [no_nested_borrows::test_panic] *) +Definition test_panic_fwd (b : bool) : result unit := + if b then Fail_ else Return tt . + +(** [no_nested_borrows::test_copy_int] *) +Definition test_copy_int_fwd : result unit := + y <- copy_int_fwd (0 %i32); if negb (0 %i32 s= y) then Fail_ else Return tt . + +(** Unit test for [no_nested_borrows::test_copy_int] *) +Check (test_copy_int_fwd )%return. + +(** [no_nested_borrows::is_cons] *) +Definition is_cons_fwd (T : Type) (l : List_t T) : result bool := + match l with | ListCons t l0 => Return true | ListNil => Return false end . + +(** [no_nested_borrows::test_is_cons] *) +Definition test_is_cons_fwd : result unit := + let l := ListNil in + b <- is_cons_fwd i32 (ListCons (0 %i32) l); + if negb b then Fail_ else Return tt + . + +(** Unit test for [no_nested_borrows::test_is_cons] *) +Check (test_is_cons_fwd )%return. + +(** [no_nested_borrows::split_list] *) +Definition split_list_fwd + (T : Type) (l : List_t T) : result (T * (List_t T)) := + match l with | ListCons hd tl => Return (hd, tl) | ListNil => Fail_ end . + +(** [no_nested_borrows::test_split_list] *) +Definition test_split_list_fwd : result unit := + let l := ListNil in + p <- split_list_fwd i32 (ListCons (0 %i32) l); + let (hd, _) := p in if negb (hd s= 0 %i32) then Fail_ else Return tt + . + +(** Unit test for [no_nested_borrows::test_split_list] *) +Check (test_split_list_fwd )%return. + +(** [no_nested_borrows::choose] *) +Definition choose_fwd (T : Type) (b : bool) (x : T) (y : T) : result T := + if b then Return x else Return y . + +(** [no_nested_borrows::choose] *) +Definition choose_back + (T : Type) (b : bool) (x : T) (y : T) (ret : T) : result (T * T) := + if b then Return (ret, y) else Return (x, ret) . + +(** [no_nested_borrows::choose_test] *) +Definition choose_test_fwd : result unit := + z <- choose_fwd i32 true (0 %i32) (0 %i32); + z0 <- i32_add z 1 %i32; + if negb (z0 s= 1 %i32) + then Fail_ + else + ( + p <- choose_back i32 true (0 %i32) (0 %i32) z0; + let (x, y) := p in + if negb (x s= 1 %i32) + then Fail_ + else if negb (y s= 0 %i32) then Fail_ else Return tt ) + . + +(** Unit test for [no_nested_borrows::choose_test] *) +Check (choose_test_fwd )%return. + +(** [no_nested_borrows::test_char] *) +Definition test_char_fwd : result char := + Return (char_of_byte Coq.Init.Byte.x61) . + +(** [no_nested_borrows::NodeElem] *) +Inductive Node_elem_t (T : Type) := +| NodeElemCons : Tree_t T -> Node_elem_t T -> Node_elem_t T +| NodeElemNil : Node_elem_t T + +(** [no_nested_borrows::Tree] *) +with Tree_t (T : Type) := +| TreeLeaf : T -> Tree_t T +| TreeNode : T -> Node_elem_t T -> Tree_t T -> Tree_t T +. + +Arguments NodeElemCons {T} _ _ . +Arguments NodeElemNil {T} . + +Arguments TreeLeaf {T} _ . +Arguments TreeNode {T} _ _ _ . + +(** [no_nested_borrows::list_length] *) +Fixpoint list_length_fwd (T : Type) (l : List_t T) : result u32 := + match l with + | ListCons t l1 => + i <- list_length_fwd T l1; i0 <- u32_add 1 %u32 i; Return i0 + | ListNil => Return (0 %u32) + end + . + +(** [no_nested_borrows::list_nth_shared] *) +Fixpoint list_nth_shared_fwd (T : Type) (l : List_t T) (i : u32) : result T := + match l with + | ListCons x tl => + if i s= 0 %u32 + then Return x + else ( i0 <- u32_sub i 1 %u32; t <- list_nth_shared_fwd T tl i0; Return t ) + | ListNil => Fail_ + end + . + +(** [no_nested_borrows::list_nth_mut] *) +Fixpoint list_nth_mut_fwd (T : Type) (l : List_t T) (i : u32) : result T := + match l with + | ListCons x tl => + if i s= 0 %u32 + then Return x + else ( i0 <- u32_sub i 1 %u32; t <- list_nth_mut_fwd T tl i0; Return t ) + | ListNil => Fail_ + end + . + +(** [no_nested_borrows::list_nth_mut] *) +Fixpoint list_nth_mut_back + (T : Type) (l : List_t T) (i : u32) (ret : T) : result (List_t T) := + match l with + | ListCons x tl => + if i s= 0 %u32 + then Return (ListCons ret tl) + else + ( + i0 <- u32_sub i 1 %u32; + tl0 <- list_nth_mut_back T tl i0 ret; Return (ListCons x tl0) ) + | ListNil => Fail_ + end + . + +(** [no_nested_borrows::list_rev_aux] *) +Fixpoint list_rev_aux_fwd + (T : Type) (li : List_t T) (lo : List_t T) : result (List_t T) := + match li with + | ListCons hd tl => l <- list_rev_aux_fwd T tl (ListCons hd lo); Return l + | ListNil => Return lo + end + . + +(** [no_nested_borrows::list_rev] *) +Definition list_rev_fwd_back (T : Type) (l : List_t T) : result (List_t T) := + let li := mem_replace_fwd (List_t T) l ListNil in + l0 <- list_rev_aux_fwd T li ListNil; Return l0 + . + +(** [no_nested_borrows::test_list_functions] *) +Definition test_list_functions_fwd : result unit := + let l := ListNil in + let l0 := ListCons (2 %i32) l in + let l1 := ListCons (1 %i32) l0 in + i <- list_length_fwd i32 (ListCons (0 %i32) l1); + if negb (i s= 3 %u32) + then Fail_ + else + ( + i0 <- list_nth_shared_fwd i32 (ListCons (0 %i32) l1) (0 %u32); + if negb (i0 s= 0 %i32) + then Fail_ + else + ( + i1 <- list_nth_shared_fwd i32 (ListCons (0 %i32) l1) (1 %u32); + if negb (i1 s= 1 %i32) + then Fail_ + else + ( + i2 <- list_nth_shared_fwd i32 (ListCons (0 %i32) l1) (2 %u32); + if negb (i2 s= 2 %i32) + then Fail_ + else + ( + ls <- + list_nth_mut_back i32 (ListCons (0 %i32) l1) (1 %u32) (3 + %i32); + i3 <- list_nth_shared_fwd i32 ls (0 %u32); + if negb (i3 s= 0 %i32) + then Fail_ + else + ( + i4 <- list_nth_shared_fwd i32 ls (1 %u32); + if negb (i4 s= 3 %i32) + then Fail_ + else + ( + i5 <- list_nth_shared_fwd i32 ls (2 %u32); + if negb (i5 s= 2 %i32) then Fail_ else Return tt ) ) + ) ) ) ) + . + +(** Unit test for [no_nested_borrows::test_list_functions] *) +Check (test_list_functions_fwd )%return. + +(** [no_nested_borrows::id_mut_pair1] *) +Definition id_mut_pair1_fwd + (T1 T2 : Type) (x : T1) (y : T2) : result (T1 * T2) := + Return (x, y) . + +(** [no_nested_borrows::id_mut_pair1] *) +Definition id_mut_pair1_back + (T1 T2 : Type) (x : T1) (y : T2) (ret : (T1 * T2)) : result (T1 * T2) := + let (t, t0) := ret in Return (t, t0) . + +(** [no_nested_borrows::id_mut_pair2] *) +Definition id_mut_pair2_fwd + (T1 T2 : Type) (p : (T1 * T2)) : result (T1 * T2) := + let (t, t0) := p in Return (t, t0) . + +(** [no_nested_borrows::id_mut_pair2] *) +Definition id_mut_pair2_back + (T1 T2 : Type) (p : (T1 * T2)) (ret : (T1 * T2)) : result (T1 * T2) := + let (t, t0) := ret in Return (t, t0) . + +(** [no_nested_borrows::id_mut_pair3] *) +Definition id_mut_pair3_fwd + (T1 T2 : Type) (x : T1) (y : T2) : result (T1 * T2) := + Return (x, y) . + +(** [no_nested_borrows::id_mut_pair3] *) +Definition id_mut_pair3_back'a + (T1 T2 : Type) (x : T1) (y : T2) (ret : T1) : result T1 := + Return ret . + +(** [no_nested_borrows::id_mut_pair3] *) +Definition id_mut_pair3_back'b + (T1 T2 : Type) (x : T1) (y : T2) (ret : T2) : result T2 := + Return ret . + +(** [no_nested_borrows::id_mut_pair4] *) +Definition id_mut_pair4_fwd + (T1 T2 : Type) (p : (T1 * T2)) : result (T1 * T2) := + let (t, t0) := p in Return (t, t0) . + +(** [no_nested_borrows::id_mut_pair4] *) +Definition id_mut_pair4_back'a + (T1 T2 : Type) (p : (T1 * T2)) (ret : T1) : result T1 := + Return ret . + +(** [no_nested_borrows::id_mut_pair4] *) +Definition id_mut_pair4_back'b + (T1 T2 : Type) (p : (T1 * T2)) (ret : T2) : result T2 := + Return ret . + +(** [no_nested_borrows::StructWithTuple] *) +Record Struct_with_tuple_t (T1 T2 : Type) := +mkStruct_with_tuple_t +{ + Struct_with_tuple_p : (T1 * T2); +} +. + +Arguments mkStruct_with_tuple_t {T1} {T2} _ . +Arguments Struct_with_tuple_p {T1} {T2} . + +(** [no_nested_borrows::new_tuple1] *) +Definition new_tuple1_fwd : result (Struct_with_tuple_t u32 u32) := + Return (mkStruct_with_tuple_t (1 %u32, 2 %u32)) . + +(** [no_nested_borrows::new_tuple2] *) +Definition new_tuple2_fwd : result (Struct_with_tuple_t i16 i16) := + Return (mkStruct_with_tuple_t (1 %i16, 2 %i16)) . + +(** [no_nested_borrows::new_tuple3] *) +Definition new_tuple3_fwd : result (Struct_with_tuple_t u64 i64) := + Return (mkStruct_with_tuple_t (1 %u64, 2 %i64)) . + +(** [no_nested_borrows::StructWithPair] *) +Record Struct_with_pair_t (T1 T2 : Type) := +mkStruct_with_pair_t +{ + Struct_with_pair_p : Pair_t T1 T2; +} +. + +Arguments mkStruct_with_pair_t {T1} {T2} _ . +Arguments Struct_with_pair_p {T1} {T2} . + +(** [no_nested_borrows::new_pair1] *) +Definition new_pair1_fwd : result (Struct_with_pair_t u32 u32) := + Return (mkStruct_with_pair_t (mkPair_t (1 %u32) (2 %u32))) . + +(** [no_nested_borrows::test_constants] *) +Definition test_constants_fwd : result unit := + swt <- new_tuple1_fwd; + match swt with + | mkStruct_with_tuple_t p => + let (i, _) := p in + if negb (i s= 1 %u32) + then Fail_ + else + ( + swt0 <- new_tuple2_fwd; + match swt0 with + | mkStruct_with_tuple_t p0 => + let (i0, _) := p0 in + if negb (i0 s= 1 %i16) + then Fail_ + else + ( + swt1 <- new_tuple3_fwd; + match swt1 with + | mkStruct_with_tuple_t p1 => + let (i1, _) := p1 in + if negb (i1 s= 1 %u64) + then Fail_ + else + ( + swp <- new_pair1_fwd; + match swp with + | mkStruct_with_pair_t p2 => + match p2 with + | mkPair_t i2 i3 => + if negb (i2 s= 1 %u32) then Fail_ else Return tt + end + end ) + end ) + end ) + end + . + +(** Unit test for [no_nested_borrows::test_constants] *) +Check (test_constants_fwd )%return. + +(** [no_nested_borrows::test_weird_borrows1] *) +Definition test_weird_borrows1_fwd : result unit := Return tt . + +(** Unit test for [no_nested_borrows::test_weird_borrows1] *) +Check (test_weird_borrows1_fwd )%return. + +(** [no_nested_borrows::test_mem_replace] *) +Definition test_mem_replace_fwd_back (px : u32) : result u32 := + let y := mem_replace_fwd u32 px (1 %u32) in + if negb (y s= 0 %u32) then Fail_ else Return (2 %u32) + . + +(** [no_nested_borrows::test_shared_borrow_bool1] *) +Definition test_shared_borrow_bool1_fwd (b : bool) : result u32 := + if b then Return (0 %u32) else Return (1 %u32) . + +(** [no_nested_borrows::test_shared_borrow_bool2] *) +Definition test_shared_borrow_bool2_fwd : result u32 := Return (0 %u32) . + +(** [no_nested_borrows::test_shared_borrow_enum1] *) +Definition test_shared_borrow_enum1_fwd (l : List_t u32) : result u32 := + match l with + | ListCons i l0 => Return (1 %u32) + | ListNil => Return (0 %u32) + end + . + +(** [no_nested_borrows::test_shared_borrow_enum2] *) +Definition test_shared_borrow_enum2_fwd : result u32 := Return (0 %u32) . + +End NoNestedBorrows . diff --git a/tests/coq/misc/Paper.v b/tests/coq/misc/Paper.v new file mode 100644 index 00000000..5d9598eb --- /dev/null +++ b/tests/coq/misc/Paper.v @@ -0,0 +1,114 @@ +(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *) +(** [paper] *) +Require Import Primitives. +Import Primitives. +Require Import Coq.ZArith.ZArith. +Local Open Scope Primitives_scope. +Module Paper . + +(** [paper::ref_incr] *) +Definition ref_incr_fwd_back (x : i32) : result i32 := + x0 <- i32_add x 1 %i32; Return x0 . + +(** [paper::test_incr] *) +Definition test_incr_fwd : result unit := + x <- ref_incr_fwd_back (0 %i32); + if negb (x s= 1 %i32) then Fail_ else Return tt + . + +(** Unit test for [paper::test_incr] *) +Check (test_incr_fwd )%return. + +(** [paper::choose] *) +Definition choose_fwd (T : Type) (b : bool) (x : T) (y : T) : result T := + if b then Return x else Return y . + +(** [paper::choose] *) +Definition choose_back + (T : Type) (b : bool) (x : T) (y : T) (ret : T) : result (T * T) := + if b then Return (ret, y) else Return (x, ret) . + +(** [paper::test_choose] *) +Definition test_choose_fwd : result unit := + z <- choose_fwd i32 true (0 %i32) (0 %i32); + z0 <- i32_add z 1 %i32; + if negb (z0 s= 1 %i32) + then Fail_ + else + ( + p <- choose_back i32 true (0 %i32) (0 %i32) z0; + let (x, y) := p in + if negb (x s= 1 %i32) + then Fail_ + else if negb (y s= 0 %i32) then Fail_ else Return tt ) + . + +(** Unit test for [paper::test_choose] *) +Check (test_choose_fwd )%return. + +(** [paper::List] *) +Inductive List_t (T : Type) := +| ListCons : T -> List_t T -> List_t T +| ListNil : List_t T +. + +Arguments ListCons {T} _ _ . +Arguments ListNil {T} . + +(** [paper::list_nth_mut] *) +Fixpoint list_nth_mut_fwd (T : Type) (l : List_t T) (i : u32) : result T := + match l with + | ListCons x tl => + if i s= 0 %u32 + then Return x + else ( i0 <- u32_sub i 1 %u32; t <- list_nth_mut_fwd T tl i0; Return t ) + | ListNil => Fail_ + end + . + +(** [paper::list_nth_mut] *) +Fixpoint list_nth_mut_back + (T : Type) (l : List_t T) (i : u32) (ret : T) : result (List_t T) := + match l with + | ListCons x tl => + if i s= 0 %u32 + then Return (ListCons ret tl) + else + ( + i0 <- u32_sub i 1 %u32; + tl0 <- list_nth_mut_back T tl i0 ret; Return (ListCons x tl0) ) + | ListNil => Fail_ + end + . + +(** [paper::sum] *) +Fixpoint sum_fwd (l : List_t i32) : result i32 := + match l with + | ListCons x tl => i <- sum_fwd tl; i0 <- i32_add x i; Return i0 + | ListNil => Return (0 %i32) + end + . + +(** [paper::test_nth] *) +Definition test_nth_fwd : result unit := + let l := ListNil in + let l0 := ListCons (3 %i32) l in + let l1 := ListCons (2 %i32) l0 in + x <- list_nth_mut_fwd i32 (ListCons (1 %i32) l1) (2 %u32); + x0 <- i32_add x 1 %i32; + l2 <- list_nth_mut_back i32 (ListCons (1 %i32) l1) (2 %u32) x0; + i <- sum_fwd l2; if negb (i s= 7 %i32) then Fail_ else Return tt + . + +(** Unit test for [paper::test_nth] *) +Check (test_nth_fwd )%return. + +(** [paper::call_choose] *) +Definition call_choose_fwd (p : (u32 * u32)) : result u32 := + let (px, py) := p in + pz <- choose_fwd u32 true px py; + pz0 <- u32_add pz 1 %u32; + p0 <- choose_back u32 true px py pz0; let (px0, _) := p0 in Return px0 + . + +End Paper . diff --git a/tests/coq/misc/Primitives.v b/tests/coq/misc/Primitives.v new file mode 100644 index 00000000..c27b8aed --- /dev/null +++ b/tests/coq/misc/Primitives.v @@ -0,0 +1,478 @@ +Require Import Lia. +Require Coq.Strings.Ascii. +Require Coq.Strings.String. +Require Import Coq.Program.Equality. +Require Import Coq.ZArith.ZArith. +Require Import Coq.ZArith.Znat. +Require Import List. +Import ListNotations. + +Module Primitives. + + (* TODO: use more *) +Declare Scope Primitives_scope. + +(*** Result *) + +Inductive result A := + | Return : A -> result A + | Fail_ : result A. + +Arguments Return {_} a. +Arguments Fail_ {_}. + +Definition bind {A B} (m: result A) (f: A -> result B) : result B := + match m with + | Fail_ => Fail_ + | Return x => f x + end. + +Definition return_ {A: Type} (x: A) : result A := Return x . +Definition fail_ {A: Type} : result A := Fail_ . + +Notation "x <- c1 ; c2" := (bind c1 (fun x => c2)) + (at level 61, c1 at next level, right associativity). + +(** Monadic assert *) +Definition massert (b: bool) : result unit := + if b then Return tt else Fail_. + +(** Normalize and unwrap a successful result (used for globals) *) +Definition eval_result_refl {A} {x} (a: result A) (p: a = Return x) : A := + match a as r return (r = Return x -> A) with + | Return a' => fun _ => a' + | Fail_ => fun p' => + False_rect _ (eq_ind Fail_ + (fun e : result A => + match e with + | Return _ => False + | Fail_ => True + end) + I (Return x) p') + end p. + +Notation "x %global" := (eval_result_refl x eq_refl) (at level 40). +Notation "x %return" := (eval_result_refl x eq_refl) (at level 40). + +(* Sanity check *) +Check (if true then Return (1 + 2) else Fail_)%global = 3. + +(*** Misc *) + + +Definition string := Coq.Strings.String.string. +Definition char := Coq.Strings.Ascii.ascii. +Definition char_of_byte := Coq.Strings.Ascii.ascii_of_byte. + +Definition mem_replace_fwd (a : Type) (x : a) (y : a) : a := x . +Definition mem_replace_back (a : Type) (x : a) (y : a) : a := y . + +(*** Scalars *) + +Definition i8_min : Z := -128%Z. +Definition i8_max : Z := 127%Z. +Definition i16_min : Z := -32768%Z. +Definition i16_max : Z := 32767%Z. +Definition i32_min : Z := -2147483648%Z. +Definition i32_max : Z := 2147483647%Z. +Definition i64_min : Z := -9223372036854775808%Z. +Definition i64_max : Z := 9223372036854775807%Z. +Definition i128_min : Z := -170141183460469231731687303715884105728%Z. +Definition i128_max : Z := 170141183460469231731687303715884105727%Z. +Definition u8_min : Z := 0%Z. +Definition u8_max : Z := 255%Z. +Definition u16_min : Z := 0%Z. +Definition u16_max : Z := 65535%Z. +Definition u32_min : Z := 0%Z. +Definition u32_max : Z := 4294967295%Z. +Definition u64_min : Z := 0%Z. +Definition u64_max : Z := 18446744073709551615%Z. +Definition u128_min : Z := 0%Z. +Definition u128_max : Z := 340282366920938463463374607431768211455%Z. + +(** The bounds of [isize] and [usize] vary with the architecture. *) +Axiom isize_min : Z. +Axiom isize_max : Z. +Definition usize_min : Z := 0%Z. +Axiom usize_max : Z. + +Open Scope Z_scope. + +(** We provide those lemmas to reason about the bounds of [isize] and [usize] *) +Axiom isize_min_bound : isize_min <= i32_min. +Axiom isize_max_bound : i32_max <= isize_max. +Axiom usize_max_bound : u32_max <= usize_max. + +Inductive scalar_ty := + | Isize + | I8 + | I16 + | I32 + | I64 + | I128 + | Usize + | U8 + | U16 + | U32 + | U64 + | U128 +. + +Definition scalar_min (ty: scalar_ty) : Z := + match ty with + | Isize => isize_min + | I8 => i8_min + | I16 => i16_min + | I32 => i32_min + | I64 => i64_min + | I128 => i128_min + | Usize => usize_min + | U8 => u8_min + | U16 => u16_min + | U32 => u32_min + | U64 => u64_min + | U128 => u128_min +end. + +Definition scalar_max (ty: scalar_ty) : Z := + match ty with + | Isize => isize_max + | I8 => i8_max + | I16 => i16_max + | I32 => i32_max + | I64 => i64_max + | I128 => i128_max + | Usize => usize_max + | U8 => u8_max + | U16 => u16_max + | U32 => u32_max + | U64 => u64_max + | U128 => u128_max +end. + +(** We use the following conservative bounds to make sure we can compute bound + checks in most situations *) +Definition scalar_min_cons (ty: scalar_ty) : Z := + match ty with + | Isize => i32_min + | Usize => u32_min + | _ => scalar_min ty +end. + +Definition scalar_max_cons (ty: scalar_ty) : Z := + match ty with + | Isize => i32_max + | Usize => u32_max + | _ => scalar_max ty +end. + +Lemma scalar_min_cons_valid : forall ty, scalar_min ty <= scalar_min_cons ty . +Proof. + destruct ty; unfold scalar_min_cons, scalar_min; try lia. + - pose isize_min_bound; lia. + - apply Z.le_refl. +Qed. + +Lemma scalar_max_cons_valid : forall ty, scalar_max ty >= scalar_max_cons ty . +Proof. + destruct ty; unfold scalar_max_cons, scalar_max; try lia. + - pose isize_max_bound; lia. + - pose usize_max_bound. lia. +Qed. + +Definition scalar (ty: scalar_ty) : Type := + { x: Z | scalar_min ty <= x <= scalar_max ty }. + +Definition to_Z {ty} (x: scalar ty) : Z := proj1_sig x. + +(** Bounds checks: we start by using the conservative bounds, to make sure we + can compute in most situations, then we use the real bounds (for [isize] + and [usize]). *) +Definition scalar_ge_min (ty: scalar_ty) (x: Z) : bool := + Z.leb (scalar_min_cons ty) x || Z.leb (scalar_min ty) x. + +Definition scalar_le_max (ty: scalar_ty) (x: Z) : bool := + Z.leb x (scalar_max_cons ty) || Z.leb x (scalar_max ty). + +Lemma scalar_ge_min_valid (ty: scalar_ty) (x: Z) : + scalar_ge_min ty x = true -> scalar_min ty <= x . +Proof. + unfold scalar_ge_min. + pose (scalar_min_cons_valid ty). + lia. +Qed. + +Lemma scalar_le_max_valid (ty: scalar_ty) (x: Z) : + scalar_le_max ty x = true -> x <= scalar_max ty . +Proof. + unfold scalar_le_max. + pose (scalar_max_cons_valid ty). + lia. +Qed. + +Definition scalar_in_bounds (ty: scalar_ty) (x: Z) : bool := + scalar_ge_min ty x && scalar_le_max ty x . + +Lemma scalar_in_bounds_valid (ty: scalar_ty) (x: Z) : + scalar_in_bounds ty x = true -> scalar_min ty <= x <= scalar_max ty . +Proof. + unfold scalar_in_bounds. + intros H. + destruct (scalar_ge_min ty x) eqn:Hmin. + - destruct (scalar_le_max ty x) eqn:Hmax. + + pose (scalar_ge_min_valid ty x Hmin). + pose (scalar_le_max_valid ty x Hmax). + lia. + + inversion H. + - inversion H. +Qed. + +Import Sumbool. + +Definition mk_scalar (ty: scalar_ty) (x: Z) : result (scalar ty) := + match sumbool_of_bool (scalar_in_bounds ty x) with + | left H => Return (exist _ x (scalar_in_bounds_valid _ _ H)) + | right _ => Fail_ + end. + +Definition scalar_add {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x + to_Z y). + +Definition scalar_sub {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x - to_Z y). + +Definition scalar_mul {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x * to_Z y). + +Definition scalar_div {ty} (x y: scalar ty) : result (scalar ty) := + if to_Z y =? 0 then Fail_ else + mk_scalar ty (to_Z x / to_Z y). + +Definition scalar_rem {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (Z.rem (to_Z x) (to_Z y)). + +Definition scalar_neg {ty} (x: scalar ty) : result (scalar ty) := mk_scalar ty (-(to_Z x)). + +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +(* TODO: check the semantics of casts in Rust *) +Definition scalar_cast (src_ty tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) := + mk_scalar tgt_ty (to_Z x). + +(** Comparisons *) +Print Z.leb . + +Definition scalar_leb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.leb (to_Z x) (to_Z y) . + +Definition scalar_ltb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.ltb (to_Z x) (to_Z y) . + +Definition scalar_geb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.geb (to_Z x) (to_Z y) . + +Definition scalar_gtb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.gtb (to_Z x) (to_Z y) . + +Definition scalar_eqb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + Z.eqb (to_Z x) (to_Z y) . + +Definition scalar_neqb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := + negb (Z.eqb (to_Z x) (to_Z y)) . + + +(** The scalar types *) +Definition isize := scalar Isize. +Definition i8 := scalar I8. +Definition i16 := scalar I16. +Definition i32 := scalar I32. +Definition i64 := scalar I64. +Definition i128 := scalar I128. +Definition usize := scalar Usize. +Definition u8 := scalar U8. +Definition u16 := scalar U16. +Definition u32 := scalar U32. +Definition u64 := scalar U64. +Definition u128 := scalar U128. + +(** Negaion *) +Definition isize_neg := @scalar_neg Isize. +Definition i8_neg := @scalar_neg I8. +Definition i16_neg := @scalar_neg I16. +Definition i32_neg := @scalar_neg I32. +Definition i64_neg := @scalar_neg I64. +Definition i128_neg := @scalar_neg I128. + +(** Division *) +Definition isize_div := @scalar_div Isize. +Definition i8_div := @scalar_div I8. +Definition i16_div := @scalar_div I16. +Definition i32_div := @scalar_div I32. +Definition i64_div := @scalar_div I64. +Definition i128_div := @scalar_div I128. +Definition usize_div := @scalar_div Usize. +Definition u8_div := @scalar_div U8. +Definition u16_div := @scalar_div U16. +Definition u32_div := @scalar_div U32. +Definition u64_div := @scalar_div U64. +Definition u128_div := @scalar_div U128. + +(** Remainder *) +Definition isize_rem := @scalar_rem Isize. +Definition i8_rem := @scalar_rem I8. +Definition i16_rem := @scalar_rem I16. +Definition i32_rem := @scalar_rem I32. +Definition i64_rem := @scalar_rem I64. +Definition i128_rem := @scalar_rem I128. +Definition usize_rem := @scalar_rem Usize. +Definition u8_rem := @scalar_rem U8. +Definition u16_rem := @scalar_rem U16. +Definition u32_rem := @scalar_rem U32. +Definition u64_rem := @scalar_rem U64. +Definition u128_rem := @scalar_rem U128. + +(** Addition *) +Definition isize_add := @scalar_add Isize. +Definition i8_add := @scalar_add I8. +Definition i16_add := @scalar_add I16. +Definition i32_add := @scalar_add I32. +Definition i64_add := @scalar_add I64. +Definition i128_add := @scalar_add I128. +Definition usize_add := @scalar_add Usize. +Definition u8_add := @scalar_add U8. +Definition u16_add := @scalar_add U16. +Definition u32_add := @scalar_add U32. +Definition u64_add := @scalar_add U64. +Definition u128_add := @scalar_add U128. + +(** Substraction *) +Definition isize_sub := @scalar_sub Isize. +Definition i8_sub := @scalar_sub I8. +Definition i16_sub := @scalar_sub I16. +Definition i32_sub := @scalar_sub I32. +Definition i64_sub := @scalar_sub I64. +Definition i128_sub := @scalar_sub I128. +Definition usize_sub := @scalar_sub Usize. +Definition u8_sub := @scalar_sub U8. +Definition u16_sub := @scalar_sub U16. +Definition u32_sub := @scalar_sub U32. +Definition u64_sub := @scalar_sub U64. +Definition u128_sub := @scalar_sub U128. + +(** Multiplication *) +Definition isize_mul := @scalar_mul Isize. +Definition i8_mul := @scalar_mul I8. +Definition i16_mul := @scalar_mul I16. +Definition i32_mul := @scalar_mul I32. +Definition i64_mul := @scalar_mul I64. +Definition i128_mul := @scalar_mul I128. +Definition usize_mul := @scalar_mul Usize. +Definition u8_mul := @scalar_mul U8. +Definition u16_mul := @scalar_mul U16. +Definition u32_mul := @scalar_mul U32. +Definition u64_mul := @scalar_mul U64. +Definition u128_mul := @scalar_mul U128. + +(** Small utility *) +Definition usize_to_nat (x: usize) : nat := Z.to_nat (to_Z x). + +(** Notations *) +Notation "x %isize" := ((mk_scalar Isize x)%return) (at level 9). +Notation "x %i8" := ((mk_scalar I8 x)%return) (at level 9). +Notation "x %i16" := ((mk_scalar I16 x)%return) (at level 9). +Notation "x %i32" := ((mk_scalar I32 x)%return) (at level 9). +Notation "x %i64" := ((mk_scalar I64 x)%return) (at level 9). +Notation "x %i128" := ((mk_scalar I128 x)%return) (at level 9). +Notation "x %usize" := ((mk_scalar Usize x)%return) (at level 9). +Notation "x %u8" := ((mk_scalar U8 x)%return) (at level 9). +Notation "x %u16" := ((mk_scalar U16 x)%return) (at level 9). +Notation "x %u32" := ((mk_scalar U32 x)%return) (at level 9). +Notation "x %u64" := ((mk_scalar U64 x)%return) (at level 9). +Notation "x %u128" := ((mk_scalar U128 x)%return) (at level 9). + +Notation "x s= y" := (scalar_eqb x y) (at level 80) : Primitives_scope. +Notation "x s<> y" := (scalar_neqb x y) (at level 80) : Primitives_scope. +Notation "x s<= y" := (scalar_leb x y) (at level 80) : Primitives_scope. +Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. +Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. +Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. + +(*** Vectors *) + +Definition vec T := { l: list T | Z.of_nat (length l) <= usize_max }. + +Definition vec_to_list {T: Type} (v: vec T) : list T := proj1_sig v. + +Definition vec_length {T: Type} (v: vec T) : Z := Z.of_nat (length (vec_to_list v)). + +Lemma le_0_usize_max : 0 <= usize_max. +Proof. + pose (H := usize_max_bound). + unfold u32_max in H. + lia. +Qed. + +Definition vec_new (T: Type) : vec T := (exist _ [] le_0_usize_max). + +Lemma vec_len_in_usize {T} (v: vec T) : usize_min <= vec_length v <= usize_max. +Proof. + unfold vec_length, usize_min. + split. + - lia. + - apply (proj2_sig v). +Qed. + +Definition vec_len (T: Type) (v: vec T) : usize := + exist _ (vec_length v) (vec_len_in_usize v). + +Fixpoint list_update {A} (l: list A) (n: nat) (a: A) + : list A := + match l with + | [] => [] + | x :: t => match n with + | 0%nat => a :: t + | S m => x :: (list_update t m a) +end end. + +Definition vec_bind {A B} (v: vec A) (f: list A -> result (list B)) : result (vec B) := + l <- f (vec_to_list v) ; + match sumbool_of_bool (scalar_le_max Usize (Z.of_nat (length l))) with + | left H => Return (exist _ l (scalar_le_max_valid _ _ H)) + | right _ => Fail_ + end. + +(* The **forward** function shouldn't be used *) +Definition vec_push_fwd (T: Type) (v: vec T) (x: T) : unit := tt. + +Definition vec_push_back (T: Type) (v: vec T) (x: T) : result (vec T) := + vec_bind v (fun l => Return (l ++ [x])). + +(* The **forward** function shouldn't be used *) +Definition vec_insert_fwd (T: Type) (v: vec T) (i: usize) (x: T) : result unit := + if to_Z i + if to_Z i Return n + | None => Fail_ + end. + +Definition vec_index_back (T: Type) (v: vec T) (i: usize) (x: T) : result unit := + if to_Z i Return n + | None => Fail_ + end. + +Definition vec_index_mut_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) := + vec_bind v (fun l => + if to_Z i tree_t t | TreeNode : t -> node_elem_t t -> tree_t t -> tree_t t -(** [no_nested_borrows::odd] *) -let rec odd_fwd (x : u32) : result bool = - if x = 0 - then Return false - else - begin match u32_sub x 1 with - | Fail -> Fail - | Return i -> - begin match even_fwd i with | Fail -> Fail | Return b -> Return b end - end - -(** [no_nested_borrows::even] *) -and even_fwd (x : u32) : result bool = - if x = 0 - then Return true - else - begin match u32_sub x 1 with - | Fail -> Fail - | Return i -> - begin match odd_fwd i with | Fail -> Fail | Return b -> Return b end - end - -(** [no_nested_borrows::test_even_odd] *) -let test_even_odd_fwd : result unit = - begin match even_fwd 0 with - | Fail -> Fail - | Return b -> - if not b - then Fail - else - begin match even_fwd 4 with - | Fail -> Fail - | Return b0 -> - if not b0 - then Fail - else - begin match odd_fwd 1 with - | Fail -> Fail - | Return b1 -> - if not b1 - then Fail - else - begin match odd_fwd 5 with - | Fail -> Fail - | Return b2 -> if not b2 then Fail else Return () - end - end - end - end - -(** Unit test for [no_nested_borrows::test_even_odd] *) -let _ = assert_norm (test_even_odd_fwd = Return ()) - (** [no_nested_borrows::list_length] *) let rec list_length_fwd (t : Type0) (l : list_t t) : result u32 = begin match l with -- cgit v1.2.3