diff options
author | Son HO | 2023-08-07 10:42:15 +0200 |
---|---|---|
committer | GitHub | 2023-08-07 10:42:15 +0200 |
commit | 1cbc7ce007cf3433a6df9bdeb12c4e27511fad9c (patch) | |
tree | c15a16b591cf25df3ccff87ad4cd7c46ddecc489 /compiler | |
parent | 887d0ef1efc8912c6273b5ebcf979384e9d7fa97 (diff) | |
parent | 9e14cdeaf429e9faff2d1efdcf297c1ac7dc7f1f (diff) |
Merge pull request #32 from AeneasVerif/son_arrays
Add support for arrays/slices and const generics
Diffstat (limited to 'compiler')
41 files changed, 1935 insertions, 1053 deletions
diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index e751d0ba..11cd5666 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -42,6 +42,8 @@ module Sig = struct let rg_id_0 = T.RegionGroupId.of_int 0 let tvar_id_0 = T.TypeVarId.of_int 0 let tvar_0 : T.sty = T.TypeVar tvar_id_0 + let cgvar_id_0 = T.ConstGenericVarId.of_int 0 + let cgvar_0 : T.const_generic = T.ConstGenericVar cgvar_id_0 (** Region 'a of id 0 *) let region_param_0 : T.region_var = { T.index = rvar_id_0; name = Some "'a" } @@ -53,11 +55,25 @@ module Sig = struct (** Type parameter [T] of id 0 *) let type_param_0 : T.type_var = { T.index = tvar_id_0; name = "T" } + let usize_ty : T.sty = T.Literal (Integer Usize) + + (** Const generic parameter [const N : usize] of id 0 *) + let cg_param_0 : T.const_generic_var = + { T.index = cgvar_id_0; name = "N"; ty = Integer Usize } + + let empty_const_generic_params : T.const_generic_var list = [] + let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) : T.sty = let ref_kind = if is_mut then T.Mut else T.Shared in mk_ref_ty r ty ref_kind + let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty = + Adt (Assumed Array, [], [ ty ], [ cg ]) + + let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], []) + let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], []) + (** [fn<T>(&'a mut T, T) -> T] *) let mem_replace_sig : A.fun_sig = (* The signature fields *) @@ -73,6 +89,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -84,6 +101,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy = []; type_params = [ type_param_0 ] (* <T> *); + const_generic_params = empty_const_generic_params; inputs = [ tvar_0 (* T *) ]; output = mk_box_ty tvar_0 (* Box<T> *); } @@ -95,6 +113,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy = []; type_params = [ type_param_0 ] (* <T> *); + const_generic_params = empty_const_generic_params; inputs = [ mk_box_ty tvar_0 (* Box<T> *) ]; output = mk_unit_ty (* () *); } @@ -112,6 +131,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params = [ type_param_0 ] (* <T> *); + const_generic_params = empty_const_generic_params; inputs = [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ]; output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *); @@ -135,6 +155,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -157,6 +178,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -180,6 +202,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -199,6 +222,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -223,6 +247,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -232,20 +257,132 @@ module Sig = struct (** [fn<T>(&'a mut Vec<T>, usize) -> &'a mut T] *) let vec_index_mut_sig : A.fun_sig = vec_index_gen_sig true + + (** Array/slice functions *) + + (** Small helper. + + Return the type: + {[ + fn<'a, T>(&'a input_ty, index_ty) -> &'a output_ty + ]} + + Remarks: + The [input_ty] and [output_ty] are parameterized by a type variable id. + The [mut_borrow] boolean controls whether we use a shared or a mutable + borrow. + *) + let mk_array_slice_borrow_sig (cgs : T.const_generic_var list) + (input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option) + (output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig = + (* The signature fields *) + let region_params = [ region_param_0 ] in + let regions_hierarchy = [ region_group_0 ] (* <'a> *) in + let type_params = [ type_param_0 ] (* <T> *) in + let inputs = + [ + mk_ref_ty rvar_0 + (input_ty type_param_0.index) + is_mut (* &'a (mut) input_ty<T> *); + ] + in + let inputs = + List.append inputs (match index_ty with None -> [] | Some ty -> [ ty ]) + in + let output = + mk_ref_ty rvar_0 + (output_ty type_param_0.index) + is_mut (* &'a (mut) output_ty<T> *) + in + { + region_params; + num_early_bound_regions = 0; + regions_hierarchy; + type_params; + const_generic_params = cgs; + inputs; + output; + } + + let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig = + (* Array<T, N> *) + let input_ty id = + if is_array then mk_array_ty (T.TypeVar id) cgvar_0 + else mk_slice_ty (T.TypeVar id) + in + (* usize *) + let index_ty = usize_ty in + (* T *) + let output_ty id = T.TypeVar id in + let cgs = if is_array then [ cg_param_0 ] else [] in + mk_array_slice_borrow_sig cgs input_ty (Some index_ty) output_ty is_mut + + let array_index_sig (is_mut : bool) = mk_array_slice_index_sig true is_mut + let slice_index_sig (is_mut : bool) = mk_array_slice_index_sig false is_mut + + let array_to_slice_sig (is_mut : bool) : A.fun_sig = + (* Array<T, N> *) + let input_ty id = mk_array_ty (T.TypeVar id) cgvar_0 in + (* Slice<T> *) + let output_ty id = mk_slice_ty (T.TypeVar id) in + let cgs = [ cg_param_0 ] in + mk_array_slice_borrow_sig cgs input_ty None output_ty is_mut + + let mk_array_slice_subslice_sig (is_array : bool) (is_mut : bool) : A.fun_sig + = + (* Array<T, N> *) + let input_ty id = + if is_array then mk_array_ty (T.TypeVar id) cgvar_0 + else mk_slice_ty (T.TypeVar id) + in + (* Range *) + let index_ty = range_ty in + (* Slice<T> *) + let output_ty id = mk_slice_ty (T.TypeVar id) in + let cgs = if is_array then [ cg_param_0 ] else [] in + mk_array_slice_borrow_sig cgs input_ty (Some index_ty) output_ty is_mut + + let array_subslice_sig (is_mut : bool) = + mk_array_slice_subslice_sig true is_mut + + let slice_subslice_sig (is_mut : bool) = + mk_array_slice_subslice_sig false is_mut + + (** Helper: + [fn<T>(&'a [T]) -> usize] + *) + let slice_len_sig : A.fun_sig = + (* The signature fields *) + let region_params = [ region_param_0 ] in + let regions_hierarchy = [ region_group_0 ] (* <'a> *) in + let type_params = [ type_param_0 ] (* <T> *) in + let inputs = + [ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ] + in + let output = mk_usize_ty (* usize *) in + { + region_params; + num_early_bound_regions = 0; + regions_hierarchy; + type_params; + const_generic_params = empty_const_generic_params; + inputs; + output; + } end type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name (** The list of assumed functions and all their information: - their signature - - a boolean indicating whether the function can fail or not + - a boolean indicating whether the function can fail or not (if true: can fail) - their name Rk.: following what is written above, we don't include [Box::free]. - + Remark about the vector functions: for [Vec::len] to be correct and return a [usize], we have to make sure that vectors are bounded by the max usize. - Followingly, [Vec::push] is monadic. + As a consequence, [Vec::push] is monadic. *) let assumed_infos : assumed_info list = let deref_pre = [ "core"; "ops"; "deref" ] in @@ -278,6 +415,46 @@ let assumed_infos : assumed_info list = Sig.vec_index_mut_sig, true, to_name (index_pre @ [ "IndexMut"; "index_mut" ]) ); + (* Array Index *) + ( ArrayIndexShared, + Sig.array_index_sig false, + true, + to_name [ "@ArrayIndexShared" ] ); + (ArrayIndexMut, Sig.array_index_sig true, true, to_name [ "@ArrayIndexMut" ]); + (* Array to slice*) + ( ArrayToSliceShared, + Sig.array_to_slice_sig false, + true, + to_name [ "@ArrayToSliceShared" ] ); + ( ArrayToSliceMut, + Sig.array_to_slice_sig true, + true, + to_name [ "@ArrayToSliceMut" ] ); + (* Array Subslice *) + ( ArraySubsliceShared, + Sig.array_subslice_sig false, + true, + to_name [ "@ArraySubsliceShared" ] ); + ( ArraySubsliceMut, + Sig.array_subslice_sig true, + true, + to_name [ "@ArraySubsliceMut" ] ); + (* Slice Index *) + ( SliceIndexShared, + Sig.slice_index_sig false, + true, + to_name [ "@SliceIndexShared" ] ); + (SliceIndexMut, Sig.slice_index_sig true, true, to_name [ "@SliceIndexMut" ]); + (* Slice Subslice *) + ( SliceSubsliceShared, + Sig.slice_subslice_sig false, + true, + to_name [ "@SliceSubsliceShared" ] ); + ( SliceSubsliceMut, + Sig.slice_subslice_sig true, + true, + to_name [ "@SliceSubsliceMut" ] ); + (SliceLen, Sig.slice_len_sig, false, to_name [ "@SliceLen" ]); ] let get_assumed_info (id : A.assumed_fun_id) : assumed_info = diff --git a/compiler/Config.ml b/compiler/Config.ml index 0475899c..bd80769f 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -162,6 +162,14 @@ let backward_no_state_update = ref false *) let split_files = ref true +(** Generate the library entry point, if the crate is split between different files. + + Sometimes we want to skip this: the library entry points just includes all the + files in the project, and the user may want to write their own entry point, to + add custom includes (to include the files containing the proofs, for instance). + *) +let generate_lib_entry_point = ref true + (** For Lean, controls whether we generate a lakefile or not. *) diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index a425d42b..2ca5653d 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -12,7 +12,8 @@ open Identifiers in the environment, because they contain borrows for instance, typically because they might be overwritten during an assignment. *) -module DummyVarId = IdGen () +module DummyVarId = +IdGen () type dummy_var_id = DummyVarId.id [@@deriving show, ord] @@ -261,6 +262,7 @@ type eval_ctx = { global_context : global_context; region_groups : RegionGroupId.id list; type_vars : type_var list; + const_generic_vars : const_generic_var list; env : env; ended_regions : RegionId.Set.t; } @@ -269,6 +271,10 @@ type eval_ctx = { let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var = TypeVarId.nth ctx.type_vars vid +let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) : + const_generic_var = + ConstGenericVarId.nth ctx.const_generic_vars vid + (** Lookup a variable in the current frame *) let env_lookup_var (env : env) (vid : VarId.id) : var_binder * typed_value = (* We take care to stop at the end of current frame: different variables diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 166ef11b..d819768b 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -107,6 +107,10 @@ let () = Arg.Clear check_invariants, " Deactivate the invariant sanity checks performed at every evaluation \ step. Dramatically increases speed." ); + ( "-no-gen-lib-entry", + Arg.Clear generate_lib_entry_point, + " Do not generate the entry point file for the generated library (only \ + valid if the crate is split between different files)" ); ( "-lean-default-lakefile", Arg.Clear lean_gen_lakefile, " Generate a default lakefile.lean (Lean only)" ); @@ -133,6 +137,8 @@ let () = (not !use_fuel) || (not !extract_decreases_clauses) && not !extract_template_decreases_clauses); + (* We have: not generate_lib_entry_point ==> split_files *) + assert (!split_files || !generate_lib_entry_point); if !lean_gen_lakefile && not (!backend = Lean) then log#error "The -lean-default-lakefile option is valid only for the Lean backend"; diff --git a/compiler/Extract.ml b/compiler/Extract.ml index b16f9639..c4238d83 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -41,7 +41,9 @@ let unop_name (unop : unop) : string = match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~") | Neg (int_ty : integer_type) -> ( match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg") - | Cast _ -> raise (Failure "Unsupported") + | Cast _ -> + (* We never directly use the unop name in this case *) + raise (Failure "Unsupported") (** Small helper to compute the name of a binary operation (note that many binary operations like "less than" are extracted to primitive operations, @@ -216,44 +218,53 @@ let keywords () = List.concat [ named_unops; named_binops; misc ] let assumed_adts () : (assumed_ty * string) list = - List.map - (fun (t, s) -> - if !backend = Lean then - ( t, - Printf.sprintf "%c%s" - (Char.uppercase_ascii s.[0]) - (String.sub s 1 (String.length s - 1)) ) - else (t, s)) - (match !backend with - | Lean -> - [ - (State, "State"); - (Result, "Result"); - (Error, "Error"); - (Fuel, "Nat"); - (Option, "Option"); - (Vec, "Vec"); - ] - | Coq | FStar -> - [ - (State, "state"); - (Result, "result"); - (Error, "error"); - (Fuel, "nat"); - (Option, "option"); - (Vec, "vec"); - ] - | HOL4 -> - [ - (State, "state"); - (Result, "result"); - (Error, "error"); - (Fuel, "num"); - (Option, "option"); - (Vec, "vec"); - ]) + match !backend with + | Lean -> + [ + (State, "State"); + (Result, "Result"); + (Error, "Error"); + (Fuel, "Nat"); + (Option, "Option"); + (Vec, "Vec"); + (Array, "Array"); + (Slice, "Slice"); + (Str, "Str"); + (Range, "Range"); + ] + | Coq | FStar -> + [ + (State, "state"); + (Result, "result"); + (Error, "error"); + (Fuel, "nat"); + (Option, "option"); + (Vec, "vec"); + (Array, "array"); + (Slice, "slice"); + (Str, "str"); + (Range, "range"); + ] + | HOL4 -> + [ + (State, "state"); + (Result, "result"); + (Error, "error"); + (Fuel, "num"); + (Option, "option"); + (Vec, "vec"); + (Array, "array"); + (Slice, "slice"); + (Str, "str"); + (Range, "range"); + ] -let assumed_structs : (assumed_ty * string) list = [] +let assumed_struct_constructors () : (assumed_ty * string) list = + match !backend with + | Lean -> [ (Range, "Range.mk"); (Array, "Array.make") ] + | Coq -> [ (Range, "mk_range"); (Array, "mk_array") ] + | FStar -> [ (Range, "Mkrange"); (Array, "mk_array") ] + | HOL4 -> [ (Range, "mk_range"); (Array, "mk_array") ] let assumed_variants () : (assumed_ty * VariantId.id * string) list = match !backend with @@ -318,6 +329,22 @@ let assumed_llbc_functions () : (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); (VecIndexMut, None, "vec_index_mut_fwd"); (VecIndexMut, rg0, "vec_index_mut_back"); + (ArrayIndexShared, None, "array_index_shared"); + (ArrayIndexMut, None, "array_index_mut_fwd"); + (ArrayIndexMut, rg0, "array_index_mut_back"); + (ArrayToSliceShared, None, "array_to_slice_shared"); + (ArrayToSliceMut, None, "array_to_slice_mut_fwd"); + (ArrayToSliceMut, rg0, "array_to_slice_mut_back"); + (ArraySubsliceShared, None, "array_subslice_shared"); + (ArraySubsliceMut, None, "array_subslice_mut_fwd"); + (ArraySubsliceMut, rg0, "array_subslice_mut_back"); + (SliceIndexShared, None, "slice_index_shared"); + (SliceIndexMut, None, "slice_index_mut_fwd"); + (SliceIndexMut, rg0, "slice_index_mut_back"); + (SliceSubsliceShared, None, "slice_subslice_shared"); + (SliceSubsliceMut, None, "slice_subslice_mut_fwd"); + (SliceSubsliceMut, rg0, "slice_subslice_mut_back"); + (SliceLen, None, "slice_len"); ] | Lean -> [ @@ -329,10 +356,26 @@ let assumed_llbc_functions () : (VecInsert, None, "Vec.insert_fwd") (* Shouldn't be used *); (VecInsert, rg0, "Vec.insert"); (VecLen, None, "Vec.len"); - (VecIndex, None, "Vec.index"); - (VecIndex, rg0, "Vec.index_back") (* shouldn't be used *); + (VecIndex, None, "Vec.index_shared"); + (VecIndex, rg0, "Vec.index_shared_back") (* shouldn't be used *); (VecIndexMut, None, "Vec.index_mut"); (VecIndexMut, rg0, "Vec.index_mut_back"); + (ArrayIndexShared, None, "Array.index_shared"); + (ArrayIndexMut, None, "Array.index_mut"); + (ArrayIndexMut, rg0, "Array.index_mut_back"); + (ArrayToSliceShared, None, "Array.to_slice_shared"); + (ArrayToSliceMut, None, "Array.to_slice_mut"); + (ArrayToSliceMut, rg0, "Array.to_slice_mut_back"); + (ArraySubsliceShared, None, "Array.subslice_shared"); + (ArraySubsliceMut, None, "Array.subslice_mut"); + (ArraySubsliceMut, rg0, "Array.subslice_mut_back"); + (SliceIndexShared, None, "Slice.index_shared"); + (SliceIndexMut, None, "Slice.index_mut"); + (SliceIndexMut, rg0, "Slice.index_mut_back"); + (SliceSubsliceShared, None, "Slice.subslice_shared"); + (SliceSubsliceMut, None, "Slice.subslice_mut"); + (SliceSubsliceMut, rg0, "Slice.subslice_mut_back"); + (SliceLen, None, "Slice.len"); ] let assumed_pure_functions () : (pure_assumed_fun_id * string) list = @@ -359,7 +402,7 @@ let names_map_init () : names_map_init = { keywords = keywords (); assumed_adts = assumed_adts (); - assumed_structs; + assumed_structs = assumed_struct_constructors (); assumed_variants = assumed_variants (); assumed_llbc_functions = assumed_llbc_functions (); assumed_pure_functions = assumed_pure_functions (); @@ -722,7 +765,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | None -> ( (* No basename: we use the first letter of the type *) match ty with - | Adt (type_id, tys) -> ( + | Adt (type_id, tys, _) -> ( match type_id with | Tuple -> (* The "pair" case is frequent enough to have its special treatment *) @@ -732,6 +775,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Assumed Fuel -> ConstStrings.fuel_basename | Assumed Option -> "opt" | Assumed Vec -> "v" + | Assumed Array -> "a" + | Assumed Slice -> "s" + | Assumed Str -> "s" + | Assumed Range -> "r" | Assumed State -> ConstStrings.state_basename | AdtId adt_id -> let def = @@ -757,12 +804,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) match !backend with | FStar -> "x" (* lacking inspiration here... *) | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *)) - | Bool -> "b" - | Char -> "c" - | Integer _ -> "i" - | Str -> "s" - | Arrow _ -> "f" - | Array _ | Slice _ -> raise Unimplemented) + | Literal lty -> ( + match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i") + | Arrow _ -> "f") in let type_var_basename (_varset : StringSet.t) (basename : string) : string = (* Rust type variables are snake-case and start with a capital letter *) @@ -775,12 +819,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) "'" ^ to_snake_case basename | Coq | Lean -> basename in + let const_generic_var_basename (_varset : StringSet.t) (basename : string) : + string = + (* Rust type variables are snake-case and start with a capital letter *) + match !backend with + | FStar | HOL4 -> + (* This is *not* a no-op: this removes the capital letter *) + to_snake_case basename + | Coq | Lean -> basename + in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i in - let extract_primitive_value (fmt : F.formatter) (inside : bool) - (cv : primitive_value) : unit = + let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit + = match cv with | Scalar sv -> ( match !backend with @@ -847,14 +900,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in F.pp_print_string fmt c; if inside then F.pp_print_string fmt ")") - | String s -> - (* We need to replace all the line breaks *) - let s = - StringUtils.map - (fun c -> if c = '\n' then "\n" else String.make 1 c) - s - in - F.pp_print_string fmt ("\"" ^ s ^ "\"") in let bool_name = if !backend = Lean then "Bool" else "bool" in let char_name = if !backend = Lean then "Char" else "char" in @@ -877,8 +922,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) opaque_pre; var_basename; type_var_basename; + const_generic_var_basename; append_index; - extract_primitive_value; + extract_literal; extract_unop; extract_binop; } @@ -1043,6 +1089,24 @@ let extract_arrow (fmt : F.formatter) () : unit = if !Config.backend = Lean then F.pp_print_string fmt "→" else F.pp_print_string fmt "->" +let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (cg : const_generic) : unit = + match cg with + | ConstGenericGlobal id -> + let s = ctx_get_global ctx.use_opaque_pre id ctx in + F.pp_print_string fmt s + | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v + | ConstGenericVar id -> + let s = ctx_get_const_generic_var id ctx in + F.pp_print_string fmt s + +let extract_literal_type (ctx : extraction_ctx) (fmt : F.formatter) + (ty : literal_type) : unit = + match ty with + | Bool -> F.pp_print_string fmt ctx.fmt.bool_name + | Char -> F.pp_print_string fmt ctx.fmt.char_name + | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) + (** [inside] constrols whether we should add parentheses or not around type applications (if [true] we add parentheses). @@ -1067,7 +1131,8 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit = let extract_rec = extract_ty ctx fmt no_params_tys in match ty with - | Adt (type_id, tys) -> ( + | Adt (type_id, tys, cgs) -> ( + let has_params = tys <> [] || cgs <> [] in match type_id with | Tuple -> (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type: @@ -1099,7 +1164,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) let with_opaque_pre = false in match !backend with | FStar | Coq | Lean -> - let print_paren = inside && tys <> [] in + let print_paren = inside && has_params in if print_paren then F.pp_print_string fmt "("; (* TODO: for now, only the opaque *functions* are extracted in the opaque module. The opaque *types* are assumed. *) @@ -1108,8 +1173,15 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); Collections.List.iter_link (F.pp_print_space fmt) (extract_rec true) tys); + if cgs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_const_generic ctx fmt true) + cgs); if print_paren then F.pp_print_string fmt ")" | HOL4 -> + (* Const generics are unsupported in HOL4 *) + assert (cgs = []); let print_tys = match type_id with | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys) @@ -1128,10 +1200,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt ()); F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx))) | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) - | Bool -> F.pp_print_string fmt ctx.fmt.bool_name - | Char -> F.pp_print_string fmt ctx.fmt.char_name - | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) - | Str -> F.pp_print_string fmt ctx.fmt.str_name + | Literal lty -> extract_literal_type ctx fmt lty | Arrow (arg_ty, ret_ty) -> if inside then F.pp_print_string fmt "("; extract_rec false arg_ty; @@ -1140,7 +1209,6 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" - | Array _ | Slice _ -> raise Unimplemented (** Compute the names for all the top-level identifiers used in a type definition (type name, variant names, field names, etc. but not type @@ -1182,8 +1250,8 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : (** Print the variants *) let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (type_name : string) - (type_params : string list) (cons_name : string) (fields : field list) : - unit = + (type_params : string list) (cg_params : string list) (cons_name : string) + (fields : field list) : unit = F.pp_print_space fmt (); (* variant box *) F.pp_open_hvbox fmt ctx.indent_incr; @@ -1235,16 +1303,18 @@ let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) let _ = List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields in + (* Sanity check: HOL4 doesn't support const generics *) + assert (cg_params = [] || !backend <> HOL4); (* Print the final type *) if !backend <> HOL4 then ( F.pp_print_space fmt (); F.pp_open_hovbox fmt 0; F.pp_print_string fmt type_name; List.iter - (fun type_param -> + (fun p -> F.pp_print_space fmt (); - F.pp_print_string fmt type_param) - type_params; + F.pp_print_string fmt p) + (List.append type_params cg_params); F.pp_close_box fmt ()); (* Close the variant box *) F.pp_close_box fmt () @@ -1252,7 +1322,8 @@ let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *) let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (def : type_decl) (def_name : string) - (type_params : string list) (variants : variant list) : unit = + (type_params : string list) (cg_params : string list) + (variants : variant list) : unit = (* We want to generate a definition which looks like this (taking F* as example): {[ type list a = | Cons : a -> list a -> list a | Nil : list a @@ -1291,7 +1362,7 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let cons_name = ctx.fmt.variant_name def.name v.variant_name in let fields = v.fields in extract_type_decl_variant ctx fmt type_decl_group def_name type_params - cons_name fields + cg_params cons_name fields in (* Print the variants *) let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in @@ -1299,7 +1370,8 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) - (type_params : string list) (fields : field list) : unit = + (type_params : string list) (cg_params : string list) (fields : field list) + : unit = (* We want to generate a definition which looks like this (taking F* as example): {[ type t = { x : int; y : bool; } @@ -1426,7 +1498,7 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt type_decl_group def_name type_params - cons_name fields) + cg_params cons_name fields) in () @@ -1473,19 +1545,25 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) else None in (* If in Coq and the declaration is opaque, it must have the shape: - [Axiom Ident : forall (T0 ... Tn : Type), ... -> ... -> ...]. + [Axiom Ident : forall (T0 ... Tn : Type) (N0 : ...) ... (Nn : ...), ... -> ... -> ...]. The boolean [is_opaque_coq] is used to detect this case. *) let is_opaque = type_kind = None in let is_opaque_coq = !backend = Coq && is_opaque in - let use_forall = is_opaque_coq && def.type_params <> [] in + let use_forall = + is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> []) + in (* Retrieve the definition name *) let with_opaque_pre = false in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in - (* Add the type params - note that we need those bindings only for the + (* Add the type and const generic params - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx_body, type_params = ctx_add_type_params def.type_params ctx in + let ctx_body, type_params, cg_params = + ctx_add_type_const_generic_params def.type_params def.const_generic_params + ctx + in + let ty_cg_params = List.append type_params cg_params in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; @@ -1498,31 +1576,47 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (match !backend with | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 | Lean -> F.pp_open_vbox fmt 0); - (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) + (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in (match qualif with | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name) | None -> F.pp_print_string fmt def_name); - (* Print the type parameters *) - if def.type_params <> [] && !backend <> HOL4 then ( + (* HOL4 doesn't support const generics *) + assert (cg_params = [] || !backend <> HOL4); + (* Print the type/const generic parameters *) + if ty_cg_params <> [] && !backend <> HOL4 then ( if use_forall then ( F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "forall"); - F.pp_print_space fmt (); - F.pp_print_string fmt "("; + (* Print the type parameters *) + if type_params <> [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + List.iter + (fun s -> + F.pp_print_string fmt s; + F.pp_print_space fmt ()) + type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword () ^ ")")); + (* Print the const generic parameters *) List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx_body in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt (type_keyword () ^ ")")); + (fun (var : const_generic_var) -> + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let n = ctx_get_const_generic_var var.index ctx in + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt var.ty; + F.pp_print_string fmt ")") + def.const_generic_params); (* Print the "=" if we extract the body*) if extract_body then ( F.pp_print_space fmt (); @@ -1551,10 +1645,10 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) match def.kind with | Struct fields -> extract_type_decl_struct_body ctx_body fmt type_decl_group kind def - type_params fields + type_params cg_params fields | Enum variants -> extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name - type_params variants + type_params cg_params variants | Opaque -> raise (Failure "Unreachable")); (* Add the definition end delimiter *) if !backend = HOL4 && decl_is_not_last_from_group kind then ( @@ -1581,6 +1675,8 @@ let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (* Retrieve the definition name *) let with_opaque_pre = false in let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in + (* Generic parameters are unsupported *) + assert (def.const_generic_params = []); (* Count the number of parameters *) let num_params = List.length def.type_params in (* Generate the declaration *) @@ -1604,6 +1700,7 @@ let extract_type_decl_hol4_empty_record (ctx : extraction_ctx) let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Sanity check *) assert (def.type_params = []); + assert (def.const_generic_params = []); (* Generate the declaration *) F.pp_print_space fmt (); F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”"); @@ -1643,11 +1740,14 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = assert (!backend = Coq); (* Generating the [Arguments] instructions is useful only if there are type parameters *) - if decl.type_params = [] then () + if decl.type_params = [] && decl.const_generic_params = [] then () else (* Add the type params - note that we need those bindings only for the * body translation (they are not top-level) *) - let _ctx_body, type_params = ctx_add_type_params decl.type_params ctx in + let _ctx_body, type_params, cg_params = + ctx_add_type_const_generic_params decl.type_params + decl.const_generic_params ctx + in (* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *) let extract_arguments_info (cons_name : string) (fields : 'a list) : unit = (* Add a break before *) @@ -1655,12 +1755,12 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box *) F.pp_open_hovbox fmt ctx.indent_incr; (* Small utility *) - let print_type_vars () = + let print_vars () = List.iter (fun (var : string) -> F.pp_print_space fmt (); F.pp_print_string fmt ("{" ^ var ^ "}")) - type_params + (List.append type_params cg_params) in let print_fields () = List.iter @@ -1673,7 +1773,7 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Arguments"; F.pp_print_space fmt (); F.pp_print_string fmt cons_name; - print_type_vars (); + print_vars (); print_fields (); F.pp_print_string fmt "."; @@ -1727,7 +1827,10 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let is_rec = decl_is_from_rec_group kind in if is_rec then (* Add the type params *) - let ctx, type_params = ctx_add_type_params decl.type_params ctx in + let ctx, type_params, cg_params = + ctx_add_type_const_generic_params decl.type_params + decl.const_generic_params ctx + in let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in let with_opaque_pre = false in @@ -1760,6 +1863,19 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) F.pp_print_space fmt (); F.pp_print_string fmt "Type}"; F.pp_print_space fmt ()); + (* Print the const generic parameters *) + if cg_params <> [] then + List.iter + (fun (v : const_generic_var) -> + F.pp_print_string fmt "{"; + let n = ctx_get_const_generic_var v.index ctx in + F.pp_print_string fmt n; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt v.ty; + F.pp_print_string fmt "}"; + F.pp_print_space fmt ()) + decl.const_generic_params; (* Print the record parameter *) F.pp_print_string fmt "("; F.pp_print_string fmt record_var; @@ -1948,16 +2064,6 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 -(** In some provers like HOL4, we don't print the type parameters when calling - functions (and the polymorphism is uniform...). - - TODO: we may want to check that the polymorphism is indeed uniform when - generating code for such backends, and print at least a warning to the - user when it is not the case. - *) -let print_fun_type_params () : bool = - match !backend with FStar | Coq | Lean -> true | HOL4 -> false - (** Compute the names for all the pure functions generated from a rust function (forward function and backward functions). *) @@ -2016,10 +2122,11 @@ let extract_adt_g_value (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = match ty with - | Adt (Tuple, type_args) -> + | Adt (Tuple, type_args, cg_args) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) assert (List.length type_args = List.length field_values); + assert (cg_args = []); (* This is very annoying: in Coq, we can't write [()] for the value of type [unit], we have to write [tt]. *) if !backend = Coq && field_values = [] then ( @@ -2037,7 +2144,7 @@ let extract_adt_g_value in F.pp_print_string fmt ")"; ctx) - | Adt (adt_id, _) -> + | Adt (adt_id, _, _) -> (* "Regular" ADT *) (* If we are generating a pattern for a let-binding and we target Lean, @@ -2107,7 +2214,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = match v.value with | PatConstant cv -> - ctx.fmt.extract_primitive_value fmt inside cv; + ctx.fmt.extract_literal fmt inside cv; ctx | PatVar (v, _) -> let vname = @@ -2142,7 +2249,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Var var_id -> let var_name = ctx_get_var var_id ctx in F.pp_print_string fmt var_name - | Const cv -> ctx.fmt.extract_primitive_value fmt inside cv + | Const cv -> ctx.fmt.extract_literal fmt inside cv | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args @@ -2155,7 +2262,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Let (_, _, _, _) -> extract_lets ctx fmt inside e | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body | Meta (_, e) -> extract_texpression ctx fmt inside e - | StructUpdate supd -> extract_StructUpdate ctx fmt inside supd + | StructUpdate supd -> extract_StructUpdate ctx fmt inside e.ty supd | Loop _ -> (* The loop nodes should have been eliminated in {!PureMicroPasses} *) raise (Failure "Unreachable") @@ -2172,10 +2279,12 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Top-level qualifier *) match qualif.id with | FunOrOp fun_id -> - extract_function_call ctx fmt inside fun_id qualif.type_args args + extract_function_call ctx fmt inside fun_id qualif.type_args + qualif.const_generic_args args | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> - extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args + extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args + qualif.const_generic_args args | Proj proj -> extract_field_projector ctx fmt inside app proj qualif.type_args args) | _ -> @@ -2201,7 +2310,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: function call *) and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (fid : fun_or_op_id) (type_args : ty list) - (args : texpression list) : unit = + (cg_args : const_generic list) (args : texpression list) : unit = match (fid, args) with | Unop unop, [ arg ] -> (* A unop can have *at most* one argument (the result can't be a function!). @@ -2222,13 +2331,20 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) let with_opaque_pre = ctx.use_opaque_pre in let fun_name = ctx_get_function with_opaque_pre fun_id ctx in F.pp_print_string fmt fun_name; + (* Sanity check: HOL4 doesn't support const generics *) + assert (cg_args = [] || !backend <> HOL4); (* Print the type parameters, if the backend is not HOL4 *) - if !backend <> HOL4 then + if !backend <> HOL4 then ( List.iter (fun ty -> F.pp_print_space fmt (); extract_ty ctx fmt TypeDeclId.Set.empty true ty) type_args; + List.iter + (fun cg -> + F.pp_print_space fmt (); + extract_const_generic ctx fmt true cg) + cg_args); (* Print the arguments *) List.iter (fun ve -> @@ -2250,9 +2366,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (** Subcase of the app case: ADT constructor *) and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : - unit = - let e_ty = Adt (adt_cons.adt_id, type_args) in + (adt_cons : adt_cons_id) (type_args : ty list) + (cg_args : const_generic list) (args : texpression list) : unit = + let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in let is_single_pat = false in let _ = extract_adt_g_value @@ -2609,104 +2725,152 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) F.pp_close_box fmt () and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (supd : struct_update) : unit = + (inside : bool) (e_ty : ty) (supd : struct_update) : unit = (* We can't update a subset of the fields in Coq (i.e., we can do [{| x:= 3; y := 4 |}], but there is no syntax for [{| s with x := 3 |}]) *) assert (!backend <> Coq || supd.init = None); (* In the case of HOL4, records with no fields are not supported and are thus extracted to unit. We need to check that by looking up the definition *) let extract_as_unit = - if !backend = HOL4 then - let d = - TypeDeclId.Map.find supd.struct_id ctx.trans_ctx.type_context.type_decls - in - d.kind = Struct [] - else false + match (!backend, supd.struct_id) with + | HOL4, AdtId adt_id -> + let d = + TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls + in + d.kind = Struct [] + | _ -> false in if extract_as_unit then (* Remark: this is only valid for HOL4 (for instance the Coq unit value is [tt]) *) F.pp_print_string fmt "()" - else ( - F.pp_open_hvbox fmt 0; - F.pp_open_hvbox fmt ctx.indent_incr; - (* Inner/outer brackets: there are several syntaxes for the field updates. - - For instance, in F*: - {[ - { x with f = ..., ... } - ]} - - In HOL4: - {[ - x with <| f = ..., ... |> - ]} - - In the above examples: - - in F*, the { } brackets are "outer" brackets - - in HOL4, the <| |> brackets are "inner" brackets + else + (* There are two cases: + - this is a regular struct + - this is an array *) - (* Outer brackets *) - let olb, orb = - match !backend with - | Lean | FStar -> (Some "{", Some "}") - | Coq -> (Some "{|", Some "|}") - | HOL4 -> (None, None) - in - (* Inner brackets *) - let ilb, irb = - match !backend with - | Lean | FStar | Coq -> (None, None) - | HOL4 -> (Some "<|", Some "|>") - in - (* Helper *) - let print_bracket (is_left : bool) b = - match b with - | Some b -> - if not is_left then F.pp_print_space fmt (); - F.pp_print_string fmt b; - if is_left then F.pp_print_space fmt () - | None -> () - in - print_bracket true olb; - let need_paren = inside && !backend = HOL4 in - if need_paren then F.pp_print_string fmt "("; - F.pp_open_hvbox fmt ctx.indent_incr; - if supd.init <> None then ( - let var_name = ctx_get_var (Option.get supd.init) ctx in - F.pp_print_string fmt var_name; - F.pp_print_space fmt (); - F.pp_print_string fmt "with"; - F.pp_print_space fmt ()); - print_bracket true ilb; - F.pp_open_hvbox fmt 0; - let delimiter = - match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" - in - let assign = - match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "=" - in - Collections.List.iter_link - (fun () -> - F.pp_print_string fmt delimiter; - F.pp_print_space fmt ()) - (fun (fid, fe) -> + match supd.struct_id with + | AdtId _ -> + F.pp_open_hvbox fmt 0; F.pp_open_hvbox fmt ctx.indent_incr; - let f = ctx_get_field (AdtId supd.struct_id) fid ctx in - F.pp_print_string fmt f; - F.pp_print_string fmt (" " ^ assign); - F.pp_print_space fmt (); + (* Inner/outer brackets: there are several syntaxes for the field updates. + + For instance, in F*: + {[ + { x with f = ..., ... } + ]} + + In HOL4: + {[ + x with <| f = ..., ... |> + ]} + + In the above examples: + - in F*, the { } brackets are "outer" brackets + - in HOL4, the <| |> brackets are "inner" brackets + *) + (* Outer brackets *) + let olb, orb = + match !backend with + | Lean | FStar -> (Some "{", Some "}") + | Coq -> (Some "{|", Some "|}") + | HOL4 -> (None, None) + in + (* Inner brackets *) + let ilb, irb = + match !backend with + | Lean | FStar | Coq -> (None, None) + | HOL4 -> (Some "<|", Some "|>") + in + (* Helper *) + let print_bracket (is_left : bool) b = + match b with + | Some b -> + if not is_left then F.pp_print_space fmt (); + F.pp_print_string fmt b; + if is_left then F.pp_print_space fmt () + | None -> () + in + print_bracket true olb; + let need_paren = inside && !backend = HOL4 in + if need_paren then F.pp_print_string fmt "("; F.pp_open_hvbox fmt ctx.indent_incr; - extract_texpression ctx fmt true fe; + if supd.init <> None then ( + let var_name = ctx_get_var (Option.get supd.init) ctx in + F.pp_print_string fmt var_name; + F.pp_print_space fmt (); + F.pp_print_string fmt "with"; + F.pp_print_space fmt ()); + print_bracket true ilb; + F.pp_open_hvbox fmt 0; + let delimiter = + match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" + in + let assign = + match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "=" + in + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt delimiter; + F.pp_print_space fmt ()) + (fun (fid, fe) -> + F.pp_open_hvbox fmt ctx.indent_incr; + let f = ctx_get_field supd.struct_id fid ctx in + F.pp_print_string fmt f; + F.pp_print_string fmt (" " ^ assign); + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + extract_texpression ctx fmt true fe; + F.pp_close_box fmt (); + F.pp_close_box fmt ()) + supd.updates; F.pp_close_box fmt (); - F.pp_close_box fmt ()) - supd.updates; - F.pp_close_box fmt (); - print_bracket false irb; - F.pp_close_box fmt (); - F.pp_close_box fmt (); - if need_paren then F.pp_print_string fmt ")"; - print_bracket false orb; - F.pp_close_box fmt ()) + print_bracket false irb; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + if need_paren then F.pp_print_string fmt ")"; + print_bracket false orb; + F.pp_close_box fmt () + | Assumed Array -> + (* Open the boxes *) + F.pp_open_hvbox fmt ctx.indent_incr; + let need_paren = inside in + if need_paren then F.pp_print_string fmt "("; + (* Open the box for `Array.mk T N [` *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the array constructor *) + let cs = ctx_get_struct false (Assumed Array) ctx in + F.pp_print_string fmt cs; + (* Print the parameters *) + let _, tys, cgs = ty_as_adt e_ty in + let ty = Collections.List.to_cons_nil tys in + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty true ty; + let cg = Collections.List.to_cons_nil cgs in + F.pp_print_space fmt (); + extract_const_generic ctx fmt true cg; + F.pp_print_space fmt (); + F.pp_print_string fmt "["; + (* Close the box for `Array.mk T N [` *) + F.pp_close_box fmt (); + (* Print the values *) + let delimiter = + match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" + in + F.pp_print_space fmt (); + F.pp_open_hovbox fmt 0; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt delimiter; + F.pp_print_space fmt ()) + (fun (_, fe) -> extract_texpression ctx fmt false fe) + supd.updates; + (* Close the boxes *) + F.pp_close_box fmt (); + if supd.updates <> [] then F.pp_print_space fmt (); + F.pp_print_string fmt "]"; + if need_paren then F.pp_print_string fmt ")"; + F.pp_close_box fmt () + | _ -> raise (Failure "Unreachable") (** Insert a space, if necessary *) let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = @@ -2729,33 +2893,52 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx = (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx, _ = ctx_add_type_params def.signature.type_params ctx in + let ctx, type_params, cg_params = + ctx_add_type_const_generic_params def.signature.type_params + def.signature.const_generic_params ctx + in (* Print the parameters - rem.: we should have filtered the functions * with no input parameters *) (* The type parameters. Note that in HOL4 we don't print the type parameters. *) - if def.signature.type_params <> [] && !backend <> HOL4 then ( - (* Open a box for the type parameters *) + if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then ( + (* Open a box for the type and const generic parameters *) F.pp_open_hovbox fmt 0; - insert_req_space fmt space; - F.pp_print_string fmt "("; - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.signature.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - let type_keyword = - match !backend with - | FStar -> "Type0" - | Coq | Lean -> "Type" - | HOL4 -> raise (Failure "Unreachable") - in - F.pp_print_string fmt (type_keyword ^ ")"); + (* The type parameters *) + if type_params <> [] then ( + insert_req_space fmt space; + F.pp_print_string fmt "("; + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx in + F.pp_print_string fmt pname; + F.pp_print_space fmt ()) + def.signature.type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + let type_keyword = + match !backend with + | FStar -> "Type0" + | Coq | Lean -> "Type" + | HOL4 -> raise (Failure "Unreachable") + in + F.pp_print_string fmt (type_keyword ^ ")")); + (* The const generic parameters *) + if cg_params <> [] then + List.iter + (fun (p : const_generic_var) -> + let pname = ctx_get_const_generic_var p.index ctx in + insert_req_space fmt space; + F.pp_print_string fmt "("; + F.pp_print_string fmt pname; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt p.ty; + F.pp_print_string fmt ")") + def.signature.const_generic_params; (* Close the box for the type parameters *) F.pp_close_box fmt ()); (* The input parameters - note that doing this adds bindings to the context *) @@ -3050,7 +3233,11 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) The boolean [is_opaque_coq] is used to detect this case. *) let is_opaque_coq = !backend = Coq && is_opaque in - let use_forall = is_opaque_coq && def.signature.type_params <> [] in + let use_forall = + is_opaque_coq + && (def.signature.type_params <> [] + || def.signature.const_generic_params <> []) + in (* Print the qualifier ("assume", etc.). if `wrap_opaque_in_sig`: we generate a record of assumed funcions. @@ -3123,13 +3310,20 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* The name of the decrease clause *) let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in F.pp_print_string fmt decr_name; - (* Print the type parameters *) + (* Print the type/const generic parameters - TODO: we do this many + times, we should have a helper to factor it out *) List.iter (fun (p : type_var) -> let pname = ctx_get_type_var p.index ctx in F.pp_print_space fmt (); F.pp_print_string fmt pname) def.signature.type_params; + List.iter + (fun (p : const_generic_var) -> + let pname = ctx_get_const_generic_var p.index ctx in + F.pp_print_space fmt (); + F.pp_print_string fmt pname) + def.signature.const_generic_params; (* Print the input values: we have to be careful here to print * only the input values which are in common with the *forward* * function (the additional input values "given back" to the @@ -3216,6 +3410,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open the box for [DECREASES] *) F.pp_open_hovbox fmt ctx.indent_incr; F.pp_print_string fmt terminates_name; + (* Print the type/const generic params - TODO: factor out *) List.iter (fun (p : type_var) -> let pname = ctx_get_type_var p.index ctx in @@ -3223,6 +3418,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt pname) def.signature.type_params; List.iter + (fun (p : const_generic_var) -> + let pname = ctx_get_const_generic_var p.index ctx in + F.pp_print_space fmt (); + F.pp_print_string fmt pname) + def.signature.const_generic_params; + (* Print the variables *) + List.iter (fun v -> F.pp_print_space fmt (); F.pp_print_string fmt (ctx_get_var v ctx_body)) @@ -3278,9 +3480,13 @@ let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id ctx in - (* Add the type parameters - note that we need those bindings only for the - * generation of the type (they are not top-level) *) - let ctx, _ = ctx_add_type_params def.signature.type_params ctx in + assert (def.signature.const_generic_params = []); + (* Add the type/const gen parameters - note that we need those bindings + only for the generation of the type (they are not top-level) *) + let ctx, _, _ = + ctx_add_type_const_generic_params def.signature.type_params + def.signature.const_generic_params ctx + in (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0; (* Open a box for the whole definition *) @@ -3459,6 +3665,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) assert (List.length body.signature.inputs = 0); assert (List.length body.signature.doutputs = 1); assert (List.length body.signature.type_params = 0); + assert (List.length body.signature.const_generic_params = 0); (* Add a break then the name of the corresponding LLBC declaration *) F.pp_print_break fmt 0 0; @@ -3529,6 +3736,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) let sg = def.signature in if sg.type_params = [] + && sg.const_generic_params = [] && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) && sg.output = mk_result_ty mk_unit_ty then ( diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 655bb033..d733c763 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -286,13 +286,15 @@ type formatter = { *) type_var_basename : StringSet.t -> string -> string; (** Generates a type variable basename. *) + const_generic_var_basename : StringSet.t -> string -> string; + (** Generates a const generic variable basename. *) append_index : string -> int -> string; (** Appends an index to a name - we use this to generate unique names: when doing so, the role of the formatter is just to concatenate indices to names, the responsability of finding a proper index is delegated to helper functions. *) - extract_primitive_value : F.formatter -> bool -> primitive_value -> unit; + extract_literal : F.formatter -> bool -> literal -> unit; (** Format a constant value. Inputs: @@ -392,6 +394,7 @@ type id = them here. *) | TypeVarId of TypeVarId.id + | ConstGenericVarId of ConstGenericVarId.id | VarId of VarId.id | UnknownId (** Used for stored various strings like keywords, definitions which @@ -674,7 +677,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" else raise (Failure "Unreachable") - | Assumed (State | Vec | Fuel) -> raise (Failure "Unreachable") + | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) -> + raise (Failure "Unreachable") | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in match def.kind with @@ -688,10 +692,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let field_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Error | Fuel | Option) -> - raise (Failure "Unreachable") - | Assumed Vec -> - (* We can't directly have access to the fields of a vector *) + | Assumed + ( State | Result | Error | Fuel | Option | Vec | Array | Slice | Str + | Range ) -> + (* We can't directly have access to the fields of those types *) raise (Failure "Unreachable") | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in @@ -709,6 +713,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = "field name: " ^ field_name | UnknownId -> "keyword" | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id + | ConstGenericVarId id -> + "const_generic_var_id: " ^ ConstGenericVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id (** We might not check for collisions for some specific ids (ex.: field names) *) @@ -787,6 +793,11 @@ let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = let is_opaque = false in ctx_get is_opaque (TypeVarId id) ctx +let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx) + : string = + let is_opaque = false in + ctx_get is_opaque (ConstGenericVarId id) ctx + let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = let is_opaque = false in @@ -822,6 +833,19 @@ let ctx_add_type_var (basename : string) (id : TypeVarId.id) let ctx = ctx_add is_opaque (TypeVarId id) name ctx in (ctx, name) +(** Generate a unique const generic variable name and add it to the context *) +let ctx_add_const_generic_var (basename : string) (id : ConstGenericVarId.id) + (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in + let name = + ctx.fmt.const_generic_var_basename ctx.names_map.names_set basename + in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name + in + let ctx = ctx_add is_opaque (ConstGenericVarId id) name ctx in + (ctx, name) + (** See {!ctx_add_type_var} *) let ctx_add_type_vars (vars : (string * TypeVarId.id) list) (ctx : extraction_ctx) : extraction_ctx * string list = @@ -854,6 +878,20 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) : (fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx) ctx vars +let ctx_add_const_generic_params (vars : const_generic_var list) + (ctx : extraction_ctx) : extraction_ctx * string list = + List.fold_left_map + (fun ctx (var : const_generic_var) -> + ctx_add_const_generic_var var.name var.index ctx) + ctx vars + +let ctx_add_type_const_generic_params (tvars : type_var list) + (cgvars : const_generic_var list) (ctx : extraction_ctx) : + extraction_ctx * string list * string list = + let ctx, tys = ctx_add_type_params tvars ctx in + let ctx, cgs = ctx_add_const_generic_params cgvars ctx in + (ctx, tys, cgs) + let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : extraction_ctx * string = assert (match def.kind with Struct _ -> true | _ -> false); diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index ccb9009e..449463c8 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -29,8 +29,8 @@ let compute_type_fun_global_contexts (m : A.crate) : let initialize_eval_context (type_context : C.type_context) (fun_context : C.fun_context) (global_context : C.global_context) - (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) : - C.eval_ctx = + (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) + (const_generic_vars : T.const_generic_var list) : C.eval_ctx = C.reset_global_counters (); { C.type_context; @@ -38,6 +38,7 @@ let initialize_eval_context (type_context : C.type_context) C.global_context; C.region_groups; C.type_vars; + C.const_generic_vars; C.env = [ C.Frame ]; C.ended_regions = T.RegionId.Set.empty; } @@ -76,11 +77,18 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) in let ctx = initialize_eval_context type_context fun_context global_context - region_groups sg.type_params + region_groups sg.type_params sg.const_generic_params in (* Instantiate the signature *) - let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in - let inst_sg = instantiate_fun_sig type_params sg in + let type_params = + List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + in + let cg_params = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + sg.const_generic_params + in + let inst_sg = instantiate_fun_sig type_params cg_params sg in (* Create fresh symbolic values for the inputs *) let input_svs = List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs @@ -149,14 +157,21 @@ let evaluate_function_symbolic_synthesize_backward_from_return ^ "\n- inside_loop: " ^ Print.bool_to_string inside_loop ^ "\n- ctx:\n" - ^ Print.Contexts.eval_ctx_to_string_gen true true ctx)); + ^ Print.Contexts.eval_ctx_to_string ctx)); (* We need to instantiate the function signature - to retrieve * the return type. Note that it is important to re-generate * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let sg = fdef.signature in - let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in - let ret_inst_sg = instantiate_fun_sig type_params sg in + let type_params = + List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + in + let cg_params = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + sg.const_generic_params + in + let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in let ret_rty = ret_inst_sg.output in (* Move the return value out of the return variable *) let pop_return_value = is_regular_return in @@ -490,7 +505,7 @@ module Test = struct compute_type_fun_global_contexts crate in let ctx = - initialize_eval_context type_context fun_context global_context [] [] + initialize_eval_context type_context fun_context global_context [] [] [] in (* Insert the (uninitialized) local variables *) @@ -518,13 +533,11 @@ module Test = struct (** Small helper: return true if the function is a *transparent* unit function (no parameters, no arguments) - TODO: move *) let fun_decl_is_transparent_unit (def : A.fun_decl) : bool = - match def.body with - | None -> false - | Some body -> - body.arg_count = 0 - && List.length def.A.signature.region_params = 0 - && List.length def.A.signature.type_params = 0 - && List.length def.A.signature.inputs = 0 + Option.is_some def.body + && def.A.signature.region_params = [] + && def.A.signature.type_params = [] + && def.A.signature.const_generic_params = [] + && def.A.signature.inputs = [] (** Test all the unit functions in a list of function definitions *) let test_unit_functions (crate : A.crate) : unit = diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 38c6df3d..4d67a4e4 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -452,8 +452,8 @@ let give_back_symbolic_value (_config : C.config) | V.SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack -> () - | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin -> - raise (Failure "Unrechable")); + | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate -> + raise (Failure "Unreachable")); (* Store the given-back value as a meta-value for synthesis purposes *) let mv = nsv in (* Substitution function, to replace the borrow projectors over symbolic values *) @@ -1733,7 +1733,7 @@ let destructure_abs (abs_kind : V.abs_kind) (can_end : bool) and list_values (v : V.typed_value) : V.typed_avalue list * V.typed_value = let ty = v.V.ty in match v.V.value with - | Primitive _ -> ([], v) + | Literal _ -> ([], v) | Adt adt -> let avll, field_values = List.split (List.map list_values adt.field_values) @@ -1841,7 +1841,7 @@ let convert_value_to_abstractions (abs_kind : V.abs_kind) (can_end : bool) let ty = v.V.ty in match v.V.value with - | V.Primitive _ -> ([], v) + | V.Literal _ -> ([], v) | V.Bottom -> (* Can happen: we *do* convert dummy values to abstractions, and dummy values can contain bottoms *) diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml index 55365043..bf083aa4 100644 --- a/compiler/InterpreterBorrowsCore.ml +++ b/compiler/InterpreterBorrowsCore.ml @@ -87,24 +87,28 @@ let add_borrow_or_abs_id_to_chain (msg : string) (id : borrow_or_abs_id) (** Helper function. - This function allows to define in a generic way a comparison of region types. + This function allows to define in a generic way a comparison of **region types**. See [projections_interesect] for instance. [default]: default boolean to return, when comparing types with no regions [combine]: how to combine booleans [compare_regions]: how to compare regions + + TODO: is there a way of deriving such a comparison? *) let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) (compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool) (ty1 : T.rty) (ty2 : T.rty) : bool = let compare = compare_rtys default combine compare_regions in match (ty1, ty2) with - | T.Bool, T.Bool | T.Char, T.Char | T.Str, T.Str -> default - | T.Integer int_ty1, T.Integer int_ty2 -> - assert (int_ty1 = int_ty2); + | T.Literal lit1, T.Literal lit2 -> + assert (lit1 = lit2); default - | T.Adt (id1, regions1, tys1), T.Adt (id2, regions2, tys2) -> + | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) -> assert (id1 = id2); + (* There are no regions in the const generics, so we ignore them, + but we still check they are the same, for sanity *) + assert (cgs1 = cgs2); (* The check for the ADTs is very crude: we simply compare the arguments * two by two. @@ -134,7 +138,6 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) in (* Combine *) combine params_b tys_b - | T.Array ty1, T.Array ty2 | T.Slice ty1, T.Slice ty2 -> compare ty1 ty2 | T.Ref (r1, ty1, kind1), T.Ref (r2, ty2, kind2) -> (* Sanity check *) assert (kind1 = kind2); diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index 64a90217..81e73e3e 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -216,7 +216,8 @@ let apply_symbolic_expansion_non_borrow (config : C.config) let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (regions : T.RegionId.id T.region list) (types : T.rty list) - (ctx : C.eval_ctx) : V.symbolic_expansion list = + (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list + = (* Lookup the definition and check if it is an enumeration with several * variants *) let def = C.ctx_lookup_type_decl ctx def_id in @@ -224,6 +225,7 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (* Retrieve, for every variant, the list of its instantiated field types *) let variants_fields_types = Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types + cgs in (* Check if there is strictly more than one variant *) if List.length variants_fields_types > 1 && not expand_enumerations then @@ -280,11 +282,12 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) : let compute_expanded_symbolic_adt_value (expand_enumerations : bool) (kind : V.sv_kind) (adt_id : T.type_id) (regions : T.RegionId.id T.region list) (types : T.rty list) - (ctx : C.eval_ctx) : V.symbolic_expansion list = + (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list + = match (adt_id, regions, types) with | T.AdtId def_id, _, _ -> compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind - def_id regions types ctx + def_id regions types cgs ctx | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ] | T.Assumed T.Option, [], [ ty ] -> compute_expanded_symbolic_option_value expand_enumerations kind ty @@ -513,10 +516,10 @@ let expand_symbolic_bool (config : C.config) (sv : V.symbolic_value) let original_sv = sv in let original_sv_place = sv_place in let rty = original_sv.V.sv_ty in - assert (rty = T.Bool); + assert (rty = T.Literal PV.Bool); (* Expand the symbolic value to true or false and continue execution *) - let see_true = V.SePrimitive (PV.Bool true) in - let see_false = V.SePrimitive (PV.Bool false) in + let see_true = V.SeLiteral (PV.Bool true) in + let see_false = V.SeLiteral (PV.Bool false) in let seel = [ (Some see_true, cf_true); (Some see_false, cf_false) ] in (* Apply the symbolic expansion (this also outputs the updated symbolic AST) *) apply_branching_symbolic_expansions_non_borrow config original_sv @@ -540,12 +543,12 @@ let expand_symbolic_value_no_branching (config : C.config) fun cf ctx -> match rty with (* ADTs *) - | T.Adt (adt_id, regions, types) -> + | T.Adt (adt_id, regions, types, cgs) -> (* Compute the expanded value *) let allow_branching = false in let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types ctx + regions types cgs ctx in (* There should be exacly one branch *) let see = Collections.List.to_cons_nil seel in @@ -597,12 +600,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value) (* Execute *) match rty with (* ADTs *) - | T.Adt (adt_id, regions, types) -> + | T.Adt (adt_id, regions, types, cgs) -> let allow_branching = true in (* Compute the expanded value *) let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types ctx + regions types cgs ctx in (* Apply *) let seel = List.map (fun see -> (Some see, cf_branches)) seel in @@ -617,7 +620,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value) (tgts : (V.scalar_value * st_cm_fun) list) (otherwise : st_cm_fun) (cf_after_join : st_m_fun) : m_fun = (* Sanity check *) - assert (sv.V.sv_ty = T.Integer int_type); + assert (sv.V.sv_ty = T.Literal (PV.Integer int_type)); (* For all the branches of the switch, we expand the symbolic value * to the value given by the branch and execute the branch statement. * For the otherwise branch, we leave the symbolic value as it is @@ -628,7 +631,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value) * (optional expansion, statement to execute) *) let seel = - List.map (fun (v, cf) -> (Some (V.SePrimitive (PV.Scalar v)), cf)) tgts + List.map (fun (v, cf) -> (Some (V.SeLiteral (PV.Scalar v)), cf)) tgts in let seel = List.append seel [ (None, otherwise) ] in (* Then expand and evaluate - this generates the proper symbolic AST *) @@ -676,7 +679,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = ^ symbolic_value_to_string ctx sv)); let cc : cm_fun = match sv.V.sv_ty with - | T.Adt (AdtId def_id, _, _) -> + | T.Adt (AdtId def_id, _, _, _) -> (* {!expand_symbolic_value_no_branching} checks if there are branchings, * but we prefer to also check it here - this leads to cleaner messages * and debugging *) @@ -701,16 +704,16 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = [config]): " ^ Print.name_to_string def.name)) else expand_symbolic_value_no_branching config sv None - | T.Adt ((Tuple | Assumed Box), _, _) | T.Ref (_, _, _) -> + | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) -> (* Ok *) expand_symbolic_value_no_branching config sv None - | T.Adt (Assumed (Vec | Option), _, _) -> + | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _) + -> (* We can't expand those *) - raise (Failure "Attempted to greedily expand a Vec or an Option ") - | T.Array _ -> raise Utils.Unimplemented - | T.Slice _ -> raise (Failure "Can't expand symbolic slices") - | T.TypeVar _ | Bool | Char | Never | Integer _ | Str -> - raise (Failure "Unreachable") + raise + (Failure + "Attempted to greedily expand an ADT which can't be expanded ") + | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable") in (* Compose and continue *) comp cc expand cf ctx diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index d75f5a26..8b2070c6 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -94,24 +94,23 @@ let access_rplace_reorganize (config : C.config) (expand_prim_copy : bool) ctx (** Convert an operand constant operand value to a typed value *) -let primitive_to_typed_value (ty : T.ety) (cv : V.primitive_value) : +let literal_to_typed_value (ty : PV.literal_type) (cv : V.literal) : V.typed_value = (* Check the type while converting - we actually need some information * contained in the type *) log#ldebug (lazy - ("primitive_to_typed_value:" ^ "\n- cv: " - ^ Print.PrimitiveValues.primitive_value_to_string cv)); + ("literal_to_typed_value:" ^ "\n- cv: " + ^ Print.PrimitiveValues.literal_to_string cv)); match (ty, cv) with (* Scalar, boolean... *) - | T.Bool, Bool v -> { V.value = V.Primitive (Bool v); ty } - | T.Char, Char v -> { V.value = V.Primitive (Char v); ty } - | T.Str, String v -> { V.value = V.Primitive (String v); ty } - | T.Integer int_ty, PV.Scalar v -> + | PV.Bool, Bool v -> { V.value = V.Literal (Bool v); ty = T.Literal ty } + | Char, Char v -> { V.value = V.Literal (Char v); ty = T.Literal ty } + | Integer int_ty, PV.Scalar v -> (* Check the type and the ranges *) assert (int_ty = v.int_ty); assert (check_scalar_value_in_range v); - { V.value = V.Primitive (PV.Scalar v); ty } + { V.value = V.Literal (PV.Scalar v); ty = T.Literal ty } (* Remaining cases (invalid) *) | _, _ -> raise (Failure "Improperly typed constant value") @@ -138,14 +137,16 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config) * the fact that we have exhaustive matches below makes very obvious the cases * in which we need to fail *) match v.V.value with - | V.Primitive _ -> (ctx, v) + | V.Literal _ -> (ctx, v) | V.Adt av -> (* Sanity check *) (match v.V.ty with - | T.Adt (T.Assumed (T.Box | Vec), _, _) -> + | T.Adt (T.Assumed (T.Box | Vec), _, _, _) -> raise (Failure "Can't copy an assumed value other than Option") - | T.Adt (T.AdtId _, _, _) -> assert allow_adt_copy - | T.Adt ((T.Assumed Option | T.Tuple), _, _) -> () (* Ok *) + | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy + | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *) + | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) -> + assert (ty_is_primitively_copyable ty) | _ -> raise (Failure "Unreachable")); let ctx, fields = List.fold_left_map @@ -231,7 +232,7 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : match op with | Expressions.Constant (ty, cv) -> (* No need to reorganize the context *) - primitive_to_typed_value ty cv |> ignore; + literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore; cf ctx | Expressions.Copy p -> (* Access the value *) @@ -259,7 +260,8 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> cf (primitive_to_typed_value ty cv) ctx + | Expressions.Constant (ty, cv) -> + cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx | Expressions.Copy p -> (* Access the value *) let access = Read in @@ -349,21 +351,21 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) (* Apply the unop *) let apply cf (v : V.typed_value) : m_fun = match (unop, v.V.value) with - | E.Not, V.Primitive (Bool b) -> - cf (Ok { v with V.value = V.Primitive (Bool (not b)) }) - | E.Neg, V.Primitive (PV.Scalar sv) -> ( + | E.Not, V.Literal (Bool b) -> + cf (Ok { v with V.value = V.Literal (Bool (not b)) }) + | E.Neg, V.Literal (PV.Scalar sv) -> ( let i = Z.neg sv.PV.value in match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) - | Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) })) - | E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> ( + | Ok sv -> cf (Ok { v with V.value = V.Literal (PV.Scalar sv) })) + | E.Cast (src_ty, tgt_ty), V.Literal (PV.Scalar sv) -> ( assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> - let ty = T.Integer tgt_ty in - let value = V.Primitive (PV.Scalar sv) in + let ty = T.Literal (Integer tgt_ty) in + let value = V.Literal (PV.Scalar sv) in cf (Ok { V.ty; value })) | _ -> raise (Failure "Invalid input for unop") in @@ -380,9 +382,9 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand) let res_sv_id = C.fresh_symbolic_value_id () in let res_sv_ty = match (unop, v.V.ty) with - | E.Not, T.Bool -> T.Bool - | E.Neg, T.Integer int_ty -> T.Integer int_ty - | E.Cast (_, tgt_ty), _ -> T.Integer tgt_ty + | E.Not, (T.Literal Bool as lty) -> lty + | E.Neg, (T.Literal (Integer _) as lty) -> lty + | E.Cast (_, tgt_ty), _ -> T.Literal (Integer tgt_ty) | _ -> raise (Failure "Invalid input for unop") in let res_sv = @@ -417,11 +419,11 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) (* Equality/inequality check is primitive only for a subset of types *) assert (ty_is_primitively_copyable v1.ty); let b = v1 = v2 in - Ok { V.value = V.Primitive (Bool b); ty = T.Bool }) + Ok { V.value = V.Literal (Bool b); ty = T.Literal Bool }) else (* For the non-equality operations, the input values are necessarily scalars *) match (v1.V.value, v2.V.value) with - | V.Primitive (PV.Scalar sv1), V.Primitive (PV.Scalar sv2) -> ( + | V.Literal (PV.Scalar sv1), V.Literal (PV.Scalar sv2) -> ( (* There are binops which require the two operands to have the same type, and binops for which it is not the case. There are also binops which return booleans, and binops which @@ -441,7 +443,9 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) | E.BitOr | E.Shl | E.Shr | E.Ne | E.Eq -> raise (Failure "Unreachable") in - Ok ({ V.value = V.Primitive (Bool b); ty = T.Bool } : V.typed_value) + Ok + ({ V.value = V.Literal (Bool b); ty = T.Literal Bool } + : V.typed_value) | E.Div | E.Rem | E.Add | E.Sub | E.Mul | E.BitXor | E.BitAnd | E.BitOr -> ( (* The two operands must have the same type and the result is an integer *) @@ -469,8 +473,8 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) | Ok sv -> Ok { - V.value = V.Primitive (PV.Scalar sv); - ty = Integer sv1.int_ty; + V.value = V.Literal (PV.Scalar sv); + ty = T.Literal (Integer sv1.int_ty); }) | E.Shl | E.Shr -> raise Unimplemented | E.Ne | E.Eq -> raise (Failure "Unreachable")) @@ -506,19 +510,19 @@ let eval_binary_op_symbolic (config : C.config) (binop : E.binop) assert (v1.ty = v2.ty); (* Equality/inequality check is primitive only for a subset of types *) assert (ty_is_primitively_copyable v1.ty); - T.Bool) + T.Literal Bool) else (* Other operations: input types are integers *) match (v1.V.ty, v2.V.ty) with - | T.Integer int_ty1, T.Integer int_ty2 -> ( + | T.Literal (Integer int_ty1), T.Literal (Integer int_ty2) -> ( match binop with | E.Lt | E.Le | E.Ge | E.Gt -> assert (int_ty1 = int_ty2); - T.Bool + T.Literal Bool | E.Div | E.Rem | E.Add | E.Sub | E.Mul | E.BitXor | E.BitAnd | E.BitOr -> assert (int_ty1 = int_ty2); - T.Integer int_ty1 + T.Literal (Integer int_ty1) | E.Shl | E.Shr -> raise Unimplemented | E.Ne | E.Eq -> raise (Failure "Unreachable")) | _ -> raise (Failure "Invalid inputs for binop") @@ -652,7 +656,7 @@ let eval_rvalue_aggregate (config : C.config) | E.AggregatedTuple -> let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in let v = V.Adt { variant_id = None; field_values = values } in - let ty = T.Adt (T.Tuple, [], tys) in + let ty = T.Adt (T.Tuple, [], tys, []) in let aggregated : V.typed_value = { V.value = v; ty } in (* Call the continuation *) cf aggregated ctx @@ -663,20 +667,20 @@ let eval_rvalue_aggregate (config : C.config) assert (List.length values = 1) else raise (Failure "Unreachable"); (* Construt the value *) - let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in + let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in let av : V.adt_value = { V.variant_id = Some variant_id; V.field_values = values } in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx - | E.AggregatedAdt (def_id, opt_variant_id, regions, types) -> + | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) -> (* Sanity checks *) let type_decl = C.ctx_lookup_type_decl ctx def_id in assert (List.length type_decl.region_params = List.length regions); let expected_field_types = Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id - types + types cgs in assert ( expected_field_types @@ -685,10 +689,52 @@ let eval_rvalue_aggregate (config : C.config) let av : V.adt_value = { V.variant_id = opt_variant_id; V.field_values = values } in - let aty = T.Adt (T.AdtId def_id, regions, types) in + let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx + | E.AggregatedRange ety -> + (* There should be two fields exactly *) + let v0, v1 = + match values with + | [ v0; v1 ] -> (v0, v1) + | _ -> raise (Failure "Unreachable") + in + (* Ranges are parametric over the type of indices. For now we only + support scalars, which can be of any type *) + assert (literal_type_is_integer (ty_as_literal ety)); + assert (v0.ty = ety); + assert (v1.ty = ety); + (* Construct the value *) + let av : V.adt_value = + { V.variant_id = None; V.field_values = values } + in + let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in + let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in + (* Call the continuation *) + cf aggregated ctx + | E.AggregatedArray (ety, cg) -> ( + (* Sanity check: all the values have the proper type *) + assert (List.for_all (fun (v : V.typed_value) -> v.V.ty = ety) values); + (* Sanity check: the number of values is consistent with the length *) + let len = (literal_as_scalar (const_generic_as_literal cg)).value in + assert (len = Z.of_int (List.length values)); + let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in + (* In order to generate a better AST, we introduce a symbolic + value equal to the array. The reason is that otherwise, the + array we introduce here might be duplicated in the generated + code: by introducing a symbolic value we introduce a let-binding + in the generated code. *) + let saggregated = + mk_fresh_symbolic_typed_value_from_ety V.Aggregate ty + in + (* Call the continuation *) + match cf saggregated ctx with + | None -> None + | Some e -> + (* Introduce the symbolic value in the AST *) + let sv = ValuesUtils.value_as_symbolic saggregated.value in + Some (SymbolicAst.IntroSymbolic (ctx, None, sv, Array values, e))) in (* Compose and apply *) comp eval_ops compute cf diff --git a/compiler/InterpreterLoopsCore.ml b/compiler/InterpreterLoopsCore.ml index 209fce1c..6e33c75b 100644 --- a/compiler/InterpreterLoopsCore.ml +++ b/compiler/InterpreterLoopsCore.ml @@ -60,8 +60,7 @@ module type PrimMatcher = sig val match_rtys : T.rty -> T.rty -> T.rty (** The input primitive values are not equal *) - val match_distinct_primitive_values : - T.ety -> V.primitive_value -> V.primitive_value -> V.typed_value + val match_distinct_literals : T.ety -> V.literal -> V.literal -> V.typed_value (** The input ADTs don't have the same variant *) val match_distinct_adts : T.ety -> V.adt_value -> V.adt_value -> V.typed_value diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml index aff8f3fe..4310f017 100644 --- a/compiler/InterpreterLoopsFixedPoint.ml +++ b/compiler/InterpreterLoopsFixedPoint.ml @@ -109,6 +109,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun = (fun r -> if T.RegionId.Set.mem r rids then nrid else r) (fun x -> x) (fun x -> x) + (fun x -> x) (fun id -> let nid = C.fresh_symbolic_value_id () in let sv = V.SymbolicValueId.Map.find id absl_id_maps.sids_to_values in @@ -321,7 +322,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun = let sv = V.SymbolicValueId.Map.find sid new_ctx_ids_map.sids_to_values in - SymbolicAst.IntroSymbolic (ctx, None, sv, v, e)) + SymbolicAst.IntroSymbolic (ctx, None, sv, SingleValue v, e)) e !sid_subst) let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) : diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index 6fb0449d..bf88e055 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -556,6 +556,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context; region_groups; type_vars; + const_generic_vars; env = _; ended_regions = ended_regions0; } = @@ -567,6 +568,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context = _; region_groups = _; type_vars = _; + const_generic_vars = _; env = _; ended_regions = ended_regions1; } = @@ -580,6 +582,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context; region_groups; type_vars; + const_generic_vars; env; ended_regions; } @@ -635,6 +638,7 @@ let refresh_abs (old_abs : V.AbstractionId.Set.t) (ctx : C.eval_ctx) : (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) subst ctx.env in { ctx with C.env } diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 80cd93cf..9248e513 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -44,11 +44,11 @@ let compute_abs_borrows_loans_maps (no_duplicates : bool) (id0 : Id0.id) (id1 : Id1.id) : unit = (* Sanity check *) (if check_singleton_sets || check_not_already_registered then - match Id0.Map.find_opt id0 !map with - | None -> () - | Some set -> - assert ( - (not check_not_already_registered) || not (Id1.Set.mem id1 set))); + match Id0.Map.find_opt id0 !map with + | None -> () + | Some set -> + assert ( + (not check_not_already_registered) || not (Id1.Set.mem id1 set))); (* Update the mapping *) map := Id0.Map.update id0 @@ -149,9 +149,11 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) (match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty = let match_rec = match_types match_distinct_types match_regions in match (ty0, ty1) with - | Adt (id0, regions0, tys0), Adt (id1, regions1, tys1) -> + | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) -> assert (id0 = id1); + assert (cgs0 = cgs1); let id = id0 in + let cgs = cgs1 in let regions = List.map (fun (id0, id1) -> match_regions id0 id1) @@ -160,16 +162,15 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) let tys = List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1) in - Adt (id, regions, tys) + Adt (id, regions, tys, cgs) | TypeVar vid0, TypeVar vid1 -> assert (vid0 = vid1); let vid = vid0 in TypeVar vid - | Bool, Bool | Char, Char | Never, Never | Str, Str -> ty0 - | Integer int_ty0, Integer int_ty1 -> - assert (int_ty0 = int_ty1); + | Literal lty0, Literal lty1 -> + assert (lty0 = lty1); ty0 - | Array ty0, Array ty1 | Slice ty0, Slice ty1 -> match_rec ty0 ty1 + | Never, Never -> ty0 | Ref (r0, ty0, k0), Ref (r1, ty1, k1) -> let r = match_regions r0 r1 in let ty = match_rec ty0 ty1 in @@ -184,8 +185,8 @@ module MakeMatcher (M : PrimMatcher) : Matcher = struct let match_rec = match_typed_values ctx in let ty = M.match_etys v0.V.ty v1.V.ty in match (v0.V.value, v1.V.value) with - | V.Primitive pv0, V.Primitive pv1 -> - if pv0 = pv1 then v1 else M.match_distinct_primitive_values ty pv0 pv1 + | V.Literal lv0, V.Literal lv1 -> + if lv0 = lv1 then v1 else M.match_distinct_literals ty lv0 lv1 | V.Adt av0, V.Adt av1 -> if av0.variant_id = av1.variant_id then let fields = List.combine av0.field_values av1.field_values in @@ -385,8 +386,8 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct assert (ty0 = ty1); ty0 - let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value) - (_ : V.primitive_value) : V.typed_value = + let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) : + V.typed_value = mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty let match_distinct_adts (ty : T.ety) (adt0 : V.adt_value) (adt1 : V.adt_value) @@ -834,8 +835,8 @@ struct in match_types match_distinct_types match_regions ty0 ty1 - let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value) - (_ : V.primitive_value) : V.typed_value = + let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) : + V.typed_value = mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty let match_distinct_adts (_ty : T.ety) (_adt0 : V.adt_value) @@ -1616,7 +1617,7 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id) cc (cf (if is_loop_entry then EndEnterLoop (loop_id, input_values) - else EndContinue (loop_id, input_values))) + else EndContinue (loop_id, input_values))) tgt_ctx in diff --git a/compiler/InterpreterLoopsMatchCtxs.mli b/compiler/InterpreterLoopsMatchCtxs.mli index d0f57f32..20b997ce 100644 --- a/compiler/InterpreterLoopsMatchCtxs.mli +++ b/compiler/InterpreterLoopsMatchCtxs.mli @@ -34,13 +34,13 @@ val compute_abs_borrows_loans_maps : We use it for joins, to check if two environments are convertible, etc. See for instance {!MakeJoinMatcher} and {!MakeCheckEquivMatcher}. - The functor is parameterized by a {!InterpreterLoopsCore.PrimMatcher}, which implements the - non-generic part of the match. More precisely, the role of {!InterpreterLoopsCore.PrimMatcher} is two + The functor is parameterized by a {!PrimMatcher}, which implements the + non-generic part of the match. More precisely, the role of {!PrimMatcher} is two provide generic functions which recursively match two values (by recursively matching the fields of ADT values for instance). When it does need to match values in a non-trivial manner (if two ADT values don't have the same variant for instance) it calls the corresponding specialized function from - {!InterpreterLoopsCore.PrimMatcher}. + {!PrimMatcher}. *) module MakeMatcher : functor (_ : PrimMatcher) -> Matcher diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 619815b3..04dc8892 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -97,7 +97,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) match (pe, v.V.value, v.V.ty) with | ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id), V.Adt adt, - T.Adt (type_id, _, _) ) -> ( + T.Adt (type_id, _, _, _) ) -> ( (* Check consistency *) (match (proj_kind, type_id) with | ProjAdt (def_id, opt_variant_id), T.AdtId def_id' -> @@ -119,7 +119,8 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) let updated = { v with value = nadt } in Ok (ctx, { res with updated })) (* Tuples *) - | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _) -> ( + | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _) + -> ( assert (arity = List.length adt.field_values); let fv = T.FieldId.nth adt.field_values field_id in (* Project *) @@ -144,7 +145,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) (* Box dereferencement *) | ( DerefBox, Adt { variant_id = None; field_values = [ bv ] }, - T.Adt (T.Assumed T.Box, _, _) ) -> ( + T.Adt (T.Assumed T.Box, _, _, _) ) -> ( (* We allow moving inside of boxes. In practice, this kind of * manipulations should happen only inside unsage code, so * it shouldn't happen due to user code, and we leverage it @@ -249,7 +250,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) in Ok (ctx, { res with updated = nv }) else Error (FailSharedLoan bids)) - | (_, (V.Primitive _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r -> + | (_, (V.Literal _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r -> let pe, v, ty = r in let pe = "- pe: " ^ E.show_projection_elem pe in let v = "- v:\n" ^ V.show_value v in @@ -358,7 +359,8 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value) let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.erased_region list) (types : T.ety list) : V.typed_value = + (regions : T.erased_region list) (types : T.ety list) + (cgs : T.const_generic list) : V.typed_value = (* Lookup the definition and check if it is an enumeration - it should be an enumeration if and only if the projection element is a field projection with *some* variant id. Retrieve the list @@ -367,12 +369,12 @@ let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) assert (List.length regions = List.length def.T.region_params); (* Compute the field types *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types + Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs in (* Initialize the expanded value *) let fields = List.map mk_bottom field_types in let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in - let ty = T.Adt (T.AdtId def_id, regions, types) in + let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in { V.value = av; V.ty } let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) @@ -385,7 +387,7 @@ let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) else raise (Failure "Unreachable") in let av = V.Adt { variant_id = Some variant_id; field_values } in - let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ]) in + let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in { V.value = av; ty } let compute_expanded_bottom_tuple_value (field_types : T.ety list) : @@ -393,7 +395,7 @@ let compute_expanded_bottom_tuple_value (field_types : T.ety list) : (* Generate the field values *) let fields = List.map mk_bottom field_types in let v = V.Adt { variant_id = None; field_values = fields } in - let ty = T.Adt (T.Tuple, [], field_types) in + let ty = T.Adt (T.Tuple, [], field_types, []) in { V.value = v; V.ty } (** Auxiliary helper to expand {!V.Bottom} values. @@ -445,16 +447,16 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place) match (pe, ty) with (* "Regular" ADTs *) | ( Field (ProjAdt (def_id, opt_variant_id), _), - T.Adt (T.AdtId def_id', regions, types) ) -> + T.Adt (T.AdtId def_id', regions, types, cgs) ) -> assert (def_id = def_id'); compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id - opt_variant_id regions types + opt_variant_id regions types cgs (* Option *) - | Field (ProjOption variant_id, _), T.Adt (T.Assumed T.Option, [], [ ty ]) - -> + | ( Field (ProjOption variant_id, _), + T.Adt (T.Assumed T.Option, [], [ ty ], []) ) -> compute_expanded_bottom_option_value variant_id ty (* Tuples *) - | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys) -> + | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) -> assert (arity = List.length tys); (* Generate the field values *) compute_expanded_bottom_tuple_value tys diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli index 6e9286cd..4a9f3b41 100644 --- a/compiler/InterpreterPaths.mli +++ b/compiler/InterpreterPaths.mli @@ -61,6 +61,7 @@ val compute_expanded_bottom_adt_value : T.VariantId.id option -> T.erased_region list -> T.ety list -> + T.const_generic list -> V.typed_value (** Compute an expanded [Option] ⊥ value *) diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml index 9487df84..faed066b 100644 --- a/compiler/InterpreterProjectors.ml +++ b/compiler/InterpreterProjectors.ml @@ -23,12 +23,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx) if not (ty_has_regions_in_set regions ty) then [] else match (v.V.value, ty) with - | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> [] - | V.Adt adt, T.Adt (id, region_params, tys) -> + | V.Literal _, T.Literal _ -> [] + | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> (* Retrieve the types of the fields *) let field_types = Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys + region_params tys cgs in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -102,12 +102,12 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx) else let value : V.avalue = match (v.V.value, ty) with - | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> V.AIgnored - | V.Adt adt, T.Adt (id, region_params, tys) -> + | V.Literal _, T.Literal _ -> V.AIgnored + | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> (* Retrieve the types of the fields *) let field_types = Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys + region_params tys cgs in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -231,7 +231,7 @@ let symbolic_expansion_non_borrow_to_value (sv : V.symbolic_value) let ty = Subst.erase_regions sv.V.sv_ty in let value = match see with - | SePrimitive cv -> V.Primitive cv + | SeLiteral cv -> V.Literal cv | SeAdt (variant_id, field_values) -> let field_values = List.map mk_typed_value_from_symbolic_value field_values @@ -267,9 +267,9 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t) (* Match *) let (value, ty) : V.avalue * T.rty = match (see, original_sv_ty) with - | SePrimitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> - (V.AIgnored, original_sv_ty) - | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys) -> + | SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty) + | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs) + -> (* Project over the field values *) let field_values = List.map diff --git a/compiler/InterpreterProjectors.mli b/compiler/InterpreterProjectors.mli index 1afb9d53..bcc3dee2 100644 --- a/compiler/InterpreterProjectors.mli +++ b/compiler/InterpreterProjectors.mli @@ -55,7 +55,16 @@ val prepare_reborrows : bool -> (V.BorrowId.id -> V.BorrowId.id) * (C.eval_ctx -> C.eval_ctx) -(** Apply (and reduce) a projector over borrows to a value. +(** Apply (and reduce) a projector over borrows to an avalue. + We use this for instance to spread the borrows present in the inputs + of a function into the regions introduced for this function. For instance: + {[ + fn f<'a, 'b, T>(x: &'a T, y: &'b T) + ]} + If we call `f` with `x -> shared_borrow l0` and `y -> shared_borrow l1`, then + for the region introduced for `'a` we need to project the value for `x` to + a shared aborrow, and we need to ignore the borrow in `y`, because it belongs + to the other region. Parameters: - [check_symbolic_no_ended]: controls whether we check or not whether diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index f5b1111e..045c4484 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -149,7 +149,7 @@ let eval_assertion_concrete (config : C.config) (assertion : A.assertion) : let eval_assert cf (v : V.typed_value) : m_fun = fun ctx -> match v.value with - | Primitive (Bool b) -> + | Literal (Bool b) -> (* Branch *) if b = assertion.expected then cf Unit ctx else cf Panic ctx | _ -> @@ -172,26 +172,26 @@ let eval_assertion (config : C.config) (assertion : A.assertion) : st_cm_fun = (* Evaluate the assertion *) let eval_assert cf (v : V.typed_value) : m_fun = fun ctx -> - assert (v.ty = T.Bool); + assert (v.ty = T.Literal PV.Bool); (* We make a choice here: we could completely decouple the concrete and * symbolic executions here but choose not to. In the case where we * know the concrete value of the boolean we test, we use this value * even if we are in symbolic mode. Note that this case should be * extremely rare... *) match v.value with - | Primitive (Bool _) -> + | Literal (Bool _) -> (* Delegate to the concrete evaluation function *) eval_assertion_concrete config assertion cf ctx | Symbolic sv -> assert (config.mode = C.SymbolicMode); - assert (sv.V.sv_ty = T.Bool); + assert (sv.V.sv_ty = T.Literal PV.Bool); (* We continue the execution as if the test had succeeded, and thus * perform the symbolic expansion: sv ~~> true. * We will of course synthesize an assertion in the generated code * (see below). *) let ctx = apply_symbolic_expansion_non_borrow config sv - (V.SePrimitive (PV.Bool true)) ctx + (V.SeLiteral (PV.Bool true)) ctx in (* Continue *) let expr = cf Unit ctx in @@ -232,7 +232,8 @@ let set_discriminant (config : C.config) (p : E.place) let update_value cf (v : V.typed_value) : m_fun = fun ctx -> match (v.V.ty, v.V.value) with - | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types), + | ( T.Adt + (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), V.Adt av ) -> ( (* There are two situations: - either the discriminant is already the proper one (in which case we @@ -252,7 +253,7 @@ let set_discriminant (config : C.config) (p : E.place) | T.AdtId def_id -> compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id (Some variant_id) - regions types + regions types cgs | T.Assumed T.Option -> assert (regions = []); compute_expanded_bottom_option_value variant_id @@ -260,13 +261,14 @@ let set_discriminant (config : C.config) (p : E.place) | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx) - | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types), + | ( T.Adt + (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), V.Bottom ) -> let bottom_v = match type_id with | T.AdtId def_id -> compute_expanded_bottom_adt_value ctx.type_context.type_decls - def_id (Some variant_id) regions types + def_id (Some variant_id) regions types cgs | T.Assumed T.Option -> assert (regions = []); compute_expanded_bottom_option_value variant_id @@ -285,7 +287,7 @@ let set_discriminant (config : C.config) (p : E.place) * or reset an already initialized value, really. *) raise (Failure "Unexpected value") | _, (V.Adt _ | V.Bottom) -> raise (Failure "Inconsistent state") - | _, (V.Primitive _ | V.Borrow _ | V.Loan _) -> + | _, (V.Literal _ | V.Borrow _ | V.Loan _) -> raise (Failure "Unexpected value") in (* Compose and apply *) @@ -302,20 +304,21 @@ let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx) instantiation of a non-local function. *) let get_non_local_function_return_type (fid : A.assumed_fun_id) - (region_params : T.erased_region list) (type_params : T.ety list) : T.ety = + (region_params : T.erased_region list) (type_params : T.ety list) + (const_generic_params : T.const_generic list) : T.ety = (* [Box::free] has a special treatment *) - match (fid, region_params, type_params) with - | A.BoxFree, [], [ _ ] -> mk_unit_ty + match (fid, region_params, type_params, const_generic_params) with + | A.BoxFree, [], [ _ ], [] -> mk_unit_ty | _ -> (* Retrieve the function's signature *) let sg = Assumed.get_assumed_sig fid in (* Instantiate the return type *) - let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) sg.type_params) - type_params + let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.const_generic_params + const_generic_params in - Subst.erase_regions_substitute_types tsubst sg.output + Subst.erase_regions_substitute_types tsubst cgsubst sg.output let move_return_value (config : C.config) (pop_return_value : bool) (cf : V.typed_value option -> m_fun) : m_fun = @@ -417,18 +420,20 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun = (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_replace_concrete (_config : C.config) - (_region_params : T.erased_region list) (_type_params : T.ety list) : cm_fun - = + (_region_params : T.erased_region list) (_type_params : T.ety list) + (_cg_params : T.const_generic list) : cm_fun = fun _cf _ctx -> raise Unimplemented (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_new_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = fun cf ctx -> (* Check and retrieve the arguments *) - match (region_params, type_params, ctx.env) with + match (region_params, type_params, cg_params, ctx.env) with | ( [], [ boxed_ty ], + [], Var (VarBinder input_var, input_value) :: Var (_ret_var, _) :: C.Frame :: _ ) -> @@ -443,7 +448,7 @@ let eval_box_new_concrete (config : C.config) (* Create the new box *) let cf_create cf (moved_input_value : V.typed_value) : m_fun = (* Create the box value *) - let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ]) in + let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in let box_v = V.Adt { variant_id = None; field_values = [ moved_input_value ] } in @@ -465,12 +470,13 @@ let eval_box_new_concrete (config : C.config) and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *) let eval_box_deref_mut_or_shared_concrete (config : C.config) (region_params : T.erased_region list) (type_params : T.ety list) - (is_mut : bool) : cm_fun = + (cg_params : T.const_generic list) (is_mut : bool) : cm_fun = fun cf ctx -> (* Check the arguments *) - match (region_params, type_params, ctx.env) with + match (region_params, type_params, cg_params, ctx.env) with | ( [], [ boxed_ty ], + [], Var (VarBinder input_var, input_value) :: Var (_ret_var, _) :: C.Frame :: _ ) -> @@ -510,15 +516,19 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_deref_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = let is_mut = false in - eval_box_deref_mut_or_shared_concrete config region_params type_params is_mut + eval_box_deref_mut_or_shared_concrete config region_params type_params + cg_params is_mut (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_deref_mut_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = let is_mut = true in - eval_box_deref_mut_or_shared_concrete config region_params type_params is_mut + eval_box_deref_mut_or_shared_concrete config region_params type_params + cg_params is_mut (** Auxiliary function - see {!eval_non_local_function_call}. @@ -540,11 +550,11 @@ let eval_box_deref_mut_concrete (config : C.config) the destination (by setting it to [()]). *) let eval_box_free (config : C.config) (region_params : T.erased_region list) - (type_params : T.ety list) (args : E.operand list) (dest : E.place) : cm_fun - = + (type_params : T.ety list) (cg_params : T.const_generic list) + (args : E.operand list) (dest : E.place) : cm_fun = fun cf ctx -> - match (region_params, type_params, args) with - | [], [ boxed_ty ], [ E.Move input_box_place ] -> + match (region_params, type_params, cg_params, args) with + | [], [ boxed_ty ], [], [ E.Move input_box_place ] -> (* Required type checking *) let input_box = InterpreterPaths.read_place Write input_box_place ctx in (let input_ty = ty_get_box input_box.V.ty in @@ -562,15 +572,15 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list) (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id) - (_region_params : T.erased_region list) (_type_params : T.ety list) : cm_fun - = + (_region_params : T.erased_region list) (_type_params : T.ety list) + (_cg_params : T.const_generic list) : cm_fun = fun _cf _ctx -> raise Unimplemented (** Evaluate a non-local function call in concrete mode *) let eval_non_local_function_call_concrete (config : C.config) (fid : A.assumed_fun_id) (region_params : T.erased_region list) - (type_params : T.ety list) (args : E.operand list) (dest : E.place) : cm_fun - = + (type_params : T.ety list) (cg_params : T.const_generic list) + (args : E.operand list) (dest : E.place) : cm_fun = (* There are two cases (and this is extremely annoying): - the function is not box_free - the function is box_free @@ -579,7 +589,7 @@ let eval_non_local_function_call_concrete (config : C.config) match fid with | A.BoxFree -> (* Degenerate case: box_free *) - eval_box_free config region_params type_params args dest + eval_box_free config region_params type_params cg_params args dest | _ -> (* "Normal" case: not box_free *) (* Evaluate the operands *) @@ -602,6 +612,7 @@ let eval_non_local_function_call_concrete (config : C.config) let ret_vid = E.VarId.zero in let ret_ty = get_non_local_function_return_type fid region_params type_params + cg_params in let ret_var = mk_var ret_vid (Some "@return") ret_ty in let cc = comp cc (push_uninitialized_var ret_var) in @@ -619,15 +630,25 @@ let eval_non_local_function_call_concrete (config : C.config) * access to a body. *) let cf_eval_body : cm_fun = match fid with - | A.Replace -> eval_replace_concrete config region_params type_params - | BoxNew -> eval_box_new_concrete config region_params type_params - | BoxDeref -> eval_box_deref_concrete config region_params type_params + | A.Replace -> + eval_replace_concrete config region_params type_params cg_params + | BoxNew -> + eval_box_new_concrete config region_params type_params cg_params + | BoxDeref -> + eval_box_deref_concrete config region_params type_params cg_params | BoxDerefMut -> eval_box_deref_mut_concrete config region_params type_params + cg_params | BoxFree -> (* Should have been treated above *) raise (Failure "Unreachable") | VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut -> eval_vec_function_concrete config fid region_params type_params + cg_params + | ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared + | ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut + | SliceIndexShared | SliceIndexMut | SliceSubsliceShared + | SliceSubsliceMut | SliceLen -> + raise (Failure "Unimplemented") in let cc = comp cc cf_eval_body in @@ -641,8 +662,8 @@ let eval_non_local_function_call_concrete (config : C.config) (* Compose and apply *) comp cf_eval_ops cf_eval_call -let instantiate_fun_sig (type_params : T.ety list) (sg : A.fun_sig) : - A.inst_fun_sig = +let instantiate_fun_sig (type_params : T.ety list) + (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig = (* Generate fresh abstraction ids and create a substitution from region * group ids to abstraction ids *) let rg_abs_ids_bindings = @@ -671,13 +692,12 @@ let instantiate_fun_sig (type_params : T.ety list) (sg : A.fun_sig) : * work to do to properly handle full type parametrization. * *) let rtype_params = List.map ety_no_regions_to_rty type_params in - let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) sg.type_params) - rtype_params + let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params in (* Substitute the signature *) - let inst_sig = Subst.substitute_signature asubst rsubst tsubst sg in + let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in (* Return *) inst_sig @@ -873,7 +893,7 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) match config.mode with | ConcreteMode -> (* Treat the evaluation of the global as a call to the global body (without arguments) *) - (eval_local_function_call_concrete config global.body_id [] [] [] dest) + (eval_local_function_call_concrete config global.body_id [] [] [] [] dest) cf ctx | SymbolicMode -> (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be @@ -909,7 +929,7 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = let cf_if (cf : st_m_fun) (op_v : V.typed_value) : m_fun = fun ctx -> match op_v.value with - | V.Primitive (PV.Bool b) -> + | V.Literal (PV.Bool b) -> (* Evaluate the if and the branch body *) let cf_branch cf : m_fun = (* Branch *) @@ -937,7 +957,7 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = let cf_switch (cf : st_m_fun) (op_v : V.typed_value) : m_fun = fun ctx -> match op_v.value with - | V.Primitive (PV.Scalar sv) -> + | V.Literal (PV.Scalar sv) -> (* Evaluate the branch *) let cf_eval_branch cf = (* Sanity check *) @@ -1024,18 +1044,17 @@ and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = match call.func with | A.Regular fid -> eval_local_function_call config fid call.region_args call.type_args - call.args call.dest + call.const_generic_args call.args call.dest | A.Assumed fid -> eval_non_local_function_call config fid call.region_args call.type_args - call.args call.dest + call.const_generic_args call.args call.dest (** Evaluate a local (i.e., non-assumed) function call in concrete mode *) and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) - (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (_region_args : T.erased_region list) (type_args : T.ety list) + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> - assert (region_args = []); - (* Retrieve the (correctly instantiated) body *) let def = C.ctx_lookup_fun_decl ctx fid in (* We can evaluate the function call only if it is not opaque *) @@ -1049,11 +1068,13 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) | Some body -> body in let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) def.A.signature.type_params) - type_args + Subst.make_type_subst_from_vars def.A.signature.type_params type_args in - let locals, body_st = Subst.fun_body_substitute_in_body tsubst body in + let cgsubst = + Subst.make_const_generic_subst_from_vars + def.A.signature.const_generic_params cg_args + in + let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in (* Evaluate the input operands *) assert (List.length args = body.A.arg_count); @@ -1112,19 +1133,20 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) (** Evaluate a local (i.e., non-assumed) function call in symbolic mode *) and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> (* Retrieve the (correctly instantiated) signature *) let def = C.ctx_lookup_fun_decl ctx fid in let sg = def.A.signature in (* Instantiate the signature and introduce fresh abstraction and region ids * while doing so *) - let inst_sg = instantiate_fun_sig type_args sg in + let inst_sg = instantiate_fun_sig type_args cg_args sg in (* Sanity check *) assert (List.length args = List.length def.A.signature.inputs); (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg - region_args type_args args dest cf ctx + region_args type_args cg_args args dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1133,10 +1155,10 @@ and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) *) and eval_function_call_symbolic_from_inst_sig (config : C.config) (fid : A.fun_id) (inst_sg : A.inst_fun_sig) - (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (_region_args : T.erased_region list) (type_args : T.ety list) + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> - assert (region_args = []); (* Generate a fresh symbolic value for the return value *) let ret_sv_ty = inst_sg.A.output in let ret_spc = mk_fresh_symbolic_value V.FunCallRet ret_sv_ty in @@ -1202,8 +1224,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids type_args args - args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args + args args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1266,8 +1288,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) (** Evaluate a non-local function call in symbolic mode *) and eval_non_local_function_call_symbolic (config : C.config) (fid : A.assumed_fun_id) (region_args : T.erased_region list) - (type_args : T.ety list) (args : E.operand list) (dest : E.place) : - st_cm_fun = + (type_args : T.ety list) (cg_args : T.const_generic list) + (args : E.operand list) (dest : E.place) : st_cm_fun = fun cf ctx -> (* Sanity check: make sure the type parameters don't contain regions - * this is a current limitation of our synthesis *) @@ -1285,7 +1307,7 @@ and eval_non_local_function_call_symbolic (config : C.config) | A.BoxFree -> (* Degenerate case: box_free - note that this is not really a function * call: no need to call a "synthesize_..." function *) - eval_box_free config region_args type_args args dest (cf Unit) ctx + eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx | _ -> (* "Normal" case: not box_free *) (* In symbolic mode, the behaviour of a function call is completely defined @@ -1296,18 +1318,20 @@ and eval_non_local_function_call_symbolic (config : C.config) | A.BoxFree -> (* should have been treated above *) raise (Failure "Unreachable") - | _ -> instantiate_fun_sig type_args (Assumed.get_assumed_sig fid) + | _ -> + instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid) in (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig - region_args type_args args dest cf ctx + region_args type_args cg_args args dest cf ctx (** Evaluate a non-local (i.e, assumed) function call such as [Box::deref] (auxiliary helper for [eval_statement]) *) and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> (* Debug *) log#ldebug @@ -1326,23 +1350,24 @@ and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) match config.mode with | C.ConcreteMode -> eval_non_local_function_call_concrete config fid region_args type_args - args dest (cf Unit) ctx + cg_args args dest (cf Unit) ctx | C.SymbolicMode -> eval_non_local_function_call_symbolic config fid region_args type_args - args dest cf ctx + cg_args args dest cf ctx (** Evaluate a local (i.e, not assumed) function call (auxiliary helper for [eval_statement]) *) and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = match config.mode with | ConcreteMode -> - eval_local_function_call_concrete config fid region_args type_args args - dest + eval_local_function_call_concrete config fid region_args type_args cg_args + args dest | SymbolicMode -> - eval_local_function_call_symbolic config fid region_args type_args args - dest + eval_local_function_call_symbolic config fid region_args type_args cg_args + args dest (** Evaluate a statement seen as a function body *) and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun = diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli index f28bf2ea..814bc964 100644 --- a/compiler/InterpreterStatements.mli +++ b/compiler/InterpreterStatements.mli @@ -31,7 +31,8 @@ val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun Note: there are no region parameters, because they should be erased. *) -val instantiate_fun_sig : T.ety list -> LA.fun_sig -> LA.inst_fun_sig +val instantiate_fun_sig : + T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig (** Helper. diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index a50bb7ac..7bd37550 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -255,7 +255,7 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx) raise Found else () | V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack - | V.SynthRetGivenBack | V.Global | V.LoopGivenBack -> + | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate -> () end in diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index 981c2c46..f29c7f88 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -377,10 +377,10 @@ let check_borrowed_values_invariant (ctx : C.eval_ctx) : unit = let info = { outer_borrow = false; outer_shared = false } in visitor#visit_eval_ctx info ctx -let check_primitive_value_type (cv : V.primitive_value) (ty : T.ety) : unit = +let check_literal_type (cv : V.literal) (ty : PV.literal_type) : unit = match (cv, ty) with - | PV.Scalar sv, T.Integer int_ty -> assert (sv.int_ty = int_ty) - | PV.Bool _, T.Bool | PV.Char _, T.Char | PV.String _, T.Str -> () + | PV.Scalar sv, PV.Integer int_ty -> assert (sv.int_ty = int_ty) + | PV.Bool _, PV.Bool | PV.Char _, PV.Char -> () | _ -> raise (Failure "Erroneous typing") let check_typing_invariant (ctx : C.eval_ctx) : unit = @@ -404,9 +404,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = method! visit_typed_value info tv = (* Check the current pair (value, type) *) (match (tv.V.value, tv.V.ty) with - | V.Primitive cv, ty -> check_primitive_value_type cv ty + | V.Literal cv, T.Literal ty -> check_literal_type cv ty (* ADT case *) - | V.Adt av, T.Adt (T.AdtId def_id, regions, tys) -> + | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in @@ -422,7 +422,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check that the field types are correct *) let field_types = Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id - tys + tys cgs in let fields_with_types = List.combine av.V.field_values field_types @@ -431,8 +431,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.Adt av, T.Adt (T.Tuple, regions, tys) -> + | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) -> assert (regions = []); + assert (cgs = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) @@ -441,20 +442,37 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys) -> ( + | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( assert (av.V.variant_id = None || aty_id = T.Option); - match (aty_id, av.V.field_values, regions, tys) with + match (aty_id, av.V.field_values, regions, tys, cgs) with (* Box *) - | T.Box, [ inner_value ], [], [ inner_ty ] - | T.Option, [ inner_value ], [], [ inner_ty ] -> + | T.Box, [ inner_value ], [], [ inner_ty ], [] + | T.Option, [ inner_value ], [], [ inner_ty ], [] -> assert (inner_value.V.ty = inner_ty) - | T.Option, _, [], [ _ ] -> + | T.Option, _, [], [ _ ], [] -> (* Option::None: nothing to check *) () - | T.Vec, fvs, [], [ vec_ty ] -> + | T.Vec, fvs, [], [ vec_ty ], [] -> List.iter (fun (v : V.typed_value) -> assert (v.ty = vec_ty)) fvs + | T.Range, [ v0; v1 ], [], [ inner_ty ], [] -> + assert (v0.V.ty = inner_ty); + assert (v1.V.ty = inner_ty) + | T.Array, inner_values, _, [ inner_ty ], [ cg ] -> + (* *) + assert ( + List.for_all + (fun (v : V.typed_value) -> v.V.ty = inner_ty) + inner_values); + (* The length is necessarily concrete *) + let len = + (PrimitiveValuesUtils.literal_as_scalar + (TypesUtils.const_generic_as_literal cg)) + .value + in + assert (Z.of_int (List.length inner_values) = len) + | (T.Slice | T.Str), _, _, _, _ -> raise (Failure "Unexpected") | _ -> raise (Failure "Erroneous type")) | V.Bottom, _ -> (* Nothing to check *) () | V.Borrow bc, T.Ref (_, ref_ty, rkind) -> ( @@ -502,13 +520,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check the current pair (value, type) *) (match (atv.V.value, atv.V.ty) with (* ADT case *) - | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys) -> + | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) assert (List.length regions = List.length def.region_params); assert (List.length tys = List.length def.type_params); + assert (List.length cgs = List.length def.const_generic_params); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -518,7 +537,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check that the field types are correct *) let field_types = Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id - regions tys + regions tys cgs in let fields_with_types = List.combine av.V.field_values field_types @@ -527,8 +546,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.AAdt av, T.Adt (T.Tuple, regions, tys) -> + | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) -> assert (regions = []); + assert (cgs = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) @@ -537,11 +557,11 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys) -> ( + | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( assert (av.V.variant_id = None); - match (aty_id, av.V.field_values, regions, tys) with + match (aty_id, av.V.field_values, regions, tys, cgs) with (* Box *) - | T.Box, [ boxed_value ], [], [ boxed_ty ] -> + | T.Box, [ boxed_value ], [], [ boxed_ty ], [] -> assert (boxed_value.V.ty = boxed_ty) | _ -> raise (Failure "Erroneous type")) | V.ABottom, _ -> (* Nothing to check *) () diff --git a/compiler/PrimitiveValuesUtils.ml b/compiler/PrimitiveValuesUtils.ml new file mode 100644 index 00000000..0000916d --- /dev/null +++ b/compiler/PrimitiveValuesUtils.ml @@ -0,0 +1 @@ +include Charon.PrimitiveValuesUtils diff --git a/compiler/Print.ml b/compiler/Print.ml index f544c0db..9aa73d7c 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -19,6 +19,8 @@ module Values = struct r_to_string : T.RegionId.id -> string; type_var_id_to_string : T.TypeVarId.id -> string; type_decl_id_to_string : T.TypeDeclId.id -> string; + const_generic_var_id_to_string : T.ConstGenericVarId.id -> string; + global_decl_id_to_string : T.GlobalDeclId.id -> string; adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string; var_id_to_string : E.VarId.id -> string; adt_field_names : @@ -30,6 +32,8 @@ module Values = struct PT.r_to_string = PT.erased_region_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter = @@ -37,6 +41,8 @@ module Values = struct PT.r_to_string = PT.region_to_string fmt.r_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter = @@ -44,6 +50,8 @@ module Values = struct PT.r_to_string = PT.region_to_string fmt.rvar_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let var_id_to_string (id : E.VarId.id) : string = @@ -72,16 +80,16 @@ module Values = struct string = let ty_fmt : PT.etype_formatter = value_to_etype_formatter fmt in match v.value with - | Primitive cv -> PPV.primitive_value_to_string cv + | Literal cv -> PPV.literal_to_string cv | Adt av -> ( let field_values = List.map (typed_value_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _) -> + | T.Adt (T.AdtId def_id, _, _, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -103,7 +111,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _) -> ( + | T.Adt (T.Assumed aty, _, _, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -116,8 +124,13 @@ module Values = struct assert (field_values = []); "@Option::None") else raise (Failure "Unreachable") + | Range, _ -> "@Range{ " ^ String.concat ", " field_values ^ "}" | Vec, _ -> "@Vec[" ^ String.concat ", " field_values ^ "]" - | _ -> raise (Failure "Inconsistent value")) + | Array, _ -> + (* Happens when we aggregate values *) + "@Array[" ^ String.concat ", " field_values ^ "]" + | _ -> + raise (Failure ("Inconsistent value: " ^ V.show_typed_value v))) | _ -> raise (Failure "Inconsistent typed value")) | Bottom -> "⊥ : " ^ PT.ty_to_string ty_fmt v.ty | Borrow bc -> borrow_content_to_string fmt bc @@ -188,10 +201,10 @@ module Values = struct List.map (typed_avalue_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _) -> + | T.Adt (T.AdtId def_id, _, _, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -213,7 +226,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _) -> ( + | T.Adt (T.Assumed aty, _, _, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -354,20 +367,27 @@ module Contexts = struct | DummyBinder bid -> dummy_var_id_to_string bid let env_elem_to_string (fmt : PV.value_formatter) (verbose : bool) - (indent : string) (indent_incr : string) (ev : C.env_elem) : string = + (with_var_types : bool) (indent : string) (indent_incr : string) + (ev : C.env_elem) : string = match ev with | Var (var, tv) -> let bv = binder_to_string var in - indent ^ bv ^ " -> " ^ PV.typed_value_to_string fmt tv ^ " ;" + let ty = + if with_var_types then + " : " ^ PT.ty_to_string (PV.value_to_etype_formatter fmt) tv.V.ty + else "" + in + indent ^ bv ^ ty ^ " -> " ^ PV.typed_value_to_string fmt tv ^ " ;" | Abs abs -> PV.abs_to_string fmt verbose indent indent_incr abs | Frame -> raise (Failure "Can't print a Frame element") let opt_env_elem_to_string (fmt : PV.value_formatter) (verbose : bool) - (indent : string) (indent_incr : string) (ev : C.env_elem option) : string - = + (with_var_types : bool) (indent : string) (indent_incr : string) + (ev : C.env_elem option) : string = match ev with | None -> indent ^ "..." - | Some ev -> env_elem_to_string fmt verbose indent indent_incr ev + | Some ev -> + env_elem_to_string fmt verbose with_var_types indent indent_incr ev (** Filters "dummy" bindings from an environment, to gain space and clarity/ See [env_to_string]. *) @@ -404,16 +424,18 @@ module Contexts = struct (** Environments can have a lot of dummy or uninitialized values: [filter] allows to filter them when printing, replacing groups of such bindings with "..." to gain space and clarity. + [with_var_types]: if true, print the type of the variables *) let env_to_string (filter : bool) (fmt : PV.value_formatter) (verbose : bool) - (env : C.env) : string = + (with_var_types : bool) (env : C.env) : string = let env = if filter then filter_env env else List.map (fun ev -> Some ev) env in "{\n" ^ String.concat "\n" (List.map - (fun ev -> opt_env_elem_to_string fmt verbose " " " " ev) + (fun ev -> + opt_env_elem_to_string fmt verbose with_var_types " " " " ev) env) ^ "\n}" @@ -425,6 +447,8 @@ module Contexts = struct PV.r_to_string = fmt.r_to_string; PV.type_var_id_to_string = fmt.type_var_id_to_string; PV.type_decl_id_to_string = fmt.type_decl_id_to_string; + PV.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PV.global_decl_id_to_string = fmt.global_decl_id_to_string; PV.adt_variant_to_string = fmt.adt_variant_to_string; PV.var_id_to_string = fmt.var_id_to_string; PV.adt_field_names = fmt.adt_field_names; @@ -450,10 +474,18 @@ module Contexts = struct let v = C.lookup_type_var ctx vid in v.name in + let const_generic_var_id_to_string vid = + let v = C.lookup_const_generic_var ctx vid in + v.name + in let type_decl_id_to_string def_id = let def = C.ctx_lookup_type_decl ctx def_id in name_to_string def.name in + let global_decl_id_to_string def_id = + let def = C.ctx_lookup_global_decl ctx def_id in + name_to_string def.name + in let adt_variant_to_string = PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls in @@ -469,6 +501,8 @@ module Contexts = struct r_to_string; type_var_id_to_string; type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; adt_variant_to_string; var_id_to_string; adt_field_names; @@ -492,6 +526,7 @@ module Contexts = struct r_to_string = ctx_fmt.PV.r_to_string; type_var_id_to_string = ctx_fmt.PV.type_var_id_to_string; type_decl_id_to_string = ctx_fmt.PV.type_decl_id_to_string; + const_generic_var_id_to_string = ctx_fmt.PV.const_generic_var_id_to_string; adt_variant_to_string = ctx_fmt.PV.adt_variant_to_string; var_id_to_string = ctx_fmt.PV.var_id_to_string; adt_field_names = ctx_fmt.PV.adt_field_names; @@ -518,7 +553,7 @@ module Contexts = struct frames let fmt_eval_ctx_to_string_gen (fmt : ctx_formatter) (verbose : bool) - (filter : bool) (ctx : C.eval_ctx) : string = + (filter : bool) (with_var_types : bool) (ctx : C.eval_ctx) : string = let ended_regions = T.RegionId.Set.to_string None ctx.ended_regions in let frames = split_env_according_to_frames ctx.env in let num_frames = List.length frames in @@ -540,23 +575,23 @@ module Contexts = struct ^ string_of_int !num_bindings ^ "\n- dummy bindings: " ^ string_of_int !num_dummies ^ "\n- abstractions: " ^ string_of_int !num_abs ^ "\n" - ^ env_to_string filter fmt verbose f + ^ env_to_string filter fmt verbose with_var_types f ^ "\n") frames in "# Ended regions: " ^ ended_regions ^ "\n" ^ "# " ^ string_of_int num_frames ^ " frame(s)\n" ^ String.concat "" frames - let eval_ctx_to_string_gen (verbose : bool) (filter : bool) (ctx : C.eval_ctx) - : string = + let eval_ctx_to_string_gen (verbose : bool) (filter : bool) + (with_var_types : bool) (ctx : C.eval_ctx) : string = let fmt = eval_ctx_to_ctx_formatter ctx in - fmt_eval_ctx_to_string_gen fmt verbose filter ctx + fmt_eval_ctx_to_string_gen fmt verbose filter with_var_types ctx let eval_ctx_to_string (ctx : C.eval_ctx) : string = - eval_ctx_to_string_gen false true ctx + eval_ctx_to_string_gen false true true ctx let eval_ctx_to_string_no_filter (ctx : C.eval_ctx) : string = - eval_ctx_to_string_gen false false ctx + eval_ctx_to_string_gen false false true ctx end module PC = Contexts (* local module *) @@ -626,7 +661,7 @@ module EvalCtxLlbcAst = struct let env_elem_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (ev : C.env_elem) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in - PC.env_elem_to_string fmt false indent indent_incr ev + PC.env_elem_to_string fmt false true indent indent_incr ev let abs_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (abs : V.abs) : string = diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 3f35a023..cfb63ec2 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -6,11 +6,15 @@ open PureUtils type type_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; + global_decl_id_to_string : GlobalDeclId.id -> string; } type value_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; + global_decl_id_to_string : GlobalDeclId.id -> string; adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; @@ -20,6 +24,8 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = { type_var_id_to_string = fmt.type_var_id_to_string; type_decl_id_to_string = fmt.type_decl_id_to_string; + const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + global_decl_id_to_string = fmt.global_decl_id_to_string; } (* TODO: we need to store which variables we have encountered so far, and @@ -28,6 +34,7 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = type ast_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_to_string : @@ -41,6 +48,8 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = { type_var_id_to_string = fmt.type_var_id_to_string; type_decl_id_to_string = fmt.type_decl_id_to_string; + const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + global_decl_id_to_string = fmt.global_decl_id_to_string; adt_variant_to_string = fmt.adt_variant_to_string; var_id_to_string = fmt.var_id_to_string; adt_field_names = fmt.adt_field_names; @@ -55,20 +64,38 @@ let fun_name_to_string = Print.fun_name_to_string let global_name_to_string = Print.global_name_to_string let option_to_string = Print.option_to_string let type_var_to_string = Print.Types.type_var_to_string -let integer_type_to_string = Print.Types.integer_type_to_string +let const_generic_var_to_string = Print.Types.const_generic_var_to_string +let integer_type_to_string = Print.PrimitiveValues.integer_type_to_string +let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string +let literal_to_string = Print.PrimitiveValues.literal_to_string let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) - (type_params : type_var list) : type_formatter = + (global_decls : A.global_decl GlobalDeclId.Map.t) + (type_params : type_var list) + (const_generic_params : const_generic_var list) : type_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in type_var_to_string var in + let const_generic_var_id_to_string vid = + let var = T.ConstGenericVarId.nth const_generic_params vid in + const_generic_var_to_string var + in let type_decl_id_to_string def_id = let def = T.TypeDeclId.Map.find def_id type_decls in name_to_string def.name in - { type_var_id_to_string; type_decl_id_to_string } + let global_decl_id_to_string def_id = + let def = T.GlobalDeclId.Map.find def_id global_decls in + name_to_string def.name + in + { + type_var_id_to_string; + type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; + } (* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter. @@ -79,11 +106,16 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (fun_decls : A.fun_decl FunDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) : ast_formatter = + (type_params : type_var list) + (const_generic_params : const_generic_var list) : ast_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in type_var_to_string var in + let const_generic_var_id_to_string vid = + let var = T.ConstGenericVarId.nth const_generic_params vid in + const_generic_var_to_string var + in let type_decl_id_to_string def_id = let def = T.TypeDeclId.Map.find def_id type_decls in name_to_string def.name @@ -111,6 +143,7 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) in { type_var_id_to_string; + const_generic_var_id_to_string; type_decl_id_to_string; adt_variant_to_string; var_id_to_string; @@ -120,36 +153,51 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) global_decl_id_to_string; } +let assumed_ty_to_string (aty : assumed_ty) : string = + match aty with + | State -> "State" + | Result -> "Result" + | Error -> "Error" + | Fuel -> "Fuel" + | Option -> "Option" + | Vec -> "Vec" + | Array -> "Array" + | Slice -> "Slice" + | Str -> "Str" + | Range -> "Range" + let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match id with | AdtId id -> fmt.type_decl_id_to_string id | Tuple -> "" - | Assumed aty -> ( - match aty with - | State -> "State" - | Result -> "Result" - | Error -> "Error" - | Fuel -> "Fuel" - | Option -> "Option" - | Vec -> "Vec") + | Assumed aty -> assumed_ty_to_string aty + +(* TODO: duplicates Charon.PrintTypes.const_generic_to_string *) +let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) : + string = + match cg with + | ConstGenericGlobal id -> fmt.global_decl_id_to_string id + | ConstGenericVar id -> fmt.const_generic_var_id_to_string id + | ConstGenericValue lit -> literal_to_string lit let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = match ty with - | Adt (id, tys) -> ( + | Adt (id, tys, cgs) -> ( let tys = List.map (ty_to_string fmt false) tys in + let cgs = List.map (const_generic_to_string fmt) cgs in + let params = List.append tys cgs in match id with - | Tuple -> "(" ^ String.concat " * " tys ^ ")" + | Tuple -> + assert (cgs = []); + "(" ^ String.concat " * " tys ^ ")" | AdtId _ | Assumed _ -> - let tys_s = if tys = [] then "" else " " ^ String.concat " " tys in - let ty_s = type_id_to_string fmt id ^ tys_s in - if tys <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) + let params_s = + if params = [] then "" else " " ^ String.concat " " params + in + let ty_s = type_id_to_string fmt id ^ params_s in + if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) | TypeVar tv -> fmt.type_var_id_to_string tv - | Bool -> "bool" - | Char -> "char" - | Integer int_ty -> integer_type_to_string int_ty - | Str -> "str" - | Array aty -> "[" ^ ty_to_string fmt false aty ^ "; ?]" - | Slice sty -> "[" ^ ty_to_string fmt false sty ^ "]" + | Literal lty -> literal_type_to_string lty | Arrow (arg_ty, ret_ty) -> let ty = ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty @@ -246,9 +294,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State -> - (* This type is opaque: we can't get there *) + | State | Array | Slice | Str -> + (* Those types are opaque: we can't get there *) raise (Failure "Unreachable") + | Vec -> "@Vec" + | Range -> "@Range" | Result -> let variant_id = Option.get variant_id in if variant_id = result_return_id then "@Result::Return" @@ -270,10 +320,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) if variant_id = option_some_id then "@Option::Some " else if variant_id = option_none_id then "@Option::None" else - raise (Failure "Unreachable: improper variant id for result type") - | Vec -> - assert (variant_id = None); - "Vec") + raise (Failure "Unreachable: improper variant id for result type")) let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) (field_id : FieldId.id) : string = @@ -290,7 +337,8 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State | Fuel | Vec -> + | Range -> FieldId.to_string field_id + | State | Fuel | Vec | Array | Slice | Str -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") | Result | Error | Option -> @@ -298,17 +346,17 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) raise (Failure "Unreachable")) (** TODO: we don't need a general function anymore (it is now only used for - patterns (i.e., patterns) + patterns) *) let adt_g_value_to_string (fmt : value_formatter) (value_to_string : 'v -> string) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : string = let field_values = List.map value_to_string field_values in match ty with - | Adt (Tuple, _) -> + | Adt (Tuple, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | Adt (AdtId def_id, _) -> + | Adt (AdtId def_id, _, _) -> (* "Regular" ADT *) let adt_ident = match variant_id with @@ -330,7 +378,7 @@ let adt_g_value_to_string (fmt : value_formatter) let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | Adt (Assumed aty, _) -> ( + | Adt (Assumed aty, _, _) -> ( (* Assumed type *) match aty with | State -> @@ -375,12 +423,20 @@ let adt_g_value_to_string (fmt : value_formatter) "@Option::None") else raise (Failure "Unreachable: improper variant id for result type") - | Vec -> + | Vec | Array | Slice | Str -> assert (variant_id = None); let field_values = List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values in - "Vec [" ^ String.concat "; " field_values ^ "]") + let id = assumed_ty_to_string aty in + id ^ " [" ^ String.concat "; " field_values ^ "]" + | Range -> + assert (variant_id = None); + let field_values = + List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values + in + let id = assumed_ty_to_string aty in + id ^ " {" ^ String.concat "; " field_values ^ "}") | _ -> let fmt = value_to_type_formatter fmt in raise @@ -392,7 +448,7 @@ let adt_g_value_to_string (fmt : value_formatter) let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) : string = match v.value with - | PatConstant cv -> Print.PrimitiveValues.primitive_value_to_string cv + | PatConstant cv -> literal_to_string cv | PatVar (v, None) -> var_to_string (ast_to_type_formatter fmt) v | PatVar (v, Some mp) -> let mp = "[@mplace=" ^ mplace_to_string fmt mp ^ "]" in @@ -450,6 +506,17 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = | A.VecLen -> "alloc::vec::Vec::len" | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index" | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut" + | ArrayIndexShared -> "@ArrayIndexShared" + | ArrayIndexMut -> "@ArrayIndexMut" + | ArrayToSliceShared -> "@ArrayToSliceShared" + | ArrayToSliceMut -> "@ArrayToSliceMut" + | ArraySubsliceShared -> "@ArraySubsliceShared" + | ArraySubsliceMut -> "@ArraySubsliceMut" + | SliceLen -> "@SliceLen" + | SliceIndexShared -> "@SliceIndexShared" + | SliceIndexMut -> "@SliceIndexMut" + | SliceSubsliceShared -> "@SliceSubsliceShared" + | SliceSubsliceMut -> "@SliceSubsliceMut" let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = match fid with @@ -495,7 +562,7 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Var var_id -> let s = fmt.var_id_to_string var_id in if inside then "(" ^ s ^ ")" else s - | Const cv -> Print.PrimitiveValues.primitive_value_to_string cv + | Const cv -> literal_to_string cv | App _ -> (* Recursively destruct the app, to have a pair (app, arguments list) *) let app, args = destruct_apps e in @@ -517,25 +584,39 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Loop loop -> let e = loop_to_string fmt indent indent_incr loop in if inside then "(" ^ e ^ ")" else e - | StructUpdate supd -> + | StructUpdate supd -> ( let s = match supd.init with | None -> "" | Some vid -> " " ^ fmt.var_id_to_string vid ^ " with" in - let field_names = Option.get (fmt.adt_field_names supd.struct_id None) in let indent1 = indent ^ indent_incr in let indent2 = indent1 ^ indent_incr in - let fields = - List.map - (fun (fid, fe) -> - let field = FieldId.nth field_names fid in - let fe = texpression_to_string fmt false indent2 indent_incr fe in - "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";") - supd.updates - in - let bl = if fields = [] then "" else "\n" ^ indent in - "{" ^ s ^ String.concat "" fields ^ bl ^ "}" + (* The id should be a custom type decl id or an array *) + match supd.struct_id with + | AdtId aid -> + let field_names = Option.get (fmt.adt_field_names aid None) in + let fields = + List.map + (fun (fid, fe) -> + let field = FieldId.nth field_names fid in + let fe = + texpression_to_string fmt false indent2 indent_incr fe + in + "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";") + supd.updates + in + let bl = if fields = [] then "" else "\n" ^ indent in + "{" ^ s ^ String.concat "" fields ^ bl ^ "}" + | Assumed Array -> + let fields = + List.map + (fun (_, fe) -> + texpression_to_string fmt false indent2 indent_incr fe) + supd.updates + in + "[ " ^ String.concat ", " fields ^ " ]" + | _ -> raise (Failure "Unexpected")) | Meta (meta, e) -> ( let meta_s = meta_to_string fmt meta in let e = texpression_to_string fmt inside indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index b251a005..ac4ca081 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -32,7 +32,11 @@ IdGen () module VarId = IdGen () +module ConstGenericVarId = T.ConstGenericVarId + type integer_type = T.integer_type [@@deriving show, ord] +type const_generic_var = T.const_generic_var [@@deriving show, ord] +type const_generic = T.const_generic [@@deriving show, ord] (** The assumed types for the pure AST. @@ -50,7 +54,17 @@ type integer_type = T.integer_type [@@deriving show, ord] this state is opaque to Aeneas (the user can define it, or leave it as assumed) *) -type assumed_ty = State | Result | Error | Fuel | Vec | Option +type assumed_ty = + | State + | Result + | Error + | Fuel + | Vec + | Option + | Array + | Slice + | Str + | Range [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather @@ -91,6 +105,29 @@ class ['self] map_type_id_base = method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty = fun _ x -> x end +(** Ancestor for reduce visitor for [ty] *) +class virtual ['self] reduce_type_id_base = + object (self : 'self) + inherit [_] VisitorsRuntime.reduce + + method visit_type_decl_id : 'env -> type_decl_id -> 'a = + fun _ _ -> self#zero + + method visit_assumed_ty : 'env -> assumed_ty -> 'a = fun _ _ -> self#zero + end + +(** Ancestor for mapreduce visitor for [ty] *) +class virtual ['self] mapreduce_type_id_base = + object (self : 'self) + inherit [_] VisitorsRuntime.mapreduce + + method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty * 'a = + fun _ x -> (x, self#zero) + end + type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty [@@deriving show, @@ -112,28 +149,66 @@ type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty nude = true (* Don't inherit {!VisitorsRuntime.iter} *); concrete = true; polymorphic = false; + }, + visitors + { + name = "reduce_type_id"; + variety = "reduce"; + ancestors = [ "reduce_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + polymorphic = false; + }, + visitors + { + name = "mapreduce_type_id"; + variety = "mapreduce"; + ancestors = [ "mapreduce_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + polymorphic = false; }] +type literal_type = T.literal_type [@@deriving show, ord] + (** Ancestor for iter visitor for [ty] *) class ['self] iter_ty_base = object (_self : 'self) inherit [_] iter_type_id + inherit! [_] T.iter_const_generic + inherit! [_] PV.iter_literal_type method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () end (** Ancestor for map visitor for [ty] *) class ['self] map_ty_base = object (_self : 'self) inherit [_] map_type_id + inherit! [_] T.map_const_generic + inherit! [_] PV.map_literal_type method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x + end - method visit_integer_type : 'env -> integer_type -> integer_type = - fun _ x -> x +(** Ancestor for reduce visitor for [ty] *) +class virtual ['self] reduce_ty_base = + object (self : 'self) + inherit [_] reduce_type_id + inherit! [_] T.reduce_const_generic + inherit! [_] PV.reduce_literal_type + method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero + end + +(** Ancestor for mapreduce visitor for [ty] *) +class virtual ['self] mapreduce_ty_base = + object (self : 'self) + inherit [_] mapreduce_type_id + inherit! [_] T.mapreduce_const_generic + inherit! [_] PV.mapreduce_literal_type + + method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a = + fun _ x -> (x, self#zero) end type ty = - | Adt of type_id * ty list + | Adt of type_id * ty list * const_generic list (** {!Adt} encodes ADTs and tuples and assumed types. TODO: what about the ended regions? (ADTs may be parameterized @@ -142,12 +217,7 @@ type ty = such "partial" ADTs. *) | TypeVar of type_var_id - | Bool - | Char - | Integer of integer_type - | Str - | Array of ty (* TODO: this should be an assumed type?... *) - | Slice of ty (* TODO: this should be an assumed type?... *) + | Literal of literal_type | Arrow of ty * ty [@@deriving show, @@ -165,9 +235,25 @@ type ty = name = "map_ty"; variety = "map"; ancestors = [ "map_ty_base" ]; - nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + nude = true (* Don't inherit {!VisitorsRuntime.map} *); concrete = true; polymorphic = false; + }, + visitors + { + name = "reduce_ty"; + variety = "reduce"; + ancestors = [ "reduce_ty_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.reduce} *); + polymorphic = false; + }, + visitors + { + name = "mapreduce_ty"; + variety = "mapreduce"; + ancestors = [ "mapreduce_ty_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.mapreduce} *); + polymorphic = false; }] type field = { field_name : string option; field_ty : ty } [@@deriving show] @@ -182,12 +268,13 @@ type type_decl = { def_id : TypeDeclId.id; name : name; type_params : type_var list; + const_generic_params : const_generic_var list; kind : type_decl_kind; } [@@deriving show] -type scalar_value = V.scalar_value [@@deriving show] -type primitive_value = V.primitive_value [@@deriving show] +type scalar_value = V.scalar_value [@@deriving show, ord] +type literal = V.literal [@@deriving show, ord] (** Because we introduce a lot of temporary variables, the list of variables is not fixed: we thus must carry all its information with the variable @@ -231,68 +318,46 @@ type variant_id = VariantId.id [@@deriving show] (** Ancestor for {!iter_typed_pattern} visitor *) class ['self] iter_typed_pattern_base = object (_self : 'self) - inherit [_] VisitorsRuntime.iter - - method visit_primitive_value : 'env -> primitive_value -> unit = - fun _ _ -> () - + inherit [_] iter_ty method visit_var : 'env -> var -> unit = fun _ _ -> () method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () - method visit_ty : 'env -> ty -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () end (** Ancestor for {!map_typed_pattern} visitor *) class ['self] map_typed_pattern_base = object (_self : 'self) - inherit [_] VisitorsRuntime.map - - method visit_primitive_value : 'env -> primitive_value -> primitive_value = - fun _ x -> x - + inherit [_] map_ty method visit_var : 'env -> var -> var = fun _ x -> x method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x - method visit_ty : 'env -> ty -> ty = fun _ x -> x method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x end (** Ancestor for {!reduce_typed_pattern} visitor *) class virtual ['self] reduce_typed_pattern_base = object (self : 'self) - inherit [_] VisitorsRuntime.reduce - - method visit_primitive_value : 'env -> primitive_value -> 'a = - fun _ _ -> self#zero - + inherit [_] reduce_ty method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero - method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero end (** Ancestor for {!mapreduce_typed_pattern} visitor *) class virtual ['self] mapreduce_typed_pattern_base = object (self : 'self) - inherit [_] VisitorsRuntime.mapreduce - - method visit_primitive_value - : 'env -> primitive_value -> primitive_value * 'a = - fun _ x -> (x, self#zero) - + inherit [_] mapreduce_ty method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero) method visit_mplace : 'env -> mplace -> mplace * 'a = fun _ x -> (x, self#zero) - method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero) - method visit_variant_id : 'env -> variant_id -> variant_id * 'a = fun _ x -> (x, self#zero) end (** A pattern (which appears on the left of assignments, in matches, etc.). *) type pattern = - | PatConstant of primitive_value + | PatConstant of literal (** {!PatConstant} is necessary because we merge the switches over integer values and the matches over enumerations *) | PatVar of var * mplace option @@ -403,7 +468,12 @@ type qualif_id = which explains why we have the [type_params] field: a function or ADT constructor is always fully instantiated. *) -type qualif = { id : qualif_id; type_args : ty list } [@@deriving show] +type qualif = { + id : qualif_id; + type_args : ty list; + const_generic_args : const_generic list; +} +[@@deriving show] type field_id = FieldId.id [@@deriving show, ord] type var_id = VarId.id [@@deriving show, ord] @@ -412,11 +482,10 @@ type var_id = VarId.id [@@deriving show, ord] class ['self] iter_expression_base = object (_self : 'self) inherit [_] iter_typed_pattern - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () + inherit! [_] iter_type_id method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> () - method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> () method visit_field_id : 'env -> field_id -> unit = fun _ _ -> () end @@ -424,17 +493,10 @@ class ['self] iter_expression_base = class ['self] map_expression_base = object (_self : 'self) inherit [_] map_typed_pattern - - method visit_integer_type : 'env -> integer_type -> integer_type = - fun _ x -> x - + inherit! [_] map_type_id method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x - - method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id = - fun _ x -> x - method visit_field_id : 'env -> field_id -> field_id = fun _ x -> x end @@ -442,17 +504,10 @@ class ['self] map_expression_base = class virtual ['self] reduce_expression_base = object (self : 'self) inherit [_] reduce_typed_pattern - - method visit_integer_type : 'env -> integer_type -> 'a = - fun _ _ -> self#zero - + inherit! [_] reduce_type_id method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero - - method visit_type_decl_id : 'env -> type_decl_id -> 'a = - fun _ _ -> self#zero - method visit_field_id : 'env -> field_id -> 'a = fun _ _ -> self#zero end @@ -460,9 +515,7 @@ class virtual ['self] reduce_expression_base = class virtual ['self] mapreduce_expression_base = object (self : 'self) inherit [_] mapreduce_typed_pattern - - method visit_integer_type : 'env -> integer_type -> integer_type * 'a = - fun _ x -> (x, self#zero) + inherit! [_] mapreduce_type_id method visit_var_id : 'env -> var_id -> var_id * 'a = fun _ x -> (x, self#zero) @@ -473,9 +526,6 @@ class virtual ['self] mapreduce_expression_base = method visit_loop_id : 'env -> loop_id -> loop_id * 'a = fun _ x -> (x, self#zero) - method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a = - fun _ x -> (x, self#zero) - method visit_field_id : 'env -> field_id -> field_id * 'a = fun _ x -> (x, self#zero) end @@ -486,7 +536,7 @@ class virtual ['self] mapreduce_expression_base = *) type expression = | Var of var_id (** a variable *) - | Const of primitive_value + | Const of literal | App of texpression * texpression (** Application of a function to an argument. @@ -585,9 +635,21 @@ and loop = { {[ { s with x := 3 } ]} + + We also use struct updates to encode array aggregates, so that whenever + the user writes code like: + {[ + let a : [u32; 2] = [0, 1]; + ... + ]} + this gets encoded to: + {[ + let a : Array u32 2 = Array.mk [0, 1] in + ... + ]} *) and struct_update = { - struct_id : type_decl_id; + struct_id : type_id; init : var_id option; updates : (field_id * texpression) list; } @@ -726,6 +788,7 @@ type fun_sig_info = { *) type fun_sig = { type_params : type_var list; + const_generic_params : const_generic_var list; (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) inputs : ty list; (** The input types. diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 74f3c576..b6025df4 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -585,6 +585,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; type_args = _; + const_generic_args = _; } -> (* Lookup the def *) let decl = @@ -610,7 +611,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (!Config.backend <> Lean && !Config.backend <> Coq) || not is_rec then - let struct_id = adt_id in + let struct_id = AdtId adt_id in let init = None in let updates = FieldId.mapi @@ -1086,6 +1087,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; type_args; + const_generic_args; } -> (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) @@ -1106,7 +1108,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = * [x.field] for some variable [x], and where the projection * is for the proper ADT *) let to_var_proj (i : int) (arg : texpression) : - (ty list * var_id) option = + (ty list * const_generic list * var_id) option = match arg.e with | App (proj, x) -> ( match (proj.e, x.e) with @@ -1115,13 +1117,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = id = Proj { adt_id = AdtId proj_adt_id; field_id }; type_args = proj_type_args; + const_generic_args = proj_const_generic_args; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if proj_adt_id = adt_id && FieldId.to_int field_id = i - then Some (proj_type_args, v) + then + Some (proj_type_args, proj_const_generic_args, v) else None | _ -> None) | _ -> None @@ -1132,12 +1136,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = if List.length args = num_fields then (* Check that this is the same variable we project from - * note that we checked above that there is at least one field *) - let (_, x), end_args = Collections.List.pop args in - if List.for_all (fun (_, y) -> y = x) end_args then ( + let (_, _, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, _, y) -> y = x) end_args then ( (* We can substitute *) (* Sanity check: all types correct *) assert ( - List.for_all (fun (tys, _) -> tys = type_args) args); + List.for_all + (fun (tys, cgs, _) -> + tys = type_args && cgs = const_generic_args) + args); { e with e = Var x }) else super#visit_texpression env e else super#visit_texpression env e @@ -1156,12 +1163,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = Proj { adt_id = AdtId proj_adt_id; field_id }; type_args = _; + const_generic_args = _; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if - proj_adt_id = struct_id && field_id = fid - && x.ty = adt_ty + AdtId proj_adt_id = struct_id + && field_id = fid && x.ty = adt_ty then Some v else None | _ -> None) @@ -1354,6 +1362,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig = { type_params = fun_sig.type_params; + const_generic_params = fun_sig.const_generic_params; inputs = inputs_tys; output; doutputs; @@ -1554,8 +1563,12 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | A.BoxFree, _ -> assert (args = []); mk_unit_rvalue - | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen - | A.VecIndex | A.VecIndexMut ), + | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen + | VecIndex | VecIndexMut | ArraySubsliceShared + | ArraySubsliceMut | SliceIndexShared | SliceIndexMut + | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | SliceLen ), _ ) -> super#visit_texpression env e) | _ -> super#visit_texpression env e) @@ -2130,7 +2143,14 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { type_params; inputs; output; doutputs; info } = + let { + type_params; + const_generic_params; + inputs; + output; + doutputs; + info; + } = decl.signature in let { @@ -2158,7 +2178,16 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : effect_info; } in - let signature = { type_params; inputs; output; doutputs; info } in + let signature = + { + type_params; + const_generic_params; + inputs; + output; + doutputs; + info; + } + in { decl with signature } in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 018ea6b5..8d28bb8a 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -3,10 +3,14 @@ open Pure open PureUtils -(** Utility function, used for type checking *) +(** Utility function, used for type checking. + + This function should only be used for "regular" ADTs, where the number + of fields is fixed: it shouldn't be used for arrays, slices, etc. + *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) - (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) : - ty list = + (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) + (cgs : const_generic list) : ty list = match type_id with | Tuple -> (* Tuple *) @@ -15,7 +19,7 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) | AdtId def_id -> (* "Regular" ADT *) let def = TypeDeclId.Map.find def_id type_decls in - type_decl_get_instantiated_fields_types def variant_id tys + type_decl_get_instantiated_fields_types def variant_id tys cgs | Assumed aty -> ( (* Assumed type *) match aty with @@ -46,8 +50,15 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) if variant_id = option_some_id then [ ty ] else if variant_id = option_none_id then [] else - raise (Failure "Unreachable: improper variant id for result type") - | Vec -> raise (Failure "Unreachable: `Vector` values are opaque")) + raise (Failure "Unreachable: improper variant id for option type") + | Range -> + let ty = Collections.List.to_cons_nil tys in + assert (variant_id = None); + [ ty; ty ] + | Vec | Array | Slice | Str -> + (* Array: when not symbolic values (for instance, because of aggregates), + the array expressions are introduced as struct updates *) + raise (Failure "Attempting to access the fields of an opaque type")) type tc_ctx = { type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *) @@ -56,17 +67,17 @@ type tc_ctx = { env : ty VarId.Map.t; (** Environment from variables to types *) } -let check_primitive_value (v : primitive_value) (ty : ty) : unit = +let check_literal (v : literal) (ty : literal_type) : unit = match (ty, v) with | Integer int_ty, PV.Scalar sv -> assert (int_ty = sv.PV.int_ty) - | Bool, Bool _ | Char, Char _ | Str, String _ -> () + | Bool, Bool _ | Char, Char _ -> () | _ -> raise (Failure "Inconsistent type") let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = log#ldebug (lazy ("check_typed_pattern: " ^ show_typed_pattern v)); match v.value with | PatConstant cv -> - check_primitive_value cv v.ty; + check_literal cv (ty_as_literal v.ty); ctx | PatDummy -> ctx | PatVar (var, _) -> @@ -75,13 +86,9 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = { ctx with env } | PatAdt av -> (* Compute the field types *) - let type_id, tys = - match v.ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Inconsistently typed value") - in + let type_id, tys, cgs = ty_as_adt v.ty in let field_tys = - get_adt_field_types ctx.type_decls type_id av.variant_id tys + get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs in let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx = if ty <> v.ty then ( @@ -108,7 +115,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match VarId.Map.find_opt var_id ctx.env with | None -> () | Some ty -> assert (ty = e.ty)) - | Const cv -> check_primitive_value cv e.ty + | Const cv -> check_literal cv (ty_as_literal e.ty) | App (app, arg) -> let input_ty, output_ty = destruct_arrow app.ty in assert (input_ty = arg.ty); @@ -130,33 +137,31 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) let adt_ty, field_ty = destruct_arrow e.ty in - let adt_id, adt_type_args = - match adt_ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in + let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in (* Check the ADT type *) assert (adt_id = proj_adt_id); assert (adt_type_args = qualif.type_args); + assert (adt_cg_args = qualif.const_generic_args); (* Retrieve and check the expected field type *) let variant_id = None in let expected_field_tys = get_adt_field_types ctx.type_decls proj_adt_id variant_id - qualif.type_args + qualif.type_args qualif.const_generic_args in let expected_field_ty = FieldId.nth expected_field_tys field_id in assert (expected_field_ty = field_ty) | AdtCons id -> ( let expected_field_tys = get_adt_field_types ctx.type_decls id.adt_id id.variant_id - qualif.type_args + qualif.type_args qualif.const_generic_args in let field_tys, adt_ty = destruct_arrows e.ty in assert (expected_field_tys = field_tys); match adt_ty with - | Adt (type_id, tys) -> + | Adt (type_id, tys, cgs) -> assert (type_id = id.adt_id); - assert (tys = qualif.type_args) + assert (tys = qualif.type_args); + assert (cgs = qualif.const_generic_args) | _ -> raise (Failure "Unreachable"))) | Let (monadic, pat, re, e_next) -> let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in @@ -172,7 +177,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx scrut; match switch_body with | If (e_then, e_else) -> - assert (scrut.ty = Bool); + assert (scrut.ty = Literal Bool); assert (e_then.ty = e.ty); assert (e_else.ty = e.ty); check_texpression ctx e_then; @@ -194,30 +199,38 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body - | StructUpdate supd -> + | StructUpdate supd -> ( (* Check the init value *) (if Option.is_some supd.init then - match VarId.Map.find_opt (Option.get supd.init) ctx.env with - | None -> () - | Some ty -> assert (ty = e.ty)); + match VarId.Map.find_opt (Option.get supd.init) ctx.env with + | None -> () + | Some ty -> assert (ty = e.ty)); (* Check the fields *) (* Retrieve and check the expected field type *) - let adt_id, adt_type_args = - match e.ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in - assert (adt_id = AdtId supd.struct_id); - let variant_id = None in - let expected_field_tys = - get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args - in - List.iter - (fun (fid, fe) -> - let expected_field_ty = FieldId.nth expected_field_tys fid in - assert (expected_field_ty = fe.ty); - check_texpression ctx fe) - supd.updates + let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in + assert (adt_id = supd.struct_id); + (* The id can only be: a custom type decl or an array *) + match adt_id with + | AdtId _ -> + let variant_id = None in + let expected_field_tys = + get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args + adt_cg_args + in + List.iter + (fun (fid, fe) -> + let expected_field_ty = FieldId.nth expected_field_tys fid in + assert (expected_field_ty = fe.ty); + check_texpression ctx fe) + supd.updates + | Assumed Array -> + let expected_field_ty = Collections.List.to_cons_nil adt_type_args in + List.iter + (fun (_, fe) -> + assert (expected_field_ty = fe.ty); + check_texpression ctx fe) + supd.updates + | _ -> raise (Failure "Unexpected")) | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 647678c1..1c8d8921 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -62,18 +62,16 @@ let dest_arrow_ty (ty : ty) : ty * ty = | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") -let compute_primitive_value_ty (cv : primitive_value) : ty = +let compute_literal_type (cv : literal) : literal_type = match cv with | PV.Scalar sv -> Integer sv.PV.int_ty | Bool _ -> Bool | Char _ -> Char - | String _ -> Str let var_get_id (v : var) : VarId.id = v.id -let mk_typed_pattern_from_primitive_value (cv : primitive_value) : typed_pattern - = - let ty = compute_primitive_value_ty cv in +let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = + let ty = Literal (compute_literal_type cv) in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) @@ -92,11 +90,13 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) { var_id; name; projection } (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = +let ty_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = let obj = object inherit [_] map_ty method! visit_TypeVar _ var_id = tsubst var_id + method! visit_ConstGenericVar _ var_id = cgsubst var_id end in obj#visit_ty () ty @@ -111,6 +111,10 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp +let make_const_generic_subst (vars : const_generic_var list) + (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = + Substitute.make_const_generic_subst_from_vars vars cgs + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -134,14 +138,17 @@ let type_decl_get_fields (def : type_decl) (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) : ty list = + (opt_variant_id : VariantId.id option) (types : ty list) + (cgs : const_generic list) : ty list = let ty_subst = make_type_subst def.type_params types in + let cg_subst = make_const_generic_subst def.const_generic_params cgs in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst f.field_ty) fields + List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : +let fun_sig_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : inst_fun_sig = - let subst = ty_substitute tsubst in + let subst = ty_substitute tsubst cgsubst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -183,9 +190,9 @@ let is_global (e : texpression) : bool = let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list = +let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = match ty with - | Adt (id, tys) -> (id, tys) + | Adt (id, tys, cgs) -> (id, tys, cgs) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -293,13 +300,19 @@ let opt_destruct_function_call (e : texpression) : let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, tys, cgs) -> + assert (cgs = []); + Some (Collections.List.to_cons_nil tys) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = - match ty with Adt (Tuple, tys) -> Some tys | _ -> None + match ty with + | Adt (Tuple, tys, cgs) -> + assert (cgs = []); + Some tys + | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = let ty = Arrow (x.ty, e.ty) in @@ -353,7 +366,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : let mk_switch (scrut : texpression) (sb : switch_body) : texpression = (* Sanity check: the scrutinee has the proper type *) (match sb with - | If (_, _) -> assert (scrut.ty = Bool) + | If (_, _) -> assert (scrut.ty = Literal Bool) | Match branches -> List.iter (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty)) @@ -370,14 +383,14 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) + match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) -let mk_bool_ty : ty = Bool -let mk_unit_ty : ty = Adt (Tuple, []) +let mk_bool_ty : ty = Literal Bool +let mk_unit_ty : ty = Adt (Tuple, [], []) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -417,7 +430,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -428,11 +441,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys } in + let qualif = { id; type_args = tys; const_generic_args = [] } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -443,36 +456,39 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option) { value; ty = adt_ty } let ty_as_integer (t : ty) : T.integer_type = - match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") + match t with + | Literal (Integer int_ty) -> int_ty + | _ -> raise (Failure "Unreachable") -(* TODO: move *) -let type_decl_is_enum (def : T.type_decl) : bool = - match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false +let ty_as_literal (t : ty) : T.literal_type = + match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) -let mk_error_ty : ty = Adt (Assumed Error, []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, []) +let mk_state_ty : ty = Adt (Assumed State, [], []) +let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) +let mk_error_ty : ty = Adt (Assumed Error, [], []) +let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ]) -> ty + | Adt (Assumed Result, [ ty ], cgs) -> + assert (cgs = []); + ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -485,11 +501,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -498,7 +514,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ]) in + let ty = Adt (Assumed Result, [ ty ], []) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -510,7 +526,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ]) in + let ty = Adt (Assumed Result, [ v.ty ], []) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -545,11 +561,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args = ty_as_adt pat.ty in + let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args } in + let qualif = { id = qualif_id; type_args; const_generic_args } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 5040fd9f..38850243 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -11,7 +11,8 @@ module C = Contexts (** Substitute types variables and regions in a type. *) let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) - (ty : 'r1 T.ty) : 'r2 T.ty = + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) : + 'r2 T.ty = let open T in let visitor = object @@ -22,25 +23,37 @@ let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) method! visit_type_var_id _ _ = (* We should never get here because we reimplemented [visit_TypeVar] *) raise (Failure "Unexpected") + + method! visit_ConstGenericVar _ id = cgsubst id + + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") end in visitor#visit_ty () ty let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) (ty : T.rty) : T.rty = + (tsubst : T.TypeVarId.id -> T.rty) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty = let rsubst r = match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty -let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) (ty : T.ety) : T.ety = +let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety = let rsubst r = r in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) let erase_regions (ty : T.rty) : T.ety = - ty_substitute (fun _ -> T.Erased) (fun vid -> T.TypeVar vid) ty + ty_substitute + (fun _ -> T.Erased) + (fun vid -> T.TypeVar vid) + (fun id -> T.ConstGenericVar id) + ty (** Generate fresh regions for region variables. @@ -59,7 +72,7 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : let ls = List.combine region_vars fresh_region_ids in let rid_map = List.fold_left - (fun mp (k, v) -> T.RegionVarId.Map.add k.T.index v mp) + (fun mp ((k : T.region_var), v) -> T.RegionVarId.Map.add k.T.index v mp) T.RegionVarId.Map.empty ls in (* Generate the substitution from region var id to region *) @@ -73,9 +86,10 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : (** Erase the regions in a type and substitute the type variables *) let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r T.region T.ty) : T.ety = let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty (** Create a region substitution from a list of region variable ids and a list of regions (with which to substitute the region variable ids *) @@ -92,6 +106,12 @@ let make_region_subst (var_ids : T.RegionVarId.id list) | T.Static -> T.Static | T.Var id -> T.RegionVarId.Map.find id mp +let make_region_subst_from_vars (vars : T.region_var list) + (regions : 'r T.region list) : T.RegionVarId.id T.region -> 'r T.region = + make_region_subst + (List.map (fun (x : T.region_var) -> x.T.index) vars) + regions + (** Create a type substitution from a list of type variable ids and a list of types (with which to substitute the type variable ids) *) let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : @@ -104,18 +124,37 @@ let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : in fun id -> T.TypeVarId.Map.find id mp +let make_type_subst_from_vars (vars : T.type_var list) (tys : 'r T.ty list) : + T.TypeVarId.id -> 'r T.ty = + make_type_subst (List.map (fun (x : T.type_var) -> x.T.index) vars) tys + +(** Create a const generic substitution from a list of const generic variable ids and a list of + const generics (with which to substitute the const generic variable ids) *) +let make_const_generic_subst (var_ids : T.ConstGenericVarId.id list) + (cgs : T.const_generic list) : T.ConstGenericVarId.id -> T.const_generic = + let ls = List.combine var_ids cgs in + let mp = + List.fold_left + (fun mp (k, v) -> T.ConstGenericVarId.Map.add k v mp) + T.ConstGenericVarId.Map.empty ls + in + fun id -> T.ConstGenericVarId.Map.find id mp + +let make_const_generic_subst_from_vars (vars : T.const_generic_var list) + (cgs : T.const_generic list) : T.ConstGenericVarId.id -> T.const_generic = + make_const_generic_subst + (List.map (fun (x : T.const_generic_var) -> x.T.index) vars) + cgs + (** Instantiate the type variables in an ADT definition, and return, for every variant, the list of the types of its fields *) let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) - (regions : T.RegionId.id T.region list) (types : T.rty list) : - (T.VariantId.id option * T.rty list) list = - let r_subst = - make_region_subst - (List.map (fun x -> x.T.index) def.T.region_params) - regions - in - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list = + let r_subst = make_region_subst_from_vars def.T.region_params regions in + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let (variants_fields : (T.VariantId.id option * T.field list) list) = match def.T.kind with @@ -133,45 +172,47 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) List.map (fun (id, fields) -> ( id, - List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields - )) + List.map + (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) + fields )) variants_fields (** Instantiate the type variables in an ADT definition, and return the list of types of the fields for the chosen variant *) let type_decl_get_instantiated_field_rtypes (def : T.type_decl) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = - let r_subst = - make_region_subst - (List.map (fun x -> x.T.index) def.T.region_params) - regions - in - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : T.rty list = + let r_subst = make_region_subst_from_vars def.T.region_params regions in + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let fields = TU.type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields + List.map + (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) + fields (** Return the types of the properly instantiated ADT's variant, provided a context *) let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : T.rty list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_rtypes def opt_variant_id regions types + type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs (** Return the types of the properly instantiated ADT value (note that here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *) let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) (adt : V.adt_value) (id : T.type_id) - (region_params : T.RegionId.id T.region list) (type_params : T.rty list) : - T.rty list = + (region_params : T.RegionId.id T.region list) (type_params : T.rty list) + (cg_params : T.const_generic list) : T.rty list = match id with | T.AdtId id -> (* Retrieve the types of the fields *) ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id - region_params type_params + region_params type_params cg_params | T.Tuple -> assert (List.length region_params = 0); type_params @@ -180,182 +221,121 @@ let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) | T.Box | T.Vec -> assert (List.length region_params = 0); assert (List.length type_params = 1); + assert (List.length cg_params = 0); type_params | T.Option -> assert (List.length region_params = 0); assert (List.length type_params = 1); + assert (List.length cg_params = 0); if adt.V.variant_id = Some T.option_some_id then type_params else if adt.V.variant_id = Some T.option_none_id then [] - else raise (Failure "Unreachable")) + else raise (Failure "Unreachable") + | T.Range -> + assert (List.length region_params = 0); + assert (List.length type_params = 1); + assert (List.length cg_params = 0); + type_params + | T.Array | T.Slice | T.Str -> + (* Those types don't have fields *) + raise (Failure "Unreachable")) (** Instantiate the type variables in an ADT definition, and return the list of types of the fields for the chosen variant *) let type_decl_get_instantiated_field_etypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) (types : T.ety list) : T.ety list = - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (opt_variant_id : T.VariantId.id option) (types : T.ety list) + (cgs : T.const_generic list) : T.ety list = + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let fields = TU.type_decl_get_fields def opt_variant_id in List.map - (fun f -> erase_regions_substitute_types ty_subst f.T.field_ty) + (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a context *) let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (types : T.ety list) : T.ety list = + (types : T.ety list) (cgs : T.const_generic list) : T.ety list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_etypes def opt_variant_id types + type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + +let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) = + object + inherit [_] A.map_statement + method! visit_ety _ ty = ety_substitute tsubst cgsubst ty + method! visit_ConstGenericVar _ id = cgsubst id + + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") + end (** Apply a type substitution to a place *) -let place_substitute (_tsubst : T.TypeVarId.id -> T.ety) (p : E.place) : E.place - = - (* There is nothing to do *) - p +let place_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) : + E.place = + (* There is in fact nothing to do *) + (statement_substitute_visitor tsubst cgsubst)#visit_place () p (** Apply a type substitution to an operand *) -let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : +let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) : E.operand = - let p_subst = place_substitute tsubst in - match op with - | E.Copy p -> E.Copy (p_subst p) - | E.Move p -> E.Move (p_subst p) - | E.Constant (ety, cv) -> - let rsubst x = x in - E.Constant (ty_substitute rsubst tsubst ety, cv) + (statement_substitute_visitor tsubst cgsubst)#visit_operand () op (** Apply a type substitution to an rvalue *) -let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) (rv : E.rvalue) : +let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) : E.rvalue = - let op_subst = operand_substitute tsubst in - let p_subst = place_substitute tsubst in - match rv with - | E.Use op -> E.Use (op_subst op) - | E.Ref (p, bkind) -> E.Ref (p_subst p, bkind) - | E.UnaryOp (unop, op) -> E.UnaryOp (unop, op_subst op) - | E.BinaryOp (binop, op1, op2) -> - E.BinaryOp (binop, op_subst op1, op_subst op2) - | E.Discriminant p -> E.Discriminant (p_subst p) - | E.Global _ -> (* Globals don't have type parameters *) rv - | E.Aggregate (kind, ops) -> - let ops = List.map op_subst ops in - let kind = - match kind with - | E.AggregatedTuple -> E.AggregatedTuple - | E.AggregatedOption (variant_id, ty) -> - let rsubst r = r in - E.AggregatedOption (variant_id, ty_substitute rsubst tsubst ty) - | E.AggregatedAdt (def_id, variant_id, regions, tys) -> - let rsubst r = r in - E.AggregatedAdt - ( def_id, - variant_id, - regions, - List.map (ty_substitute rsubst tsubst) tys ) - in - E.Aggregate (kind, ops) + (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv (** Apply a type substitution to an assertion *) -let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) (a : A.assertion) : +let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) : A.assertion = - { A.cond = operand_substitute tsubst a.A.cond; A.expected = a.A.expected } + (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a (** Apply a type substitution to a call *) -let call_substitute (tsubst : T.TypeVarId.id -> T.ety) (call : A.call) : A.call - = - let rsubst x = x in - let type_args = List.map (ty_substitute rsubst tsubst) call.A.type_args in - let args = List.map (operand_substitute tsubst) call.A.args in - let dest = place_substitute tsubst call.A.dest in - (* Putting all the paramters on purpose: we want to get a compiler error if - something moves - we may add a field on which we need to apply a substitution *) - { - func = call.A.func; - region_args = call.A.region_args; - A.type_args; - args; - dest; - } - -(** Apply a type substitution to a statement - TODO: use a map iterator *) -let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (st : A.statement) : A.statement = - { st with A.content = raw_statement_substitute tsubst st.content } - -and raw_statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (st : A.raw_statement) : A.raw_statement = - match st with - | A.Assign (p, rvalue) -> - let p = place_substitute tsubst p in - let rvalue = rvalue_substitute tsubst rvalue in - A.Assign (p, rvalue) - | A.FakeRead p -> - let p = place_substitute tsubst p in - A.FakeRead p - | A.SetDiscriminant (p, vid) -> - let p = place_substitute tsubst p in - A.SetDiscriminant (p, vid) - | A.Drop p -> - let p = place_substitute tsubst p in - A.Drop p - | A.Assert assertion -> - let assertion = assertion_substitute tsubst assertion in - A.Assert assertion - | A.Call call -> - let call = call_substitute tsubst call in - A.Call call - | A.Panic | A.Return | A.Break _ | A.Continue _ | A.Nop -> st - | A.Sequence (st1, st2) -> - A.Sequence - (statement_substitute tsubst st1, statement_substitute tsubst st2) - | A.Switch switch -> A.Switch (switch_substitute tsubst switch) - | A.Loop le -> A.Loop (statement_substitute tsubst le) - -(** Apply a type substitution to a switch *) -and switch_substitute (tsubst : T.TypeVarId.id -> T.ety) (switch : A.switch) : - A.switch = - match switch with - | A.If (op, st1, st2) -> - A.If - ( operand_substitute tsubst op, - statement_substitute tsubst st1, - statement_substitute tsubst st2 ) - | A.SwitchInt (op, int_ty, tgts, otherwise) -> - let tgts = - List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts - in - let otherwise = statement_substitute tsubst otherwise in - A.SwitchInt (operand_substitute tsubst op, int_ty, tgts, otherwise) - | A.Match (p, tgts, otherwise) -> - let tgts = - List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts - in - let otherwise = statement_substitute tsubst otherwise in - A.Match (place_substitute tsubst p, tgts, otherwise) +let call_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) : + A.call = + (statement_substitute_visitor tsubst cgsubst)#visit_call () call + +(** Apply a type substitution to a statement *) +let statement_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) : + A.statement = + (statement_substitute_visitor tsubst cgsubst)#visit_statement () st (** Apply a type substitution to a function body. Return the local variables and the body. *) let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) - (body : A.fun_body) : A.var list * A.statement = + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) : + A.var list * A.statement = let rsubst r = r in let locals = List.map - (fun v -> { v with A.var_ty = ty_substitute rsubst tsubst v.A.var_ty }) + (fun (v : A.var) -> + { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty }) body.A.locals in - let body = statement_substitute tsubst body.body in + let body = statement_substitute tsubst cgsubst body.body in (locals, body) (** Substitute a function signature *) let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) (rsubst : T.RegionVarId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) (sg : A.fun_sig) : A.inst_fun_sig = + (tsubst : T.TypeVarId.id -> T.rty) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) : + A.inst_fun_sig = let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) in - let inputs = List.map (ty_substitute rsubst' tsubst) sg.A.inputs in - let output = ty_substitute rsubst' tsubst sg.A.output in + let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in + let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in let subst_region_group (rg : T.region_var_group) : A.abs_region_group = let id = asubst rg.id in let regions = List.map rsubst rg.regions in @@ -366,7 +346,8 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) { A.regions_hierarchy; inputs; output } (** Substitute type variable identifiers in a type *) -let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) +let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) : 'r T.ty = let open T in let visitor = @@ -374,6 +355,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) inherit [_] map_ty method visit_'r _ r = r method! visit_type_var_id _ id = tsubst id + method! visit_const_generic_var_id _ id = cgsubst id end in @@ -384,7 +366,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) We want to write a class which visits abstractions, values, etc. *and their types* to substitute identifiers. - The issue comes is that we derive two specialized types (ety and rty) from a + The issue is that we derive two specialized types (ety and rty) from a polymorphic type (ty). Because of this, there are typing issues with [visit_'r] if we define a class which visits objects of types [ety] and [rty] while inheriting a class which visit [ty]... @@ -392,6 +374,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) = @@ -403,6 +386,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) method! visit_type_var_id _ id = tsubst id + method! visit_const_generic_var_id _ id = cgsubst id end in @@ -411,7 +395,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] C.map_env method! visit_borrow_id _ bid = bsubst bid method! visit_loan_id _ bid = bsubst bid - method! visit_ety _ ty = ty_substitute_ids tsubst ty + method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty method! visit_rty env ty = subst_rty#visit_ty env ty method! visit_symbolic_value_id _ id = ssubst id @@ -444,11 +428,12 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) : V.typed_value = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) #visit_typed_value v let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) @@ -458,33 +443,39 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) v let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) #visit_typed_avalue v let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs = - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst)#visit_abs x + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + #visit_abs x let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env = - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst)#visit_env x + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + #visit_env x let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : V.typed_avalue) : V.typed_avalue = @@ -494,6 +485,7 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) asubst) #visit_typed_avalue x @@ -505,6 +497,7 @@ let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) (fun x -> x)) #visit_env x diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 0e68d2fd..7dc94dcd 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -43,7 +43,10 @@ type call = { borrows (we need to perform lookups). *) abstractions : V.AbstractionId.id list; + (* TODO: rename to "...args" *) type_params : T.ety list; + (* TODO: rename to "...args" *) + const_generic_params : T.const_generic list; args : V.typed_value list; args_places : mplace option list; (** Meta information *) dest : V.symbolic_value; @@ -164,7 +167,7 @@ type expression = Contexts.eval_ctx * mplace option * V.symbolic_value - * V.typed_value + * value_aggregate * expression (** We introduce a new symbolic value, equal to some other value. @@ -243,6 +246,13 @@ and expansion = T.integer_type * (V.scalar_value * expression) list * expression (** An integer expansion (i.e, a switch over an integer). The last expression is for the "otherwise" branch. *) + +(* Remark: this type doesn't have to be mutually recursive with the other + types, but it makes it easy to generate the visitors *) +and value_aggregate = + | SingleValue of V.typed_value (** Regular case *) + | Array of V.typed_value list + (** This is used when introducing array aggregates *) [@@deriving show, visitors diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 5dc8664a..3512270a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -107,6 +107,7 @@ type loop_info = { input_vars : var list; input_svl : V.symbolic_value list; type_args : ty list; + const_generic_args : const_generic list; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; @@ -240,6 +241,8 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = r_to_string; type_var_id_to_string; type_decl_id_to_string = ast_fmt.type_decl_id_to_string; + const_generic_var_id_to_string = ast_fmt.const_generic_var_id_to_string; + global_decl_id_to_string = ast_fmt.global_decl_id_to_string; adt_variant_to_string = ast_fmt.adt_variant_to_string; var_id_to_string; adt_field_names = ast_fmt.adt_field_names; @@ -247,10 +250,12 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = let type_params = ctx.fun_decl.signature.type_params in + let cg_params = ctx.fun_decl.signature.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string = let fmt = bs_ctx_to_ctx_formatter ctx in @@ -273,8 +278,12 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = let type_params = def.type_params in + let cg_params = def.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in + let global_decls = ctx.global_context.llbc_global_decls in + let fmt = + PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + in PrintPure.type_decl_to_string fmt def let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = @@ -283,21 +292,25 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = let type_params = sg.type_params in + let cg_params = sg.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in + let cg_params = def.signature.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_decl_to_string fmt def @@ -315,16 +328,17 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = Print.Values.abs_to_string fmt verbose indent indent_incr abs let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) : - inst_fun_sig = + (back_id : T.RegionGroupId.id option) (tys : ty list) + (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig = (* Lookup the non-instantiated function signature *) let sg = (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg in (* Create the substitution *) let tsubst = make_type_subst sg.type_params tys in + let cgsubst = make_const_generic_subst sg.const_generic_params cgs in (* Apply *) - fun_sig_substitute tsubst sg + fun_sig_substitute tsubst cgsubst sg let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = @@ -380,17 +394,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys) -> ( + | T.Adt (type_id, regions, tys, cgs) -> ( (* Can't translate types with regions for now *) assert (regions = []); let tys = List.map translate tys in match type_id with - | T.AdtId adt_id -> Adt (AdtId adt_id, tys) + | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs) | T.Tuple -> mk_simpl_tuple_ty tys | T.Assumed aty -> ( match aty with - | T.Vec -> Adt (Assumed Vec, tys) - | T.Option -> Adt (Assumed Option, tys) + | T.Vec -> Adt (Assumed Vec, tys, cgs) + | T.Option -> Adt (Assumed Option, tys, cgs) | T.Box -> ( (* Eliminate the boxes *) match tys with @@ -399,15 +413,14 @@ let rec translate_sty (ty : T.sty) : ty = raise (Failure "Box/vec/option type with incorrect number of arguments") - ))) + ) + | T.Array -> Adt (Assumed Array, tys, cgs) + | T.Slice -> Adt (Assumed Slice, tys, cgs) + | T.Str -> Adt (Assumed Str, tys, cgs) + | T.Range -> Adt (Assumed Range, tys, cgs))) | TypeVar vid -> TypeVar vid - | Bool -> Bool - | Char -> Char + | Literal ty -> Literal ty | Never -> raise (Failure "Unreachable") - | Integer int_ty -> Integer int_ty - | Str -> Str - | Array ty -> Array (translate ty) - | Slice ty -> Slice (translate ty) | Ref (_, rty, _) -> translate rty let translate_field (f : T.field) : field = @@ -445,34 +458,49 @@ let translate_type_decl (def : T.type_decl) : type_decl = (* Can't translate types with regions for now *) assert (def.region_params = []); let type_params = def.type_params in + let const_generic_params = def.const_generic_params in let kind = translate_type_decl_kind def.T.kind in - { def_id; name; type_params; kind } + { def_id; name; type_params; const_generic_params; kind } + +let translate_type_id (id : T.type_id) : type_id = + match id with + | AdtId adt_id -> AdtId adt_id + | T.Assumed aty -> + let aty = + match aty with + | T.Vec -> Vec + | T.Option -> Option + | T.Array -> Array + | T.Slice -> Slice + | T.Str -> Str + | T.Range -> Range + | T.Box -> + (* Boxes have to be eliminated: this type id shouldn't + be translated *) + raise (Failure "Unreachable") + in + Assumed aty + | T.Tuple -> Tuple (** Translate a type, seen as an input/output of a forward function (preserve all borrows, etc.) *) - let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty type_infos in match ty with - | T.Adt (type_id, regions, tys) -> ( + | T.Adt (type_id, regions, tys, cgs) -> ( (* Can't translate types with regions for now *) assert (regions = []); (* Translate the type parameters *) let t_tys = List.map translate tys in (* Eliminate boxes and simplify tuples *) match type_id with - | AdtId _ | T.Assumed (T.Vec | T.Option) -> + | AdtId _ + | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> (* No general parametricity for now *) assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); - let type_id = - match type_id with - | AdtId adt_id -> AdtId adt_id - | T.Assumed T.Vec -> Assumed Vec - | T.Assumed T.Option -> Assumed Option - | _ -> raise (Failure "Unreachable") - in - Adt (type_id, t_tys) + let type_id = translate_type_id type_id in + Adt (type_id, t_tys, cgs) | Tuple -> (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the identity *) @@ -489,17 +517,8 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = "Unreachable: box/vec/option receives exactly one type \ parameter"))) | TypeVar vid -> TypeVar vid - | Bool -> Bool - | Char -> Char | Never -> raise (Failure "Unreachable") - | Integer int_ty -> Integer int_ty - | Str -> Str - | Array ty -> - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - Array (translate ty) - | Slice ty -> - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - Slice (translate ty) + | Literal lty -> Literal lty | Ref (_, rty, _) -> translate rty (** Simply calls [translate_fwd_ty] *) @@ -519,21 +538,16 @@ let rec translate_back_ty (type_infos : TA.type_infos) (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with - | T.Adt (type_id, _, tys) -> ( + | T.Adt (type_id, _, tys, cgs) -> ( match type_id with - | T.AdtId _ | Assumed (T.Vec | T.Option) -> + | T.AdtId _ + | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows type_infos ty)); - let type_id = - match type_id with - | T.AdtId id -> AdtId id - | T.Assumed T.Vec -> Assumed Vec - | T.Assumed T.Option -> Assumed Option - | T.Tuple | T.Assumed T.Box -> raise (Failure "Unreachable") - in + let type_id = translate_type_id type_id in if inside_mut then let tys_t = List.filter_map translate tys in - Some (Adt (type_id, tys_t)) + Some (Adt (type_id, tys_t, cgs)) else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) @@ -555,17 +569,8 @@ let rec translate_back_ty (type_infos : TA.type_infos) * is the identity *) Some (mk_simpl_tuple_ty tys_t))) | TypeVar vid -> wrap (TypeVar vid) - | Bool -> wrap Bool - | Char -> wrap Char | Never -> raise (Failure "Unreachable") - | Integer int_ty -> wrap (Integer int_ty) - | Str -> wrap Str - | Array ty -> ( - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - match translate ty with None -> None | Some ty -> Some (Array ty)) - | Slice ty -> ( - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - match translate ty with None -> None | Some ty -> Some (Slice ty)) + | Literal lty -> wrap (Literal lty) | Ref (r, rty, rkind) -> ( match rkind with | T.Shared -> @@ -801,8 +806,9 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Wrap in a result type *) if effect_info.can_fail then mk_result_ty output else output in - (* Type parameters *) + (* Type/const generic parameters *) let type_params = sg.type_params in + let const_generic_params = sg.const_generic_params in (* Return *) let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in @@ -830,7 +836,9 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) effect_info; } in - let sg = { type_params; inputs; output; doutputs; info } in + let sg = + { type_params; const_generic_params; inputs; output; doutputs; info } + in { sg; output_names } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = @@ -909,7 +917,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = (** Peel boxes as long as the value is of the form [Box<T>] *) let rec unbox_typed_value (v : V.typed_value) : V.typed_value = match (v.value, v.ty) with - | V.Adt av, T.Adt (T.Assumed T.Box, _, _) -> ( + | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> ( match av.field_values with | [ bv ] -> unbox_typed_value bv | _ -> raise (Failure "Unreachable")) @@ -948,26 +956,22 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the value *) let value = match v.value with - | V.Primitive cv -> { e = Const cv; ty } + | V.Literal cv -> { e = Const cv; ty } | Adt av -> ( let variant_id = av.variant_id in let field_values = List.map translate av.field_values in (* Eliminate the tuple wrapper if it is a tuple with exactly one field *) match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> assert (variant_id = None); mk_simpl_tuple_texpression field_values | _ -> - (* Retrieve the type and the translated type arguments from the - * translated type (simpler this way) *) - let adt_id, type_args = - match ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in + (* Retrieve the type, the translated type arguments and the + * const generic arguments from the translated type (simpler this way) *) + let adt_id, type_args, const_generic_args = ty_as_adt ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args } in + let qualif = { id = qualif_id; type_args; const_generic_args } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) field_values @@ -1034,9 +1038,11 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the field values *) let field_values = List.filter_map translate adt_v.field_values in (* For now, only tuples can contain borrows *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + | T.AdtId _ + | T.Assumed + (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> assert (field_values = []); None | T.Tuple -> @@ -1177,11 +1183,13 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) in let field_values = List.filter_map (fun x -> x) field_values in (* For now, only tuples can contain borrows - note that if we gave - * something like a [&mut Vec] to a function, we give give back the + * something like a [&mut Vec] to a function, we give back the * vector value upon visiting the "abstraction borrow" node *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + | T.AdtId _ + | T.Assumed + (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> assert (field_values = []); (ctx, None) | T.Tuple -> @@ -1451,6 +1459,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = (* Translate the function call *) let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in + let const_generic_args = call.const_generic_params in let args = let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in let args_mplaces = List.map translate_opt_mplace call.args_places in @@ -1552,7 +1561,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let func = { id = FunOrOp fun_id; type_args } in + let func = { id = FunOrOp fun_id; type_args; const_generic_args } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty @@ -1613,13 +1622,13 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) * to the backward function, and which consumed the values [consumed_i], * we introduce: * {[ - * let v_i = consumed_i in - * ... - * ]} + * let v_i = consumed_i in + * ... + * ]} * Then, when we reach the [Return] node, we introduce: * {[ - * (v_i) - * ]} + * (v_i) + * ]} * *) (* First, get the given back variables. @@ -1684,6 +1693,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in + let const_generic_args = call.const_generic_params in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in (* Retrieve the values consumed when we called the forward function and @@ -1732,7 +1742,10 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) in (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) let _ = - let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx in + let inst_sg = + get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args + ctx + in log#ldebug (lazy ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" @@ -1775,7 +1788,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; type_args } in + let func = { id = FunOrOp func; type_args; const_generic_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -1838,7 +1851,7 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) {[ let id_back x nx = let s = nx in // the name [s] is not important (only collision matters) - ... + ... ]} This let-binding later gets inlined, during a micro-pass. @@ -1899,6 +1912,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) in let loop_info = LoopId.Map.find loop_id ctx.loops in let type_args = loop_info.type_args in + let const_generic_args = loop_info.const_generic_args in let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we @@ -1947,7 +1961,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) in let func_ty = mk_arrows input_tys ret_ty in let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; type_args } in + let func = { id = FunOrOp func; type_args; const_generic_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -2007,7 +2021,9 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in - let global_expr = { id = Global gid; type_args = [] } in + let global_expr = + { id = Global gid; type_args = []; const_generic_args = [] } + in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in let gval = { e = Qualif global_expr; ty } in @@ -2020,8 +2036,14 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) let monadic = true in let v = typed_value_to_texpression ctx ectx v in let args = [ v ] in - let func = { id = FunOrOp (Fun (Pure Assert)); type_args = [] } in - let func_ty = mk_arrow Bool mk_unit_ty in + let func = + { + id = FunOrOp (Fun (Pure Assert)); + type_args = []; + const_generic_args = []; + } + in + let func_ty = mk_arrow (Literal Bool) mk_unit_ty in let func = { e = Qualif func; ty = func_ty } in let assertion = mk_apps func args in mk_let monadic (mk_dummy_pattern mk_unit_ty) assertion next_e @@ -2036,13 +2058,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match exp with | ExpandNoBranch (sexp, e) -> ( match sexp with - | V.SePrimitive _ -> - (* Actually, we don't *register* symbolic expansions to constant - * values in the symbolic ADT *) + | V.SeLiteral _ -> + (* We do not *register* symbolic expansions to literal + * values in the symbolic ADT *) raise (Failure "Unreachable") | SeMutRef (_, nsv) | SeSharedRef (_, nsv) -> (* The (mut/shared) borrow type is extracted to identity: we thus simply - * introduce an reassignment *) + * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let next_e = translate_expression e ctx in let monadic = false in @@ -2063,10 +2085,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. - We can decompose the ADT value with a let-binding, unless - the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): - we *ignore* this branch (and go to the next one) if the ADT is a custom - adt, and [always_deconstruct_adts_with_matches] is true. + We can decompose the ADT value with a let-binding, unless + the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): + we *ignore* this branch (and go to the next one) if the ADT is a custom + adt, and [always_deconstruct_adts_with_matches] is true. *) translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace variant_id svl branch ctx @@ -2115,14 +2137,14 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : match_branch = (* We don't need to update the context: we don't introduce any - * new values/variables *) + * new values/variables *) let branch = translate_expression branch_e ctx in - let pat = mk_typed_pattern_from_primitive_value (PV.Scalar v) in + let pat = mk_typed_pattern_from_literal (PV.Scalar v) in { pat; branch } in let branches = List.map translate_branch branches in let otherwise = translate_expression otherwise ctx in - let pat_ty = Integer int_ty in + let pat_ty = Literal (Integer int_ty) in let otherwise_pat : typed_pattern = { value = PatDummy; ty = pat_ty } in let otherwise : match_branch = { pat = otherwise_pat; branch = otherwise } @@ -2142,18 +2164,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) There are several possibilities: - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding: - {[ - let Cons x0 ... xn = y in - ... - ]} + {[ + let Cons x0 ... xn = y in + ... + ]} - if the ADT is a structure, we attempt to introduce one let-binding per field: - {[ - let x0 = y.f0 in - ... + {[ + let x0 = y.f0 in + ... let xn = y.fn in ... - ]} + ]} Of course, this is not always possible depending on the backend. Also, recursive structures, and more specifically structures mutually recursive @@ -2167,14 +2189,14 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (branch : S.expression) (ctx : bs_ctx) : texpression = (* TODO: always introduce a match, and use micro-passes to turn the the match into a let? *) - let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let branch = translate_expression branch ctx in match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in - let is_enum = type_decl_is_enum tdef in + let is_enum = TypesUtils.type_decl_is_enum tdef in (* We deconstruct the ADT with a let-binding in two situations: - if the ADT is an enumeration (which must have exactly one branch) - if we forbid using field projectors. @@ -2202,14 +2224,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) * field. * We use the [dest] variable in order not to have to recompute * the type of the result of the projection... *) - let adt_id, type_args = - match scrutinee.ty with - | Adt (adt_id, tys) -> (adt_id, tys) - | _ -> raise (Failure "Unreachable") - in + let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args } in + let qualif = { id = Proj proj_kind; type_args; const_generic_args } in let proj_e = Qualif qualif in let proj_ty = mk_arrow scrutinee.ty dest.ty in let proj = { e = proj_e; ty = proj_ty } in @@ -2241,31 +2259,47 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (mk_typed_pattern_from_var var None) (mk_opt_mplace_texpression scrutinee_mplace scrutinee) branch - | T.Assumed T.Vec -> - (* We can't expand vector values: we can access the fields only + | T.Assumed (T.Vec | T.Array | T.Slice | T.Str) -> + (* We can't expand those values: we can access the fields only * through the functions provided by the API (note that we don't - * know how to expand a vector, because it has a variable number + * know how to expand values like vectors or arrays, because they have a variable number * of fields!) *) - raise (Failure "Can't expand a vector value") + raise (Failure "Attempt to expand a non-expandable value") + | T.Assumed Range -> raise (Failure "Unimplemented") | T.Assumed T.Option -> (* We shouldn't get there in the "one-branch" case: options have * two variants *) raise (Failure "Unreachable") and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) - (sv : V.symbolic_value) (v : V.typed_value) (e : S.expression) + (sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression) (ctx : bs_ctx) : texpression = let mplace = translate_opt_mplace p in (* Introduce a fresh variable for the symbolic value *) let ctx, var = fresh_var_for_symbolic_value sv ctx in - (* Translate the value *) - let v = typed_value_to_texpression ctx ectx v in - (* Translate the next expression *) let next_e = translate_expression e ctx in + (* Translate the value: there are two cases, depending on whether this + is a "regular" let-binding or an array aggregate. + *) + let v = + match v with + | SingleValue v -> typed_value_to_texpression ctx ectx v + | Array values -> + (* We use a struct update to encode the array aggregate, in order + to preserve the structure and allow generating code of the shape + `[x0, ...., xn]` *) + let values = List.map (typed_value_to_texpression ctx ectx) values in + let values = FieldId.mapi (fun fid v -> (fid, v)) values in + let su : struct_update = + { struct_id = Assumed Array; init = None; updates = values } + in + { e = StructUpdate su; ty = var.ty } + in + (* Make the let-binding *) let monadic = false in let var = mk_typed_pattern_from_var var mplace in @@ -2382,7 +2416,13 @@ and translate_forward_end (ectx : C.eval_ctx) let loop_call = let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in - let func = { id = FunOrOp fun_id; type_args = loop_info.type_args } in + let func = + { + id = FunOrOp fun_id; + type_args = loop_info.type_args; + const_generic_args = loop_info.const_generic_args; + } + in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty @@ -2503,7 +2543,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (and will introduce the outputs at that moment, together with the actual call to the loop forward function *) let type_args = - List.map (fun ty -> TypeVar ty.T.index) ctx.sg.type_params + List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params + in + let const_generic_args = + List.map + (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) + ctx.sg.const_generic_params in let loop_info = @@ -2512,6 +2557,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_vars = inputs; input_svl = loop.input_svalues; type_args; + const_generic_args; forward_inputs = None; forward_output_no_state_no_result = None; } @@ -2599,14 +2645,26 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) *) (* Create the expression: [fuel0 = 0] *) let check_fuel = - let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in + let func = + { + id = FunOrOp (Fun (Pure FuelEqZero)); + type_args = []; + const_generic_args = []; + } + in let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in let func = { e = Qualif func; ty = func_ty } in mk_app func fuel0 in (* Create the expression: [decrease fuel0] *) let decrease_fuel = - let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in + let func = + { + id = FunOrOp (Fun (Pure FuelDecrease)); + type_args = []; + const_generic_args = []; + } + in let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in let func = { e = Qualif func; ty = func_ty } in mk_app func fuel0 diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 6668c043..857fea97 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -32,16 +32,16 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) (* Match on the symbolic value type to know which can of expansion happened *) let expansion = match sv.V.sv_ty with - | T.Bool -> ( + | T.Literal PV.Bool -> ( (* Boolean expansion: there should be two branches *) match ls with | [ - (Some (V.SePrimitive (PV.Bool true)), true_exp); - (Some (V.SePrimitive (PV.Bool false)), false_exp); + (Some (V.SeLiteral (PV.Bool true)), true_exp); + (Some (V.SeLiteral (PV.Bool false)), false_exp); ] -> ExpandBool (true_exp, false_exp) | _ -> raise (Failure "Ill-formed boolean expansion")) - | T.Integer int_ty -> + | T.Literal (PV.Integer int_ty) -> (* Switch over an integer: split between the "regular" branches and the "otherwise" branch (which should be the last branch) *) let branches, otherwise = C.List.pop_last ls in @@ -50,7 +50,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) let get_scalar (see : V.symbolic_expansion option) : V.scalar_value = match see with - | Some (V.SePrimitive (PV.Scalar cv)) -> + | Some (V.SeLiteral (PV.Scalar cv)) -> assert (cv.PV.int_ty = int_ty); cv | _ -> raise (Failure "Unreachable") @@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) assert (otherwise_see = None); (* Return *) ExpandInt (int_ty, branches, otherwise) - | T.Adt (_, _, _) -> + | T.Adt (_, _, _, _) -> (* Branching: it is necessarily an enumeration expansion *) let get_variant (see : V.symbolic_expansion option) : T.VariantId.id option * V.symbolic_value list = @@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) match ls with | [ (Some see, exp) ] -> ExpandNoBranch (see, exp) | _ -> raise (Failure "Ill-formed borrow expansion")) - | T.TypeVar _ | Char | Never | Str | Array _ | Slice _ -> + | T.TypeVar _ | T.Literal Char | Never -> raise (Failure "Ill-formed symbolic expansion") in Some (Expansion (place, sv, expansion)) @@ -98,9 +98,9 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value) let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (args : V.typed_value list) (args_places : mplace option list) - (dest : V.symbolic_value) (dest_place : mplace option) - (e : expression option) : expression option = + (const_generic_params : T.const_generic list) (args : V.typed_value list) + (args_places : mplace option list) (dest : V.symbolic_value) + (dest_place : mplace option) (e : expression option) : expression option = Option.map (fun e -> let call = @@ -109,6 +109,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) ctx; abstractions; type_params; + const_generic_params; args; dest; args_places; @@ -125,24 +126,25 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) let synthesize_regular_function_call (fun_id : A.fun_id) (call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (args : V.typed_value list) (args_places : mplace option list) - (dest : V.symbolic_value) (dest_place : mplace option) - (e : expression option) : expression option = + (const_generic_params : T.const_generic list) (args : V.typed_value list) + (args_places : mplace option list) (dest : V.symbolic_value) + (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions type_params args args_places dest dest_place e + ctx abstractions type_params const_generic_params args args_places dest + dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop) (arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Unop unop) ctx [] [] [ arg ] [ arg_place ] dest + synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop) (arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value) (arg1_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Binop binop) ctx [] [] [ arg0; arg1 ] + synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index c5f7df92..6c6435c5 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -833,6 +833,8 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) Printf.fprintf out "Require Import Primitives.\n"; Printf.fprintf out "Import Primitives.\n"; Printf.fprintf out "Require Import Coq.ZArith.ZArith.\n"; + Printf.fprintf out "Require Import List.\n"; + Printf.fprintf out "Import ListNotations.\n"; Printf.fprintf out "Local Open Scope Primitives_scope.\n"; (* Add the custom imports *) @@ -855,7 +857,7 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) (* Always open the Primitives namespace *) Printf.fprintf out "open Primitives\n"; (* If we are inside the namespace: declare it, otherwise: open it *) - if fi.in_namespace then Printf.fprintf out "namespace %s\n" fi.namespace + if fi.in_namespace then Printf.fprintf out "\nnamespace %s\n" fi.namespace else Printf.fprintf out "open %s\n" fi.namespace | HOL4 -> Printf.fprintf out "open primitivesLib divDefLib\n"; @@ -1332,7 +1334,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - (* Add the extension for F* *) extract_file gen_config gen_ctx file_info); (* Generate the build file *) @@ -1350,7 +1351,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : * Generate the library entry point, if the crate is split between * different files. *) - if !Config.split_files then ( + if !Config.split_files && !Config.generate_lib_entry_point then ( let filename = Filename.concat dest_dir (crate_name ^ ".lean") in let out = open_out filename in (* Write *) diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index 9ba73c7e..ba5e237b 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -32,33 +32,39 @@ type pure_fun_translation = fun_and_loops * fun_and_loops list let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = let type_params = def.type_params in + let cg_params = def.const_generic_params in let type_decls = ctx.type_context.type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in + let global_decls = ctx.global_context.global_decls in + let fmt = + PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + in PrintPure.type_decl_to_string fmt def -let type_id_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = - let type_params = def.type_params in - let type_decls = ctx.type_context.type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in - PrintPure.type_decl_to_string fmt def +let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string = + Print.fun_name_to_string + (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = let type_params = sg.type_params in + let cg_params = sg.const_generic_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in let global_decls = ctx.global_context.global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in + let cg_params = def.signature.const_generic_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in let global_decls = ctx.global_context.global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_decl_to_string fmt def diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index db6d3b98..925f6d39 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -122,7 +122,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let rec analyze (expl_info : expl_info) (ty_info : partial_type_info) (ty : 'r ty) : partial_type_info = match ty with - | Bool | Char | Never | Integer _ | Str -> ty_info + | Literal _ | Never -> ty_info | TypeVar var_id -> ( (* Update the information for the proper parameter, if necessary *) match ty_info.param_infos with @@ -145,9 +145,6 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in let param_infos = Some param_infos in { ty_info with param_infos }) - | Array ty | Slice ty -> - (* Just dive in *) - analyze expl_info ty_info ty | Ref (r, rty, rkind) -> (* Update the type info *) let contains_static = r_is_static r in @@ -172,12 +169,16 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in (* Continue exploring *) analyze expl_info ty_info rty - | Adt ((Tuple | Assumed (Box | Vec | Option)), _, tys) -> + | Adt + ( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)), + _, + tys, + _ ) -> (* Nothing to update: just explore the type parameters *) List.fold_left (fun ty_info ty -> analyze expl_info ty_info ty) ty_info tys - | Adt (AdtId adt_id, regions, tys) -> + | Adt (AdtId adt_id, regions, tys, _cgs) -> (* Lookup the information for this type definition *) let adt_info = TypeDeclId.Map.find adt_id infos in (* Update the type info with the information from the adt *) diff --git a/compiler/Values.ml b/compiler/Values.ml index d4eff58a..d884c319 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -14,7 +14,7 @@ module LoopId = IdGen () type big_int = PrimitiveValues.big_int [@@deriving show, ord] type scalar_value = PrimitiveValues.scalar_value [@@deriving show, ord] -type primitive_value = PrimitiveValues.primitive_value [@@deriving show, ord] +type literal = PrimitiveValues.literal [@@deriving show, ord] type symbolic_value_id = SymbolicValueId.id [@@deriving show, ord] type symbolic_value_id_set = SymbolicValueId.Set.t [@@deriving show, ord] type loop_id = LoopId.id [@@deriving show, ord] @@ -50,6 +50,8 @@ type sv_kind = (** A value given back by a loop (when ending abstractions while going backwards) *) | LoopJoin (** The result of a loop join (when computing loop fixed points) *) + | Aggregate + (** A symbolic value we introduce in place of an aggregate value *) [@@deriving show, ord] (** Ancestor for {!symbolic_value} iter visitor *) @@ -110,10 +112,7 @@ type loan_id_set = BorrowId.Set.t [@@deriving show, ord] class ['self] iter_typed_value_base = object (self : 'self) inherit [_] iter_symbolic_value - - method visit_primitive_value : 'env -> primitive_value -> unit = - fun _ _ -> () - + method visit_literal : 'env -> literal -> unit = fun _ _ -> () method visit_erased_region : 'env -> erased_region -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () method visit_ety : 'env -> ety -> unit = fun _ _ -> () @@ -131,9 +130,7 @@ class ['self] iter_typed_value_base = class ['self] map_typed_value_base = object (self : 'self) inherit [_] map_symbolic_value - - method visit_primitive_value : 'env -> primitive_value -> primitive_value = - fun _ cv -> cv + method visit_literal : 'env -> literal -> literal = fun _ cv -> cv method visit_erased_region : 'env -> erased_region -> erased_region = fun _ r -> r @@ -152,7 +149,7 @@ class ['self] map_typed_value_base = (** An untyped value, used in the environments *) type value = - | Primitive of primitive_value (** Non-symbolic primitive value *) + | Literal of literal (** Non-symbolic primitive value *) | Adt of adt_value (** Enumerations and structures *) | Bottom (** No value (uninitialized or moved value) *) | Borrow of borrow_content (** A borrowed value *) @@ -1019,7 +1016,7 @@ type abs = { TODO: this should rather be name "expanded_symbolic" *) type symbolic_expansion = - | SePrimitive of primitive_value + | SeLiteral of literal | SeAdt of (VariantId.id option * symbolic_value list) | SeMutRef of BorrowId.id * symbolic_value | SeSharedRef of BorrowId.Set.t * symbolic_value diff --git a/compiler/ValuesUtils.ml b/compiler/ValuesUtils.ml index abbfad31..527434c1 100644 --- a/compiler/ValuesUtils.ml +++ b/compiler/ValuesUtils.ml @@ -3,6 +3,7 @@ open TypesUtils open Types open Values module TA = TypesAnalysis +include PrimitiveValuesUtils (** Utility exception *) exception FoundSymbolicValue of symbolic_value @@ -16,6 +17,9 @@ let mk_bottom (ty : ety) : typed_value = { value = Bottom; ty } let mk_abottom (ty : rty) : typed_avalue = { value = ABottom; ty } let mk_aignored (ty : rty) : typed_avalue = { value = AIgnored; ty } +let value_as_symbolic (v : value) : symbolic_value = + match v with Symbolic v -> v | _ -> raise (Failure "Unexpected") + (** Box a value *) let mk_box_value (v : typed_value) : typed_value = let box_ty = mk_box_ty v.ty in diff --git a/compiler/dune b/compiler/dune index b74b65fa..6785cad4 100644 --- a/compiler/dune +++ b/compiler/dune @@ -48,6 +48,8 @@ PrePasses Print PrintPure + PrimitiveValues + PrimitiveValuesUtils PureMicroPasses Pure PureTypeCheck @@ -67,8 +69,7 @@ TypesUtils Utils Values - ValuesUtils - PrimitiveValues)) + ValuesUtils)) (documentation (package aeneas)) |