diff options
-rw-r--r-- | compiler/Extract.ml | 90 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 23 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 1 | ||||
-rw-r--r-- | compiler/Pure.ml | 6 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 13 |
5 files changed, 80 insertions, 53 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index b16f9639..3cb2be2c 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -41,7 +41,9 @@ let unop_name (unop : unop) : string = match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~") | Neg (int_ty : integer_type) -> ( match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg") - | Cast _ -> raise (Failure "Unsupported") + | Cast _ -> + (* We never directly use the unop name in this case *) + raise (Failure "Unsupported") (** Small helper to compute the name of a binary operation (note that many binary operations like "less than" are extracted to primitive operations, @@ -722,7 +724,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | None -> ( (* No basename: we use the first letter of the type *) match ty with - | Adt (type_id, tys) -> ( + | Adt (type_id, tys, _) -> ( match type_id with | Tuple -> (* The "pair" case is frequent enough to have its special treatment *) @@ -732,6 +734,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) | Assumed Fuel -> ConstStrings.fuel_basename | Assumed Option -> "opt" | Assumed Vec -> "v" + | Assumed Array -> "a" + | Assumed Slice -> "s" + | Assumed Str -> "s" + | Assumed Range -> "r" | Assumed State -> ConstStrings.state_basename | AdtId adt_id -> let def = @@ -757,12 +763,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) match !backend with | FStar -> "x" (* lacking inspiration here... *) | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *)) - | Bool -> "b" - | Char -> "c" - | Integer _ -> "i" - | Str -> "s" - | Arrow _ -> "f" - | Array _ | Slice _ -> raise Unimplemented) + | Literal lty -> ( + match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i") + | Arrow _ -> "f") in let type_var_basename (_varset : StringSet.t) (basename : string) : string = (* Rust type variables are snake-case and start with a capital letter *) @@ -775,12 +778,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) "'" ^ to_snake_case basename | Coq | Lean -> basename in + let const_generic_var_basename (_varset : StringSet.t) (basename : string) : + string = + (* Rust type variables are snake-case and start with a capital letter *) + match !backend with + | FStar | HOL4 -> + (* This is *not* a no-op: this removes the capital letter *) + to_snake_case basename + | Coq | Lean -> basename + in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i in - let extract_primitive_value (fmt : F.formatter) (inside : bool) - (cv : primitive_value) : unit = + let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit + = match cv with | Scalar sv -> ( match !backend with @@ -847,14 +859,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) in F.pp_print_string fmt c; if inside then F.pp_print_string fmt ")") - | String s -> - (* We need to replace all the line breaks *) - let s = - StringUtils.map - (fun c -> if c = '\n' then "\n" else String.make 1 c) - s - in - F.pp_print_string fmt ("\"" ^ s ^ "\"") in let bool_name = if !backend = Lean then "Bool" else "bool" in let char_name = if !backend = Lean then "Char" else "char" in @@ -877,8 +881,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) opaque_pre; var_basename; type_var_basename; + const_generic_var_basename; append_index; - extract_primitive_value; + extract_literal; extract_unop; extract_binop; } @@ -1043,6 +1048,17 @@ let extract_arrow (fmt : F.formatter) () : unit = if !Config.backend = Lean then F.pp_print_string fmt "→" else F.pp_print_string fmt "->" +let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (cg : const_generic) : unit = + match cg with + | ConstGenericGlobal id -> + let s = ctx_get_global ctx.use_opaque_pre id ctx in + F.pp_print_string fmt s + | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v + | ConstGenericVar id -> + let s = ctx_get_const_generic_var id ctx in + F.pp_print_string fmt s + (** [inside] constrols whether we should add parentheses or not around type applications (if [true] we add parentheses). @@ -1067,7 +1083,8 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit = let extract_rec = extract_ty ctx fmt no_params_tys in match ty with - | Adt (type_id, tys) -> ( + | Adt (type_id, tys, cgs) -> ( + let has_params = tys <> [] || cgs <> [] in match type_id with | Tuple -> (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type: @@ -1099,7 +1116,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) let with_opaque_pre = false in match !backend with | FStar | Coq | Lean -> - let print_paren = inside && tys <> [] in + let print_paren = inside && has_params in if print_paren then F.pp_print_string fmt "("; (* TODO: for now, only the opaque *functions* are extracted in the opaque module. The opaque *types* are assumed. *) @@ -1108,6 +1125,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); Collections.List.iter_link (F.pp_print_space fmt) (extract_rec true) tys); + if cgs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_const_generic ctx fmt true) + cgs); if print_paren then F.pp_print_string fmt ")" | HOL4 -> let print_tys = @@ -1128,10 +1150,9 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt ()); F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx))) | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) - | Bool -> F.pp_print_string fmt ctx.fmt.bool_name - | Char -> F.pp_print_string fmt ctx.fmt.char_name - | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) - | Str -> F.pp_print_string fmt ctx.fmt.str_name + | Literal Bool -> F.pp_print_string fmt ctx.fmt.bool_name + | Literal Char -> F.pp_print_string fmt ctx.fmt.char_name + | Literal (Integer int_ty) -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) | Arrow (arg_ty, ret_ty) -> if inside then F.pp_print_string fmt "("; extract_rec false arg_ty; @@ -1140,7 +1161,6 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" - | Array _ | Slice _ -> raise Unimplemented (** Compute the names for all the top-level identifiers used in a type definition (type name, variant names, field names, etc. but not type @@ -2016,10 +2036,11 @@ let extract_adt_g_value (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = match ty with - | Adt (Tuple, type_args) -> + | Adt (Tuple, type_args, cg_args) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) assert (List.length type_args = List.length field_values); + assert (cg_args = []); (* This is very annoying: in Coq, we can't write [()] for the value of type [unit], we have to write [tt]. *) if !backend = Coq && field_values = [] then ( @@ -2037,7 +2058,7 @@ let extract_adt_g_value in F.pp_print_string fmt ")"; ctx) - | Adt (adt_id, _) -> + | Adt (adt_id, _, _) -> (* "Regular" ADT *) (* If we are generating a pattern for a let-binding and we target Lean, @@ -2107,7 +2128,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = match v.value with | PatConstant cv -> - ctx.fmt.extract_primitive_value fmt inside cv; + ctx.fmt.extract_literal fmt inside cv; ctx | PatVar (v, _) -> let vname = @@ -2142,7 +2163,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Var var_id -> let var_name = ctx_get_var var_id ctx in F.pp_print_string fmt var_name - | Const cv -> ctx.fmt.extract_primitive_value fmt inside cv + | Const cv -> ctx.fmt.extract_literal fmt inside cv | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args @@ -2175,7 +2196,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) extract_function_call ctx fmt inside fun_id qualif.type_args args | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> - extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args + extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args + qualif.const_generic_args args | Proj proj -> extract_field_projector ctx fmt inside app proj qualif.type_args args) | _ -> @@ -2250,9 +2272,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (** Subcase of the app case: ADT constructor *) and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) : - unit = - let e_ty = Adt (adt_cons.adt_id, type_args) in + (adt_cons : adt_cons_id) (type_args : ty list) + (cg_args : const_generic list) (args : texpression list) : unit = + let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in let is_single_pat = false in let _ = extract_adt_g_value diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index bff6a360..3ba507a6 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -286,6 +286,8 @@ type formatter = { *) type_var_basename : StringSet.t -> string -> string; (** Generates a type variable basename. *) + const_generic_var_basename : StringSet.t -> string -> string; + (** Generates a const generic variable basename. *) append_index : string -> int -> string; (** Appends an index to a name - we use this to generate unique names: when doing so, the role of the formatter is just to concatenate @@ -392,6 +394,7 @@ type id = them here. *) | TypeVarId of TypeVarId.id + | ConstGenericVarId of ConstGenericVarId.id | VarId of VarId.id | UnknownId (** Used for stored various strings like keywords, definitions which @@ -710,6 +713,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = "field name: " ^ field_name | UnknownId -> "keyword" | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id + | ConstGenericVarId id -> + "const_generic_var_id: " ^ ConstGenericVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id (** We might not check for collisions for some specific ids (ex.: field names) *) @@ -788,6 +793,11 @@ let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = let is_opaque = false in ctx_get is_opaque (TypeVarId id) ctx +let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx) + : string = + let is_opaque = false in + ctx_get is_opaque (ConstGenericVarId id) ctx + let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = let is_opaque = false in @@ -823,6 +833,19 @@ let ctx_add_type_var (basename : string) (id : TypeVarId.id) let ctx = ctx_add is_opaque (TypeVarId id) name ctx in (ctx, name) +(** Generate a unique const generic variable name and add it to the context *) +let ctx_add_const_generic_var (basename : string) (id : ConstGenericVarId.id) + (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in + let name = + ctx.fmt.const_generic_var_basename ctx.names_map.names_set basename + in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name + in + let ctx = ctx_add is_opaque (ConstGenericVarId id) name ctx in + (ctx, name) + (** See {!ctx_add_type_var} *) let ctx_add_type_vars (vars : (string * TypeVarId.id) list) (ctx : extraction_ctx) : extraction_ctx * string list = diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 8fb6d644..211fb2c2 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -543,7 +543,6 @@ let unop_to_string (unop : unop) : string = | Cast (src, tgt) -> "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt ^ ">" - | SliceNew tgt_len -> "array_to_slice<" ^ scalar_value_to_string tgt_len ^ ">" let binop_to_string = Print.Expressions.binop_to_string diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 551ebf7b..e202b170 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -344,11 +344,7 @@ and typed_pattern = { value : pattern; ty : ty } polymorphic = false; }] -type unop = - | Not - | Neg of integer_type - | Cast of integer_type * integer_type - | SliceNew of scalar_value +type unop = Not | Neg of integer_type | Cast of integer_type * integer_type [@@deriving show, ord] (** Identifiers of assumed functions that we use only in the pure translation *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 958c1bc8..5e47459d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1537,19 +1537,6 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) - | S.Unop (E.SliceNew tgt_len) -> - (* The cast can fail if the length of the source array is not - big enough *) - let effect_info = - { - can_fail = true; - stateful_group = false; - stateful = false; - can_diverge = false; - is_rec = false; - } - in - (ctx, Unop (SliceNew tgt_len), effect_info, args, None) | S.Binop binop -> ( match args with | [ arg0; arg1 ] -> |