diff options
92 files changed, 6276 insertions, 2195 deletions
@@ -166,7 +166,7 @@ tleanp-polonius_list: OPTIONS += thol4p-polonius_list: SUBDIR := misc-polonius_list thol4p-polonius_list: OPTIONS += -trans-constants: OPTIONS += -test-units -test-trans-units -no-split-files -no-state +trans-constants: OPTIONS += -test-trans-units -no-split-files -no-state trans-constants: SUBDIR := misc tfstar-constants: OPTIONS += tcoq-constants: OPTIONS += diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v index 71a2d9c3..b92eb967 100644 --- a/backends/coq/Primitives.v +++ b/backends/coq/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; @@ -419,6 +433,9 @@ Qed. (* TODO: finish the definitions *) Axiom mk_array : forall (T : Type) (n : usize) (l : list T), array T n. +(* For initialization *) +Axiom array_repeat : forall {T : Type} (n : usize) (x : T), array T n. + Axiom array_index_shared : forall (T : Type) (n : usize) (x : array T n) (i : usize), result T. Axiom array_index_mut_fwd : forall (T : Type) (n : usize) (x : array T n) (i : usize), result T. Axiom array_index_mut_back : forall (T : Type) (n : usize) (x : array T n) (i : usize) (nx : T), result (array T n). diff --git a/backends/fstar/Primitives.fst b/backends/fstar/Primitives.fst index 9db82069..e9391834 100644 --- a/backends/fstar/Primitives.fst +++ b/backends/fstar/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize @@ -325,6 +352,9 @@ let array_subslice_mut_fwd (a : Type0) (n : usize) (x : array a n) (r : range us let array_subslice_mut_back (a : Type0) (n : usize) (x : array a n) (r : range usize) (ns : slice a) : result (array a n) = admit() +let array_repeat (a : Type0) (n : usize) (x : a) : array a n = + admit() + let slice_subslice_shared (a : Type0) (x : slice a) (r : range usize) : result (slice a) = admit() diff --git a/backends/hol4/primitivesScript.sml b/backends/hol4/primitivesScript.sml index 82da4de9..916988be 100644 --- a/backends/hol4/primitivesScript.sml +++ b/backends/hol4/primitivesScript.sml @@ -555,6 +555,32 @@ Proof QED val _ = evalLib.add_unfold_thm "mk_isize_unfold" +(** Constants *) +val core_i8_min_def = Define ‘core_i8_min = int_to_i8 i8_min’ +val core_i8_max_def = Define ‘core_i8_max = int_to_i8 i8_max’ +val core_i16_min_def = Define ‘core_i16_min = int_to_i16 i16_min’ +val core_i16_max_def = Define ‘core_i16_max = int_to_i16 i16_max’ +val core_i32_min_def = Define ‘core_i32_min = int_to_i32 i32_min’ +val core_i32_max_def = Define ‘core_i32_max = int_to_i32 i32_max’ +val core_i64_min_def = Define ‘core_i64_min = int_to_i64 i64_min’ +val core_i64_max_def = Define ‘core_i64_max = int_to_i64 i64_max’ +val core_i128_min_def = Define ‘core_i128_min = int_to_i128 i128_min’ +val core_i128_max_def = Define ‘core_i128_max = int_to_i128 i128_max’ +val core_isize_min_def = Define ‘core_isize_min = int_to_isize isize_min’ +val core_isize_max_def = Define ‘core_isize_max = int_to_isize isize_max’ +val core_u8_min_def = Define ‘core_u8_min = int_to_u8 0’ +val core_u8_max_def = Define ‘core_u8_max = int_to_u8 u8_max’ +val core_u16_min_def = Define ‘core_u16_min = int_to_u16 0’ +val core_u16_max_def = Define ‘core_u16_max = int_to_u16 u16_max’ +val core_u32_min_def = Define ‘core_u32_min = int_to_u32 0’ +val core_u32_max_def = Define ‘core_u32_max = int_to_u32 u32_max’ +val core_u64_min_def = Define ‘core_u64_min = int_to_u64 0’ +val core_u64_max_def = Define ‘core_u64_max = int_to_u64 u64_max’ +val core_u128_min_def = Define ‘core_u128_min = int_to_u128 0’ +val core_u128_max_def = Define ‘core_u128_max = int_to_u128 u128_max’ +val core_usize_min_def = Define ‘core_usize_min = int_to_usize 0’ +val core_usize_max_def = Define ‘core_usize_max = int_to_usize usize_max’ + val isize_neg_def = Define ‘isize_neg x = mk_isize (- (isize_to_int x))’ val i8_neg_def = Define ‘i8_neg x = mk_i8 (- (i8_to_int x))’ val i16_neg_def = Define ‘i16_neg x = mk_i16 (- (i16_to_int x))’ diff --git a/backends/hol4/primitivesTheory.sig b/backends/hol4/primitivesTheory.sig index 6660b02d..4ae6bb3e 100644 --- a/backends/hol4/primitivesTheory.sig +++ b/backends/hol4/primitivesTheory.sig @@ -46,6 +46,30 @@ sig (* Definitions *) val bind_def : thm + val core_i128_max_def : thm + val core_i128_min_def : thm + val core_i16_max_def : thm + val core_i16_min_def : thm + val core_i32_max_def : thm + val core_i32_min_def : thm + val core_i64_max_def : thm + val core_i64_min_def : thm + val core_i8_max_def : thm + val core_i8_min_def : thm + val core_isize_max_def : thm + val core_isize_min_def : thm + val core_u128_max_def : thm + val core_u128_min_def : thm + val core_u16_max_def : thm + val core_u16_min_def : thm + val core_u32_max_def : thm + val core_u32_min_def : thm + val core_u64_max_def : thm + val core_u64_min_def : thm + val core_u8_max_def : thm + val core_u8_min_def : thm + val core_usize_max_def : thm + val core_usize_min_def : thm val error_BIJ : thm val error_CASE : thm val error_TY_DEF : thm @@ -566,6 +590,102 @@ sig monad_bind x f = case x of Return y => f y | Fail e => Fail e | Diverge => Diverge + [core_i128_max_def] Definition + + ⊢ core_i128_max = int_to_i128 i128_max + + [core_i128_min_def] Definition + + ⊢ core_i128_min = int_to_i128 i128_min + + [core_i16_max_def] Definition + + ⊢ core_i16_max = int_to_i16 i16_max + + [core_i16_min_def] Definition + + ⊢ core_i16_min = int_to_i16 i16_min + + [core_i32_max_def] Definition + + ⊢ core_i32_max = int_to_i32 i32_max + + [core_i32_min_def] Definition + + ⊢ core_i32_min = int_to_i32 i32_min + + [core_i64_max_def] Definition + + ⊢ core_i64_max = int_to_i64 i64_max + + [core_i64_min_def] Definition + + ⊢ core_i64_min = int_to_i64 i64_min + + [core_i8_max_def] Definition + + ⊢ core_i8_max = int_to_i8 i8_max + + [core_i8_min_def] Definition + + ⊢ core_i8_min = int_to_i8 i8_min + + [core_isize_max_def] Definition + + ⊢ core_isize_max = int_to_isize isize_max + + [core_isize_min_def] Definition + + ⊢ core_isize_min = int_to_isize isize_min + + [core_u128_max_def] Definition + + ⊢ core_u128_max = int_to_u128 u128_max + + [core_u128_min_def] Definition + + ⊢ core_u128_min = int_to_u128 0 + + [core_u16_max_def] Definition + + ⊢ core_u16_max = int_to_u16 u16_max + + [core_u16_min_def] Definition + + ⊢ core_u16_min = int_to_u16 0 + + [core_u32_max_def] Definition + + ⊢ core_u32_max = int_to_u32 u32_max + + [core_u32_min_def] Definition + + ⊢ core_u32_min = int_to_u32 0 + + [core_u64_max_def] Definition + + ⊢ core_u64_max = int_to_u64 u64_max + + [core_u64_min_def] Definition + + ⊢ core_u64_min = int_to_u64 0 + + [core_u8_max_def] Definition + + ⊢ core_u8_max = int_to_u8 u8_max + + [core_u8_min_def] Definition + + ⊢ core_u8_min = int_to_u8 0 + + [core_usize_max_def] Definition + + ⊢ core_usize_max = int_to_usize usize_max + + [core_usize_min_def] Definition + + ⊢ core_usize_min = int_to_usize 0 + [error_BIJ] Definition ⊢ (∀a. num2error (error2num a) = a) ∧ diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index a940da25..79de93d5 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -112,7 +112,19 @@ def pairwise_rel section Lemmas -variable {α : Type u} +variable {α : Type u} + +def ireplicate {α : Type u} (i : ℤ) (x : α) : List α := + if i ≤ 0 then [] + else x :: ireplicate (i - 1) x +termination_by ireplicate i x => i.toNat +decreasing_by + simp_wf + -- TODO: simplify this kind of proofs + simp at * + have : 0 ≤ i := by linarith + have : 1 ≤ i := by linarith + simp [Int.toNat_sub_of_le, *] @[simp] theorem update_nil : update ([] : List α) i y = [] := by simp [update] @[simp] theorem update_zero_cons : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update] @@ -129,6 +141,10 @@ variable {α : Type u} @[simp] theorem slice_nil : slice i j ([] : List α) = [] := by simp [slice] @[simp] theorem slice_zero : slice 0 0 (ls : List α) = [] := by cases ls <;> simp [slice] +@[simp] theorem ireplicate_zero : ireplicate 0 x = [] := by rw [ireplicate]; simp +@[simp] theorem ireplicate_nzero_cons (hne : 0 < i) : ireplicate i x = x :: ireplicate (i - 1) x := by + rw [ireplicate]; simp [*]; intro; linarith + @[simp] theorem slice_nzero_cons (i j : Int) (x : α) (tl : List α) (hne : i ≠ 0) : slice i j ((x :: tl) : List α) = slice (i - 1) (j - 1) tl := match tl with @@ -144,6 +160,45 @@ theorem slice_nzero_cons (i j : Int) (x : α) (tl : List α) (hne : i ≠ 0) : s conv at this => lhs; simp [slice, *] simp [*, slice] +@[simp] +theorem ireplicate_replicate {α : Type u} (l : ℤ) (x : α) (h : 0 ≤ l) : + ireplicate l x = replicate l.toNat x := + if hz: l = 0 then by + simp [*] + else by + have : 0 < l := by int_tac + have hr := ireplicate_replicate (l - 1) x (by int_tac) + simp [*] + have hl : l.toNat = .succ (l.toNat - 1) := by + cases hl: l.toNat <;> simp_all + conv => rhs; rw[hl] +termination_by ireplicate_replicate l x h => l.toNat +decreasing_by + simp_wf + -- TODO: simplify this kind of proofs + simp at * + have : 0 ≤ l := by linarith + have : 1 ≤ l := by linarith + simp [Int.toNat_sub_of_le, *] + +@[simp] +theorem ireplicate_len {α : Type u} (l : ℤ) (x : α) (h : 0 ≤ l) : + (ireplicate l x).len = l := + if hz: l = 0 then by + simp [*] + else by + have : 0 < l := by int_tac + have hr := ireplicate_len (l - 1) x (by int_tac) + simp [*] +termination_by ireplicate_len l x h => l.toNat +decreasing_by + simp_wf + -- TODO: simplify this kind of proofs + simp at * + have : 0 ≤ l := by linarith + have : 1 ≤ l := by linarith + simp [Int.toNat_sub_of_le, *] + theorem len_eq_length (ls : List α) : ls.len = ls.length := by induction ls . rfl diff --git a/backends/lean/Base/Primitives/Array.lean b/backends/lean/Base/Primitives/Array.lean index 6c95fd78..49c84bee 100644 --- a/backends/lean/Base/Primitives/Array.lean +++ b/backends/lean/Base/Primitives/Array.lean @@ -51,6 +51,15 @@ def Array.index_shared (α : Type u) (n : Usize) (v: Array α n) (i: Usize) : Re | none => fail .arrayOutOfBounds | some x => ret x +-- For initialization +def Array.repeat (α : Type u) (n : Usize) (x : α) : Array α n := + ⟨ List.ireplicate n.val x, by have h := n.hmin; simp_all [Scalar.min] ⟩ + +@[pspec] +theorem Array.repeat_spec {α : Type u} (n : Usize) (x : α) : + ∃ a, Array.repeat α n x = a ∧ a.val = List.ireplicate n.val x := by + simp [Array.repeat] + /- In the theorems below: we don't always need the `∃ ..`, but we use one so that `progress` introduces an opaque variable and an equality. This helps control the context. diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 55227a9f..ec9665a5 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -230,6 +230,20 @@ def Scalar.cMax (ty : ScalarTy) : Int := | .Usize => Scalar.max .U32 | _ => Scalar.max ty +theorem Scalar.min_lt_max (ty : ScalarTy) : Scalar.min ty < Scalar.max ty := by + cases ty <;> simp [Scalar.min, Scalar.max] + . simp [Isize.min, Isize.max] + have h1 := Isize.refined_min.property + have h2 := Isize.refined_max.property + cases h1 <;> cases h2 <;> simp [*] + . simp [Usize.max] + have h := Usize.refined_max.property + cases h <;> simp [*] + +theorem Scalar.min_le_max (ty : ScalarTy) : Scalar.min ty ≤ Scalar.max ty := by + have := Scalar.min_lt_max ty + int_tac + theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * have h := Isize.refined_min.property @@ -395,6 +409,34 @@ def Scalar.cast {src_ty : ScalarTy} (tgt_ty : ScalarTy) (x : Scalar src_ty) : Re @[reducible] def U64 := Scalar .U64 @[reducible] def U128 := Scalar .U128 +-- TODO: reducible? +@[reducible] def core_isize_min : Isize := Scalar.ofInt Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_isize_max : Isize := Scalar.ofInt Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min +@[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max +@[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min +@[reducible] def core_i16_max : I16 := Scalar.ofInt I16.max +@[reducible] def core_i32_min : I32 := Scalar.ofInt I32.min +@[reducible] def core_i32_max : I32 := Scalar.ofInt I32.max +@[reducible] def core_i64_min : I64 := Scalar.ofInt I64.min +@[reducible] def core_i64_max : I64 := Scalar.ofInt I64.max +@[reducible] def core_i128_min : I128 := Scalar.ofInt I128.min +@[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max + +-- TODO: reducible? +@[reducible] def core_usize_min : Usize := Scalar.ofInt Usize.min +@[reducible] def core_usize_max : Usize := Scalar.ofInt Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) +@[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min +@[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max +@[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min +@[reducible] def core_u16_max : U16 := Scalar.ofInt U16.max +@[reducible] def core_u32_min : U32 := Scalar.ofInt U32.min +@[reducible] def core_u32_max : U32 := Scalar.ofInt U32.max +@[reducible] def core_u64_min : U64 := Scalar.ofInt U64.min +@[reducible] def core_u64_max : U64 := Scalar.ofInt U64.max +@[reducible] def core_u128_min : U128 := Scalar.ofInt U128.min +@[reducible] def core_u128_max : U128 := Scalar.ofInt U128.max + -- TODO: below: not sure this is the best way. -- Should we rather overload operations like +, -, etc.? -- Also, it is possible to automate the generation of those definitions @@ -861,33 +903,33 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S -- ofIntCore -- TODO: typeclass? -@[reducible] def Isize.ofIntCore := @Scalar.ofIntCore .Isize -@[reducible] def I8.ofIntCore := @Scalar.ofIntCore .I8 -@[reducible] def I16.ofIntCore := @Scalar.ofIntCore .I16 -@[reducible] def I32.ofIntCore := @Scalar.ofIntCore .I32 -@[reducible] def I64.ofIntCore := @Scalar.ofIntCore .I64 -@[reducible] def I128.ofIntCore := @Scalar.ofIntCore .I128 -@[reducible] def Usize.ofIntCore := @Scalar.ofIntCore .Usize -@[reducible] def U8.ofIntCore := @Scalar.ofIntCore .U8 -@[reducible] def U16.ofIntCore := @Scalar.ofIntCore .U16 -@[reducible] def U32.ofIntCore := @Scalar.ofIntCore .U32 -@[reducible] def U64.ofIntCore := @Scalar.ofIntCore .U64 -@[reducible] def U128.ofIntCore := @Scalar.ofIntCore .U128 +def Isize.ofIntCore := @Scalar.ofIntCore .Isize +def I8.ofIntCore := @Scalar.ofIntCore .I8 +def I16.ofIntCore := @Scalar.ofIntCore .I16 +def I32.ofIntCore := @Scalar.ofIntCore .I32 +def I64.ofIntCore := @Scalar.ofIntCore .I64 +def I128.ofIntCore := @Scalar.ofIntCore .I128 +def Usize.ofIntCore := @Scalar.ofIntCore .Usize +def U8.ofIntCore := @Scalar.ofIntCore .U8 +def U16.ofIntCore := @Scalar.ofIntCore .U16 +def U32.ofIntCore := @Scalar.ofIntCore .U32 +def U64.ofIntCore := @Scalar.ofIntCore .U64 +def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -@[reducible] def Isize.ofInt := @Scalar.ofInt .Isize -@[reducible] def I8.ofInt := @Scalar.ofInt .I8 -@[reducible] def I16.ofInt := @Scalar.ofInt .I16 -@[reducible] def I32.ofInt := @Scalar.ofInt .I32 -@[reducible] def I64.ofInt := @Scalar.ofInt .I64 -@[reducible] def I128.ofInt := @Scalar.ofInt .I128 -@[reducible] def Usize.ofInt := @Scalar.ofInt .Usize -@[reducible] def U8.ofInt := @Scalar.ofInt .U8 -@[reducible] def U16.ofInt := @Scalar.ofInt .U16 -@[reducible] def U32.ofInt := @Scalar.ofInt .U32 -@[reducible] def U64.ofInt := @Scalar.ofInt .U64 -@[reducible] def U128.ofInt := @Scalar.ofInt .U128 +abbrev Isize.ofInt := @Scalar.ofInt .Isize +abbrev I8.ofInt := @Scalar.ofInt .I8 +abbrev I16.ofInt := @Scalar.ofInt .I16 +abbrev I32.ofInt := @Scalar.ofInt .I32 +abbrev I64.ofInt := @Scalar.ofInt .I64 +abbrev I128.ofInt := @Scalar.ofInt .I128 +abbrev Usize.ofInt := @Scalar.ofInt .Usize +abbrev U8.ofInt := @Scalar.ofInt .U8 +abbrev U16.ofInt := @Scalar.ofInt .U16 +abbrev U32.ofInt := @Scalar.ofInt .U32 +abbrev U64.ofInt := @Scalar.ofInt .U64 +abbrev U128.ofInt := @Scalar.ofInt .U128 postfix:max "#isize" => Isize.ofInt postfix:max "#i8" => I8.ofInt @@ -905,9 +947,46 @@ postfix:max "#u128" => U128.ofInt -- Testing the notations example : Result Usize := 0#usize + 1#usize +-- TODO: factor those lemmas out @[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by simp [Scalar.ofInt, Scalar.ofIntCore] +@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + +@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by + apply Scalar.ofInt_val_eq h + -- Comparisons instance {ty} : LT (Scalar ty) where lt a b := LT.lt a.val b.val diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml new file mode 100644 index 00000000..992dade9 --- /dev/null +++ b/compiler/AssociatedTypes.ml @@ -0,0 +1,577 @@ +(** This file implements utilities to handle trait associated types, in + particular with normalization helpers. + + When normalizing a type, we simplify the references to the trait associated + types, and choose a representative when there are equalities between types + enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]). + *) + +module T = Types +module TU = TypesUtils +module V = Values +module E = Expressions +module A = LlbcAst +module C = Contexts +module Subst = Substitute +module L = Logging +module UF = UnionFind +module PA = Print.EvalCtxLlbcAst + +(** The local logger *) +let log = L.associated_types_log + +let trait_type_ref_substitute (subst : ('r, 'r1) Subst.subst) + (r : 'r C.trait_type_ref) : 'r1 C.trait_type_ref = + let { C.trait_ref; type_name } = r in + let trait_ref = Subst.trait_ref_substitute subst trait_ref in + { C.trait_ref; type_name } + +(* TODO: how not to duplicate below? *) +module RTyOrd = struct + type t = T.rty + + let compare = T.compare_rty + let to_string = T.show_rty + let pp_t = T.pp_rty + let show_t = T.show_rty +end + +module STyOrd = struct + type t = T.sty + + let compare = T.compare_sty + let to_string = T.show_sty + let pp_t = T.pp_sty + let show_t = T.show_sty +end + +module RTyMap = Collections.MakeMap (RTyOrd) +module STyMap = Collections.MakeMap (STyOrd) + +(* TODO: is it possible not to have this? *) +module type TypeWrapper = sig + type t +end + +(* TODO: don't manage to get the syntax right so using a functor *) +module MakeNormalizer + (R : TypeWrapper) + (RTyMap : Collections.Map with type key = R.t T.region T.ty) + (M : Collections.Map with type key = R.t T.region C.trait_type_ref) = +struct + let compute_norm_trait_types_from_preds + (trait_type_constraints : R.t T.region T.trait_type_constraint list) : + R.t T.region T.ty M.t = + (* Compute a union-find structure by recursively exploring the predicates and clauses *) + let norm : R.t T.region T.ty UF.elem RTyMap.t ref = ref RTyMap.empty in + let get_ref (ty : R.t T.region T.ty) : R.t T.region T.ty UF.elem = + match RTyMap.find_opt ty !norm with + | Some r -> r + | None -> + let r = UF.make ty in + norm := RTyMap.add ty r !norm; + r + in + let add_trait_type_constraint (c : R.t T.region T.trait_type_constraint) = + let trait_ty = T.TraitType (c.trait_ref, c.generics, c.type_name) in + let trait_ty_ref = get_ref trait_ty in + let ty_ref = get_ref c.ty in + let new_repr = UF.get ty_ref in + let merged = UF.union trait_ty_ref ty_ref in + (* Not sure the set operation is necessary, but I want to control which + representative is chosen *) + UF.set merged new_repr + in + (* Explore the local predicates *) + List.iter add_trait_type_constraint trait_type_constraints; + (* TODO: explore the local clauses *) + (* Compute the norm maps *) + let rbindings = + List.map (fun (k, v) -> (k, UF.get v)) (RTyMap.bindings !norm) + in + (* Filter the keys to keep only the trait type aliases *) + let rbindings = + List.filter_map + (fun (k, v) -> + match k with + | T.TraitType (trait_ref, generics, type_name) -> + assert (generics = TypesUtils.mk_empty_generic_args); + Some ({ C.trait_ref; type_name }, v) + | _ -> None) + rbindings + in + M.of_list rbindings +end + +(** Compute the representative classes of trait associated types, for normalization *) +let compute_norm_trait_stypes_from_preds + (trait_type_constraints : T.strait_type_constraint list) : + T.sty C.STraitTypeRefMap.t = + (* Compute the normalization map for the types with regions *) + let module R = struct + type t = T.region_var_id + end in + let module M = C.STraitTypeRefMap in + let module Norm = MakeNormalizer (R) (STyMap) (M) in + Norm.compute_norm_trait_types_from_preds trait_type_constraints + +(** Compute the representative classes of trait associated types, for normalization *) +let compute_norm_trait_types_from_preds + (trait_type_constraints : T.rtrait_type_constraint list) : + T.ety C.ETraitTypeRefMap.t * T.rty C.RTraitTypeRefMap.t = + (* Compute the normalization map for the types with regions *) + let module R = struct + type t = T.region_id + end in + let module M = C.RTraitTypeRefMap in + let module Norm = MakeNormalizer (R) (RTyMap) (M) in + let rbindings = + Norm.compute_norm_trait_types_from_preds trait_type_constraints + in + (* Compute the normalization map for the types with erased regions *) + let ebindings = + List.map + (fun (k, v) -> + ( trait_type_ref_substitute Subst.erase_regions_subst k, + Subst.erase_regions v )) + (M.bindings rbindings) + in + (C.ETraitTypeRefMap.of_list ebindings, rbindings) + +let ctx_add_norm_trait_stypes_from_preds (ctx : C.eval_ctx) + (trait_type_constraints : T.strait_type_constraint list) : C.eval_ctx = + let norm_trait_stypes = + compute_norm_trait_stypes_from_preds trait_type_constraints + in + { ctx with C.norm_trait_stypes } + +let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx) + (trait_type_constraints : T.rtrait_type_constraint list) : C.eval_ctx = + let norm_trait_etypes, norm_trait_rtypes = + compute_norm_trait_types_from_preds trait_type_constraints + in + { ctx with C.norm_trait_etypes; norm_trait_rtypes } + +(** A trait instance id refers to a local clause if it only uses the variants: + [Self], [Clause], [ParentClause], [ItemClause] *) +let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool = + match id with + | T.Self | Clause _ -> true + | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false + | ParentClause (id, _, _) | ItemClause (id, _, _, _) -> + trait_instance_id_is_local_clause id + +(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.), + but they should be applied to types without regions. + *) +type 'r norm_ctx = { + ctx : C.eval_ctx; + get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option; + convert_ety : T.ety -> 'r T.ty; + convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref; + ty_to_string : 'r T.ty -> string; + trait_ref_to_string : 'r T.trait_ref -> string; + trait_instance_id_to_string : 'r T.trait_instance_id -> string; + pp_r : Format.formatter -> 'r -> unit; +} + +(** Normalize a type by simplyfying the references to trait associated types + and choosing a representative when there are equalities between types + enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *) +let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = + fun ctx ty -> + log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty)); + match ty with + | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics) + | TypeVar _ | Literal _ | Never -> ty + | Ref (r, ty, rkind) -> + let ty = ctx_normalize_ty ctx ty in + T.Ref (r, ty, rkind) + | TraitType (trait_ref, generics, type_name) -> ( + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: " ^ ctx.ty_to_string ty + ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + (* Normalize and attempt to project the type from the trait ref *) + let trait_ref = ctx_normalize_trait_ref ctx trait_ref in + let generics = ctx_normalize_generic_args ctx generics in + let ty : 'r T.ty = + match trait_ref.trait_id with + | T.TraitRef + { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } -> + assert (ref_generics = TypesUtils.mk_empty_generic_args); + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: trait ref: " + ^ ctx.ty_to_string ty)); + (* Lookup the implementation *) + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the type *) + let ty = snd (List.assoc type_name trait_impl.types) in + (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *) + let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in + (* Substitute *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_subst_from_generics_no_regions trait_impl.generics + generics tr_self + in + let ty = Subst.ty_substitute subst ty in + (* Reconvert *) + let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in + (* Normalize *) + ctx_normalize_ty ctx ty + | T.TraitImpl impl_id -> + (* This happens. This doesn't come from the substituations + performed by Aeneas (the [TraitImpl] would be wrapped in a + [TraitRef] but from non-normalized traits translated from + the Rustc AST. + TODO: factor out with the branch above. + *) + (* Lookup the implementation *) + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the type *) + let ty = snd (List.assoc type_name trait_impl.types) in + (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *) + let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in + (* Substitute *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_subst_from_generics_no_regions trait_impl.generics + generics tr_self + in + let ty = Subst.ty_substitute subst ty in + (* Reconvert *) + let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in + (* Normalize *) + ctx_normalize_ty ctx ty + | _ -> + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: not a trait ref: " + ^ ctx.ty_to_string ty ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + (* We can't project *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + T.TraitType (trait_ref, generics, type_name) + in + let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in + (* Lookup the representative, if there is *) + match ctx.get_ty_repr tr with None -> ty | Some ty -> ty) + +(** This returns the normalized trait instance id together with an optional + reference to a trait **implementation**. + + We need this in particular to simplify the trait instance ids after we + performed a substitution. + + Example: + ======== + {[ + trait Trait { + type S + } + + impl TraitImpl for Foo { + type S = usize + } + + fn f<T : Trait>(...) -> T::S; + + ... + let x = f<Foo>[TraitImpl](...); // T::S ~~> TraitImpl::S ~~> usize + ]} + + Several remarks: + - as we do not allow higher-order types (yet) then local clauses (and + sub-clauses) can't have generic arguments + - the [TraitRef] case only happens because of substitution, the role of + the normalization is in particular to eliminate it. Inside a [TraitRef] + there is necessarily: + - an id referencing a local (sub-)clause, that is an id using the variants + [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't + simplify those cases: all we can do is remove the [TraitRef] wrapper + by leveraging the fact that the generic arguments must be empty. + - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl], + it can't be for instance a [ParentClause(TraitImpl ...)] because the + trait resolution would then directly reference the implementation + designated by [ParentClause(TraitImpl ...)] (and same for the other cases). + In this case we can lookup the trait implementation and recursively project + over it. + *) +and ctx_normalize_trait_instance_id : + 'r. + 'r norm_ctx -> + 'r T.trait_instance_id -> + 'r T.trait_instance_id * 'r T.trait_ref option = + fun ctx id -> + match id with + | Self -> (id, None) + | TraitImpl _ -> + (* The [TraitImpl] shouldn't be inside any projection - we check this + elsewhere by asserting that whenever we return [None] for the impl + trait ref, then the id actually refers to a local clause. *) + (id, None) + | Clause _ -> (id, None) + | BuiltinOrAuto _ -> (id, None) + | ParentClause (inst_id, decl_id, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ParentClause (inst_id, decl_id, clause_id), None) + | Some impl -> + (* We figure out the parent clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1 { ... } + + // Check the trait it implements + trait Trait1 : ParentTrait1 + ParentTrait2 { ... } + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + those are the parent clauses + ]} + + We can find the parent clauses in the [trait_decl_ref] field, which + tells us which specific instantiation of [Trait1] is implemented + by [Impl1]. + *) + let clause = + T.TraitClauseId.nth impl.trait_decl_ref.decl_generics.trait_refs + clause_id + in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (TraitRef clause, Some clause)) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ParentClause (inst_id, decl_id, clause_id), None) + | Some impl -> + (* We figure out the item clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1<R> { + type S = ... + with Impl2 : Trait2 ... // Instances satisfying the declared bounds + ^^^^^^^^^^^^^^^^^^ + Lookup the clause from here + } + ]} + *) + (* The referenced instance should be an impl *) + let impl_id = + TypesUtils.trait_instance_id_as_trait_impl impl.trait_id + in + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the clause *) + let item = List.assoc item_name trait_impl.types in + let clause = T.TraitClauseId.nth (fst item) clause_id in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (* We need to convert the clause type - + TODO: we have too many problems with those conversions, we should + merge the types. *) + let clause = ctx.convert_etrait_ref clause in + (TraitRef clause, Some clause)) + | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } -> + (* We can't simplify the id *yet* : we will simplify it when projecting. + However, we have an implementation to return *) + (* Normalize the generics *) + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + let trait_ref : 'r T.trait_ref = + { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } + in + (TraitRef trait_ref, Some trait_ref) + | TraitRef trait_ref -> + (* The trait instance id necessarily refers to a local sub-clause. We + can't project over it and can only peel off the [TraitRef] wrapper *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + assert (trait_ref.generics = TypesUtils.mk_empty_generic_args); + (trait_ref.trait_id, None) + | UnknownTrait _ -> + (* This is actually an error case *) + (id, None) + +and ctx_normalize_generic_args (ctx : 'r norm_ctx) + (generics : 'r T.generic_args) : 'r T.generic_args = + let { T.regions; types; const_generics; trait_refs } = generics in + let types = List.map (ctx_normalize_ty ctx) types in + let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in + { T.regions; types; const_generics; trait_refs } + +and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) : + 'r T.trait_ref = + log#ldebug + (lazy + ("ctx_normalize_trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + let { T.trait_id; generics; trait_decl_ref } = trait_ref in + (* Check if the id is an impl, otherwise normalize it *) + let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in + match norm_trait_ref with + | None -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: no norm: " + ^ ctx.trait_instance_id_to_string trait_id)); + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + { T.trait_id; generics; trait_decl_ref } + | Some trait_ref -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: normalized to: " + ^ ctx.trait_ref_to_string trait_ref)); + assert (generics = TypesUtils.mk_empty_generic_args); + trait_ref + +(* Not sure this one is really necessary *) +and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx) + (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref = + let { T.trait_decl_id; decl_generics } = trait_decl_ref in + let decl_generics = ctx_normalize_generic_args ctx decl_generics in + { T.trait_decl_id; decl_generics } + +let ctx_normalize_trait_type_constraint (ctx : 'r norm_ctx) + (ttc : 'r T.trait_type_constraint) : 'r T.trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let trait_ref = ctx_normalize_trait_ref ctx trait_ref in + let generics = ctx_normalize_generic_args ctx generics in + let ty = ctx_normalize_ty ctx ty in + { T.trait_ref; generics; type_name; ty } + +let mk_snorm_ctx (ctx : C.eval_ctx) : T.RegionVarId.id T.region norm_ctx = + let get_ty_repr x = C.STraitTypeRefMap.find_opt x ctx.norm_trait_stypes in + { + ctx; + get_ty_repr; + convert_ety = TypesUtils.ety_no_regions_to_sty; + convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + ty_to_string = PA.sty_to_string ctx; + trait_ref_to_string = PA.strait_ref_to_string ctx; + trait_instance_id_to_string = PA.strait_instance_id_to_string ctx; + pp_r = T.pp_region T.pp_region_var_id; + } + +let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx = + let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in + { + ctx; + get_ty_repr; + convert_ety = TypesUtils.ety_no_regions_to_rty; + convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + ty_to_string = PA.rty_to_string ctx; + trait_ref_to_string = PA.rtrait_ref_to_string ctx; + trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx; + pp_r = T.pp_region T.pp_region_id; + } + +let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx = + let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in + { + ctx; + get_ty_repr; + convert_ety = (fun x -> x); + convert_etrait_ref = (fun x -> x); + ty_to_string = PA.ety_to_string ctx; + trait_ref_to_string = PA.etrait_ref_to_string ctx; + trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx; + pp_r = T.pp_erased_region; + } + +let ctx_normalize_sty (ctx : C.eval_ctx) (ty : T.sty) : T.sty = + ctx_normalize_ty (mk_snorm_ctx ctx) ty + +let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty = + ctx_normalize_ty (mk_rnorm_ctx ctx) ty + +let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety = + ctx_normalize_ty (mk_enorm_ctx ctx) ty + +let ctx_normalize_rtrait_type_constraint (ctx : C.eval_ctx) + (ttc : T.rtrait_type_constraint) : T.rtrait_type_constraint = + ctx_normalize_trait_type_constraint (mk_rnorm_ctx ctx) ttc + +(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *) +let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx) + (def : T.type_decl) (generics : T.rgeneric_args) : + (T.VariantId.id option * T.rty list) list = + let res = + Subst.type_decl_get_instantiated_variants_fields_rtypes def generics + in + List.map + (fun (variant_id, types) -> + (variant_id, List.map (ctx_normalize_rty ctx) types)) + res + +(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *) +let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) : + T.rty list = + let types = + Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics + in + List.map (ctx_normalize_rty ctx) types + +(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *) +let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx) + (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) : + T.rty list = + let types = + Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics + in + List.map (ctx_normalize_rty ctx) types + +(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *) +let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) : + T.ety list = + let types = + Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics + in + List.map (ctx_normalize_ety ctx) types + +(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *) +let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx) + (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) + (generics : T.egeneric_args) : T.ety list = + let types = + Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id + generics + in + List.map (ctx_normalize_ety ctx) types + +(** Same as [substitute_signature] but normalizes the types *) +let ctx_subst_norm_signature (ctx : C.eval_ctx) + (asubst : T.RegionGroupId.id -> V.AbstractionId.id) + (r_subst : T.RegionVarId.id -> T.RegionId.id) + (ty_subst : T.TypeVarId.id -> T.rty) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + let sg = + Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self + sg + in + let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in + let inputs = List.map (ctx_normalize_rty ctx) inputs in + let output = ctx_normalize_rty ctx output in + let trait_type_constraints = + List.map (ctx_normalize_rtrait_type_constraint ctx) trait_type_constraints + in + { regions_hierarchy; inputs; output; trait_type_constraints } diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index 11cd5666..109175af 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -63,79 +63,81 @@ module Sig = struct let empty_const_generic_params : T.const_generic_var list = [] + let mk_generic_args regions types const_generics : T.sgeneric_args = + { regions; types; const_generics; trait_refs = [] } + + let mk_generic_params regions types const_generics : T.generic_params = + { regions; types; const_generics; trait_clauses = [] } + let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) : T.sty = let ref_kind = if is_mut then T.Mut else T.Shared in mk_ref_ty r ty ref_kind let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty = - Adt (Assumed Array, [], [ ty ], [ cg ]) + Adt (Assumed Array, mk_generic_args [] [ ty ] [ cg ]) + + let mk_slice_ty (ty : T.sty) : T.sty = + Adt (Assumed Slice, mk_generic_args [] [ ty ] []) - let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], []) - let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], []) + let range_ty : T.sty = Adt (Assumed Range, mk_generic_args [] [ usize_ty ] []) + + let mk_sig generics regions_hierarchy inputs output : A.fun_sig = + let preds : T.predicates = + { regions_outlive = []; types_outlive = []; trait_type_constraints = [] } + in + { + generics; + preds; + parent_params_info = None; + regions_hierarchy; + inputs; + output; + } (** [fn<T>(&'a mut T, T) -> T] *) let mem_replace_sig : A.fun_sig = (* The signature fields *) - let region_params = [ region_param_0 ] (* <'a> *) in + let regions = [ region_param_0 ] (* <'a> *) in let regions_hierarchy = [ region_group_0 ] (* [{<'a>}] *) in - let type_params = [ type_param_0 ] (* <T> *) in + let types = [ type_param_0 ] (* <T> *) in + let generics = mk_generic_params regions types [] in let inputs = [ mk_ref_ty rvar_0 tvar_0 true (* &'a mut T *); tvar_0 (* T *) ] in let output = tvar_0 (* T *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** [fn<T>(T) -> Box<T>] *) let box_new_sig : A.fun_sig = - { - region_params = []; - num_early_bound_regions = 0; - regions_hierarchy = []; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = [ tvar_0 (* T *) ]; - output = mk_box_ty tvar_0 (* Box<T> *); - } + let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in + let regions_hierarchy = [] in + let inputs = [ tvar_0 (* T *) ] in + let output = mk_box_ty tvar_0 (* Box<T> *) in + mk_sig generics regions_hierarchy inputs output (** [fn<T>(Box<T>) -> ()] *) let box_free_sig : A.fun_sig = - { - region_params = []; - num_early_bound_regions = 0; - regions_hierarchy = []; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = [ mk_box_ty tvar_0 (* Box<T> *) ]; - output = mk_unit_ty (* () *); - } + let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in + let regions_hierarchy = [] in + let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in + let output = mk_unit_ty (* () *) in + mk_sig generics regions_hierarchy inputs output (** Helper for [Box::deref_shared] and [Box::deref_mut]. Returns: [fn<'a, T>(&'a (mut) Box<T>) -> &'a (mut) T] *) let box_deref_gen_sig (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = - [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ]; - output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *); - } + let inputs = + [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ] + in + let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in + mk_sig generics regions_hierarchy inputs output (** [fn<'a, T>(&'a Box<T>) -> &'a T] *) let box_deref_shared_sig = box_deref_gen_sig false @@ -145,27 +147,18 @@ module Sig = struct (** [fn<T>() -> Vec<T>] *) let vec_new_sig : A.fun_sig = - let region_params = [] in + let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in let regions_hierarchy = [] in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [] in let output = mk_vec_ty tvar_0 (* Vec<T> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** [fn<T>(&'a mut Vec<T>, T)] *) let vec_push_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *); @@ -173,22 +166,14 @@ module Sig = struct ] in let output = mk_unit_ty (* () *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** [fn<T>(&'a mut Vec<T>, usize, T)] *) let vec_insert_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *); @@ -197,44 +182,28 @@ module Sig = struct ] in let output = mk_unit_ty (* () *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** [fn<T>(&'a Vec<T>) -> usize] *) let vec_len_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) false (* &'a Vec<T> *) ] in let output = mk_usize_ty (* usize *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** Helper: [fn<T>(&'a (mut) Vec<T>, usize) -> &'a (mut) T] *) let vec_index_gen_sig (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *); @@ -242,15 +211,7 @@ module Sig = struct ] in let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output (** [fn<T>(&'a Vec<T>, usize) -> &'a T] *) let vec_index_shared_sig : A.fun_sig = vec_index_gen_sig false @@ -275,10 +236,10 @@ module Sig = struct let mk_array_slice_borrow_sig (cgs : T.const_generic_var list) (input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option) (output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 @@ -294,15 +255,7 @@ module Sig = struct (output_ty type_param_0.index) is_mut (* &'a (mut) output_ty<T> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = cgs; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig = (* Array<T, N> *) @@ -345,6 +298,19 @@ module Sig = struct let array_subslice_sig (is_mut : bool) = mk_array_slice_subslice_sig true is_mut + let array_repeat_sig = + let generics = + (* <T, N> *) + mk_generic_params [] [ type_param_0 ] [ cg_param_0 ] + in + let regions_hierarchy = [] (* <> *) in + let inputs = [ tvar_0 (* T *) ] in + let output = + (* [T; N] *) + mk_array_ty tvar_0 cgvar_0 + in + mk_sig generics regions_hierarchy inputs output + let slice_subslice_sig (is_mut : bool) = mk_array_slice_subslice_sig false is_mut @@ -352,23 +318,15 @@ module Sig = struct [fn<T>(&'a [T]) -> usize] *) let slice_len_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ] in let output = mk_usize_ty (* usize *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output end type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name @@ -439,6 +397,8 @@ let assumed_infos : assumed_info list = Sig.array_subslice_sig true, true, to_name [ "@ArraySubsliceMut" ] ); + (* Array Repeat *) + (ArrayRepeat, Sig.array_repeat_sig, false, to_name [ "@ArrayRepeat" ]); (* Slice Index *) ( SliceIndexShared, Sig.slice_index_sig false, diff --git a/compiler/Config.ml b/compiler/Config.ml index bd80769f..62f6c300 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -323,3 +323,23 @@ let wrap_opaque_in_sig = ref false information), we use short names (i.e., the original field names). *) let record_fields_short_names = ref false + +(** Parameterize the traits with their associated types, so as not to use + types as first class objects. + + This is useful for some backends with limited expressiveness like HOL4, + and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]). + *) +let parameterize_trait_types = ref false + +(** For sanity check: type check the generated pure code (activates checks in + several places). + + TODO: deactivated for now because we need to implement the normalization of + trait associated types in the pure code. + *) +let type_check_pure_code = ref false + +(** Shall we fail hard if there is an issue at code-generation time? + We may not want in case outputting a code with holes helps debugging *) +let extract_fail_hard = ref false diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index 2ca5653d..dac64a9a 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -5,6 +5,7 @@ open LlbcAst module V = Values open ValuesUtils open Identifiers +module L = Logging (** The [Id] module for dummy variables. @@ -17,6 +18,9 @@ IdGen () type dummy_var_id = DummyVarId.id [@@deriving show, ord] +(** The local logger *) +let log = L.contexts_log + (** Some global counters. Note that those counters were initially stored in {!eval_ctx} values, @@ -40,6 +44,7 @@ type dummy_var_id = DummyVarId.id [@@deriving show, ord] fn f x : fun_type = let id = fresh_id () in ... + fun () -> ... let g = f x in // <-- the fresh identifier gets generated here let x1 = g () in // <-- no fresh generation here @@ -250,27 +255,127 @@ type type_context = { } [@@deriving show] -type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show] +type fun_context = { + fun_decls : fun_decl FunDeclId.Map.t; + fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t; +} +[@@deriving show] type global_context = { global_decls : global_decl GlobalDeclId.Map.t } [@@deriving show] +type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t } +[@@deriving show] + +type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t } +[@@deriving show] + +type decls_ctx = { + type_ctx : type_context; + fun_ctx : fun_context; + global_ctx : global_context; + trait_decls_ctx : trait_decls_context; + trait_impls_ctx : trait_impls_context; +} +[@@deriving show] + +(** A reference to a trait associated type *) +type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string } +[@@deriving show, ord] + +type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord] + +type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref +[@@deriving show, ord] + +type strait_type_ref = Types.RegionVarId.id Types.region trait_type_ref +[@@deriving show, ord] + +(* TODO: correctly use the functors so as not to have a duplication below *) +module ETraitTypeRefOrd = struct + type t = etrait_type_ref + + let compare = compare_etrait_type_ref + let to_string = show_etrait_type_ref + let pp_t = pp_etrait_type_ref + let show_t = show_etrait_type_ref +end + +module RTraitTypeRefOrd = struct + type t = rtrait_type_ref + + let compare = compare_rtrait_type_ref + let to_string = show_rtrait_type_ref + let pp_t = pp_rtrait_type_ref + let show_t = show_rtrait_type_ref +end + +module STraitTypeRefOrd = struct + type t = strait_type_ref + + let compare = compare_strait_type_ref + let to_string = show_strait_type_ref + let pp_t = pp_strait_type_ref + let show_t = show_strait_type_ref +end + +module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd) +module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd) +module STraitTypeRefMap = Collections.MakeMap (STraitTypeRefOrd) + (** Evaluation context *) type eval_ctx = { type_context : type_context; fun_context : fun_context; global_context : global_context; + trait_decls_context : trait_decls_context; + trait_impls_context : trait_impls_context; region_groups : RegionGroupId.id list; type_vars : type_var list; const_generic_vars : const_generic_var list; + const_generic_vars_map : typed_value Types.ConstGenericVarId.Map.t; + (** The map from const generic vars to their values. Those values + can be symbolic values or concrete values (in the latter case: + if we run in interpreter mode) *) + norm_trait_etypes : ety ETraitTypeRefMap.t; + (** The normalized trait types (a map from trait types to their representatives). + Note that this doesn't support account higher-order types. *) + norm_trait_rtypes : rty RTraitTypeRefMap.t; + (** We need this because we manipulate two kinds of types. + Note that we actually forbid regions from appearing both in the trait + references and in the constraints given to the associated types, + meaning that we don't have to worry about mismatches due to changes + in region ids. + + TODO: how not to duplicate? + *) + norm_trait_stypes : sty STraitTypeRefMap.t; + (** We sometimes need to normalize types in non-instantiated signatures. + + Note that we either need to use the etypes/rtypes maps, or the stypes map. + This means that we either compute the maps for etypes and rtypes, or compute + the one for stypes (we don't always compute and carry all the maps). + *) env : env; ended_regions : RegionId.Set.t; } [@@deriving show] +let lookup_type_var_opt (ctx : eval_ctx) (vid : TypeVarId.id) : type_var option + = + if TypeVarId.to_int vid < List.length ctx.type_vars then + Some (TypeVarId.nth ctx.type_vars vid) + else None + let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var = TypeVarId.nth ctx.type_vars vid +let lookup_const_generic_var_opt (ctx : eval_ctx) (vid : ConstGenericVarId.id) : + const_generic_var option = + if ConstGenericVarId.to_int vid < List.length ctx.const_generic_vars then + Some (ConstGenericVarId.nth ctx.const_generic_vars vid) + else None + let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) : const_generic_var = ConstGenericVarId.nth ctx.const_generic_vars vid @@ -304,6 +409,12 @@ let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : global_decl = GlobalDeclId.Map.find gid ctx.global_context.global_decls +let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl = + TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls + +let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl = + TraitImplId.Map.find id ctx.trait_impls_context.trait_impls + (** Retrieve a variable's value in the current frame *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = snd (env_lookup_var env vid) @@ -312,6 +423,11 @@ let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = let ctx_lookup_var_value (ctx : eval_ctx) (vid : VarId.id) : typed_value = env_lookup_var_value ctx.env vid +(** Retrieve a const generic value in an evaluation context *) +let ctx_lookup_const_generic_value (ctx : eval_ctx) (vid : ConstGenericVarId.id) + : typed_value = + Types.ConstGenericVarId.Map.find vid ctx.const_generic_vars_map + (** Update a variable's value in the current frame. This is a helper function: it can break invariants and doesn't perform @@ -361,6 +477,15 @@ let ctx_push_var (ctx : eval_ctx) (var : var) (v : typed_value) : eval_ctx = *) let ctx_push_vars (ctx : eval_ctx) (vars : (var * typed_value) list) : eval_ctx = + log#ldebug + (lazy + ("push_vars:\n" + ^ String.concat "\n" + (List.map + (fun (var, value) -> + (* We can unfortunately not use Print because it depends on Contexts... *) + show_var var ^ " -> " ^ V.show_typed_value value) + vars))); assert ( List.for_all (fun (var, (value : typed_value)) -> var.var_ty = value.ty) diff --git a/compiler/Driver.ml b/compiler/Driver.ml index b646a53d..414b042d 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -17,11 +17,15 @@ let log = main_log let _ = (* Set up the logging - for now we use default values - TODO: use the * command-line arguments *) - (* By setting a level for the main_logger_handler, we filter everything *) + (* By setting a level for the main_logger_handler, we filter everything. + To have a good trace: one should switch between Info and Debug. + *) Easy_logging.Handlers.set_level main_logger_handler EL.Debug; main_log#set_level EL.Info; llbc_of_json_logger#set_level EL.Info; pre_passes_log#set_level EL.Info; + associated_types_log#set_level EL.Info; + contexts_log#set_level EL.Info; interpreter_log#set_level EL.Info; statements_log#set_level EL.Info; loops_match_ctxs_log#set_level EL.Info; @@ -164,8 +168,6 @@ let () = decompose_monadic_let_bindings := true; decompose_nested_let_patterns := true | Lean -> - (* The Lean backend is experimental: print a warning *) - log#lwarning (lazy "The Lean backend is experimental"); (* We don't support fuel for the Lean backend *) if !use_fuel then ( log#error "The Lean backend doesn't support the -use-fuel option"; @@ -220,20 +222,6 @@ let () = in if has_loops then log#lwarning (lazy "Support for loops is experimental"); - (* If we target Lean, we request the crates to be split into several files - whenever there are opaque functions *) - if - !backend = Lean - && A.FunDeclId.Map.exists - (fun _ (d : A.fun_decl) -> d.body = None) - m.functions - && not !split_files - then ( - log#error - "For Lean, we request the -split-file option whenever using opaque \ - functions"; - fail ()); - (* We don't support mutually recursive definitions with decreases clauses in Lean *) if !backend = Lean && !extract_decreases_clauses diff --git a/compiler/Extract.ml b/compiler/Extract.ml index c4238d83..ac81d6f3 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -338,6 +338,7 @@ let assumed_llbc_functions () : (ArraySubsliceShared, None, "array_subslice_shared"); (ArraySubsliceMut, None, "array_subslice_mut_fwd"); (ArraySubsliceMut, rg0, "array_subslice_mut_back"); + (ArrayRepeat, None, "array_repeat"); (SliceIndexShared, None, "slice_index_shared"); (SliceIndexMut, None, "slice_index_mut_fwd"); (SliceIndexMut, rg0, "slice_index_mut_back"); @@ -369,6 +370,7 @@ let assumed_llbc_functions () : (ArraySubsliceShared, None, "Array.subslice_shared"); (ArraySubsliceMut, None, "Array.subslice_mut"); (ArraySubsliceMut, rg0, "Array.subslice_mut_back"); + (ArrayRepeat, None, "Array.repeat"); (SliceIndexShared, None, "Slice.index_shared"); (SliceIndexMut, None, "Slice.index_mut"); (SliceIndexMut, rg0, "Slice.index_mut_back"); @@ -529,7 +531,15 @@ let type_decl_kind_to_qualif (kind : decl_kind) *) Some "with" | (Assumed | Declared), None -> Some "Axiom" - | _ -> raise (Failure "Unexpected")) + | SingleNonRec, None -> + (* This is for traits *) + Some "Record" + | _ -> + raise + (Failure + ("Unexpected: (" ^ show_decl_kind kind ^ ", " + ^ Print.option_to_string show_type_decl_kind type_kind + ^ ")"))) | Lean -> ( match kind with | SingleNonRec -> @@ -647,6 +657,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | _ -> raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name)) in + let flatten_name (name : string list) : string = + match !backend with + | FStar | Coq | HOL4 -> String.concat "_" name + | Lean -> String.concat "." name + in let get_type_name = get_name in let type_name_to_camel_case name = let name = get_type_name name in @@ -668,15 +683,20 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in let field_name (def_name : name) (field_id : FieldId.id) (field_name : string option) : string = - let field_name = + let field_name_s = match field_name with | Some field_name -> field_name - | None -> FieldId.to_string field_id + | None -> + (* TODO: extract structs with no field names to tuples *) + FieldId.to_string field_id in - if !Config.record_fields_short_names then field_name + if !Config.record_fields_short_names then + if field_name = None then (* TODO: this is a bit ugly *) + "_" ^ field_name_s + else field_name_s else let def_name = type_name_to_snake_case def_name ^ "_" in - def_name ^ field_name + def_name ^ field_name_s in let variant_name (def_name : name) (variant : string) : string = match !backend with @@ -700,9 +720,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let get_fun_name fname = let fname = get_name fname in (* TODO: don't convert to snake case for Coq, HOL4, F* *) - match !backend with - | FStar | Coq | HOL4 -> String.concat "_" (List.map to_snake_case fname) - | Lean -> String.concat "." fname + flatten_name fname in let global_name (name : global_name) : string = (* Converting to snake case also lowercases the letters (in Rust, global @@ -720,6 +738,50 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ suffix in + let trait_decl_name (trait_decl : trait_decl) : string = + type_name trait_decl.name + in + + let trait_impl_name (trait_decl : trait_decl) (trait_impl : trait_impl) : + string = + (* TODO: provisional: we concatenate the trait impl name (which is its type) + with the trait decl name *) + let trait_decl = + let name = trait_decl.name in + match !backend with + | FStar | Coq | HOL4 -> type_name_to_snake_case name ^ "_inst" + | Lean -> String.concat "" (get_type_name name) ^ "Inst" + in + flatten_name (get_type_name trait_impl.name @ [ trait_decl ]) + in + + let trait_parent_clause_name (trait_decl : trait_decl) (clause : trait_clause) + : string = + (* TODO: improve - it would be better to not use indices *) + let clause = "parent_clause_" ^ TraitClauseId.to_string clause.clause_id in + if !Config.record_fields_short_names then clause + else trait_decl_name trait_decl ^ "_" ^ clause + in + let trait_type_name (trait_decl : trait_decl) (item : string) : string = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + let trait_const_name (trait_decl : trait_decl) (item : string) : string = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + let trait_method_name (trait_decl : trait_decl) (item : string) : string = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + let trait_type_clause_name (trait_decl : trait_decl) (item : string) + (clause : trait_clause) : string = + (* TODO: improve - it would be better to not use indices *) + trait_type_name trait_decl item + ^ "_clause_" + ^ TraitClauseId.to_string clause.clause_id + in + let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) : string = let fname = get_fun_name fname in @@ -757,6 +819,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = + (* Small helper to derive var names from ADT type names. + + We do the following: + - convert the type name to snake case + - take the first letter of every "letter group" + Ex.: "HashMap" -> "hash_map" -> "hm" + *) + let name_from_type_ident (name : string) : string = + let cl = to_snake_case name in + let cl = String.split_on_char '_' cl in + let cl = List.filter (fun s -> String.length s > 0) cl in + assert (List.length cl > 0); + let cl = List.map (fun s -> s.[0]) cl in + StringUtils.string_of_chars cl + in (* If there is a basename, we use it *) match basename with | Some basename -> @@ -765,11 +842,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | None -> ( (* No basename: we use the first letter of the type *) match ty with - | Adt (type_id, tys, _) -> ( + | Adt (type_id, generics) -> ( match type_id with | Tuple -> (* The "pair" case is frequent enough to have its special treatment *) - if List.length tys = 2 then "p" else "t" + if List.length generics.types = 2 then "p" else "t" | Assumed Result -> "r" | Assumed Error -> ConstStrings.error_basename | Assumed Fuel -> ConstStrings.fuel_basename @@ -781,24 +858,14 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Assumed Range -> "r" | Assumed State -> ConstStrings.state_basename | AdtId adt_id -> - let def = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in - (* We do the following: - * - compute the type name, and retrieve the last ident - * - convert this to snake case - * - take the first letter of every "letter group" + let def = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in + (* Derive the var name from the last ident of the type name * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm" *) - (* Thename shouldn't be empty, and its last element should + (* The name shouldn't be empty, and its last element should * be an ident *) let cl = List.nth def.name (List.length def.name - 1) in - let cl = to_snake_case (Names.as_ident cl) in - let cl = String.split_on_char '_' cl in - let cl = List.filter (fun s -> String.length s > 0) cl in - assert (List.length cl > 0); - let cl = List.map (fun s -> s.[0]) cl in - StringUtils.string_of_chars cl) + name_from_type_ident (Names.as_ident cl)) | TypeVar _ -> ( (* TODO: use "t" also for F* *) match !backend with @@ -806,7 +873,8 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *)) | Literal lty -> ( match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i") - | Arrow _ -> "f") + | Arrow _ -> "f" + | TraitType (_, _, name) -> name_from_type_ident name) in let type_var_basename (_varset : StringSet.t) (basename : string) : string = (* Rust type variables are snake-case and start with a capital letter *) @@ -828,6 +896,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) to_snake_case basename | Coq | Lean -> basename in + let trait_clause_basename (_varset : StringSet.t) (_clause : trait_clause) : + string = + (* TODO: actually use the clause to derive the name *) + "inst" + in + let trait_self_clause_basename = "self_clause" in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i in @@ -838,11 +912,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Scalar sv -> ( match !backend with | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value) - | Coq | HOL4 -> + | Coq | HOL4 | Lean -> let print_brackets = inside && !backend = HOL4 in if print_brackets then F.pp_print_string fmt "("; (match !backend with - | Coq -> () + | Coq | Lean -> () | HOL4 -> F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty); F.pp_print_space fmt () @@ -850,30 +924,20 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* We need to add parentheses if the value is negative *) if sv.PV.value >= Z.of_int 0 then F.pp_print_string fmt (Z.to_string sv.PV.value) - else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); + else + F.pp_print_string fmt + ("(" ^ Z.to_string sv.PV.value + ^ if !backend = Lean then ":Int" else "" ^ ")"); (match !backend with - | Coq -> F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty) + | Coq -> + let iname = int_name sv.PV.int_ty in + F.pp_print_string fmt ("%" ^ iname) + | Lean -> + let iname = String.lowercase_ascii (int_name sv.PV.int_ty) in + F.pp_print_string fmt ("#" ^ iname) | HOL4 -> () | _ -> raise (Failure "Unreachable")); - if print_brackets then F.pp_print_string fmt ")" - | Lean -> - F.pp_print_string fmt "("; - F.pp_print_string fmt (int_name sv.int_ty); - F.pp_print_string fmt ".ofInt "; - (* Something very annoying: negated values like `-3` are - ambiguous in Lean because of conversions, so we have to - be extremely explicit with negative numbers. - *) - if Z.lt sv.value Z.zero then ( - F.pp_print_string fmt "("; - F.pp_print_string fmt "-"; - F.pp_print_string fmt "("; - Z.pp_print fmt (Z.neg sv.value); - F.pp_print_string fmt ":Int"; - F.pp_print_string fmt ")"; - F.pp_print_string fmt ")") - else Z.pp_print fmt sv.value; - F.pp_print_string fmt ")") + if print_brackets then F.pp_print_string fmt ")") | Bool b -> let b = match !backend with @@ -919,10 +983,19 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fun_name; termination_measure_name; decreases_proof_name; + trait_decl_name; + trait_impl_name; + trait_parent_clause_name; + trait_const_name; + trait_type_name; + trait_method_name; + trait_type_clause_name; opaque_pre; var_basename; type_var_basename; const_generic_var_basename; + trait_self_clause_basename; + trait_clause_basename; append_index; extract_literal; extract_unop; @@ -1131,13 +1204,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit = let extract_rec = extract_ty ctx fmt no_params_tys in match ty with - | Adt (type_id, tys, cgs) -> ( - let has_params = tys <> [] || cgs <> [] in + | Adt (type_id, generics) -> ( + let has_params = generics <> empty_generic_args in match type_id with | Tuple -> (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type: * we have to write [unit]... *) - if tys = [] then F.pp_print_string fmt (unit_name ()) + if generics.types = [] then F.pp_print_string fmt (unit_name ()) else ( F.pp_print_string fmt "("; Collections.List.iter_link @@ -1152,7 +1225,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) in F.pp_print_string fmt product; F.pp_print_space fmt ()) - (extract_rec true) tys; + (extract_rec true) generics.types; F.pp_print_string fmt ")") | AdtId _ | Assumed _ -> ( (* HOL4 behaves differently. Where in Coq/FStar/Lean we would write: @@ -1169,36 +1242,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: for now, only the opaque *functions* are extracted in the opaque module. The opaque *types* are assumed. *) F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); - if tys <> [] then ( - F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_rec true) tys); - if cgs <> [] then ( - F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_const_generic ctx fmt true) - cgs); + extract_generic_args ctx fmt no_params_tys generics; if print_paren then F.pp_print_string fmt ")" | HOL4 -> - (* Const generics are unsupported in HOL4 *) - assert (cgs = []); + let { types; const_generics; trait_refs } = generics in + (* Const generics are not supported in HOL4 *) + assert (const_generics = []); let print_tys = match type_id with | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys) | Assumed _ -> true | _ -> raise (Failure "Unreachable") in - if tys <> [] && print_tys then ( - let print_paren = List.length tys > 1 in + if types <> [] && print_tys then ( + let print_paren = List.length types > 1 in if print_paren then F.pp_print_string fmt "("; Collections.List.iter_link (fun () -> F.pp_print_string fmt ","; F.pp_print_space fmt ()) - (extract_rec true) tys; + (extract_rec true) types; if print_paren then F.pp_print_string fmt ")"; F.pp_print_space fmt ()); - F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx))) + F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); + if trait_refs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_trait_ref ctx fmt no_params_tys true) + trait_refs))) | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) | Literal lty -> extract_literal_type ctx fmt lty | Arrow (arg_ty, ret_ty) -> @@ -1209,6 +1280,113 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" + | TraitType (trait_ref, generics, type_name) -> + if !parameterize_trait_types then raise (Failure "Unimplemented") + else if trait_ref.trait_id <> Self then ( + (* HOL4 doesn't have 1st class types *) + assert (!backend <> HOL4); + let use_brackets = generics <> empty_generic_args in + if use_brackets then F.pp_print_string fmt "("; + extract_trait_ref ctx fmt no_params_tys false trait_ref; + extract_generic_args ctx fmt no_params_tys generics; + let name = + ctx_get_trait_type trait_ref.trait_decl_ref.trait_decl_id type_name + ctx + in + if use_brackets then F.pp_print_string fmt ")"; + F.pp_print_string fmt ("." ^ name)) + else + (* There are two situations: + - we are extracting a declared item (typically a function signature) + for a trait declaration. We directly refer to the item (we extract + trait declarations as structures, so we can refer to their fields) + - we are extracting a provided method for a trait declaration. We + refer to the item in the self trait clause (see {!SelfTraitClauseId}). + + Remark: we can't get there for trait *implementations* because then the + types should have been normalized. + *) + let trait_decl_id = Option.get ctx.trait_decl_id in + let item_name = ctx_get_trait_type trait_decl_id type_name ctx in + assert (generics = empty_generic_args); + if ctx.is_provided_method then + (* Provided method: use the trait self clause *) + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt (self_clause ^ "." ^ item_name) + else + (* Declaration: directly refer to the item *) + F.pp_print_string fmt item_name + +and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit = + let use_brackets = tr.generics <> empty_generic_args && inside in + if use_brackets then F.pp_print_string fmt "("; + extract_trait_instance_id ctx fmt no_params_tys inside tr.trait_id; + extract_generic_args ctx fmt no_params_tys tr.generics; + if use_brackets then F.pp_print_string fmt ")" + +and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_decl_ref) : + unit = + let use_brackets = tr.decl_generics <> empty_generic_args && inside in + let is_opaque = false in + let name = ctx_get_trait_decl is_opaque tr.trait_decl_id ctx in + if use_brackets then F.pp_print_string fmt "("; + F.pp_print_string fmt name; + extract_generic_args ctx fmt no_params_tys tr.decl_generics; + if use_brackets then F.pp_print_string fmt ")" + +and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit = + let { types; const_generics; trait_refs } = generics in + if !backend <> HOL4 then ( + if types <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_ty ctx fmt no_params_tys true) + types); + if const_generics <> [] then ( + assert (!backend <> HOL4); + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_const_generic ctx fmt true) + const_generics)); + if trait_refs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_trait_ref ctx fmt no_params_tys true) + trait_refs) + +and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id) + : unit = + let with_opaque_pre = false in + match id with + | Self -> + (* This has specific treatment depending on the item we're extracting + (associated type, etc.). We should have caught this elsewhere. *) + raise (Failure "Unexpected") + | TraitImpl id -> + let name = ctx_get_trait_impl with_opaque_pre id ctx in + F.pp_print_string fmt name + | Clause id -> + let name = ctx_get_local_trait_clause id ctx in + F.pp_print_string fmt name + | ParentClause (inst_id, decl_id, clause_id) -> + (* Use the trait decl id to lookup the name *) + let name = ctx_get_trait_parent_clause decl_id clause_id ctx in + extract_trait_instance_id ctx fmt no_params_tys true inst_id; + F.pp_print_string fmt ("." ^ name) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> + (* Use the trait decl id to lookup the name *) + let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in + extract_trait_instance_id ctx fmt no_params_tys true inst_id; + F.pp_print_string fmt ("." ^ name) + | TraitRef trait_ref -> + extract_trait_ref ctx fmt no_params_tys inside trait_ref + | UnknownTrait _ -> + (* This is an error case *) + raise (Failure "Unexpected") (** Compute the names for all the top-level identifiers used in a type definition (type name, variant names, field names, etc. but not type @@ -1524,6 +1702,172 @@ let extract_comment (fmt : F.formatter) (sl : string list) : unit = F.pp_print_string fmt rd; F.pp_close_box fmt () +let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit = + let with_opaque_pre = false in + let trait_name = ctx_get_trait_decl with_opaque_pre clause.trait_id ctx in + F.pp_print_string fmt trait_name; + extract_generic_args ctx fmt no_params_tys clause.generics + +(** Insert a space, if necessary *) +let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = + if !space then space := false else F.pp_print_space fmt () + +(** Extract the trait self clause. + + We add the trait self clause for provided methods (see {!TraitSelfClauseId}). + *) +let extract_trait_self_clause (insert_req_space : unit -> unit) + (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : trait_decl) + (params : string list) : unit = + insert_req_space (); + F.pp_print_string fmt "("; + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt self_clause; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + let with_opaque_pre = false in + let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in + F.pp_print_string fmt trait_id; + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + params; + F.pp_print_string fmt ")" + +(** + - [trait_decl]: if [Some], it means we are extracting the generics for a provided + method and need to insert a trait self clause (see {!TraitSelfClauseId}). + *) +let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false) + ?(use_forall_use_sep = true) ?(as_implicits : bool = false) + ?(space : bool ref option = None) ?(trait_decl : trait_decl option = None) + (generics : generic_params) (type_params : string list) + (cg_params : string list) (trait_clauses : string list) : unit = + let all_params = List.concat [ type_params; cg_params; trait_clauses ] in + (* HOL4 doesn't support const generics *) + assert (cg_params = [] || !backend <> HOL4); + let left_bracket (implicit : bool) = + if implicit then F.pp_print_string fmt "{" else F.pp_print_string fmt "(" + in + let right_bracket (implicit : bool) = + if implicit then F.pp_print_string fmt "}" else F.pp_print_string fmt ")" + in + let insert_req_space () = + match space with + | None -> F.pp_print_space fmt () + | Some space -> insert_req_space fmt space + in + (* Print the type/const generic parameters *) + if all_params <> [] then ( + if use_forall then ( + if use_forall_use_sep then ( + insert_req_space (); + F.pp_print_string fmt ":"); + insert_req_space (); + F.pp_print_string fmt "forall"); + (* Small helper - we may need to split the parameters *) + let print_generics (as_implicits : bool) (type_params : string list) + (const_generics : const_generic_var list) + (trait_clauses : trait_clause list) : unit = + (* Note that in HOL4 we don't print the type parameters. *) + if !backend <> HOL4 then ( + (* Print the type parameters *) + if type_params <> [] then ( + insert_req_space (); + (* ( *) + left_bracket as_implicits; + List.iter + (fun s -> + F.pp_print_string fmt s; + F.pp_print_space fmt ()) + type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ()); + (* ) *) + right_bracket as_implicits); + (* Print the const generic parameters *) + List.iter + (fun (var : const_generic_var) -> + insert_req_space (); + (* ( *) + left_bracket as_implicits; + let n = ctx_get_const_generic_var var.index ctx in + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt var.ty; + (* ) *) + right_bracket as_implicits) + const_generics); + (* Print the trait clauses *) + List.iter + (fun (clause : trait_clause) -> + insert_req_space (); + (* ( *) + left_bracket as_implicits; + let n = ctx_get_local_trait_clause clause.clause_id ctx in + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_trait_clause_type ctx fmt no_params_tys clause; + (* ) *) + right_bracket as_implicits) + trait_clauses + in + (* If we extract the generics for a provided method for a trait declaration + (indicated by the trait decl given as input), we need to split the generics: + - we print the generics for the trait decl + - we print the trait self clause + - we print the generics for the trait method + *) + match trait_decl with + | None -> + print_generics as_implicits type_params generics.const_generics + generics.trait_clauses + | Some trait_decl -> + (* Split the generics between the generics specific to the trait decl + and those specific to the trait method *) + let open Collections.List in + let dtype_params, mtype_params = + split_at type_params (length trait_decl.generics.types) + in + let dcgs, mcgs = + split_at generics.const_generics + (length trait_decl.generics.const_generics) + in + let dtrait_clauses, mtrait_clauses = + split_at generics.trait_clauses + (length trait_decl.generics.trait_clauses) + in + (* Extract the trait decl generics - note that we can always deduce + those parameters from the trait self clause: for this reason + they are always implicit *) + print_generics true dtype_params dcgs dtrait_clauses; + (* Extract the trait self clause *) + let params = + concat + [ + dtype_params; + map + (fun (cg : const_generic_var) -> + ctx_get_const_generic_var cg.index ctx) + dcgs; + map + (fun c -> ctx_get_local_trait_clause c.clause_id ctx) + dtrait_clauses; + ] + in + extract_trait_self_clause insert_req_space ctx fmt trait_decl params; + (* Extract the method generics *) + print_generics as_implicits mtype_params mcgs mtrait_clauses) + (** Extract a type declaration. This function is for all type declarations and all backends **at the exception** @@ -1551,19 +1895,15 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let is_opaque = type_kind = None in let is_opaque_coq = !backend = Coq && is_opaque in - let use_forall = - is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> []) - in + let use_forall = is_opaque_coq && def.generics <> empty_generic_params in (* Retrieve the definition name *) let with_opaque_pre = false in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Add the type and const generic params - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx_body, type_params, cg_params = - ctx_add_type_const_generic_params def.type_params def.const_generic_params - ctx + let ctx_body, type_params, cg_params, trait_clauses = + ctx_add_generic_params def.generics ctx in - let ty_cg_params = List.append type_params cg_params in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; @@ -1583,40 +1923,12 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (match qualif with | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name) | None -> F.pp_print_string fmt def_name); - (* HOL4 doesn't support const generics *) - assert (cg_params = [] || !backend <> HOL4); - (* Print the type/const generic parameters *) - if ty_cg_params <> [] && !backend <> HOL4 then ( - if use_forall then ( - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "forall"); - (* Print the type parameters *) - if type_params <> [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - List.iter - (fun s -> - F.pp_print_string fmt s; - F.pp_print_space fmt ()) - type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt (type_keyword () ^ ")")); - (* Print the const generic parameters *) - List.iter - (fun (var : const_generic_var) -> - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - let n = ctx_get_const_generic_var var.index ctx in - F.pp_print_string fmt n; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt var.ty; - F.pp_print_string fmt ")") - def.const_generic_params); + (* HOL4 doesn't support const generics, and type definitions in HOL4 don't + support trait clauses *) + assert ((cg_params = [] && trait_clauses = []) || !backend <> HOL4); + (* Print the generic parameters *) + extract_generic_params ctx_body fmt type_decl_group ~use_forall def.generics + type_params cg_params trait_clauses; (* Print the "=" if we extract the body*) if extract_body then ( F.pp_print_space fmt (); @@ -1676,9 +1988,12 @@ let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) let with_opaque_pre = false in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Generic parameters are unsupported *) - assert (def.const_generic_params = []); + assert (def.generics.const_generics = []); + (* Trait clauses on type definitions are unsupported *) + assert (def.generics.trait_clauses = []); + (* Types *) (* Count the number of parameters *) - let num_params = List.length def.type_params in + let num_params = List.length def.generics.types in (* Generate the declaration *) F.pp_print_space fmt (); F.pp_print_string fmt @@ -1699,8 +2014,7 @@ let extract_type_decl_hol4_empty_record (ctx : extraction_ctx) let with_opaque_pre = false in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Sanity check *) - assert (def.type_params = []); - assert (def.const_generic_params = []); + assert (def.generics = empty_generic_params); (* Generate the declaration *) F.pp_print_space fmt (); F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”"); @@ -1740,13 +2054,12 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = assert (!backend = Coq); (* Generating the [Arguments] instructions is useful only if there are type parameters *) - if decl.type_params = [] && decl.const_generic_params = [] then () + if decl.generics.types = [] && decl.generics.const_generics = [] then () else (* Add the type params - note that we need those bindings only for the * body translation (they are not top-level) *) - let _ctx_body, type_params, cg_params = - ctx_add_type_const_generic_params decl.type_params - decl.const_generic_params ctx + let _ctx_body, type_params, cg_params, trait_clauses = + ctx_add_generic_params decl.generics ctx in (* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *) let extract_arguments_info (cons_name : string) (fields : 'a list) : unit = @@ -1754,27 +2067,22 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_break fmt 0 0; (* Open a box *) F.pp_open_hovbox fmt ctx.indent_incr; - (* Small utility *) - let print_vars () = - List.iter - (fun (var : string) -> - F.pp_print_space fmt (); - F.pp_print_string fmt ("{" ^ var ^ "}")) - (List.append type_params cg_params) - in - let print_fields () = - List.iter - (fun _ -> - F.pp_print_space fmt (); - F.pp_print_string fmt "_") - fields - in F.pp_print_break fmt 0 0; F.pp_print_string fmt "Arguments"; F.pp_print_space fmt (); F.pp_print_string fmt cons_name; - print_vars (); - print_fields (); + (* Print the type/const params and the trait clauses (`{T}`) *) + List.iter + (fun (var : string) -> + F.pp_print_space fmt (); + F.pp_print_string fmt ("{" ^ var ^ "}")) + (List.concat [ type_params; cg_params; trait_clauses ]); + (* Print the fields (`_`) *) + List.iter + (fun _ -> + F.pp_print_space fmt (); + F.pp_print_string fmt "_") + fields; F.pp_print_string fmt "."; (* Close the box *) @@ -1827,9 +2135,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let is_rec = decl_is_from_rec_group kind in if is_rec then (* Add the type params *) - let ctx, type_params, cg_params = - ctx_add_type_const_generic_params decl.type_params - decl.const_generic_params ctx + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params decl.generics ctx in let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in @@ -1850,33 +2157,12 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) F.pp_print_space fmt (); let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in F.pp_print_string fmt field_name; - F.pp_print_space fmt (); - (* Print the type parameters *) - if type_params <> [] then ( - F.pp_print_string fmt "{"; - List.iter - (fun p -> - F.pp_print_string fmt p; - F.pp_print_space fmt ()) - type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type}"; - F.pp_print_space fmt ()); - (* Print the const generic parameters *) - if cg_params <> [] then - List.iter - (fun (v : const_generic_var) -> - F.pp_print_string fmt "{"; - let n = ctx_get_const_generic_var v.index ctx in - F.pp_print_string fmt n; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt v.ty; - F.pp_print_string fmt "}"; - F.pp_print_space fmt ()) - decl.const_generic_params; + (* Print the generics *) + let as_implicits = true in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~as_implicits + decl.generics type_params cg_params trait_clauses; (* Print the record parameter *) + F.pp_print_space fmt (); F.pp_print_string fmt "("; F.pp_print_string fmt record_var; F.pp_print_space fmt (); @@ -2067,10 +2353,11 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) (** Compute the names for all the pure functions generated from a rust function (forward function and backward functions). *) -let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) +let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : extraction_ctx = - let (fwd, loop_fwds), back_ls = def in + let fwd = def.fwd in + let backs = def.backs in (* Register the decrease clauses, if necessary *) let register_decreases ctx def = if has_decreases_clause def then @@ -2083,22 +2370,19 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) | Lean -> ctx_add_decreases_proof def ctx else ctx in - let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in - let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in + let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in + let register_fun ctx f = ctx_add_fun_decl def f ctx in let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the forward functions' names *) - let ctx = register_funs ctx (fwd :: loop_fwds) in - (* Register the backward functions' names *) + (* Register the names of the forward functions *) let ctx = - List.fold_left - (fun ctx (back, loop_backs) -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx back_ls + if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx in - - (* Return *) - ctx + (* Register the names of the backward functions *) + List.fold_left + (fun ctx { f = back; loops = loop_backs } -> + let ctx = register_fun ctx back in + register_funs ctx loop_backs) + ctx backs (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -2122,11 +2406,11 @@ let extract_adt_g_value (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = match ty with - | Adt (Tuple, type_args, cg_args) -> + | Adt (Tuple, generics) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) - assert (List.length type_args = List.length field_values); - assert (cg_args = []); + assert (List.length generics.types = List.length field_values); + assert (generics.const_generics = [] && generics.trait_refs = []); (* This is very annoying: in Coq, we can't write [()] for the value of type [unit], we have to write [tt]. *) if !backend = Coq && field_values = [] then ( @@ -2144,7 +2428,7 @@ let extract_adt_g_value in F.pp_print_string fmt ")"; ctx) - | Adt (adt_id, _, _) -> + | Adt (adt_id, _) -> (* "Regular" ADT *) (* If we are generating a pattern for a let-binding and we target Lean, @@ -2249,6 +2533,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Var var_id -> let var_name = ctx_get_var var_id ctx in F.pp_print_string fmt var_name + | CVar var_id -> + let var_name = ctx_get_const_generic_var var_id ctx in + F.pp_print_string fmt var_name | Const cv -> ctx.fmt.extract_literal fmt inside cv | App _ -> let app, args = destruct_apps e in @@ -2279,14 +2566,23 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Top-level qualifier *) match qualif.id with | FunOrOp fun_id -> - extract_function_call ctx fmt inside fun_id qualif.type_args - qualif.const_generic_args args + extract_function_call ctx fmt inside fun_id qualif.generics args | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> - extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args - qualif.const_generic_args args + extract_adt_cons ctx fmt inside adt_cons_id qualif.generics args | Proj proj -> - extract_field_projector ctx fmt inside app proj qualif.type_args args) + extract_field_projector ctx fmt inside app proj qualif.generics args + | TraitConst (trait_ref, generics, const_name) -> + let use_brackets = generics <> empty_generic_args in + if use_brackets then F.pp_print_string fmt "("; + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref; + extract_generic_args ctx fmt TypeDeclId.Set.empty generics; + let name = + ctx_get_trait_const trait_ref.trait_decl_ref.trait_decl_id + const_name ctx + in + if use_brackets then F.pp_print_string fmt ")"; + F.pp_print_string fmt ("." ^ name)) | _ -> (* "Regular" expression *) (* Open parentheses *) @@ -2309,8 +2605,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: function call *) and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (fid : fun_or_op_id) (type_args : ty list) - (cg_args : const_generic list) (args : texpression list) : unit = + (inside : bool) (fid : fun_or_op_id) (generics : generic_args) + (args : texpression list) : unit = match (fid, args) with | Unop unop, [ arg ] -> (* A unop can have *at most* one argument (the result can't be a function!). @@ -2329,22 +2625,97 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the function name *) let with_opaque_pre = ctx.use_opaque_pre in - let fun_name = ctx_get_function with_opaque_pre fun_id ctx in - F.pp_print_string fmt fun_name; - (* Sanity check: HOL4 doesn't support const generics *) - assert (cg_args = [] || !backend <> HOL4); - (* Print the type parameters, if the backend is not HOL4 *) - if !backend <> HOL4 then ( - List.iter - (fun ty -> - F.pp_print_space fmt (); - extract_ty ctx fmt TypeDeclId.Set.empty true ty) - type_args; - List.iter - (fun cg -> + (* For the function name: the id is not the same depending on whether + we call a trait method and a "regular" function (remark: trait + method *implementations* are considered as regular functions here; + only calls to method of traits which are parameterized in a where + clause have a special treatment. + + Remark: the reason why trait method declarations have a special + treatment is that, as traits are extracted to records, we may + allow collisions between trait item names and some other names, + while we do not allow collisions between function names. + + # Impl trait refs: + ================== + When the trait ref refers to an impl, in + [InterpreterStatement.eval_transparent_function_call_symbolic] we + replace the call to the trait impl method to a call to the function + which implements the trait method (that is, we "forget" that we + called a trait method, and treat it as a regular function call). + + # Provided trait methods: + ========================= + Calls to provided trait methods also have a special treatment. + For now, we do not allow overriding provided trait methods (methods + for which a default implementation is provided in the trait declaration). + Whenever we translate a provided trait method, we translate it once as + a function which takes a trait ref as input. We have to handle this + case below. + + With an example, if in Rust we write: + {[ + fn Foo { + fn f(&self) -> u32; // Required + fn ret_true(&self) -> bool { true } // Provided + } + ]} + + We generate: + {[ + structure Foo (Self : Type) = { + f : Self -> result u32 + } + + let ret_true (Self : Type) (self_clause : Foo Self) (self : Self) : result bool = + true + ]} + *) + (match fun_id with + | FromLlbc + (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) -> + (* We have to check whether the trait method is required or provided *) + let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in + let trait_decl = + TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls + in + let method_id = + PureUtils.trait_decl_get_method trait_decl method_name + in + + if not method_id.is_provided then ( + (* Required method *) + assert (lp_id = None); + extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref; + let fun_name = + ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id + method_name rg_id ctx + in + F.pp_print_string fmt ("." ^ fun_name)) + else + (* Provided method: we see it as a regular function call, and use + the function name *) + let fun_id = + FromLlbc (FunId (A.Regular method_id.id), lp_id, rg_id) + in + let fun_name = ctx_get_function with_opaque_pre fun_id ctx in + F.pp_print_string fmt fun_name; + + (* Note that we do not need to print the generics for the trait + declaration: they are always implicit as they can be deduced + from the trait self clause. + + Print the trait ref (to instantate the self clause) *) F.pp_print_space fmt (); - extract_const_generic ctx fmt true cg) - cg_args); + extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref + | _ -> + let fun_name = ctx_get_function with_opaque_pre fun_id ctx in + F.pp_print_string fmt fun_name); + + (* Sanity check: HOL4 doesn't support const generics *) + assert (generics.const_generics = [] || !backend <> HOL4); + (* Print the generics *) + extract_generic_args ctx fmt TypeDeclId.Set.empty generics; (* Print the arguments *) List.iter (fun ve -> @@ -2366,9 +2737,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (** Subcase of the app case: ADT constructor *) and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (adt_cons : adt_cons_id) (type_args : ty list) - (cg_args : const_generic list) (args : texpression list) : unit = - let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in + (adt_cons : adt_cons_id) (generics : generic_args) (args : texpression list) + : unit = + let e_ty = Adt (adt_cons.adt_id, generics) in let is_single_pat = false in let _ = extract_adt_g_value @@ -2382,7 +2753,7 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: ADT field projector. *) and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (original_app : texpression) (proj : projection) - (_proj_type_params : ty list) (args : texpression list) : unit = + (_generics : generic_args) (args : texpression list) : unit = (* We isolate the first argument (if there is), in order to pretty print the * projection ([x.field] instead of [MkAdt?.field x] *) match args with @@ -2734,9 +3105,7 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) let extract_as_unit = match (!backend, supd.struct_id) with | HOL4, AdtId adt_id -> - let d = - TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls - in + let d = TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls in d.kind = Struct [] | _ -> false in @@ -2835,17 +3204,17 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hvbox fmt ctx.indent_incr; let need_paren = inside in if need_paren then F.pp_print_string fmt "("; - (* Open the box for `Array.mk T N [` *) + (* Open the box for `Array.replicate T N [` *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the array constructor *) let cs = ctx_get_struct false (Assumed Array) ctx in F.pp_print_string fmt cs; (* Print the parameters *) - let _, tys, cgs = ty_as_adt e_ty in - let ty = Collections.List.to_cons_nil tys in + let _, generics = ty_as_adt e_ty in + let ty = Collections.List.to_cons_nil generics.types in F.pp_print_space fmt (); extract_ty ctx fmt TypeDeclId.Set.empty true ty; - let cg = Collections.List.to_cons_nil cgs in + let cg = Collections.List.to_cons_nil generics.const_generics in F.pp_print_space fmt (); extract_const_generic ctx fmt true cg; F.pp_print_space fmt (); @@ -2872,17 +3241,15 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) F.pp_close_box fmt () | _ -> raise (Failure "Unreachable") -(** Insert a space, if necessary *) -let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = - if !space then space := false else F.pp_print_space fmt () - (** A small utility to print the parameters of a function signature. We return two contexts: - - the context augmented with bindings for the type parameters - - the context augmented with bindings for the type parameters *and* + - the context augmented with bindings for the generics + - the context augmented with bindings for the generics *and* bindings for the input values + We also return names for the type parameters, const generics, etc. + TODO: do we really need the first one? We should probably always use the second one. It comes from the fact that when we print the input values for the @@ -2890,57 +3257,40 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = patterns, not the variables). We should figure a cleaner way. *) let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) - (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx = + (fmt : F.formatter) (def : fun_decl) : + extraction_ctx * extraction_ctx * string list = + (* First, add the associated types and constants if the function is a method + in a trait declaration. + + About the order: we want to make sure the names are reserved for + those (variable names might collide with them but it is ok, we will add + suffixes to the variables). + + TODO: micro-pass to update what happens when calling trait provided + functions. + *) + let ctx, trait_decl = + match def.kind with + | TraitMethodProvided (decl_id, _) -> + let trait_decl = T.TraitDeclId.Map.find decl_id ctx.trans_trait_decls in + let ctx, _ = ctx_add_trait_self_clause ctx in + let ctx = { ctx with is_provided_method = true } in + (ctx, Some trait_decl) + | _ -> (ctx, None) + in (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx, type_params, cg_params = - ctx_add_type_const_generic_params def.signature.type_params - def.signature.const_generic_params ctx + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params def.signature.generics ctx in - (* Print the parameters - rem.: we should have filtered the functions - * with no input parameters *) - (* The type parameters. - - Note that in HOL4 we don't print the type parameters. - *) - if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then ( - (* Open a box for the type and const generic parameters *) - F.pp_open_hovbox fmt 0; - (* The type parameters *) - if type_params <> [] then ( - insert_req_space fmt space; - F.pp_print_string fmt "("; - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.signature.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - let type_keyword = - match !backend with - | FStar -> "Type0" - | Coq | Lean -> "Type" - | HOL4 -> raise (Failure "Unreachable") - in - F.pp_print_string fmt (type_keyword ^ ")")); - (* The const generic parameters *) - if cg_params <> [] then - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - insert_req_space fmt space; - F.pp_print_string fmt "("; - F.pp_print_string fmt pname; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt p.ty; - F.pp_print_string fmt ")") - def.signature.const_generic_params; - (* Close the box for the type parameters *) - F.pp_close_box fmt ()); + (* Print the generics *) + (* Open a box for the generics *) + F.pp_open_hovbox fmt 0; + (let space = Some space in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~space ~trait_decl + def.signature.generics type_params cg_params trait_clauses); + (* Close the box for the generics *) + F.pp_close_box fmt (); (* The input parameters - note that doing this adds bindings to the context *) let ctx_body = match def.body with @@ -2963,7 +3313,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) ctx) ctx body.inputs_lvs in - (ctx, ctx_body) + (ctx, ctx_body, List.concat [ type_params; cg_params; trait_clauses ]) (** A small utility to print the types of the input parameters in the form: [u32 -> list u32 -> ...] @@ -2982,6 +3332,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx) in List.iter extract_param def.signature.inputs +let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx) + (fmt : F.formatter) (def : fun_decl) : unit = + extract_fun_input_parameters_types ctx fmt def; + extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output + let assert_backend_supports_decreases_clauses () = match !backend with | FStar | Lean -> () @@ -3032,7 +3387,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx) F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in - let _, _ = extract_fun_parameters space ctx fmt def in + let _, _, _ = extract_fun_parameters space ctx fmt def in insert_req_space fmt space; F.pp_print_string fmt ":"; (* Print the signature *) @@ -3094,7 +3449,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in - let _, ctx_body = extract_fun_parameters space ctx fmt def in + let _, ctx_body, _ = extract_fun_parameters space ctx fmt def in (* Print the ":=" *) F.pp_print_space fmt (); F.pp_print_string fmt ":="; @@ -3164,7 +3519,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = let { keep_fwd; num_backs } = PureUtils.RegularFunIdMap.find - (A.Regular def.def_id, def.loop_id, def.back_id) + (Pure.FunId (A.Regular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in @@ -3234,9 +3589,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let is_opaque_coq = !backend = Coq && is_opaque in let use_forall = - is_opaque_coq - && (def.signature.type_params <> [] - || def.signature.const_generic_params <> []) + is_opaque_coq && def.signature.generics <> empty_generic_params in (* Print the qualifier ("assume", etc.). @@ -3262,7 +3615,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for "(PARAMS) :" *) F.pp_open_hovbox fmt 0; let space = ref true in - let ctx, ctx_body = extract_fun_parameters space ctx fmt def in + let ctx, ctx_body, all_params = extract_fun_parameters space ctx fmt def in (* Print the return type - note that we have to be careful when * printing the input values for the decrease clause, because * it introduces bindings in the context... We thus "forget" @@ -3310,20 +3663,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* The name of the decrease clause *) let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in F.pp_print_string fmt decr_name; - (* Print the type/const generic parameters - TODO: we do this many + (* Print the generic parameters - TODO: we do this many times, we should have a helper to factor it out *) List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in + (fun (name : string) -> F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.type_params; - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.const_generic_params; + F.pp_print_string fmt name) + all_params; (* Print the input values: we have to be careful here to print * only the input values which are in common with the *forward* * function (the additional input values "given back" to the @@ -3410,19 +3756,12 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open the box for [DECREASES] *) F.pp_open_hovbox fmt ctx.indent_incr; F.pp_print_string fmt terminates_name; - (* Print the type/const generic params - TODO: factor out *) + (* Print the generic params - TODO: factor out *) List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in + (fun (name : string) -> F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.type_params; - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.const_generic_params; + F.pp_print_string fmt name) + all_params; (* Print the variables *) List.iter (fun v -> @@ -3480,13 +3819,10 @@ let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id ctx in - assert (def.signature.const_generic_params = []); + assert (def.signature.generics.const_generics = []); (* Add the type/const gen parameters - note that we need those bindings only for the generation of the type (they are not top-level) *) - let ctx, _, _ = - ctx_add_type_const_generic_params def.signature.type_params - def.signature.const_generic_params ctx - in + let ctx, _, _, _ = ctx_add_generic_params def.signature.generics ctx in (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0; (* Open a box for the whole definition *) @@ -3635,8 +3971,13 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (* Print the type *) F.pp_open_hovbox fmt 0; extract_ty ctx fmt TypeDeclId.Set.empty false ty; + (* Close the definition *) + F.pp_print_string fmt ")"; F.pp_close_box fmt (); - (* Close the definition boxe *) F.pp_close_box fmt () + (* Close the definition box *) + F.pp_close_box fmt (); + (* Add a line *) + F.pp_print_space fmt () (** Extract a global declaration. @@ -3662,10 +4003,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = assert body.is_global_decl_body; assert (Option.is_none body.back_id); - assert (List.length body.signature.inputs = 0); + assert (body.signature.inputs = []); assert (List.length body.signature.doutputs = 1); - assert (List.length body.signature.type_params = 0); - assert (List.length body.signature.const_generic_params = 0); + assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) F.pp_print_break fmt 0 0; @@ -3676,7 +4016,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in let body_name = ctx_get_function with_opaque_pre - (FromLlbc (Regular global.body_id, None, None)) + (FromLlbc (Pure.FunId (Regular global.body_id), None, None)) ctx in @@ -3713,6 +4053,433 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break to insert lines between declarations *) F.pp_print_break fmt 0 0 +(** Register the names for one trait method item *) +let extract_trait_decl_method_register_names (ctx : extraction_ctx) + (trait_decl : trait_decl) (name : string) (id : fun_decl_id) : + extraction_ctx = + (* We add one field per required forward/backward function *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + + let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in + (* Register the names *) + let funs = trans.fwd :: trans.backs in + List.fold_left register_fun ctx funs + +(** Similar to {!extract_type_decl_register_names} *) +let extract_trait_decl_register_names (ctx : extraction_ctx) + (trait_decl : trait_decl) : extraction_ctx = + let { + def_id = _; + name = _; + generics; + preds = _; + all_trait_clauses = _; + consts; + types; + required_methods; + provided_methods = _; + } = + trait_decl + in + let ctx = ctx_add_trait_decl trait_decl ctx in + (* Parent clauses *) + let ctx = + List.fold_left + (fun ctx clause -> ctx_add_trait_parent_clause trait_decl clause ctx) + ctx generics.trait_clauses + in + (* Constants *) + let ctx = + List.fold_left + (fun ctx (name, (_, _)) -> ctx_add_trait_const trait_decl name ctx) + ctx consts + in + (* Types *) + let ctx = + List.fold_left + (fun ctx (name, (clauses, _)) -> + let ctx = ctx_add_trait_type trait_decl name ctx in + List.fold_left + (fun ctx clause -> + ctx_add_trait_type_clause trait_decl name clause ctx) + ctx clauses) + ctx types + in + (* Required methods *) + List.fold_left + (fun ctx (name, id) -> + (* We add one field per required forward/backward function *) + extract_trait_decl_method_register_names ctx trait_decl name id) + ctx required_methods + +(** Similar to {!extract_type_decl_register_names} *) +let extract_trait_impl_register_names (ctx : extraction_ctx) + (trait_impl : trait_impl) : extraction_ctx = + (* For now we do not support overriding provided methods *) + assert (trait_impl.provided_methods = []); + (* Everything is taken care of by {!extract_trait_decl_register_names} *but* + the name of the implementation itself *) + ctx_add_trait_impl trait_impl ctx + +(** Small helper. + + The type `ty` is to be understood in a very general sense. + *) +let extract_trait_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (separator : string) (ty : unit -> unit) : unit = + F.pp_print_space fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt item_name; + F.pp_print_space fmt (); + (* ":" or "=" *) + F.pp_print_string fmt separator; + ty (); + (match !Config.backend with Lean -> () | _ -> F.pp_print_string fmt ";"); + F.pp_close_box fmt () + +let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (ty : unit -> unit) : unit = + extract_trait_item ctx fmt item_name ":" ty + +let extract_trait_impl_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (ty : unit -> unit) : unit = + let assign = match !Config.backend with Lean -> ":=" | _ -> "=" in + extract_trait_item ctx fmt item_name assign ty + +(** Small helper - TODO: move *) +let generic_params_drop_prefix (g1 : generic_params) (g2 : generic_params) : + generic_params = + let open Collections.List in + let types = drop (length g1.types) g2.types in + let const_generics = drop (length g1.const_generics) g2.const_generics in + let trait_clauses = drop (length g1.trait_clauses) g2.trait_clauses in + { types; const_generics; trait_clauses } + +(** Small helper. + + Extract the items for a method in a trait decl. + *) +let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit = + (* Lookup the definition *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + let extract_method (f : fun_and_loops) = + let f = f.f in + let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in + let ty () = + (* Extract the generics *) + (* We need to add the generics specific to the method, by removing those + which actually apply to the trait decl *) + let generics = + generic_params_drop_prefix decl.generics f.signature.generics + in + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + let use_forall = generics <> empty_generic_params in + let use_forall_use_sep = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall + ~use_forall_use_sep generics type_params cg_params trait_clauses; + if use_forall then F.pp_print_string fmt ","; + (* Extract the inputs and output *) + F.pp_print_space fmt (); + extract_fun_inputs_output_parameters_types ctx fmt f + in + extract_trait_decl_item ctx fmt fun_name ty + in + List.iter extract_method funs + +(** Extract a trait declaration *) +let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) : unit = + (* Retrieve the trait name *) + let with_opaque_pre = false in + let decl_name = ctx_get_trait_decl with_opaque_pre decl.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt + [ "Trait declaration: [" ^ Print.name_to_string decl.name ^ "]" ]; + F.pp_print_break fmt 0 0; + (* Open two outer boxes for the definition, so that whenever possible it gets printed on + one line and indents are correct. + + There is just an exception with Lean: in this backend, line breaks are important + for the parsing, so we always open a vertical box. + *) + if !Config.backend = Lean then F.pp_open_vbox fmt ctx.indent_incr + else ( + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr); + + (* `struct Trait (....) =` *) + (* Open the box for the name + generics *) + F.pp_open_hovbox fmt ctx.indent_incr; + let qualif = + Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct)) + in + F.pp_print_string fmt qualif; + F.pp_print_space fmt (); + F.pp_print_string fmt decl_name; + (* Print the generics *) + (* We ignore the trait clauses, which we extract as *fields* *) + let generics = { decl.generics with trait_clauses = [] } in + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + extract_generic_params ctx fmt TypeDeclId.Set.empty generics type_params + cg_params trait_clauses; + + F.pp_print_space fmt (); + (match !backend with + | Lean -> F.pp_print_string fmt "where" + | _ -> F.pp_print_string fmt "{"); + + (* Close the box for the name + generics *) + F.pp_close_box fmt (); + + (* + * Extract the items + *) + + (* The parent clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx + in + let ty () = + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + decl.generics.trait_clauses; + + (* The constants *) + List.iter + (fun (name, (ty, _)) -> + let item_name = ctx_get_trait_const decl.def_id name ctx in + let ty () = + let inside = false in + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty inside ty + in + extract_trait_decl_item ctx fmt item_name ty) + decl.consts; + + (* The types *) + List.iter + (fun (name, (clauses, _)) -> + (* Extract the type *) + let item_name = ctx_get_trait_type decl.def_id name ctx in + let ty () = + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ()) + in + extract_trait_decl_item ctx fmt item_name ty; + (* Extract the clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + clauses) + decl.types; + + (* The required methods *) + List.iter + (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id) + decl.required_methods; + + (* Close the brackets *) + (match !Config.backend with + | Lean -> () + | _ -> + F.pp_print_space fmt (); + F.pp_print_string fmt "}"); + + (* Close the outer boxes for the definition *) + if !Config.backend <> Lean then F.pp_close_box fmt (); + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Small helper. + + Extract the items for a method in a trait impl. + *) +let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) + (impl : trait_impl) (item_name : string) (id : fun_decl_id) + (impl_generics : string list * string list * string list) : unit = + let trait_decl_id = impl.impl_trait.trait_decl_id in + (* Lookup the definition *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + let extract_method (f : fun_and_loops) = + let f = f.f in + let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in + let ty () = + (* Extract the generics - we need to quantify over the generics which + are specific to the method, and call it will all the generics + (trait impl + method generics) *) + let f_generics = + generic_params_drop_prefix impl.generics f.signature.generics + in + let ctx, f_tys, f_cgs, f_tcs = ctx_add_generic_params f_generics ctx in + let use_forall = f_generics <> empty_generic_params in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics + f_tys f_cgs f_tcs; + if use_forall then F.pp_print_string fmt ","; + (* Extract the function call *) + F.pp_print_space fmt (); + let id = ctx_get_local_function false f.def_id None f.back_id ctx in + F.pp_print_string fmt id; + let all_generics = + let i_tys, i_cgs, i_tcs = impl_generics in + List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] + in + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + all_generics + in + extract_trait_impl_item ctx fmt fun_name ty + in + List.iter extract_method funs + +(** Extract a trait implementation *) +let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) + (impl : trait_impl) : unit = + (* Retrieve the impl name *) + let with_opaque_pre = false in + let impl_name = ctx_get_trait_impl with_opaque_pre impl.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt + [ "Trait implementation: [" ^ Print.name_to_string impl.name ^ "]" ]; + F.pp_print_break fmt 0 0; + + (* Open two outer boxes for the definition, so that whenever possible it gets printed on + one line and indents are correct. + + There is just an exception with Lean: in this backend, line breaks are important + for the parsing, so we always open a vertical box. + *) + if !Config.backend = Lean then ( + F.pp_open_vbox fmt 0; + F.pp_open_vbox fmt ctx.indent_incr) + else ( + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr); + + (* `let (....) : Trait ... =` *) + (* Open the box for the name + generics *) + F.pp_open_hovbox fmt ctx.indent_incr; + let qualif = Option.get (ctx.fmt.fun_decl_kind_to_qualif SingleNonRec) in + F.pp_print_string fmt qualif; + F.pp_print_space fmt (); + F.pp_print_string fmt impl_name; + + (* Print the generics *) + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params impl.generics ctx + in + let all_generics = (type_params, cg_params, trait_clauses) in + extract_generic_params ctx fmt TypeDeclId.Set.empty impl.generics type_params + cg_params trait_clauses; + + (* Print the type *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_trait_decl_ref ctx fmt TypeDeclId.Set.empty false impl.impl_trait; + + F.pp_print_space fmt (); + if !Config.backend = Lean then F.pp_print_string fmt ":= {" + else F.pp_print_string fmt "= {"; + + (* Close the box for the name + generics *) + F.pp_close_box fmt (); + + (* + * Extract the items + *) + + (* The parent clauses - we retrieve those from the impl_ref *) + let trait_decl_id = impl.impl_trait.trait_decl_id in + TraitClauseId.iteri + (fun clause_id trait_ref -> + let item_name = ctx_get_trait_parent_clause trait_decl_id clause_id ctx in + let ty () = + F.pp_print_space fmt (); + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref + in + extract_trait_impl_item ctx fmt item_name ty) + impl.impl_trait.decl_generics.trait_refs; + + (* The constants *) + List.iter + (fun (name, (_, id)) -> + let item_name = ctx_get_trait_const trait_decl_id name ctx in + let ty () = + F.pp_print_space fmt (); + F.pp_print_string fmt (ctx_get_global false id ctx) + in + + extract_trait_impl_item ctx fmt item_name ty) + impl.consts; + + (* The types *) + List.iter + (fun (name, (trait_refs, ty)) -> + (* Extract the type *) + let item_name = ctx_get_trait_type trait_decl_id name ctx in + let ty () = + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty false ty + in + extract_trait_impl_item ctx fmt item_name ty; + (* Extract the clauses *) + TraitClauseId.iteri + (fun clause_id trait_ref -> + let item_name = + ctx_get_trait_item_clause trait_decl_id name clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref + in + extract_trait_impl_item ctx fmt item_name ty) + trait_refs) + impl.types; + + (* The required methods *) + List.iter + (fun (name, id) -> + extract_trait_impl_method_items ctx fmt impl name id all_generics) + impl.required_methods; + + (* Close the outer boxes for the definition, as well as the brackets *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + F.pp_print_string fmt "}"; + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + (** Extract a unit test, if the function is a unit function (takes no parameters, returns unit). @@ -3735,8 +4502,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) (* Check if this is a unit function *) let sg = def.signature in if - sg.type_params = [] - && sg.const_generic_params = [] + sg.generics = empty_generic_params && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) && sg.output = mk_result_ty mk_unit_ty then ( diff --git a/compiler/ExtractAssumed.ml b/compiler/ExtractAssumed.ml new file mode 100644 index 00000000..7f094b24 --- /dev/null +++ b/compiler/ExtractAssumed.ml @@ -0,0 +1,61 @@ +(** This file declares external identifiers that we catch to map them to + definitions coming from the standard libraries in our backends. *) + +open Names + +type simple_name = string list [@@deriving show, ord] + +let name_to_simple_name (s : name) : simple_name = + (* We simply ignore the disambiguators *) + List.filter_map (function Ident id -> Some id | Disambiguator _ -> None) s + +(** Small helper which cuts a string at the occurrences of "::" *) +let string_to_simple_name (s : string) : simple_name = + (* No function to split by using string separator?? *) + let name = String.split_on_char ':' s in + List.filter (fun s -> s <> "") name + +module SimpleNameOrd = struct + type t = simple_name + + let compare = compare_simple_name + let to_string = show_simple_name + let pp_t = pp_simple_name + let show_t = show_simple_name +end + +module SimpleNameMap = Collections.MakeMap (SimpleNameOrd) + +let assumed_globals : (string * string) list = + [ + (* Min *) + ("core::num::usize::MIN", "core_usize_min"); + ("core::num::u8::MIN", "core_u8_min"); + ("core::num::u16::MIN", "core_u16_min"); + ("core::num::u32::MIN", "core_u32_min"); + ("core::num::u64::MIN", "core_u64_min"); + ("core::num::u128::MIN", "core_u128_min"); + ("core::num::isize::MIN", "core_isize_min"); + ("core::num::i8::MIN", "core_i8_min"); + ("core::num::i16::MIN", "core_i16_min"); + ("core::num::i32::MIN", "core_i32_min"); + ("core::num::i64::MIN", "core_i64_min"); + ("core::num::i128::MIN", "core_i128_min"); + (* Max *) + ("core::num::usize::MAX", "core_usize_max"); + ("core::num::u8::MAX", "core_u8_max"); + ("core::num::u16::MAX", "core_u16_max"); + ("core::num::u32::MAX", "core_u32_max"); + ("core::num::u64::MAX", "core_u64_max"); + ("core::num::u128::MAX", "core_u128_max"); + ("core::num::isize::MAX", "core_isize_max"); + ("core::num::i8::MAX", "core_i8_max"); + ("core::num::i16::MAX", "core_i16_max"); + ("core::num::i32::MAX", "core_i32_max"); + ("core::num::i64::MAX", "core_i64_max"); + ("core::num::i128::MAX", "core_i128_max"); + ] + +let assumed_globals_map : string SimpleNameMap.t = + SimpleNameMap.of_list + (List.map (fun (x, y) -> (string_to_simple_name x, y)) assumed_globals) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index d733c763..1586e6ed 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -5,6 +5,7 @@ open TranslateCore module C = Contexts module RegionVarId = T.RegionVarId module F = Format +open ExtractAssumed (** The local logger *) let log = L.pure_to_extract_log @@ -77,6 +78,7 @@ type decl_kind = F*: [val x : Type0] Coq: [Axiom x : Type.] *) +[@@deriving show] (** Return [true] if the declaration is the last from its group of declarations. @@ -111,9 +113,9 @@ let decl_is_first_from_group (kind : decl_kind) : bool = let decl_is_not_last_from_group (kind : decl_kind) : bool = not (decl_is_last_from_group kind) -(* TODO: this should a module we give to a functor! *) +type type_decl_kind = Enum | Struct [@@deriving show] -type type_decl_kind = Enum | Struct +(* TODO: this should be a module we give to a functor! *) (** A formatter's role is twofold: 1. Come up with name suggestions. @@ -125,6 +127,9 @@ type type_decl_kind = Enum | Struct snake case, adding prefixes/suffixes, etc. 2. Format some specific terms, like constants. + + TODO: unclear that this is useful now that all the backends are so much + entangled in Extract.ml *) type formatter = { bool_name : string; @@ -239,6 +244,13 @@ type formatter = { the same purpose as in {!field:fun_name}. - loop identifier, if this is for a loop *) + trait_decl_name : trait_decl -> string; + trait_impl_name : trait_decl -> trait_impl -> string; + trait_parent_clause_name : trait_decl -> trait_clause -> string; + trait_const_name : trait_decl -> string -> string; + trait_type_name : trait_decl -> string -> string; + trait_method_name : trait_decl -> string -> string; + trait_type_clause_name : trait_decl -> string -> trait_clause -> string; opaque_pre : unit -> string; (** TODO: obsolete, remove. @@ -288,6 +300,14 @@ type formatter = { (** Generates a type variable basename. *) const_generic_var_basename : StringSet.t -> string -> string; (** Generates a const generic variable basename. *) + trait_self_clause_basename : string; + trait_clause_basename : StringSet.t -> trait_clause -> string; + (** Return a base name for a trait clause. We might add a suffix to prevent + collisions. + + In the traduction we explicitely manipulate the trait clause instances, + that is we introduce one input variable for each trait clause. + *) append_index : string -> int -> string; (** Appends an index to a name - we use this to generate unique names: when doing so, the role of the formatter is just to concatenate @@ -396,10 +416,59 @@ type id = | TypeVarId of TypeVarId.id | ConstGenericVarId of ConstGenericVarId.id | VarId of VarId.id + | TraitDeclId of TraitDeclId.id + | TraitImplId of TraitImplId.id + | LocalTraitClauseId of TraitClauseId.id + | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option + (** Something peculiar with trait methods: because we have to take into + account forward/backward functions, we may need to generate fields + items per method. + *) + | TraitItemId of TraitDeclId.id * string + (** A trait associated item which is not a method *) + | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id + | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id + | TraitSelfClauseId + (** Specifically for the clause: [Self : Trait]. + + For now, we forbid provided methods (methods in trait declarations + with a default implementation) from being overriden in trait implementations. + We extract trait provided methods such that they take an instance of + the trait as input: this instance is given by the trait self clause. + + For instance: + {[ + // + // Rust + // + trait ToU64 { + fn to_u64(&self) -> u64; + + // Provided method + fn is_pos(&self) -> bool { + self.to_u64() > 0 + } + } + + // + // Generated code + // + struct ToU64 (T : Type) { + to_u64 : T -> u64; + } + + // The trait self clause + // vvvvvvvvvvvvvvvvvvvvvv + let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool = + trait_self.to_u64 self > 0 + ]} + *) | UnknownId (** Used for stored various strings like keywords, definitions which should always be in context, etc. and which can't be linked to one of the above. + + TODO: rename to "keyword" *) [@@deriving show, ord] @@ -445,23 +514,52 @@ type names_map = { *) } -let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id) - (name : string) (nm : names_map) : names_map = - (* Check if there is a clash *) - (match StringMap.find_opt name nm.name_to_id with +let empty_names_map : names_map = + { + id_to_name = IdMap.empty; + name_to_id = StringMap.empty; + names_set = StringSet.empty; + opaque_ids = IdSet.empty; + } + +(** Small helper to report name collision *) +let report_name_collision (id_to_string : id -> string) (id1 : id) (id2 : id) + (name : string) : unit = + let id1 = "\n- " ^ id_to_string id1 in + let id2 = "\n- " ^ id_to_string id2 in + let err = + "Name clash detected: the following identifiers are bound to the same name \ + \"" ^ name ^ "\":" ^ id1 ^ id2 + ^ "\nYou may want to rename some of your definitions, or report an issue." + in + log#serror err; + (* If we fail hard on errors, raise an exception *) + if !Config.extract_fail_hard then raise (Failure err) + +let names_map_get_id_from_name (name : string) (nm : names_map) : id option = + StringMap.find_opt name nm.name_to_id + +let names_map_check_collision (id_to_string : id -> string) (id : id) + (name : string) (nm : names_map) : unit = + match names_map_get_id_from_name name nm with | None -> () (* Ok *) | Some clash -> (* There is a clash: print a nice debugging message for the user *) - let id1 = "\n- " ^ id_to_string clash in - let id2 = "\n- " ^ id_to_string id in - let err = - "Name clash detected: the following identifiers are bound to the same \ - name \"" ^ name ^ "\":" ^ id1 ^ id2 - in - log#serror err; - raise (Failure err)); + report_name_collision id_to_string clash id name + +let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id) + (name : string) (nm : names_map) : names_map = + (* Check if there is a clash *) + names_map_check_collision id_to_string id name nm; (* Sanity check *) - assert (not (StringSet.mem name nm.names_set)); + if StringSet.mem name nm.names_set then ( + let err = + "Error when registering the name for id: " ^ id_to_string id + ^ ":\nThe chosen name is already in the names set: " ^ name + in + log#serror err; + (* If we fail hard on errors, raise an exception *) + if !Config.extract_fail_hard then raise (Failure err)); (* Insert *) let id_to_name = IdMap.add id name nm.id_to_name in let name_to_id = StringMap.add name id nm.name_to_id in @@ -549,6 +647,7 @@ type fun_name_info = { keep_fwd : bool; num_backs : int } functions, etc. *) type extraction_ctx = { + crate : A.crate; trans_ctx : trans_ctx; names_map : names_map; (** The map for id to names, where we forbid name collisions @@ -556,6 +655,15 @@ type extraction_ctx = { unsafe_names_map : unsafe_names_map; (** The map for id to names, where we allow name collisions (ex.: we might allow record field name collisions). *) + strict_names_map : names_map; + (** This map is a sub-map of [names_map]. For the ids in this map we also + forbid collisions with names in the [unsafe_names_map]. + + We do so for keywords for instance, but also for types (in a dependently + typed language, we might have an issue if the field of a record has, say, + the name "u32", and another field of the same record refers to "u32" + (for instance in its type). + *) fmt : formatter; indent_incr : int; (** The indent increment we insert whenever we need to indent more *) @@ -586,6 +694,14 @@ type extraction_ctx = { in case a Rust function only has one backward translation and we filter the forward function because it returns unit. *) + trait_decl_id : trait_decl_id option; + (** If we are extracting a trait declaration, identifies it *) + is_provided_method : bool; + trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; + trans_funs : pure_fun_translation A.FunDeclId.Map.t; + functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; + trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t; + trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t; } (** Debugging function, used when communicating name collisions to the user, @@ -593,9 +709,16 @@ type extraction_ctx = { instance). *) let id_to_string (id : id) (ctx : extraction_ctx) : string = - let global_decls = ctx.trans_ctx.global_context.global_decls in - let fun_decls = ctx.trans_ctx.fun_context.fun_decls in - let type_decls = ctx.trans_ctx.type_context.type_decls in + let global_decls = ctx.trans_ctx.global_ctx.global_decls in + let fun_decls = ctx.trans_ctx.fun_ctx.fun_decls in + let type_decls = ctx.trans_ctx.type_ctx.type_decls in + let trait_decls = ctx.trans_ctx.trait_decls_ctx.trait_decls in + let trait_decl_id_to_string (id : A.TraitDeclId.id) : string = + let trait_name = + Print.fun_name_to_string (A.TraitDeclId.Map.find id trait_decls).name + in + "trait_decl: " ^ trait_name ^ " (id: " ^ A.TraitDeclId.to_string id ^ ")" + in (* TODO: factorize the pretty-printing with what is in PrintPure *) let get_type_name (id : type_id) : string = match id with @@ -614,10 +737,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | FromLlbc (fid, lp_id, rg_id) -> let fun_name = match fid with - | Regular fid -> + | FunId (Regular fid) -> Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name - | Assumed aid -> A.show_assumed_fun_id aid + | FunId (Assumed aid) -> A.show_assumed_fun_id aid + | TraitMethod (trait_ref, method_name, _) -> + (* Shouldn't happen *) + if !Config.extract_fail_hard then raise (Failure "Unexpected") + else + "Trait method: decl: " + ^ TraitDeclId.to_string trait_ref.trait_decl_ref.trait_decl_id + ^ ", method_name: " ^ method_name in let lp_kind = @@ -677,8 +807,16 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" else raise (Failure "Unreachable") - | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) -> - raise (Failure "Unreachable") + | Assumed Fuel -> + if variant_id = fuel_zero_id then "@fuel::0" + else if variant_id = fuel_succ_id then "@fuel::Succ" + else raise (Failure "Unreachable") + | Assumed (State | Vec | Array | Slice | Str | Range) -> + raise + (Failure + ("Unreachable: variant id (" + ^ VariantId.to_string variant_id + ^ ") for " ^ show_type_id id)) | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in match def.kind with @@ -716,43 +854,120 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | ConstGenericVarId id -> "const_generic_var_id: " ^ ConstGenericVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id + | TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id + | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id + | LocalTraitClauseId id -> + "local_trait_clause_id: " ^ TraitClauseId.to_string id + | TraitParentClauseId (id, clause_id) -> + "trait_parent_clause_id: " ^ trait_decl_id_to_string id ^ ", clause_id: " + ^ TraitClauseId.to_string clause_id + | TraitItemClauseId (id, item_name, clause_id) -> + "trait_item_clause_id: " ^ trait_decl_id_to_string id ^ ", item name: " + ^ item_name ^ ", clause_id: " + ^ TraitClauseId.to_string clause_id + | TraitItemId (id, name) -> + "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name + | TraitMethodId (trait_decl_id, fun_name, rg_id) -> + let fwd_back_kind = + match rg_id with + | None -> "forward" + | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id + in + trait_decl_id_to_string trait_decl_id + ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name + | TraitSelfClauseId -> "trait_self_clause" + +(** Return [true] if we are strict on collisions for this id (i.e., we forbid + collisions even with the ids in the unsafe names map) *) +let strict_collisions (id : id) : bool = + match id with UnknownId | TypeId _ -> true | _ -> false (** We might not check for collisions for some specific ids (ex.: field names) *) let allow_collisions (id : id) : bool = match id with - | FieldId (_, _) -> !Config.record_fields_short_names + | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _ + | TraitMethodId _ -> + !Config.record_fields_short_names | _ -> false let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = - (* We do not use the same name map if we allow/disallow collisions *) + (* The id_to_string function to print nice debugging messages if there are + * collisions *) + let id_to_string (id : id) : string = id_to_string id ctx in + (* We do not use the same name map if we allow/disallow collisions. + We notably use it for field names: some backends like Lean can use the + type information to disambiguate field projections. + + Remark: we still need to check that those "unsafe" ids don't collide with + the ids that we mark as "strict on collision". + + For instance, we don't allow naming a field "let". We enforce this by + not checking collision between ids for which we permit collisions (ex.: + between fields), but still checking collisions between those ids and the + others (ex.: fields and keywords). + *) if allow_collisions id then ( assert (not is_opaque); + (* Check with the ids which are considered to be strict on collisions *) + names_map_check_collision id_to_string id name ctx.strict_names_map; { ctx with unsafe_names_map = unsafe_names_map_add id name ctx.unsafe_names_map; }) else - (* The id_to_string function to print nice debugging messages if there are - * collisions *) - let id_to_string (id : id) : string = id_to_string id ctx in + (* Remark: if we are strict on collisions: + - we add the id to the strict collisions map + - we check that the id doesn't collide with the unsafe map + *) + let strict_names_map = + if strict_collisions id then + names_map_add id_to_string is_opaque id name ctx.strict_names_map + else ctx.strict_names_map + in let names_map = names_map_add id_to_string is_opaque id name ctx.names_map in - { ctx with names_map } + { ctx with strict_names_map; names_map } (** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *) let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string = (* We do not use the same name map if we allow/disallow collisions *) - if allow_collisions id then IdMap.find id ctx.unsafe_names_map.id_to_name + let map_to_string (m : string IdMap.t) : string = + "[\n" + ^ String.concat "," + (List.map + (fun (id, n) -> "\n " ^ id_to_string id ctx ^ " -> " ^ n) + (IdMap.bindings m)) + ^ "\n]" + in + if allow_collisions id then ( + let m = ctx.unsafe_names_map.id_to_name in + match IdMap.find_opt id m with + | Some s -> s + | None -> + let err = + "Could not find: " ^ id_to_string id ctx ^ "\nNames map:\n" + ^ map_to_string m + in + log#serror err; + if !Config.extract_fail_hard then raise (Failure err) + else + "(%%%ERROR: unknown identifier\": " ^ id_to_string id ctx ^ "\"%%%)") else - match IdMap.find_opt id ctx.names_map.id_to_name with + let m = ctx.names_map.id_to_name in + match IdMap.find_opt id m with | Some s -> let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s | None -> - log#serror ("Could not find: " ^ id_to_string id ctx); - raise Not_found + let err = + "Could not find: " ^ id_to_string id ctx ^ "\nNames map:\n" + ^ map_to_string m + in + log#serror err; + if !Config.extract_fail_hard then raise (Failure err) + else "(ERROR: \"" ^ id_to_string id ctx ^ "\")" let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = @@ -765,7 +980,7 @@ let ctx_get_function (with_opaque_pre : bool) (id : fun_id) let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id) (lp : LoopId.id option) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx + ctx_get_function with_opaque_pre (FromLlbc (FunId (Regular id), lp, rg)) ctx let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx) : string = @@ -785,6 +1000,46 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = let is_opaque = false in ctx_get_type is_opaque (Assumed id) ctx +let ctx_get_trait_self_clause (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre TraitSelfClauseId ctx + +let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id) + (ctx : extraction_ctx) : string = + ctx_get with_opaque_pre (TraitDeclId id) ctx + +let ctx_get_trait_impl (with_opaque_pre : bool) (id : trait_impl_id) + (ctx : extraction_ctx) : string = + ctx_get with_opaque_pre (TraitImplId id) ctx + +let ctx_get_trait_item (id : trait_decl_id) (item_name : string) + (ctx : extraction_ctx) : string = + let is_opaque = false in + ctx_get is_opaque (TraitItemId (id, item_name)) ctx + +let ctx_get_trait_const (id : trait_decl_id) (item_name : string) + (ctx : extraction_ctx) : string = + ctx_get_trait_item id item_name ctx + +let ctx_get_trait_type (id : trait_decl_id) (item_name : string) + (ctx : extraction_ctx) : string = + ctx_get_trait_item id item_name ctx + +let ctx_get_trait_method (id : trait_decl_id) (item_name : string) + (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre (TraitMethodId (id, item_name, rg_id)) ctx + +let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) + (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre (TraitParentClauseId (id, clause)) ctx + +let ctx_get_trait_item_clause (id : trait_decl_id) (item : string) + (clause : trait_clause_id) (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre (TraitItemClauseId (id, item, clause)) ctx + let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = let is_opaque = false in ctx_get is_opaque (VarId id) ctx @@ -798,6 +1053,11 @@ let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx) let is_opaque = false in ctx_get is_opaque (ConstGenericVarId id) ctx +let ctx_get_local_trait_clause (id : TraitClauseId.id) (ctx : extraction_ctx) : + string = + let is_opaque = false in + ctx_get is_opaque (LocalTraitClauseId id) ctx + let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = let is_opaque = false in @@ -863,6 +1123,26 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : let ctx = ctx_add is_opaque (VarId id) name ctx in (ctx, name) +(** Generate a unique variable name for the trait self clause and add it to the context *) +let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in + let basename = ctx.fmt.trait_self_clause_basename in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + in + let ctx = ctx_add is_opaque TraitSelfClauseId name ctx in + (ctx, name) + +(** Generate a unique trait clause name and add it to the context *) +let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id) + (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + in + let ctx = ctx_add is_opaque (LocalTraitClauseId id) name ctx in + (ctx, name) + (** See {!ctx_add_var} *) let ctx_add_vars (vars : var list) (ctx : extraction_ctx) : extraction_ctx * string list = @@ -885,12 +1165,26 @@ let ctx_add_const_generic_params (vars : const_generic_var list) ctx_add_const_generic_var var.name var.index ctx) ctx vars -let ctx_add_type_const_generic_params (tvars : type_var list) - (cgvars : const_generic_var list) (ctx : extraction_ctx) : - extraction_ctx * string list * string list = - let ctx, tys = ctx_add_type_params tvars ctx in - let ctx, cgs = ctx_add_const_generic_params cgvars ctx in - (ctx, tys, cgs) +let ctx_add_local_trait_clauses (clauses : trait_clause list) + (ctx : extraction_ctx) : extraction_ctx * string list = + List.fold_left_map + (fun ctx (c : trait_clause) -> + let basename = ctx.fmt.trait_clause_basename ctx.names_map.names_set c in + ctx_add_local_trait_clause basename c.clause_id ctx) + ctx clauses + +(** Returns the lists of names for: + - the type variables + - the const generic variables + - the trait clauses + *) +let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) : + extraction_ctx * string list * string list * string list = + let { types; const_generics; trait_clauses } = generics in + let ctx, tys = ctx_add_type_params types ctx in + let ctx, cgs = ctx_add_const_generic_params const_generics ctx in + let ctx, tcs = ctx_add_local_trait_clauses trait_clauses ctx in + (ctx, tys, cgs, tcs) let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : extraction_ctx * string = @@ -977,49 +1271,62 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = (* TODO: update once the body id can be an option *) let is_opaque = false in - let name = ctx.fmt.global_name def.name in let decl = GlobalId def.def_id in - let body = FunId (FromLlbc (Regular def.body_id, None, None)) in - let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in - let ctx = ctx_add is_opaque body (name ^ "_body") ctx in - ctx -let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) - (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - (* Sanity check: the function should not be a global body - those are handled - * separately *) - assert (not def.is_global_decl_body); + (* Check if the global corresponds to an assumed global that we should map + to a custom definition in our standard library (for instance, happens + with "core::num::usize::MAX") *) + let sname = name_to_simple_name def.name in + match SimpleNameMap.find_opt sname assumed_globals_map with + | Some name -> + (* Yes: register the custom binding *) + ctx_add is_opaque decl name ctx + | None -> + (* Not the case: "standard" registration *) + let name = ctx.fmt.global_name def.name in + let body = FunId (FromLlbc (FunId (Regular def.body_id), None, None)) in + let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in + let ctx = ctx_add is_opaque body (name ^ "_body") ctx in + ctx + +let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) + (ctx : extraction_ctx) : string = (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in - let llbc_def = - A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls - in + let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in let sg = llbc_def.signature in let num_rgs = List.length sg.regions_hierarchy in - let keep_fwd, (_, backs) = trans_group in + let { keep_fwd; fwd = _; backs } = trans_group in let num_backs = List.length backs in let rg_info = match def.back_id with | None -> None | Some rg_id -> let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in - let regions = + let region_names = List.map - (fun rid -> T.RegionVarId.nth sg.region_params rid) + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) rg.regions in - let region_names = - List.map (fun (r : T.region_var) -> r.name) regions - in Some { id = rg_id; region_names } in + (* Add the function name *) + ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info + (keep_fwd, num_backs) + +let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl) + (ctx : extraction_ctx) : extraction_ctx = + (* Sanity check: the function should not be a global body - those are handled + * separately *) + assert (not def.is_global_decl_body); + (* Lookup the LLBC def to compute the region group information *) + let def_id = def.def_id in + let { keep_fwd; fwd = _; backs } = trans_group in + let num_backs = List.length backs in let is_opaque = def.body = None in (* Add the function name *) - let def_name = - ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info - (keep_fwd, num_backs) - in - let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in + let def_name = ctx_compute_fun_name trans_group def ctx in + let fun_id = (Pure.FunId (Regular def_id), def.loop_id, def.back_id) in let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in (* Add the name info *) { @@ -1029,6 +1336,91 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) ctx.fun_name_info; } +let ctx_add_trait_decl (d : trait_decl) (ctx : extraction_ctx) : extraction_ctx + = + let is_opaque = false in + let name = ctx.fmt.trait_decl_name d in + ctx_add is_opaque (TraitDeclId d.def_id) name ctx + +let ctx_add_trait_impl (d : trait_impl) (ctx : extraction_ctx) : extraction_ctx + = + (* We need to lookup the trait decl that is implemented by the trait impl *) + let decl = + Pure.TraitDeclId.Map.find d.impl_trait.trait_decl_id ctx.trans_trait_decls + in + (* Compute the name *) + let is_opaque = false in + let name = ctx.fmt.trait_impl_name decl d in + ctx_add is_opaque (TraitImplId d.def_id) name ctx + +let ctx_add_trait_const (d : trait_decl) (item : string) (ctx : extraction_ctx) + : extraction_ctx = + let is_opaque = false in + let name = ctx.fmt.trait_const_name d item in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name d ^ name + in + ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx + +let ctx_add_trait_type (d : trait_decl) (item : string) (ctx : extraction_ctx) : + extraction_ctx = + let is_opaque = false in + let name = ctx.fmt.trait_type_name d item in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name d ^ name + in + ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx + +let ctx_add_trait_method (d : trait_decl) (item_name : string) (f : fun_decl) + (ctx : extraction_ctx) : extraction_ctx = + (* We do something special: we use the base name but remove everything + but the crate (because [get_name] removes it) and the last ident. + This allows us to reuse the [ctx_compute_fun_decl] function. + *) + let basename : name = + match (f.basename : name) with + | Ident crate :: name -> Ident crate :: [ Collections.List.last name ] + | _ -> raise (Failure "Unexpected") + in + let f = { f with basename } in + let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in + let name = ctx_compute_fun_name trans f ctx in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name d ^ "_" ^ name + in + let is_opaque = false in + ctx_add is_opaque (TraitMethodId (d.def_id, item_name, f.back_id)) name ctx + +let ctx_add_trait_parent_clause (d : trait_decl) (clause : trait_clause) + (ctx : extraction_ctx) : extraction_ctx = + let is_opaque = false in + let name = ctx.fmt.trait_parent_clause_name d clause in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name d ^ name + in + ctx_add is_opaque (TraitParentClauseId (d.def_id, clause.clause_id)) name ctx + +let ctx_add_trait_type_clause (d : trait_decl) (item : string) + (clause : trait_clause) (ctx : extraction_ctx) : extraction_ctx = + let is_opaque = false in + let name = ctx.fmt.trait_type_clause_name d item clause in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name d ^ name + in + ctx_add is_opaque + (TraitItemClauseId (d.def_id, item, clause.clause_id)) + name ctx + type names_map_init = { keywords : string list; assumed_adts : (assumed_ty * string) list; @@ -1089,7 +1481,8 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = in let assumed_functions = List.map - (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name)) + (fun (fid, rg, name) -> + (FromLlbc (Pure.FunId (A.Assumed fid), None, rg), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index b72fa078..f4406653 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -70,14 +70,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) method! visit_rvalue _env rv = match rv with - | Use _ | Ref _ | Global _ | Discriminant _ | Aggregate _ -> () + | Use _ | RvRef _ | Global _ | Discriminant _ | Aggregate _ -> () | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail | BinaryOp (bop, _, _) -> can_fail := EU.binop_can_fail bop || !can_fail method! visit_Call env call = (match call.func with - | Regular id -> + | FunId (Regular id) -> if FunDeclId.Set.mem id fun_ids then ( can_diverge := true; is_rec := true) @@ -86,9 +86,13 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) self#may_fail info.can_fail; stateful := !stateful || info.stateful; can_diverge := !can_diverge || info.can_diverge - | Assumed id -> + | FunId (Assumed id) -> (* None of the assumed functions can diverge nor are considered stateful *) - can_fail := !can_fail || Assumed.assumed_can_fail id); + can_fail := !can_fail || Assumed.assumed_can_fail id + | TraitMethod _ -> + (* We consider trait functions can fail, diverge, and are not stateful *) + can_fail := true; + can_diverge := true); super#visit_Call env call method! visit_Panic env = @@ -141,7 +145,8 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) let rec analyze_decl_groups (decls : declaration_group list) : unit = match decls with | [] -> () - | Type _ :: decls' -> analyze_decl_groups decls' + | (Type _ | TraitDecl _ | TraitImpl _) :: decls' -> + analyze_decl_groups decls' | Fun decl :: decls' -> analyze_fun_decl_group decl; analyze_decl_groups decls' diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 154c5a21..24ff4808 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -12,55 +12,165 @@ module SA = SymbolicAst (** The local logger *) let log = L.interpreter_log -let compute_type_fun_global_contexts (m : A.crate) : - C.type_context * C.fun_context * C.global_context = - let type_decls_list, _, _ = split_declarations m.declarations in +let compute_contexts (m : A.crate) : C.decls_ctx = + let type_decls_list, _, _, _, _ = split_declarations m.declarations in let type_decls = m.types in let fun_decls = m.functions in let global_decls = m.globals in - let type_decls_groups, _funs_defs_groups, _globals_defs_groups = + let trait_decls = m.trait_decls in + let trait_impls = m.trait_impls in + let type_decls_groups, _, _, _, _ = split_declarations_to_group_maps m.declarations in let type_infos = TypesAnalysis.analyze_type_declarations type_decls type_decls_list in - let type_context = { C.type_decls_groups; type_decls; type_infos } in - let fun_context = { C.fun_decls } in - let global_context = { C.global_decls } in - (type_context, fun_context, global_context) - -let initialize_eval_context (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) - (const_generic_vars : T.const_generic_var list) : C.eval_ctx = - C.reset_global_counters (); - { - C.type_context; - C.fun_context; - C.global_context; - C.region_groups; - C.type_vars; - C.const_generic_vars; - C.env = [ C.Frame ]; - C.ended_regions = T.RegionId.Set.empty; - } + let type_ctx = { C.type_decls_groups; type_decls; type_infos } in + let fun_infos = + FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state + in + let fun_ctx = { C.fun_decls; fun_infos } in + let global_ctx = { C.global_decls } in + let trait_decls_ctx = { C.trait_decls } in + let trait_impls_ctx = { C.trait_impls } in + { C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx } + +(** Small helper. + + Normalize an instantiated function signature provided we used this signature + to compute a normalization map (for the associated types) and that we added + it in the context. + *) +let normalize_inst_fun_sig (ctx : C.eval_ctx) (sg : A.inst_fun_sig) : + A.inst_fun_sig = + let { A.regions_hierarchy = _; trait_type_constraints = _; inputs; output } = + sg + in + let norm = AssociatedTypes.ctx_normalize_rty ctx in + let inputs = List.map norm inputs in + let output = norm output in + { sg with A.inputs; output } + +(** Instantiate a function signature for a symbolic execution. + + We return a new context because we compute and add the type normalization + map in the same step. + + **WARNING**: this doesn't normalize the types. This step has to be done + separately. Remark: we need to normalize essentially because of the where + clauses (we are not considering a function call, so we don't need to + normalize because a trait clause was instantiated with a specific trait ref). + *) +let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig) + (kind : A.fun_kind) : C.eval_ctx * A.inst_fun_sig = + let tr_self = + match kind with + | RegularKind | TraitMethodImpl _ -> T.UnknownTrait __FUNCTION__ + | TraitMethodDecl _ | TraitMethodProvided _ -> T.Self + in + let generics = + let { T.regions; types; const_generics; trait_clauses } = sg.generics in + let regions = List.map (fun _ -> T.Erased) regions in + let types = List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) types in + let const_generics = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + const_generics + in + (* Annoying that we have to generate this substitution here *) + let r_subst _ = raise (Failure "Unexpected region") in + let ty_subst = Subst.make_type_subst_from_vars sg.generics.types types in + let cg_subst = + Subst.make_const_generic_subst_from_vars sg.generics.const_generics + const_generics + in + (* TODO: some clauses may use the types of other clauses, so we may have to + reorder them. + + Example: + If in Rust we write: + {[ + pub fn use_get<'a, T: Get>(x: &'a mut T) -> u32 + where + T::Item: ToU32, + { + x.get().to_u32() + } + ]} + + In LLBC we get: + {[ + fn demo::use_get<'a, T>(@1: &'a mut (T)) -> u32 + where + [@TraitClause0]: demo::Get<T>, + [@TraitClause1]: demo::ToU32<@TraitClause0::Item>, // HERE + { + ... // Omitted + } + ]} + *) + (* We will need to update the trait refs map while we perform the instantiations *) + let mk_tr_subst + (tr_map : T.erased_region T.trait_instance_id T.TraitClauseId.Map.t) + clause_id : T.erased_region T.trait_instance_id = + match T.TraitClauseId.Map.find_opt clause_id tr_map with + | Some tr -> tr + | None -> raise (Failure "Local trait clause not found") + in + let mk_subst tr_map = + let tr_subst = mk_tr_subst tr_map in + { Subst.r_subst; ty_subst; cg_subst; tr_subst; tr_self } + in + let _, trait_refs = + List.fold_left_map + (fun tr_map (c : T.trait_clause) -> + let subst = mk_subst tr_map in + let { T.trait_id = trait_decl_id; generics; _ } = c in + let generics = Subst.generic_args_substitute subst generics in + let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in + (* Note that because we directly refer to the clause, we give it + empty generics *) + let trait_id = T.Clause c.clause_id in + let trait_ref = + { + T.trait_id; + generics = TypesUtils.mk_empty_generic_args; + trait_decl_ref; + } + in + (* Update the traits map *) + let tr_map = T.TraitClauseId.Map.add c.T.clause_id trait_id tr_map in + (tr_map, trait_ref)) + T.TraitClauseId.Map.empty trait_clauses + in + { T.regions; types; const_generics; trait_refs } + in + let inst_sg = instantiate_fun_sig ctx generics tr_self sg in + (* Compute the normalization maps *) + let ctx = + AssociatedTypes.ctx_add_norm_trait_types_from_preds ctx + inst_sg.trait_type_constraints + in + (* Normalize the signature *) + let inst_sg = normalize_inst_fun_sig ctx inst_sg in + (* Return *) + (ctx, inst_sg) (** Initialize an evaluation context to execute a function. - Introduces local variables initialized in the following manner: - - input arguments are initialized as symbolic values - - the remaining locals are initialized as [⊥] - Abstractions are introduced for the regions present in the function - signature. - - We return: - - the initialized evaluation context - - the list of symbolic values introduced for the input values - - the instantiated function signature + Introduces local variables initialized in the following manner: + - input arguments are initialized as symbolic values + - the remaining locals are initialized as [⊥] + Abstractions are introduced for the regions present in the function + signature. + + We return: + - the initialized evaluation context + - the list of symbolic values introduced for the input values + - the instantiated function signature *) -let initialize_symbolic_context_for_fun (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = +let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl) + : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = (* The abstractions are not initialized the same way as for function * calls: they contain *loan* projectors, because they "provide" us * with the input values (which behave as if they had been returned @@ -78,19 +188,15 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy in let ctx = - initialize_eval_context type_context fun_context global_context - region_groups sg.type_params sg.const_generic_params + initialize_eval_context ctx region_groups sg.generics.types + sg.generics.const_generics in - (* Instantiate the signature *) - let type_params = - List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + (* Instantiate the signature. This updates the context because we compute + at the same time the normalization map for the associated types. + *) + let ctx, inst_sg = + symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind in - let cg_params = - List.map - (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) - sg.const_generic_params - in - let inst_sg = instantiate_fun_sig type_params cg_params sg in (* Create fresh symbolic values for the inputs *) let input_svs = List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs @@ -165,15 +271,9 @@ let evaluate_function_symbolic_synthesize_backward_from_return * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let sg = fdef.signature in - let type_params = - List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params - in - let cg_params = - List.map - (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) - sg.const_generic_params + let _, ret_inst_sg = + symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind in - let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in let ret_rty = ret_inst_sg.output in (* Move the return value out of the return variable *) let pop_return_value = is_regular_return in @@ -347,19 +447,14 @@ let evaluate_function_symbolic_synthesize_backward_from_return for the synthesis) - the symbolic AST generated by the symbolic execution *) -let evaluate_function_symbolic (synthesize : bool) - (type_context : C.type_context) (fun_context : C.fun_context) - (global_context : C.global_context) (fdef : A.fun_decl) : - V.symbolic_value list * SA.expression option = +let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx) + (fdef : A.fun_decl) : V.symbolic_value list * SA.expression option = (* Debug *) let name_to_string () = Print.fun_name_to_string fdef.A.name in log#ldebug (lazy ("evaluate_function_symbolic: " ^ name_to_string ())); (* Create the evaluation context *) - let ctx, input_svs, inst_sg = - initialize_symbolic_context_for_fun type_context fun_context global_context - fdef - in + let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in (* Create the continuation to finish the evaluation *) let config = C.mk_config C.SymbolicMode in @@ -488,7 +583,8 @@ module Test = struct (** Test a unit function (taking no arguments) by evaluating it in an empty environment. *) - let test_unit_function (crate : A.crate) (fid : A.FunDeclId.id) : unit = + let test_unit_function (crate : A.crate) (decls_ctx : C.decls_ctx) + (fid : A.FunDeclId.id) : unit = (* Retrieve the function declaration *) let fdef = A.FunDeclId.Map.find fid crate.functions in let body = Option.get fdef.body in @@ -498,17 +594,11 @@ module Test = struct (lazy ("test_unit_function: " ^ Print.fun_name_to_string fdef.A.name)); (* Sanity check - *) - assert (List.length fdef.A.signature.region_params = 0); - assert (List.length fdef.A.signature.type_params = 0); + assert (fdef.A.signature.generics = TypesUtils.mk_empty_generic_params); assert (body.A.arg_count = 0); (* Create the evaluation context *) - let type_context, fun_context, global_context = - compute_type_fun_global_contexts crate - in - let ctx = - initialize_eval_context type_context fun_context global_context [] [] [] - in + let ctx = initialize_eval_context decls_ctx [] [] [] in (* Insert the (uninitialized) local variables *) let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in @@ -536,9 +626,7 @@ module Test = struct (no parameters, no arguments) - TODO: move *) let fun_decl_is_transparent_unit (def : A.fun_decl) : bool = Option.is_some def.body - && def.A.signature.region_params = [] - && def.A.signature.type_params = [] - && def.A.signature.const_generic_params = [] + && def.A.signature.generics = TypesUtils.mk_empty_generic_params && def.A.signature.inputs = [] (** Test all the unit functions in a list of function definitions *) @@ -548,24 +636,9 @@ module Test = struct (fun _ -> fun_decl_is_transparent_unit) crate.functions in + let decls_ctx = compute_contexts crate in let test_unit_fun _ (def : A.fun_decl) : unit = - test_unit_function crate def.A.def_id + test_unit_function crate decls_ctx def.A.def_id in A.FunDeclId.Map.iter test_unit_fun unit_funs - - (** Execute the symbolic interpreter on a function. *) - let test_function_symbolic (synthesize : bool) (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (fdef : A.fun_decl) : unit = - (* Debug *) - log#ldebug - (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name)); - - (* Evaluate *) - let _ = - evaluate_function_symbolic synthesize type_context fun_context - global_context fdef - in - - () end diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 4d67a4e4..e97795a1 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -452,7 +452,8 @@ let give_back_symbolic_value (_config : C.config) | V.SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack -> () - | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate -> + | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate + | ConstGeneric | TraitConst -> raise (Failure "Unreachable")); (* Store the given-back value as a meta-value for synthesis purposes *) let mv = nsv in diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml index bf083aa4..e7da045c 100644 --- a/compiler/InterpreterBorrowsCore.ml +++ b/compiler/InterpreterBorrowsCore.ml @@ -100,15 +100,18 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) (compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool) (ty1 : T.rty) (ty2 : T.rty) : bool = let compare = compare_rtys default combine compare_regions in + (* Normalize the associated types *) match (ty1, ty2) with | T.Literal lit1, T.Literal lit2 -> assert (lit1 = lit2); default - | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) -> + | T.Adt (id1, generics1), T.Adt (id2, generics2) -> assert (id1 = id2); (* There are no regions in the const generics, so we ignore them, but we still check they are the same, for sanity *) - assert (cgs1 = cgs2); + assert (generics1.const_generics = generics2.const_generics); + + (* We also ignore the trait refs *) (* The check for the ADTs is very crude: we simply compare the arguments * two by two. @@ -123,14 +126,14 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) * this check would still be a reasonable conservative approximation. *) (* Check the region parameters *) - let regions = List.combine regions1 regions2 in + let regions = List.combine generics1.regions generics2.regions in let params_b = List.fold_left (fun b (r1, r2) -> combine b (compare_regions r1 r2)) default regions in (* Check the type parameters *) - let tys = List.combine tys1 tys2 in + let tys = List.combine generics1.types generics2.types in let tys_b = List.fold_left (fun b (ty1, ty2) -> combine b (compare ty1 ty2)) @@ -150,6 +153,11 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) | T.TypeVar id1, T.TypeVar id2 -> assert (id1 = id2); default + | T.TraitType _, T.TraitType _ -> + (* The types should have been normalized. If after normalization we + get trait types, we can consider them as variables *) + assert (ty1 = ty2); + default | _ -> log#lerror (lazy diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index 81e73e3e..ea692386 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -9,6 +9,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils module Inv = Invariants @@ -204,7 +205,7 @@ let apply_symbolic_expansion_non_borrow (config : C.config) apply_symbolic_expansion_to_avalues config allow_reborrows original_sv expansion ctx -(** Compute the expansion of a non-assumed (i.e.: not [Option], [Box], etc.) +(** Compute the expansion of a non-assumed (i.e.: not [Box], etc.) adt value. The function might return a list of values if the symbolic value to expand @@ -214,18 +215,15 @@ let apply_symbolic_expansion_non_borrow (config : C.config) doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) - (kind : V.sv_kind) (def_id : T.TypeDeclId.id) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list - = + (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (generics : T.rgeneric_args) + (ctx : C.eval_ctx) : V.symbolic_expansion list = (* Lookup the definition and check if it is an enumeration with several * variants *) let def = C.ctx_lookup_type_decl ctx def_id in - assert (List.length regions = List.length def.T.region_params); + assert (List.length generics.regions = List.length def.T.generics.regions); (* Retrieve, for every variant, the list of its instantiated field types *) let variants_fields_types = - Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types - cgs + Assoc.type_decl_get_inst_norm_variants_fields_rtypes ctx def generics in (* Check if there is strictly more than one variant *) if List.length variants_fields_types > 1 && not expand_enumerations then @@ -280,15 +278,14 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) : doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_adt_value (expand_enumerations : bool) - (kind : V.sv_kind) (adt_id : T.type_id) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list - = - match (adt_id, regions, types) with + (kind : V.sv_kind) (adt_id : T.type_id) (generics : T.rgeneric_args) + (ctx : C.eval_ctx) : V.symbolic_expansion list = + match (adt_id, generics.regions, generics.types) with | T.AdtId def_id, _, _ -> compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind - def_id regions types cgs ctx - | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ] + def_id generics ctx + | T.Tuple, [], _ -> + [ compute_expanded_symbolic_tuple_value kind generics.types ] | T.Assumed T.Option, [], [ ty ] -> compute_expanded_symbolic_option_value expand_enumerations kind ty | T.Assumed T.Box, [], [ boxed_ty ] -> @@ -543,12 +540,12 @@ let expand_symbolic_value_no_branching (config : C.config) fun cf ctx -> match rty with (* ADTs *) - | T.Adt (adt_id, regions, types, cgs) -> + | T.Adt (adt_id, generics) -> (* Compute the expanded value *) let allow_branching = false in let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types cgs ctx + generics ctx in (* There should be exacly one branch *) let see = Collections.List.to_cons_nil seel in @@ -600,12 +597,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value) (* Execute *) match rty with (* ADTs *) - | T.Adt (adt_id, regions, types, cgs) -> + | T.Adt (adt_id, generics) -> let allow_branching = true in (* Compute the expanded value *) let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types cgs ctx + generics ctx in (* Apply *) let seel = List.map (fun see -> (Some see, cf_branches)) seel in @@ -679,7 +676,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = ^ symbolic_value_to_string ctx sv)); let cc : cm_fun = match sv.V.sv_ty with - | T.Adt (AdtId def_id, _, _, _) -> + | T.Adt (AdtId def_id, _) -> (* {!expand_symbolic_value_no_branching} checks if there are branchings, * but we prefer to also check it here - this leads to cleaner messages * and debugging *) @@ -704,16 +701,16 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = [config]): " ^ Print.name_to_string def.name)) else expand_symbolic_value_no_branching config sv None - | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) -> + | T.Adt ((Tuple | Assumed Box), _) | T.Ref (_, _, _) -> (* Ok *) expand_symbolic_value_no_branching config sv None - | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _) - -> + | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _) -> (* We can't expand those *) raise (Failure "Attempted to greedily expand an ADT which can't be expanded ") - | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable") + | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ -> + raise (Failure "Unreachable") in (* Compose and continue *) comp cc expand cf ctx diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 8b2070c6..29826233 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -7,6 +7,7 @@ module E = Expressions open Utils module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils open ValuesUtils @@ -141,11 +142,18 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config) | V.Adt av -> (* Sanity check *) (match v.V.ty with - | T.Adt (T.Assumed (T.Box | Vec), _, _, _) -> + | T.Adt (T.Assumed (T.Box | Vec), _) -> raise (Failure "Can't copy an assumed value other than Option") - | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy - | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *) - | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) -> + | T.Adt (T.AdtId _, _) -> assert allow_adt_copy + | T.Adt ((T.Assumed Option | T.Tuple), _) -> () (* Ok *) + | T.Adt + ( T.Assumed (Slice | T.Array), + { + regions = []; + types = [ ty ]; + const_generics = []; + trait_refs = []; + } ) -> assert (ty_is_primitively_copyable ty) | _ -> raise (Failure "Unreachable")); let ctx, fields = @@ -230,17 +238,16 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : let prepare : cm_fun = fun cf ctx -> match op with - | Expressions.Constant (ty, cv) -> + | E.Constant _ -> (* No need to reorganize the context *) - literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore; cf ctx - | Expressions.Copy p -> + | E.Copy p -> (* Access the value *) let access = Read in (* Expand the symbolic values, if necessary *) let expand_prim_copy = true in access_rplace_reorganize config expand_prim_copy access p cf ctx - | Expressions.Move p -> + | E.Move p -> (* Access the value *) let access = Move in let expand_prim_copy = false in @@ -260,9 +267,70 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> - cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx - | Expressions.Copy p -> + | E.Constant cv -> ( + match cv.value with + | E.CLiteral lit -> + cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx + | E.TraitConst (trait_ref, generics, const_name) -> ( + assert (generics = TypesUtils.mk_empty_generic_args); + match trait_ref.trait_id with + | T.TraitImpl _ -> + (* This shouldn't happen: if we refer to a concrete implementation, we + should directly refer to the top-level constant *) + raise (Failure "Unreachable") + | _ -> ( + (* We refer to a constant defined in a local clause: simply + introduce a fresh symbolic value *) + let ctx0 = ctx in + (* Lookup the trait declaration to retrieve the type of the symbolic value *) + let trait_decl = + C.ctx_lookup_trait_decl ctx + trait_ref.trait_decl_ref.trait_decl_id + in + let _, (ty, _) = + List.find (fun (name, _) -> name = const_name) trait_decl.consts + in + (* Introduce a fresh symbolic value *) + let v = mk_fresh_symbolic_typed_value_from_ety V.TraitConst ty in + (* Continue the evaluation *) + let e = cf v ctx in + (* We have to wrap the generated expression *) + match e with + | None -> None + | Some e -> + Some + (SymbolicAst.IntroSymbolic + ( ctx0, + None, + value_as_symbolic v.value, + SymbolicAst.TraitConstValue + (trait_ref, generics, const_name), + e )))) + | E.CVar vid -> ( + let ctx0 = ctx in + (* Lookup the const generic value *) + let cv = C.ctx_lookup_const_generic_value ctx vid in + (* Copy the value *) + let allow_adt_copy = false in + let ctx, v = copy_value allow_adt_copy config ctx cv in + (* Continue *) + let e = cf v ctx in + (* We have to wrap the generated expression *) + match e with + | None -> None + | Some e -> + (* If we are synthesizing a symbolic AST, it means that we are in symbolic + mode: the value of the const generic is necessarily symbolic. *) + assert (is_symbolic cv.V.value); + (* *) + Some + (SymbolicAst.IntroSymbolic + ( ctx0, + None, + value_as_symbolic v.value, + SymbolicAst.ConstGenericValue vid, + e )))) + | E.Copy p -> (* Access the value *) let access = Read in let cc = read_place access p in @@ -283,7 +351,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) in (* Compose and apply *) comp cc copy cf ctx - | Expressions.Move p -> + | E.Move p -> (* Access the value *) let access = Move in let cc = read_place access p in @@ -656,7 +724,8 @@ let eval_rvalue_aggregate (config : C.config) | E.AggregatedTuple -> let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in let v = V.Adt { variant_id = None; field_values = values } in - let ty = T.Adt (T.Tuple, [], tys, []) in + let generics = TypesUtils.mk_generic_args [] tys [] [] in + let ty = T.Adt (T.Tuple, generics) in let aggregated : V.typed_value = { V.value = v; ty } in (* Call the continuation *) cf aggregated ctx @@ -667,20 +736,22 @@ let eval_rvalue_aggregate (config : C.config) assert (List.length values = 1) else raise (Failure "Unreachable"); (* Construt the value *) - let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in + let generics = TypesUtils.mk_generic_args [] [ ty ] [] [] in + let aty = T.Adt (T.Assumed T.Option, generics) in let av : V.adt_value = { V.variant_id = Some variant_id; V.field_values = values } in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx - | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) -> + | E.AggregatedAdt (def_id, opt_variant_id, generics) -> (* Sanity checks *) let type_decl = C.ctx_lookup_type_decl ctx def_id in - assert (List.length type_decl.region_params = List.length regions); + assert ( + List.length type_decl.generics.regions = List.length generics.regions); let expected_field_types = - Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id - types cgs + Assoc.ctx_adt_get_inst_norm_field_etypes ctx def_id opt_variant_id + generics in assert ( expected_field_types @@ -689,7 +760,7 @@ let eval_rvalue_aggregate (config : C.config) let av : V.adt_value = { V.variant_id = opt_variant_id; V.field_values = values } in - let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in + let aty = T.Adt (T.AdtId def_id, generics) in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx @@ -709,7 +780,8 @@ let eval_rvalue_aggregate (config : C.config) let av : V.adt_value = { V.variant_id = None; V.field_values = values } in - let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in + let generics = TypesUtils.mk_generic_args_from_types [ ety ] in + let aty = T.Adt (T.Assumed T.Range, generics) in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx @@ -719,7 +791,8 @@ let eval_rvalue_aggregate (config : C.config) (* Sanity check: the number of values is consistent with the length *) let len = (literal_as_scalar (const_generic_as_literal cg)).value in assert (len = Z.of_int (List.length values)); - let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in + let generics = TypesUtils.mk_generic_args [] [ ety ] [ cg ] [] in + let ty = T.Adt (T.Assumed T.Array, generics) in (* In order to generate a better AST, we introduce a symbolic value equal to the array. The reason is that otherwise, the array we introduce here might be duplicated in the generated @@ -752,7 +825,7 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue) (* Delegate to the proper auxiliary function *) match rvalue with | E.Use op -> comp_wrap (eval_operand config op) ctx - | E.Ref (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx + | E.RvRef (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx | E.UnaryOp (unop, op) -> eval_unary_op config unop op cf ctx | E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx | E.Aggregate (aggregate_kind, ops) -> diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index bf88e055..6d3ecb18 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -554,9 +554,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + trait_decls_context; + trait_impls_context; region_groups; type_vars; const_generic_vars; + const_generic_vars_map; + norm_trait_etypes; + norm_trait_rtypes; + norm_trait_stypes; env = _; ended_regions = ended_regions0; } = @@ -566,9 +572,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context = _; fun_context = _; global_context = _; + trait_decls_context = _; + trait_impls_context = _; region_groups = _; type_vars = _; const_generic_vars = _; + const_generic_vars_map = _; + norm_trait_etypes = _; + norm_trait_rtypes = _; + norm_trait_stypes = _; env = _; ended_regions = ended_regions1; } = @@ -580,9 +592,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + trait_decls_context; + trait_impls_context; region_groups; type_vars; const_generic_vars; + const_generic_vars_map; + norm_trait_etypes; + norm_trait_rtypes; + norm_trait_stypes; env; ended_regions; } diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 9248e513..8cab546e 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -149,20 +149,25 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) (match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty = let match_rec = match_types match_distinct_types match_regions in match (ty0, ty1) with - | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) -> + | Adt (id0, generics0), Adt (id1, generics1) -> assert (id0 = id1); - assert (cgs0 = cgs1); + assert (generics0.const_generics = generics1.const_generics); + assert (generics0.trait_refs = generics1.trait_refs); let id = id0 in - let cgs = cgs1 in + let const_generics = generics1.const_generics in + let trait_refs = generics1.trait_refs in let regions = List.map (fun (id0, id1) -> match_regions id0 id1) - (List.combine regions0 regions1) + (List.combine generics0.regions generics1.regions) in - let tys = - List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1) + let types = + List.map + (fun (ty0, ty1) -> match_rec ty0 ty1) + (List.combine generics0.types generics1.types) in - Adt (id, regions, tys, cgs) + let generics = { T.regions; types; const_generics; trait_refs } in + Adt (id, generics) | TypeVar vid0, TypeVar vid1 -> assert (vid0 = vid1); let vid = vid0 in diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 04dc8892..465d0028 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open Cps open ValuesUtils @@ -97,7 +98,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) match (pe, v.V.value, v.V.ty) with | ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id), V.Adt adt, - T.Adt (type_id, _, _, _) ) -> ( + T.Adt (type_id, _) ) -> ( (* Check consistency *) (match (proj_kind, type_id) with | ProjAdt (def_id, opt_variant_id), T.AdtId def_id' -> @@ -119,8 +120,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) let updated = { v with value = nadt } in Ok (ctx, { res with updated })) (* Tuples *) - | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _) - -> ( + | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _) -> ( assert (arity = List.length adt.field_values); let fv = T.FieldId.nth adt.field_values field_id in (* Project *) @@ -145,9 +145,9 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) (* Box dereferencement *) | ( DerefBox, Adt { variant_id = None; field_values = [ bv ] }, - T.Adt (T.Assumed T.Box, _, _, _) ) -> ( - (* We allow moving inside of boxes. In practice, this kind of - * manipulations should happen only inside unsage code, so + T.Adt (T.Assumed T.Box, _) ) -> ( + (* We allow moving outside of boxes. In practice, this kind of + * manipulations should happen only inside unsafe code, so * it shouldn't happen due to user code, and we leverage it * when implementing box dereferencement for the concrete * interpreter *) @@ -357,24 +357,23 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value) | Error e -> raise (Failure ("Unreachable: " ^ show_path_fail_kind e)) | Ok ctx -> ctx -let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) +let compute_expanded_bottom_adt_value (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.erased_region list) (types : T.ety list) - (cgs : T.const_generic list) : V.typed_value = + (generics : T.egeneric_args) : V.typed_value = (* Lookup the definition and check if it is an enumeration - it should be an enumeration if and only if the projection element is a field projection with *some* variant id. Retrieve the list of fields at the same time. *) - let def = T.TypeDeclId.Map.find def_id tyctx in - assert (List.length regions = List.length def.T.region_params); + let def = C.ctx_lookup_type_decl ctx def_id in + assert (List.length generics.regions = List.length def.T.generics.regions); (* Compute the field types *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + Assoc.type_decl_get_inst_norm_field_etypes ctx def opt_variant_id generics in (* Initialize the expanded value *) let fields = List.map mk_bottom field_types in let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in - let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in + let ty = T.Adt (T.AdtId def_id, generics) in { V.value = av; V.ty } let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) @@ -387,7 +386,8 @@ let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) else raise (Failure "Unreachable") in let av = V.Adt { variant_id = Some variant_id; field_values } in - let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in + let generics = TypesUtils.mk_generic_args [] [ param_ty ] [] [] in + let ty = T.Adt (T.Assumed T.Option, generics) in { V.value = av; ty } let compute_expanded_bottom_tuple_value (field_types : T.ety list) : @@ -395,7 +395,8 @@ let compute_expanded_bottom_tuple_value (field_types : T.ety list) : (* Generate the field values *) let fields = List.map mk_bottom field_types in let v = V.Adt { variant_id = None; field_values = fields } in - let ty = T.Adt (T.Tuple, [], field_types, []) in + let generics = TypesUtils.mk_generic_args [] field_types [] [] in + let ty = T.Adt (T.Tuple, generics) in { V.value = v; V.ty } (** Auxiliary helper to expand {!V.Bottom} values. @@ -447,19 +448,29 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place) match (pe, ty) with (* "Regular" ADTs *) | ( Field (ProjAdt (def_id, opt_variant_id), _), - T.Adt (T.AdtId def_id', regions, types, cgs) ) -> + T.Adt (T.AdtId def_id', generics) ) -> assert (def_id = def_id'); - compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id - opt_variant_id regions types cgs + compute_expanded_bottom_adt_value ctx def_id opt_variant_id generics (* Option *) | ( Field (ProjOption variant_id, _), - T.Adt (T.Assumed T.Option, [], [ ty ], []) ) -> + T.Adt + ( T.Assumed T.Option, + { + T.regions = []; + types = [ ty ]; + const_generics = []; + trait_refs = []; + } ) ) -> compute_expanded_bottom_option_value variant_id ty (* Tuples *) - | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) -> - assert (arity = List.length tys); + | ( Field (ProjTuple arity, _), + T.Adt + ( T.Tuple, + { T.regions = []; types; const_generics = []; trait_refs = [] } ) ) + -> + assert (arity = List.length types); (* Generate the field values *) - compute_expanded_bottom_tuple_value tys + compute_expanded_bottom_tuple_value types | _ -> raise (Failure diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli index 4a9f3b41..041b0a97 100644 --- a/compiler/InterpreterPaths.mli +++ b/compiler/InterpreterPaths.mli @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open Cps open InterpreterExpansion @@ -56,12 +57,10 @@ val compute_expanded_bottom_tuple_value : T.ety list -> V.typed_value (** Compute an expanded ADT ⊥ value *) val compute_expanded_bottom_adt_value : - T.type_decl T.TypeDeclId.Map.t -> + C.eval_ctx -> T.TypeDeclId.id -> T.VariantId.id option -> - T.erased_region list -> - T.ety list -> - T.const_generic list -> + T.egeneric_args -> V.typed_value (** Compute an expanded [Option] ⊥ value *) diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml index faed066b..9e0c2b75 100644 --- a/compiler/InterpreterProjectors.ml +++ b/compiler/InterpreterProjectors.ml @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils open InterpreterUtils @@ -24,12 +25,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx) else match (v.V.value, ty) with | V.Literal _, T.Literal _ -> [] - | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> + | V.Adt adt, T.Adt (id, generics) -> (* Retrieve the types of the fields *) let field_types = - Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys cgs + Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics in + (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in let proj_fields = @@ -103,11 +104,10 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx) let value : V.avalue = match (v.V.value, ty) with | V.Literal _, T.Literal _ -> V.AIgnored - | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> + | V.Adt adt, T.Adt (id, generics) -> (* Retrieve the types of the fields *) let field_types = - Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys cgs + Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -268,8 +268,7 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t) let (value, ty) : V.avalue * T.rty = match (see, original_sv_ty) with | SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty) - | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs) - -> + | SeAdt (variant_id, field_values), T.Adt (_id, _generics) -> (* Project over the field values *) let field_values = List.map diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 045c4484..36bc3492 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -17,6 +17,7 @@ open InterpreterProjectors open InterpreterExpansion open InterpreterPaths open InterpreterExpressions +module PCtx = Print.EvalCtxLlbcAst (** The local logger *) let log = L.statements_log @@ -232,9 +233,8 @@ let set_discriminant (config : C.config) (p : E.place) let update_value cf (v : V.typed_value) : m_fun = fun ctx -> match (v.V.ty, v.V.value) with - | ( T.Adt - (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), - V.Adt av ) -> ( + | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Adt av + -> ( (* There are two situations: - either the discriminant is already the proper one (in which case we don't do anything) @@ -251,28 +251,26 @@ let set_discriminant (config : C.config) (p : E.place) let bottom_v = match type_id with | T.AdtId def_id -> - compute_expanded_bottom_adt_value - ctx.type_context.type_decls def_id (Some variant_id) - regions types cgs + compute_expanded_bottom_adt_value ctx def_id + (Some variant_id) generics | T.Assumed T.Option -> - assert (regions = []); + assert (generics.regions = []); compute_expanded_bottom_option_value variant_id - (Collections.List.to_cons_nil types) + (Collections.List.to_cons_nil generics.types) | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx) - | ( T.Adt - (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), - V.Bottom ) -> + | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Bottom + -> let bottom_v = match type_id with | T.AdtId def_id -> - compute_expanded_bottom_adt_value ctx.type_context.type_decls - def_id (Some variant_id) regions types cgs + compute_expanded_bottom_adt_value ctx def_id (Some variant_id) + generics | T.Assumed T.Option -> - assert (regions = []); + assert (generics.regions = []); compute_expanded_bottom_option_value variant_id - (Collections.List.to_cons_nil types) + (Collections.List.to_cons_nil generics.types) | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx @@ -301,24 +299,34 @@ let ctx_push_frame (ctx : C.eval_ctx) : C.eval_ctx = let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx) (** Small helper: compute the type of the return value for a specific - instantiation of a non-local function. + instantiation of an assumed function. *) -let get_non_local_function_return_type (fid : A.assumed_fun_id) - (region_params : T.erased_region list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) : T.ety = +let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id) + (generics : T.egeneric_args) : T.ety = + assert (generics.trait_refs = []); (* [Box::free] has a special treatment *) - match (fid, region_params, type_params, const_generic_params) with - | A.BoxFree, [], [ _ ], [] -> mk_unit_ty + match fid with + | A.BoxFree -> + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + mk_unit_ty | _ -> (* Retrieve the function's signature *) let sg = Assumed.get_assumed_sig fid in (* Instantiate the return type *) - let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in - let cgsubst = - Subst.make_const_generic_subst_from_vars sg.const_generic_params - const_generic_params + (* There shouldn't be any reference to Self *) + let tr_self : T.erased_region T.trait_instance_id = + T.UnknownTrait __FUNCTION__ + in + let { Subst.r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } = + Subst.make_esubst_from_generics sg.generics generics tr_self in - Subst.erase_regions_substitute_types tsubst cgsubst sg.output + let ty = + Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self + sg.output + in + Assoc.ctx_normalize_ety ctx ty let move_return_value (config : C.config) (pop_return_value : bool) (cf : V.typed_value option -> m_fun) : m_fun = @@ -418,19 +426,19 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun = in comp cf_pop cf_assign -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_replace_concrete (_config : C.config) - (_region_params : T.erased_region list) (_type_params : T.ety list) - (_cg_params : T.const_generic list) : cm_fun = +(** Auxiliary function - see {!eval_assumed_function_call} *) +let eval_replace_concrete (_config : C.config) (_generics : T.egeneric_args) : + cm_fun = fun _cf _ctx -> raise Unimplemented -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_new_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = +(** Auxiliary function - see {!eval_assumed_function_call} *) +let eval_box_new_concrete (config : C.config) (generics : T.egeneric_args) : + cm_fun = fun cf ctx -> (* Check and retrieve the arguments *) - match (region_params, type_params, cg_params, ctx.env) with + match + (generics.regions, generics.types, generics.const_generics, ctx.env) + with | ( [], [ boxed_ty ], [], @@ -448,7 +456,8 @@ let eval_box_new_concrete (config : C.config) (* Create the new box *) let cf_create cf (moved_input_value : V.typed_value) : m_fun = (* Create the box value *) - let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in + let generics = TypesUtils.mk_generic_args_from_types [ boxed_ty ] in + let box_ty = T.Adt (T.Assumed T.Box, generics) in let box_v = V.Adt { variant_id = None; field_values = [ moved_input_value ] } in @@ -467,13 +476,14 @@ let eval_box_new_concrete (config : C.config) | _ -> raise (Failure "Inconsistent state") (** Auxiliary function which factorizes code to evaluate [std::Deref::deref] - and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *) + and [std::DerefMut::deref_mut] - see {!eval_assumed_function_call} *) let eval_box_deref_mut_or_shared_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) (is_mut : bool) : cm_fun = + (generics : T.egeneric_args) (is_mut : bool) : cm_fun = fun cf ctx -> (* Check the arguments *) - match (region_params, type_params, cg_params, ctx.env) with + match + (generics.regions, generics.types, generics.const_generics, ctx.env) + with | ( [], [ boxed_ty ], [], @@ -495,7 +505,7 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) { E.var_id = input_var.C.index; projection = [ E.Deref; E.DerefBox ] } in let borrow_kind = if is_mut then E.Mut else E.Shared in - let rv = E.Ref (p, borrow_kind) in + let rv = E.RvRef (p, borrow_kind) in let cf_borrow = eval_rvalue_not_global config rv in (* Move the borrow to its destination *) @@ -514,23 +524,19 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) comp cf_borrow cf_move cf ctx | _ -> raise (Failure "Inconsistent state") -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_deref_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = +(** Auxiliary function - see {!eval_assumed_function_call} *) +let eval_box_deref_concrete (config : C.config) (generics : T.egeneric_args) : + cm_fun = let is_mut = false in - eval_box_deref_mut_or_shared_concrete config region_params type_params - cg_params is_mut + eval_box_deref_mut_or_shared_concrete config generics is_mut -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_deref_mut_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = +(** Auxiliary function - see {!eval_assumed_function_call} *) +let eval_box_deref_mut_concrete (config : C.config) (generics : T.egeneric_args) + : cm_fun = let is_mut = true in - eval_box_deref_mut_or_shared_concrete config region_params type_params - cg_params is_mut + eval_box_deref_mut_or_shared_concrete config generics is_mut -(** Auxiliary function - see {!eval_non_local_function_call}. +(** Auxiliary function - see {!eval_assumed_function_call}. [Box::free] is not handled the same way as the other assumed functions: - in the regular case, whenever we need to evaluate an assumed function, @@ -549,11 +555,10 @@ let eval_box_deref_mut_concrete (config : C.config) It thus updates the box value (by calling {!drop_value}) and updates the destination (by setting it to [()]). *) -let eval_box_free (config : C.config) (region_params : T.erased_region list) - (type_params : T.ety list) (cg_params : T.const_generic list) +let eval_box_free (config : C.config) (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) : cm_fun = fun cf ctx -> - match (region_params, type_params, cg_params, args) with + match (generics.regions, generics.types, generics.const_generics, args) with | [], [ boxed_ty ], [], [ E.Move input_box_place ] -> (* Required type checking *) let input_box = InterpreterPaths.read_place Write input_box_place ctx in @@ -570,17 +575,20 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list) cc cf ctx | _ -> raise (Failure "Inconsistent state") -(** Auxiliary function - see {!eval_non_local_function_call} *) +(** Auxiliary function - see {!eval_assumed_function_call} *) let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id) - (_region_params : T.erased_region list) (_type_params : T.ety list) - (_cg_params : T.const_generic list) : cm_fun = + (_generics : T.egeneric_args) : cm_fun = fun _cf _ctx -> raise Unimplemented (** Evaluate a non-local function call in concrete mode *) -let eval_non_local_function_call_concrete (config : C.config) - (fid : A.assumed_fun_id) (region_params : T.erased_region list) - (type_params : T.ety list) (cg_params : T.const_generic list) - (args : E.operand list) (dest : E.place) : cm_fun = +let eval_assumed_function_call_concrete (config : C.config) + (fid : A.assumed_fun_id) (call : A.call) : cm_fun = + let generics = call.generics in + let args = call.args in + let dest = call.dest in + (* Sanity check: we don't fully handle the const generic vars environment + in concrete mode yet *) + assert (generics.const_generics = []); (* There are two cases (and this is extremely annoying): - the function is not box_free - the function is box_free @@ -589,7 +597,7 @@ let eval_non_local_function_call_concrete (config : C.config) match fid with | A.BoxFree -> (* Degenerate case: box_free *) - eval_box_free config region_params type_params cg_params args dest + eval_box_free config generics args dest | _ -> (* "Normal" case: not box_free *) (* Evaluate the operands *) @@ -604,16 +612,14 @@ let eval_non_local_function_call_concrete (config : C.config) * but it made it less clear where the computed values came from, * so we reversed the modifications. *) let cf_eval_call cf (args_vl : V.typed_value list) : m_fun = + fun ctx -> (* Push the stack frame: we initialize the frame with the return variable, and one variable per input argument *) let cc = push_frame in (* Create and push the return variable *) let ret_vid = E.VarId.zero in - let ret_ty = - get_non_local_function_return_type fid region_params type_params - cg_params - in + let ret_ty = get_assumed_function_return_type ctx fid generics in let ret_var = mk_var ret_vid (Some "@return") ret_ty in let cc = comp cc (push_uninitialized_var ret_var) in @@ -630,23 +636,17 @@ let eval_non_local_function_call_concrete (config : C.config) * access to a body. *) let cf_eval_body : cm_fun = match fid with - | A.Replace -> - eval_replace_concrete config region_params type_params cg_params - | BoxNew -> - eval_box_new_concrete config region_params type_params cg_params - | BoxDeref -> - eval_box_deref_concrete config region_params type_params cg_params - | BoxDerefMut -> - eval_box_deref_mut_concrete config region_params type_params - cg_params + | A.Replace -> eval_replace_concrete config generics + | BoxNew -> eval_box_new_concrete config generics + | BoxDeref -> eval_box_deref_concrete config generics + | BoxDerefMut -> eval_box_deref_mut_concrete config generics | BoxFree -> (* Should have been treated above *) raise (Failure "Unreachable") | VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut -> - eval_vec_function_concrete config fid region_params type_params - cg_params + eval_vec_function_concrete config fid generics | ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut - | SliceIndexShared | SliceIndexMut | SliceSubsliceShared + | ArrayRepeat | SliceIndexShared | SliceIndexMut | SliceSubsliceShared | SliceSubsliceMut | SliceLen -> raise (Failure "Unimplemented") in @@ -657,50 +657,11 @@ let eval_non_local_function_call_concrete (config : C.config) let cc = comp cc (pop_frame_assign config dest) in (* Continue *) - cc cf + cc cf ctx in (* Compose and apply *) comp cf_eval_ops cf_eval_call -let instantiate_fun_sig (type_params : T.ety list) - (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig = - (* Generate fresh abstraction ids and create a substitution from region - * group ids to abstraction ids *) - let rg_abs_ids_bindings = - List.map - (fun rg -> - let abs_id = C.fresh_abstraction_id () in - (rg.T.id, abs_id)) - sg.regions_hierarchy - in - let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t = - List.fold_left - (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp) - T.RegionGroupId.Map.empty rg_abs_ids_bindings - in - let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id = - T.RegionGroupId.Map.find rg_id asubst_map - in - (* Generate fresh regions and their substitutions *) - let _, rsubst, _ = Subst.fresh_regions_with_substs sg.region_params in - (* Generate the type substitution - * Note that we need the substitution to map the type variables to - * {!rty} types (not {!ety}). In order to do that, we convert the - * type parameters to types with regions. This is possible only - * if those types don't contain any regions. - * This is a current limitation of the analysis: there is still some - * work to do to properly handle full type parametrization. - * *) - let rtype_params = List.map ety_no_regions_to_rty type_params in - let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in - let cgsubst = - Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params - in - (* Substitute the signature *) - let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in - (* Return *) - inst_sig - (** Helper Create abstractions (with no avalues, which have to be inserted afterwards) @@ -836,7 +797,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = match rvalue with | E.Global _ -> raise (Failure "Unreachable") | E.Use _ - | E.Ref (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow)) + | E.RvRef (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow)) | E.UnaryOp _ | E.BinaryOp _ | E.Discriminant _ | E.Aggregate _ -> let rp = rvalue_get_place rvalue in @@ -893,7 +854,16 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) match config.mode with | ConcreteMode -> (* Treat the evaluation of the global as a call to the global body (without arguments) *) - (eval_local_function_call_concrete config global.body_id [] [] [] [] dest) + let call = + { + A.func = A.FunId (A.Regular global.body_id); + generics = TypesUtils.mk_empty_generic_args; + trait_and_method_generic_args = None; + args = []; + dest; + } + in + (eval_transparent_function_call_concrete config global.body_id call) cf ctx | SymbolicMode -> (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be @@ -1037,126 +1007,354 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = (** Evaluate a function call (auxiliary helper for [eval_statement]) *) and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = - (* There are two cases: + (* There are several cases: - this is a local function, in which case we execute its body - - this is a non-local function, in which case there is a special treatment + - this is an assumed function, in which case there is a special treatment + - this is a trait method *) - match call.func with - | A.Regular fid -> - eval_local_function_call config fid call.region_args call.type_args - call.const_generic_args call.args call.dest - | A.Assumed fid -> - eval_non_local_function_call config fid call.region_args call.type_args - call.const_generic_args call.args call.dest + match config.mode with + | C.ConcreteMode -> eval_function_call_concrete config call + | C.SymbolicMode -> eval_function_call_symbolic config call -(** Evaluate a local (i.e., non-assumed) function call in concrete mode *) -and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) - (_region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = +and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun + = fun cf ctx -> - (* Retrieve the (correctly instantiated) body *) - let def = C.ctx_lookup_fun_decl ctx fid in - (* We can evaluate the function call only if it is not opaque *) - let body = - match def.body with - | None -> - raise - (Failure - ("Can't evaluate a call to an opaque function: " - ^ Print.name_to_string def.name)) - | Some body -> body - in - let tsubst = - Subst.make_type_subst_from_vars def.A.signature.type_params type_args - in - let cgsubst = - Subst.make_const_generic_subst_from_vars - def.A.signature.const_generic_params cg_args - in - let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in - - (* Evaluate the input operands *) - assert (List.length args = body.A.arg_count); - let cc = eval_operands config args in - - (* Push a frame delimiter - we use {!comp_transmit} to transmit the result - * of the operands evaluation from above to the functions afterwards, while - * ignoring it in this function *) - let cc = comp_transmit cc push_frame in - - (* Compute the initial values for the local variables *) - (* 1. Push the return value *) - let ret_var, locals = - match locals with - | ret_ty :: locals -> (ret_ty, locals) - | _ -> raise (Failure "Unreachable") - in - let input_locals, locals = - Collections.List.split_at locals body.A.arg_count - in - - let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in - - (* 2. Push the input values *) - let cf_push_inputs cf args = - let inputs = List.combine input_locals args in - (* Note that this function checks that the variables and their values - * have the same type (this is important) *) - push_vars inputs cf - in - let cc = comp cc cf_push_inputs in + match call.func with + | A.FunId (A.Regular fid) -> + eval_transparent_function_call_concrete config fid call cf ctx + | A.FunId (A.Assumed fid) -> + (* Continue - note that we do as if the function call has been successful, + * by giving {!Unit} to the continuation, because we place us in the case + * where we haven't panicked. Of course, the translation needs to take the + * panic case into account... *) + eval_assumed_function_call_concrete config fid call (cf Unit) ctx + | A.TraitMethod _ -> raise (Failure "Unimplemented") + +and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun + = + match call.func with + | A.FunId (A.Regular _) | A.TraitMethod _ -> + eval_transparent_function_call_symbolic config call + | A.FunId (A.Assumed fid) -> + eval_assumed_function_call_symbolic config fid call - (* 3. Push the remaining local variables (initialized as {!Bottom}) *) - let cc = comp cc (push_uninitialized_vars locals) in +(** Evaluate a local (i.e., non-assumed) function call in concrete mode *) +and eval_transparent_function_call_concrete (config : C.config) + (fid : A.FunDeclId.id) (call : A.call) : st_cm_fun = + let generics = call.A.generics in + let args = call.A.args in + let dest = call.A.dest in + (* Sanity check: we don't fully handle the const generic vars environment + in concrete mode yet *) + assert (generics.const_generics = []); + fun cf ctx -> + (* Retrieve the (correctly instantiated) body *) + let def = C.ctx_lookup_fun_decl ctx fid in + (* We can evaluate the function call only if it is not opaque *) + let body = + match def.body with + | None -> + raise + (Failure + ("Can't evaluate a call to an opaque function: " + ^ Print.name_to_string def.name)) + | Some body -> body + in + (* TODO: we need to normalize the types if we want to correctly support traits *) + assert (generics.trait_refs = []); + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_esubst_from_generics def.A.signature.generics generics tr_self + in + let locals, body_st = Subst.fun_body_substitute_in_body subst body in + + (* Evaluate the input operands *) + assert (List.length args = body.A.arg_count); + let cc = eval_operands config args in + + (* Push a frame delimiter - we use {!comp_transmit} to transmit the result + * of the operands evaluation from above to the functions afterwards, while + * ignoring it in this function *) + let cc = comp_transmit cc push_frame in + + (* Compute the initial values for the local variables *) + (* 1. Push the return value *) + let ret_var, locals = + match locals with + | ret_ty :: locals -> (ret_ty, locals) + | _ -> raise (Failure "Unreachable") + in + let input_locals, locals = + Collections.List.split_at locals body.A.arg_count + in - (* Execute the function body *) - let cc = comp cc (eval_function_body config body_st) in + let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in - (* Pop the stack frame and move the return value to its destination *) - let cf_finish cf res = - match res with - | Panic -> cf Panic - | Return -> - (* Pop the stack frame, retrieve the return value, move it to - * its destination and continue *) - pop_frame_assign config dest (cf Unit) - | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _ - | EndContinue _ -> - raise (Failure "Unreachable") - in - let cc = comp cc cf_finish in + (* 2. Push the input values *) + let cf_push_inputs cf args = + let inputs = List.combine input_locals args in + (* Note that this function checks that the variables and their values + * have the same type (this is important) *) + push_vars inputs cf + in + let cc = comp cc cf_push_inputs in + + (* 3. Push the remaining local variables (initialized as {!Bottom}) *) + let cc = comp cc (push_uninitialized_vars locals) in + + (* Execute the function body *) + let cc = comp cc (eval_function_body config body_st) in + + (* Pop the stack frame and move the return value to its destination *) + let cf_finish cf res = + match res with + | Panic -> cf Panic + | Return -> + (* Pop the stack frame, retrieve the return value, move it to + * its destination and continue *) + pop_frame_assign config dest (cf Unit) + | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _ + | EndContinue _ -> + raise (Failure "Unreachable") + in + let cc = comp cc cf_finish in - (* Continue *) - cc cf ctx + (* Continue *) + cc cf ctx (** Evaluate a local (i.e., non-assumed) function call in symbolic mode *) -and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = +and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) + : st_cm_fun = fun cf ctx -> - (* Retrieve the (correctly instantiated) signature *) - let def = C.ctx_lookup_fun_decl ctx fid in - let sg = def.A.signature in - (* Instantiate the signature and introduce fresh abstraction and region ids - * while doing so *) - let inst_sg = instantiate_fun_sig type_args cg_args sg in + (* Instantiate the signature and introduce fresh abstractions and region ids while doing so. + + We perform some manipulations when instantiating the signature. + + # Trait impl calls + ================== + In particular, we have a special treatment of trait method calls when + the trait ref is a known impl. + + For instance: + {[ + trait HasValue { + fn has_value(&self) -> bool; + } + + impl<T> HasValue for Option<T> { + fn has_value(&self) { + match self { + None => false, + Some(_) => true, + } + } + } + + fn option_has_value<T>(x: &Option<T>) -> bool { + x.has_value() + } + ]} + + The generated code looks like this: + {[ + structure HasValue (Self : Type) = { + has_value : Self -> result bool + } + + let OptionHasValueImpl.has_value (Self : Type) (self : Self) : result bool = + match self with + | None => false + | Some _ => true + + let OptionHasValueInstance (T : Type) : HasValue (Option T) = { + has_value = OptionHasValueInstance.has_value + } + ]} + + In [option_has_value], we don't want to refer to the [has_value] method + of the instance of [HasValue] for [Option<T>]. We want to refer directly + to the function which implements [has_value] for [Option<T>]. + That is, instead of generating this: + {[ + let option_has_value (T : Type) (x : Option T) : result bool = + (OptionHasValueInstance T).has_value x + ]} + + We want to generate this: + {[ + let option_has_value (T : Type) (x : Option T) : result bool = + OptionHasValueImpl.has_value T x + ]} + + # Provided trait methods + ======================== + Calls to provided trait methods also have a special treatment because + for now we forbid overriding provided trait methods in the trait implementations, + which means that whenever we call a provided trait method, we do not refer + to a trait clause but directly to the method provided in the trait declaration. + *) + let func, generics, def, inst_sg = + match call.func with + | A.FunId (A.Regular fid) -> + let def = C.ctx_lookup_fun_decl ctx fid in + let tr_self = T.UnknownTrait __FUNCTION__ in + let inst_sg = + instantiate_fun_sig ctx call.generics tr_self def.A.signature + in + (call.func, call.generics, def, inst_sg) + | A.FunId (A.Assumed _) -> + (* Unreachable: must be a transparent function *) + raise (Failure "Unreachable") + | A.TraitMethod (trait_ref, method_name, _) -> ( + log#ldebug + (lazy + ("trait method call:\n- call: " ^ call_to_string ctx call + ^ "\n- method name: " ^ method_name ^ "\n- call.generics:\n" + ^ egeneric_args_to_string ctx call.generics + ^ "\n- trait and method generics:\n" + ^ egeneric_args_to_string ctx + (Option.get call.trait_and_method_generic_args))); + (* When instantiating, we need to group the generics for the trait ref + and the method *) + let generics = Option.get call.trait_and_method_generic_args in + (* Lookup the trait method signature - there are several possibilities + depending on whethere we call a top-level trait method impl or the + method from a local clause *) + match trait_ref.trait_id with + | TraitImpl impl_id -> ( + (* Lookup the trait impl *) + let trait_impl = C.ctx_lookup_trait_impl ctx impl_id in + log#ldebug + (lazy ("trait impl: " ^ trait_impl_to_string ctx trait_impl)); + (* First look in the required methods *) + let method_id = + List.find_opt + (fun (s, _) -> s = method_name) + trait_impl.required_methods + in + match method_id with + | Some (_, id) -> + (* This is a required method *) + let method_def = C.ctx_lookup_fun_decl ctx id in + (* Instantiate *) + let tr_self = + T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref) + in + let inst_sg = + instantiate_fun_sig ctx generics tr_self + method_def.A.signature + in + (* Also update the function identifier: we want to forget + the fact that we called a trait method, and treat it as + a regular function call to the top-level function + which implements the method. In order to do this properly, + we also need to update the generics. + *) + let func = A.FunId (A.Regular id) in + (func, generics, method_def, inst_sg) + | None -> + (* If not found, lookup the methods provided by the trait *declaration* + (remember: for now, we forbid overriding provided methods) *) + assert (trait_impl.provided_methods = []); + let trait_decl = + C.ctx_lookup_trait_decl ctx + trait_ref.trait_decl_ref.trait_decl_id + in + let _, method_id = + List.find + (fun (s, _) -> s = method_name) + trait_decl.provided_methods + in + let method_id = Option.get method_id in + let method_def = C.ctx_lookup_fun_decl ctx method_id in + (* For the instantiation we have to do something peculiar + because the method was defined for the trait declaration. + We have to group: + - the parameters given to the trait decl reference + - the parameters given to the method itself + For instance: + {[ + trait Foo<T> { + fn f<U>(...) { ... } + } + + fn g<G>(x : G) where Clause0: Foo<G, bool> + { + x.f::<u32>(...) // The arguments to f are: <G, bool, u32> + } + ]} + *) + let all_generics = + TypesUtils.merge_generic_args + trait_ref.trait_decl_ref.decl_generics call.generics + in + log#ldebug + (lazy + ("provided method call:" ^ "\n- method name: " ^ method_name + ^ "\n- all_generics:\n" + ^ egeneric_args_to_string ctx all_generics + ^ "\n- parent params info: " + ^ Print.option_to_string A.show_params_info + method_def.signature.parent_params_info)); + let tr_self = + T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref) + in + let inst_sg = + instantiate_fun_sig ctx all_generics tr_self + method_def.A.signature + in + (call.func, call.generics, method_def, inst_sg)) + | _ -> + (* We are using a local clause - we lookup the trait decl *) + let trait_decl = + C.ctx_lookup_trait_decl ctx trait_ref.trait_decl_ref.trait_decl_id + in + (* Lookup the method decl in the required *and* the provided methods *) + let _, method_id = + let provided = + List.filter_map + (fun (id, f) -> + match f with None -> None | Some f -> Some (id, f)) + trait_decl.provided_methods + in + List.find + (fun (s, _) -> s = method_name) + (List.append trait_decl.required_methods provided) + in + let method_def = C.ctx_lookup_fun_decl ctx method_id in + log#ldebug (lazy ("method:\n" ^ fun_decl_to_string ctx method_def)); + (* Instantiate *) + let tr_self = T.TraitRef trait_ref in + let tr_self = + TypesUtils.etrait_instance_id_no_regions_to_gr_trait_instance_id + tr_self + in + let inst_sg = + instantiate_fun_sig ctx generics tr_self method_def.A.signature + in + (call.func, call.generics, method_def, inst_sg)) + in (* Sanity check *) - assert (List.length args = List.length def.A.signature.inputs); + assert (List.length call.args = List.length def.A.signature.inputs); (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg - region_args type_args cg_args args dest cf ctx + eval_function_call_symbolic_from_inst_sig config func inst_sg generics + call.args call.dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. This allows us to factorize the evaluation of local and non-local function calls in symbolic mode: only their signatures matter. + + The [self_trait_ref] trait ref refers to [Self]. We use it when calling + a provided trait method, because those methods have a special treatment: + we dot not group them with the required trait methods, and forbid (for now) + overriding them. We treat them as regular method, which take an additional + trait ref as input. *) and eval_function_call_symbolic_from_inst_sig (config : C.config) - (fid : A.fun_id) (inst_sg : A.inst_fun_sig) - (_region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + (fid : A.fun_id_or_trait_method_ref) (inst_sg : A.inst_fun_sig) + (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) : st_cm_fun = fun cf ctx -> (* Generate a fresh symbolic value for the return value *) @@ -1224,8 +1422,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args - args args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx abs_ids generics args + args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1286,17 +1484,18 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) cc (cf Unit) ctx (** Evaluate a non-local function call in symbolic mode *) -and eval_non_local_function_call_symbolic (config : C.config) - (fid : A.assumed_fun_id) (region_args : T.erased_region list) - (type_args : T.ety list) (cg_args : T.const_generic list) - (args : E.operand list) (dest : E.place) : st_cm_fun = +and eval_assumed_function_call_symbolic (config : C.config) + (fid : A.assumed_fun_id) (call : A.call) : st_cm_fun = fun cf ctx -> + let generics = call.generics in + let args = call.args in + let dest = call.dest in (* Sanity check: make sure the type parameters don't contain regions - * this is a current limitation of our synthesis *) assert ( List.for_all (fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty)) - type_args); + generics.types); (* There are two cases (and this is extremely annoying): - the function is not box_free @@ -1307,7 +1506,7 @@ and eval_non_local_function_call_symbolic (config : C.config) | A.BoxFree -> (* Degenerate case: box_free - note that this is not really a function * call: no need to call a "synthesize_..." function *) - eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx + eval_box_free config generics args dest (cf Unit) ctx | _ -> (* "Normal" case: not box_free *) (* In symbolic mode, the behaviour of a function call is completely defined @@ -1319,55 +1518,15 @@ and eval_non_local_function_call_symbolic (config : C.config) (* should have been treated above *) raise (Failure "Unreachable") | _ -> - instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid) + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + instantiate_fun_sig ctx generics tr_self + (Assumed.get_assumed_sig fid) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig - region_args type_args cg_args args dest cf ctx - -(** Evaluate a non-local (i.e, assumed) function call such as [Box::deref] - (auxiliary helper for [eval_statement]) *) -and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = - fun cf ctx -> - (* Debug *) - log#ldebug - (lazy - (let type_args = - "[" ^ String.concat ", " (List.map (ety_to_string ctx) type_args) ^ "]" - in - let args = - "[" ^ String.concat ", " (List.map (operand_to_string ctx) args) ^ "]" - in - let dest = place_to_string ctx dest in - "eval_non_local_function_call:\n- fid:" ^ A.show_assumed_fun_id fid - ^ "\n- type_args: " ^ type_args ^ "\n- args: " ^ args ^ "\n- dest: " - ^ dest)); - - match config.mode with - | C.ConcreteMode -> - eval_non_local_function_call_concrete config fid region_args type_args - cg_args args dest (cf Unit) ctx - | C.SymbolicMode -> - eval_non_local_function_call_symbolic config fid region_args type_args - cg_args args dest cf ctx - -(** Evaluate a local (i.e, not assumed) function call (auxiliary helper for - [eval_statement]) *) -and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = - match config.mode with - | ConcreteMode -> - eval_local_function_call_concrete config fid region_args type_args cg_args - args dest - | SymbolicMode -> - eval_local_function_call_symbolic config fid region_args type_args cg_args - args dest + eval_function_call_symbolic_from_inst_sig config (A.FunId (A.Assumed fid)) + inst_sig generics args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun = diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli index 814bc964..e65758ae 100644 --- a/compiler/InterpreterStatements.mli +++ b/compiler/InterpreterStatements.mli @@ -25,15 +25,6 @@ open InterpreterExpressions *) val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun -(** Instantiate a function signature, introducing **fresh** abstraction ids and - region ids. This is mostly used in preparation of function calls, when - evaluating in symbolic mode of course. - - Note: there are no region parameters, because they should be erased. - *) -val instantiate_fun_sig : - T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig - (** Helper. Create a list of abstractions from a list of regions groups, and insert diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index 7bd37550..7aaee6ff 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -10,6 +10,11 @@ open TypesUtils module PA = Print.EvalCtxLlbcAst open Cps +(* TODO: we should probably rename the file to ContextsUtils *) + +(** The local logger *) +let log = L.interpreter_log + (** Some utilities *) (** Auxiliary function - call a function which requires a continuation, @@ -38,6 +43,15 @@ let typed_value_to_string = PA.typed_value_to_string let typed_avalue_to_string = PA.typed_avalue_to_string let place_to_string = PA.place_to_string let operand_to_string = PA.operand_to_string +let egeneric_args_to_string = PA.egeneric_args_to_string +let rtrait_instance_id_to_string = PA.rtrait_instance_id_to_string +let fun_sig_to_string = PA.fun_sig_to_string +let fun_decl_to_string = PA.fun_decl_to_string +let call_to_string = PA.call_to_string + +let trait_impl_to_string ctx = + PA.trait_impl_to_string { ctx with type_vars = []; const_generic_vars = [] } + let statement_to_string ctx = PA.statement_to_string ctx "" " " let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " " let env_elem_to_string ctx = PA.env_elem_to_string ctx "" " " @@ -255,7 +269,8 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx) raise Found else () | V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack - | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate -> + | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate + | V.ConstGeneric | V.TraitConst -> () end in @@ -272,7 +287,7 @@ let rvalue_get_place (rv : E.rvalue) : E.place option = match rv with | Use (Copy p | Move p) -> Some p | Use (Constant _) -> None - | Ref (p, _) -> Some p + | RvRef (p, _) -> Some p | UnaryOp _ | BinaryOp _ | Global _ | Discriminant _ | Aggregate _ -> None (** See {!ValuesUtils.symbolic_value_has_borrows} *) @@ -403,3 +418,103 @@ let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets * ids_to_values = (** Compute the sets of ids found in a context. *) let compute_context_ids (ctx : C.eval_ctx) : ids_sets * ids_to_values = compute_contexts_ids [ ctx ] + +(** **WARNING**: this function doesn't compute the normalized types + (for the trait type aliases). This should be computed afterwards. + *) +let initialize_eval_context (ctx : C.decls_ctx) + (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) + (const_generic_vars : T.const_generic_var list) : C.eval_ctx = + C.reset_global_counters (); + let const_generic_vars_map = + T.ConstGenericVarId.Map.of_list + (List.map + (fun (cg : T.const_generic_var) -> + let ty = TypesUtils.ety_no_regions_to_rty (T.Literal cg.ty) in + let cv = mk_fresh_symbolic_typed_value V.ConstGeneric ty in + (cg.index, cv)) + const_generic_vars) + in + { + C.type_context = ctx.type_ctx; + C.fun_context = ctx.fun_ctx; + C.global_context = ctx.global_ctx; + C.trait_decls_context = ctx.trait_decls_ctx; + C.trait_impls_context = ctx.trait_impls_ctx; + C.region_groups; + C.type_vars; + C.const_generic_vars; + C.const_generic_vars_map; + C.norm_trait_etypes = C.ETraitTypeRefMap.empty (* Empty for now *); + C.norm_trait_rtypes = C.RTraitTypeRefMap.empty (* Empty for now *); + C.norm_trait_stypes = C.STraitTypeRefMap.empty (* Empty for now *); + C.env = [ C.Frame ]; + C.ended_regions = T.RegionId.Set.empty; + } + +(** Instantiate a function signature, introducing **fresh** abstraction ids and + region ids. This is mostly used in preparation of function calls (when + evaluating in symbolic mode). + + Note: there are no region parameters, because they should be erased. + *) +let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.egeneric_args) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + log#ldebug + (lazy + ("instantiate_fun_sig:" ^ "\n- generics: " + ^ egeneric_args_to_string ctx generics + ^ "\n- tr_self: " + ^ rtrait_instance_id_to_string ctx tr_self + ^ "\n- sg: " ^ fun_sig_to_string ctx sg)); + (* Generate fresh abstraction ids and create a substitution from region + * group ids to abstraction ids *) + let rg_abs_ids_bindings = + List.map + (fun rg -> + let abs_id = C.fresh_abstraction_id () in + (rg.T.id, abs_id)) + sg.regions_hierarchy + in + let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t = + List.fold_left + (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp) + T.RegionGroupId.Map.empty rg_abs_ids_bindings + in + let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id = + T.RegionGroupId.Map.find rg_id asubst_map + in + (* Generate fresh regions and their substitutions *) + let _, rsubst, _ = Subst.fresh_regions_with_substs sg.generics.regions in + (* Generate the type substitution + * Note that we need the substitution to map the type variables to + * {!rty} types (not {!ety}). In order to do that, we convert the + * type parameters to types with regions. This is possible only + * if those types don't contain any regions. + * This is a current limitation of the analysis: there is still some + * work to do to properly handle full type parametrization. + * *) + let rtype_params = List.map ety_no_regions_to_rty generics.types in + let tsubst = Subst.make_type_subst_from_vars sg.generics.types rtype_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.generics.const_generics + generics.const_generics + in + (* TODO: something annoying with the trait ref subst: we need to use region + types, but the arguments use erased regions. For now we use the fact + that no regions should appear inside. In the future: we should merge + ety and rty. *) + let trait_refs = + List.map TypesUtils.etrait_ref_no_regions_to_gr_trait_ref + generics.trait_refs + in + let tr_subst = + Subst.make_trait_subst_from_clauses sg.generics.trait_clauses trait_refs + in + (* Substitute the signature *) + let inst_sig = + AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst + tr_subst tr_self sg + in + (* Return *) + inst_sig diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index f29c7f88..9ac5ce13 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -7,6 +7,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module A = LlbcAst module L = Logging open Cps @@ -406,13 +407,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (match (tv.V.value, tv.V.ty) with | V.Literal cv, T.Literal ty -> check_literal_type cv ty (* ADT case *) - | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> + | V.Adt av, T.Adt (T.AdtId def_id, generics) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) - assert (List.length regions = List.length def.region_params); - assert (List.length tys = List.length def.type_params); + assert ( + List.length generics.regions = List.length def.generics.regions); + assert (List.length generics.types = List.length def.generics.types); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -421,8 +423,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = | _ -> raise (Failure "Erroneous typing")); (* Check that the field types are correct *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id - tys cgs + Assoc.type_decl_get_inst_norm_field_etypes ctx def av.V.variant_id + generics in let fields_with_types = List.combine av.V.field_values field_types @@ -431,20 +433,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) -> - assert (regions = []); - assert (cgs = []); + | V.Adt av, T.Adt (T.Tuple, generics) -> + assert (generics.regions = []); + assert (generics.const_generics = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) - let fields_with_types = List.combine av.V.field_values tys in + let fields_with_types = + List.combine av.V.field_values generics.types + in List.iter (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( + | V.Adt av, T.Adt (T.Assumed aty_id, generics) -> ( assert (av.V.variant_id = None || aty_id = T.Option); - match (aty_id, av.V.field_values, regions, tys, cgs) with + match + ( aty_id, + av.V.field_values, + generics.regions, + generics.types, + generics.const_generics ) + with (* Box *) | T.Box, [ inner_value ], [], [ inner_ty ], [] | T.Option, [ inner_value ], [], [ inner_ty ], [] -> @@ -520,14 +530,17 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check the current pair (value, type) *) (match (atv.V.value, atv.V.ty) with (* ADT case *) - | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> + | V.AAdt av, T.Adt (T.AdtId def_id, generics) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) - assert (List.length regions = List.length def.region_params); - assert (List.length tys = List.length def.type_params); - assert (List.length cgs = List.length def.const_generic_params); + assert ( + List.length generics.regions = List.length def.generics.regions); + assert (List.length generics.types = List.length def.generics.types); + assert ( + List.length generics.const_generics + = List.length def.generics.const_generics); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -536,8 +549,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = | _ -> raise (Failure "Erroneous typing")); (* Check that the field types are correct *) let field_types = - Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id - regions tys cgs + Assoc.type_decl_get_inst_norm_field_rtypes ctx def av.V.variant_id + generics in let fields_with_types = List.combine av.V.field_values field_types @@ -546,20 +559,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) -> - assert (regions = []); - assert (cgs = []); + | V.AAdt av, T.Adt (T.Tuple, generics) -> + assert (generics.regions = []); + assert (generics.const_generics = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) - let fields_with_types = List.combine av.V.field_values tys in + let fields_with_types = + List.combine av.V.field_values generics.types + in List.iter (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( + | V.AAdt av, T.Adt (T.Assumed aty_id, generics) -> ( assert (av.V.variant_id = None); - match (aty_id, av.V.field_values, regions, tys, cgs) with + match + ( aty_id, + av.V.field_values, + generics.regions, + generics.types, + generics.const_generics ) + with (* Box *) | T.Box, [ boxed_value ], [], [ boxed_ty ], [] -> assert (boxed_value.V.ty = boxed_ty) diff --git a/compiler/LlbcAst.ml b/compiler/LlbcAst.ml index f4d26e18..2db859b2 100644 --- a/compiler/LlbcAst.ml +++ b/compiler/LlbcAst.ml @@ -11,6 +11,7 @@ type abs_region_groups = (AbstractionId.id, RegionId.id) g_region_groups (** A function signature, after instantiation *) type inst_fun_sig = { regions_hierarchy : abs_region_groups; + trait_type_constraints : rtrait_type_constraint list; inputs : rty list; output : rty; } diff --git a/compiler/LlbcAstUtils.ml b/compiler/LlbcAstUtils.ml index 1111c297..a982af30 100644 --- a/compiler/LlbcAstUtils.ml +++ b/compiler/LlbcAstUtils.ml @@ -12,3 +12,32 @@ let lookup_fun_name (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) : match fun_id with | Regular id -> (FunDeclId.Map.find id fun_decls).name | Assumed aid -> Assumed.get_assumed_name aid + +(** Return the opaque declarations found in the crate. + + [filter_assumed]: if [true], do not consider as opaque the external definitions + that we will map to definitions from the standard library. + + Remark: the list of functions also contains the list of opaque global bodies. + *) +let crate_get_opaque_decls (k : crate) (filter_assumed : bool) : + T.type_decl list * fun_decl list = + let open ExtractAssumed in + let is_opaque_fun (d : fun_decl) : bool = + let sname = name_to_simple_name d.name in + d.body = None + (* Something to pay attention to: we must ignore trait method *declarations* + (which don't have a body but must not be considered as opaque) *) + && (match d.kind with TraitMethodDecl _ -> false | _ -> true) + && ((not filter_assumed) + || not (SimpleNameMap.mem sname assumed_globals_map)) + in + let is_opaque_type (d : T.type_decl) : bool = d.kind = T.Opaque in + (* Note that by checking the function bodies we also the globals *) + ( List.filter is_opaque_type (T.TypeDeclId.Map.values k.types), + List.filter is_opaque_fun (FunDeclId.Map.values k.functions) ) + +(** Return true if the crate contains opaque declarations, ignoring the assumed + definitions. *) +let crate_has_opaque_decls (k : crate) (filter_assumed : bool) : bool = + crate_get_opaque_decls k filter_assumed <> ([], []) diff --git a/compiler/Logging.ml b/compiler/Logging.ml index 9dc1f5e3..59abbfc7 100644 --- a/compiler/Logging.ml +++ b/compiler/Logging.ml @@ -9,6 +9,9 @@ let pre_passes_log = L.get_logger "MainLogger.PrePasses" (** Logger for Translate *) let translate_log = L.get_logger "MainLogger.Translate" +(** Logger for Contexts *) +let contexts_log = L.get_logger "MainLogger.Contexts" + (** Logger for PureUtils *) let pure_utils_log = L.get_logger "MainLogger.PureUtils" @@ -57,6 +60,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows" (** Logger for Invariants *) let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants" +(** Logger for AssociatedTypes *) +let associated_types_log = L.get_logger "MainLogger.AssociatedTypes" + (** Logger for SCC *) let scc_log = L.get_logger "MainLogger.Graph.SCC" diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml index b348ba1d..1058fab0 100644 --- a/compiler/PrePasses.ml +++ b/compiler/PrePasses.ml @@ -107,7 +107,7 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl = false | Assign (_, rv) -> ( match rv with - | Use _ | Ref _ -> not must_end_with_exit + | Use _ | RvRef _ -> not must_end_with_exit | Aggregate (AggregatedTuple, []) -> not must_end_with_exit | _ -> false) | FakeRead _ | Drop _ | Nop -> not must_end_with_exit @@ -376,7 +376,7 @@ let remove_shallow_borrows (crate : A.crate) (f : A.fun_decl) : A.fun_decl = method! visit_Assign env p rv = match (p.projection, rv) with - | [], E.Ref (_, E.Shallow) -> + | [], E.RvRef (_, E.Shallow) -> (* Filter *) filtered := E.VarId.Set.add p.var_id !filtered; Nop diff --git a/compiler/Print.ml b/compiler/Print.ml index 9aa73d7c..5d5c16ee 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -21,6 +21,9 @@ module Values = struct type_decl_id_to_string : T.TypeDeclId.id -> string; const_generic_var_id_to_string : T.ConstGenericVarId.id -> string; global_decl_id_to_string : T.GlobalDeclId.id -> string; + trait_decl_id_to_string : T.TraitDeclId.id -> string; + trait_impl_id_to_string : T.TraitImplId.id -> string; + trait_clause_id_to_string : T.TraitClauseId.id -> string; adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string; var_id_to_string : E.VarId.id -> string; adt_field_names : @@ -34,6 +37,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter = @@ -43,6 +49,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter = @@ -52,6 +61,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let var_id_to_string (id : E.VarId.id) : string = @@ -86,10 +98,10 @@ module Values = struct List.map (typed_value_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _, _) -> + | T.Adt (T.AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -111,7 +123,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _, _) -> ( + | T.Adt (T.Assumed aty, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -201,10 +213,10 @@ module Values = struct List.map (typed_avalue_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _, _) -> + | T.Adt (T.AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -226,7 +238,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _, _) -> ( + | T.Adt (T.Assumed aty, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -452,6 +464,9 @@ module Contexts = struct PV.adt_variant_to_string = fmt.adt_variant_to_string; PV.var_id_to_string = fmt.var_id_to_string; PV.adt_field_names = fmt.adt_field_names; + PV.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PV.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PV.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let ast_to_value_formatter (fmt : PA.ast_formatter) : PV.value_formatter = @@ -463,20 +478,27 @@ module Contexts = struct let ctx_to_rtype_formatter (fmt : ctx_formatter) : PT.rtype_formatter = PV.value_to_rtype_formatter fmt + let ctx_to_stype_formatter (fmt : ctx_formatter) : PT.stype_formatter = + PV.value_to_stype_formatter fmt + let eval_ctx_to_ctx_formatter (ctx : C.eval_ctx) : ctx_formatter = - (* We shouldn't use rvar_to_string *) - let rvar_to_string _r = - raise (Failure "Unexpected use of rvar_to_string") + let rvar_to_string r = + (* In theory we shouldn't use rvar_to_string, but it can happen + when printing definitions for instance... *) + T.RegionVarId.to_string r in let r_to_string r = PT.region_id_to_string r in let type_var_id_to_string vid = - let v = C.lookup_type_var ctx vid in - v.name + (* The context may be invalid *) + match C.lookup_type_var_opt ctx vid with + | None -> T.TypeVarId.to_string vid + | Some v -> v.name in let const_generic_var_id_to_string vid = - let v = C.lookup_const_generic_var ctx vid in - v.name + match C.lookup_const_generic_var_opt ctx vid with + | None -> T.ConstGenericVarId.to_string vid + | Some v -> v.name in let type_decl_id_to_string def_id = let def = C.ctx_lookup_type_decl ctx def_id in @@ -486,6 +508,15 @@ module Contexts = struct let def = C.ctx_lookup_global_decl ctx def_id in name_to_string def.name in + let trait_decl_id_to_string def_id = + let def = C.ctx_lookup_trait_decl ctx def_id in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = C.ctx_lookup_trait_impl ctx def_id in + name_to_string def.name + in + let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in let adt_variant_to_string = PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls in @@ -506,6 +537,9 @@ module Contexts = struct adt_variant_to_string; var_id_to_string; adt_field_names; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } let eval_ctx_to_ast_formatter (ctx : C.eval_ctx) : PA.ast_formatter = @@ -521,6 +555,15 @@ module Contexts = struct let def = C.ctx_lookup_global_decl ctx def_id in global_name_to_string def.name in + let trait_decl_id_to_string def_id = + let def = C.ctx_lookup_trait_decl ctx def_id in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = C.ctx_lookup_trait_impl ctx def_id in + name_to_string def.name + in + let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in { rvar_to_string = ctx_fmt.PV.rvar_to_string; r_to_string = ctx_fmt.PV.r_to_string; @@ -533,6 +576,9 @@ module Contexts = struct adt_field_to_string; fun_decl_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } (** Split an [env] at every occurrence of [Frame], eliminating those elements. @@ -608,6 +654,50 @@ module EvalCtxLlbcAst = struct let fmt = PC.ctx_to_rtype_formatter fmt in PT.rty_to_string fmt t + let sty_to_string (ctx : C.eval_ctx) (t : T.sty) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.sty_to_string fmt t + + let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_ref_to_string fmt x + + let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_ref_to_string fmt x + + let strait_ref_to_string (ctx : C.eval_ctx) (x : T.strait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.strait_ref_to_string fmt x + + let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_instance_id_to_string fmt x + + let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_instance_id_to_string fmt x + + let strait_instance_id_to_string (ctx : C.eval_ctx) (x : T.strait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.strait_instance_id_to_string fmt x + + let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string + = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.egeneric_args_to_string fmt x + let borrow_content_to_string (ctx : C.eval_ctx) (bc : V.borrow_content) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in @@ -653,11 +743,27 @@ module EvalCtxLlbcAst = struct let fmt = PC.eval_ctx_to_ast_formatter ctx in PE.operand_to_string fmt op + let call_to_string (ctx : C.eval_ctx) (call : A.call) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.call_to_string fmt "" call + + let fun_decl_to_string (ctx : C.eval_ctx) (f : A.fun_decl) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.fun_decl_to_string fmt "" " " f + + let fun_sig_to_string (ctx : C.eval_ctx) (x : A.fun_sig) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.fun_sig_to_string fmt "" " " x + let statement_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (e : A.statement) : string = let fmt = PC.eval_ctx_to_ast_formatter ctx in PA.statement_to_string fmt indent indent_incr e + let trait_impl_to_string (ctx : C.eval_ctx) (timpl : A.trait_impl) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.trait_impl_to_string fmt " " " " timpl + let env_elem_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (ev : C.env_elem) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index cfb63ec2..5fb5978b 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -8,6 +8,9 @@ type type_formatter = { type_decl_id_to_string : TypeDeclId.id -> string; const_generic_var_id_to_string : ConstGenericVarId.id -> string; global_decl_id_to_string : GlobalDeclId.id -> string; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } type value_formatter = { @@ -18,6 +21,9 @@ type value_formatter = { adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } let value_to_type_formatter (fmt : value_formatter) : type_formatter = @@ -26,6 +32,9 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = type_decl_id_to_string = fmt.type_decl_id_to_string; const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; global_decl_id_to_string = fmt.global_decl_id_to_string; + trait_decl_id_to_string = fmt.trait_decl_id_to_string; + trait_impl_id_to_string = fmt.trait_impl_id_to_string; + trait_clause_id_to_string = fmt.trait_clause_id_to_string; } (* TODO: we need to store which variables we have encountered so far, and @@ -42,6 +51,9 @@ type ast_formatter = { adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; fun_decl_id_to_string : FunDeclId.id -> string; global_decl_id_to_string : GlobalDeclId.id -> string; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = @@ -53,6 +65,9 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = adt_variant_to_string = fmt.adt_variant_to_string; var_id_to_string = fmt.var_id_to_string; adt_field_names = fmt.adt_field_names; + trait_decl_id_to_string = fmt.trait_decl_id_to_string; + trait_impl_id_to_string = fmt.trait_impl_id_to_string; + trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let ast_to_type_formatter (fmt : ast_formatter) : type_formatter = @@ -70,31 +85,51 @@ let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string let literal_to_string = Print.PrimitiveValues.literal_to_string +(* Remark: not using generic_params on purpose, because we may use parameters + which either come from LLBC or from pure, and the [generic_params] type + for those ASTs is not the same. Note that it works because we actually don't + need to know the trait clauses to print the AST: we can thus ignore them. +*) let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) + (trait_decls : A.trait_decl TraitDeclId.Map.t) + (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list) (const_generic_params : const_generic_var list) : type_formatter = let type_var_id_to_string vid = - let var = T.TypeVarId.nth type_params vid in + let var = TypeVarId.nth type_params vid in type_var_to_string var in let const_generic_var_id_to_string vid = - let var = T.ConstGenericVarId.nth const_generic_params vid in + let var = ConstGenericVarId.nth const_generic_params vid in const_generic_var_to_string var in let type_decl_id_to_string def_id = - let def = T.TypeDeclId.Map.find def_id type_decls in + let def = TypeDeclId.Map.find def_id type_decls in name_to_string def.name in let global_decl_id_to_string def_id = - let def = T.GlobalDeclId.Map.find def_id global_decls in + let def = GlobalDeclId.Map.find def_id global_decls in name_to_string def.name in + let trait_decl_id_to_string def_id = + let def = TraitDeclId.Map.find def_id trait_decls in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = TraitImplId.Map.find def_id trait_impls in + name_to_string def.name + in + let trait_clause_id_to_string id = + Print.PT.trait_clause_id_to_pretty_string id + in { type_var_id_to_string; type_decl_id_to_string; const_generic_var_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } (* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter. @@ -106,19 +141,21 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (fun_decls : A.fun_decl FunDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) + (trait_decls : A.trait_decl TraitDeclId.Map.t) + (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list) (const_generic_params : const_generic_var list) : ast_formatter = - let type_var_id_to_string vid = - let var = T.TypeVarId.nth type_params vid in - type_var_to_string var - in - let const_generic_var_id_to_string vid = - let var = T.ConstGenericVarId.nth const_generic_params vid in - const_generic_var_to_string var - in - let type_decl_id_to_string def_id = - let def = T.TypeDeclId.Map.find def_id type_decls in - name_to_string def.name + let ({ + type_var_id_to_string; + type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; + } + : type_formatter) = + mk_type_formatter type_decls global_decls trait_decls trait_impls + type_params const_generic_params in let adt_variant_to_string = Print.Types.type_ctx_to_adt_variant_to_string_fun type_decls @@ -137,10 +174,6 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let def = FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in - let global_decl_id_to_string def_id = - let def = GlobalDeclId.Map.find def_id global_decls in - global_name_to_string def.name - in { type_var_id_to_string; const_generic_var_id_to_string; @@ -151,6 +184,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) adt_field_to_string; fun_decl_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } let assumed_ty_to_string (aty : assumed_ty) : string = @@ -182,20 +218,18 @@ let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) : let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = match ty with - | Adt (id, tys, cgs) -> ( - let tys = List.map (ty_to_string fmt false) tys in - let cgs = List.map (const_generic_to_string fmt) cgs in - let params = List.append tys cgs in + | Adt (id, generics) -> ( match id with | Tuple -> - assert (cgs = []); - "(" ^ String.concat " * " tys ^ ")" + let generics = generic_args_to_strings fmt false generics in + "(" ^ String.concat " * " generics ^ ")" | AdtId _ | Assumed _ -> - let params_s = - if params = [] then "" else " " ^ String.concat " " params + let generics = generic_args_to_strings fmt true generics in + let generics_s = + if generics = [] then "" else " " ^ String.concat " " generics in - let ty_s = type_id_to_string fmt id ^ params_s in - if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) + let ty_s = type_id_to_string fmt id ^ generics_s in + if generics <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) | TypeVar tv -> fmt.type_var_id_to_string tv | Literal lty -> literal_type_to_string lty | Arrow (arg_ty, ret_ty) -> @@ -203,6 +237,71 @@ let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty in if inside then "(" ^ ty ^ ")" else ty + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = trait_ref_to_string fmt false trait_ref in + let s = + if generics = empty_generic_args then trait_ref ^ "::" ^ type_name + else + let generics = generic_args_to_string fmt generics in + "(" ^ trait_ref ^ " " ^ generics ^ ")::" ^ type_name + in + if inside then "(" ^ s ^ ")" else s + +and generic_args_to_strings (fmt : type_formatter) (inside : bool) + (generics : generic_args) : string list = + let tys = List.map (ty_to_string fmt inside) generics.types in + let cgs = List.map (const_generic_to_string fmt) generics.const_generics in + let trait_refs = + List.map (trait_ref_to_string fmt inside) generics.trait_refs + in + List.concat [ tys; cgs; trait_refs ] + +and generic_args_to_string (fmt : type_formatter) (generics : generic_args) : + string = + String.concat " " (generic_args_to_strings fmt true generics) + +and trait_ref_to_string (fmt : type_formatter) (inside : bool) (tr : trait_ref) + : string = + let trait_id = trait_instance_id_to_string fmt false tr.trait_id in + let generics = generic_args_to_string fmt tr.generics in + let s = trait_id ^ generics in + if tr.generics = empty_generic_args || not inside then s else "(" ^ s ^ ")" + +and trait_instance_id_to_string (fmt : type_formatter) (inside : bool) + (id : trait_instance_id) : string = + match id with + | Self -> "Self" + | TraitImpl id -> fmt.trait_impl_id_to_string id + | Clause id -> fmt.trait_clause_id_to_string id + | ParentClause (inst_id, _decl_id, clause_id) -> + let inst_id = trait_instance_id_to_string fmt false inst_id in + let clause_id = fmt.trait_clause_id_to_string clause_id in + "parent(" ^ inst_id ^ ")::" ^ clause_id + | ItemClause (inst_id, _decl_id, item_name, clause_id) -> + let inst_id = trait_instance_id_to_string fmt false inst_id in + let clause_id = fmt.trait_clause_id_to_string clause_id in + "(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]" + | TraitRef tr -> trait_ref_to_string fmt inside tr + | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")" + +let trait_clause_to_string (fmt : type_formatter) (clause : trait_clause) : + string = + let clause_id = fmt.trait_clause_id_to_string clause.clause_id in + let trait_id = fmt.trait_decl_id_to_string clause.trait_id in + let generics = generic_args_to_strings fmt true clause.generics in + let generics = + if generics = [] then "" else " " ^ String.concat " " generics + in + "[" ^ clause_id ^ "]: " ^ trait_id ^ generics + +let generic_params_to_strings (fmt : type_formatter) (generics : generic_params) + : string list = + let tys = List.map type_var_to_string generics.types in + let cgs = List.map const_generic_var_to_string generics.const_generics in + let trait_clauses = + List.map (trait_clause_to_string fmt) generics.trait_clauses + in + List.concat [ tys; cgs; trait_clauses ] let field_to_string fmt inside (f : field) : string = match f.field_name with @@ -217,11 +316,10 @@ let variant_to_string fmt (v : variant) : string = ^ ")" let type_decl_to_string (fmt : type_formatter) (def : type_decl) : string = - let types = def.type_params in let name = name_to_string def.name in let params = - if types = [] then "" - else " " ^ String.concat " " (List.map type_var_to_string types) + if def.generics = empty_generic_params then "" + else " " ^ String.concat " " (generic_params_to_strings fmt def.generics) in match def.kind with | Struct fields -> @@ -353,10 +451,10 @@ let adt_g_value_to_string (fmt : value_formatter) (field_values : 'v list) (ty : ty) : string = let field_values = List.map value_to_string field_values in match ty with - | Adt (Tuple, _, _) -> + | Adt (Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | Adt (AdtId def_id, _, _) -> + | Adt (AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match variant_id with @@ -378,7 +476,7 @@ let adt_g_value_to_string (fmt : value_formatter) let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | Adt (Assumed aty, _, _) -> ( + | Adt (Assumed aty, _) -> ( (* Assumed type *) match aty with | State -> @@ -464,10 +562,10 @@ let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) : let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string = let ty_fmt = ast_to_type_formatter fmt in - let type_params = List.map type_var_to_string sg.type_params in + let generics = generic_params_to_strings ty_fmt sg.generics in let inputs = List.map (ty_to_string ty_fmt false) sg.inputs in let output = ty_to_string ty_fmt false sg.output in - let all_types = List.concat [ type_params; inputs; [ output ] ] in + let all_types = List.concat [ generics; inputs; [ output ] ] in String.concat " -> " all_types let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string = @@ -512,6 +610,7 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = | ArrayToSliceMut -> "@ArrayToSliceMut" | ArraySubsliceShared -> "@ArraySubsliceShared" | ArraySubsliceMut -> "@ArraySubsliceMut" + | ArrayRepeat -> "@ArrayRepeat" | SliceLen -> "@SliceLen" | SliceIndexShared -> "@SliceIndexShared" | SliceIndexMut -> "@SliceIndexMut" @@ -531,8 +630,11 @@ let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = | FromLlbc (fid, lp_id, rg_id) -> let f = match fid with - | Regular fid -> fmt.fun_decl_id_to_string fid - | Assumed fid -> llbc_assumed_fun_id_to_string fid + | FunId (Regular fid) -> fmt.fun_decl_id_to_string fid + | FunId (Assumed fid) -> llbc_assumed_fun_id_to_string fid + | TraitMethod (trait_ref, method_name, _) -> + let fmt = ast_to_type_formatter fmt in + trait_ref_to_string fmt true trait_ref ^ "." ^ method_name in f ^ fun_suffix lp_id rg_id | Pure fid -> pure_assumed_fun_id_to_string fid @@ -559,9 +661,8 @@ let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) : let rec texpression_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (indent_incr : string) (e : texpression) : string = match e.e with - | Var var_id -> - let s = fmt.var_id_to_string var_id in - if inside then "(" ^ s ^ ")" else s + | Var var_id -> fmt.var_id_to_string var_id + | CVar cg_id -> fmt.const_generic_var_id_to_string cg_id | Const cv -> literal_to_string cv | App _ -> (* Recursively destruct the app, to have a pair (app, arguments list) *) @@ -632,10 +733,11 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (* There are two possibilities: either the [app] is an instantiated, * top-level qualifier (function, ADT constructore...), or it is a "regular" * expression *) - let app, tys = + let app, generics = match app.e with | Qualif qualif -> (* Qualifier case *) + let ty_fmt = ast_to_type_formatter fmt in (* Convert the qualifier identifier *) let qualif_s = match qualif.id with @@ -654,12 +756,17 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) let field_s = adt_field_to_string value_fmt adt_id field_id in (* Adopting an F*-like syntax *) ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s + | TraitConst (trait_ref, generics, const_name) -> + let trait_ref = trait_ref_to_string ty_fmt true trait_ref in + let generics_s = generic_args_to_string ty_fmt generics in + if generics <> empty_generic_args then + "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name + else trait_ref ^ "." ^ const_name in (* Convert the type instantiation *) - let ty_fmt = ast_to_type_formatter fmt in - let tys = List.map (ty_to_string ty_fmt true) qualif.type_args in + let generics = generic_args_to_strings ty_fmt true qualif.generics in (* *) - (qualif_s, tys) + (qualif_s, generics) | _ -> (* "Regular" expression case *) let inside = args <> [] || (args = [] && inside) in @@ -674,7 +781,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) texpression_to_string fmt inside indent1 indent_incr in let args = List.map arg_to_string args in - let all_args = List.append tys args in + let all_args = List.append generics args in (* Put together *) let e = if all_args = [] then app else app ^ " " ^ String.concat " " all_args diff --git a/compiler/Pure.ml b/compiler/Pure.ml index ac4ca081..47c7beb4 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -13,6 +13,9 @@ module FieldId = T.FieldId module SymbolicValueId = V.SymbolicValueId module FunDeclId = A.FunDeclId module GlobalDeclId = A.GlobalDeclId +module TraitDeclId = T.TraitDeclId +module TraitImplId = T.TraitImplId +module TraitClauseId = T.TraitClauseId (** We redefine identifiers for loop: in {!Values}, the identifiers are global (they monotonically increase across functions) while in {!module:Pure} we want @@ -37,6 +40,13 @@ module ConstGenericVarId = T.ConstGenericVarId type integer_type = T.integer_type [@@deriving show, ord] type const_generic_var = T.const_generic_var [@@deriving show, ord] type const_generic = T.const_generic [@@deriving show, ord] +type const_generic_var_id = T.const_generic_var_id [@@deriving show, ord] +type trait_decl_id = T.trait_decl_id [@@deriving show, ord] +type trait_impl_id = T.trait_impl_id [@@deriving show, ord] +type trait_clause_id = T.trait_clause_id [@@deriving show, ord] +type trait_item_name = T.trait_item_name [@@deriving show, ord] +type global_decl_id = T.global_decl_id [@@deriving show, ord] +type fun_decl_id = A.fun_decl_id [@@deriving show, ord] (** The assumed types for the pure AST. @@ -176,6 +186,14 @@ class ['self] iter_ty_base = inherit! [_] T.iter_const_generic inherit! [_] PV.iter_literal_type method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> () + method visit_trait_decl_id : 'env -> trait_decl_id -> unit = fun _ _ -> () + method visit_trait_impl_id : 'env -> trait_impl_id -> unit = fun _ _ -> () + + method visit_trait_clause_id : 'env -> trait_clause_id -> unit = + fun _ _ -> () + + method visit_trait_item_name : 'env -> trait_item_name -> unit = + fun _ _ -> () end (** Ancestor for map visitor for [ty] *) @@ -185,6 +203,18 @@ class ['self] map_ty_base = inherit! [_] T.map_const_generic inherit! [_] PV.map_literal_type method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x + + method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id = + fun _ x -> x + + method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id = + fun _ x -> x + + method visit_trait_clause_id : 'env -> trait_clause_id -> trait_clause_id = + fun _ x -> x + + method visit_trait_item_name : 'env -> trait_item_name -> trait_item_name = + fun _ x -> x end (** Ancestor for reduce visitor for [ty] *) @@ -194,6 +224,18 @@ class virtual ['self] reduce_ty_base = inherit! [_] T.reduce_const_generic inherit! [_] PV.reduce_literal_type method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero + + method visit_trait_decl_id : 'env -> trait_decl_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_impl_id : 'env -> trait_impl_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_clause_id : 'env -> trait_clause_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_item_name : 'env -> trait_item_name -> 'a = + fun _ _ -> self#zero end (** Ancestor for mapreduce visitor for [ty] *) @@ -205,10 +247,24 @@ class virtual ['self] mapreduce_ty_base = method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a = fun _ x -> (x, self#zero) + + method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_clause_id + : 'env -> trait_clause_id -> trait_clause_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_item_name + : 'env -> trait_item_name -> trait_item_name * 'a = + fun _ x -> (x, self#zero) end type ty = - | Adt of type_id * ty list * const_generic list + | Adt of type_id * generic_args (** {!Adt} encodes ADTs and tuples and assumed types. TODO: what about the ended regions? (ADTs may be parameterized @@ -219,8 +275,38 @@ type ty = | TypeVar of type_var_id | Literal of literal_type | Arrow of ty * ty + | TraitType of trait_ref * generic_args * string + (** The string is for the name of the associated type *) + +and trait_ref = { + trait_id : trait_instance_id; + generics : generic_args; + trait_decl_ref : trait_decl_ref; +} + +and trait_decl_ref = { + trait_decl_id : trait_decl_id; + decl_generics : generic_args; (* The name: annoying field collisions... *) +} + +and generic_args = { + types : ty list; + const_generics : const_generic list; + trait_refs : trait_ref list; +} + +and trait_instance_id = + | Self + | TraitImpl of trait_impl_id + | Clause of trait_clause_id + | ParentClause of trait_instance_id * trait_decl_id * trait_clause_id + | ItemClause of + trait_instance_id * trait_decl_id * trait_item_name * trait_clause_id + | TraitRef of trait_ref + | UnknownTrait of string [@@deriving show, + ord, visitors { name = "iter_ty"; @@ -264,12 +350,37 @@ type type_decl_kind = Struct of field list | Enum of variant list | Opaque type type_var = T.type_var [@@deriving show] +type trait_clause = { + clause_id : trait_clause_id; + trait_id : trait_decl_id; + generics : generic_args; +} +[@@deriving show] + +type generic_params = { + types : type_var list; + const_generics : const_generic_var list; + trait_clauses : trait_clause list; +} +[@@deriving show] + +type trait_type_constraint = { + trait_ref : trait_ref; + generics : generic_args; + type_name : trait_item_name; + ty : ty; +} +[@@deriving show, ord] + +type predicates = { trait_type_constraints : trait_type_constraint list } +[@@deriving show] + type type_decl = { def_id : TypeDeclId.id; name : name; - type_params : type_var list; - const_generic_params : const_generic_var list; + generics : generic_params; kind : type_decl_kind; + preds : predicates; } [@@deriving show] @@ -420,8 +531,15 @@ type pure_assumed_fun_id = | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) [@@deriving show, ord] +type fun_id_or_trait_method_ref = + | FunId of A.fun_id + | TraitMethod of trait_ref * string * fun_decl_id + (** The fun decl id is not really needed and here for convenience purposes *) +[@@deriving show, ord] + (** A function id for a non-assumed function *) -type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option +type regular_fun_id = + fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -457,23 +575,20 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] type qualif_id = | FunOrOp of fun_or_op_id (** A function or an operation *) - | Global of GlobalDeclId.id + | Global of global_decl_id | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) | Proj of projection (** Field projector *) + | TraitConst of trait_ref * generic_args * string + (** A trait associated constant *) [@@deriving show] -(** An instantiated qualified. +(** An instantiated qualifier. Note that for now we have a clear separation between types and expressions, - which explains why we have the [type_params] field: a function or ADT + which explains why we have the [generics] field: a function or ADT constructor is always fully instantiated. *) -type qualif = { - id : qualif_id; - type_args : ty list; - const_generic_args : const_generic list; -} -[@@deriving show] +type qualif = { id : qualif_id; generics : generic_args } [@@deriving show] type field_id = FieldId.id [@@deriving show, ord] type var_id = VarId.id [@@deriving show, ord] @@ -536,6 +651,7 @@ class virtual ['self] mapreduce_expression_base = *) type expression = | Var of var_id (** a variable *) + | CVar of const_generic_var_id (** a const generic var *) | Const of literal | App of texpression * texpression (** Application of a function to an argument. @@ -787,11 +903,11 @@ type fun_sig_info = { - etc. *) type fun_sig = { - type_params : type_var list; - const_generic_params : const_generic_var list; + generics : generic_params; (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) + preds : predicates; inputs : ty list; - (** The input types. + (** The types of the inputs. Note that those input types take into account the [fuel] parameter, if the function uses fuel for termination, and the [state] parameter, @@ -861,8 +977,11 @@ type fun_body = { } [@@deriving show] +type fun_kind = A.fun_kind [@@deriving show] + type fun_decl = { def_id : FunDeclId.id; + kind : fun_kind; num_loops : int; (** The number of loops in the parent forward function (basically the number of loops appearing in the original Rust functions, unless some loops are @@ -882,3 +1001,29 @@ type fun_decl = { body : fun_body option; } [@@deriving show] + +type trait_decl = { + def_id : trait_decl_id; + name : name; + generics : generic_params; + preds : predicates; + all_trait_clauses : trait_clause list; + consts : (trait_item_name * (ty * global_decl_id option)) list; + types : (trait_item_name * (trait_clause list * ty option)) list; + required_methods : (trait_item_name * fun_decl_id) list; + provided_methods : (trait_item_name * fun_decl_id option) list; +} +[@@deriving show] + +type trait_impl = { + def_id : trait_impl_id; + name : name; + impl_trait : trait_decl_ref; + generics : generic_params; + preds : predicates; + consts : (trait_item_name * (ty * global_decl_id)) list; + types : (trait_item_name * (trait_ref list * ty)) list; + required_methods : (trait_item_name * fun_decl_id) list; + provided_methods : (trait_item_name * fun_decl_id) list; +} +[@@deriving show] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index b6025df4..cedc3559 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ty = e.ty in let ctx, e = match e.e with - | Var _ -> (* Nothing to do *) (ctx, e.e) - | Const _ -> (* Nothing to do *) (ctx, e.e) + | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e) | App (app, arg) -> let ctx, app = update_texpression app ctx in let ctx, arg = update_texpression arg ctx in @@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args = _; - const_generic_args = _; + generics = _; } -> (* Lookup the def *) - let decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in + let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in (* Check that there are as many arguments as there are fields - note that the def should have a body (otherwise we couldn't use the constructor) *) @@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Check if the definition is recursive *) let is_rec = match - TypeDeclId.Map.find adt_id - ctx.type_context.type_decls_groups + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups with | NonRec _ -> false | Rec _ -> true @@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) | _ -> false in (* And either: - * 2.1 the right-expression is a variable or a global *) - let var_or_global = is_var re || is_global re in + * 2.1 the right-expression is a variable, a global or a const generic var *) + let var_or_global = is_var re || is_cvar re || is_global re in (* Or: * 2.2 the right-expression is a constant value, an ADT value, * a projection or a primitive function call *and* the flag @@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call [f@fwd x]. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : A.fun_id) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) + (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) + (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list) + let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) (args1 : texpression list) : bool = (* Check the fun_ids, to see if call1's function is a child of call0's function *) match fun_id1 with @@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (* We need to use the regions hierarchy *) (* First, lookup the signature of the LLBC function *) let sg = - LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls + let id0 = + match id0 with + | FunId fun_id -> fun_id + | TraitMethod (_, _, fun_decl_id) -> A.Regular fun_decl_id + in + LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls in (* Compute the set of ancestors of the function in call1 *) let call1_ancestors = @@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let input_eq (v0, v1) = PureUtils.remove_meta v0 = PureUtils.remove_meta v1 in - (* Compare the input types and the prefix of the input arguments *) - tys0 = tys1 && List.for_all input_eq args + (* Compare the generics and the prefix of the input arguments *) + generics0 = generics1 && List.for_all input_eq args else (* Not a child *) false else (* Not the same function *) @@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) method! visit_texpression env e = match e.e with - | Var _ | Const _ -> fun _ -> false + | Var _ | CVar _ | Const _ -> fun _ -> false | StructUpdate _ -> (* There shouldn't be monadic calls in structure updates - also note that by returning [false] we are conservative: we might @@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) | Let (_, _, re, e) -> ( match opt_destruct_function_call re with | None -> fun () -> self#visit_texpression env e () - | Some (func1, tys1, args1) -> - let call_is_child = check_call func1 tys1 args1 in + | Some (func1, generics1, args1) -> + let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) | App _ -> ( @@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Var _ | Const _ | App _ | Qualif _ + | Var _ | CVar _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) | StructUpdate _ | Abs _ -> @@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args; - const_generic_args; + generics; } -> (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) let adt_decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in let fields = match adt_decl.kind with @@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = * [x.field] for some variable [x], and where the projection * is for the proper ADT *) let to_var_proj (i : int) (arg : texpression) : - (ty list * const_generic list * var_id) option = + (generic_args * var_id) option = match arg.e with | App (proj, x) -> ( match (proj.e, x.e) with @@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = proj_type_args; - const_generic_args = proj_const_generic_args; + generics = proj_generics; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if proj_adt_id = adt_id && FieldId.to_int field_id = i - then - Some (proj_type_args, proj_const_generic_args, v) + then Some (proj_generics, v) else None | _ -> None) | _ -> None @@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = if List.length args = num_fields then (* Check that this is the same variable we project from - * note that we checked above that there is at least one field *) - let (_, _, x), end_args = Collections.List.pop args in - if List.for_all (fun (_, _, y) -> y = x) end_args then ( + let (_, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, y) -> y = x) end_args then ( (* We can substitute *) (* Sanity check: all types correct *) assert ( List.for_all - (fun (tys, cgs, _) -> - tys = type_args && cgs = const_generic_args) + (fun (generics1, _) -> generics1 = generics) args); { e with e = Var x }) else super#visit_texpression env e @@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | ( Qualif { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = _; - const_generic_args = _; + generics = _; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) @@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig = { - type_params = fun_sig.type_params; - const_generic_params = fun_sig.const_generic_params; + generics = fun_sig.generics; + preds = fun_sig.preds; inputs = inputs_tys; output; doutputs; @@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_def = { def_id = def.def_id; + kind = def.kind; num_loops; loop_id = Some loop.loop_id; back_id = def.back_id; @@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = In such situation, we can remove the forward function definition altogether. *) -let keep_forward (trans : pure_fun_translation) : bool = - let (fwd, _), backs = trans in +let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if !Config.filter_useless_functions - && fwd.signature.output = mk_result_ty mk_unit_ty + && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] then false else true @@ -1528,7 +1523,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( match fun_id with - | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> ( + | Fun (FromLlbc (FunId (A.Assumed aid), _lp_id, rg_id)) -> ( (* Below, when dealing with the arguments: we consider the very * general case, where functions could be boxed (meaning we * could have: [box_new f x]) @@ -1568,7 +1563,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | ArraySubsliceMut | SliceIndexShared | SliceIndexMut | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | SliceLen ), + | ArrayRepeat | SliceLen ), _ ) -> super#visit_texpression env e) | _ -> super#visit_texpression env e) @@ -1914,7 +1909,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = [ctx]: used only for printing. *) let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - (fun_decl * fun_decl list) option = + fun_and_loops option = (* Debug *) log#ldebug (lazy @@ -1955,9 +1950,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def, loops = decompose_loops def in (* Apply the remaining passes *) - let def = apply_end_passes_to_def ctx def in + let f = apply_end_passes_to_def ctx def in let loops = List.map (apply_end_passes_to_def ctx) loops in - Some (def, loops) + Some { f; loops } (** Small utility for {!filter_loop_inputs} *) let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = @@ -1983,8 +1978,8 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) (** Filter the useless loop input parameters. *) -let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : - (bool * pure_fun_translation) list = +let filter_loop_inputs (transl : pure_fun_translation list) : + pure_fun_translation list = (* We need to explore groups of mutually recursive functions. In order to compute which parameters are useless, we need to explore the functions by groups of mutually recursive definitions. @@ -2002,10 +1997,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (List.concat (List.concat (List.map - (fun (_, ((fwd, loops_fwd), backs)) -> - [ fwd :: loops_fwd ] + (fun { fwd; backs; _ } -> + [ fwd.f :: fwd.loops ] :: List.map - (fun (back, loops_back) -> [ back :: loops_back ]) + (fun { f = back; loops = loops_back } -> + [ back :: loops_back ]) backs) transl))) in @@ -2030,7 +2026,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : additional parameters. *) let used_map = ref FunLoopIdMap.empty in - let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = @@ -2075,8 +2070,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id')) -> - if fun_id_to_fun_loop_id fun_id' = fun_id then ( + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial applications of loop functions: the number of arguments @@ -2135,22 +2130,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (* We then apply the filtering to all the function definitions at once *) let filter_in_one (decl : fun_decl) : fun_decl = (* Filter the function signature *) - let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in + let fun_id = (A.Regular decl.def_id, decl.loop_id) in let decl = - match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) decl | Some used_info -> let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } = + let { generics; preds; inputs; output; doutputs; info } = decl.signature in let { @@ -2178,16 +2166,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : effect_info; } in - let signature = - { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } - in + let signature = { generics; preds; inputs; output; doutputs; info } in { decl with signature } in @@ -2201,9 +2180,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let { inputs; inputs_lvs; body } = body in let inputs, inputs_lvs = - match - FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map - with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) (inputs, inputs_lvs) | Some used_info -> let inputs = filter_prefix used_info inputs in @@ -2223,11 +2200,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id)) -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) + -> ( match - FunLoopIdMap.find_opt - (fun_id_to_fun_loop_id fun_id) - !used_map + FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with | None -> super#visit_texpression env e | Some used_info -> @@ -2267,13 +2243,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : in let transl = List.map - (fun (b, (fwd, backs)) -> - let filter_fun_and_loops (f, fl) = - (filter_in_one f, List.map filter_in_one fl) + (fun trans -> + let filter_fun_and_loops f = + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } in - let fwd = filter_fun_and_loops fwd in - let backs = List.map filter_fun_and_loops backs in - (b, (fwd, backs))) + let fwd = filter_fun_and_loops trans.fwd in + let backs = List.map filter_fun_and_loops trans.backs in + { trans with fwd; backs }) transl in @@ -2294,18 +2270,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : - (bool * pure_fun_translation) list = - let apply_to_one (trans : fun_decl * fun_decl list) : - bool * pure_fun_translation = + (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = + let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = (* Apply the passes to the individual functions *) - let forward, backwards = trans in - let forward = Option.get (apply_passes_to_def ctx forward) in - let backwards = List.filter_map (apply_passes_to_def ctx) backwards in - let trans = (forward, backwards) in + let fwd, backs = trans in + let fwd = Option.get (apply_passes_to_def ctx fwd) in + let backs = List.filter_map (apply_passes_to_def ctx) backs in (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) + let keep_fwd = keep_forward fwd backs in + { keep_fwd; fwd; backs } in + let transl = List.map apply_to_one transl in (* Filter the useless inputs in the loop functions *) diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 8d28bb8a..b80ff72f 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -9,17 +9,19 @@ open PureUtils of fields is fixed: it shouldn't be used for arrays, slices, etc. *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) - (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) - (cgs : const_generic list) : ty list = + (type_id : type_id) (variant_id : VariantId.id option) + (generics : generic_args) : ty list = match type_id with | Tuple -> (* Tuple *) + assert (generics.const_generics = []); + assert (generics.trait_refs = []); assert (variant_id = None); - tys + generics.types | AdtId def_id -> (* "Regular" ADT *) let def = TypeDeclId.Map.find def_id type_decls in - type_decl_get_instantiated_fields_types def variant_id tys cgs + type_decl_get_instantiated_fields_types def variant_id generics | Assumed aty -> ( (* Assumed type *) match aty with @@ -27,14 +29,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) (* This type is opaque *) raise (Failure "Unreachable: opaque type") | Result -> - let ty = Collections.List.to_cons_nil tys in + let ty = Collections.List.to_cons_nil generics.types in let variant_id = Option.get variant_id in if variant_id = result_return_id then [ ty ] else if variant_id = result_fail_id then [ mk_error_ty ] else raise (Failure "Unreachable: improper variant id for result type") | Error -> - assert (tys = []); + assert (generics = empty_generic_args); let variant_id = Option.get variant_id in assert ( variant_id = error_failure_id || variant_id = error_out_of_fuel_id); @@ -45,14 +47,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) else if variant_id = fuel_succ_id then [ mk_fuel_ty ] else raise (Failure "Unreachable: improper variant id for fuel type") | Option -> - let ty = Collections.List.to_cons_nil tys in + let ty = Collections.List.to_cons_nil generics.types in let variant_id = Option.get variant_id in if variant_id = option_some_id then [ ty ] else if variant_id = option_none_id then [] else raise (Failure "Unreachable: improper variant id for option type") | Range -> - let ty = Collections.List.to_cons_nil tys in + let ty = Collections.List.to_cons_nil generics.types in assert (variant_id = None); [ ty; ty ] | Vec | Array | Slice | Str -> @@ -65,6 +67,9 @@ type tc_ctx = { global_decls : A.global_decl A.GlobalDeclId.Map.t; (** The global declarations *) env : ty VarId.Map.t; (** Environment from variables to types *) + const_generics : ty T.ConstGenericVarId.Map.t; + (** The types of the const generics *) + (* TODO: add trait type constraints *) } let check_literal (v : literal) (ty : literal_type) : unit = @@ -86,12 +91,13 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = { ctx with env } | PatAdt av -> (* Compute the field types *) - let type_id, tys, cgs = ty_as_adt v.ty in + let type_id, generics = ty_as_adt v.ty in let field_tys = - get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs + get_adt_field_types ctx.type_decls type_id av.variant_id generics in let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx = if ty <> v.ty then ( + (* TODO: we need to normalize the types *) log#serror ("check_typed_pattern: not the same types:" ^ "\n- ty: " ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty); @@ -115,6 +121,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match VarId.Map.find_opt var_id ctx.env with | None -> () | Some ty -> assert (ty = e.ty)) + | CVar cg_id -> + let ty = T.ConstGenericVarId.Map.find cg_id ctx.const_generics in + assert (ty = e.ty) | Const cv -> check_literal cv (ty_as_literal e.ty) | App (app, arg) -> let input_ty, output_ty = destruct_arrow app.ty in @@ -133,35 +142,34 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match qualif.id with | FunOrOp _ -> () (* TODO *) | Global _ -> () (* TODO *) + | TraitConst _ -> () (* TODO *) | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) let adt_ty, field_ty = destruct_arrow e.ty in - let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in + let adt_id, adt_generics = ty_as_adt adt_ty in (* Check the ADT type *) assert (adt_id = proj_adt_id); - assert (adt_type_args = qualif.type_args); - assert (adt_cg_args = qualif.const_generic_args); + assert (adt_generics = qualif.generics); (* Retrieve and check the expected field type *) let variant_id = None in let expected_field_tys = get_adt_field_types ctx.type_decls proj_adt_id variant_id - qualif.type_args qualif.const_generic_args + qualif.generics in let expected_field_ty = FieldId.nth expected_field_tys field_id in assert (expected_field_ty = field_ty) | AdtCons id -> ( let expected_field_tys = get_adt_field_types ctx.type_decls id.adt_id id.variant_id - qualif.type_args qualif.const_generic_args + qualif.generics in let field_tys, adt_ty = destruct_arrows e.ty in assert (expected_field_tys = field_tys); match adt_ty with - | Adt (type_id, tys, cgs) -> + | Adt (type_id, generics) -> assert (type_id = id.adt_id); - assert (tys = qualif.type_args); - assert (cgs = qualif.const_generic_args) + assert (generics = qualif.generics) | _ -> raise (Failure "Unreachable"))) | Let (monadic, pat, re, e_next) -> let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in @@ -207,15 +215,14 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = | Some ty -> assert (ty = e.ty)); (* Check the fields *) (* Retrieve and check the expected field type *) - let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in + let adt_id, adt_generics = ty_as_adt e.ty in assert (adt_id = supd.struct_id); (* The id can only be: a custom type decl or an array *) match adt_id with | AdtId _ -> let variant_id = None in let expected_field_tys = - get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args - adt_cg_args + get_adt_field_types ctx.type_decls adt_id variant_id adt_generics in List.iter (fun (fid, fe) -> @@ -224,7 +231,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | Assumed Array -> - let expected_field_ty = Collections.List.to_cons_nil adt_type_args in + let expected_field_ty = + Collections.List.to_cons_nil adt_generics.types + in List.iter (fun (_, fe) -> assert (expected_field_ty = fe.ty); diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 1c8d8921..4e44f252 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) (projection : mprojection) : mplace = { var_id; name; projection } +let empty_generic_params : generic_params = + { types = []; const_generics = []; trait_clauses = [] } + +let empty_generic_args : generic_args = + { types = []; const_generics = []; trait_refs = [] } + +let mk_generic_args_from_types (types : ty list) : generic_args = + { types; const_generics = []; trait_refs = [] } + +type subst = { + ty_subst : TypeVarId.id -> ty; + cg_subst : ConstGenericVarId.id -> const_generic; + tr_subst : TraitClauseId.id -> trait_instance_id; + tr_self : trait_instance_id; +} + (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = +let ty_substitute (subst : subst) (ty : ty) : ty = let obj = object inherit [_] map_ty - method! visit_TypeVar _ var_id = tsubst var_id - method! visit_ConstGenericVar _ var_id = cgsubst var_id + method! visit_TypeVar _ var_id = subst.ty_subst var_id + method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end in obj#visit_ty () ty @@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list) (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = Substitute.make_const_generic_subst_from_vars vars cgs +let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) : + TraitClauseId.id -> trait_instance_id = + let clauses = List.map (fun x -> x.clause_id) clauses in + let refs = List.map (fun x -> TraitRef x) refs in + let ls = List.combine clauses refs in + let mp = + List.fold_left + (fun mp (k, v) -> TraitClauseId.Map.add k v mp) + TraitClauseId.Map.empty ls + in + fun id -> TraitClauseId.Map.find id mp + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl) - def: " ^ show_type_decl def ^ "\n- opt_variant_id: " ^ opt_variant_id)) +let make_subst_from_generics (params : generic_params) (args : generic_args) + (tr_self : trait_instance_id) : subst = + let ty_subst = make_type_subst params.types args.types in + let cg_subst = + make_const_generic_subst params.const_generics args.const_generics + in + let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in + { ty_subst; cg_subst; tr_subst; tr_self } + (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) - (cgs : const_generic list) : ty list = - let ty_subst = make_type_subst def.type_params types in - let cg_subst = make_const_generic_subst def.const_generic_params cgs in + (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list = + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.generics generics tr_self in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields + List.map (fun f -> ty_substitute subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : - inst_fun_sig = - let subst = ty_substitute tsubst cgsubst in +let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = + let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -164,7 +200,8 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false + | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> + false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -184,15 +221,18 @@ let is_var (e : texpression) : bool = let as_var (e : texpression) : VarId.id = match e.e with Var v -> v | _ -> raise (Failure "Unreachable") +let is_cvar (e : texpression) : bool = + match e.e with CVar _ -> true | _ -> false + let is_global (e : texpression) : bool = match e.e with Qualif { id = Global _; _ } -> true | _ -> false let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = +let ty_as_adt (ty : ty) : type_id * generic_args = match ty with - | Adt (id, tys, cgs) -> (id, tys, cgs) + | Adt (id, generics) -> (id, generics) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -290,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list = (** Destruct an expression into a function call, if possible *) let opt_destruct_function_call (e : texpression) : - (fun_or_op_id * ty list * texpression list) option = + (fun_or_op_id * generic_args * texpression list) option = match opt_destruct_qualif_app e with | None -> None | Some (qualif, args) -> ( match qualif.id with - | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args) + | FunOrOp fun_id -> Some (fun_id, qualif.generics, args) | _ -> None) let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys, cgs) -> - assert (cgs = []); - Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some (Collections.List.to_cons_nil generics.types) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = match ty with - | Adt (Tuple, tys, cgs) -> - assert (cgs = []); - Some tys + | Adt (Tuple, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some generics.types | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = @@ -383,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) + match tys with + | [ ty ] -> ty + | _ -> Adt (Tuple, mk_generic_args_from_types tys) let mk_bool_ty : ty = Literal Bool -let mk_unit_ty : ty = Adt (Tuple, [], []) +let mk_unit_ty : ty = Adt (Tuple, empty_generic_args) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -430,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -441,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types tys } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -463,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type = let ty_as_literal (t : ty) : T.literal_type = match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, [], []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) -let mk_error_ty : ty = Adt (Assumed Error, [], []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) +let mk_state_ty : ty = Adt (Assumed State, empty_generic_args) + +let mk_result_ty (ty : ty) : ty = + Adt (Assumed Result, mk_generic_args_from_types [ ty ]) + +let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args) +let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ], cgs) -> - assert (cgs = []); + | Adt + (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] }) + -> ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -501,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -514,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -526,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -561,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in + let adt_id, generics = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values @@ -577,3 +625,20 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option Some (mk_apps cons fields_values).e in match e_opt with None -> None | Some e -> Some { e; ty = pat.ty } + +type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id } + +let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) : + trait_decl_method_decl_id = + (* First look in the required methods *) + let method_id = + List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods + in + match method_id with + | Some (_, id) -> { is_provided = false; id } + | None -> + (* Must be a provided method *) + let _, id = + List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods + in + { is_provided = true; id = Option.get id } diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index fc4744bc..10b68da3 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -38,14 +38,16 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t = method! visit_qualif _ id = match id.id with - | FunOrOp (Unop _ | Binop _) | Global _ | AdtCons _ | Proj _ -> () + | FunOrOp (Unop _ | Binop _) + | Global _ | AdtCons _ | Proj _ | TraitConst _ -> + () | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () | FromLlbc (fid, lp_id, rg_id) -> ( match fid with - | Assumed _ -> () - | Regular fid -> + | FunId (Assumed _) -> () + | TraitMethod (_, _, fid) | FunId (Regular fid) -> let id = { def_id = fid; lp_id; rg_id } in ids := FunIdSet.add id !ids)) end diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 38850243..b1680282 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -9,51 +9,67 @@ module E = Expressions module A = LlbcAst module C = Contexts -(** Substitute types variables and regions in a type. *) -let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) : - 'r2 T.ty = - let open T in - let visitor = - object - inherit [_] map_ty - method visit_'r _ r = rsubst r - method! visit_TypeVar _ id = tsubst id +type ('r1, 'r2) subst = { + r_subst : 'r1 -> 'r2; + ty_subst : T.TypeVarId.id -> 'r2 T.ty; + cg_subst : T.ConstGenericVarId.id -> T.const_generic; + (** Substitution from *local* trait clause to trait instance *) + tr_subst : T.TraitClauseId.id -> 'r2 T.trait_instance_id; + (** Substitution for the [Self] trait instance *) + tr_self : 'r2 T.trait_instance_id; +} + +let ty_substitute_visitor (subst : ('r1, 'r2) subst) = + object + inherit [_] T.map_ty + method visit_'r _ r = subst.r_subst r + method! visit_TypeVar _ id = subst.ty_subst id - method! visit_type_var_id _ _ = - (* We should never get here because we reimplemented [visit_TypeVar] *) - raise (Failure "Unexpected") + method! visit_type_var_id _ _ = + (* We should never get here because we reimplemented [visit_TypeVar] *) + raise (Failure "Unexpected") - method! visit_ConstGenericVar _ id = cgsubst id + method! visit_ConstGenericVar _ id = subst.cg_subst id - method! visit_const_generic_var_id _ _ = - (* We should never get here because we reimplemented [visit_Var] *) - raise (Failure "Unexpected") - end - in + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") - visitor#visit_ty () ty + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self + end -let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty = - let rsubst r = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) - in - ty_substitute rsubst tsubst cgsubst ty +(** Substitute types variables and regions in a type. + + **IMPORTANT**: this doesn't normalize the types. + *) +let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty = + let visitor = ty_substitute_visitor subst in + visitor#visit_ty () ty -let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety = - let rsubst r = r in - ty_substitute rsubst tsubst cgsubst ty +(** **IMPORTANT**: this doesn't normalize the types. *) +let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) : + 'r2 T.trait_ref = + let visitor = ty_substitute_visitor subst in + visitor#visit_trait_ref () tr + +(** **IMPORTANT**: this doesn't normalize the types. *) +let generic_args_substitute (subst : ('r1, 'r2) subst) (g : 'r1 T.generic_args) + : 'r2 T.generic_args = + let visitor = ty_substitute_visitor subst in + visitor#visit_generic_args () g + +let erase_regions_subst : ('r, T.erased_region) subst = + { + r_subst = (fun _ -> T.Erased); + ty_subst = (fun vid -> T.TypeVar vid); + cg_subst = (fun id -> T.ConstGenericVar id); + tr_subst = (fun id -> T.Clause id); + tr_self = T.Self; + } (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) -let erase_regions (ty : T.rty) : T.ety = - ty_substitute - (fun _ -> T.Erased) - (fun vid -> T.TypeVar vid) - (fun id -> T.ConstGenericVar id) - ty +let erase_regions (ty : 'r T.ty) : T.ety = ty_substitute erase_regions_subst ty (** Generate fresh regions for region variables. @@ -78,18 +94,20 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : (* Generate the substitution from region var id to region *) let rid_subst id = T.RegionVarId.Map.find id rid_map in (* Generate the substitution from region to region *) - let rsubst r = + let r_subst r = match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id) in (* Return *) - (fresh_region_ids, rid_subst, rsubst) + (fresh_region_ids, rid_subst, r_subst) -(** Erase the regions in a type and substitute the type variables *) -let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) - (ty : 'r T.region T.ty) : T.ety = - let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in - ty_substitute rsubst tsubst cgsubst ty +(** Erase the regions in a type and perform a substitution *) +let erase_regions_substitute_types (ty_subst : T.TypeVarId.id -> T.ety) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.etrait_instance_id) + (tr_self : T.etrait_instance_id) (ty : 'r T.ty) : T.ety = + let r_subst (_ : 'r) : T.erased_region = T.Erased in + let subst = { r_subst; ty_subst; cg_subst; tr_subst; tr_self } in + ty_substitute subst ty (** Create a region substitution from a list of region variable ids and a list of regions (with which to substitute the region variable ids *) @@ -146,16 +164,81 @@ let make_const_generic_subst_from_vars (vars : T.const_generic_var list) (List.map (fun (x : T.const_generic_var) -> x.T.index) vars) cgs -(** Instantiate the type variables in an ADT definition, and return, for - every variant, the list of the types of its fields *) -let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list = - let r_subst = make_region_subst_from_vars def.T.region_params regions in - let ty_subst = make_type_subst_from_vars def.T.type_params types in +(** Create a trait substitution from a list of trait clause ids and a list of + trait refs *) +let make_trait_subst (clause_ids : T.TraitClauseId.id list) + (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id = + let ls = List.combine clause_ids trs in + let mp = + List.fold_left + (fun mp (k, v) -> T.TraitClauseId.Map.add k (T.TraitRef v) mp) + T.TraitClauseId.Map.empty ls + in + fun id -> T.TraitClauseId.Map.find id mp + +let make_trait_subst_from_clauses (clauses : T.trait_clause list) + (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id = + make_trait_subst + (List.map (fun (x : T.trait_clause) -> x.T.clause_id) clauses) + trs + +let make_subst_from_generics (params : T.generic_params) + (args : 'r T.region T.generic_args) + (tr_self : 'r T.region T.trait_instance_id) : + (T.region_var_id T.region, 'r T.region) subst = + let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in + let ty_subst = make_type_subst_from_vars params.T.types args.T.types in let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs + make_const_generic_subst_from_vars params.T.const_generics + args.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs + in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +let make_subst_from_generics_no_regions : + 'r. + T.generic_params -> + 'r T.generic_args -> + 'r T.trait_instance_id -> + (T.region_var_id T.region, 'r) subst = + fun params args tr_self -> + let r_subst _ = raise (Failure "Unexpected region") in + let ty_subst = make_type_subst_from_vars params.T.types args.T.types in + let cg_subst = + make_const_generic_subst_from_vars params.T.const_generics + args.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs + in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +let make_esubst_from_generics (params : T.generic_params) + (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) = + let r_subst _ = T.Erased in + let ty_subst = make_type_subst_from_vars params.types generics.T.types in + let cg_subst = + make_const_generic_subst_from_vars params.const_generics + generics.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.trait_clauses generics.T.trait_refs in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +(** Instantiate the type variables in an ADT definition, and return, for + every variant, the list of the types of its fields. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) +let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) + (generics : T.rgeneric_args) : (T.VariantId.id option * T.rty list) list = + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.T.generics generics tr_self in let (variants_fields : (T.VariantId.id option * T.field list) list) = match def.T.kind with | T.Enum variants -> @@ -171,191 +254,232 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) in List.map (fun (id, fields) -> - ( id, - List.map - (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) - fields )) + (id, List.map (fun f -> ty_substitute subst f.T.field_ty) fields)) variants_fields (** Instantiate the type variables in an ADT definition, and return the list - of types of the fields for the chosen variant *) + of types of the fields for the chosen variant. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let type_decl_get_instantiated_field_rtypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : T.rty list = - let r_subst = make_region_subst_from_vars def.T.region_params regions in - let ty_subst = make_type_subst_from_vars def.T.type_params types in - let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs - in + (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) : + T.rty list = + (* For now, check that there are no clauses - otherwise we might need + to normalize the types *) + assert (def.generics.trait_clauses = []); + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.T.generics generics tr_self in let fields = TU.type_decl_get_fields def opt_variant_id in - List.map - (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) - fields + List.map (fun f -> ty_substitute subst f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a - context *) + context. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : T.rty list = + (generics : T.rgeneric_args) : T.rty list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs + type_decl_get_instantiated_field_rtypes def opt_variant_id generics (** Return the types of the properly instantiated ADT value (note that - here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *) + here, ADT is understood in its broad meaning: ADT, assumed value or tuple). + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. + *) let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) - (adt : V.adt_value) (id : T.type_id) - (region_params : T.RegionId.id T.region list) (type_params : T.rty list) - (cg_params : T.const_generic list) : T.rty list = + (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) : + T.rty list = match id with | T.AdtId id -> (* Retrieve the types of the fields *) - ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id - region_params type_params cg_params + ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id generics | T.Tuple -> - assert (List.length region_params = 0); - type_params + assert (generics.regions = []); + generics.types | T.Assumed aty -> ( match aty with | T.Box | T.Vec -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - type_params + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + generics.types | T.Option -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - if adt.V.variant_id = Some T.option_some_id then type_params + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + if adt.V.variant_id = Some T.option_some_id then generics.types else if adt.V.variant_id = Some T.option_none_id then [] else raise (Failure "Unreachable") | T.Range -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - type_params + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + generics.types | T.Array | T.Slice | T.Str -> (* Those types don't have fields *) raise (Failure "Unreachable")) (** Instantiate the type variables in an ADT definition, and return the list - of types of the fields for the chosen variant *) + of types of the fields for the chosen variant. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let type_decl_get_instantiated_field_etypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) (types : T.ety list) - (cgs : T.const_generic list) : T.ety list = - let ty_subst = make_type_subst_from_vars def.T.type_params types in - let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs + (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) : + T.ety list = + (* For now, check that there are no clauses - otherwise we might need + to normalize the types *) + assert (def.generics.trait_clauses = []); + (* There shouldn't be any reference to Self *) + let tr_self : T.erased_region T.trait_instance_id = + T.UnknownTrait __FUNCTION__ + in + let { r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } = + make_esubst_from_generics def.T.generics generics tr_self in let fields = TU.type_decl_get_fields def opt_variant_id in List.map - (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty) + (fun (f : T.field) -> + erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self + f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a - context *) + context. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. + *) let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (types : T.ety list) (cgs : T.const_generic list) : T.ety list = + (generics : T.egeneric_args) : T.ety list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + type_decl_get_instantiated_field_etypes def opt_variant_id generics -let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) = +let statement_substitute_visitor + (subst : (T.erased_region, T.erased_region) subst) = + (* Keep in synch with [ty_substitute_visitor] *) object inherit [_] A.map_statement - method! visit_ety _ ty = ety_substitute tsubst cgsubst ty - method! visit_ConstGenericVar _ id = cgsubst id + method! visit_'r _ r = subst.r_subst r + method! visit_TypeVar _ id = subst.ty_subst id + + method! visit_type_var_id _ _ = + (* We should never get here because we reimplemented [visit_TypeVar] *) + raise (Failure "Unexpected") + + method! visit_ConstGenericVar _ id = subst.cg_subst id method! visit_const_generic_var_id _ _ = (* We should never get here because we reimplemented [visit_Var] *) raise (Failure "Unexpected") + + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end (** Apply a type substitution to a place *) -let place_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) : - E.place = +let place_substitute (subst : (T.erased_region, T.erased_region) subst) + (p : E.place) : E.place = (* There is in fact nothing to do *) - (statement_substitute_visitor tsubst cgsubst)#visit_place () p + (statement_substitute_visitor subst)#visit_place () p (** Apply a type substitution to an operand *) -let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) : - E.operand = - (statement_substitute_visitor tsubst cgsubst)#visit_operand () op +let operand_substitute (subst : (T.erased_region, T.erased_region) subst) + (op : E.operand) : E.operand = + (statement_substitute_visitor subst)#visit_operand () op (** Apply a type substitution to an rvalue *) -let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) : - E.rvalue = - (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv +let rvalue_substitute (subst : (T.erased_region, T.erased_region) subst) + (rv : E.rvalue) : E.rvalue = + (statement_substitute_visitor subst)#visit_rvalue () rv (** Apply a type substitution to an assertion *) -let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) : - A.assertion = - (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a +let assertion_substitute (subst : (T.erased_region, T.erased_region) subst) + (a : A.assertion) : A.assertion = + (statement_substitute_visitor subst)#visit_assertion () a (** Apply a type substitution to a call *) -let call_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) : - A.call = - (statement_substitute_visitor tsubst cgsubst)#visit_call () call +let call_substitute (subst : (T.erased_region, T.erased_region) subst) + (call : A.call) : A.call = + (statement_substitute_visitor subst)#visit_call () call (** Apply a type substitution to a statement *) -let statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) : - A.statement = - (statement_substitute_visitor tsubst cgsubst)#visit_statement () st +let statement_substitute (subst : (T.erased_region, T.erased_region) subst) + (st : A.statement) : A.statement = + (statement_substitute_visitor subst)#visit_statement () st (** Apply a type substitution to a function body. Return the local variables and the body. *) -let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) : +let fun_body_substitute_in_body + (subst : (T.erased_region, T.erased_region) subst) (body : A.fun_body) : A.var list * A.statement = - let rsubst r = r in let locals = List.map - (fun (v : A.var) -> - { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty }) + (fun (v : A.var) -> { v with A.var_ty = ty_substitute subst v.A.var_ty }) body.A.locals in - let body = statement_substitute tsubst cgsubst body.body in + let body = statement_substitute subst body.body in (locals, body) -(** Substitute a function signature *) +let trait_type_constraint_substitute (subst : ('r1, 'r2) subst) + (ttc : 'r1 T.trait_type_constraint) : 'r2 T.trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let visitor = ty_substitute_visitor subst in + let trait_ref = visitor#visit_trait_ref () trait_ref in + let generics = visitor#visit_generic_args () generics in + let ty = visitor#visit_ty () ty in + { T.trait_ref; generics; type_name; ty } + +(** Substitute a function signature. + + **IMPORTANT:** this function doesn't normalize the types. + *) let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) - (rsubst : T.RegionVarId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) : - A.inst_fun_sig = - let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + (r_subst : T.RegionVarId.id -> T.RegionId.id) + (ty_subst : T.TypeVarId.id -> T.rty) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + let r_subst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = + match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid) in - let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in - let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in + let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in + let inputs = List.map (ty_substitute subst) sg.A.inputs in + let output = ty_substitute subst sg.A.output in let subst_region_group (rg : T.region_var_group) : A.abs_region_group = let id = asubst rg.id in - let regions = List.map rsubst rg.regions in + let regions = List.map r_subst rg.regions in let parents = List.map asubst rg.parents in { id; regions; parents } in let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in - { A.regions_hierarchy; inputs; output } + let trait_type_constraints = + List.map + (trait_type_constraint_substitute subst) + sg.preds.trait_type_constraints + in + { A.inputs; output; regions_hierarchy; trait_type_constraints } -(** Substitute type variable identifiers in a type *) -let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) +(** Substitute variable identifiers in a type *) +let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) : 'r T.ty = let open T in let visitor = object inherit [_] map_ty method visit_'r _ r = r - method! visit_type_var_id _ id = tsubst id - method! visit_const_generic_var_id _ id = cgsubst id + method! visit_type_var_id _ id = ty_subst id + method! visit_const_generic_var_id _ id = cg_subst id end in @@ -371,10 +495,10 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) [visit_'r] if we define a class which visits objects of types [ety] and [rty] while inheriting a class which visit [ty]... *) -let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) +let subst_ids_visitor (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) = @@ -383,10 +507,10 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] T.map_ty method visit_'r _ r = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid) - method! visit_type_var_id _ id = tsubst id - method! visit_const_generic_var_id _ id = cgsubst id + method! visit_type_var_id _ id = ty_subst id + method! visit_const_generic_var_id _ id = cg_subst id end in @@ -395,7 +519,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] C.map_env method! visit_borrow_id _ bid = bsubst bid method! visit_loan_id _ bid = bsubst bid - method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty + method! visit_ety _ ty = ty_substitute_ids ty_subst cg_subst ty method! visit_rty env ty = subst_rty#visit_ty env ty method! visit_symbolic_value_id _ id = ssubst id @@ -405,7 +529,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) (** We *do* visit meta-values *) method! visit_mvalue env v = self#visit_typed_value env v - method! visit_region_id _ id = rsubst id + method! visit_region_id _ id = r_subst id method! visit_region_var_id _ id = rvsubst id method! visit_abstraction_id _ id = asubst id end @@ -425,20 +549,20 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) method visit_env (env : C.env) : C.env = visitor#visit_env () env end -let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_value_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) : V.typed_value = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_typed_value v -let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_value_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (v : V.typed_value) : V.typed_value = - typed_value_subst_ids rsubst + typed_value_subst_ids r_subst (fun x -> x) (fun x -> x) (fun x -> x) @@ -446,41 +570,41 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) v -let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_avalue_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_typed_avalue v -let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let abs_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs = - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_abs x -let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let env_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env = - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_env x -let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_avalue_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst + (subst_ids_visitor r_subst (fun x -> x) (fun x -> x) (fun x -> x) @@ -490,9 +614,9 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) #visit_typed_avalue x -let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env - = - (subst_ids_visitor rsubst +let env_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : C.env) : + C.env = + (subst_ids_visitor r_subst (fun x -> x) (fun x -> x) (fun x -> x) diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 7dc94dcd..4df8fec7 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -29,7 +29,7 @@ type mplace = { [@@deriving show] type call_id = - | Fun of A.fun_id * V.FunCallId.id + | Fun of A.fun_id_or_trait_method_ref * V.FunCallId.id (** A "regular" function (i.e., a function which is not a primitive operation) *) | Unop of E.unop | Binop of E.binop @@ -43,10 +43,7 @@ type call = { borrows (we need to perform lookups). *) abstractions : V.AbstractionId.id list; - (* TODO: rename to "...args" *) - type_params : T.ety list; - (* TODO: rename to "...args" *) - const_generic_params : T.const_generic list; + generics : T.egeneric_args; args : V.typed_value list; args_places : mplace option list; (** Meta information *) dest : V.symbolic_value; @@ -79,6 +76,9 @@ class ['self] iter_expression_base = method visit_loop_id : 'env -> V.loop_id -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () + method visit_const_generic_var_id : 'env -> T.const_generic_var_id -> unit = + fun _ _ -> () + method visit_symbolic_value_id : 'env -> V.symbolic_value_id -> unit = fun _ _ -> () @@ -120,6 +120,9 @@ class ['self] iter_expression_base = method visit_symbolic_expansion : 'env -> V.symbolic_expansion -> unit = fun _ _ -> () + + method visit_etrait_ref : 'env -> T.etrait_ref -> unit = fun _ _ -> () + method visit_egeneric_args : 'env -> T.egeneric_args -> unit = fun _ _ -> () end (** **Rem.:** here, {!expression} is not at all equivalent to the expressions @@ -171,14 +174,15 @@ type expression = * expression (** We introduce a new symbolic value, equal to some other value. - This is used for instance when reorganizing the environment to compute - fixed points: we duplicate some shared symbolic values to destructure - the shared values, in order to make the environment a bit more general - (while losing precision of course). + This is used for instance when reorganizing the environment to compute + fixed points: we duplicate some shared symbolic values to destructure + the shared values, in order to make the environment a bit more general + (while losing precision of course). We also use it to introduce symbolic + values when evaluating constant generics, or trait constants. - The context is the evaluation context from before introducing the new - value. It has the same purpose as for the {!Return} case. - *) + The context is the evaluation context from before introducing the new + value. It has the same purpose as for the {!Return} case. + *) | ForwardEnd of Contexts.eval_ctx * V.typed_value symbolic_value_id_map option @@ -253,6 +257,11 @@ and value_aggregate = | SingleValue of V.typed_value (** Regular case *) | Array of V.typed_value list (** This is used when introducing array aggregates *) + | ConstGenericValue of T.const_generic_var_id + (** This is used when evaluating a const generic value: in the interpreter, + we introduce a fresh symbolic value. *) + | TraitConstValue of T.etrait_ref * T.egeneric_args * string + (** A trait constant value *) [@@deriving show, visitors diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3512270a..429198ad 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -4,6 +4,7 @@ open Pure open PureUtils module Id = Identifiers module C = Contexts +module A = LlbcAst module S = SymbolicAst module TA = TypesAnalysis module L = Logging @@ -52,6 +53,9 @@ type fun_context = { type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } [@@deriving show] +type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] +type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show] + (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended children abstractions). @@ -106,8 +110,7 @@ type loop_info = { loop_id : LoopId.id; input_vars : var list; input_svl : V.symbolic_value list; - type_args : ty list; - const_generic_args : const_generic list; + generics : generic_args; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; @@ -120,6 +123,8 @@ type bs_ctx = { type_context : type_context; fun_context : fun_context; global_context : global_context; + trait_decls_ctx : trait_decls_context; + trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) sg : fun_sig; @@ -201,34 +206,11 @@ type bs_ctx = { } [@@deriving show] -let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = - let env = VarId.Map.empty in - let ctx = - { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; - env; - } - in - let _ = PureTypeCheck.check_typed_pattern ctx v in - () - -let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = - let env = VarId.Map.empty in - let ctx = - { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; - env; - } - in - PureTypeCheck.check_texpression ctx e - (* TODO: move *) let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter = Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls - ctx.fun_decl + ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = let rvar_to_string = Print.Types.region_var_id_to_string in @@ -246,16 +228,25 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = adt_variant_to_string = ast_fmt.adt_variant_to_string; var_id_to_string; adt_field_names = ast_fmt.adt_field_names; + trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string; + trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string; + trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string; } let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = - let type_params = ctx.fun_decl.signature.type_params in - let cg_params = ctx.fun_decl.signature.const_generic_params in + let generics = ctx.fun_decl.signature.generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types + generics.const_generics + +let ctx_egeneric_args_to_string (ctx : bs_ctx) (args : T.egeneric_args) : string + = + let fmt = bs_ctx_to_ctx_formatter ctx in + let fmt = Print.PC.ctx_to_etype_formatter fmt in + Print.PT.egeneric_args_to_string fmt args let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string = let fmt = bs_ctx_to_ctx_formatter ctx in @@ -277,12 +268,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = Print.PT.rty_to_string fmt ty let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = - let type_params = def.type_params in - let cg_params = def.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx + ctx.trait_impls_ctx def.generics.types def.generics.const_generics in PrintPure.type_decl_to_string fmt def @@ -291,26 +281,27 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = PrintPure.texpression_to_string fmt false "" " " e let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = - let type_params = sg.type_params in - let cg_params = sg.const_generic_params in + let type_params = sg.generics.types in + let cg_params = sg.generics.const_generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = - let type_params = def.signature.type_params in - let cg_params = def.signature.const_generic_params in + let generics = def.signature.generics in + let type_params = generics.types in + let cg_params = generics.const_generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params in PrintPure.fun_decl_to_string fmt def @@ -328,17 +319,18 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = Print.Values.abs_to_string fmt verbose indent indent_incr abs let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (tys : ty list) - (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig = + (back_id : T.RegionGroupId.id option) (generics : generic_args) + (ctx : bs_ctx) : inst_fun_sig = (* Lookup the non-instantiated function signature *) let sg = (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg in (* Create the substitution *) - let tsubst = make_type_subst sg.type_params tys in - let cgsubst = make_const_generic_subst sg.const_generic_params cgs in + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics sg.generics generics tr_self in (* Apply *) - fun_sig_substitute tsubst cgsubst sg + fun_sig_substitute subst sg let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = @@ -354,74 +346,121 @@ let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) let id = (A.Regular def_id, back_id) in (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg -let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = - let calls = ctx.calls in - assert (not (V.FunCallId.Map.mem call_id calls)); - let info = - { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } - in - let calls = V.FunCallId.Map.add call_id info calls in - { ctx with calls } - -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = - (* Insert the abstraction in the call informations *) - let info = V.FunCallId.Map.find call_id ctx.calls in - assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = - T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards - in - let info = { info with backwards } in - let calls = V.FunCallId.Map.add call_id info ctx.calls in - (* Insert the abstraction in the abstractions map *) - let abstractions = ctx.abstractions in - assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); - let abstractions = - V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions - in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") - in - (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) +(* Some generic translation functions (we need to translate different "flavours" + of types: sty, forward types, backward types, etc.) *) +let rec translate_generic_args (translate_ty : 'r T.ty -> ty) + (generics : 'r T.generic_args) : generic_args = + (* We ignore the regions: if they didn't cause trouble for the symbolic execution, + then everything's fine *) + let types = List.map translate_ty generics.types in + let const_generics = generics.const_generics in + let trait_refs = + List.map (translate_trait_ref translate_ty) generics.trait_refs + in + { types; const_generics; trait_refs } + +and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) : + trait_ref = + let trait_id = translate_trait_instance_id translate_ty tr.trait_id in + let generics = translate_generic_args translate_ty tr.generics in + let trait_decl_ref = + translate_trait_decl_ref translate_ty tr.trait_decl_ref + in + { trait_id; generics; trait_decl_ref } + +and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty) + (tr : 'r T.trait_decl_ref) : trait_decl_ref = + let decl_generics = translate_generic_args translate_ty tr.decl_generics in + { trait_decl_id = tr.trait_decl_id; decl_generics } + +and translate_trait_instance_id (translate_ty : 'r T.ty -> ty) + (id : 'r T.trait_instance_id) : trait_instance_id = + let translate_trait_instance_id = translate_trait_instance_id translate_ty in + match id with + | T.Self -> Self + | TraitImpl id -> TraitImpl id + | BuiltinOrAuto _ -> + (* We should have eliminated those in the prepasses *) + raise (Failure "Unreachable") + | Clause id -> Clause id + | ParentClause (inst_id, decl_id, clause_id) -> + let inst_id = translate_trait_instance_id inst_id in + ParentClause (inst_id, decl_id, clause_id) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> + let inst_id = translate_trait_instance_id inst_id in + ItemClause (inst_id, decl_id, item_name, clause_id) + | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr) + | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s)) let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys, cgs) -> ( - (* Can't translate types with regions for now *) - assert (regions = []); - let tys = List.map translate tys in + | T.Adt (type_id, generics) -> ( + let generics = translate_sgeneric_args generics in match type_id with - | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs) - | T.Tuple -> mk_simpl_tuple_ty tys + | T.AdtId adt_id -> Adt (AdtId adt_id, generics) + | T.Tuple -> + assert (generics.const_generics = []); + mk_simpl_tuple_ty generics.types | T.Assumed aty -> ( match aty with - | T.Vec -> Adt (Assumed Vec, tys, cgs) - | T.Option -> Adt (Assumed Option, tys, cgs) + | T.Vec -> Adt (Assumed Vec, generics) + | T.Option -> Adt (Assumed Option, generics) | T.Box -> ( (* Eliminate the boxes *) - match tys with + match generics.types with | [ ty ] -> ty | _ -> raise (Failure "Box/vec/option type with incorrect number of arguments") ) - | T.Array -> Adt (Assumed Array, tys, cgs) - | T.Slice -> Adt (Assumed Slice, tys, cgs) - | T.Str -> Adt (Assumed Str, tys, cgs) - | T.Range -> Adt (Assumed Range, tys, cgs))) + | T.Array -> Adt (Assumed Array, generics) + | T.Slice -> Adt (Assumed Slice, generics) + | T.Str -> Adt (Assumed Str, generics) + | T.Range -> Adt (Assumed Range, generics))) | TypeVar vid -> TypeVar vid | Literal ty -> Literal ty | Never -> raise (Failure "Unreachable") | Ref (_, rty, _) -> translate rty + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = translate_strait_ref trait_ref in + let generics = translate_sgeneric_args generics in + TraitType (trait_ref, generics, type_name) + +and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args = + translate_generic_args translate_sty generics + +and translate_strait_ref (tr : T.strait_ref) : trait_ref = + translate_trait_ref translate_sty tr + +and translate_strait_instance_id (id : T.strait_instance_id) : trait_instance_id + = + translate_trait_instance_id translate_sty id + +let translate_trait_clause (clause : T.trait_clause) : trait_clause = + let { T.clause_id; meta = _; trait_id; generics } = clause in + let generics = translate_sgeneric_args generics in + { clause_id; trait_id; generics } + +let translate_strait_type_constraint (ttc : T.strait_type_constraint) : + trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let trait_ref = translate_strait_ref trait_ref in + let generics = translate_sgeneric_args generics in + let ty = translate_sty ty in + { trait_ref; generics; type_name; ty } + +let translate_predicates (preds : T.predicates) : predicates = + let trait_type_constraints = + List.map translate_strait_type_constraint preds.trait_type_constraints + in + { trait_type_constraints } + +let translate_generic_params (generics : T.generic_params) : generic_params = + let { T.regions = _; types; const_generics; trait_clauses } = generics in + let trait_clauses = List.map translate_trait_clause trait_clauses in + { types; const_generics; trait_clauses } let translate_field (f : T.field) : field = let field_name = f.field_name in @@ -452,15 +491,16 @@ let translate_type_decl_kind (kind : T.type_decl_kind) : type_decl_kind = point of moving this definition for now. *) let translate_type_decl (def : T.type_decl) : type_decl = - (* Translate *) let def_id = def.T.def_id in let name = def.name in + let { T.regions; types; const_generics; trait_clauses } = def.generics in (* Can't translate types with regions for now *) - assert (def.region_params = []); - let type_params = def.type_params in - let const_generic_params = def.const_generic_params in + assert (regions = []); + let trait_clauses = List.map translate_trait_clause trait_clauses in + let generics = { types; const_generics; trait_clauses } in let kind = translate_type_decl_kind def.T.kind in - { def_id; name; type_params; const_generic_params; kind } + let preds = translate_predicates def.preds in + { def_id; name; generics; kind; preds } let translate_type_id (id : T.type_id) : type_id = match id with @@ -488,28 +528,33 @@ let translate_type_id (id : T.type_id) : type_id = let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty type_infos in match ty with - | T.Adt (type_id, regions, tys, cgs) -> ( - (* Can't translate types with regions for now *) - assert (regions = []); - (* Translate the type parameters *) - let t_tys = List.map translate tys in + | T.Adt (type_id, generics) -> ( + let t_generics = translate_fwd_generic_args type_infos generics in (* Eliminate boxes and simplify tuples *) match type_id with | AdtId _ | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); + assert ( + not + (List.exists + (TypesUtils.ty_has_borrows type_infos) + generics.types)); let type_id = translate_type_id type_id in - Adt (type_id, t_tys, cgs) + Adt (type_id, t_generics) | Tuple -> (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the identity *) - mk_simpl_tuple_ty t_tys + mk_simpl_tuple_ty t_generics.types | T.Assumed T.Box -> ( (* We eliminate boxes *) (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); - match t_tys with + assert ( + not + (List.exists + (TypesUtils.ty_has_borrows type_infos) + generics.types)); + match t_generics.types with | [ bty ] -> bty | _ -> raise @@ -520,12 +565,34 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = | Never -> raise (Failure "Unreachable") | Literal lty -> Literal lty | Ref (_, rty, _) -> translate rty + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + TraitType (trait_ref, generics, type_name) + +and translate_fwd_generic_args (type_infos : TA.type_infos) + (generics : 'r T.generic_args) : generic_args = + translate_generic_args (translate_fwd_ty type_infos) generics + +and translate_fwd_trait_ref (type_infos : TA.type_infos) (tr : 'r T.trait_ref) : + trait_ref = + translate_trait_ref (translate_fwd_ty type_infos) tr + +and translate_fwd_trait_instance_id (type_infos : TA.type_infos) + (id : 'r T.trait_instance_id) : trait_instance_id = + translate_trait_instance_id (translate_fwd_ty type_infos) id (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = let type_infos = ctx.type_context.type_infos in translate_fwd_ty type_infos ty +(** Simply calls [translate_fwd_generic_args] *) +let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : 'r T.generic_args) + : generic_args = + let type_infos = ctx.type_context.type_infos in + translate_fwd_generic_args type_infos generics + (** Translate a type, when some regions may have ended. We return an option, because the translated type may be empty. @@ -538,7 +605,7 @@ let rec translate_back_ty (type_infos : TA.type_infos) (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with - | T.Adt (type_id, _, tys, cgs) -> ( + | T.Adt (type_id, generics) -> ( match type_id with | T.AdtId _ | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> @@ -546,22 +613,24 @@ let rec translate_back_ty (type_infos : TA.type_infos) assert (not (TypesUtils.ty_has_borrows type_infos ty)); let type_id = translate_type_id type_id in if inside_mut then - let tys_t = List.filter_map translate tys in - Some (Adt (type_id, tys_t, cgs)) + (* We do not want to filter anything, so we translate the generics + as "forward" types *) + let generics = translate_fwd_generic_args type_infos generics in + Some (Adt (type_id, generics)) else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows type_infos ty)); (* Eliminate the box *) - match tys with + match generics.types with | [ bty ] -> translate bty | _ -> raise (Failure "Unreachable: boxes receive exactly one type parameter") ) | T.Tuple -> ( - (* Tuples can contain borrows (which we eliminated) *) - let tys_t = List.filter_map translate tys in + (* Tuples can contain borrows (which we eliminate) *) + let tys_t = List.filter_map translate generics.types in match tys_t with | [] -> None | _ -> @@ -582,6 +651,13 @@ let rec translate_back_ty (type_infos : TA.type_infos) if keep_region r then translate_back_ty type_infos keep_region inside_mut rty else None) + | TraitType (trait_ref, generics, type_name) -> + assert (generics.regions = []); + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TraitType (trait_ref, generics, type_name)) (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) @@ -589,6 +665,80 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) let type_infos = ctx.type_context.type_infos in translate_back_ty type_infos keep_region inside_mut ty +let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = + let const_generics = + T.ConstGenericVarId.Map.of_list + (List.map + (fun (cg : T.const_generic_var) -> + (cg.index, ctx_translate_fwd_ty ctx (T.Literal cg.ty))) + ctx.sg.generics.const_generics) + in + let env = VarId.Map.empty in + { + PureTypeCheck.type_decls = ctx.type_context.type_decls; + global_decls = ctx.global_context.llbc_global_decls; + env; + const_generics; + } + +let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = + let ctx = mk_type_check_ctx ctx in + let _ = PureTypeCheck.check_typed_pattern ctx v in + () + +let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = + if !Config.type_check_pure_code then + let ctx = mk_type_check_ctx ctx in + PureTypeCheck.check_texpression ctx e + +let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) + (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref = + match id with + | A.FunId fun_id -> FunId fun_id + | TraitMethod (trait_ref, method_name, fun_decl_id) -> + let type_infos = ctx.type_context.type_infos in + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + TraitMethod (trait_ref, method_name, fun_decl_id) + +let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) + (args : texpression list) (ctx : bs_ctx) : bs_ctx = + let calls = ctx.calls in + assert (not (V.FunCallId.Map.mem call_id calls)); + let info = + { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } + in + let calls = V.FunCallId.Map.add call_id info calls in + { ctx with calls } + +(** [back_args]: the *additional* list of inputs received by the backward function *) +let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) + (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) + : bs_ctx * fun_or_op_id = + (* Insert the abstraction in the call informations *) + let info = V.FunCallId.Map.find call_id ctx.calls in + assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); + let backwards = + T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards + in + let info = { info with backwards } in + let calls = V.FunCallId.Map.add call_id info ctx.calls in + (* Insert the abstraction in the abstractions map *) + let abstractions = ctx.abstractions in + assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); + let abstractions = + V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions + in + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + (* Update the context and return *) + ({ ctx with calls; abstractions }, fun_id) + (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) : V.AbstractionId.id list = @@ -642,10 +792,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : (** Small utility. *) let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (lid : V.LoopId.id option) + (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with - | A.Regular fid -> + | A.TraitMethod (_, _, fid) | A.FunId (A.Regular fid) -> let info = A.FunDeclId.Map.find fid fun_infos in let stateful_group = info.stateful in let stateful = @@ -658,7 +808,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) can_diverge = info.can_diverge; is_rec = info.is_rec || Option.is_some lid; } - | A.Assumed aid -> + | A.FunId (A.Assumed aid) -> assert (lid = None); { can_fail = Assumed.assumed_can_fail aid; @@ -673,12 +823,14 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding name (outputs for backward functions come from borrows in the inputs - of the forward function) which we use as hints to generate pretty names. + of the forward function) which we use as hints to generate pretty names + in the extracted code. *) -let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig) - (input_names : string option list) (bid : T.RegionGroupId.id option) : - fun_sig_named_outputs = +let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) + (sg : A.fun_sig) (input_names : string option list) + (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + let fun_infos = decls_ctx.fun_ctx.fun_infos in + let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) let gid, parents = match bid with @@ -689,7 +841,34 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) in (* Is the function stateful, and can it fail? *) let lid = None in - let effect_info = get_fun_effect_info fun_infos fun_id lid bid in + let effect_info = get_fun_effect_info fun_infos (A.FunId fun_id) lid bid in + (* We need an evaluation context to normalize the types (to normalize the + associated types, etc. - for instance it may happen that the types + refer to the types associated to a trait ref, but where the trait ref + is a known impl). *) + (* Create the context *) + let ctx = + let region_groups = + List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy + in + let ctx = + InterpreterUtils.initialize_eval_context decls_ctx region_groups + sg.generics.types sg.generics.const_generics + in + (* Compute the normalization map for the *sty* types and add it to the context *) + AssociatedTypes.ctx_add_norm_trait_stypes_from_preds ctx + sg.preds.trait_type_constraints + in + + (* Normalize the signature *) + let sg = + let ({ A.inputs; output; _ } : A.fun_sig) = sg in + let norm = AssociatedTypes.ctx_normalize_sty ctx in + let inputs = List.map norm inputs in + let output = norm output in + { sg with A.inputs; output } + in + (* List the inputs for: * - the fuel * - the forward function @@ -806,9 +985,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Wrap in a result type *) if effect_info.can_fail then mk_result_ty output else output in - (* Type/const generic parameters *) - let type_params = sg.type_params in - let const_generic_params = sg.const_generic_params in + (* Generic parameters *) + let generics = translate_generic_params sg.generics in (* Return *) let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in @@ -836,9 +1014,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) effect_info; } in - let sg = - { type_params; const_generic_params; inputs; output; doutputs; info } - in + let preds = translate_predicates sg.A.preds in + let sg = { generics; preds; inputs; output; doutputs; info } in { sg; output_names } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = @@ -917,7 +1094,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = (** Peel boxes as long as the value is of the form [Box<T>] *) let rec unbox_typed_value (v : V.typed_value) : V.typed_value = match (v.value, v.ty) with - | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> ( + | V.Adt av, T.Adt (T.Assumed T.Box, _) -> ( match av.field_values with | [ bv ] -> unbox_typed_value bv | _ -> raise (Failure "Unreachable")) @@ -962,16 +1139,16 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) let field_values = List.map translate av.field_values in (* Eliminate the tuple wrapper if it is a tuple with exactly one field *) match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> assert (variant_id = None); mk_simpl_tuple_texpression field_values | _ -> - (* Retrieve the type, the translated type arguments and the - * const generic arguments from the translated type (simpler this way) *) - let adt_id, type_args, const_generic_args = ty_as_adt ty in + (* Retrieve the type and the translated generics from the translated + type (simpler this way) *) + let adt_id, generics = ty_as_adt ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) field_values @@ -1038,7 +1215,7 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the field values *) let field_values = List.filter_map translate adt_v.field_values in (* For now, only tuples can contain borrows *) - let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _ = TypesUtils.ty_as_adt av.ty in match adt_id with | T.AdtId _ | T.Assumed @@ -1185,7 +1362,7 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) (* For now, only tuples can contain borrows - note that if we gave * something like a [&mut Vec] to a function, we give back the * vector value upon visiting the "abstraction borrow" node *) - let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _ = TypesUtils.ty_as_adt av.ty in match adt_id with | T.AdtId _ | T.Assumed @@ -1457,9 +1634,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = + log#ldebug + (lazy + ("translate_function_call:\n" + ^ ctx_egeneric_args_to_string ctx call.generics)); (* Translate the function call *) - let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - let const_generic_args = call.const_generic_params in + let generics = ctx_translate_fwd_generic_args ctx call.generics in let args = let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in let args_mplaces = List.map translate_opt_mplace call.args_places in @@ -1475,7 +1655,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) - let func = Fun (FromLlbc (fid, None, None)) in + let fid_t = translate_fun_id_or_trait_method_ref ctx fid in + let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = @@ -1561,7 +1742,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let func = { id = FunOrOp fun_id; type_args; const_generic_args } in + let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty @@ -1665,9 +1846,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) (* Group the two lists *) let variables_values = List.combine given_back_variables consumed_values in (* Sanity check: the two lists match (same types) *) - List.iter - (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) - variables_values; + (* TODO: normalize the types *) + if !Config.type_check_pure_code then + List.iter + (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) + variables_values; (* Translate the next expression *) let next_e = translate_expression e ctx in (* Generate the assignemnts *) @@ -1692,8 +1875,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let effect_info = get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) in - let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - let const_generic_args = call.const_generic_params in + let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in (* Retrieve the values consumed when we called the forward function and @@ -1741,34 +1923,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - let _ = - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args - ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - in + (if (* TODO: normalize the types *) !Config.type_check_pure_code then + match fun_id with + | A.FunId fun_id -> + let inst_sg = + get_instantiated_fun_sig fun_id (Some rg_id) generics ctx + in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) + | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = @@ -1788,7 +1971,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; type_args; const_generic_args } in + let func = { id = FunOrOp func; generics } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -1907,12 +2090,11 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopCall -> let fun_id = A.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id) - (Some rg_id) + get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fun_id) + (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in - let type_args = loop_info.type_args in - let const_generic_args = loop_info.const_generic_args in + let generics = loop_info.generics in let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we @@ -1960,8 +2142,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; type_args; const_generic_args } in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -2021,9 +2203,7 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in - let global_expr = - { id = Global gid; type_args = []; const_generic_args = [] } - in + let global_expr = { id = Global gid; generics = empty_generic_args } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in let gval = { e = Qualif global_expr; ty } in @@ -2037,11 +2217,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) let v = typed_value_to_texpression ctx ectx v in let args = [ v ] in let func = - { - id = FunOrOp (Fun (Pure Assert)); - type_args = []; - const_generic_args = []; - } + { id = FunOrOp (Fun (Pure Assert)); generics = empty_generic_args } in let func_ty = mk_arrow (Literal Bool) mk_unit_ty in let func = { e = Qualif func; ty = func_ty } in @@ -2189,7 +2365,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (branch : S.expression) (ctx : bs_ctx) : texpression = (* TODO: always introduce a match, and use micro-passes to turn the the match into a let? *) - let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let type_id, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let branch = translate_expression branch ctx in match type_id with @@ -2224,10 +2400,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) * field. * We use the [dest] variable in order not to have to recompute * the type of the result of the projection... *) - let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in + let adt_id, generics = ty_as_adt scrutinee.ty in let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args; const_generic_args } in + let qualif = { id = Proj proj_kind; generics } in let proj_e = Qualif qualif in let proj_ty = mk_arrow scrutinee.ty dest.ty in let proj = { e = proj_e; ty = proj_ty } in @@ -2282,8 +2458,9 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) (* Translate the next expression *) let next_e = translate_expression e ctx in - (* Translate the value: there are two cases, depending on whether this - is a "regular" let-binding or an array aggregate. + (* Translate the value: there are several cases, depending on whether this + is a "regular" let-binding, an array aggregate, a const generic or + a trait associated constant. *) let v = match v with @@ -2298,6 +2475,14 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) { struct_id = Assumed Array; init = None; updates = values } in { e = StructUpdate su; ty = var.ty } + | ConstGenericValue cg_id -> { e = CVar cg_id; ty = var.ty } + | TraitConstValue (trait_ref, generics, const_name) -> + let type_infos = ctx.type_context.type_infos in + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + let qualif_id = TraitConst (trait_ref, generics, const_name) in + let qualif = { id = qualif_id; generics = empty_generic_args } in + { e = Qualif qualif; ty = var.ty } in (* Make the let-binding *) @@ -2370,7 +2555,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = A.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid + get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) @@ -2415,14 +2600,8 @@ and translate_forward_end (ectx : C.eval_ctx) let out_pat = mk_simpl_tuple_pattern out_pats in let loop_call = - let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in - let func = - { - id = FunOrOp fun_id; - type_args = loop_info.type_args; - const_generic_args = loop_info.const_generic_args; - } - in + let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in + let func = { id = FunOrOp fun_id; generics = loop_info.generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty @@ -2541,14 +2720,31 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Note that we will retrieve the input values later in the [ForwardEnd] (and will introduce the outputs at that moment, together with the actual - call to the loop forward function *) - let type_args = - List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params - in - let const_generic_args = - List.map - (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) - ctx.sg.const_generic_params + call to the loop forward function) *) + let generics = + let { types; const_generics; trait_clauses } = ctx.sg.generics in + let types = + List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) types + in + let const_generics = + List.map + (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) + const_generics + in + let trait_refs = + List.map + (fun (c : trait_clause) -> + let trait_decl_ref = + { trait_decl_id = c.trait_id; decl_generics = empty_generic_args } + in + { + trait_id = Clause c.clause_id; + generics = empty_generic_args; + trait_decl_ref; + }) + trait_clauses + in + { types; const_generics; trait_refs } in let loop_info = @@ -2556,8 +2752,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = loop_id; input_vars = inputs; input_svl = loop.input_svalues; - type_args; - const_generic_args; + generics; forward_inputs = None; forward_output_no_state_no_result = None; } @@ -2648,8 +2843,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let func = { id = FunOrOp (Fun (Pure FuelEqZero)); - type_args = []; - const_generic_args = []; + generics = empty_generic_args; } in let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in @@ -2661,8 +2855,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let func = { id = FunOrOp (Fun (Pure FuelDecrease)); - type_args = []; - const_generic_args = []; + generics = empty_generic_args; } in let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in @@ -2727,8 +2920,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None - bid + get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id)) + None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -2803,10 +2996,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs) )); - assert ( - List.for_all - (fun (var, ty) -> (var : var).ty = ty) - (List.combine inputs signature.inputs)); + (* TODO: we need to normalize the types *) + if !Config.type_check_pure_code then + assert ( + List.for_all + (fun (var, ty) -> (var : var).ty = ty) + (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in @@ -2821,6 +3016,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def = { def_id; + kind = def.kind; num_loops; loop_id; back_id = bid; @@ -2853,8 +3049,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (type_infos : TA.type_infos) +let translate_fun_signatures (decls_ctx : C.decls_ctx) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdNotLoopMap.t = (* For every function, translate the signatures of: @@ -2865,17 +3060,14 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list = (* The forward function *) - let fwd_sg = - translate_fun_sig fun_infos fun_id type_infos sg input_names None - in + let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id type_infos sg input_names - (Some rg.id) + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) @@ -2891,3 +3083,91 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.fold_left (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) RegularFunIdNotLoopMap.empty translated + +let translate_trait_decl (type_infos : TA.type_infos) + (trait_decl : A.trait_decl) : trait_decl = + let { + A.def_id; + name; + generics; + preds; + all_trait_clauses; + consts; + types; + required_methods; + provided_methods; + } = + trait_decl + in + let generics = translate_generic_params generics in + let preds = translate_predicates preds in + let all_trait_clauses = List.map translate_trait_clause all_trait_clauses in + let consts = + List.map + (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id))) + consts + in + let types = + List.map + (fun (name, (trait_clauses, ty)) -> + ( name, + ( List.map translate_trait_clause trait_clauses, + Option.map (translate_fwd_ty type_infos) ty ) )) + types + in + { + def_id; + name; + generics; + preds; + all_trait_clauses; + consts; + types; + required_methods; + provided_methods; + } + +let translate_trait_impl (type_infos : TA.type_infos) + (trait_impl : A.trait_impl) : trait_impl = + let { + A.def_id; + name; + impl_trait; + generics; + preds; + consts; + types; + required_methods; + provided_methods; + } = + trait_impl + in + let impl_trait = + translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait + in + let generics = translate_generic_params generics in + let preds = translate_predicates preds in + let consts = + List.map + (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id))) + consts + in + let types = + List.map + (fun (name, (trait_refs, ty)) -> + ( name, + ( List.map (translate_fwd_trait_ref type_infos) trait_refs, + translate_fwd_ty type_infos ty ) )) + types + in + { + def_id; + name; + impl_trait; + generics; + preds; + consts; + types; + required_methods; + provided_methods; + } diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 857fea97..aeb6899f 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) assert (otherwise_see = None); (* Return *) ExpandInt (int_ty, branches, otherwise) - | T.Adt (_, _, _, _) -> + | T.Adt (_, _) -> (* Branching: it is necessarily an enumeration expansion *) let get_variant (see : V.symbolic_expansion option) : T.VariantId.id option * V.symbolic_value list = @@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) match ls with | [ (Some see, exp) ] -> ExpandNoBranch (see, exp) | _ -> raise (Failure "Ill-formed borrow expansion")) - | T.TypeVar _ | T.Literal Char | Never -> + | T.TypeVar _ | T.Literal Char | Never | T.TraitType _ -> raise (Failure "Ill-formed symbolic expansion") in Some (Expansion (place, sv, expansion)) @@ -97,10 +97,10 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value) synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) - (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) (args : V.typed_value list) - (args_places : mplace option list) (dest : V.symbolic_value) - (dest_place : mplace option) (e : expression option) : expression option = + (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args) + (args : V.typed_value list) (args_places : mplace option list) + (dest : V.symbolic_value) (dest_place : mplace option) + (e : expression option) : expression option = Option.map (fun e -> let call = @@ -108,8 +108,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) call_id; ctx; abstractions; - type_params; - const_generic_params; + generics; args; dest; args_places; @@ -123,28 +122,29 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) (e : expression option) : expression option = Option.map (fun e -> EvalGlobal (gid, dest, e)) e -let synthesize_regular_function_call (fun_id : A.fun_id) +let synthesize_regular_function_call (fun_id : A.fun_id_or_trait_method_ref) (call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx) - (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) (args : V.typed_value list) - (args_places : mplace option list) (dest : V.symbolic_value) - (dest_place : mplace option) (e : expression option) : expression option = + (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args) + (args : V.typed_value list) (args_places : mplace option list) + (dest : V.symbolic_value) (dest_place : mplace option) + (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions type_params const_generic_params args args_places dest - dest_place e + ctx abstractions generics args args_places dest dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop) (arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest - dest_place e + let generics = TypesUtils.mk_empty_generic_args in + synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ] + dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop) (arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value) (arg1_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ] + let generics = TypesUtils.mk_empty_generic_args in + synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 70ef5e3d..e69abee1 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -5,6 +5,7 @@ module T = Types module A = LlbcAst module SA = SymbolicAst module Micro = PureMicroPasses +module C = Contexts open PureUtils open TranslateCore @@ -28,18 +29,12 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) ("translate_function_to_symbolics: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context; global_context } = trans_ctx in - let fun_context = { C.fun_decls = fun_context.fun_decls } in - match fdef.body with | None -> None | Some _ -> (* Evaluate *) let synthesize = true in - let inputs, symb = - evaluate_function_symbolic synthesize type_context fun_context - global_context fdef - in + let inputs, symb = evaluate_function_symbolic synthesize trans_ctx fdef in Some (inputs, Option.get symb) (** Translate a function, by generating its forward and backward translations. @@ -57,7 +52,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (lazy ("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context; global_context } = trans_ctx in let def_id = fdef.def_id in (* Compute the symbolic ASTs, if the function is transparent *) @@ -82,25 +76,25 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (List.filter_map (fun (tid, g) -> match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid) - (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups)) + (T.TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) in let type_context = { - SymbolicToPure.type_infos = type_context.type_infos; - llbc_type_decls = type_context.type_decls; + SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos; + llbc_type_decls = trans_ctx.type_ctx.type_decls; type_decls = pure_type_decls; recursive_decls = recursive_type_decls; } in let fun_context = { - SymbolicToPure.llbc_fun_decls = fun_context.fun_decls; + SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; fun_sigs; - fun_infos = fun_context.fun_infos; + fun_infos = trans_ctx.fun_ctx.fun_infos; } in let global_context = - { SymbolicToPure.llbc_global_decls = global_context.global_decls } + { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls } in (* Compute the set of loops, and find better ids for them (starting at 0). @@ -148,6 +142,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) type_context; fun_context; global_context; + trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls; + trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; forward_inputs = []; (* Empty for now *) @@ -274,21 +270,18 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Return *) (pure_forward, pure_backwards) +(* TODO: factor out the return type *) let translate_crate_to_pure (crate : A.crate) : - trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list = + trans_ctx + * Pure.type_decl list + * pure_fun_translation list + * Pure.trait_decl list + * Pure.trait_impl list = (* Debug *) log#ldebug (lazy "translate_crate_to_pure"); - (* Compute the type and function contexts *) - let type_context, fun_context, global_context = - compute_type_fun_global_contexts crate - in - let fun_infos = - FA.analyze_module crate fun_context.C.fun_decls - global_context.C.global_decls !Config.use_state - in - let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in - let trans_ctx = { type_context; fun_context; global_context } in + (* Compute the translation context *) + let trans_ctx = compute_contexts crate in (* Translate all the type definitions *) let type_decls = @@ -323,10 +316,7 @@ let translate_crate_to_pure (crate : A.crate) : (A.FunDeclId.Map.values crate.functions) in let sigs = List.append assumed_sigs local_sigs in - let fun_sigs = - SymbolicToPure.translate_fun_signatures fun_context.fun_infos - type_context.type_infos sigs - in + let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in (* Translate all the *transparent* functions *) let pure_translations = @@ -335,28 +325,38 @@ let translate_crate_to_pure (crate : A.crate) : (A.FunDeclId.Map.values crate.functions) in + (* Translate the trait declarations *) + let type_infos = trans_ctx.type_ctx.type_infos in + let trait_decls = + List.map + (SymbolicToPure.translate_trait_decl type_infos) + (T.TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls) + in + + (* Translate the trait implementations *) + let trait_impls = + List.map + (SymbolicToPure.translate_trait_impl type_infos) + (T.TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls) + in + (* Apply the micro-passes *) let pure_translations = Micro.apply_passes_to_pure_fun_translations trans_ctx pure_translations in (* Return *) - (trans_ctx, type_decls, pure_translations) - -(** Extraction context *) -type gen_ctx = { - crate : A.crate; - extract_ctx : ExtractBase.extraction_ctx; - trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; - functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; -} + (trans_ctx, type_decls, pure_translations, trait_decls, trait_impls) + +type gen_ctx = ExtractBase.extraction_ctx type gen_config = { extract_types : bool; extract_decreases_clauses : bool; extract_template_decreases_clauses : bool; extract_fun_decls : bool; + extract_trait_decls : bool; + extract_trait_impls : bool; extract_transparent : bool; (** If [true], extract the transparent declarations, otherwise ignore. *) extract_opaque : bool; @@ -383,21 +383,23 @@ type gen_config = { test_trans_unit_functions : bool; } -(** Returns the pair: (has opaque type decls, has opaque fun decls) *) -let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = - let has_opaque_types = - Pure.TypeDeclId.Map.exists - (fun _ (d : Pure.type_decl) -> - match d.kind with Opaque -> true | _ -> false) - ctx.trans_types - in - let has_opaque_funs = - A.FunDeclId.Map.exists - (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) -> - Option.is_none t_fwd.body) - ctx.trans_funs +(** Returns the pair: (has opaque type decls, has opaque fun decls). + + [filter_assumed]: if [true], do not consider as opaque the external definitions + that we will map to definitions from the standard library. + *) +let crate_has_opaque_decls (ctx : gen_ctx) (filter_assumed : bool) : bool * bool + = + let types, funs = + LlbcAstUtils.crate_get_opaque_decls ctx.crate filter_assumed in - (has_opaque_types, has_opaque_funs) + log#ldebug + (lazy + ("Opaque decls:" ^ "\n- types:\n" + ^ String.concat ",\n" (List.map T.show_type_decl types) + ^ "\n- functions:\n" + ^ String.concat ",\n" (List.map A.show_fun_decl funs))); + (types <> [], funs <> []) (** Export a type declaration. @@ -429,9 +431,9 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) || ((not is_opaque) && config.extract_transparent) then ( if extract_decl then - Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def; + Extract.extract_type_decl ctx fmt type_decl_group kind def; if extract_extra_info then - Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def) + Extract.extract_type_decl_extra_info ctx fmt kind def) (** Export a group of types. @@ -483,7 +485,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) End ]} *) - Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs; + Extract.start_type_decl_group ctx fmt is_rec defs; List.iteri (fun i def -> let kind = kind_from_index i in @@ -504,26 +506,34 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) *) let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = - let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in + let global_decls = ctx.trans_ctx.global_ctx.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in - let _, ((body, loop_fwds), body_backs) = - A.FunDeclId.Map.find global.body_id ctx.trans_funs - in - assert (body_backs = []); - assert (loop_fwds = []); + let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in + assert (trans.fwd.loops = []); + assert (trans.backs = []); + let body = trans.fwd.f in let is_opaque = Option.is_none body.Pure.body in - if + (* Check if we extract the global *) + let extract = config.extract_globals && (((not is_opaque) && config.extract_transparent) || (is_opaque && config.extract_opaque)) - then + in + (* Check if it is an assumed global - if yes, we ignore it because we + map the definition to one in the standard library *) + let open ExtractAssumed in + let sname = name_to_simple_name global.name in + let extract = + extract && SimpleNameMap.find_opt sname assumed_globals_map = None + in + if extract then (* We don't wrap global declaration groups between calls to functions [{start, end}_global_decl_group] (which don't exist): global declaration groups are always singletons, so the [extract_global_decl] function takes care of generating the delimiters. *) - Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface + Extract.extract_global_decl ctx fmt global body config.interface (** Utility. @@ -604,14 +614,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) then Some (fun () -> - Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause - def) + Extract.extract_fun_decl ctx fmt kind has_decr_clause def) else None) decls in let extract_defs = List.filter_map (fun x -> x) extract_defs in if extract_defs <> [] then ( - Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls; + Extract.start_fun_decl_group ctx fmt is_rec decls; List.iter (fun f -> f ()) extract_defs; Extract.end_fun_decl_group fmt is_rec decls) @@ -621,7 +630,7 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) check if the forward and backward functions are mutually recursive. *) let export_functions_group (fmt : Format.formatter) (config : gen_config) - (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = + (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) @@ -631,7 +640,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun (_, ((fwd, loop_fwds), _)) -> + (fun { fwd; _ } -> (* We only generate decreases clauses for the forward functions, because the termination argument should only depend on the forward inputs. The backward functions thus use the same decreases clauses as the @@ -647,19 +656,18 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) if has_decr_clause then match !Config.backend with | Lean -> - Extract.extract_template_lean_termination_and_decreasing - ctx.extract_ctx fmt decl + Extract.extract_template_lean_termination_and_decreasing ctx fmt + decl | FStar -> - Extract.extract_template_fstar_decreases_clause ctx.extract_ctx - fmt decl + Extract.extract_template_fstar_decreases_clause ctx fmt decl | Coq -> raise (Failure "Coq doesn't have decreases/termination clauses") | HOL4 -> raise (Failure "HOL4 doesn't have decreases/termination clauses") in - extract_decrease fwd; - List.iter extract_decrease loop_fwds) + extract_decrease fwd.f; + List.iter extract_decrease fwd.loops) pure_ls; (* Concatenate the function definitions, filtering the useless forward @@ -667,15 +675,13 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let decls = List.concat (List.map - (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) -> - let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in - let back : Pure.fun_decl list = + (fun { keep_fwd; fwd; backs } -> + let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in + let backs : Pure.fun_decl list = List.concat - (List.map - (fun (back, loop_backs) -> List.append loop_backs [ back ]) - back_ls) + (List.map (fun back -> List.append back.loops [ back.f ]) backs) in - List.append fwd back) + List.append fwd backs) pure_ls) in @@ -693,11 +699,24 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun (keep_fwd, ((fwd, _), _)) -> - if keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) + (fun trans -> + if trans.keep_fwd then + Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) pure_ls +(** Export a trait declaration. *) +let export_trait_decl (fmt : Format.formatter) (_config : gen_config) + (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) : unit = + let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in + let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in + Extract.extract_trait_decl ctx fmt trait_decl + +(** Export a trait implementation. *) +let export_trait_impl (fmt : Format.formatter) (_config : gen_config) + (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit = + let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in + Extract.extract_trait_impl ctx fmt trait_impl + (** A generic utility to generate the extracted definitions: as we may want to split the definitions between different files (or not), we can control what is precisely extracted. @@ -712,12 +731,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let export_functions_group = export_functions_group fmt config ctx in let export_global = export_global fmt config ctx in let export_types_group = export_types_group fmt config ctx in + let export_trait_decl = export_trait_decl fmt config ctx in + let export_trait_impl = export_trait_impl fmt config ctx in let export_state_type () : unit = let kind = if config.interface then ExtractBase.Declared else ExtractBase.Assumed in - Extract.extract_state_type fmt ctx.extract_ctx kind + Extract.extract_state_type fmt ctx kind in let export_decl_group (dg : A.declaration_group) : unit = @@ -725,11 +746,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) | Type (NonRec id) -> if config.extract_types then export_types_group false [ id ] | Type (Rec ids) -> if config.extract_types then export_types_group true ids - | Fun (NonRec id) -> + | Fun (NonRec id) -> ( (* Lookup *) let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in - (* Translate *) - export_functions_group [ pure_fun ] + (* Special case: we skip trait method *declarations* (we will + extract their type directly in the records we generate for + the trait declarations themselves, there is no point in having + separate type definitions) *) + match pure_fun.fwd.f.Pure.kind with + | TraitMethodDecl _ -> () + | _ -> + (* Translate *) + export_functions_group [ pure_fun ]) | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) @@ -739,11 +767,17 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (* Translate *) export_functions_group pure_funs | Global id -> export_global id + | TraitDecl id -> + if config.extract_trait_decls && config.extract_transparent then + export_trait_decl id + | TraitImpl id -> + if config.extract_trait_impls && config.extract_transparent then + export_trait_impl id in (* If we need to export the state type: we try to export it after we defined * the type definitions, because if the user wants to define a model for the - * type, he might want to reuse those in the state type. + * type, they might want to reuse those in the state type. * More specifically: if we extract functions in the same file as the type, * we have no choice but to define the state type before the functions, * because they may reuse this state type: in this case, we define/declare @@ -764,7 +798,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) config.extract_opaque && config.extract_fun_decls && !Config.wrap_opaque_in_sig && - let _, opaque_funs = module_has_opaque_decls ctx in + let _, opaque_funs = crate_has_opaque_decls ctx true in opaque_funs in if wrap_in_sig then ( @@ -774,7 +808,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if config.extract_transparent then "Definitions" else "OpaqueDefs" in Format.pp_print_break fmt 0 0; - Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr; + Format.pp_open_vbox fmt ctx.indent_incr; Format.pp_print_string fmt ("structure " ^ struct_name ^ " where"); Format.pp_print_break fmt 0 0); List.iter export_decl_group ctx.crate.declarations; @@ -904,7 +938,9 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : unit = (* Translate the module to the pure AST *) - let trans_ctx, trans_types, trans_funs = translate_crate_to_pure crate in + let trans_ctx, trans_types, trans_funs, trans_trait_decls, trans_trait_impls = + translate_crate_to_pure crate + in (* Initialize the extraction context - for now we extract only to F*. * We initialize the names map by registering the keywords used in the @@ -921,36 +957,36 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : mk_formatter_and_names_map trans_ctx crate.name variant_concatenate_type_name in - (* Put everything in the context *) - let ctx = - { - ExtractBase.trans_ctx; - names_map; - unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty }; - fmt; - indent_incr = 2; - use_opaque_pre = !Config.split_files; - use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; - fun_name_info = PureUtils.RegularFunIdMap.empty; - } + let strict_names_map = + let open ExtractBase in + let ids = + List.filter + (fun (id, _) -> strict_collisions id) + (IdMap.bindings names_map.id_to_name) + in + let is_opaque = false in + List.fold_left + (* id_to_string: we shouldn't need to use it *) + (fun m (id, n) -> names_map_add show_id is_opaque id n m) + empty_names_map ids in (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = List.map - (fun (_, ((fwd, loop_fwds), _)) -> - let fwd = - if fwd.Pure.signature.info.effect_info.is_rec then - [ (fwd.def_id, None) ] + (fun { fwd; _ } -> + let fwd_f = + if fwd.f.Pure.signature.info.effect_info.is_rec then + [ (fwd.f.def_id, None) ] else [] in let loop_fwds = List.map (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) - loop_fwds + fwd.loops in - fwd :: loop_fwds) + fwd_f :: loop_fwds) trans_funs in let rec_functions : PureUtils.fun_loop_id list = @@ -958,22 +994,70 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in - (* Register unique names for all the top-level types, globals and functions. + (* Put the translated definitions in maps *) + let trans_types = + Pure.TypeDeclId.Map.of_list + (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) + in + let trans_funs : pure_fun_translation A.FunDeclId.Map.t = + A.FunDeclId.Map.of_list + (List.map + (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) + trans_funs) + in + + (* Put everything in the context *) + let ctx = + let trans_trait_decls = + T.TraitDeclId.Map.of_list + (List.map + (fun (d : Pure.trait_decl) -> (d.def_id, d)) + trans_trait_decls) + in + let trans_trait_impls = + T.TraitImplId.Map.of_list + (List.map + (fun (d : Pure.trait_impl) -> (d.def_id, d)) + trans_trait_impls) + in + { + ExtractBase.crate; + trans_ctx; + names_map; + unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty }; + strict_names_map; + fmt; + indent_incr = 2; + use_opaque_pre = !Config.split_files; + use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; + fun_name_info = PureUtils.RegularFunIdMap.empty; + trait_decl_id = None (* None by default *); + is_provided_method = false (* false by default *); + trans_trait_decls; + trans_trait_impls; + trans_types; + trans_funs; + functions_with_decreases_clause = rec_functions; + } + in + + (* Register unique names for all the top-level types, globals, functions... * Note that the order in which we generate the names doesn't matter: * we just need to generate a mapping from identifier to name, and make * sure there are no name clashes. *) let ctx = List.fold_left (fun ctx def -> Extract.extract_type_decl_register_names ctx def) - ctx trans_types + ctx + (Pure.TypeDeclId.Map.values trans_types) in let ctx = List.fold_left - (fun ctx (keep_fwd, defs) -> + (fun ctx (trans : pure_fun_translation) -> (* If requested by the user, register termination measures and decreases proofs for all the recursive functions *) - let fwd_def = fst (fst defs) in + let fwd_def = trans.fwd.f in let gen_decr_clause (def : Pure.fun_decl) = !Config.extract_decreases_clauses && PureUtils.FunLoopIdSet.mem @@ -984,10 +1068,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : * those are handled later *) let is_global = fwd_def.Pure.is_global_decl_body in if is_global then ctx - else - Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause - defs) - ctx trans_funs + else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans) + ctx + (A.FunDeclId.Map.values trans_funs) in let ctx = @@ -995,6 +1078,16 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (A.GlobalDeclId.Map.values crate.globals) in + let ctx = + List.fold_left Extract.extract_trait_decl_register_names ctx + trans_trait_decls + in + + let ctx = + List.fold_left Extract.extract_trait_impl_register_names ctx + trans_trait_impls + in + (* Open the output file *) (* First compute the filename by replacing the extension and converting the * case (rust module names are snake case) *) @@ -1023,19 +1116,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (namespace, crate_name, Filename.concat dest_dir crate_name) in - (* Put the translated definitions in maps *) - let trans_types = - Pure.TypeDeclId.Map.of_list - (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) - in - let trans_funs = - A.FunDeclId.Map.of_list - (List.map - (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> - ((fst fd).def_id, (keep_fwd, (fd, bdl)))) - trans_funs) - in - let mkdir_if dest_dir = if not (Sys.file_exists dest_dir) then ( log#linfo (lazy ("Creating missing directory: " ^ dest_dir)); @@ -1091,16 +1171,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in (* Extract the file(s) *) - let gen_ctx = - { - crate; - extract_ctx = ctx; - trans_types; - trans_funs; - functions_with_decreases_clause = rec_functions; - } - in - let module_delimiter = match !Config.backend with | FStar -> "." @@ -1136,6 +1206,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : extract_decreases_clauses = !Config.extract_decreases_clauses; extract_template_decreases_clauses = false; extract_fun_decls = false; + extract_trait_decls = false; + extract_trait_impls = false; extract_transparent = true; extract_opaque = false; extract_state_type = false; @@ -1147,7 +1219,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Check if there are opaque types and functions - in which case we need * to split *) - let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in + let has_opaque_types, has_opaque_funs = crate_has_opaque_decls ctx true in let has_opaque_types = has_opaque_types || !Config.use_state in (* Extract the types *) @@ -1168,6 +1240,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_types = true; + extract_trait_decls = true; extract_opaque = true; extract_state_type = !Config.use_state; interface = has_opaque_types; @@ -1186,7 +1259,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file types_config gen_ctx file_info; + extract_file types_config ctx file_info; (* Extract the template clauses *) (if needs_clauses_module && !Config.extract_template_decreases_clauses then @@ -1214,9 +1287,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file template_clauses_config gen_ctx file_info); + extract_file template_clauses_config ctx file_info); - (* Extract the opaque functions, if needed *) + (* Extract the opaque declarations, if needed *) let opaque_funs_module = if has_opaque_funs then ( (* In the case of Lean we generate a template file *) @@ -1244,17 +1317,14 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_fun_decls = true; + extract_trait_impls = true; + extract_globals = true; extract_transparent = false; extract_opaque = true; interface = true; } in - let gen_ctx = - { - gen_ctx with - extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false }; - } - in + let ctx = { ctx with use_opaque_pre = false } in let file_info = { filename = opaque_filename; @@ -1268,7 +1338,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = [ types_module ]; } in - extract_file opaque_config gen_ctx file_info; + extract_file opaque_config ctx file_info; (* Return the additional dependencies *) [ opaque_imported_module ]) else [] @@ -1281,6 +1351,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_fun_decls = true; + extract_trait_impls = true; extract_globals = true; test_trans_unit_functions = !Config.test_trans_unit_functions; } @@ -1307,7 +1378,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : [ types_module ] @ opaque_funs_module @ clauses_module; } in - extract_file fun_config gen_ctx file_info) + extract_file fun_config ctx file_info) else let gen_config = { @@ -1316,6 +1387,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : extract_template_decreases_clauses = !Config.extract_template_decreases_clauses; extract_fun_decls = true; + extract_trait_decls = true; + extract_trait_impls = true; extract_transparent = true; extract_opaque = true; extract_state_type = !Config.use_state; @@ -1337,7 +1410,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file gen_config gen_ctx file_info); + extract_file gen_config ctx file_info); (* Generate the build file *) match !Config.backend with diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index ba5e237b..3427fd43 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -10,64 +10,69 @@ module FA = FunsAnalysis (** The local logger *) let log = L.translate_log -type type_context = C.type_context [@@deriving show] - -type fun_context = { - fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_infos : FA.fun_info A.FunDeclId.Map.t; -} -[@@deriving show] +type trans_ctx = C.decls_ctx [@@deriving show] +type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list } +type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list -type global_context = C.global_context [@@deriving show] +type pure_fun_translation = { + keep_fwd : bool; + (** Should we extract the forward function? -type trans_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; + If the forward function returns `()` and there is exactly one + backward function, we may merge the forward into the backward + function and thus don't extract the forward function)? + *) + fwd : fun_and_loops; + backs : fun_and_loops list; } -type fun_and_loops = Pure.fun_decl * Pure.fun_decl list -type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list -type pure_fun_translation = fun_and_loops * fun_and_loops list +let trans_ctx_to_type_formatter (ctx : trans_ctx) + (type_params : Pure.type_var list) + (const_generic_params : Pure.const_generic_var list) : + PrintPure.type_formatter = + let type_decls = ctx.type_ctx.type_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in + PrintPure.mk_type_formatter type_decls global_decls trait_decls trait_impls + type_params const_generic_params let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = - let type_params = def.type_params in - let cg_params = def.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let global_decls = ctx.global_context.global_decls in + let generics = def.generics in let fmt = - PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + trans_ctx_to_type_formatter ctx generics.types generics.const_generics in PrintPure.type_decl_to_string fmt def let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string = Print.fun_name_to_string - (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name + (Pure.TypeDeclId.Map.find id ctx.type_ctx.type_decls).name + +let trans_ctx_to_ast_formatter (ctx : trans_ctx) + (type_params : Pure.type_var list) + (const_generic_params : Pure.const_generic_var list) : + PrintPure.ast_formatter = + let type_decls = ctx.type_ctx.type_decls in + let fun_decls = ctx.fun_ctx.fun_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in + PrintPure.mk_ast_formatter type_decls fun_decls global_decls trait_decls + trait_impls type_params const_generic_params let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = - let type_params = sg.type_params in - let cg_params = sg.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in + let generics = sg.generics in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + trans_ctx_to_ast_formatter ctx generics.types generics.const_generics in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = - let type_params = def.signature.type_params in - let cg_params = def.signature.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in + let generics = def.signature.generics in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + trans_ctx_to_ast_formatter ctx generics.types generics.const_generics in PrintPure.fun_decl_to_string fmt def let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string = - Print.fun_name_to_string - (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find id ctx.fun_ctx.fun_decls).name diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index 925f6d39..95c7206a 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -14,11 +14,10 @@ type expl_info = subtype_info [@@deriving show] type type_borrows_info = { contains_static : bool; - (** Does the type (transitively) contains a static borrow? *) - contains_borrow : bool; - (** Does the type (transitively) contains a borrow? *) + (** Does the type (transitively) contain a static borrow? *) + contains_borrow : bool; (** Does the type (transitively) contain a borrow? *) contains_nested_borrows : bool; - (** Does the type (transitively) contains nested borrows? *) + (** Does the type (transitively) contain nested borrows? *) contains_borrow_under_mut : bool; } [@@deriving show] @@ -61,7 +60,7 @@ let initialize_g_type_info (param_infos : 'p) : 'p g_type_info = let initialize_type_decl_info (def : type_decl) : type_decl_info = let param_info = { under_borrow = false; under_mut_borrow = false } in - let param_infos = List.map (fun _ -> param_info) def.type_params in + let param_infos = List.map (fun _ -> param_info) def.generics.types in initialize_g_type_info param_infos let type_decl_info_to_partial_type_info (info : type_decl_info) : @@ -122,7 +121,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let rec analyze (expl_info : expl_info) (ty_info : partial_type_info) (ty : 'r ty) : partial_type_info = match ty with - | Literal _ | Never -> ty_info + | Literal _ | Never | TraitType _ -> ty_info | TypeVar var_id -> ( (* Update the information for the proper parameter, if necessary *) match ty_info.param_infos with @@ -171,20 +170,18 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) analyze expl_info ty_info rty | Adt ( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)), - _, - tys, - _ ) -> + generics ) -> (* Nothing to update: just explore the type parameters *) List.fold_left (fun ty_info ty -> analyze expl_info ty_info ty) - ty_info tys - | Adt (AdtId adt_id, regions, tys, _cgs) -> + ty_info generics.types + | Adt (AdtId adt_id, generics) -> (* Lookup the information for this type definition *) let adt_info = TypeDeclId.Map.find adt_id infos in (* Update the type info with the information from the adt *) let ty_info = update_ty_info ty_info adt_info.borrows_info in (* Check if 'static appears in the region parameters *) - let found_static = List.exists r_is_static regions in + let found_static = List.exists r_is_static generics.regions in let borrows_info = ty_info.borrows_info in let borrows_info = { @@ -196,7 +193,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let ty_info = { ty_info with borrows_info } in (* For every instantiated type parameter: update the exploration info * then explore the type *) - let params_tys = List.combine adt_info.param_infos tys in + let params_tys = List.combine adt_info.param_infos generics.types in let ty_info = List.fold_left (fun ty_info (param_info, ty) -> diff --git a/compiler/Values.ml b/compiler/Values.ml index d884c319..de27e7a9 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -52,6 +52,10 @@ type sv_kind = (** The result of a loop join (when computing loop fixed points) *) | Aggregate (** A symbolic value we introduce in place of an aggregate value *) + | ConstGeneric + (** A symbolic value we introduce when using a const generic as a value *) + | TraitConst + (** A symbolic value we introduce when evaluating a trait associated constant *) [@@deriving show, ord] (** Ancestor for {!symbolic_value} iter visitor *) diff --git a/compiler/dune b/compiler/dune index 6785cad4..2f5a0a44 100644 --- a/compiler/dune +++ b/compiler/dune @@ -12,6 +12,7 @@ (pps ppx_deriving.show ppx_deriving.ord visitors.ppx)) (libraries charon core_unix unionFind ocamlgraph) (modules + AssociatedTypes Assumed Collections Config @@ -21,6 +22,7 @@ Expressions ExpressionsUtils Extract + ExtractAssumed ExtractBase FunsAnalysis Identifiers diff --git a/tests/coq/array/Array_Funs.v b/tests/coq/array/Array_Funs.v index 6d791873..6ff3066a 100644 --- a/tests/coq/array/Array_Funs.v +++ b/tests/coq/array/Array_Funs.v @@ -183,6 +183,10 @@ Definition index_index_array_fwd array_index_shared u32 32%usize a j . +(** [array::const_gen_ret]: forward function *) +Definition const_gen_ret_fwd (N : usize) : result usize := + Return N. + (** [array::update_update_array]: forward function *) Definition update_update_array_fwd (s : array (array u32 32%usize) 32%usize) (i : usize) (j : usize) : @@ -447,6 +451,15 @@ Definition f3_fwd (n : nat) : result u32 := sum2_fwd n s s0 . +(** [array::SZ] *) +Definition sz_body : result usize := Return 32%usize. +Definition sz_c : usize := sz_body%global. + +(** [array::f5]: forward function *) +Definition f5_fwd (x : array u32 32%usize) : result u32 := + array_index_shared u32 32%usize x 0%usize +. + (** [array::ite]: forward function *) Definition ite_fwd : result unit := s <- diff --git a/tests/coq/array/Primitives.v b/tests/coq/array/Primitives.v index 71a2d9c3..8d6c9c8d 100644 --- a/tests/coq/array/Primitives.v +++ b/tests/coq/array/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; diff --git a/tests/coq/betree/BetreeMain_Funs.v b/tests/coq/betree/BetreeMain_Funs.v index 1e457433..cfa1f8fb 100644 --- a/tests/coq/betree/BetreeMain_Funs.v +++ b/tests/coq/betree/BetreeMain_Funs.v @@ -75,12 +75,6 @@ Definition betree_node_id_counter_fresh_id_back Return {| Betree_node_id_counter_next_node_id := i |} . -(** [core::num::u64::{9}::MAX] *) -Definition core_num_u64_max_body : result u64 := - Return 18446744073709551615%u64 -. -Definition core_num_u64_max_c : u64 := core_num_u64_max_body%global. - (** [betree_main::betree::upsert_update]: forward function *) Definition betree_upsert_update_fwd (prev : option u64) (st : Betree_upsert_fun_state_t) : result u64 := @@ -93,8 +87,8 @@ Definition betree_upsert_update_fwd | Some prev0 => match st with | BetreeUpsertFunStateAdd v => - margin <- u64_sub core_num_u64_max_c prev0; - if margin s>= v then u64_add prev0 v else Return core_num_u64_max_c + margin <- u64_sub core_u64_max prev0; + if margin s>= v then u64_add prev0 v else Return core_u64_max | BetreeUpsertFunStateSub v => if prev0 s>= v then u64_sub prev0 v else Return 0%u64 end diff --git a/tests/coq/betree/Primitives.v b/tests/coq/betree/Primitives.v index 71a2d9c3..8d6c9c8d 100644 --- a/tests/coq/betree/Primitives.v +++ b/tests/coq/betree/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; diff --git a/tests/coq/hashmap/Hashmap_Funs.v b/tests/coq/hashmap/Hashmap_Funs.v index e950ba0b..054880d4 100644 --- a/tests/coq/hashmap/Hashmap_Funs.v +++ b/tests/coq/hashmap/Hashmap_Funs.v @@ -190,10 +190,6 @@ Definition hash_map_insert_no_resize_fwd_back |}) . -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body : result u32 := Return 4294967295%u32. -Definition core_num_u32_max_c : u32 := core_num_u32_max_body%global. - (** [hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) Fixpoint hash_map_move_elements_from_list_loop_fwd_back @@ -259,7 +255,7 @@ Definition hash_map_move_elements_fwd_back (there is a single backward function, and the forward function returns ()) *) Definition hash_map_try_resize_fwd_back (T : Type) (n : nat) (self : Hash_map_t T) : result (Hash_map_t T) := - max_usize <- scalar_cast U32 Usize core_num_u32_max_c; + max_usize <- scalar_cast U32 Usize core_u32_max; let capacity := vec_len (List_t T) self.(Hash_map_slots) in n1 <- usize_div max_usize 2%usize; let (i, i0) := self.(Hash_map_max_load_factor) in diff --git a/tests/coq/hashmap/Primitives.v b/tests/coq/hashmap/Primitives.v index 71a2d9c3..8d6c9c8d 100644 --- a/tests/coq/hashmap/Primitives.v +++ b/tests/coq/hashmap/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; diff --git a/tests/coq/hashmap_on_disk/HashmapMain_Funs.v b/tests/coq/hashmap_on_disk/HashmapMain_Funs.v index 657d5590..a85adbf2 100644 --- a/tests/coq/hashmap_on_disk/HashmapMain_Funs.v +++ b/tests/coq/hashmap_on_disk/HashmapMain_Funs.v @@ -208,10 +208,6 @@ Definition hashmap_hash_map_insert_no_resize_fwd_back |}) . -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body : result u32 := Return 4294967295%u32. -Definition core_num_u32_max_c : u32 := core_num_u32_max_body%global. - (** [hashmap_main::hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) Fixpoint hashmap_hash_map_move_elements_from_list_loop_fwd_back @@ -282,7 +278,7 @@ Definition hashmap_hash_map_try_resize_fwd_back (T : Type) (n : nat) (self : Hashmap_hash_map_t T) : result (Hashmap_hash_map_t T) := - max_usize <- scalar_cast U32 Usize core_num_u32_max_c; + max_usize <- scalar_cast U32 Usize core_u32_max; let capacity := vec_len (Hashmap_list_t T) self.(Hashmap_hash_map_slots) in n1 <- usize_div max_usize 2%usize; let (i, i0) := self.(Hashmap_hash_map_max_load_factor) in diff --git a/tests/coq/hashmap_on_disk/Primitives.v b/tests/coq/hashmap_on_disk/Primitives.v index 71a2d9c3..8d6c9c8d 100644 --- a/tests/coq/hashmap_on_disk/Primitives.v +++ b/tests/coq/hashmap_on_disk/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; diff --git a/tests/coq/misc/Constants.v b/tests/coq/misc/Constants.v index f1c32730..5dd78a09 100644 --- a/tests/coq/misc/Constants.v +++ b/tests/coq/misc/Constants.v @@ -12,12 +12,8 @@ Module Constants. Definition x0_body : result u32 := Return 0%u32. Definition x0_c : u32 := x0_body%global. -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body : result u32 := Return 4294967295%u32. -Definition core_num_u32_max_c : u32 := core_num_u32_max_body%global. - (** [constants::X1] *) -Definition x1_body : result u32 := Return core_num_u32_max_c. +Definition x1_body : result u32 := Return core_u32_max. Definition x1_c : u32 := x1_body%global. (** [constants::X2] *) diff --git a/tests/coq/misc/Primitives.v b/tests/coq/misc/Primitives.v index 71a2d9c3..8d6c9c8d 100644 --- a/tests/coq/misc/Primitives.v +++ b/tests/coq/misc/Primitives.v @@ -394,6 +394,20 @@ Notation "x s< y" := (scalar_ltb x y) (at level 80) : Primitives_scope. Notation "x s>= y" := (scalar_geb x y) (at level 80) : Primitives_scope. Notation "x s> y" := (scalar_gtb x y) (at level 80) : Primitives_scope. +(** Constants *) +Definition core_u8_max := u8_max %u32. +Definition core_u16_max := u16_max %u32. +Definition core_u32_max := u32_max %u32. +Definition core_u64_max := u64_max %u64. +Definition core_u128_max := u64_max %u128. +Axiom core_usize_max : usize. (** TODO *) +Definition core_i8_max := i8_max %i32. +Definition core_i16_max := i16_max %i32. +Definition core_i32_max := i32_max %i32. +Definition core_i64_max := i64_max %i64. +Definition core_i128_max := i64_max %i128. +Axiom core_isize_max : isize. (** TODO *) + (*** Range *) Record range (T : Type) := mk_range { start: T; diff --git a/tests/fstar/array/Array.Funs.fst b/tests/fstar/array/Array.Funs.fst index 7c1d0b09..83256398 100644 --- a/tests/fstar/array/Array.Funs.fst +++ b/tests/fstar/array/Array.Funs.fst @@ -139,6 +139,10 @@ let index_index_array_fwd let* a = array_index_shared (array u32 32) 32 s i in array_index_shared u32 32 a j +(** [array::const_gen_ret]: forward function *) +let const_gen_ret_fwd (n : usize) : result usize = + Return n + (** [array::update_update_array]: forward function *) let update_update_array_fwd (s : array (array u32 32) 32) (i : usize) (j : usize) : result unit = @@ -343,6 +347,14 @@ let f3_fwd : result u32 = ]) 16 18 in sum2_fwd s s0 +(** [array::SZ] *) +let sz_body : result usize = Return 32 +let sz_c : usize = eval_global sz_body + +(** [array::f5]: forward function *) +let f5_fwd (x : array u32 32) : result u32 = + array_index_shared u32 32 x 0 + (** [array::ite]: forward function *) let ite_fwd : result unit = let* s = array_to_slice_mut_fwd u32 2 (mk_array u32 2 [ 0; 0 ]) in diff --git a/tests/fstar/array/Primitives.fst b/tests/fstar/array/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/array/Primitives.fst +++ b/tests/fstar/array/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/fstar/betree/BetreeMain.Funs.fst b/tests/fstar/betree/BetreeMain.Funs.fst index 847dc865..2bb2352b 100644 --- a/tests/fstar/betree/BetreeMain.Funs.fst +++ b/tests/fstar/betree/BetreeMain.Funs.fst @@ -60,10 +60,6 @@ let betree_node_id_counter_fresh_id_back let* i = u64_add self.betree_node_id_counter_next_node_id 1 in Return { betree_node_id_counter_next_node_id = i } -(** [core::num::u64::{9}::MAX] *) -let core_num_u64_max_body : result u64 = Return 18446744073709551615 -let core_num_u64_max_c : u64 = eval_global core_num_u64_max_body - (** [betree_main::betree::upsert_update]: forward function *) let betree_upsert_update_fwd (prev : option u64) (st : betree_upsert_fun_state_t) : result u64 = @@ -76,8 +72,8 @@ let betree_upsert_update_fwd | Some prev0 -> begin match st with | BetreeUpsertFunStateAdd v -> - let* margin = u64_sub core_num_u64_max_c prev0 in - if margin >= v then u64_add prev0 v else Return core_num_u64_max_c + let* margin = u64_sub core_u64_max prev0 in + if margin >= v then u64_add prev0 v else Return core_u64_max | BetreeUpsertFunStateSub v -> if prev0 >= v then u64_sub prev0 v else Return 0 end diff --git a/tests/fstar/betree/Primitives.fst b/tests/fstar/betree/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/betree/Primitives.fst +++ b/tests/fstar/betree/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/fstar/betree_back_stateful/BetreeMain.Funs.fst b/tests/fstar/betree_back_stateful/BetreeMain.Funs.fst index 3d08cd3c..8083ee8f 100644 --- a/tests/fstar/betree_back_stateful/BetreeMain.Funs.fst +++ b/tests/fstar/betree_back_stateful/BetreeMain.Funs.fst @@ -60,10 +60,6 @@ let betree_node_id_counter_fresh_id_back let* i = u64_add self.betree_node_id_counter_next_node_id 1 in Return { betree_node_id_counter_next_node_id = i } -(** [core::num::u64::{9}::MAX] *) -let core_num_u64_max_body : result u64 = Return 18446744073709551615 -let core_num_u64_max_c : u64 = eval_global core_num_u64_max_body - (** [betree_main::betree::upsert_update]: forward function *) let betree_upsert_update_fwd (prev : option u64) (st : betree_upsert_fun_state_t) : result u64 = @@ -76,8 +72,8 @@ let betree_upsert_update_fwd | Some prev0 -> begin match st with | BetreeUpsertFunStateAdd v -> - let* margin = u64_sub core_num_u64_max_c prev0 in - if margin >= v then u64_add prev0 v else Return core_num_u64_max_c + let* margin = u64_sub core_u64_max prev0 in + if margin >= v then u64_add prev0 v else Return core_u64_max | BetreeUpsertFunStateSub v -> if prev0 >= v then u64_sub prev0 v else Return 0 end diff --git a/tests/fstar/betree_back_stateful/Primitives.fst b/tests/fstar/betree_back_stateful/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/betree_back_stateful/Primitives.fst +++ b/tests/fstar/betree_back_stateful/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/fstar/hashmap/Hashmap.Funs.fst b/tests/fstar/hashmap/Hashmap.Funs.fst index f4c13a7b..40cd0477 100644 --- a/tests/fstar/hashmap/Hashmap.Funs.fst +++ b/tests/fstar/hashmap/Hashmap.Funs.fst @@ -139,10 +139,6 @@ let hash_map_insert_no_resize_fwd_back let* v = vec_index_mut_back (list_t t) self.hash_map_slots hash_mod l0 in Return { self with hash_map_slots = v } -(** [core::num::u32::{8}::MAX] *) -let core_num_u32_max_body : result u32 = Return 4294967295 -let core_num_u32_max_c : u32 = eval_global core_num_u32_max_body - (** [hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) let rec hash_map_move_elements_from_list_loop_fwd_back @@ -194,7 +190,7 @@ let hash_map_move_elements_fwd_back (there is a single backward function, and the forward function returns ()) *) let hash_map_try_resize_fwd_back (t : Type0) (self : hash_map_t t) : result (hash_map_t t) = - let* max_usize = scalar_cast U32 Usize core_num_u32_max_c in + let* max_usize = scalar_cast U32 Usize core_u32_max in let capacity = vec_len (list_t t) self.hash_map_slots in let* n1 = usize_div max_usize 2 in let (i, i0) = self.hash_map_max_load_factor in diff --git a/tests/fstar/hashmap/Primitives.fst b/tests/fstar/hashmap/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/hashmap/Primitives.fst +++ b/tests/fstar/hashmap/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/fstar/hashmap_on_disk/HashmapMain.Funs.fst b/tests/fstar/hashmap_on_disk/HashmapMain.Funs.fst index 1c94209c..5af90bd8 100644 --- a/tests/fstar/hashmap_on_disk/HashmapMain.Funs.fst +++ b/tests/fstar/hashmap_on_disk/HashmapMain.Funs.fst @@ -157,10 +157,6 @@ let hashmap_hash_map_insert_no_resize_fwd_back hash_mod l0 in Return { self with hashmap_hash_map_slots = v } -(** [core::num::u32::{8}::MAX] *) -let core_num_u32_max_body : result u32 = Return 4294967295 -let core_num_u32_max_c : u32 = eval_global core_num_u32_max_body - (** [hashmap_main::hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) let rec hashmap_hash_map_move_elements_from_list_loop_fwd_back @@ -218,7 +214,7 @@ let hashmap_hash_map_move_elements_fwd_back (there is a single backward function, and the forward function returns ()) *) let hashmap_hash_map_try_resize_fwd_back (t : Type0) (self : hashmap_hash_map_t t) : result (hashmap_hash_map_t t) = - let* max_usize = scalar_cast U32 Usize core_num_u32_max_c in + let* max_usize = scalar_cast U32 Usize core_u32_max in let capacity = vec_len (hashmap_list_t t) self.hashmap_hash_map_slots in let* n1 = usize_div max_usize 2 in let (i, i0) = self.hashmap_hash_map_max_load_factor in diff --git a/tests/fstar/hashmap_on_disk/Primitives.fst b/tests/fstar/hashmap_on_disk/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/hashmap_on_disk/Primitives.fst +++ b/tests/fstar/hashmap_on_disk/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/fstar/misc/Constants.fst b/tests/fstar/misc/Constants.fst index d2b0415e..7dfb6f36 100644 --- a/tests/fstar/misc/Constants.fst +++ b/tests/fstar/misc/Constants.fst @@ -9,12 +9,8 @@ open Primitives let x0_body : result u32 = Return 0 let x0_c : u32 = eval_global x0_body -(** [core::num::u32::{8}::MAX] *) -let core_num_u32_max_body : result u32 = Return 4294967295 -let core_num_u32_max_c : u32 = eval_global core_num_u32_max_body - (** [constants::X1] *) -let x1_body : result u32 = Return core_num_u32_max_c +let x1_body : result u32 = Return core_u32_max let x1_c : u32 = eval_global x1_body (** [constants::X2] *) diff --git a/tests/fstar/misc/Primitives.fst b/tests/fstar/misc/Primitives.fst index 9db82069..cd18cf29 100644 --- a/tests/fstar/misc/Primitives.fst +++ b/tests/fstar/misc/Primitives.fst @@ -169,17 +169,44 @@ let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : /// The scalar types type isize : eqtype = scalar Isize -type i8 : eqtype = scalar I8 -type i16 : eqtype = scalar I16 -type i32 : eqtype = scalar I32 -type i64 : eqtype = scalar I64 -type i128 : eqtype = scalar I128 +type i8 : eqtype = scalar I8 +type i16 : eqtype = scalar I16 +type i32 : eqtype = scalar I32 +type i64 : eqtype = scalar I64 +type i128 : eqtype = scalar I128 type usize : eqtype = scalar Usize -type u8 : eqtype = scalar U8 -type u16 : eqtype = scalar U16 -type u32 : eqtype = scalar U32 -type u64 : eqtype = scalar U64 -type u128 : eqtype = scalar U128 +type u8 : eqtype = scalar U8 +type u16 : eqtype = scalar U16 +type u32 : eqtype = scalar U32 +type u64 : eqtype = scalar U64 +type u128 : eqtype = scalar U128 + + +let core_isize_min : isize = isize_min +let core_isize_max : isize = isize_max +let core_i8_min : i8 = i8_min +let core_i8_max : i8 = i8_max +let core_i16_min : i16 = i16_min +let core_i16_max : i16 = i16_max +let core_i32_min : i32 = i32_min +let core_i32_max : i32 = i32_max +let core_i64_min : i64 = i64_min +let core_i64_max : i64 = i64_max +let core_i128_min : i128 = i128_min +let core_i128_max : i128 = i128_max + +let core_usize_min : usize = usize_min +let core_usize_max : usize = usize_max +let core_u8_min : u8 = u8_min +let core_u8_max : u8 = u8_max +let core_u16_min : u16 = u16_min +let core_u16_max : u16 = u16_max +let core_u32_min : u32 = u32_min +let core_u32_max : u32 = u32_max +let core_u64_min : u64 = u64_min +let core_u64_max : u64 = u64_max +let core_u128_min : u128 = u128_min +let core_u128_max : u128 = u128_max /// Negation let isize_neg = scalar_neg #Isize diff --git a/tests/hol4/betree/betreeMain_FunsScript.sml b/tests/hol4/betree/betreeMain_FunsScript.sml index 5e604f8c..bd16c16c 100644 --- a/tests/hol4/betree/betreeMain_FunsScript.sml +++ b/tests/hol4/betree/betreeMain_FunsScript.sml @@ -88,14 +88,6 @@ val betree_node_id_counter_fresh_id_back_def = Define ‘ od ’ -(** [core::num::u64::{9}::MAX] *) -Definition core_num_u64_max_body_def: - core_num_u64_max_body : u64 result = Return (int_to_u64 18446744073709551615) -End -Definition core_num_u64_max_c_def: - core_num_u64_max_c : u64 = get_return_value core_num_u64_max_body -End - val betree_upsert_update_fwd_def = Define ‘ (** [betree_main::betree::upsert_update]: forward function *) betree_upsert_update_fwd @@ -109,8 +101,8 @@ val betree_upsert_update_fwd_def = Define ‘ (case st of | BetreeUpsertFunStateAdd v => do - margin <- u64_sub core_num_u64_max_c prev0; - if u64_ge margin v then u64_add prev0 v else Return core_num_u64_max_c + margin <- u64_sub core_u64_max prev0; + if u64_ge margin v then u64_add prev0 v else Return core_u64_max od | BetreeUpsertFunStateSub v => if u64_ge prev0 v then u64_sub prev0 v else Return (int_to_u64 0))) diff --git a/tests/hol4/betree/betreeMain_FunsTheory.sig b/tests/hol4/betree/betreeMain_FunsTheory.sig index 6c249f70..c922ca9f 100644 --- a/tests/hol4/betree/betreeMain_FunsTheory.sig +++ b/tests/hol4/betree/betreeMain_FunsTheory.sig @@ -58,8 +58,6 @@ sig val betree_store_internal_node_fwd_def : thm val betree_store_leaf_node_fwd_def : thm val betree_upsert_update_fwd_def : thm - val core_num_u64_max_body_def : thm - val core_num_u64_max_c_def : thm val main_fwd_def : thm val betreeMain_Funs_grammars : type_grammar.grammar * term_grammar.grammar @@ -1215,22 +1213,14 @@ sig case st of BetreeUpsertFunStateAdd v => do - margin <- u64_sub core_num_u64_max_c prev0; + margin <- u64_sub core_u64_max prev0; if u64_ge margin v then u64_add prev0 v - else Return core_num_u64_max_c + else Return core_u64_max od | BetreeUpsertFunStateSub v' => if u64_ge prev0 v' then u64_sub prev0 v' else Return (int_to_u64 0) - [core_num_u64_max_body_def] Definition - - ⊢ core_num_u64_max_body = Return (int_to_u64 18446744073709551615) - - [core_num_u64_max_c_def] Definition - - ⊢ core_num_u64_max_c = get_return_value core_num_u64_max_body - [main_fwd_def] Definition ⊢ main_fwd = Return () diff --git a/tests/hol4/hashmap/hashmap_FunsScript.sml b/tests/hol4/hashmap/hashmap_FunsScript.sml index e3c3d2a5..682c5760 100644 --- a/tests/hol4/hashmap/hashmap_FunsScript.sml +++ b/tests/hol4/hashmap/hashmap_FunsScript.sml @@ -170,14 +170,6 @@ val hash_map_insert_no_resize_fwd_back_def = Define ‘ od ’ -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body_def: - core_num_u32_max_body : u32 result = Return (int_to_u32 4294967295) -End -Definition core_num_u32_max_c_def: - core_num_u32_max_c : u32 = get_return_value core_num_u32_max_body -End - val [hash_map_move_elements_from_list_loop_fwd_back_def] = DefineDiv ‘ (** [hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) @@ -241,7 +233,7 @@ val hash_map_try_resize_fwd_back_def = Define ‘ (there is a single backward function, and the forward function returns ()) *) hash_map_try_resize_fwd_back (self : 't hash_map_t) : 't hash_map_t result = do - max_usize <- mk_usize (u32_to_int core_num_u32_max_c); + max_usize <- mk_usize (u32_to_int core_u32_max); let capacity = vec_len self.hash_map_slots in do n1 <- usize_div max_usize (int_to_usize 2); diff --git a/tests/hol4/hashmap/hashmap_FunsTheory.sig b/tests/hol4/hashmap/hashmap_FunsTheory.sig index 50482547..bb3e192b 100644 --- a/tests/hol4/hashmap/hashmap_FunsTheory.sig +++ b/tests/hol4/hashmap/hashmap_FunsTheory.sig @@ -3,8 +3,6 @@ sig type thm = Thm.thm (* Definitions *) - val core_num_u32_max_body_def : thm - val core_num_u32_max_c_def : thm val hash_key_fwd_def : thm val hash_map_allocate_slots_fwd_def : thm val hash_map_allocate_slots_loop_fwd_def : thm @@ -48,14 +46,6 @@ sig (* [hashmap_Types] Parent theory of "hashmap_Funs" - [core_num_u32_max_body_def] Definition - - ⊢ core_num_u32_max_body = Return (int_to_u32 4294967295) - - [core_num_u32_max_c_def] Definition - - ⊢ core_num_u32_max_c = get_return_value core_num_u32_max_body - [hash_key_fwd_def] Definition ⊢ ∀k. hash_key_fwd k = Return k @@ -472,7 +462,7 @@ sig ⊢ ∀self. hash_map_try_resize_fwd_back self = do - max_usize <- mk_usize (u32_to_int core_num_u32_max_c); + max_usize <- mk_usize (u32_to_int core_u32_max); capacity <<- vec_len self.hash_map_slots; n1 <- usize_div max_usize (int_to_usize 2); (i,i0) <<- self.hash_map_max_load_factor; diff --git a/tests/hol4/hashmap/hashmap_PropertiesScript.sml b/tests/hol4/hashmap/hashmap_PropertiesScript.sml index 7259f2f5..8bc12fa5 100644 --- a/tests/hol4/hashmap/hashmap_PropertiesScript.sml +++ b/tests/hol4/hashmap/hashmap_PropertiesScript.sml @@ -1296,7 +1296,7 @@ Proof rw [hash_map_try_resize_fwd_back_def] >> (* “_ <-- mk_usize (u32_to_int core_num_u32_max_c)” *) assume_tac usize_u32_bounds >> - fs [core_num_u32_max_c_def, core_num_u32_max_body_def, get_return_value_def, u32_max_def] >> + fs [core_u32_max_def, u32_max_def] >> massage >> fs [mk_usize_def, u32_max_def] >> (* / 2 *) progress >> diff --git a/tests/hol4/hashmap_on_disk/hashmapMain_FunsScript.sml b/tests/hol4/hashmap_on_disk/hashmapMain_FunsScript.sml index b21c4f58..c1e30aa6 100644 --- a/tests/hol4/hashmap_on_disk/hashmapMain_FunsScript.sml +++ b/tests/hol4/hashmap_on_disk/hashmapMain_FunsScript.sml @@ -193,14 +193,6 @@ val hashmap_hash_map_insert_no_resize_fwd_back_def = Define ‘ od ’ -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body_def: - core_num_u32_max_body : u32 result = Return (int_to_u32 4294967295) -End -Definition core_num_u32_max_c_def: - core_num_u32_max_c : u32 = get_return_value core_num_u32_max_body -End - val [hashmap_hash_map_move_elements_from_list_loop_fwd_back_def] = DefineDiv ‘ (** [hashmap_main::hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) *) @@ -271,7 +263,7 @@ val hashmap_hash_map_try_resize_fwd_back_def = Define ‘ hashmap_hash_map_try_resize_fwd_back (self : 't hashmap_hash_map_t) : 't hashmap_hash_map_t result = do - max_usize <- mk_usize (u32_to_int core_num_u32_max_c); + max_usize <- mk_usize (u32_to_int core_u32_max); let capacity = vec_len self.hashmap_hash_map_slots in do n1 <- usize_div max_usize (int_to_usize 2); diff --git a/tests/hol4/hashmap_on_disk/hashmapMain_FunsTheory.sig b/tests/hol4/hashmap_on_disk/hashmapMain_FunsTheory.sig index 1d24cb26..d4e43d9a 100644 --- a/tests/hol4/hashmap_on_disk/hashmapMain_FunsTheory.sig +++ b/tests/hol4/hashmap_on_disk/hashmapMain_FunsTheory.sig @@ -3,8 +3,6 @@ sig type thm = Thm.thm (* Definitions *) - val core_num_u32_max_body_def : thm - val core_num_u32_max_c_def : thm val hashmap_hash_key_fwd_def : thm val hashmap_hash_map_allocate_slots_fwd_def : thm val hashmap_hash_map_allocate_slots_loop_fwd_def : thm @@ -50,14 +48,6 @@ sig (* [hashmapMain_Opaque] Parent theory of "hashmapMain_Funs" - [core_num_u32_max_body_def] Definition - - ⊢ core_num_u32_max_body = Return (int_to_u32 4294967295) - - [core_num_u32_max_c_def] Definition - - ⊢ core_num_u32_max_c = get_return_value core_num_u32_max_body - [hashmap_hash_key_fwd_def] Definition ⊢ ∀k. hashmap_hash_key_fwd k = Return k @@ -506,7 +496,7 @@ sig ⊢ ∀self. hashmap_hash_map_try_resize_fwd_back self = do - max_usize <- mk_usize (u32_to_int core_num_u32_max_c); + max_usize <- mk_usize (u32_to_int core_u32_max); capacity <<- vec_len self.hashmap_hash_map_slots; n1 <- usize_div max_usize (int_to_usize 2); (i,i0) <<- self.hashmap_hash_map_max_load_factor; diff --git a/tests/hol4/misc-constants/constantsScript.sml b/tests/hol4/misc-constants/constantsScript.sml index d589d348..40a319c6 100644 --- a/tests/hol4/misc-constants/constantsScript.sml +++ b/tests/hol4/misc-constants/constantsScript.sml @@ -13,17 +13,9 @@ Definition x0_c_def: x0_c : u32 = get_return_value x0_body End -(** [core::num::u32::{8}::MAX] *) -Definition core_num_u32_max_body_def: - core_num_u32_max_body : u32 result = Return (int_to_u32 4294967295) -End -Definition core_num_u32_max_c_def: - core_num_u32_max_c : u32 = get_return_value core_num_u32_max_body -End - (** [constants::X1] *) Definition x1_body_def: - x1_body : u32 result = Return core_num_u32_max_c + x1_body : u32 result = Return core_u32_max End Definition x1_c_def: x1_c : u32 = get_return_value x1_body diff --git a/tests/hol4/misc-constants/constantsTheory.sig b/tests/hol4/misc-constants/constantsTheory.sig index 149d7e22..287ad5f5 100644 --- a/tests/hol4/misc-constants/constantsTheory.sig +++ b/tests/hol4/misc-constants/constantsTheory.sig @@ -4,8 +4,6 @@ sig (* Definitions *) val add_fwd_def : thm - val core_num_u32_max_body_def : thm - val core_num_u32_max_c_def : thm val get_z1_fwd_def : thm val get_z1_z1_body_def : thm val get_z1_z1_c_def : thm @@ -110,14 +108,6 @@ sig ⊢ ∀a b. add_fwd a b = i32_add a b - [core_num_u32_max_body_def] Definition - - ⊢ core_num_u32_max_body = Return (int_to_u32 4294967295) - - [core_num_u32_max_c_def] Definition - - ⊢ core_num_u32_max_c = get_return_value core_num_u32_max_body - [get_z1_fwd_def] Definition ⊢ get_z1_fwd = Return get_z1_z1_c @@ -321,7 +311,7 @@ sig [x1_body_def] Definition - ⊢ x1_body = Return core_num_u32_max_c + ⊢ x1_body = Return core_u32_max [x1_c_def] Definition diff --git a/tests/lean/Array/Funs.lean b/tests/lean/Array/Funs.lean index ad737dca..d3cb29d2 100644 --- a/tests/lean/Array/Funs.lean +++ b/tests/lean/Array/Funs.lean @@ -53,16 +53,6 @@ def index_array_u32 (s : Array U32 (Usize.ofInt 32)) (i : Usize) : Result U32 := Array.index_shared U32 (Usize.ofInt 32) s i -/- [array::index_array_generic]: forward function -/ -def index_array_generic - (N : Usize) (s : Array U32 N) (i : Usize) : Result U32 := - Array.index_shared U32 N s i - -/- [array::index_array_generic_call]: forward function -/ -def index_array_generic_call - (N : Usize) (s : Array U32 N) (i : Usize) : Result U32 := - index_array_generic N s i - /- [array::index_array_copy]: forward function -/ def index_array_copy (x : Array U32 (Usize.ofInt 32)) : Result U32 := Array.index_shared U32 (Usize.ofInt 32) x (Usize.ofInt 0) @@ -426,23 +416,21 @@ def f3 : Result U32 := (Array.make U32 (Usize.ofInt 2) [ (U32.ofInt 1), (U32.ofInt 2) ]) (Usize.ofInt 0) let _ ← f2 i + let b := Array.repeat U32 (Usize.ofInt 32) (U32.ofInt 0) let s ← Array.to_slice_shared U32 (Usize.ofInt 2) (Array.make U32 (Usize.ofInt 2) [ (U32.ofInt 1), (U32.ofInt 2) ]) - let s0 ← - f4 - (Array.make U32 (Usize.ofInt 32) [ - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), - (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0), (U32.ofInt 0) - ]) (Usize.ofInt 16) (Usize.ofInt 18) + let s0 ← f4 b (Usize.ofInt 16) (Usize.ofInt 18) sum2 s s0 +/- [array::SZ] -/ +def sz_body : Result Usize := Result.ret (Usize.ofInt 32) +def sz_c : Usize := eval_global sz_body (by simp) + +/- [array::f5]: forward function -/ +def f5 (x : Array U32 (Usize.ofInt 32)) : Result U32 := + Array.index_shared U32 (Usize.ofInt 32) x (Usize.ofInt 0) + /- [array::ite]: forward function -/ def ite : Result Unit := do @@ -462,4 +450,9 @@ def ite : Result Unit := (Array.make U32 (Usize.ofInt 2) [ (U32.ofInt 0), (U32.ofInt 0) ]) s2 Result.ret () +/- [array::array]: forward function -/ +def array (LEN : Usize) : Result (Array U8 LEN) := + let a := Array.repeat U8 LEN (U8.ofInt 0) + Result.ret a + end array diff --git a/tests/lean/BetreeMain/Funs.lean b/tests/lean/BetreeMain/Funs.lean index 07ef08dc..6681731f 100644 --- a/tests/lean/BetreeMain/Funs.lean +++ b/tests/lean/BetreeMain/Funs.lean @@ -64,11 +64,6 @@ def betree.NodeIdCounter.fresh_id_back let i ← self.next_node_id + (U64.ofInt 1) Result.ret { next_node_id := i } -/- [core::num::u64::{9}::MAX] -/ -def core_num_u64_max_body : Result U64 := - Result.ret (U64.ofInt 18446744073709551615) -def core_num_u64_max_c : U64 := eval_global core_num_u64_max_body (by simp) - /- [betree_main::betree::upsert_update]: forward function -/ def betree.upsert_update (prev : Option U64) (st : betree.UpsertFunState) : Result U64 := @@ -81,10 +76,10 @@ def betree.upsert_update match st with | betree.UpsertFunState.Add v => do - let margin ← core_num_u64_max_c - prev0 + let margin ← core_u64_max - prev0 if margin >= v then prev0 + v - else Result.ret core_num_u64_max_c + else Result.ret core_u64_max | betree.UpsertFunState.Sub v => if prev0 >= v then prev0 - v diff --git a/tests/lean/Constants.lean b/tests/lean/Constants.lean index 51b415d6..b0cdaa90 100644 --- a/tests/lean/Constants.lean +++ b/tests/lean/Constants.lean @@ -9,12 +9,8 @@ namespace constants def x0_body : Result U32 := Result.ret (U32.ofInt 0) def x0_c : U32 := eval_global x0_body (by simp) -/- [core::num::u32::{8}::MAX] -/ -def core_num_u32_max_body : Result U32 := Result.ret (U32.ofInt 4294967295) -def core_num_u32_max_c : U32 := eval_global core_num_u32_max_body (by simp) - /- [constants::X1] -/ -def x1_body : Result U32 := Result.ret core_num_u32_max_c +def x1_body : Result U32 := Result.ret core_u32_max def x1_c : U32 := eval_global x1_body (by simp) /- [constants::X2] -/ diff --git a/tests/lean/Hashmap/Funs.lean b/tests/lean/Hashmap/Funs.lean index 30b30e0b..01c61de4 100644 --- a/tests/lean/Hashmap/Funs.lean +++ b/tests/lean/Hashmap/Funs.lean @@ -132,10 +132,6 @@ def HashMap.insert_no_resize let v ← Vec.index_mut_back (List T) self.slots hash_mod l0 Result.ret { self with slots := v } -/- [core::num::u32::{8}::MAX] -/ -def core_num_u32_max_body : Result U32 := Result.ret (U32.ofInt 4294967295) -def core_num_u32_max_c : U32 := eval_global core_num_u32_max_body (by simp) - /- [hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) -/ divergent def HashMap.move_elements_from_list_loop @@ -184,7 +180,7 @@ def HashMap.move_elements (there is a single backward function, and the forward function returns ()) -/ def HashMap.try_resize (T : Type) (self : HashMap T) : Result (HashMap T) := do - let max_usize ← Scalar.cast .Usize core_num_u32_max_c + let max_usize ← Scalar.cast .Usize core_u32_max let capacity := Vec.len (List T) self.slots let n1 ← max_usize / (Usize.ofInt 2) let (i, i0) := self.max_load_factor diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean index fe00ab14..4db54316 100644 --- a/tests/lean/Hashmap/Properties.lean +++ b/tests/lean/Hashmap/Properties.lean @@ -303,19 +303,17 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value | some _ => nhm.len_s = hm.len_s) := by rw [insert_no_resize] simp only [hash_key, bind_tc_ret] -- TODO: annoying - have _ : (Vec.len (List α) hm.slots).val ≠ 0 := by checkpoint + have _ : (Vec.len (List α) hm.slots).val ≠ 0 := by intro simp_all [inv] - progress keep _ as ⟨ hash_mod, hhm ⟩ - have _ : 0 ≤ hash_mod.val := by checkpoint scalar_tac + progress as ⟨ hash_mod, hhm ⟩ + have _ : 0 ≤ hash_mod.val := by scalar_tac have _ : hash_mod.val < Vec.length hm.slots := by have : 0 < hm.slots.val.len := by simp [inv] at hinv simp [hinv] -- TODO: we want to automate that simp [*, Int.emod_lt_of_pos] - -- TODO: change the spec of Vec.index_mut to introduce a let-binding. - -- or: make progress introduce the let-binding by itself (this is clearer) progress as ⟨ l, h_leq ⟩ -- TODO: make progress use the names written in the goal progress as ⟨ inserted ⟩ diff --git a/tests/lean/HashmapMain/Funs.lean b/tests/lean/HashmapMain/Funs.lean index aec957ec..848b1a35 100644 --- a/tests/lean/HashmapMain/Funs.lean +++ b/tests/lean/HashmapMain/Funs.lean @@ -147,10 +147,6 @@ def hashmap.HashMap.insert_no_resize let v ← Vec.index_mut_back (hashmap.List T) self.slots hash_mod l0 Result.ret { self with slots := v } -/- [core::num::u32::{8}::MAX] -/ -def core_num_u32_max_body : Result U32 := Result.ret (U32.ofInt 4294967295) -def core_num_u32_max_c : U32 := eval_global core_num_u32_max_body (by simp) - /- [hashmap_main::hashmap::HashMap::{0}::move_elements_from_list]: loop 0: merged forward/backward function (there is a single backward function, and the forward function returns ()) -/ divergent def hashmap.HashMap.move_elements_from_list_loop @@ -206,7 +202,7 @@ def hashmap.HashMap.move_elements def hashmap.HashMap.try_resize (T : Type) (self : hashmap.HashMap T) : Result (hashmap.HashMap T) := do - let max_usize ← Scalar.cast .Usize core_num_u32_max_c + let max_usize ← Scalar.cast .Usize core_u32_max let capacity := Vec.len (hashmap.List T) self.slots let n1 ← max_usize / (Usize.ofInt 2) let (i, i0) := self.max_load_factor diff --git a/tests/lean/Traits.lean b/tests/lean/Traits.lean new file mode 100644 index 00000000..5e812e95 --- /dev/null +++ b/tests/lean/Traits.lean @@ -0,0 +1 @@ +import Traits.Funs diff --git a/tests/lean/Traits/Funs.lean b/tests/lean/Traits/Funs.lean new file mode 100644 index 00000000..52ff0c0a --- /dev/null +++ b/tests/lean/Traits/Funs.lean @@ -0,0 +1,232 @@ +-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS +-- [traits]: function definitions +import Base +import Traits.Types +open Primitives + +namespace traits + +/- [traits::Bool::{0}::get_bool]: forward function -/ +def Bool.get_bool (self : Bool) : Result Bool := + Result.ret self + +/- Trait implementation: [traits::Bool::{0}] -/ +def Bool.BoolTraitInst : BoolTrait Bool := { + get_bool := Bool.get_bool +} + +/- [traits::BoolTrait::ret_true]: forward function -/ +def BoolTrait.ret_true + {Self : Type} (self_clause : BoolTrait Self) (self : Self) : Result Bool := + Result.ret true + +/- [traits::test_bool_trait_bool]: forward function -/ +def test_bool_trait_bool (x : Bool) : Result Bool := + do + let b ← Bool.get_bool x + if b + then BoolTrait.ret_true Bool.BoolTraitInst x + else Result.ret false + +/- [traits::Option::{1}::get_bool]: forward function -/ +def Option.get_bool (T : Type) (self : Option T) : Result Bool := + match self with + | Option.none => Result.ret false + | Option.some t => Result.ret true + +/- Trait implementation: [traits::Option::{1}] -/ +def Option.BoolTraitInst (T : Type) : BoolTrait (Option T) := { + get_bool := Option.get_bool T +} + +/- [traits::test_bool_trait_option]: forward function -/ +def test_bool_trait_option (T : Type) (x : Option T) : Result Bool := + do + let b ← Option.get_bool T x + if b + then BoolTrait.ret_true (Option.BoolTraitInst T) x + else Result.ret false + +/- [traits::test_bool_trait]: forward function -/ +def test_bool_trait (T : Type) (inst : BoolTrait T) (x : T) : Result Bool := + inst.get_bool x + +/- [traits::u64::{2}::to_u64]: forward function -/ +def u64.to_u64 (self : U64) : Result U64 := + Result.ret self + +/- Trait implementation: [traits::u64::{2}] -/ +def u64.ToU64Inst : ToU64 U64 := { + to_u64 := u64.to_u64 +} + +/- [traits::Tuple2::{3}::to_u64]: forward function -/ +def Tuple2.to_u64 (A : Type) (inst : ToU64 A) (self : (A × A)) : Result U64 := + do + let (t, t0) := self + let i ← inst.to_u64 t + let i0 ← inst.to_u64 t0 + i + i0 + +/- Trait implementation: [traits::Tuple2::{3}] -/ +def Tuple2.ToU64Inst (A : Type) (inst : ToU64 A) : ToU64 (A × A) := { + to_u64 := Tuple2.to_u64 A inst +} + +/- [traits::f]: forward function -/ +def f (T : Type) (inst : ToU64 T) (x : (T × T)) : Result U64 := + Tuple2.to_u64 T inst x + +/- [traits::g]: forward function -/ +def g (T : Type) (inst : ToU64 (T × T)) (x : (T × T)) : Result U64 := + inst.to_u64 x + +/- [traits::h0]: forward function -/ +def h0 (x : U64) : Result U64 := + u64.to_u64 x + +/- [traits::Wrapper::{4}::to_u64]: forward function -/ +def Wrapper.to_u64 + (T : Type) (inst : ToU64 T) (self : Wrapper T) : Result U64 := + inst.to_u64 self.x + +/- Trait implementation: [traits::Wrapper::{4}] -/ +def Wrapper.ToU64Inst (T : Type) (inst : ToU64 T) : ToU64 (Wrapper T) := { + to_u64 := Wrapper.to_u64 T inst +} + +/- [traits::h1]: forward function -/ +def h1 (x : Wrapper U64) : Result U64 := + Wrapper.to_u64 U64 u64.ToU64Inst x + +/- [traits::h2]: forward function -/ +def h2 (T : Type) (inst : ToU64 T) (x : Wrapper T) : Result U64 := + Wrapper.to_u64 T inst x + +/- [traits::u64::{5}::to_type]: forward function -/ +def u64.to_type (self : U64) : Result Bool := + Result.ret (self > (U64.ofInt 0)) + +/- Trait implementation: [traits::u64::{5}] -/ +def u64.ToTypeInst : ToType U64 Bool := { + to_type := u64.to_type +} + +/- [traits::h3]: forward function -/ +def h3 + (T1 T2 : Type) (inst : OfType T1) (inst0 : ToType T2 T1) (y : T2) : + Result T1 + := + inst.of_type T2 inst0 y + +/- [traits::h4]: forward function -/ +def h4 + (T1 T2 : Type) (inst : OfTypeBis T1 T2) (inst0 : ToType T2 T1) (y : T2) : + Result T1 + := + inst.of_type y + +/- [traits::TestType::{6}::test::TestType1::{0}::test]: forward function -/ +def TestType.test.TestType1.test + (self : TestType.test.TestType1) : Result Bool := + Result.ret (self._0 > (U64.ofInt 1)) + +/- Trait implementation: [traits::TestType::{6}::test::TestType1::{0}] -/ +def TestType.test.TestType1.TestTypetestTestTraitInst : TestType.test.TestTrait + TestType.test.TestType1 := { + test := TestType.test.TestType1.test +} + +/- [traits::TestType::{6}::test]: forward function -/ +def TestType.test + (T : Type) (inst : ToU64 T) (self : TestType T) (x : T) : Result Bool := + do + let x0 ← inst.to_u64 x + if x0 > (U64.ofInt 0) + then TestType.test.TestType1.test { _0 := (U64.ofInt 0) } + else Result.ret false + +/- [traits::BoolWrapper::{7}::to_type]: forward function -/ +def BoolWrapper.to_type + (T : Type) (inst : ToType Bool T) (self : BoolWrapper) : Result T := + inst.to_type self._0 + +/- Trait implementation: [traits::BoolWrapper::{7}] -/ +def BoolWrapper.ToTypeInst (T : Type) (inst : ToType Bool T) : ToType + BoolWrapper T := { + to_type := BoolWrapper.to_type T inst +} + +/- [traits::WithConstTy::LEN2] -/ +def with_const_ty_len2_body : Result Usize := Result.ret (Usize.ofInt 32) +def with_const_ty_len2_c : Usize := + eval_global with_const_ty_len2_body (by simp) + +/- [traits::Bool::{8}::LEN1] -/ +def bool_len1_body : Result Usize := Result.ret (Usize.ofInt 12) +def bool_len1_c : Usize := eval_global bool_len1_body (by simp) + +/- [traits::Bool::{8}::f]: merged forward/backward function + (there is a single backward function, and the forward function returns ()) -/ +def Bool.f (i : U64) (a : Array U8 (Usize.ofInt 32)) : Result U64 := + Result.ret i + +/- Trait implementation: [traits::Bool::{8}] -/ +def Bool.WithConstTyInst : WithConstTy Bool (Usize.ofInt 32) := { + LEN1 := bool_len1_c + LEN2 := with_const_ty_len2_c + V := U8 + W := U64 + W_clause_0 := u64.ToU64Inst + f := Bool.f +} + +/- [traits::use_with_const_ty1]: forward function -/ +def use_with_const_ty1 + (H : Type) (LEN : Usize) (inst : WithConstTy H LEN) : Result Usize := + let i := inst.LEN1 + Result.ret i + +/- [traits::use_with_const_ty2]: forward function -/ +def use_with_const_ty2 + (H : Type) (LEN : Usize) (inst : WithConstTy H LEN) (w : inst.W) : + Result Unit + := + Result.ret () + +/- [traits::use_with_const_ty3]: forward function -/ +def use_with_const_ty3 + (H : Type) (LEN : Usize) (inst : WithConstTy H LEN) (x : inst.W) : + Result U64 + := + inst.W_clause_0.to_u64 x + +/- [traits::test_where1]: forward function -/ +def test_where1 (T : Type) (_x : T) : Result Unit := + Result.ret () + +/- [traits::test_where2]: forward function -/ +def test_where2 + (T : Type) (inst : WithConstTy T (Usize.ofInt 32)) (_x : U32) : + Result Unit + := + Result.ret () + +/- [traits::test_child_trait1]: forward function -/ +def test_child_trait1 + (T : Type) (inst : ChildTrait T) (x : T) : Result alloc.string.String := + inst.parent_clause_0.get_name x + +/- [traits::test_child_trait2]: forward function -/ +def test_child_trait2 + (T : Type) (inst : ChildTrait T) (x : T) : Result inst.parent_clause_0.W := + inst.parent_clause_0.get_w x + +/- [traits::order1]: forward function -/ +def order1 + (T U : Type) (inst : ParentTrait0 T) (inst0 : ParentTrait0 U) : + Result Unit + := + Result.ret () + +end traits diff --git a/tests/lean/Traits/Types.lean b/tests/lean/Traits/Types.lean new file mode 100644 index 00000000..a8c12fe5 --- /dev/null +++ b/tests/lean/Traits/Types.lean @@ -0,0 +1,86 @@ +-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS +-- [traits]: type definitions +import Base +open Primitives + +namespace traits + +/- Trait declaration: [traits::BoolTrait] -/ +structure BoolTrait (Self : Type) where + get_bool : Self → Result Bool + +/- Trait declaration: [traits::ToU64] -/ +structure ToU64 (Self : Type) where + to_u64 : Self → Result U64 + +/- [traits::Wrapper] -/ +structure Wrapper (T : Type) where + x : T + +/- Trait declaration: [traits::ToType] -/ +structure ToType (Self T : Type) where + to_type : Self → Result T + +/- Trait declaration: [traits::OfType] -/ +structure OfType (Self : Type) where + of_type : forall (T : Type) (inst : ToType T Self), T → Result Self + +/- Trait declaration: [traits::OfTypeBis] -/ +structure OfTypeBis (Self T : Type) where + parent_clause_0 :ToType T Self + of_type : T → Result Self + +/- [traits::TestType] -/ +structure TestType (T : Type) where + _0 : T + +/- [traits::TestType::{6}::test::TestType1] -/ +structure TestType.test.TestType1 where + _0 : U64 + +/- Trait declaration: [traits::TestType::{6}::test::TestTrait] -/ +structure TestType.test.TestTrait (Self : Type) where + test : Self → Result Bool + +/- [traits::BoolWrapper] -/ +structure BoolWrapper where + _0 : Bool + +/- Trait declaration: [traits::WithConstTy] -/ +structure WithConstTy (Self : Type) (LEN : Usize) where + LEN1 : Usize + LEN2 : Usize + V : Type + W : Type + W_clause_0 : ToU64 W + f : W → Array U8 LEN → Result W + +/- [alloc::string::String] -/ +axiom alloc.string.String : Type + +/- Trait declaration: [traits::ParentTrait0] -/ +structure ParentTrait0 (Self : Type) where + W : Type + get_name : Self → Result alloc.string.String + get_w : Self → Result W + +/- Trait declaration: [traits::ParentTrait1] -/ +structure ParentTrait1 (Self : Type) where + +/- Trait declaration: [traits::ChildTrait] -/ +structure ChildTrait (Self : Type) where + parent_clause_0 :ParentTrait0 Self + parent_clause_1 :ParentTrait1 Self + +/- Trait declaration: [traits::Iterator] -/ +structure Iterator (Self : Type) where + Item : Type + +/- Trait declaration: [traits::IntoIterator] -/ +structure IntoIterator (Self : Type) where + Item : Type + IntoIter : Type + IntoIter_clause_0 : Iterator IntoIter + into_iter : Self → Result IntoIter + +end traits diff --git a/tests/lean/lakefile.lean b/tests/lean/lakefile.lean index 8acf6973..fef94971 100644 --- a/tests/lean/lakefile.lean +++ b/tests/lean/lakefile.lean @@ -19,3 +19,4 @@ package «tests» {} @[default_target] lean_lib paper @[default_target] lean_lib poloniusList @[default_target] lean_lib array +@[default_target] lean_lib traits |