diff options
author | Son Ho | 2022-02-03 12:30:28 +0100 |
---|---|---|
committer | Son Ho | 2022-02-03 12:30:28 +0100 |
commit | 972ed4288ff1f489fcf03b4cdca847abcc55674e (patch) | |
tree | 10878047f977de423bb2f43af5c8cf72617e7631 | |
parent | 72a6a2830a257aad3e4da2d8a53ac07cd38e8f41 (diff) |
Make more progress on implementing function extraction
Diffstat (limited to '')
-rw-r--r-- | src/Collections.ml | 23 | ||||
-rw-r--r-- | src/ExtractToFStar.ml | 137 | ||||
-rw-r--r-- | src/PureToExtract.ml | 116 | ||||
-rw-r--r-- | src/Translate.ml | 2 |
4 files changed, 225 insertions, 53 deletions
diff --git a/src/Collections.ml b/src/Collections.ml index 125cab1f..2bfdd18b 100644 --- a/src/Collections.ml +++ b/src/Collections.ml @@ -41,9 +41,9 @@ module List = struct (** Iter and link the iterations. - Iterate over a list, but call a function between every two elements - (but not before the first element, and not after the last). - *) + Iterate over a list, but call a function between every two elements + (but not before the first element, and not after the last). + *) let iter_link (link : unit -> unit) (f : 'a -> unit) (ls : 'a list) : unit = let rec iter ls = match ls with @@ -55,6 +55,23 @@ module List = struct iter (y :: ls) in iter ls + + (** Fold and link the iterations. + + Similar to [iter_link] but for fold left operations. + *) + let fold_left_link (link : unit -> unit) (f : 'a -> 'b -> 'a) (init : 'a) + (ls : 'b list) : 'a = + let rec fold (acc : 'a) (ls : 'b list) : 'a = + match ls with + | [] -> acc + | [ x ] -> f acc x + | x :: y :: ls -> + let acc = f acc x in + link (); + fold acc (y :: ls) + in + fold init ls end module type OrderedType = sig diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 6e09dfa6..fb781939 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -39,6 +39,21 @@ let fstar_keywords = "Type0"; ] +let fstar_assumed_adts : (assumed_ty * string) list = [ (Result, "result") ] + +let fstar_assumed_structs : (assumed_ty * string) list = [] + +let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list = + [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail") ] + +let fstar_names_map_init = + { + keywords = fstar_keywords; + assumed_adts = fstar_assumed_adts; + assumed_structs = fstar_assumed_structs; + assumed_variants = fstar_assumed_variants; + } + (** * [ctx]: we use the context to lookup type definitions, to retrieve type names. * This is used to compute variable names, when they have no basenames: in this @@ -121,6 +136,10 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = type_name_to_camel_case def_name ^ variant else variant in + let struct_constructor (basename : name) : string = + let tname = type_name basename in + "Mk" ^ tname + in (* For now, we treat only the case where function names are of the * form: `function` (no module prefix) *) @@ -182,6 +201,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = str_name = "string"; field_name; variant_name; + struct_constructor; type_name; fun_name; var_basename; @@ -283,7 +303,7 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hvbox fmt 0; (* Print the fields *) let print_field (field_id : FieldId.id) (f : field) : unit = - let field_name = ctx_get_field def.def_id field_id ctx in + let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in F.pp_open_box fmt ctx.indent_incr; F.pp_print_string fmt field_name; F.pp_print_space fmt (); @@ -339,7 +359,7 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) *) (* Print the variants *) let print_variant (variant_id : VariantId.id) (variant : variant) : unit = - let variant_name = ctx_get_variant def.def_id variant_id ctx in + let variant_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in F.pp_print_space fmt (); F.pp_open_hvbox fmt ctx.indent_incr; (* variant box *) @@ -469,10 +489,74 @@ let extract_fun_def_register_names (ctx : extraction_ctx) (* Return *) ctx -(** [inside]: see [extract_ty] *) +(** The following function factorizes the extraction of ADT values. + + Note that lvalues can introduce new variables: we thus return an extraction + context updated with new bindings. + *) +let extract_adt_g_value + (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) + (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) + (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : + extraction_ctx = + match ty with + | Adt (Tuple, _) -> + (* Tuple *) + F.pp_print_string fmt "("; + let ctx = + Collections.List.fold_left_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx false v) + ctx field_values + in + F.pp_print_string fmt ")"; + ctx + | Adt (adt_id, _) -> + (* "Regular" ADT *) + (* We print something of the form: `Cons field0 ... fieldn`. + * We could update the code to print something of the form: + * `{ field0=...; ...; fieldn=...; }` in case of structures. + *) + let adt_ident = + match variant_id with + | Some vid -> ctx_get_variant adt_id vid ctx + | None -> ctx_get_struct adt_id ctx + in + if inside && field_values <> [] then F.pp_print_string fmt "("; + let ctx = + Collections.List.fold_left_link + (fun () -> F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx true v) + ctx field_values + in + if inside && field_values <> [] then F.pp_print_string fmt ")"; + ctx + | _ -> failwith "Inconsistent typed value" + +(** [inside]: see [extract_ty]. + + As an lvalue can introduce new variables, we return an extraction context + updated with new bindings. + *) let rec extract_typed_lvalue (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_lvalue) : unit = - raise Unimplemented + (inside : bool) (v : typed_lvalue) : extraction_ctx = + match v.value with + | LvVar (Var (v, _)) -> + let vname = + ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty + in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | LvVar Dummy -> + F.pp_print_string fmt "_"; + ctx + | LvAdt av -> + let extract_value ctx inside v = extract_typed_lvalue ctx fmt inside v in + extract_adt_g_value extract_value fmt ctx inside av.variant_id + av.field_values v.ty (** [inside]: see [extract_ty] *) let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter) @@ -493,17 +577,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (qualif : fun_def_qualif) (def : fun_def) : unit = (* Retrieve the function name *) let def_name = ctx_get_local_function def.def_id def.back_id ctx in - (* Add the parameters - note that we need those bindings only for the + (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx_body, type_params = - ctx_add_type_params def.signature.type_params ctx - in - (* Note that some of the input parameters might not be used, in which case - * they could be ignored (they will be printed as `_` and thus won't appear, - * but if we don't ignore them they will still be used to check for name - * clashes, and will have an influence on the computation of indices for - * the local variables). This is mostly a detail, though. *) - let ctx_body, input_params = ctx_add_vars def.inputs ctx in + let ctx, type_params = ctx_add_type_params def.signature.type_params ctx in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -527,25 +603,28 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "("; List.iter (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx_body in + let pname = ctx_get_type_var p.index ctx in F.pp_print_string fmt pname; F.pp_print_space fmt ()) def.signature.type_params; F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "Type0)"); - (* The input parameters *) - List.iter - (fun (lv : typed_lvalue) -> - F.pp_print_string fmt "("; - extract_typed_lvalue ctx_body fmt false lv; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx_body fmt false lv.ty; - F.pp_print_string fmt ")"; - F.pp_print_space fmt ()) - def.inputs_lvs; + (* The input parameters - note that doing this adds bindings in the context *) + let ctx = + List.fold_left + (fun ctx (lv : typed_lvalue) -> + F.pp_print_string fmt "("; + let ctx = extract_typed_lvalue ctx fmt false lv in + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false lv.ty; + F.pp_print_string fmt ")"; + F.pp_print_space fmt (); + ctx) + ctx def.inputs_lvs + in (* Print the "=" *) F.pp_print_space fmt (); F.pp_print_string fmt "="; @@ -555,7 +634,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the body *) F.pp_open_hvbox fmt 0; (* Extract the body *) - extract_texpression ctx_body fmt false def.body; + extract_texpression ctx fmt false def.body; F.pp_close_box fmt (); (* Close the box for the body *) (* Close the box for the definition *) diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index be489952..02c507ef 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -52,6 +52,18 @@ type name_formatter = { - type name - variant name *) + struct_constructor : name -> string; + (** Structure constructors are used when constructing structure values. + + For instance, in F*: + ``` + type pair = { x : nat; y : nat } + let p : pair = Mkpair 0 1 + ``` + + Inputs: + - type name + *) type_name : name -> string; (** Provided a basename, compute a type name. *) fun_name : A.fun_id -> name -> int -> region_group_info option -> string; (** Inputs: @@ -165,11 +177,21 @@ let compute_fun_def_name (ctx : trans_ctx) (fmt : name_formatter) type id = | FunId of A.fun_id * RegionGroupId.id option | TypeId of type_id - | VariantId of TypeDefId.id * VariantId.id + | StructId of type_id + (** We use this when we manipulate the names of the structure + constructors. + + For instance, in F*: + ``` + type pair = { x: nat; y : nat } + let p : pair = Mkpair 0 1 + ``` + *) + | VariantId of type_id * VariantId.id (** If often happens that variant names must be unique (it is the case in F* ) which is why we register them here. *) - | FieldId of TypeDefId.id * FieldId.id + | FieldId of type_id * FieldId.id (** If often happens that in the case of structures, the field names must be unique (it is the case in F* ) which is why we register them here. @@ -211,19 +233,6 @@ type names_map = { We use it for lookups (during the translation) and to check for name clashes. *) -(** Initialize a names map with a proper set of keywords/names coming from the - target language/prover. *) -let initialize_names_map (keywords : string list) : names_map = - let name_to_id = - StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords) - in - let names_set = StringSet.of_list keywords in - (* We initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. - * Also note that we don't need this mapping for keywords: we insert keywords only - * to check collisions. *) - let id_to_name = IdMap.empty in - { id_to_name; name_to_id; names_set } - let names_map_add (id : id) (name : string) (nm : names_map) : names_map = (* Sanity check: no clashes *) assert (not (StringSet.mem name nm.names_set)); @@ -233,6 +242,18 @@ let names_map_add (id : id) (name : string) (nm : names_map) : names_map = let names_set = StringSet.add name nm.names_set in { id_to_name; name_to_id; names_set } +let names_map_add_assumed_type (id : assumed_ty) (name : string) + (nm : names_map) : names_map = + names_map_add (TypeId (Assumed id)) name nm + +let names_map_add_assumed_struct (id : assumed_ty) (name : string) + (nm : names_map) : names_map = + names_map_add (StructId (Assumed id)) name nm + +let names_map_add_assumed_variant (id : assumed_ty) (variant_id : VariantId.id) + (name : string) (nm : names_map) : names_map = + names_map_add (VariantId (Assumed id, variant_id)) name nm + (* TODO: remove those functions? We use the ones of extraction_ctx *) let names_map_get (id : id) (nm : names_map) : string = IdMap.find id nm.id_to_name @@ -311,17 +332,23 @@ let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = let ctx_get_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string = ctx_get_type (AdtId id) ctx +let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = + ctx_get_type (Assumed id) ctx + let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = ctx_get (VarId id) ctx let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = ctx_get (TypeVarId id) ctx -let ctx_get_field (def_id : TypeDefId.id) (field_id : FieldId.id) +let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = - ctx_get (FieldId (def_id, field_id)) ctx + ctx_get (FieldId (type_id, field_id)) ctx + +let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string = + ctx_get (StructId def_id) ctx -let ctx_get_variant (def_id : TypeDefId.id) (variant_id : VariantId.id) +let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = ctx_get (VariantId (def_id, variant_id)) ctx @@ -366,6 +393,12 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) : (fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx) ctx vars +let ctx_add_type_def_struct (def : type_def) (ctx : extraction_ctx) : + extraction_ctx * string = + let cons_name = ctx.fmt.struct_constructor def.name in + let ctx = ctx_add (StructId (AdtId def.def_id)) cons_name ctx in + (ctx, cons_name) + let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) : extraction_ctx * string = let def_name = ctx.fmt.type_name def.name in @@ -375,7 +408,7 @@ let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) : let ctx_add_field (def : type_def) (field_id : FieldId.id) (field : field) (ctx : extraction_ctx) : extraction_ctx * string = let name = ctx.fmt.field_name def.name field_id field.field_name in - let ctx = ctx_add (FieldId (def.def_id, field_id)) name ctx in + let ctx = ctx_add (FieldId (AdtId def.def_id, field_id)) name ctx in (ctx, name) let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list) @@ -387,7 +420,7 @@ let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list) let ctx_add_variant (def : type_def) (variant_id : VariantId.id) (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = let name = ctx.fmt.variant_name def.name variant.variant_name in - let ctx = ctx_add (VariantId (def.def_id, variant_id)) name ctx in + let ctx = ctx_add (VariantId (AdtId def.def_id, variant_id)) name ctx in (ctx, name) let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list) @@ -421,3 +454,46 @@ let ctx_add_fun_def (def : fun_def) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.fun_name def_id def.basename num_rgs rg_info in let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in ctx + +type names_map_init = { + keywords : string list; + assumed_adts : (assumed_ty * string) list; + assumed_structs : (assumed_ty * string) list; + assumed_variants : (assumed_ty * VariantId.id * string) list; +} + +(** Initialize a names map with a proper set of keywords/names coming from the + target language/prover. *) +let initialize_names_map (init : names_map_init) : names_map = + let name_to_id = + StringMap.of_list (List.map (fun x -> (x, UnknownId)) init.keywords) + in + let names_set = StringSet.of_list init.keywords in + (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. + * Also note that we don't need this mapping for keywords: we insert keywords only + * to check collisions. *) + let id_to_name = IdMap.empty in + let nm = { id_to_name; name_to_id; names_set } in + (* Then we add: + * - the assumed types + * - the assumed struct constructors + * - the assumed variants + *) + let nm = + List.fold_left + (fun nm (type_id, name) -> names_map_add_assumed_type type_id name nm) + nm init.assumed_adts + in + let nm = + List.fold_left + (fun nm (type_id, name) -> names_map_add_assumed_struct type_id name nm) + nm init.assumed_structs + in + let nm = + List.fold_left + (fun nm (type_id, variant_id, name) -> + names_map_add_assumed_variant type_id variant_id name nm) + nm init.assumed_variants + in + (* Return *) + nm diff --git a/src/Translate.ml b/src/Translate.ml index cff814f4..efaa43d6 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -267,7 +267,7 @@ let translate_module (filename : string) (config : C.partial_config) (* Initialize the extraction context - for now we extract only to F* *) let names_map = - PureToExtract.initialize_names_map ExtractToFStar.fstar_keywords + PureToExtract.initialize_names_map ExtractToFStar.fstar_names_map_init in let variant_concatenate_type_name = true in let fstar_fmt = |