diff options
author | Son Ho | 2022-05-15 21:30:49 +0200 |
---|---|---|
committer | Son Ho | 2022-05-15 21:30:49 +0200 |
commit | a25d820b6eb02f573ad2c274a35e3496a9dacd40 (patch) | |
tree | d491994904b8f57b4b5ed993f61cec2127ebe20c | |
parent | f8f07a3135e69529407dfd9359197cb09e78776f (diff) |
Treat integer casts in a general manner
-rw-r--r-- | fstar/Primitives.fst | 4 | ||||
-rw-r--r-- | src/Expressions.ml | 8 | ||||
-rw-r--r-- | src/ExtractToFStar.ml | 33 | ||||
-rw-r--r-- | src/InterpreterExpressions.ml | 10 | ||||
-rw-r--r-- | src/LlbcOfJson.ml | 4 | ||||
-rw-r--r-- | src/Print.ml | 10 | ||||
-rw-r--r-- | src/PrintPure.ml | 12 | ||||
-rw-r--r-- | src/Pure.ml | 31 | ||||
-rw-r--r-- | src/StringUtils.ml | 11 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 6 | ||||
-rw-r--r-- | tests/betree/Primitives.fst | 4 | ||||
-rw-r--r-- | tests/hashmap/Primitives.fst | 4 | ||||
-rw-r--r-- | tests/hashmap_on_disk/Primitives.fst | 4 | ||||
-rw-r--r-- | tests/misc/NoNestedBorrows.fst | 4 | ||||
-rw-r--r-- | tests/misc/Primitives.fst | 4 |
15 files changed, 99 insertions, 50 deletions
diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst index 77cf59aa..f73c8c09 100644 --- a/fstar/Primitives.fst +++ b/fstar/Primitives.fst @@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x * y) +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = + mk_scalar tgt_ty x + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 diff --git a/src/Expressions.ml b/src/Expressions.ml index 61a2f95c..6bf14c66 100644 --- a/src/Expressions.ml +++ b/src/Expressions.ml @@ -18,7 +18,13 @@ type projection_elem = type projection = projection_elem list [@@deriving show] type place = { var_id : VarId.id; projection : projection } [@@deriving show] type borrow_kind = Shared | Mut | TwoPhaseMut [@@deriving show] -type unop = Not | Neg [@@deriving show, ord] + +type unop = + | Not + | Neg + | Cast of integer_type * integer_type + (** Cast an integer from a source type to a target type *) +[@@deriving show, ord] (** A binary operation diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index b5190a45..84e447a8 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -44,7 +44,10 @@ let fstar_int_name (int_ty : integer_type) = (** Small helper to compute the name of a unary operation *) let fstar_unop_name (unop : unop) : string = - match unop with Not -> "not" | Neg int_ty -> fstar_int_name int_ty ^ "_neg" + match unop with + | Not -> "not" + | Neg int_ty -> fstar_int_name int_ty ^ "_neg" + | Cast _ -> raise (Failure "Unsupported") (** Small helper to compute the name of a binary operation (note that many binary operations like "less than" are extracted to primitive operations, @@ -83,7 +86,9 @@ let fstar_keywords = "rec"; "in"; "fn"; + "val"; "int"; + "nat"; "list"; "FStar"; "FStar.Mul"; @@ -95,6 +100,7 @@ let fstar_keywords = "Type0"; "unit"; "not"; + "scalar_cast"; ] in List.concat [ named_unops; named_binops; misc ] @@ -142,12 +148,25 @@ let fstar_names_map_init = let fstar_extract_unop (extract_expr : bool -> texpression -> unit) (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit = - let unop = fstar_unop_name unop in - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt unop; - F.pp_print_space fmt (); - extract_expr true arg; - if inside then F.pp_print_string fmt ")" + match unop with + | Not | Neg _ -> + let unop = fstar_unop_name unop in + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt unop; + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")" + | Cast (_src, tgt) -> + (* The source type is an implicit parameter *) + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt "scalar_cast"; + F.pp_print_space fmt (); + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string tgt)); + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")" let fstar_extract_binop (extract_expr : bool -> texpression -> unit) (fmt : F.formatter) (inside : bool) (binop : E.binop) diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml index e46ca721..4549365d 100644 --- a/src/InterpreterExpressions.ml +++ b/src/InterpreterExpressions.ml @@ -274,6 +274,15 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with V.value = V.Concrete (V.Scalar sv) })) + | E.Cast (src_ty, tgt_ty), V.Concrete (V.Scalar sv) -> ( + assert (src_ty == sv.int_ty); + let i = sv.V.value in + match mk_scalar tgt_ty i with + | Error _ -> cf (Error EPanic) + | Ok sv -> + let ty = T.Integer tgt_ty in + let value = V.Concrete (V.Scalar sv) in + cf (Ok { V.ty; value })) | _ -> raise (Failure "Invalid input for unop") in comp eval_op apply cf @@ -291,6 +300,7 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand) match (unop, v.V.ty) with | E.Not, T.Bool -> T.Bool | E.Neg, T.Integer int_ty -> T.Integer int_ty + | E.Cast (_, tgt_ty), _ -> T.Integer tgt_ty | _ -> raise (Failure "Invalid input for unop") in let res_sv = diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml index 32ca802e..99d652ec 100644 --- a/src/LlbcOfJson.ml +++ b/src/LlbcOfJson.ml @@ -367,6 +367,10 @@ let unop_of_json (js : json) : (E.unop, string) result = match js with | `String "Not" -> Ok E.Not | `String "Neg" -> Ok E.Neg + | `Assoc [ ("Cast", `List [ src_ty; tgt_ty ]) ] -> + let* src_ty = integer_type_of_json src_ty in + let* tgt_ty = integer_type_of_json tgt_ty in + Ok (E.Cast (src_ty, tgt_ty)) | _ -> Error ("unop_of_json failed on:" ^ show js) let binop_of_json (js : json) : (E.binop, string) result = diff --git a/src/Print.ml b/src/Print.ml index 0c4ef20a..8e29bc67 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -830,7 +830,15 @@ module LlbcAst = struct projection_to_string fmt var p.E.projection let unop_to_string (unop : E.unop) : string = - match unop with E.Not -> "¬" | E.Neg -> "-" + match unop with + | E.Not -> "¬" + | E.Neg -> "-" + | E.Cast (src, tgt) -> + "cast<" + ^ PT.integer_type_to_string src + ^ "," + ^ PT.integer_type_to_string tgt + ^ ">" let binop_to_string (binop : E.binop) : string = match binop with diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 07144d3e..5e817dde 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -61,15 +61,10 @@ let ast_to_type_formatter (fmt : ast_formatter) : type_formatter = value_to_type_formatter fmt let name_to_string = Print.name_to_string - let fun_name_to_string = Print.fun_name_to_string - let option_to_string = Print.option_to_string - let type_var_to_string = Print.Types.type_var_to_string - let integer_type_to_string = Print.Types.integer_type_to_string - let scalar_value_to_string = Print.Values.scalar_value_to_string let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) @@ -419,7 +414,12 @@ let fun_suffix (rg_id : T.RegionGroupId.id option) : string = | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id let unop_to_string (unop : unop) : string = - match unop with Not -> "¬" | Neg _ -> "-" + match unop with + | Not -> "¬" + | Neg _ -> "-" + | Cast (src, tgt) -> + "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt + ^ ">" let binop_to_string = Print.LlbcAst.binop_to_string diff --git a/src/Pure.ml b/src/Pure.ml index f5bed43d..5834b87f 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -39,11 +39,8 @@ type assumed_ty = State | Result | Vec | Option [@@deriving show, ord] * the monadic functions `return` and `fail` (makes treatment of error and * state-error monads more uniform) *) let result_return_id = VariantId.of_int 0 - let result_fail_id = VariantId.of_int 1 - let option_some_id = T.option_some_id - let option_none_id = T.option_none_id type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty @@ -53,11 +50,8 @@ type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty class ['self] iter_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> () - method visit_type_id : 'env -> type_id -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () end @@ -65,9 +59,7 @@ class ['self] iter_ty_base = class ['self] map_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.map - method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id - method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id method visit_integer_type : 'env -> integer_type -> integer_type = @@ -113,7 +105,6 @@ type ty = }] type field = { field_name : string option; field_ty : ty } [@@deriving show] - type variant = { variant_name : string; fields : field list } [@@deriving show] type type_decl_kind = Struct of field list | Enum of variant list | Opaque @@ -130,7 +121,6 @@ type type_decl = { [@@deriving show] type scalar_value = V.scalar_value [@@deriving show] - type constant_value = V.constant_value [@@deriving show] type var = { @@ -176,15 +166,10 @@ type variant_id = VariantId.id [@@deriving show] class ['self] iter_value_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - method visit_constant_value : 'env -> constant_value -> unit = fun _ _ -> () - method visit_var : 'env -> var -> unit = fun _ _ -> () - method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () - method visit_ty : 'env -> ty -> unit = fun _ _ -> () - method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () end @@ -197,11 +182,8 @@ class ['self] map_value_base = fun _ x -> x method visit_var : 'env -> var -> var = fun _ x -> x - method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x - method visit_ty : 'env -> ty -> ty = fun _ x -> x - method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x end @@ -214,11 +196,8 @@ class virtual ['self] reduce_value_base = fun _ _ -> self#zero method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero - method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero - method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero - method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero end @@ -294,7 +273,8 @@ and typed_pattern = { value : pattern; ty : ty } polymorphic = false; }] -type unop = Not | Neg of integer_type [@@deriving show, ord] +type unop = Not | Neg of integer_type | Cast of integer_type * integer_type +[@@deriving show, ord] type fun_id = | Regular of A.fun_id * T.RegionGroupId.id option @@ -341,11 +321,8 @@ type var_id = VarId.id [@@deriving show] class ['self] iter_expression_base = object (_self : 'self) inherit [_] iter_typed_pattern - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () - method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () - method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () end @@ -358,7 +335,6 @@ class ['self] map_expression_base = fun _ x -> x method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x - method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x end @@ -371,7 +347,6 @@ class virtual ['self] reduce_expression_base = fun _ _ -> self#zero method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero - method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero end @@ -451,9 +426,7 @@ type expression = | Meta of (meta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list - and match_branch = { pat : typed_pattern; branch : texpression } - and texpression = { e : expression; ty : ty } and mvalue = (texpression[@opaque]) diff --git a/src/StringUtils.ml b/src/StringUtils.ml index adf63151..601249ca 100644 --- a/src/StringUtils.ml +++ b/src/StringUtils.ml @@ -7,15 +7,10 @@ *) let code_0 = 48 - let code_9 = 57 - let code_A = 65 - let code_Z = 90 - let code_a = 97 - let code_z = 122 let is_lowercase_ascii (c : char) : bool = @@ -34,7 +29,6 @@ let is_digit_ascii (c : char) : bool = code_0 <= c && c <= code_9 let lowercase_ascii = Char.lowercase_ascii - let uppercase_ascii = Char.uppercase_ascii (** Using buffers as per: @@ -97,6 +91,11 @@ let map (f : char -> string) (s : string) : string = let sl = List.map string_to_chars sl in string_of_chars (List.concat sl) +let capitalize_first_letter (s : string) : string = + let s = string_to_chars s in + let s = match s with [] -> s | c :: s' -> uppercase_ascii c :: s' in + string_of_chars s + (** Unit tests *) let _ = assert (to_camel_case "hello_world" = "HelloWorld"); diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 3ac68365..42479a6e 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -1177,6 +1177,12 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) + | S.Unop (E.Cast (src_ty, tgt_ty)) -> + (* Note that cast can fail *) + let effect_info = + { can_fail = true; input_state = false; output_state = false } + in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) | S.Binop binop -> ( match args with | [ arg0; arg1 ] -> diff --git a/tests/betree/Primitives.fst b/tests/betree/Primitives.fst index 77cf59aa..f73c8c09 100644 --- a/tests/betree/Primitives.fst +++ b/tests/betree/Primitives.fst @@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x * y) +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = + mk_scalar tgt_ty x + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 diff --git a/tests/hashmap/Primitives.fst b/tests/hashmap/Primitives.fst index 77cf59aa..f73c8c09 100644 --- a/tests/hashmap/Primitives.fst +++ b/tests/hashmap/Primitives.fst @@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x * y) +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = + mk_scalar tgt_ty x + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 diff --git a/tests/hashmap_on_disk/Primitives.fst b/tests/hashmap_on_disk/Primitives.fst index 77cf59aa..f73c8c09 100644 --- a/tests/hashmap_on_disk/Primitives.fst +++ b/tests/hashmap_on_disk/Primitives.fst @@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x * y) +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = + mk_scalar tgt_ty x + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 diff --git a/tests/misc/NoNestedBorrows.fst b/tests/misc/NoNestedBorrows.fst index c83656bf..97688191 100644 --- a/tests/misc/NoNestedBorrows.fst +++ b/tests/misc/NoNestedBorrows.fst @@ -54,6 +54,10 @@ let div_test1_fwd (x : u32) : result u32 = let rem_test_fwd (x : u32) (y : u32) : result u32 = begin match u32_rem x y with | Fail -> Fail | Return i -> Return i end +(** [no_nested_borrows::cast_test] *) +let cast_test_fwd (x : u32) : result i32 = + begin match scalar_cast I32 x with | Fail -> Fail | Return i -> Return i end + (** [no_nested_borrows::test2] *) let test2_fwd : result unit = Return () diff --git a/tests/misc/Primitives.fst b/tests/misc/Primitives.fst index 77cf59aa..f73c8c09 100644 --- a/tests/misc/Primitives.fst +++ b/tests/misc/Primitives.fst @@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x * y) +(** Cast an integer from a [src_ty] to a [tgt_ty] *) +let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = + mk_scalar tgt_ty x + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 |