summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ExtractToFstar.ml29
-rw-r--r--src/PureToExtract.ml25
2 files changed, 43 insertions, 11 deletions
diff --git a/src/ExtractToFstar.ml b/src/ExtractToFstar.ml
index ef7d756a..56a8c338 100644
--- a/src/ExtractToFstar.ml
+++ b/src/ExtractToFstar.ml
@@ -94,18 +94,27 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
raise Unimplemented
let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
- (def : type_def) : unit =
- let name = ctx_find_local_type def.def_id ctx in
- let ctx, type_params = ctx_add_type_params def.type_params ctx in
+ (def : type_def) : extraction_ctx =
+ (* Compute and register the type def name *)
+ let ctx, def_name = ctx_add_type_def def ctx in
+ (* Compute and register the variant names, if this is an enumeration *)
+ let ctx, variant_names =
+ match def.kind with
+ | Struct _ -> (ctx, [])
+ | Enum variants ->
+ ctx_add_variants def (VariantId.mapi (fun id v -> (id, v)) variants) ctx
+ in
+ (* Add the type params - note that we remember those bindings only for the
+ * body translation: the updated ctx we return at the end of the function
+ * only contains the registered type def and variant names *)
+ let ctx_body, type_params = ctx_add_type_params def.type_params ctx in
(* > "type TYPE_NAME =" *)
F.pp_print_string fmt "type";
F.pp_print_space fmt ();
- F.pp_print_string fmt name;
- match def.kind with
+ F.pp_print_string fmt def_name;
+ (match def.kind with
| Struct fields ->
- extract_type_def_struct_body ctx fmt name type_params fields
+ extract_type_def_struct_body ctx_body fmt def_name type_params fields
| Enum variants ->
- extract_type_def_enum_body ctx fmt name type_params variants
-
-(*let rec extract_field (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
- (ty : ty) : unit =*)
+ extract_type_def_enum_body ctx_body fmt def_name type_params variants);
+ ctx
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index 7bafad08..f9c021fb 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -40,7 +40,7 @@ type name_formatter = {
- type name
- field name
*)
- variant_name : string -> string -> string;
+ variant_name : name -> string -> string;
(** Inputs:
- type name
- variant name
@@ -156,6 +156,10 @@ let compute_fun_def_name (ctx : trans_ctx) (fmt : name_formatter)
type id =
| FunId of A.fun_id * RegionGroupId.id option
| TypeId of type_id
+ | VariantId of TypeDefId.id * VariantId.id
+ (** If often happens that variant names must be unique (it is the case in
+ F* ) which is why we register them here.
+ *)
| TypeVarId of TypeVarId.id
| VarId of VarId.id
| UnknownId
@@ -258,6 +262,7 @@ type extraction_ctx = {
*)
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
+ (* TODO : nice debugging message if collision *)
let names_map = names_map_add id name ctx.names_map in
{ ctx with names_map }
@@ -306,3 +311,21 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) :
List.fold_left_map
(fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx)
ctx vars
+
+let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) :
+ extraction_ctx * string =
+ let def_name = ctx.fmt.type_name def.name in
+ let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in
+ (ctx, def_name)
+
+let ctx_add_variant (def : type_def) (variant_id : VariantId.id)
+ (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string =
+ let name = ctx.fmt.variant_name def.name variant.variant_name in
+ let ctx = ctx_add (VariantId (def.def_id, variant_id)) name ctx in
+ (ctx, name)
+
+let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list)
+ (ctx : extraction_ctx) : extraction_ctx * string list =
+ List.fold_left_map
+ (fun ctx (vid, v) -> ctx_add_variant def vid v ctx)
+ ctx variants