summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r--compiler/Extract.ml90
1 files changed, 56 insertions, 34 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index b16f9639..3cb2be2c 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -41,7 +41,9 @@ let unop_name (unop : unop) : string =
match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~")
| Neg (int_ty : integer_type) -> (
match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg")
- | Cast _ -> raise (Failure "Unsupported")
+ | Cast _ ->
+ (* We never directly use the unop name in this case *)
+ raise (Failure "Unsupported")
(** Small helper to compute the name of a binary operation (note that many
binary operations like "less than" are extracted to primitive operations,
@@ -722,7 +724,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| None -> (
(* No basename: we use the first letter of the type *)
match ty with
- | Adt (type_id, tys) -> (
+ | Adt (type_id, tys, _) -> (
match type_id with
| Tuple ->
(* The "pair" case is frequent enough to have its special treatment *)
@@ -732,6 +734,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| Assumed Fuel -> ConstStrings.fuel_basename
| Assumed Option -> "opt"
| Assumed Vec -> "v"
+ | Assumed Array -> "a"
+ | Assumed Slice -> "s"
+ | Assumed Str -> "s"
+ | Assumed Range -> "r"
| Assumed State -> ConstStrings.state_basename
| AdtId adt_id ->
let def =
@@ -757,12 +763,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
match !backend with
| FStar -> "x" (* lacking inspiration here... *)
| Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *))
- | Bool -> "b"
- | Char -> "c"
- | Integer _ -> "i"
- | Str -> "s"
- | Arrow _ -> "f"
- | Array _ | Slice _ -> raise Unimplemented)
+ | Literal lty -> (
+ match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i")
+ | Arrow _ -> "f")
in
let type_var_basename (_varset : StringSet.t) (basename : string) : string =
(* Rust type variables are snake-case and start with a capital letter *)
@@ -775,12 +778,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
"'" ^ to_snake_case basename
| Coq | Lean -> basename
in
+ let const_generic_var_basename (_varset : StringSet.t) (basename : string) :
+ string =
+ (* Rust type variables are snake-case and start with a capital letter *)
+ match !backend with
+ | FStar | HOL4 ->
+ (* This is *not* a no-op: this removes the capital letter *)
+ to_snake_case basename
+ | Coq | Lean -> basename
+ in
let append_index (basename : string) (i : int) : string =
basename ^ string_of_int i
in
- let extract_primitive_value (fmt : F.formatter) (inside : bool)
- (cv : primitive_value) : unit =
+ let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit
+ =
match cv with
| Scalar sv -> (
match !backend with
@@ -847,14 +859,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
in
F.pp_print_string fmt c;
if inside then F.pp_print_string fmt ")")
- | String s ->
- (* We need to replace all the line breaks *)
- let s =
- StringUtils.map
- (fun c -> if c = '\n' then "\n" else String.make 1 c)
- s
- in
- F.pp_print_string fmt ("\"" ^ s ^ "\"")
in
let bool_name = if !backend = Lean then "Bool" else "bool" in
let char_name = if !backend = Lean then "Char" else "char" in
@@ -877,8 +881,9 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
opaque_pre;
var_basename;
type_var_basename;
+ const_generic_var_basename;
append_index;
- extract_primitive_value;
+ extract_literal;
extract_unop;
extract_binop;
}
@@ -1043,6 +1048,17 @@ let extract_arrow (fmt : F.formatter) () : unit =
if !Config.backend = Lean then F.pp_print_string fmt "→"
else F.pp_print_string fmt "->"
+let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter)
+ (inside : bool) (cg : const_generic) : unit =
+ match cg with
+ | ConstGenericGlobal id ->
+ let s = ctx_get_global ctx.use_opaque_pre id ctx in
+ F.pp_print_string fmt s
+ | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v
+ | ConstGenericVar id ->
+ let s = ctx_get_const_generic_var id ctx in
+ F.pp_print_string fmt s
+
(** [inside] constrols whether we should add parentheses or not around type
applications (if [true] we add parentheses).
@@ -1067,7 +1083,8 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit =
let extract_rec = extract_ty ctx fmt no_params_tys in
match ty with
- | Adt (type_id, tys) -> (
+ | Adt (type_id, tys, cgs) -> (
+ let has_params = tys <> [] || cgs <> [] in
match type_id with
| Tuple ->
(* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type:
@@ -1099,7 +1116,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
let with_opaque_pre = false in
match !backend with
| FStar | Coq | Lean ->
- let print_paren = inside && tys <> [] in
+ let print_paren = inside && has_params in
if print_paren then F.pp_print_string fmt "(";
(* TODO: for now, only the opaque *functions* are extracted in the
opaque module. The opaque *types* are assumed. *)
@@ -1108,6 +1125,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
Collections.List.iter_link (F.pp_print_space fmt)
(extract_rec true) tys);
+ if cgs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_const_generic ctx fmt true)
+ cgs);
if print_paren then F.pp_print_string fmt ")"
| HOL4 ->
let print_tys =
@@ -1128,10 +1150,9 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ());
F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx)))
| TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
- | Bool -> F.pp_print_string fmt ctx.fmt.bool_name
- | Char -> F.pp_print_string fmt ctx.fmt.char_name
- | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
- | Str -> F.pp_print_string fmt ctx.fmt.str_name
+ | Literal Bool -> F.pp_print_string fmt ctx.fmt.bool_name
+ | Literal Char -> F.pp_print_string fmt ctx.fmt.char_name
+ | Literal (Integer int_ty) -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
| Arrow (arg_ty, ret_ty) ->
if inside then F.pp_print_string fmt "(";
extract_rec false arg_ty;
@@ -1140,7 +1161,6 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
extract_rec false ret_ty;
if inside then F.pp_print_string fmt ")"
- | Array _ | Slice _ -> raise Unimplemented
(** Compute the names for all the top-level identifiers used in a type
definition (type name, variant names, field names, etc. but not type
@@ -2016,10 +2036,11 @@ let extract_adt_g_value
(inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
(ty : ty) : extraction_ctx =
match ty with
- | Adt (Tuple, type_args) ->
+ | Adt (Tuple, type_args, cg_args) ->
(* Tuple *)
(* For now, we only support fully applied tuple constructors *)
assert (List.length type_args = List.length field_values);
+ assert (cg_args = []);
(* This is very annoying: in Coq, we can't write [()] for the value of
type [unit], we have to write [tt]. *)
if !backend = Coq && field_values = [] then (
@@ -2037,7 +2058,7 @@ let extract_adt_g_value
in
F.pp_print_string fmt ")";
ctx)
- | Adt (adt_id, _) ->
+ | Adt (adt_id, _, _) ->
(* "Regular" ADT *)
(* If we are generating a pattern for a let-binding and we target Lean,
@@ -2107,7 +2128,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
(is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx =
match v.value with
| PatConstant cv ->
- ctx.fmt.extract_primitive_value fmt inside cv;
+ ctx.fmt.extract_literal fmt inside cv;
ctx
| PatVar (v, _) ->
let vname =
@@ -2142,7 +2163,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Var var_id ->
let var_name = ctx_get_var var_id ctx in
F.pp_print_string fmt var_name
- | Const cv -> ctx.fmt.extract_primitive_value fmt inside cv
+ | Const cv -> ctx.fmt.extract_literal fmt inside cv
| App _ ->
let app, args = destruct_apps e in
extract_App ctx fmt inside app args
@@ -2175,7 +2196,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
extract_function_call ctx fmt inside fun_id qualif.type_args args
| Global global_id -> extract_global ctx fmt global_id
| AdtCons adt_cons_id ->
- extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args
+ extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args
+ qualif.const_generic_args args
| Proj proj ->
extract_field_projector ctx fmt inside app proj qualif.type_args args)
| _ ->
@@ -2250,9 +2272,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
(** Subcase of the app case: ADT constructor *)
and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
- (adt_cons : adt_cons_id) (type_args : ty list) (args : texpression list) :
- unit =
- let e_ty = Adt (adt_cons.adt_id, type_args) in
+ (adt_cons : adt_cons_id) (type_args : ty list)
+ (cg_args : const_generic list) (args : texpression list) : unit =
+ let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in
let is_single_pat = false in
let _ =
extract_adt_g_value