From 9a8e43df626400aacdfcb9d2cf2eec38d71d2d73 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 23:04:31 +0100 Subject: Fix minor issues --- backends/coq/Primitives.v | 21 +++++++ backends/fstar/merge/Primitives.fst | 21 ++++++- backends/lean/Base/Primitives/ArraySlice.lean | 42 ++++++++++++++ compiler/ExtractBase.ml | 81 ++++++++++++++++++--------- compiler/PureUtils.ml | 2 +- compiler/SymbolicToPure.ml | 57 ++++++++++++++++--- 6 files changed, 186 insertions(+), 38 deletions(-) diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v index e6d3118f..990e27e4 100644 --- a/backends/coq/Primitives.v +++ b/backends/coq/Primitives.v @@ -585,6 +585,13 @@ Axiom array_repeat : forall (T : Type) (n : usize) (x : T), array T n. Axiom array_index_usize : forall (T : Type) (n : usize) (x : array T n) (i : usize), result T. Axiom array_update_usize : forall (T : Type) (n : usize) (x : array T n) (i : usize) (nx : T), result (array T n). +Definition array_index_mut_usize (T : Type) (n : usize) (a : array T n) (i : usize) : + result (T * (T -> result (array T n))) := + match array_index_usize T n a i with + | Fail_ e => Fail_ e + | Return x => Return (x, array_update_usize T n a i) + end. + (*** Slice *) Definition slice T := { l: list T | Z.of_nat (length l) <= usize_max}. @@ -592,11 +599,25 @@ Axiom slice_len : forall (T : Type) (s : slice T), usize. Axiom slice_index_usize : forall (T : Type) (x : slice T) (i : usize), result T. Axiom slice_update_usize : forall (T : Type) (x : slice T) (i : usize) (nx : T), result (slice T). +Definition slice_index_mut_usize (T : Type) (s : slice T) (i : usize) : + result (T * (T -> result (slice T))) := + match slice_index_usize T s i with + | Fail_ e => Fail_ e + | Return x => Return (x, slice_update_usize T s i) + end. + (*** Subslices *) Axiom array_to_slice : forall (T : Type) (n : usize) (x : array T n), result (slice T). Axiom array_from_slice : forall (T : Type) (n : usize) (x : array T n) (s : slice T), result (array T n). +Definition array_to_slice_mut (T : Type) (n : usize) (a : array T n) : + result (slice T * (slice T -> result (array T n))) := + match array_to_slice T n a with + | Fail_ e => Fail_ e + | Return x => Return (x, array_from_slice T n a) + end. + Axiom array_subslice: forall (T : Type) (n : usize) (x : array T n) (r : core_ops_range_Range usize), result (slice T). Axiom array_update_subslice: forall (T : Type) (n : usize) (x : array T n) (r : core_ops_range_Range usize) (ns : slice T), result (array T n). diff --git a/backends/fstar/merge/Primitives.fst b/backends/fstar/merge/Primitives.fst index 8011efa1..fca80829 100644 --- a/backends/fstar/merge/Primitives.fst +++ b/backends/fstar/merge/Primitives.fst @@ -531,10 +531,18 @@ let array_index_usize (a : Type0) (n : usize) (x : array a n) (i : usize) : resu if i < length x then Return (index x i) else Fail Failure -let array_update_usize (a : Type0) (n : usize) (x : array a n) (i : usize) (nx : a) : result (array a n) = +let array_update_usize (a : Type0) (n : usize) (x : array a n) (i : usize) (nx : a) : + result (array a n) = if i < length x then Return (list_update x i nx) else Fail Failure +let array_index_mut_usize (a : Type0) (n : usize) (x : array a n) (i : usize) : + result (a & (a -> result (array a n))) = + match array_index_usize a n x i with + | Fail e -> Fail e + | Return v -> + Return (v, array_update_usize a n x i) + (*** Slice *) type slice (a : Type0) = s:list a{length s <= usize_max} @@ -548,6 +556,13 @@ let slice_update_usize (a : Type0) (x : slice a) (i : usize) (nx : a) : result ( if i < length x then Return (list_update x i nx) else Fail Failure +let slice_index_mut_usize (a : Type0) (s : slice a) (i : usize) : + result (a & (a -> result (slice a))) = + match slice_index_usize a s i with + | Fail e -> Fail e + | Return x -> + Return (x, slice_update_usize a s i) + (*** Subslices *) let array_to_slice (a : Type0) (n : usize) (x : array a n) : result (slice a) = Return x @@ -555,6 +570,10 @@ let array_from_slice (a : Type0) (n : usize) (x : array a n) (s : slice a) : res if length s = n then Return s else Fail Failure +let array_to_slice_mut (a : Type0) (n : usize) (x : array a n) : + result (slice a & (slice a -> result (array a n))) = + Return (x, array_from_slice a n x) + // TODO: finish the definitions below (there lacks [List.drop] and [List.take] in the standard library *) let array_subslice (a : Type0) (n : usize) (x : array a n) (r : core_ops_range_Range usize) : result (slice a) = admit() diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean index 59432a0b..5057fb01 100644 --- a/backends/lean/Base/Primitives/ArraySlice.lean +++ b/backends/lean/Base/Primitives/ArraySlice.lean @@ -93,6 +93,21 @@ theorem Array.update_usize_spec {α : Type u} {n : Usize} (v: Array α n) (i: Us . simp_all [length]; cases h <;> scalar_tac . simp_all +def Array.index_mut_usize (α : Type u) (n : Usize) (v: Array α n) (i: Usize) : + Result (α × (α -> Result (Array α n))) := do + let x ← index_usize α n v i + ret (x, update_usize α n v i) + +@[pspec] +theorem Array.index_mut_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v: Array α n) (i: Usize) + (hbound : i.val < v.length) : + ∃ x back, v.index_mut_usize α n i = ret (x, back) ∧ + x = v.val.index i.val ∧ + back = update_usize α n v i := by + simp only [index_mut_usize, Bind.bind, bind] + have ⟨ x, h ⟩ := index_usize_spec v i hbound + simp [h] + def Slice (α : Type u) := { l : List α // l.length ≤ Usize.max } instance (a : Type u) : Arith.HasIntProp (Slice a) where @@ -167,6 +182,21 @@ theorem Slice.update_usize_spec {α : Type u} (v: Slice α) (i: Usize) (x : α) . simp_all [length]; cases h <;> scalar_tac . simp_all +def Slice.index_mut_usize (α : Type u) (v: Slice α) (i: Usize) : + Result (α × (α → Result (Slice α))) := do + let x ← Slice.index_usize α v i + ret (x, Slice.update_usize α v i) + +@[pspec] +theorem Slice.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i: Usize) + (hbound : i.val < v.length) : + ∃ x back, v.index_mut_usize α i = ret (x, back) ∧ + x = v.val.index i.val ∧ + back = Slice.update_usize α v i := by + simp only [index_mut_usize, Bind.bind, bind] + have ⟨ x, h ⟩ := Slice.index_usize_spec v i hbound + simp [h] + /- Array to slice/subslices -/ /- We could make this function not use the `Result` type. By making it monadic, we @@ -190,6 +220,18 @@ theorem Array.from_slice_spec {α : Type u} {n : Usize} (a : Array α n) (ns : S ∃ na, from_slice α n a ns = ret na ∧ na.val = ns.val := by simp [from_slice, *] +def Array.to_slice_mut (α : Type u) (n : Usize) (a : Array α n) : + Result (Slice α × (Slice α → Result (Array α n))) := do + let s ← Array.to_slice α n a + ret (s, Array.from_slice α n a) + +@[pspec] +theorem Array.to_slice_mut_spec {α : Type u} {n : Usize} (v : Array α n) : + ∃ s back, to_slice_mut α n v = ret (s, back) ∧ + v.val = s.val ∧ + back = Array.from_slice α n v + := by simp [to_slice_mut, to_slice] + def Array.subslice (α : Type u) (n : Usize) (a : Array α n) (r : Range Usize) : Result (Slice α) := -- TODO: not completely sure here if r.start.val < r.end_.val ∧ r.end_.val ≤ a.val.len then diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 0af7a9b4..db887539 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1051,33 +1051,60 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = let assumed_llbc_functions () : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = let rg0 = Some T.RegionGroupId.zero in - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexShared, None, "array_index_usize"); - (ArrayIndexMut, None, "array_index_usize"); - (ArrayIndexMut, rg0, "array_update_usize"); - (ArrayToSliceShared, None, "array_to_slice"); - (ArrayToSliceMut, None, "array_to_slice"); - (ArrayToSliceMut, rg0, "array_from_slice"); - (ArrayRepeat, None, "array_repeat"); - (SliceIndexShared, None, "slice_index_usize"); - (SliceIndexMut, None, "slice_index_usize"); - (SliceIndexMut, rg0, "slice_update_usize"); - ] - | Lean -> - [ - (ArrayIndexShared, None, "Array.index_usize"); - (ArrayIndexMut, None, "Array.index_usize"); - (ArrayIndexMut, rg0, "Array.update_usize"); - (ArrayToSliceShared, None, "Array.to_slice"); - (ArrayToSliceMut, None, "Array.to_slice"); - (ArrayToSliceMut, rg0, "Array.from_slice"); - (ArrayRepeat, None, "Array.repeat"); - (SliceIndexShared, None, "Slice.index_usize"); - (SliceIndexMut, None, "Slice.index_usize"); - (SliceIndexMut, rg0, "Slice.update_usize"); - ] + let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, None, "array_index_usize"); + (ArrayToSliceShared, None, "array_to_slice"); + (ArrayRepeat, None, "array_repeat"); + (SliceIndexShared, None, "slice_index_usize"); + ] + | Lean -> + [ + (ArrayIndexShared, None, "Array.index_usize"); + (ArrayToSliceShared, None, "Array.to_slice"); + (ArrayRepeat, None, "Array.repeat"); + (SliceIndexShared, None, "Slice.index_usize"); + ] + in + let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + if !Config.return_back_funs then + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_mut_usize"); + (ArrayToSliceMut, None, "array_to_slice_mut"); + (SliceIndexMut, None, "slice_index_mut_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_mut_usize"); + (ArrayToSliceMut, None, "Array.to_slice_mut"); + (SliceIndexMut, None, "Slice.index_mut_usize"); + ] + else + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexMut, None, "array_index_usize"); + (ArrayIndexMut, rg0, "array_update_usize"); + (ArrayToSliceMut, None, "array_to_slice"); + (ArrayToSliceMut, rg0, "array_from_slice"); + (SliceIndexMut, None, "slice_index_usize"); + (SliceIndexMut, rg0, "slice_update_usize"); + ] + | Lean -> + [ + (ArrayIndexMut, None, "Array.index_usize"); + (ArrayIndexMut, rg0, "Array.update_usize"); + (ArrayToSliceMut, None, "Array.to_slice"); + (ArrayToSliceMut, rg0, "Array.from_slice"); + (SliceIndexMut, None, "Slice.index_usize"); + (SliceIndexMut, rg0, "Slice.update_usize"); + ] + in + regular @ mut_funs let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index cc439e64..80bf3c42 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -731,7 +731,7 @@ let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) (e : texpression) : texpression = let vars = List.combine vars mps in - List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars + List.fold_right (fun (v, mp) e -> mk_lambda_from_var v mp e) vars e let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = match e.e with diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f0d1ca62..3a50e495 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -734,11 +734,15 @@ let rec translate_back_ty (type_infos : type_infos) None | TTraitType (trait_ref, generics, type_name) -> assert (generics.regions = []); - (* Translate the trait ref and the generics as "forward" generics - - we do not want to filter any type *) - let trait_ref = translate_fwd_trait_ref type_infos trait_ref in - let generics = translate_fwd_generic_args type_infos generics in - Some (TTraitType (trait_ref, generics, type_name)) + assert ( + AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id); + if inside_mut then + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TTraitType (trait_ref, generics, type_name)) + else None | TArrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) @@ -1056,7 +1060,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed Upon ending the abstraction for 'a, we need to get back the borrow the function returned. *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + let inputs = + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let output = Print.Types.ty_to_string ctx sg.output in + let inputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) inputs + in + "translate_back_inputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n")); + inputs in let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = @@ -1080,7 +1098,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let outputs = List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs in - List.split outputs + let names, outputs = List.split outputs in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let inputs = + Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs + in + let outputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) outputs + in + "compute_back_outputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n")); + (names, outputs) in let compute_back_info_for_group (rg : T.region_var_group) : RegionGroupId.id * back_sg_info = @@ -1201,8 +1233,15 @@ let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) (fun (v : LlbcAst.var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature - input_names + let sg = + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + in + log#ldebug + (lazy + ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: " + ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n")); + sg let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = -- cgit v1.2.3