From 45b8a56c20f2cf29c8bc9fc1de593e9c82a2fb6d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Feb 2022 16:22:47 +0100 Subject: Fix more issues when extracting types to F* --- src/ExtractToFStar.ml | 72 +++++++++++++++++++++++++-------------------------- src/PureToExtract.ml | 5 ++-- src/StringUtils.ml | 4 ++- src/Translate.ml | 3 ++- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 63d1affd..65eaae60 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -86,18 +86,18 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = in let type_name_to_camel_case name = let name = get_name name in - to_snake_case name + to_camel_case name in let type_name_to_snake_case name = let name = get_name name in to_snake_case name in - let type_name name = type_name_to_camel_case name ^ "_t" in + let type_name name = type_name_to_snake_case name ^ "_t" in let field_name (def_name : name) (field_id : FieldId.id) (field_name : string option) : string = let def_name = type_name_to_snake_case def_name ^ "_" in match field_name with - | Some field_name -> def_name ^ "_" ^ field_name + | Some field_name -> def_name ^ field_name | None -> def_name ^ FieldId.to_string field_id in let variant_name (def_name : name) (variant : string) : string = @@ -147,17 +147,10 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = | Str -> "s" | Array _ | Slice _ -> raise Unimplemented) in - let type_var_basename (_varset : StringSet.t) (basename : string option) : - string = - (* If there is a basename, we use it *) - match basename with - | Some basename -> - (* This is *not* a no-op: type variables in Rust often start with - * a capital letter *) - to_snake_case basename - | None -> - (* For no, we use "a" *) - "a" + let type_var_basename (_varset : StringSet.t) (basename : string) : string = + (* This is *not* a no-op: type variables in Rust often start with + * a capital letter *) + to_snake_case basename in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i @@ -218,32 +211,39 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) * } * ``` * Note that we already printed: `type t =` + * + * Also, in case there are no fields, we need to define the type as `unit` + * (`type t = {}` doesn't work in F* ). *) - F.pp_print_space fmt (); - F.pp_print_string fmt "{"; - (* The body itself *) - F.pp_open_hvbox fmt 0; - F.pp_open_hvbox fmt ctx.indent_incr; - F.pp_print_space fmt (); - (* Print the fields *) - let print_field (field_id : FieldId.id) (f : field) : unit = - let field_name = ctx_get_field def.def_id field_id ctx in - F.pp_open_box fmt ctx.indent_incr; - F.pp_print_string fmt field_name; + if fields = [] then ( F.pp_print_space fmt (); - F.pp_print_string fmt ":"; + F.pp_print_string fmt "unit") + else ( F.pp_print_space fmt (); - extract_ty ctx fmt false f.field_ty; - F.pp_print_string fmt ";"; + F.pp_print_string fmt "{"; + (* The body itself *) + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr; F.pp_print_space fmt (); - F.pp_close_box fmt () - in - let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - List.iter (fun (fid, f) -> print_field fid f) fields; - (* Close *) - F.pp_close_box fmt (); - F.pp_print_string fmt "}"; - F.pp_close_box fmt () + (* Print the fields *) + let print_field (field_id : FieldId.id) (f : field) : unit = + let field_name = ctx_get_field def.def_id field_id ctx in + F.pp_open_box fmt ctx.indent_incr; + F.pp_print_string fmt field_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false f.field_ty; + F.pp_print_string fmt ";"; + F.pp_print_space fmt (); + F.pp_close_box fmt () + in + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + List.iter (fun (fid, f) -> print_field fid f) fields; + (* Close *) + F.pp_close_box fmt (); + F.pp_print_string fmt "}"; + F.pp_close_box fmt ()) let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) (def_name : string) (type_params : string list) diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index 4e6c9014..c36ed8fe 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -76,7 +76,7 @@ type name_formatter = { if necessary to prevent name clashes: the burden of name clashes checks is thus on the caller's side. *) - type_var_basename : StringSet.t -> string option -> string; + type_var_basename : StringSet.t -> string -> string; (** Generates a type variable basename. *) append_index : string -> int -> string; (** Appends an index to a name - we use this to generate unique @@ -328,8 +328,9 @@ let ctx_get_variant (def_id : TypeDefId.id) (variant_id : VariantId.id) (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = + let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name in let ctx = ctx_add (TypeVarId id) name ctx in (ctx, name) diff --git a/src/StringUtils.ml b/src/StringUtils.ml index 7c77a8d1..2e0e18f7 100644 --- a/src/StringUtils.ml +++ b/src/StringUtils.ml @@ -90,4 +90,6 @@ let to_snake_case (s : string) : string = let _ = assert (to_camel_case "hello_world" = "HelloWorld"); assert (to_snake_case "HelloWorld36Hello" = "hello_world36_hello"); - assert (to_snake_case "HELLO" = "hello") + assert (to_snake_case "HELLO" = "hello"); + assert (to_snake_case "T1" = "t1"); + assert (to_camel_case "list" = "List") diff --git a/src/Translate.ml b/src/Translate.ml index e43d5741..2ddc4adc 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -351,4 +351,5 @@ let translate_module (filename : string) (config : C.partial_config) List.iter export_decl m.declarations; (* Close the box and end the formatting *) - Format.pp_close_box fmt () + Format.pp_close_box fmt (); + Format.pp_print_newline fmt () -- cgit v1.2.3