summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-22 23:04:31 +0100
committerSon Ho2023-12-22 23:04:31 +0100
commit9a8e43df626400aacdfcb9d2cf2eec38d71d2d73 (patch)
tree2df260fb8340c64348f046c32cbb1c712a508341
parentd9ace7d5f1968f26b586fb712c725b2ce51086f8 (diff)
Fix minor issues
-rw-r--r--backends/coq/Primitives.v21
-rw-r--r--backends/fstar/merge/Primitives.fst21
-rw-r--r--backends/lean/Base/Primitives/ArraySlice.lean42
-rw-r--r--compiler/ExtractBase.ml81
-rw-r--r--compiler/PureUtils.ml2
-rw-r--r--compiler/SymbolicToPure.ml57
6 files changed, 186 insertions, 38 deletions
diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v
index e6d3118f..990e27e4 100644
--- a/backends/coq/Primitives.v
+++ b/backends/coq/Primitives.v
@@ -585,6 +585,13 @@ Axiom array_repeat : forall (T : Type) (n : usize) (x : T), array T n.
Axiom array_index_usize : forall (T : Type) (n : usize) (x : array T n) (i : usize), result T.
Axiom array_update_usize : forall (T : Type) (n : usize) (x : array T n) (i : usize) (nx : T), result (array T n).
+Definition array_index_mut_usize (T : Type) (n : usize) (a : array T n) (i : usize) :
+ result (T * (T -> result (array T n))) :=
+ match array_index_usize T n a i with
+ | Fail_ e => Fail_ e
+ | Return x => Return (x, array_update_usize T n a i)
+ end.
+
(*** Slice *)
Definition slice T := { l: list T | Z.of_nat (length l) <= usize_max}.
@@ -592,11 +599,25 @@ Axiom slice_len : forall (T : Type) (s : slice T), usize.
Axiom slice_index_usize : forall (T : Type) (x : slice T) (i : usize), result T.
Axiom slice_update_usize : forall (T : Type) (x : slice T) (i : usize) (nx : T), result (slice T).
+Definition slice_index_mut_usize (T : Type) (s : slice T) (i : usize) :
+ result (T * (T -> result (slice T))) :=
+ match slice_index_usize T s i with
+ | Fail_ e => Fail_ e
+ | Return x => Return (x, slice_update_usize T s i)
+ end.
+
(*** Subslices *)
Axiom array_to_slice : forall (T : Type) (n : usize) (x : array T n), result (slice T).
Axiom array_from_slice : forall (T : Type) (n : usize) (x : array T n) (s : slice T), result (array T n).
+Definition array_to_slice_mut (T : Type) (n : usize) (a : array T n) :
+ result (slice T * (slice T -> result (array T n))) :=
+ match array_to_slice T n a with
+ | Fail_ e => Fail_ e
+ | Return x => Return (x, array_from_slice T n a)
+ end.
+
Axiom array_subslice: forall (T : Type) (n : usize) (x : array T n) (r : core_ops_range_Range usize), result (slice T).
Axiom array_update_subslice: forall (T : Type) (n : usize) (x : array T n) (r : core_ops_range_Range usize) (ns : slice T), result (array T n).
diff --git a/backends/fstar/merge/Primitives.fst b/backends/fstar/merge/Primitives.fst
index 8011efa1..fca80829 100644
--- a/backends/fstar/merge/Primitives.fst
+++ b/backends/fstar/merge/Primitives.fst
@@ -531,10 +531,18 @@ let array_index_usize (a : Type0) (n : usize) (x : array a n) (i : usize) : resu
if i < length x then Return (index x i)
else Fail Failure
-let array_update_usize (a : Type0) (n : usize) (x : array a n) (i : usize) (nx : a) : result (array a n) =
+let array_update_usize (a : Type0) (n : usize) (x : array a n) (i : usize) (nx : a) :
+ result (array a n) =
if i < length x then Return (list_update x i nx)
else Fail Failure
+let array_index_mut_usize (a : Type0) (n : usize) (x : array a n) (i : usize) :
+ result (a & (a -> result (array a n))) =
+ match array_index_usize a n x i with
+ | Fail e -> Fail e
+ | Return v ->
+ Return (v, array_update_usize a n x i)
+
(*** Slice *)
type slice (a : Type0) = s:list a{length s <= usize_max}
@@ -548,6 +556,13 @@ let slice_update_usize (a : Type0) (x : slice a) (i : usize) (nx : a) : result (
if i < length x then Return (list_update x i nx)
else Fail Failure
+let slice_index_mut_usize (a : Type0) (s : slice a) (i : usize) :
+ result (a & (a -> result (slice a))) =
+ match slice_index_usize a s i with
+ | Fail e -> Fail e
+ | Return x ->
+ Return (x, slice_update_usize a s i)
+
(*** Subslices *)
let array_to_slice (a : Type0) (n : usize) (x : array a n) : result (slice a) = Return x
@@ -555,6 +570,10 @@ let array_from_slice (a : Type0) (n : usize) (x : array a n) (s : slice a) : res
if length s = n then Return s
else Fail Failure
+let array_to_slice_mut (a : Type0) (n : usize) (x : array a n) :
+ result (slice a & (slice a -> result (array a n))) =
+ Return (x, array_from_slice a n x)
+
// TODO: finish the definitions below (there lacks [List.drop] and [List.take] in the standard library *)
let array_subslice (a : Type0) (n : usize) (x : array a n) (r : core_ops_range_Range usize) : result (slice a) =
admit()
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index 59432a0b..5057fb01 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -93,6 +93,21 @@ theorem Array.update_usize_spec {α : Type u} {n : Usize} (v: Array α n) (i: Us
. simp_all [length]; cases h <;> scalar_tac
. simp_all
+def Array.index_mut_usize (α : Type u) (n : Usize) (v: Array α n) (i: Usize) :
+ Result (α × (α -> Result (Array α n))) := do
+ let x ← index_usize α n v i
+ ret (x, update_usize α n v i)
+
+@[pspec]
+theorem Array.index_mut_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v: Array α n) (i: Usize)
+ (hbound : i.val < v.length) :
+ ∃ x back, v.index_mut_usize α n i = ret (x, back) ∧
+ x = v.val.index i.val ∧
+ back = update_usize α n v i := by
+ simp only [index_mut_usize, Bind.bind, bind]
+ have ⟨ x, h ⟩ := index_usize_spec v i hbound
+ simp [h]
+
def Slice (α : Type u) := { l : List α // l.length ≤ Usize.max }
instance (a : Type u) : Arith.HasIntProp (Slice a) where
@@ -167,6 +182,21 @@ theorem Slice.update_usize_spec {α : Type u} (v: Slice α) (i: Usize) (x : α)
. simp_all [length]; cases h <;> scalar_tac
. simp_all
+def Slice.index_mut_usize (α : Type u) (v: Slice α) (i: Usize) :
+ Result (α × (α → Result (Slice α))) := do
+ let x ← Slice.index_usize α v i
+ ret (x, Slice.update_usize α v i)
+
+@[pspec]
+theorem Slice.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i: Usize)
+ (hbound : i.val < v.length) :
+ ∃ x back, v.index_mut_usize α i = ret (x, back) ∧
+ x = v.val.index i.val ∧
+ back = Slice.update_usize α v i := by
+ simp only [index_mut_usize, Bind.bind, bind]
+ have ⟨ x, h ⟩ := Slice.index_usize_spec v i hbound
+ simp [h]
+
/- Array to slice/subslices -/
/- We could make this function not use the `Result` type. By making it monadic, we
@@ -190,6 +220,18 @@ theorem Array.from_slice_spec {α : Type u} {n : Usize} (a : Array α n) (ns : S
∃ na, from_slice α n a ns = ret na ∧ na.val = ns.val
:= by simp [from_slice, *]
+def Array.to_slice_mut (α : Type u) (n : Usize) (a : Array α n) :
+ Result (Slice α × (Slice α → Result (Array α n))) := do
+ let s ← Array.to_slice α n a
+ ret (s, Array.from_slice α n a)
+
+@[pspec]
+theorem Array.to_slice_mut_spec {α : Type u} {n : Usize} (v : Array α n) :
+ ∃ s back, to_slice_mut α n v = ret (s, back) ∧
+ v.val = s.val ∧
+ back = Array.from_slice α n v
+ := by simp [to_slice_mut, to_slice]
+
def Array.subslice (α : Type u) (n : Usize) (a : Array α n) (r : Range Usize) : Result (Slice α) :=
-- TODO: not completely sure here
if r.start.val < r.end_.val ∧ r.end_.val ≤ a.val.len then
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 0af7a9b4..db887539 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -1051,33 +1051,60 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
let assumed_llbc_functions () :
(A.assumed_fun_id * T.RegionGroupId.id option * string) list =
let rg0 = Some T.RegionGroupId.zero in
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexShared, None, "array_index_usize");
- (ArrayIndexMut, None, "array_index_usize");
- (ArrayIndexMut, rg0, "array_update_usize");
- (ArrayToSliceShared, None, "array_to_slice");
- (ArrayToSliceMut, None, "array_to_slice");
- (ArrayToSliceMut, rg0, "array_from_slice");
- (ArrayRepeat, None, "array_repeat");
- (SliceIndexShared, None, "slice_index_usize");
- (SliceIndexMut, None, "slice_index_usize");
- (SliceIndexMut, rg0, "slice_update_usize");
- ]
- | Lean ->
- [
- (ArrayIndexShared, None, "Array.index_usize");
- (ArrayIndexMut, None, "Array.index_usize");
- (ArrayIndexMut, rg0, "Array.update_usize");
- (ArrayToSliceShared, None, "Array.to_slice");
- (ArrayToSliceMut, None, "Array.to_slice");
- (ArrayToSliceMut, rg0, "Array.from_slice");
- (ArrayRepeat, None, "Array.repeat");
- (SliceIndexShared, None, "Slice.index_usize");
- (SliceIndexMut, None, "Slice.index_usize");
- (SliceIndexMut, rg0, "Slice.update_usize");
- ]
+ let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexShared, None, "array_index_usize");
+ (ArrayToSliceShared, None, "array_to_slice");
+ (ArrayRepeat, None, "array_repeat");
+ (SliceIndexShared, None, "slice_index_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexShared, None, "Array.index_usize");
+ (ArrayToSliceShared, None, "Array.to_slice");
+ (ArrayRepeat, None, "Array.repeat");
+ (SliceIndexShared, None, "Slice.index_usize");
+ ]
+ in
+ let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
+ if !Config.return_back_funs then
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexMut, None, "array_index_mut_usize");
+ (ArrayToSliceMut, None, "array_to_slice_mut");
+ (SliceIndexMut, None, "slice_index_mut_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexMut, None, "Array.index_mut_usize");
+ (ArrayToSliceMut, None, "Array.to_slice_mut");
+ (SliceIndexMut, None, "Slice.index_mut_usize");
+ ]
+ else
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexMut, None, "array_index_usize");
+ (ArrayIndexMut, rg0, "array_update_usize");
+ (ArrayToSliceMut, None, "array_to_slice");
+ (ArrayToSliceMut, rg0, "array_from_slice");
+ (SliceIndexMut, None, "slice_index_usize");
+ (SliceIndexMut, rg0, "slice_update_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexMut, None, "Array.index_usize");
+ (ArrayIndexMut, rg0, "Array.update_usize");
+ (ArrayToSliceMut, None, "Array.to_slice");
+ (ArrayToSliceMut, rg0, "Array.from_slice");
+ (SliceIndexMut, None, "Slice.index_usize");
+ (SliceIndexMut, rg0, "Slice.update_usize");
+ ]
+ in
+ regular @ mut_funs
let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
match !backend with
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index cc439e64..80bf3c42 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -731,7 +731,7 @@ let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) :
let mk_lambdas_from_vars (vars : var list) (mps : mplace option list)
(e : texpression) : texpression =
let vars = List.combine vars mps in
- List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars
+ List.fold_right (fun (v, mp) e -> mk_lambda_from_var v mp e) vars e
let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression =
match e.e with
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index f0d1ca62..3a50e495 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -734,11 +734,15 @@ let rec translate_back_ty (type_infos : type_infos)
None
| TTraitType (trait_ref, generics, type_name) ->
assert (generics.regions = []);
- (* Translate the trait ref and the generics as "forward" generics -
- we do not want to filter any type *)
- let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
- let generics = translate_fwd_generic_args type_infos generics in
- Some (TTraitType (trait_ref, generics, type_name))
+ assert (
+ AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id);
+ if inside_mut then
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TTraitType (trait_ref, generics, type_name))
+ else None
| TArrow _ -> raise (Failure "TODO")
(** Simply calls [translate_back_ty] *)
@@ -1056,7 +1060,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
Upon ending the abstraction for 'a, we need to get back the borrow
the function returned.
*)
- List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ let inputs =
+ List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let output = Print.Types.ty_to_string ctx sg.output in
+ let inputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) inputs
+ in
+ "translate_back_inputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n"));
+ inputs
in
let compute_back_outputs_for_gid (gid : RegionGroupId.id) :
string option list * ty list =
@@ -1080,7 +1098,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let outputs =
List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs
in
- List.split outputs
+ let names, outputs = List.split outputs in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let inputs =
+ Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs
+ in
+ let outputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) outputs
+ in
+ "compute_back_outputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n"));
+ (names, outputs)
in
let compute_back_info_for_group (rg : T.region_var_group) :
RegionGroupId.id * back_sg_info =
@@ -1201,8 +1233,15 @@ let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx)
(fun (v : LlbcAst.var) -> v.name)
(LlbcAstUtils.fun_body_get_input_vars body)
in
- translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature
- input_names
+ let sg =
+ translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature
+ input_names
+ in
+ log#ldebug
+ (lazy
+ ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: "
+ ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n"));
+ sg
let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
=