From 6739ab801801519f118cbb992b04c57f77c0cd17 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Feb 2022 22:59:24 +0100 Subject: Make minor modifications to extract mutually recursive types --- src/ExtractToFStar.ml | 118 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 41 deletions(-) (limited to 'src/ExtractToFStar.ml') diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 26316bc4..a1b56964 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -7,6 +7,19 @@ open PureToExtract open StringUtils module F = Format +(** A qualifier for a type definition. + + Controls whether we should use `type ...` or `and ...` (for mutually + recursive datatypes). + *) +type type_def_qualif = Type | And + +(** A qualifier for function definitions. + + Controls whether we should use `let ...`, `let rec ...` or `and ...` + *) +type fun_def_qualif = Let | LetRec | And + (** A list of keywords/identifiers used in F* and with which we want to check collision. *) let fstar_keywords = @@ -78,18 +91,20 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = | U64 -> "u64" | U128 -> "u128" in - (* For now, we treat only the case where type and function names are of the - * form: `Module::Type` and `Module:function`. + (* For now, we treat only the case where type names are of the + * form: `Module::Type` *) - let get_name (name : name) : string = - match name with [ _module; name ] -> name | _ -> failwith "Unexpected" + let get_type_name (name : name) : string = + match name with + | [ _module; name ] -> name + | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name) in let type_name_to_camel_case name = - let name = get_name name in + let name = get_type_name name in to_camel_case name in let type_name_to_snake_case name = - let name = get_name name in + let name = get_type_name name in to_snake_case name in let type_name name = type_name_to_snake_case name ^ "_t" in @@ -106,12 +121,17 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = type_name_to_camel_case def_name ^ variant else variant in - (* For now, we only treat the case where the type name is: - * `Module::Type` + (* For now, we treat only the case where function names are of the + * form: `function` (no module prefix) *) + let get_fun_name (name : name) : string = + match name with + | [ name ] -> name + | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name) + in let fun_name (_fid : A.fun_id) (fname : name) (num_rgs : int) (rg : region_group_info option) : string = - let fname = get_name fname in + let fname = get_fun_name fname in (* Converting to snake case should be a no-op, but it doesn't cost much *) let fname = to_snake_case fname in (* Compute the suffix *) @@ -139,7 +159,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = let def = TypeDefId.Map.find adt_id ctx.type_context.type_defs in - StringUtils.string_of_chars [ (get_name def.name).[0] ]) + StringUtils.string_of_chars [ (get_type_name def.name).[0] ]) | TypeVar _ -> "x" (* lacking imagination here... *) | Bool -> "b" | Char -> "c" @@ -196,6 +216,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Str -> F.pp_print_string fmt ctx.fmt.str_name | Array _ | Slice _ -> raise Unimplemented +(** Compute the names for all the top-level identifiers used in a type + definition (type name, variant names, field names, etc. but not type + parameters). + + We need to do this preemptively, beforce extracting any definition, + because of recursive definitions. + *) +let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) : + extraction_ctx = + (* Compute and register the type def name *) + let ctx, def_name = ctx_add_type_def def ctx in + (* Compute and register: + * - the variant names, if this is an enumeration + * - the field names, if this is a structure + *) + let ctx = + match def.kind with + | Struct fields -> + fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) + | Enum variants -> + fst + (ctx_add_variants def + (VariantId.mapi (fun id v -> (id, v)) variants) + ctx) + in + (* Return *) + ctx + let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) (fields : field list) : unit = (* We want to generate a definition which looks like this: @@ -354,41 +402,13 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in List.iter (fun (vid, v) -> print_variant vid v) variants -(** Compute the names for all the top-level identifiers used in a type - definition (type name, variant names, field names, etc. but not type - parameters). - - We need to do this preemptively, beforce extracting any definition, - because of recursive definitions. - *) -let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) : - extraction_ctx = - (* Compute and register the type def name *) - let ctx, def_name = ctx_add_type_def def ctx in - (* Compute and register: - * - the variant names, if this is an enumeration - * - the field names, if this is a structure - *) - let ctx = - match def.kind with - | Struct fields -> - fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) - | Enum variants -> - fst - (ctx_add_variants def - (VariantId.mapi (fun id v -> (id, v)) variants) - ctx) - in - (* Return *) - ctx - (** Extract a type definition. Note that all the names used for extraction should already have been registered. *) -let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) - : unit = +let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) + (qualif : type_def_qualif) (def : type_def) : unit = (* Retrieve the definition name *) let def_name = ctx_get_local_type def.def_id ctx in (* Add the type params - note that we remember those bindings only for the @@ -406,7 +426,8 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) - F.pp_print_string fmt ("type " ^ def_name); + let qualif = match qualif with Type -> "type" | And -> "and" in + F.pp_print_string fmt (qualif ^ " " ^ def_name); (* Print the type parameters *) if def.type_params <> [] then ( F.pp_print_space fmt (); @@ -433,3 +454,18 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 + +(** Compute the names for all the pure functions generated from a rust function + (forward function and backward functions). + *) +let extract_fun_def_register_names (ctx : extraction_ctx) + (def : pure_fun_translation) : extraction_ctx = + let fwd, back_ls = def in + (* Register the forward function name *) + let ctx = ctx_add_fun_def fwd ctx in + (* Register the backward functions' names *) + let ctx = + List.fold_left (fun ctx back -> ctx_add_fun_def back ctx) ctx back_ls + in + (* Return *) + ctx -- cgit v1.2.3