summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-14 17:11:54 +0100
committerSon Ho2023-12-14 17:11:54 +0100
commitf074320eee2203857e669cfb72f7f8f94ab52151 (patch)
treed5f2f8d4a45f206e0a94e980ea4c6ad074f2bc19 /compiler
parentf69ac6a4a244c99a41a90ed57f74ea83b3835882 (diff)
parentc3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff)
Merge remote-tracking branch 'origin/main' into son/merge_back
Diffstat (limited to '')
-rw-r--r--compiler/ExtractBase.ml10
-rw-r--r--compiler/ExtractTypes.ml81
-rw-r--r--compiler/InterpreterExpressions.ml25
-rw-r--r--compiler/PrintPure.ml2
-rw-r--r--compiler/Pure.ml2
-rw-r--r--compiler/SymbolicToPure.ml2
6 files changed, 95 insertions, 27 deletions
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 93204515..eb2a2ec9 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -498,7 +498,7 @@ let char_name () = if !backend = Lean then "Char" else "char"
let str_name () = if !backend = Lean then "String" else "string"
(** Small helper to compute the name of an int type *)
-let int_name (int_ty : integer_type) =
+let int_name (int_ty : integer_type) : string =
let isize, usize, i_format, u_format =
match !backend with
| FStar | Coq | HOL4 ->
@@ -519,6 +519,14 @@ let int_name (int_ty : integer_type) =
| U64 -> Printf.sprintf u_format 64
| U128 -> Printf.sprintf u_format 128
+let scalar_name (ty : literal_type) : string =
+ match ty with
+ | TInteger ty -> int_name ty
+ | TBool -> (
+ match !backend with FStar | Coq | HOL4 -> "bool" | Lean -> "Bool")
+ | TChar -> (
+ match !backend with FStar | Coq | HOL4 -> "char" | Lean -> "Char")
+
(** Extraction context.
Note that the extraction context contains information coming from the
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index 785f7629..51e3fd77 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -107,35 +107,74 @@ let extract_unop (extract_expr : bool -> texpression -> unit)
]}
*)
if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt ("mk_" ^ int_name tgt);
+ F.pp_print_string fmt ("mk_" ^ scalar_name tgt);
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- F.pp_print_string fmt (int_name src ^ "_to_int");
+ F.pp_print_string fmt (scalar_name src ^ "_to_int");
F.pp_print_space fmt ();
extract_expr true arg;
F.pp_print_string fmt ")";
if inside then F.pp_print_string fmt ")"
| FStar | Coq | Lean ->
- (* Rem.: the source type is an implicit parameter *)
if inside then F.pp_print_string fmt "(";
- let cast_str =
- match !backend with
- | Coq | FStar -> "scalar_cast"
- | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast"
- | HOL4 -> raise (Failure "Unreachable")
- in
- F.pp_print_string fmt cast_str;
- F.pp_print_space fmt ();
- if !backend <> Lean then (
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string src));
- F.pp_print_space fmt ());
- if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt)
- else
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string tgt));
+ (* Rem.: the source type is an implicit parameter *)
+ (* Different cases depending on the conversion *)
+ (let cast_str, src, tgt =
+ let integer_type_to_string (ty : integer_type) : string =
+ if !backend = Lean then "." ^ int_name ty
+ else
+ StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string ty)
+ in
+ match (src, tgt) with
+ | TInteger src, TInteger tgt ->
+ let cast_str =
+ match !backend with
+ | Coq | FStar -> "scalar_cast"
+ | Lean -> "Scalar.cast"
+ | HOL4 -> raise (Failure "Unreachable")
+ in
+ let src =
+ if !backend <> Lean then Some (integer_type_to_string src)
+ else None
+ in
+ let tgt = integer_type_to_string tgt in
+ (cast_str, src, Some tgt)
+ | TBool, TInteger tgt ->
+ let cast_str =
+ match !backend with
+ | Coq | FStar -> "scalar_cast_bool"
+ | Lean -> "Scalar.cast_bool"
+ | HOL4 -> raise (Failure "Unreachable")
+ in
+ let tgt = integer_type_to_string tgt in
+ (cast_str, None, Some tgt)
+ | TInteger _, TBool ->
+ (* This is not allowed by rustc: the way of doing it in Rust is: [x != 0] *)
+ raise (Failure "Unexpected cast: integer to bool")
+ | TBool, TBool ->
+ (* There shouldn't be any cast here. Note that if
+ one writes [b as bool] in Rust (where [b] is a
+ boolean), it gets compiled to [b] (i.e., no cast
+ is introduced). *)
+ raise (Failure "Unexpected cast: bool to bool")
+ | _ -> raise (Failure "Unreachable")
+ in
+ (* Print the name of the function *)
+ F.pp_print_string fmt cast_str;
+ (* Print the src type argument *)
+ (match src with
+ | None -> ()
+ | Some src ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt src);
+ (* Print the tgt type argument *)
+ match tgt with
+ | None -> ()
+ | Some tgt ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt tgt);
+ (* Extract the argument *)
F.pp_print_space fmt ();
extract_expr true arg;
if inside then F.pp_print_string fmt ")")
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 9f117933..1b5b79dd 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -435,7 +435,9 @@ let eval_unary_op_concrete (config : config) (unop : unop) (op : operand)
match mk_scalar sv.int_ty i with
| Error _ -> cf (Error EPanic)
| Ok sv -> cf (Ok { v with value = VLiteral (VScalar sv) }))
- | Cast (CastInteger (src_ty, tgt_ty)), VLiteral (VScalar sv) -> (
+ | ( Cast (CastScalar (TInteger src_ty, TInteger tgt_ty)),
+ VLiteral (VScalar sv) ) -> (
+ (* Cast between integers *)
assert (src_ty = sv.int_ty);
let i = sv.value in
match mk_scalar tgt_ty i with
@@ -444,6 +446,25 @@ let eval_unary_op_concrete (config : config) (unop : unop) (op : operand)
let ty = TLiteral (TInteger tgt_ty) in
let value = VLiteral (VScalar sv) in
cf (Ok { ty; value }))
+ | Cast (CastScalar (TBool, TInteger tgt_ty)), VLiteral (VBool sv) -> (
+ (* Cast bool -> int *)
+ let i = Z.of_int (if sv then 1 else 0) in
+ match mk_scalar tgt_ty i with
+ | Error _ -> cf (Error EPanic)
+ | Ok sv ->
+ let ty = TLiteral (TInteger tgt_ty) in
+ let value = VLiteral (VScalar sv) in
+ cf (Ok { ty; value }))
+ | Cast (CastScalar (TInteger _, TBool)), VLiteral (VScalar sv) ->
+ (* Cast int -> bool *)
+ let b =
+ if Z.of_int 0 = sv.value then false
+ else if Z.of_int 1 = sv.value then true
+ else raise (Failure "Conversion from int to bool: out of range")
+ in
+ let value = VLiteral (VBool b) in
+ let ty = TLiteral TBool in
+ cf (Ok { ty; value })
| _ -> raise (Failure "Invalid input for unop")
in
comp eval_op apply cf
@@ -461,7 +482,7 @@ let eval_unary_op_symbolic (config : config) (unop : unop) (op : operand)
match (unop, v.ty) with
| Not, (TLiteral TBool as lty) -> lty
| Neg, (TLiteral (TInteger _) as lty) -> lty
- | Cast (CastInteger (_, tgt_ty)), _ -> TLiteral (TInteger tgt_ty)
+ | Cast (CastScalar (_, tgt_ty)), _ -> TLiteral tgt_ty
| _ -> raise (Failure "Invalid input for unop")
in
let res_sv = { sv_id = res_sv_id; sv_ty = res_sv_ty } in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index d33a2f18..2fe5843e 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -519,7 +519,7 @@ let unop_to_string (unop : unop) : string =
| Not -> "¬"
| Neg _ -> "-"
| Cast (src, tgt) ->
- "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt
+ "cast<" ^ literal_type_to_string src ^ "," ^ literal_type_to_string tgt
^ ">"
let binop_to_string = Print.Expressions.binop_to_string
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index bb522623..c3716001 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -540,7 +540,7 @@ and typed_pattern = { value : pattern; ty : ty }
polymorphic = false;
}]
-type unop = Not | Neg of integer_type | Cast of integer_type * integer_type
+type unop = Not | Neg of integer_type | Cast of literal_type * literal_type
[@@deriving show, ord]
(** Identifiers of assumed functions that we use only in the pure translation *)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 59205f08..1fd4896e 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -1793,7 +1793,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
- | CastInteger (src_ty, tgt_ty) ->
+ | CastScalar (src_ty, tgt_ty) ->
(* Note that cast can fail *)
let effect_info =
{