summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--fstar/Primitives.fst4
-rw-r--r--src/Expressions.ml8
-rw-r--r--src/ExtractToFStar.ml33
-rw-r--r--src/InterpreterExpressions.ml10
-rw-r--r--src/LlbcOfJson.ml4
-rw-r--r--src/Print.ml10
-rw-r--r--src/PrintPure.ml12
-rw-r--r--src/Pure.ml31
-rw-r--r--src/StringUtils.ml11
-rw-r--r--src/SymbolicToPure.ml6
-rw-r--r--tests/betree/Primitives.fst4
-rw-r--r--tests/hashmap/Primitives.fst4
-rw-r--r--tests/hashmap_on_disk/Primitives.fst4
-rw-r--r--tests/misc/NoNestedBorrows.fst4
-rw-r--r--tests/misc/Primitives.fst4
15 files changed, 99 insertions, 50 deletions
diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst
index 77cf59aa..f73c8c09 100644
--- a/fstar/Primitives.fst
+++ b/fstar/Primitives.fst
@@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x * y)
+(** Cast an integer from a [src_ty] to a [tgt_ty] *)
+let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+ mk_scalar tgt_ty x
+
/// The scalar types
type isize : eqtype = scalar Isize
type i8 : eqtype = scalar I8
diff --git a/src/Expressions.ml b/src/Expressions.ml
index 61a2f95c..6bf14c66 100644
--- a/src/Expressions.ml
+++ b/src/Expressions.ml
@@ -18,7 +18,13 @@ type projection_elem =
type projection = projection_elem list [@@deriving show]
type place = { var_id : VarId.id; projection : projection } [@@deriving show]
type borrow_kind = Shared | Mut | TwoPhaseMut [@@deriving show]
-type unop = Not | Neg [@@deriving show, ord]
+
+type unop =
+ | Not
+ | Neg
+ | Cast of integer_type * integer_type
+ (** Cast an integer from a source type to a target type *)
+[@@deriving show, ord]
(** A binary operation
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index b5190a45..84e447a8 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -44,7 +44,10 @@ let fstar_int_name (int_ty : integer_type) =
(** Small helper to compute the name of a unary operation *)
let fstar_unop_name (unop : unop) : string =
- match unop with Not -> "not" | Neg int_ty -> fstar_int_name int_ty ^ "_neg"
+ match unop with
+ | Not -> "not"
+ | Neg int_ty -> fstar_int_name int_ty ^ "_neg"
+ | Cast _ -> raise (Failure "Unsupported")
(** Small helper to compute the name of a binary operation (note that many
binary operations like "less than" are extracted to primitive operations,
@@ -83,7 +86,9 @@ let fstar_keywords =
"rec";
"in";
"fn";
+ "val";
"int";
+ "nat";
"list";
"FStar";
"FStar.Mul";
@@ -95,6 +100,7 @@ let fstar_keywords =
"Type0";
"unit";
"not";
+ "scalar_cast";
]
in
List.concat [ named_unops; named_binops; misc ]
@@ -142,12 +148,25 @@ let fstar_names_map_init =
let fstar_extract_unop (extract_expr : bool -> texpression -> unit)
(fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit
=
- let unop = fstar_unop_name unop in
- if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt unop;
- F.pp_print_space fmt ();
- extract_expr true arg;
- if inside then F.pp_print_string fmt ")"
+ match unop with
+ | Not | Neg _ ->
+ let unop = fstar_unop_name unop in
+ if inside then F.pp_print_string fmt "(";
+ F.pp_print_string fmt unop;
+ F.pp_print_space fmt ();
+ extract_expr true arg;
+ if inside then F.pp_print_string fmt ")"
+ | Cast (_src, tgt) ->
+ (* The source type is an implicit parameter *)
+ if inside then F.pp_print_string fmt "(";
+ F.pp_print_string fmt "scalar_cast";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string tgt));
+ F.pp_print_space fmt ();
+ extract_expr true arg;
+ if inside then F.pp_print_string fmt ")"
let fstar_extract_binop (extract_expr : bool -> texpression -> unit)
(fmt : F.formatter) (inside : bool) (binop : E.binop)
diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml
index e46ca721..4549365d 100644
--- a/src/InterpreterExpressions.ml
+++ b/src/InterpreterExpressions.ml
@@ -274,6 +274,15 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand)
match mk_scalar sv.int_ty i with
| Error _ -> cf (Error EPanic)
| Ok sv -> cf (Ok { v with V.value = V.Concrete (V.Scalar sv) }))
+ | E.Cast (src_ty, tgt_ty), V.Concrete (V.Scalar sv) -> (
+ assert (src_ty == sv.int_ty);
+ let i = sv.V.value in
+ match mk_scalar tgt_ty i with
+ | Error _ -> cf (Error EPanic)
+ | Ok sv ->
+ let ty = T.Integer tgt_ty in
+ let value = V.Concrete (V.Scalar sv) in
+ cf (Ok { V.ty; value }))
| _ -> raise (Failure "Invalid input for unop")
in
comp eval_op apply cf
@@ -291,6 +300,7 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand)
match (unop, v.V.ty) with
| E.Not, T.Bool -> T.Bool
| E.Neg, T.Integer int_ty -> T.Integer int_ty
+ | E.Cast (_, tgt_ty), _ -> T.Integer tgt_ty
| _ -> raise (Failure "Invalid input for unop")
in
let res_sv =
diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml
index 32ca802e..99d652ec 100644
--- a/src/LlbcOfJson.ml
+++ b/src/LlbcOfJson.ml
@@ -367,6 +367,10 @@ let unop_of_json (js : json) : (E.unop, string) result =
match js with
| `String "Not" -> Ok E.Not
| `String "Neg" -> Ok E.Neg
+ | `Assoc [ ("Cast", `List [ src_ty; tgt_ty ]) ] ->
+ let* src_ty = integer_type_of_json src_ty in
+ let* tgt_ty = integer_type_of_json tgt_ty in
+ Ok (E.Cast (src_ty, tgt_ty))
| _ -> Error ("unop_of_json failed on:" ^ show js)
let binop_of_json (js : json) : (E.binop, string) result =
diff --git a/src/Print.ml b/src/Print.ml
index 0c4ef20a..8e29bc67 100644
--- a/src/Print.ml
+++ b/src/Print.ml
@@ -830,7 +830,15 @@ module LlbcAst = struct
projection_to_string fmt var p.E.projection
let unop_to_string (unop : E.unop) : string =
- match unop with E.Not -> "¬" | E.Neg -> "-"
+ match unop with
+ | E.Not -> "¬"
+ | E.Neg -> "-"
+ | E.Cast (src, tgt) ->
+ "cast<"
+ ^ PT.integer_type_to_string src
+ ^ ","
+ ^ PT.integer_type_to_string tgt
+ ^ ">"
let binop_to_string (binop : E.binop) : string =
match binop with
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 07144d3e..5e817dde 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -61,15 +61,10 @@ let ast_to_type_formatter (fmt : ast_formatter) : type_formatter =
value_to_type_formatter fmt
let name_to_string = Print.name_to_string
-
let fun_name_to_string = Print.fun_name_to_string
-
let option_to_string = Print.option_to_string
-
let type_var_to_string = Print.Types.type_var_to_string
-
let integer_type_to_string = Print.Types.integer_type_to_string
-
let scalar_value_to_string = Print.Values.scalar_value_to_string
let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
@@ -419,7 +414,12 @@ let fun_suffix (rg_id : T.RegionGroupId.id option) : string =
| Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
let unop_to_string (unop : unop) : string =
- match unop with Not -> "¬" | Neg _ -> "-"
+ match unop with
+ | Not -> "¬"
+ | Neg _ -> "-"
+ | Cast (src, tgt) ->
+ "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt
+ ^ ">"
let binop_to_string = Print.LlbcAst.binop_to_string
diff --git a/src/Pure.ml b/src/Pure.ml
index f5bed43d..5834b87f 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -39,11 +39,8 @@ type assumed_ty = State | Result | Vec | Option [@@deriving show, ord]
* the monadic functions `return` and `fail` (makes treatment of error and
* state-error monads more uniform) *)
let result_return_id = VariantId.of_int 0
-
let result_fail_id = VariantId.of_int 1
-
let option_some_id = T.option_some_id
-
let option_none_id = T.option_none_id
type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
@@ -53,11 +50,8 @@ type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
class ['self] iter_ty_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.iter
-
method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> ()
-
method visit_type_id : 'env -> type_id -> unit = fun _ _ -> ()
-
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
end
@@ -65,9 +59,7 @@ class ['self] iter_ty_base =
class ['self] map_ty_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.map
-
method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id
-
method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id
method visit_integer_type : 'env -> integer_type -> integer_type =
@@ -113,7 +105,6 @@ type ty =
}]
type field = { field_name : string option; field_ty : ty } [@@deriving show]
-
type variant = { variant_name : string; fields : field list } [@@deriving show]
type type_decl_kind = Struct of field list | Enum of variant list | Opaque
@@ -130,7 +121,6 @@ type type_decl = {
[@@deriving show]
type scalar_value = V.scalar_value [@@deriving show]
-
type constant_value = V.constant_value [@@deriving show]
type var = {
@@ -176,15 +166,10 @@ type variant_id = VariantId.id [@@deriving show]
class ['self] iter_value_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.iter
-
method visit_constant_value : 'env -> constant_value -> unit = fun _ _ -> ()
-
method visit_var : 'env -> var -> unit = fun _ _ -> ()
-
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
-
method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
-
method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
end
@@ -197,11 +182,8 @@ class ['self] map_value_base =
fun _ x -> x
method visit_var : 'env -> var -> var = fun _ x -> x
-
method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
-
method visit_ty : 'env -> ty -> ty = fun _ x -> x
-
method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x
end
@@ -214,11 +196,8 @@ class virtual ['self] reduce_value_base =
fun _ _ -> self#zero
method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero
-
method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero
-
method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
-
method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero
end
@@ -294,7 +273,8 @@ and typed_pattern = { value : pattern; ty : ty }
polymorphic = false;
}]
-type unop = Not | Neg of integer_type [@@deriving show, ord]
+type unop = Not | Neg of integer_type | Cast of integer_type * integer_type
+[@@deriving show, ord]
type fun_id =
| Regular of A.fun_id * T.RegionGroupId.id option
@@ -341,11 +321,8 @@ type var_id = VarId.id [@@deriving show]
class ['self] iter_expression_base =
object (_self : 'self)
inherit [_] iter_typed_pattern
-
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
-
method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
-
method visit_qualif : 'env -> qualif -> unit = fun _ _ -> ()
end
@@ -358,7 +335,6 @@ class ['self] map_expression_base =
fun _ x -> x
method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
-
method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x
end
@@ -371,7 +347,6 @@ class virtual ['self] reduce_expression_base =
fun _ _ -> self#zero
method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero
-
method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero
end
@@ -451,9 +426,7 @@ type expression =
| Meta of (meta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
-
and match_branch = { pat : typed_pattern; branch : texpression }
-
and texpression = { e : expression; ty : ty }
and mvalue = (texpression[@opaque])
diff --git a/src/StringUtils.ml b/src/StringUtils.ml
index adf63151..601249ca 100644
--- a/src/StringUtils.ml
+++ b/src/StringUtils.ml
@@ -7,15 +7,10 @@
*)
let code_0 = 48
-
let code_9 = 57
-
let code_A = 65
-
let code_Z = 90
-
let code_a = 97
-
let code_z = 122
let is_lowercase_ascii (c : char) : bool =
@@ -34,7 +29,6 @@ let is_digit_ascii (c : char) : bool =
code_0 <= c && c <= code_9
let lowercase_ascii = Char.lowercase_ascii
-
let uppercase_ascii = Char.uppercase_ascii
(** Using buffers as per:
@@ -97,6 +91,11 @@ let map (f : char -> string) (s : string) : string =
let sl = List.map string_to_chars sl in
string_of_chars (List.concat sl)
+let capitalize_first_letter (s : string) : string =
+ let s = string_to_chars s in
+ let s = match s with [] -> s | c :: s' -> uppercase_ascii c :: s' in
+ string_of_chars s
+
(** Unit tests *)
let _ =
assert (to_camel_case "hello_world" = "HelloWorld");
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 3ac68365..42479a6e 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -1177,6 +1177,12 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
+ | S.Unop (E.Cast (src_ty, tgt_ty)) ->
+ (* Note that cast can fail *)
+ let effect_info =
+ { can_fail = true; input_state = false; output_state = false }
+ in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
| S.Binop binop -> (
match args with
| [ arg0; arg1 ] ->
diff --git a/tests/betree/Primitives.fst b/tests/betree/Primitives.fst
index 77cf59aa..f73c8c09 100644
--- a/tests/betree/Primitives.fst
+++ b/tests/betree/Primitives.fst
@@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x * y)
+(** Cast an integer from a [src_ty] to a [tgt_ty] *)
+let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+ mk_scalar tgt_ty x
+
/// The scalar types
type isize : eqtype = scalar Isize
type i8 : eqtype = scalar I8
diff --git a/tests/hashmap/Primitives.fst b/tests/hashmap/Primitives.fst
index 77cf59aa..f73c8c09 100644
--- a/tests/hashmap/Primitives.fst
+++ b/tests/hashmap/Primitives.fst
@@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x * y)
+(** Cast an integer from a [src_ty] to a [tgt_ty] *)
+let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+ mk_scalar tgt_ty x
+
/// The scalar types
type isize : eqtype = scalar Isize
type i8 : eqtype = scalar I8
diff --git a/tests/hashmap_on_disk/Primitives.fst b/tests/hashmap_on_disk/Primitives.fst
index 77cf59aa..f73c8c09 100644
--- a/tests/hashmap_on_disk/Primitives.fst
+++ b/tests/hashmap_on_disk/Primitives.fst
@@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x * y)
+(** Cast an integer from a [src_ty] to a [tgt_ty] *)
+let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+ mk_scalar tgt_ty x
+
/// The scalar types
type isize : eqtype = scalar Isize
type i8 : eqtype = scalar I8
diff --git a/tests/misc/NoNestedBorrows.fst b/tests/misc/NoNestedBorrows.fst
index c83656bf..97688191 100644
--- a/tests/misc/NoNestedBorrows.fst
+++ b/tests/misc/NoNestedBorrows.fst
@@ -54,6 +54,10 @@ let div_test1_fwd (x : u32) : result u32 =
let rem_test_fwd (x : u32) (y : u32) : result u32 =
begin match u32_rem x y with | Fail -> Fail | Return i -> Return i end
+(** [no_nested_borrows::cast_test] *)
+let cast_test_fwd (x : u32) : result i32 =
+ begin match scalar_cast I32 x with | Fail -> Fail | Return i -> Return i end
+
(** [no_nested_borrows::test2] *)
let test2_fwd : result unit = Return ()
diff --git a/tests/misc/Primitives.fst b/tests/misc/Primitives.fst
index 77cf59aa..f73c8c09 100644
--- a/tests/misc/Primitives.fst
+++ b/tests/misc/Primitives.fst
@@ -145,6 +145,10 @@ let scalar_sub (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x * y)
+(** Cast an integer from a [src_ty] to a [tgt_ty] *)
+let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+ mk_scalar tgt_ty x
+
/// The scalar types
type isize : eqtype = scalar Isize
type i8 : eqtype = scalar I8