summaryrefslogtreecommitdiff
path: root/src/PureToExtract.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/PureToExtract.ml')
-rw-r--r--src/PureToExtract.ml25
1 files changed, 24 insertions, 1 deletions
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index 7bafad08..f9c021fb 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -40,7 +40,7 @@ type name_formatter = {
- type name
- field name
*)
- variant_name : string -> string -> string;
+ variant_name : name -> string -> string;
(** Inputs:
- type name
- variant name
@@ -156,6 +156,10 @@ let compute_fun_def_name (ctx : trans_ctx) (fmt : name_formatter)
type id =
| FunId of A.fun_id * RegionGroupId.id option
| TypeId of type_id
+ | VariantId of TypeDefId.id * VariantId.id
+ (** If often happens that variant names must be unique (it is the case in
+ F* ) which is why we register them here.
+ *)
| TypeVarId of TypeVarId.id
| VarId of VarId.id
| UnknownId
@@ -258,6 +262,7 @@ type extraction_ctx = {
*)
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
+ (* TODO : nice debugging message if collision *)
let names_map = names_map_add id name ctx.names_map in
{ ctx with names_map }
@@ -306,3 +311,21 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) :
List.fold_left_map
(fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx)
ctx vars
+
+let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) :
+ extraction_ctx * string =
+ let def_name = ctx.fmt.type_name def.name in
+ let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in
+ (ctx, def_name)
+
+let ctx_add_variant (def : type_def) (variant_id : VariantId.id)
+ (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string =
+ let name = ctx.fmt.variant_name def.name variant.variant_name in
+ let ctx = ctx_add (VariantId (def.def_id, variant_id)) name ctx in
+ (ctx, name)
+
+let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list)
+ (ctx : extraction_ctx) : extraction_ctx * string list =
+ List.fold_left_map
+ (fun ctx (vid, v) -> ctx_add_variant def vid v ctx)
+ ctx variants