From c48859717d847f4492a0c3cc76e8f8b0b38fcc10 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 16:54:10 +0100 Subject: Update the extraction to handle casts between integers/bools --- backends/coq/Primitives.v | 4 ++ backends/fstar/Primitives.fst | 4 ++ backends/lean/Base/Primitives/Scalar.lean | 4 ++ compiler/ExtractBase.ml | 10 +++- compiler/ExtractTypes.ml | 81 +++++++++++++++++++++++-------- compiler/InterpreterExpressions.ml | 25 +++++++++- compiler/PrintPure.ml | 2 +- compiler/Pure.ml | 2 +- compiler/SymbolicToPure.ml | 2 +- 9 files changed, 107 insertions(+), 27 deletions(-) diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v index 99ffe070..84280b96 100644 --- a/backends/coq/Primitives.v +++ b/backends/coq/Primitives.v @@ -266,6 +266,10 @@ Axiom scalar_shr : forall ty0 ty1, scalar ty0 -> scalar ty1 -> result (scalar ty Definition scalar_cast (src_ty tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) := mk_scalar tgt_ty (to_Z x). +(* This can't fail, but for now we make all casts faillible (easier for the translation) *) +Definition scalar_cast_bool (tgt_ty : scalar_ty) (x : bool) : result (scalar tgt_ty) := + mk_scalar tgt_ty (if x then 1 else 0). + (** Comparisons *) Definition scalar_leb {ty : scalar_ty} (x : scalar ty) (y : scalar ty) : bool := Z.leb (to_Z x) (to_Z y) . diff --git a/backends/fstar/Primitives.fst b/backends/fstar/Primitives.fst index dd340c00..a3ffbde4 100644 --- a/backends/fstar/Primitives.fst +++ b/backends/fstar/Primitives.fst @@ -273,6 +273,10 @@ let scalar_shr (#ty0 #ty1 : scalar_ty) let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = mk_scalar tgt_ty x +// This can't fail, but for now we make all casts faillible (easier for the translation) +let scalar_cast_bool (tgt_ty : scalar_ty) (x : bool) : result (scalar tgt_ty) = + mk_scalar tgt_ty (if x then 1 else 0) + /// The scalar types type isize : eqtype = scalar Isize type i8 : eqtype = scalar I8 diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index db522df2..a8eda6d5 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -411,6 +411,10 @@ def Scalar.or {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Scalar ty := def Scalar.cast {src_ty : ScalarTy} (tgt_ty : ScalarTy) (x : Scalar src_ty) : Result (Scalar tgt_ty) := Scalar.tryMk tgt_ty x.val +-- This can't fail, but for now we make all casts faillible (easier for the translation) +def Scalar.cast_bool (tgt_ty : ScalarTy) (x : Bool) : Result (Scalar tgt_ty) := + Scalar.tryMk tgt_ty (if x then 1 else 0) + -- The scalar types -- We declare the definitions as reducible so that Lean can unfold them (useful -- for type class resolution for instance). diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 93204515..eb2a2ec9 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -498,7 +498,7 @@ let char_name () = if !backend = Lean then "Char" else "char" let str_name () = if !backend = Lean then "String" else "string" (** Small helper to compute the name of an int type *) -let int_name (int_ty : integer_type) = +let int_name (int_ty : integer_type) : string = let isize, usize, i_format, u_format = match !backend with | FStar | Coq | HOL4 -> @@ -519,6 +519,14 @@ let int_name (int_ty : integer_type) = | U64 -> Printf.sprintf u_format 64 | U128 -> Printf.sprintf u_format 128 +let scalar_name (ty : literal_type) : string = + match ty with + | TInteger ty -> int_name ty + | TBool -> ( + match !backend with FStar | Coq | HOL4 -> "bool" | Lean -> "Bool") + | TChar -> ( + match !backend with FStar | Coq | HOL4 -> "char" | Lean -> "Char") + (** Extraction context. Note that the extraction context contains information coming from the diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 785f7629..51e3fd77 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -107,35 +107,74 @@ let extract_unop (extract_expr : bool -> texpression -> unit) ]} *) if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt ("mk_" ^ int_name tgt); + F.pp_print_string fmt ("mk_" ^ scalar_name tgt); F.pp_print_space fmt (); F.pp_print_string fmt "("; - F.pp_print_string fmt (int_name src ^ "_to_int"); + F.pp_print_string fmt (scalar_name src ^ "_to_int"); F.pp_print_space fmt (); extract_expr true arg; F.pp_print_string fmt ")"; if inside then F.pp_print_string fmt ")" | FStar | Coq | Lean -> - (* Rem.: the source type is an implicit parameter *) if inside then F.pp_print_string fmt "("; - let cast_str = - match !backend with - | Coq | FStar -> "scalar_cast" - | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast" - | HOL4 -> raise (Failure "Unreachable") - in - F.pp_print_string fmt cast_str; - F.pp_print_space fmt (); - if !backend <> Lean then ( - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string src)); - F.pp_print_space fmt ()); - if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt) - else - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string tgt)); + (* Rem.: the source type is an implicit parameter *) + (* Different cases depending on the conversion *) + (let cast_str, src, tgt = + let integer_type_to_string (ty : integer_type) : string = + if !backend = Lean then "." ^ int_name ty + else + StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string ty) + in + match (src, tgt) with + | TInteger src, TInteger tgt -> + let cast_str = + match !backend with + | Coq | FStar -> "scalar_cast" + | Lean -> "Scalar.cast" + | HOL4 -> raise (Failure "Unreachable") + in + let src = + if !backend <> Lean then Some (integer_type_to_string src) + else None + in + let tgt = integer_type_to_string tgt in + (cast_str, src, Some tgt) + | TBool, TInteger tgt -> + let cast_str = + match !backend with + | Coq | FStar -> "scalar_cast_bool" + | Lean -> "Scalar.cast_bool" + | HOL4 -> raise (Failure "Unreachable") + in + let tgt = integer_type_to_string tgt in + (cast_str, None, Some tgt) + | TInteger _, TBool -> + (* This is not allowed by rustc: the way of doing it in Rust is: [x != 0] *) + raise (Failure "Unexpected cast: integer to bool") + | TBool, TBool -> + (* There shouldn't be any cast here. Note that if + one writes [b as bool] in Rust (where [b] is a + boolean), it gets compiled to [b] (i.e., no cast + is introduced). *) + raise (Failure "Unexpected cast: bool to bool") + | _ -> raise (Failure "Unreachable") + in + (* Print the name of the function *) + F.pp_print_string fmt cast_str; + (* Print the src type argument *) + (match src with + | None -> () + | Some src -> + F.pp_print_space fmt (); + F.pp_print_string fmt src); + (* Print the tgt type argument *) + match tgt with + | None -> () + | Some tgt -> + F.pp_print_space fmt (); + F.pp_print_string fmt tgt); + (* Extract the argument *) F.pp_print_space fmt (); extract_expr true arg; if inside then F.pp_print_string fmt ")") diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 9f117933..1b5b79dd 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -435,7 +435,9 @@ let eval_unary_op_concrete (config : config) (unop : unop) (op : operand) match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with value = VLiteral (VScalar sv) })) - | Cast (CastInteger (src_ty, tgt_ty)), VLiteral (VScalar sv) -> ( + | ( Cast (CastScalar (TInteger src_ty, TInteger tgt_ty)), + VLiteral (VScalar sv) ) -> ( + (* Cast between integers *) assert (src_ty = sv.int_ty); let i = sv.value in match mk_scalar tgt_ty i with @@ -444,6 +446,25 @@ let eval_unary_op_concrete (config : config) (unop : unop) (op : operand) let ty = TLiteral (TInteger tgt_ty) in let value = VLiteral (VScalar sv) in cf (Ok { ty; value })) + | Cast (CastScalar (TBool, TInteger tgt_ty)), VLiteral (VBool sv) -> ( + (* Cast bool -> int *) + let i = Z.of_int (if sv then 1 else 0) in + match mk_scalar tgt_ty i with + | Error _ -> cf (Error EPanic) + | Ok sv -> + let ty = TLiteral (TInteger tgt_ty) in + let value = VLiteral (VScalar sv) in + cf (Ok { ty; value })) + | Cast (CastScalar (TInteger _, TBool)), VLiteral (VScalar sv) -> + (* Cast int -> bool *) + let b = + if Z.of_int 0 = sv.value then false + else if Z.of_int 1 = sv.value then true + else raise (Failure "Conversion from int to bool: out of range") + in + let value = VLiteral (VBool b) in + let ty = TLiteral TBool in + cf (Ok { ty; value }) | _ -> raise (Failure "Invalid input for unop") in comp eval_op apply cf @@ -461,7 +482,7 @@ let eval_unary_op_symbolic (config : config) (unop : unop) (op : operand) match (unop, v.ty) with | Not, (TLiteral TBool as lty) -> lty | Neg, (TLiteral (TInteger _) as lty) -> lty - | Cast (CastInteger (_, tgt_ty)), _ -> TLiteral (TInteger tgt_ty) + | Cast (CastScalar (_, tgt_ty)), _ -> TLiteral tgt_ty | _ -> raise (Failure "Invalid input for unop") in let res_sv = { sv_id = res_sv_id; sv_ty = res_sv_ty } in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index d33a2f18..2fe5843e 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -519,7 +519,7 @@ let unop_to_string (unop : unop) : string = | Not -> "¬" | Neg _ -> "-" | Cast (src, tgt) -> - "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt + "cast<" ^ literal_type_to_string src ^ "," ^ literal_type_to_string tgt ^ ">" let binop_to_string = Print.Expressions.binop_to_string diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 0ae83007..8d39cc69 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -540,7 +540,7 @@ and typed_pattern = { value : pattern; ty : ty } polymorphic = false; }] -type unop = Not | Neg of integer_type | Cast of integer_type * integer_type +type unop = Not | Neg of integer_type | Cast of literal_type * literal_type [@@deriving show, ord] (** Identifiers of assumed functions that we use only in the pure translation *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf4d26f2..84f09280 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1762,7 +1762,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with - | CastInteger (src_ty, tgt_ty) -> + | CastScalar (src_ty, tgt_ty) -> (* Note that cast can fail *) let effect_info = { -- cgit v1.2.3