summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-02 22:59:24 +0100
committerSon Ho2022-02-02 22:59:24 +0100
commit6739ab801801519f118cbb992b04c57f77c0cd17 (patch)
tree58caf5dc56e0d8d14ab72f553f5cc67dbeb0394e /src
parent6ee61aa87a564768d954ad767673b2b25a340516 (diff)
Make minor modifications to extract mutually recursive types
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml118
-rw-r--r--src/PureToExtract.ml26
-rw-r--r--src/Translate.ml24
3 files changed, 122 insertions, 46 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 26316bc4..a1b56964 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -7,6 +7,19 @@ open PureToExtract
open StringUtils
module F = Format
+(** A qualifier for a type definition.
+
+ Controls whether we should use `type ...` or `and ...` (for mutually
+ recursive datatypes).
+ *)
+type type_def_qualif = Type | And
+
+(** A qualifier for function definitions.
+
+ Controls whether we should use `let ...`, `let rec ...` or `and ...`
+ *)
+type fun_def_qualif = Let | LetRec | And
+
(** A list of keywords/identifiers used in F* and with which we want to check
collision. *)
let fstar_keywords =
@@ -78,18 +91,20 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
| U64 -> "u64"
| U128 -> "u128"
in
- (* For now, we treat only the case where type and function names are of the
- * form: `Module::Type` and `Module:function`.
+ (* For now, we treat only the case where type names are of the
+ * form: `Module::Type`
*)
- let get_name (name : name) : string =
- match name with [ _module; name ] -> name | _ -> failwith "Unexpected"
+ let get_type_name (name : name) : string =
+ match name with
+ | [ _module; name ] -> name
+ | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name)
in
let type_name_to_camel_case name =
- let name = get_name name in
+ let name = get_type_name name in
to_camel_case name
in
let type_name_to_snake_case name =
- let name = get_name name in
+ let name = get_type_name name in
to_snake_case name
in
let type_name name = type_name_to_snake_case name ^ "_t" in
@@ -106,12 +121,17 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
type_name_to_camel_case def_name ^ variant
else variant
in
- (* For now, we only treat the case where the type name is:
- * `Module::Type`
+ (* For now, we treat only the case where function names are of the
+ * form: `function` (no module prefix)
*)
+ let get_fun_name (name : name) : string =
+ match name with
+ | [ name ] -> name
+ | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name)
+ in
let fun_name (_fid : A.fun_id) (fname : name) (num_rgs : int)
(rg : region_group_info option) : string =
- let fname = get_name fname in
+ let fname = get_fun_name fname in
(* Converting to snake case should be a no-op, but it doesn't cost much *)
let fname = to_snake_case fname in
(* Compute the suffix *)
@@ -139,7 +159,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
let def =
TypeDefId.Map.find adt_id ctx.type_context.type_defs
in
- StringUtils.string_of_chars [ (get_name def.name).[0] ])
+ StringUtils.string_of_chars [ (get_type_name def.name).[0] ])
| TypeVar _ -> "x" (* lacking imagination here... *)
| Bool -> "b"
| Char -> "c"
@@ -196,6 +216,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| Str -> F.pp_print_string fmt ctx.fmt.str_name
| Array _ | Slice _ -> raise Unimplemented
+(** Compute the names for all the top-level identifiers used in a type
+ definition (type name, variant names, field names, etc. but not type
+ parameters).
+
+ We need to do this preemptively, beforce extracting any definition,
+ because of recursive definitions.
+ *)
+let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) :
+ extraction_ctx =
+ (* Compute and register the type def name *)
+ let ctx, def_name = ctx_add_type_def def ctx in
+ (* Compute and register:
+ * - the variant names, if this is an enumeration
+ * - the field names, if this is a structure
+ *)
+ let ctx =
+ match def.kind with
+ | Struct fields ->
+ fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
+ | Enum variants ->
+ fst
+ (ctx_add_variants def
+ (VariantId.mapi (fun id v -> (id, v)) variants)
+ ctx)
+ in
+ (* Return *)
+ ctx
+
let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
(def : type_def) (fields : field list) : unit =
(* We want to generate a definition which looks like this:
@@ -354,41 +402,13 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
List.iter (fun (vid, v) -> print_variant vid v) variants
-(** Compute the names for all the top-level identifiers used in a type
- definition (type name, variant names, field names, etc. but not type
- parameters).
-
- We need to do this preemptively, beforce extracting any definition,
- because of recursive definitions.
- *)
-let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) :
- extraction_ctx =
- (* Compute and register the type def name *)
- let ctx, def_name = ctx_add_type_def def ctx in
- (* Compute and register:
- * - the variant names, if this is an enumeration
- * - the field names, if this is a structure
- *)
- let ctx =
- match def.kind with
- | Struct fields ->
- fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
- | Enum variants ->
- fst
- (ctx_add_variants def
- (VariantId.mapi (fun id v -> (id, v)) variants)
- ctx)
- in
- (* Return *)
- ctx
-
(** Extract a type definition.
Note that all the names used for extraction should already have been
registered.
*)
-let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
- : unit =
+let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
+ (qualif : type_def_qualif) (def : type_def) : unit =
(* Retrieve the definition name *)
let def_name = ctx_get_local_type def.def_id ctx in
(* Add the type params - note that we remember those bindings only for the
@@ -406,7 +426,8 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
(* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* > "type TYPE_NAME" *)
- F.pp_print_string fmt ("type " ^ def_name);
+ let qualif = match qualif with Type -> "type" | And -> "and" in
+ F.pp_print_string fmt (qualif ^ " " ^ def_name);
(* Print the type parameters *)
if def.type_params <> [] then (
F.pp_print_space fmt ();
@@ -433,3 +454,18 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0
+
+(** Compute the names for all the pure functions generated from a rust function
+ (forward function and backward functions).
+ *)
+let extract_fun_def_register_names (ctx : extraction_ctx)
+ (def : pure_fun_translation) : extraction_ctx =
+ let fwd, back_ls = def in
+ (* Register the forward function name *)
+ let ctx = ctx_add_fun_def fwd ctx in
+ (* Register the backward functions' names *)
+ let ctx =
+ List.fold_left (fun ctx back -> ctx_add_fun_def back ctx) ctx back_ls
+ in
+ (* Return *)
+ ctx
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index c36ed8fe..226f178a 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -386,3 +386,29 @@ let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list)
List.fold_left_map
(fun ctx (vid, v) -> ctx_add_variant def vid v ctx)
ctx variants
+
+let ctx_add_fun_def (def : fun_def) (ctx : extraction_ctx) : extraction_ctx =
+ (* Lookup the CFIM def to compute the region group information *)
+ let def_id = def.def_id in
+ let cfim_def = FunDefId.Map.find def_id ctx.trans_ctx.fun_context.fun_defs in
+ let sg = cfim_def.signature in
+ let num_rgs = List.length sg.regions_hierarchy in
+ let rg_info =
+ match def.back_id with
+ | None -> None
+ | Some rg_id ->
+ let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in
+ let regions =
+ List.map
+ (fun rid -> T.RegionVarId.nth sg.region_params rid)
+ rg.regions
+ in
+ let region_names =
+ List.map (fun (r : T.region_var) -> r.name) regions
+ in
+ Some { id = rg_id; region_names }
+ in
+ let def_id = A.Local def_id in
+ let name = ctx.fmt.fun_name def_id def.basename num_rgs rg_info in
+ let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in
+ ctx
diff --git a/src/Translate.ml b/src/Translate.ml
index 75975704..b840b7bc 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -288,7 +288,12 @@ let translate_module (filename : string) (config : C.partial_config)
extract_ctx trans_types
in
- (* TODO: register the functions *)
+ let extract_ctx =
+ List.fold_left
+ (fun extract_ctx def ->
+ ExtractToFStar.extract_fun_def_register_names extract_ctx def)
+ extract_ctx trans_funs
+ in
(* Open the output file *)
(* First compute the filename by replacing the extension and converting the
@@ -340,18 +345,27 @@ let translate_module (filename : string) (config : C.partial_config)
Format.pp_print_break fmt 0 0;
(* Export the definition groups to the file, in the proper order *)
- let export_type (id : Pure.TypeDefId.id) : unit =
+ let export_type (qualif : ExtractToFStar.type_def_qualif)
+ (id : Pure.TypeDefId.id) : unit =
let def = Pure.TypeDefId.Map.find id trans_types in
- ExtractToFStar.extract_type_def extract_ctx fmt def
+ ExtractToFStar.extract_type_def extract_ctx fmt qualif def
in
let export_function (id : Pure.FunDefId.id) : unit =
(* TODO *)
+ (* let pure_defs = Pure.FunDefId.Map.find id trans_funs in *)
()
in
let export_decl (decl : M.declaration_group) : unit =
match decl with
- | Type (NonRec id) -> export_type id
- | Type (Rec ids) -> List.iter export_type ids
+ | Type (NonRec id) -> export_type ExtractToFStar.Type id
+ | Type (Rec ids) ->
+ List.iteri
+ (fun i id ->
+ let qualif =
+ if i = 0 then ExtractToFStar.Type else ExtractToFStar.And
+ in
+ export_type qualif id)
+ ids
| Fun (NonRec id) -> export_function id
| Fun (Rec ids) -> List.iter export_function ids
in