summaryrefslogtreecommitdiff
path: root/src/ExtractToFStar.ml
diff options
context:
space:
mode:
authorSon Ho2022-02-02 22:59:24 +0100
committerSon Ho2022-02-02 22:59:24 +0100
commit6739ab801801519f118cbb992b04c57f77c0cd17 (patch)
tree58caf5dc56e0d8d14ab72f553f5cc67dbeb0394e /src/ExtractToFStar.ml
parent6ee61aa87a564768d954ad767673b2b25a340516 (diff)
Make minor modifications to extract mutually recursive types
Diffstat (limited to '')
-rw-r--r--src/ExtractToFStar.ml118
1 files changed, 77 insertions, 41 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 26316bc4..a1b56964 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -7,6 +7,19 @@ open PureToExtract
open StringUtils
module F = Format
+(** A qualifier for a type definition.
+
+ Controls whether we should use `type ...` or `and ...` (for mutually
+ recursive datatypes).
+ *)
+type type_def_qualif = Type | And
+
+(** A qualifier for function definitions.
+
+ Controls whether we should use `let ...`, `let rec ...` or `and ...`
+ *)
+type fun_def_qualif = Let | LetRec | And
+
(** A list of keywords/identifiers used in F* and with which we want to check
collision. *)
let fstar_keywords =
@@ -78,18 +91,20 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
| U64 -> "u64"
| U128 -> "u128"
in
- (* For now, we treat only the case where type and function names are of the
- * form: `Module::Type` and `Module:function`.
+ (* For now, we treat only the case where type names are of the
+ * form: `Module::Type`
*)
- let get_name (name : name) : string =
- match name with [ _module; name ] -> name | _ -> failwith "Unexpected"
+ let get_type_name (name : name) : string =
+ match name with
+ | [ _module; name ] -> name
+ | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name)
in
let type_name_to_camel_case name =
- let name = get_name name in
+ let name = get_type_name name in
to_camel_case name
in
let type_name_to_snake_case name =
- let name = get_name name in
+ let name = get_type_name name in
to_snake_case name
in
let type_name name = type_name_to_snake_case name ^ "_t" in
@@ -106,12 +121,17 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
type_name_to_camel_case def_name ^ variant
else variant
in
- (* For now, we only treat the case where the type name is:
- * `Module::Type`
+ (* For now, we treat only the case where function names are of the
+ * form: `function` (no module prefix)
*)
+ let get_fun_name (name : name) : string =
+ match name with
+ | [ name ] -> name
+ | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name)
+ in
let fun_name (_fid : A.fun_id) (fname : name) (num_rgs : int)
(rg : region_group_info option) : string =
- let fname = get_name fname in
+ let fname = get_fun_name fname in
(* Converting to snake case should be a no-op, but it doesn't cost much *)
let fname = to_snake_case fname in
(* Compute the suffix *)
@@ -139,7 +159,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
let def =
TypeDefId.Map.find adt_id ctx.type_context.type_defs
in
- StringUtils.string_of_chars [ (get_name def.name).[0] ])
+ StringUtils.string_of_chars [ (get_type_name def.name).[0] ])
| TypeVar _ -> "x" (* lacking imagination here... *)
| Bool -> "b"
| Char -> "c"
@@ -196,6 +216,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| Str -> F.pp_print_string fmt ctx.fmt.str_name
| Array _ | Slice _ -> raise Unimplemented
+(** Compute the names for all the top-level identifiers used in a type
+ definition (type name, variant names, field names, etc. but not type
+ parameters).
+
+ We need to do this preemptively, beforce extracting any definition,
+ because of recursive definitions.
+ *)
+let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) :
+ extraction_ctx =
+ (* Compute and register the type def name *)
+ let ctx, def_name = ctx_add_type_def def ctx in
+ (* Compute and register:
+ * - the variant names, if this is an enumeration
+ * - the field names, if this is a structure
+ *)
+ let ctx =
+ match def.kind with
+ | Struct fields ->
+ fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
+ | Enum variants ->
+ fst
+ (ctx_add_variants def
+ (VariantId.mapi (fun id v -> (id, v)) variants)
+ ctx)
+ in
+ (* Return *)
+ ctx
+
let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
(def : type_def) (fields : field list) : unit =
(* We want to generate a definition which looks like this:
@@ -354,41 +402,13 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
List.iter (fun (vid, v) -> print_variant vid v) variants
-(** Compute the names for all the top-level identifiers used in a type
- definition (type name, variant names, field names, etc. but not type
- parameters).
-
- We need to do this preemptively, beforce extracting any definition,
- because of recursive definitions.
- *)
-let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) :
- extraction_ctx =
- (* Compute and register the type def name *)
- let ctx, def_name = ctx_add_type_def def ctx in
- (* Compute and register:
- * - the variant names, if this is an enumeration
- * - the field names, if this is a structure
- *)
- let ctx =
- match def.kind with
- | Struct fields ->
- fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
- | Enum variants ->
- fst
- (ctx_add_variants def
- (VariantId.mapi (fun id v -> (id, v)) variants)
- ctx)
- in
- (* Return *)
- ctx
-
(** Extract a type definition.
Note that all the names used for extraction should already have been
registered.
*)
-let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
- : unit =
+let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
+ (qualif : type_def_qualif) (def : type_def) : unit =
(* Retrieve the definition name *)
let def_name = ctx_get_local_type def.def_id ctx in
(* Add the type params - note that we remember those bindings only for the
@@ -406,7 +426,8 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
(* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* > "type TYPE_NAME" *)
- F.pp_print_string fmt ("type " ^ def_name);
+ let qualif = match qualif with Type -> "type" | And -> "and" in
+ F.pp_print_string fmt (qualif ^ " " ^ def_name);
(* Print the type parameters *)
if def.type_params <> [] then (
F.pp_print_space fmt ();
@@ -433,3 +454,18 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0
+
+(** Compute the names for all the pure functions generated from a rust function
+ (forward function and backward functions).
+ *)
+let extract_fun_def_register_names (ctx : extraction_ctx)
+ (def : pure_fun_translation) : extraction_ctx =
+ let fwd, back_ls = def in
+ (* Register the forward function name *)
+ let ctx = ctx_add_fun_def fwd ctx in
+ (* Register the backward functions' names *)
+ let ctx =
+ List.fold_left (fun ctx back -> ctx_add_fun_def back ctx) ctx back_ls
+ in
+ (* Return *)
+ ctx