summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/ExtractToFStar.ml142
-rw-r--r--src/PureToExtract.ml90
2 files changed, 184 insertions, 48 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 56a8c338..35d15607 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -37,11 +37,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
F.pp_print_string fmt ")"
| AdtId _ | Assumed _ ->
if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt (ctx_find_type type_id ctx);
+ F.pp_print_string fmt (ctx_get_type type_id ctx);
if tys <> [] then F.pp_print_space fmt ();
list_iterb (F.pp_print_space fmt) (extract_ty ctx fmt true) tys;
if inside then F.pp_print_string fmt ")")
- | TypeVar vid -> F.pp_print_string fmt (ctx_find_type_var vid ctx)
+ | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
| Bool -> F.pp_print_string fmt ctx.fmt.bool_name
| Char -> F.pp_print_string fmt ctx.fmt.char_name
| Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
@@ -49,21 +49,20 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| Array _ | Slice _ -> raise Unimplemented
let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
- (type_name : string) (type_params : string list) (fields : field list) :
- unit =
+ (def : type_def) (fields : field list) : unit =
(* We want to generate a definition which looks like this:
* ```
- * type s = { x : int; y : bool; }
+ * type t = { x : int; y : bool; }
* ```
*
* Or if there isn't enough space on one line:
* ```
- * type s = {
+ * type t = {
* x : int;
* y : bool;
* }
* ```
- * Note that we already printed: `type s =`
+ * Note that we already printed: `type t =`
*)
F.pp_print_space fmt ();
F.pp_print_string fmt "{";
@@ -72,8 +71,8 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hvbox fmt ctx.indent_incr;
F.pp_print_space fmt ();
(* Print the fields *)
- let print_field (f : field) : unit =
- let field_name = ctx.fmt.field_name type_name f.field_name in
+ let print_field (field_id : FieldId.id) (f : field) : unit =
+ let field_name = ctx_get_field def.def_id field_id ctx in
F.pp_open_box fmt ctx.indent_incr;
F.pp_print_string fmt field_name;
F.pp_print_space fmt ();
@@ -82,27 +81,131 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
extract_ty ctx fmt false f.field_ty;
F.pp_close_box fmt ()
in
- List.iter print_field fields;
+ let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
+ List.iter (fun (fid, f) -> print_field fid f) fields;
(* Close *)
F.pp_close_box fmt ();
F.pp_print_string fmt "}";
F.pp_close_box fmt ()
let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
- (type_name : string) (type_params : string list) (variants : variant list) :
- unit =
- raise Unimplemented
+ (def : type_def) (type_name : string) (type_params : string list)
+ (variants : variant list) : unit =
+ (* We want to generate a definition which looks like this:
+ * ```
+ * type list a = | Cons : a -> list a -> list a | Nil : list a
+ * ```
+ *
+ * If there isn't enough space on one line:
+ * ```
+ * type s =
+ * | Cons : a -> list a -> list a
+ * | Nil : list a
+ * ```
+ *
+ * And if we need to write the type of a variant on several lines:
+ * ```
+ * type s =
+ * | Cons :
+ * a ->
+ * list a ->
+ * list a
+ * | Nil : list a
+ * ```
+ *
+ * Finally, it is possible to give names to the variant fields in Rust.
+ * In this situation, we generate a definition like this:
+ * ```
+ * type s =
+ * | Cons : hd:a -> tl:list a -> list a
+ * | Nil : list a
+ * ```
+ *
+ * Note that we already printed: `type s =`
+ *)
+ (* Open the body box *)
+ F.pp_open_hvbox fmt 0;
+ (* Print the variants *)
+ let print_variant (variant_id : VariantId.id) (variant : variant) : unit =
+ let variant_name = ctx_get_variant def.def_id variant_id ctx in
+ F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ (* variant box *)
+ (* `| Cons :`
+ * Note that we really don't want any break above so we print everything
+ * at once. *)
+ F.pp_print_string fmt ("| " ^ variant_name ^ " :");
+ F.pp_print_space fmt ();
+ let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) :
+ extraction_ctx =
+ (* Open the field box *)
+ F.pp_open_box fmt ctx.indent_incr;
+ (* Print the field names
+ * ` x :`
+ * Note that when printing fields, we register the field names as
+ * *variables*: they don't need to be unique at the top level. *)
+ let ctx =
+ match f.field_name with
+ | None -> ctx
+ | Some field_name ->
+ let var_id = VarId.of_int (FieldId.to_int fid) in
+ let ctx, field_name = ctx_add_var field_name var_id ctx in
+ F.pp_print_string fmt (field_name ^ " :");
+ F.pp_space fmt ();
+ ctx
+ in
+ (* Print the field type *)
+ extract_ty ctx fmt false f.field_ty;
+ (* Print the arrow `->`*)
+ F.pp_space fmt ();
+ F.pp_print_string "->";
+ (* Close the field box *)
+ F.pp_close_box fmt ();
+ F.pp_space fmt ();
+ (* Return *)
+ ctx
+ in
+ (* Print the fields *)
+ let fields = FieldId.mapi (fun fid f -> (fid, f)) variant.fields in
+ let ctx =
+ List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields
+ in
+ (* Print the final type *)
+ F.pp_open_hovbox fmt ();
+ F.pp_string fmt def_name;
+ List.iter
+ (fun type_param ->
+ F.pp_space fmt ();
+ F.pp_string fmt type_param)
+ type_params;
+ F.pp_close_hovbox fmt ();
+ (* Close the variant box *)
+ F.pp_close_box fmt ()
+ in
+ (* Print the variants *)
+ let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
+ List.iter (fun (vid, v) -> print_variant vid v) variants;
+ (* Close the body box *)
+ F.pp_close_box fmt ()
let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
(def : type_def) : extraction_ctx =
(* Compute and register the type def name *)
let ctx, def_name = ctx_add_type_def def ctx in
- (* Compute and register the variant names, if this is an enumeration *)
- let ctx, variant_names =
+ (* Compute and register:
+ * - the variant names, if this is an enumeration
+ * - the field names, if this is a structure
+ * We do this because in F*, they have to be unique at the top-level.
+ *)
+ let ctx =
match def.kind with
- | Struct _ -> (ctx, [])
+ | Struct fields ->
+ fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
| Enum variants ->
- ctx_add_variants def (VariantId.mapi (fun id v -> (id, v)) variants) ctx
+ fst
+ (ctx_add_variants def
+ (VariantId.mapi (fun id v -> (id, v)) variants)
+ ctx)
in
(* Add the type params - note that we remember those bindings only for the
* body translation: the updated ctx we return at the end of the function
@@ -113,8 +216,7 @@ let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
F.pp_print_string fmt def_name;
(match def.kind with
- | Struct fields ->
- extract_type_def_struct_body ctx_body fmt def_name type_params fields
+ | Struct fields -> extract_type_def_struct_body ctx_body fmt fields
| Enum variants ->
- extract_type_def_enum_body ctx_body fmt def_name type_params variants);
+ extract_type_def_enum_body ctx_body fmt def def_name type_params variants);
ctx
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index f9c021fb..f2c03b90 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -35,7 +35,7 @@ type name_formatter = {
char_name : string;
int_name : integer_type -> string;
str_name : string;
- field_name : string -> string -> string;
+ field_name : name -> string -> string;
(** Inputs:
- type name
- field name
@@ -160,6 +160,11 @@ type id =
(** If often happens that variant names must be unique (it is the case in
F* ) which is why we register them here.
*)
+ | FieldId of TypeDefId.id * FieldId.id
+ (** If often happens that in the case of structures, the field names
+ must be unique (it is the case in F* ) which is why we register
+ them here.
+ *)
| TypeVarId of TypeVarId.id
| VarId of VarId.id
| UnknownId
@@ -207,29 +212,29 @@ let names_map_add (id : id) (name : string) (nm : names_map) : names_map =
{ id_to_name; name_to_id; names_set }
(* TODO: remove those functions? We use the ones of extraction_ctx *)
-let names_map_find (id : id) (nm : names_map) : string =
+let names_map_get (id : id) (nm : names_map) : string =
IdMap.find id nm.id_to_name
-let names_map_find_function (id : A.fun_id) (rg : RegionGroupId.id option)
+let names_map_get_function (id : A.fun_id) (rg : RegionGroupId.id option)
(nm : names_map) : string =
- names_map_find (FunId (id, rg)) nm
+ names_map_get (FunId (id, rg)) nm
-let names_map_find_local_function (id : FunDefId.id)
+let names_map_get_local_function (id : FunDefId.id)
(rg : RegionGroupId.id option) (nm : names_map) : string =
- names_map_find_function (A.Local id) rg nm
+ names_map_get_function (A.Local id) rg nm
-let names_map_find_type (id : type_id) (nm : names_map) : string =
+let names_map_get_type (id : type_id) (nm : names_map) : string =
assert (id <> Tuple);
- names_map_find (TypeId id) nm
+ names_map_get (TypeId id) nm
-let names_map_find_local_type (id : TypeDefId.id) (nm : names_map) : string =
- names_map_find_type (AdtId id) nm
+let names_map_get_local_type (id : TypeDefId.id) (nm : names_map) : string =
+ names_map_get_type (AdtId id) nm
-let names_map_find_var (id : VarId.id) (nm : names_map) : string =
- names_map_find (VarId id) nm
+let names_map_get_var (id : VarId.id) (nm : names_map) : string =
+ names_map_get (VarId id) nm
-let names_map_find_type_var (id : TypeVarId.id) (nm : names_map) : string =
- names_map_find (TypeVarId id) nm
+let names_map_get_type_var (id : TypeVarId.id) (nm : names_map) : string =
+ names_map_get (TypeVarId id) nm
(** Make a (variable) basename unique (by adding an index).
@@ -266,31 +271,39 @@ let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
let names_map = names_map_add id name ctx.names_map in
{ ctx with names_map }
-let ctx_find (id : id) (ctx : extraction_ctx) : string =
+let ctx_get (id : id) (ctx : extraction_ctx) : string =
IdMap.find id ctx.names_map.id_to_name
-let ctx_find_function (id : A.fun_id) (rg : RegionGroupId.id option)
+let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option)
(ctx : extraction_ctx) : string =
- ctx_find (FunId (id, rg)) ctx
+ ctx_get (FunId (id, rg)) ctx
-let ctx_find_local_function (id : FunDefId.id) (rg : RegionGroupId.id option)
+let ctx_get_local_function (id : FunDefId.id) (rg : RegionGroupId.id option)
(ctx : extraction_ctx) : string =
- ctx_find_function (A.Local id) rg ctx
+ ctx_get_function (A.Local id) rg ctx
-let ctx_find_type (id : type_id) (ctx : extraction_ctx) : string =
+let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
assert (id <> Tuple);
- ctx_find (TypeId id) ctx
+ ctx_get (TypeId id) ctx
+
+let ctx_get_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string =
+ ctx_get_type (AdtId id) ctx
-let ctx_find_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string =
- ctx_find_type (AdtId id) ctx
+let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
+ ctx_get (VarId id) ctx
-let ctx_find_var (id : VarId.id) (ctx : extraction_ctx) : string =
- ctx_find (VarId id) ctx
+let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string =
+ ctx_get (TypeVarId id) ctx
-let ctx_find_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string =
- ctx_find (TypeVarId id) ctx
+let ctx_get_field (def_id : TypeDefId.id) (field_id : FieldId.id)
+ (ctx : extraction_ctx) : string =
+ ctx_get (FieldId (def_id, field_id)) ctx
+
+let ctx_get_variant (def_id : TypeDefId.id) (variant_id : VariantId.id)
+ (ctx : extraction_ctx) : string =
+ ctx_get (VariantId (def_id, variant_id)) ctx
-(** Generate a unique type variable name and add to the context *)
+(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
let name =
@@ -299,6 +312,15 @@ let ctx_add_type_var (basename : string) (id : TypeVarId.id)
let ctx = ctx_add (TypeVarId id) name ctx in
(ctx, name)
+(** Generate a unique variable name and add it to the context *)
+let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
+ extraction_ctx * string =
+ let name =
+ basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ in
+ let ctx = ctx_add (VarId id) name ctx in
+ (ctx, name)
+
(** See [ctx_add_type_var] *)
let ctx_add_type_vars (vars : (string * TypeVarId.id) list)
(ctx : extraction_ctx) : extraction_ctx * string list =
@@ -318,6 +340,18 @@ let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) :
let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in
(ctx, def_name)
+let ctx_add_field (def : type_def) (field_id : FieldId.id) (field : field)
+ (ctx : extraction_ctx) : extraction_ctx * string =
+ let name = ctx.fmt.field_name def.name field.field_name in
+ let ctx = ctx_add (FieldId (def.def_id, field_id)) name ctx in
+ (ctx, name)
+
+let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list)
+ (ctx : extraction_ctx) : extraction_ctx * string list =
+ List.fold_left_map
+ (fun ctx (vid, v) -> ctx_add_field def vid v ctx)
+ ctx fields
+
let ctx_add_variant (def : type_def) (variant_id : VariantId.id)
(variant : variant) (ctx : extraction_ctx) : extraction_ctx * string =
let name = ctx.fmt.variant_name def.name variant.variant_name in