diff options
Diffstat (limited to '')
-rw-r--r-- | src/ExtractToFStar.ml | 142 | ||||
-rw-r--r-- | src/PureToExtract.ml | 90 |
2 files changed, 184 insertions, 48 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 56a8c338..35d15607 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -37,11 +37,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_print_string fmt ")" | AdtId _ | Assumed _ -> if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt (ctx_find_type type_id ctx); + F.pp_print_string fmt (ctx_get_type type_id ctx); if tys <> [] then F.pp_print_space fmt (); list_iterb (F.pp_print_space fmt) (extract_ty ctx fmt true) tys; if inside then F.pp_print_string fmt ")") - | TypeVar vid -> F.pp_print_string fmt (ctx_find_type_var vid ctx) + | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) | Bool -> F.pp_print_string fmt ctx.fmt.bool_name | Char -> F.pp_print_string fmt ctx.fmt.char_name | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) @@ -49,21 +49,20 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Array _ | Slice _ -> raise Unimplemented let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) - (type_name : string) (type_params : string list) (fields : field list) : - unit = + (def : type_def) (fields : field list) : unit = (* We want to generate a definition which looks like this: * ``` - * type s = { x : int; y : bool; } + * type t = { x : int; y : bool; } * ``` * * Or if there isn't enough space on one line: * ``` - * type s = { + * type t = { * x : int; * y : bool; * } * ``` - * Note that we already printed: `type s =` + * Note that we already printed: `type t =` *) F.pp_print_space fmt (); F.pp_print_string fmt "{"; @@ -72,8 +71,8 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hvbox fmt ctx.indent_incr; F.pp_print_space fmt (); (* Print the fields *) - let print_field (f : field) : unit = - let field_name = ctx.fmt.field_name type_name f.field_name in + let print_field (field_id : FieldId.id) (f : field) : unit = + let field_name = ctx_get_field def.def_id field_id ctx in F.pp_open_box fmt ctx.indent_incr; F.pp_print_string fmt field_name; F.pp_print_space fmt (); @@ -82,27 +81,131 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) extract_ty ctx fmt false f.field_ty; F.pp_close_box fmt () in - List.iter print_field fields; + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + List.iter (fun (fid, f) -> print_field fid f) fields; (* Close *) F.pp_close_box fmt (); F.pp_print_string fmt "}"; F.pp_close_box fmt () let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) - (type_name : string) (type_params : string list) (variants : variant list) : - unit = - raise Unimplemented + (def : type_def) (type_name : string) (type_params : string list) + (variants : variant list) : unit = + (* We want to generate a definition which looks like this: + * ``` + * type list a = | Cons : a -> list a -> list a | Nil : list a + * ``` + * + * If there isn't enough space on one line: + * ``` + * type s = + * | Cons : a -> list a -> list a + * | Nil : list a + * ``` + * + * And if we need to write the type of a variant on several lines: + * ``` + * type s = + * | Cons : + * a -> + * list a -> + * list a + * | Nil : list a + * ``` + * + * Finally, it is possible to give names to the variant fields in Rust. + * In this situation, we generate a definition like this: + * ``` + * type s = + * | Cons : hd:a -> tl:list a -> list a + * | Nil : list a + * ``` + * + * Note that we already printed: `type s =` + *) + (* Open the body box *) + F.pp_open_hvbox fmt 0; + (* Print the variants *) + let print_variant (variant_id : VariantId.id) (variant : variant) : unit = + let variant_name = ctx_get_variant def.def_id variant_id ctx in + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + (* variant box *) + (* `| Cons :` + * Note that we really don't want any break above so we print everything + * at once. *) + F.pp_print_string fmt ("| " ^ variant_name ^ " :"); + F.pp_print_space fmt (); + let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) : + extraction_ctx = + (* Open the field box *) + F.pp_open_box fmt ctx.indent_incr; + (* Print the field names + * ` x :` + * Note that when printing fields, we register the field names as + * *variables*: they don't need to be unique at the top level. *) + let ctx = + match f.field_name with + | None -> ctx + | Some field_name -> + let var_id = VarId.of_int (FieldId.to_int fid) in + let ctx, field_name = ctx_add_var field_name var_id ctx in + F.pp_print_string fmt (field_name ^ " :"); + F.pp_space fmt (); + ctx + in + (* Print the field type *) + extract_ty ctx fmt false f.field_ty; + (* Print the arrow `->`*) + F.pp_space fmt (); + F.pp_print_string "->"; + (* Close the field box *) + F.pp_close_box fmt (); + F.pp_space fmt (); + (* Return *) + ctx + in + (* Print the fields *) + let fields = FieldId.mapi (fun fid f -> (fid, f)) variant.fields in + let ctx = + List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields + in + (* Print the final type *) + F.pp_open_hovbox fmt (); + F.pp_string fmt def_name; + List.iter + (fun type_param -> + F.pp_space fmt (); + F.pp_string fmt type_param) + type_params; + F.pp_close_hovbox fmt (); + (* Close the variant box *) + F.pp_close_box fmt () + in + (* Print the variants *) + let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in + List.iter (fun (vid, v) -> print_variant vid v) variants; + (* Close the body box *) + F.pp_close_box fmt () let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) : extraction_ctx = (* Compute and register the type def name *) let ctx, def_name = ctx_add_type_def def ctx in - (* Compute and register the variant names, if this is an enumeration *) - let ctx, variant_names = + (* Compute and register: + * - the variant names, if this is an enumeration + * - the field names, if this is a structure + * We do this because in F*, they have to be unique at the top-level. + *) + let ctx = match def.kind with - | Struct _ -> (ctx, []) + | Struct fields -> + fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) | Enum variants -> - ctx_add_variants def (VariantId.mapi (fun id v -> (id, v)) variants) ctx + fst + (ctx_add_variants def + (VariantId.mapi (fun id v -> (id, v)) variants) + ctx) in (* Add the type params - note that we remember those bindings only for the * body translation: the updated ctx we return at the end of the function @@ -113,8 +216,7 @@ let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); F.pp_print_string fmt def_name; (match def.kind with - | Struct fields -> - extract_type_def_struct_body ctx_body fmt def_name type_params fields + | Struct fields -> extract_type_def_struct_body ctx_body fmt fields | Enum variants -> - extract_type_def_enum_body ctx_body fmt def_name type_params variants); + extract_type_def_enum_body ctx_body fmt def def_name type_params variants); ctx diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index f9c021fb..f2c03b90 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -35,7 +35,7 @@ type name_formatter = { char_name : string; int_name : integer_type -> string; str_name : string; - field_name : string -> string -> string; + field_name : name -> string -> string; (** Inputs: - type name - field name @@ -160,6 +160,11 @@ type id = (** If often happens that variant names must be unique (it is the case in F* ) which is why we register them here. *) + | FieldId of TypeDefId.id * FieldId.id + (** If often happens that in the case of structures, the field names + must be unique (it is the case in F* ) which is why we register + them here. + *) | TypeVarId of TypeVarId.id | VarId of VarId.id | UnknownId @@ -207,29 +212,29 @@ let names_map_add (id : id) (name : string) (nm : names_map) : names_map = { id_to_name; name_to_id; names_set } (* TODO: remove those functions? We use the ones of extraction_ctx *) -let names_map_find (id : id) (nm : names_map) : string = +let names_map_get (id : id) (nm : names_map) : string = IdMap.find id nm.id_to_name -let names_map_find_function (id : A.fun_id) (rg : RegionGroupId.id option) +let names_map_get_function (id : A.fun_id) (rg : RegionGroupId.id option) (nm : names_map) : string = - names_map_find (FunId (id, rg)) nm + names_map_get (FunId (id, rg)) nm -let names_map_find_local_function (id : FunDefId.id) +let names_map_get_local_function (id : FunDefId.id) (rg : RegionGroupId.id option) (nm : names_map) : string = - names_map_find_function (A.Local id) rg nm + names_map_get_function (A.Local id) rg nm -let names_map_find_type (id : type_id) (nm : names_map) : string = +let names_map_get_type (id : type_id) (nm : names_map) : string = assert (id <> Tuple); - names_map_find (TypeId id) nm + names_map_get (TypeId id) nm -let names_map_find_local_type (id : TypeDefId.id) (nm : names_map) : string = - names_map_find_type (AdtId id) nm +let names_map_get_local_type (id : TypeDefId.id) (nm : names_map) : string = + names_map_get_type (AdtId id) nm -let names_map_find_var (id : VarId.id) (nm : names_map) : string = - names_map_find (VarId id) nm +let names_map_get_var (id : VarId.id) (nm : names_map) : string = + names_map_get (VarId id) nm -let names_map_find_type_var (id : TypeVarId.id) (nm : names_map) : string = - names_map_find (TypeVarId id) nm +let names_map_get_type_var (id : TypeVarId.id) (nm : names_map) : string = + names_map_get (TypeVarId id) nm (** Make a (variable) basename unique (by adding an index). @@ -266,31 +271,39 @@ let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = let names_map = names_map_add id name ctx.names_map in { ctx with names_map } -let ctx_find (id : id) (ctx : extraction_ctx) : string = +let ctx_get (id : id) (ctx : extraction_ctx) : string = IdMap.find id ctx.names_map.id_to_name -let ctx_find_function (id : A.fun_id) (rg : RegionGroupId.id option) +let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_find (FunId (id, rg)) ctx + ctx_get (FunId (id, rg)) ctx -let ctx_find_local_function (id : FunDefId.id) (rg : RegionGroupId.id option) +let ctx_get_local_function (id : FunDefId.id) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_find_function (A.Local id) rg ctx + ctx_get_function (A.Local id) rg ctx -let ctx_find_type (id : type_id) (ctx : extraction_ctx) : string = +let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = assert (id <> Tuple); - ctx_find (TypeId id) ctx + ctx_get (TypeId id) ctx + +let ctx_get_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string = + ctx_get_type (AdtId id) ctx -let ctx_find_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string = - ctx_find_type (AdtId id) ctx +let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = + ctx_get (VarId id) ctx -let ctx_find_var (id : VarId.id) (ctx : extraction_ctx) : string = - ctx_find (VarId id) ctx +let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = + ctx_get (TypeVarId id) ctx -let ctx_find_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = - ctx_find (TypeVarId id) ctx +let ctx_get_field (def_id : TypeDefId.id) (field_id : FieldId.id) + (ctx : extraction_ctx) : string = + ctx_get (FieldId (def_id, field_id)) ctx + +let ctx_get_variant (def_id : TypeDefId.id) (variant_id : VariantId.id) + (ctx : extraction_ctx) : string = + ctx_get (VariantId (def_id, variant_id)) ctx -(** Generate a unique type variable name and add to the context *) +(** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = let name = @@ -299,6 +312,15 @@ let ctx_add_type_var (basename : string) (id : TypeVarId.id) let ctx = ctx_add (TypeVarId id) name ctx in (ctx, name) +(** Generate a unique variable name and add it to the context *) +let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : + extraction_ctx * string = + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + in + let ctx = ctx_add (VarId id) name ctx in + (ctx, name) + (** See [ctx_add_type_var] *) let ctx_add_type_vars (vars : (string * TypeVarId.id) list) (ctx : extraction_ctx) : extraction_ctx * string list = @@ -318,6 +340,18 @@ let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) : let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in (ctx, def_name) +let ctx_add_field (def : type_def) (field_id : FieldId.id) (field : field) + (ctx : extraction_ctx) : extraction_ctx * string = + let name = ctx.fmt.field_name def.name field.field_name in + let ctx = ctx_add (FieldId (def.def_id, field_id)) name ctx in + (ctx, name) + +let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list) + (ctx : extraction_ctx) : extraction_ctx * string list = + List.fold_left_map + (fun ctx (vid, v) -> ctx_add_field def vid v ctx) + ctx fields + let ctx_add_variant (def : type_def) (variant_id : VariantId.id) (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = let name = ctx.fmt.variant_name def.name variant.variant_name in |