From 6739ab801801519f118cbb992b04c57f77c0cd17 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Feb 2022 22:59:24 +0100 Subject: Make minor modifications to extract mutually recursive types --- src/ExtractToFStar.ml | 118 ++++++++++++++++++++++++++++++++------------------ src/PureToExtract.ml | 26 +++++++++++ src/Translate.ml | 24 +++++++--- 3 files changed, 122 insertions(+), 46 deletions(-) (limited to 'src') diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 26316bc4..a1b56964 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -7,6 +7,19 @@ open PureToExtract open StringUtils module F = Format +(** A qualifier for a type definition. + + Controls whether we should use `type ...` or `and ...` (for mutually + recursive datatypes). + *) +type type_def_qualif = Type | And + +(** A qualifier for function definitions. + + Controls whether we should use `let ...`, `let rec ...` or `and ...` + *) +type fun_def_qualif = Let | LetRec | And + (** A list of keywords/identifiers used in F* and with which we want to check collision. *) let fstar_keywords = @@ -78,18 +91,20 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = | U64 -> "u64" | U128 -> "u128" in - (* For now, we treat only the case where type and function names are of the - * form: `Module::Type` and `Module:function`. + (* For now, we treat only the case where type names are of the + * form: `Module::Type` *) - let get_name (name : name) : string = - match name with [ _module; name ] -> name | _ -> failwith "Unexpected" + let get_type_name (name : name) : string = + match name with + | [ _module; name ] -> name + | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name) in let type_name_to_camel_case name = - let name = get_name name in + let name = get_type_name name in to_camel_case name in let type_name_to_snake_case name = - let name = get_name name in + let name = get_type_name name in to_snake_case name in let type_name name = type_name_to_snake_case name ^ "_t" in @@ -106,12 +121,17 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = type_name_to_camel_case def_name ^ variant else variant in - (* For now, we only treat the case where the type name is: - * `Module::Type` + (* For now, we treat only the case where function names are of the + * form: `function` (no module prefix) *) + let get_fun_name (name : name) : string = + match name with + | [ name ] -> name + | _ -> failwith ("Unexpected name shape: " ^ Print.name_to_string name) + in let fun_name (_fid : A.fun_id) (fname : name) (num_rgs : int) (rg : region_group_info option) : string = - let fname = get_name fname in + let fname = get_fun_name fname in (* Converting to snake case should be a no-op, but it doesn't cost much *) let fname = to_snake_case fname in (* Compute the suffix *) @@ -139,7 +159,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = let def = TypeDefId.Map.find adt_id ctx.type_context.type_defs in - StringUtils.string_of_chars [ (get_name def.name).[0] ]) + StringUtils.string_of_chars [ (get_type_name def.name).[0] ]) | TypeVar _ -> "x" (* lacking imagination here... *) | Bool -> "b" | Char -> "c" @@ -196,6 +216,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Str -> F.pp_print_string fmt ctx.fmt.str_name | Array _ | Slice _ -> raise Unimplemented +(** Compute the names for all the top-level identifiers used in a type + definition (type name, variant names, field names, etc. but not type + parameters). + + We need to do this preemptively, beforce extracting any definition, + because of recursive definitions. + *) +let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) : + extraction_ctx = + (* Compute and register the type def name *) + let ctx, def_name = ctx_add_type_def def ctx in + (* Compute and register: + * - the variant names, if this is an enumeration + * - the field names, if this is a structure + *) + let ctx = + match def.kind with + | Struct fields -> + fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) + | Enum variants -> + fst + (ctx_add_variants def + (VariantId.mapi (fun id v -> (id, v)) variants) + ctx) + in + (* Return *) + ctx + let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) (fields : field list) : unit = (* We want to generate a definition which looks like this: @@ -354,41 +402,13 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in List.iter (fun (vid, v) -> print_variant vid v) variants -(** Compute the names for all the top-level identifiers used in a type - definition (type name, variant names, field names, etc. but not type - parameters). - - We need to do this preemptively, beforce extracting any definition, - because of recursive definitions. - *) -let extract_type_def_register_names (ctx : extraction_ctx) (def : type_def) : - extraction_ctx = - (* Compute and register the type def name *) - let ctx, def_name = ctx_add_type_def def ctx in - (* Compute and register: - * - the variant names, if this is an enumeration - * - the field names, if this is a structure - *) - let ctx = - match def.kind with - | Struct fields -> - fst (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) - | Enum variants -> - fst - (ctx_add_variants def - (VariantId.mapi (fun id v -> (id, v)) variants) - ctx) - in - (* Return *) - ctx - (** Extract a type definition. Note that all the names used for extraction should already have been registered. *) -let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) - : unit = +let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) + (qualif : type_def_qualif) (def : type_def) : unit = (* Retrieve the definition name *) let def_name = ctx_get_local_type def.def_id ctx in (* Add the type params - note that we remember those bindings only for the @@ -406,7 +426,8 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) (* Open a box for "type TYPE_NAME (TYPE_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) - F.pp_print_string fmt ("type " ^ def_name); + let qualif = match qualif with Type -> "type" | And -> "and" in + F.pp_print_string fmt (qualif ^ " " ^ def_name); (* Print the type parameters *) if def.type_params <> [] then ( F.pp_print_space fmt (); @@ -433,3 +454,18 @@ let extract_type_def (ctx : extraction_ctx) (fmt : F.formatter) (def : type_def) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 + +(** Compute the names for all the pure functions generated from a rust function + (forward function and backward functions). + *) +let extract_fun_def_register_names (ctx : extraction_ctx) + (def : pure_fun_translation) : extraction_ctx = + let fwd, back_ls = def in + (* Register the forward function name *) + let ctx = ctx_add_fun_def fwd ctx in + (* Register the backward functions' names *) + let ctx = + List.fold_left (fun ctx back -> ctx_add_fun_def back ctx) ctx back_ls + in + (* Return *) + ctx diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index c36ed8fe..226f178a 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -386,3 +386,29 @@ let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list) List.fold_left_map (fun ctx (vid, v) -> ctx_add_variant def vid v ctx) ctx variants + +let ctx_add_fun_def (def : fun_def) (ctx : extraction_ctx) : extraction_ctx = + (* Lookup the CFIM def to compute the region group information *) + let def_id = def.def_id in + let cfim_def = FunDefId.Map.find def_id ctx.trans_ctx.fun_context.fun_defs in + let sg = cfim_def.signature in + let num_rgs = List.length sg.regions_hierarchy in + let rg_info = + match def.back_id with + | None -> None + | Some rg_id -> + let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in + let regions = + List.map + (fun rid -> T.RegionVarId.nth sg.region_params rid) + rg.regions + in + let region_names = + List.map (fun (r : T.region_var) -> r.name) regions + in + Some { id = rg_id; region_names } + in + let def_id = A.Local def_id in + let name = ctx.fmt.fun_name def_id def.basename num_rgs rg_info in + let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in + ctx diff --git a/src/Translate.ml b/src/Translate.ml index 75975704..b840b7bc 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -288,7 +288,12 @@ let translate_module (filename : string) (config : C.partial_config) extract_ctx trans_types in - (* TODO: register the functions *) + let extract_ctx = + List.fold_left + (fun extract_ctx def -> + ExtractToFStar.extract_fun_def_register_names extract_ctx def) + extract_ctx trans_funs + in (* Open the output file *) (* First compute the filename by replacing the extension and converting the @@ -340,18 +345,27 @@ let translate_module (filename : string) (config : C.partial_config) Format.pp_print_break fmt 0 0; (* Export the definition groups to the file, in the proper order *) - let export_type (id : Pure.TypeDefId.id) : unit = + let export_type (qualif : ExtractToFStar.type_def_qualif) + (id : Pure.TypeDefId.id) : unit = let def = Pure.TypeDefId.Map.find id trans_types in - ExtractToFStar.extract_type_def extract_ctx fmt def + ExtractToFStar.extract_type_def extract_ctx fmt qualif def in let export_function (id : Pure.FunDefId.id) : unit = (* TODO *) + (* let pure_defs = Pure.FunDefId.Map.find id trans_funs in *) () in let export_decl (decl : M.declaration_group) : unit = match decl with - | Type (NonRec id) -> export_type id - | Type (Rec ids) -> List.iter export_type ids + | Type (NonRec id) -> export_type ExtractToFStar.Type id + | Type (Rec ids) -> + List.iteri + (fun i id -> + let qualif = + if i = 0 then ExtractToFStar.Type else ExtractToFStar.And + in + export_type qualif id) + ids | Fun (NonRec id) -> export_function id | Fun (Rec ids) -> List.iter export_function ids in -- cgit v1.2.3