summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-02-03 12:30:28 +0100
committerSon Ho2022-02-03 12:30:28 +0100
commit972ed4288ff1f489fcf03b4cdca847abcc55674e (patch)
tree10878047f977de423bb2f43af5c8cf72617e7631
parent72a6a2830a257aad3e4da2d8a53ac07cd38e8f41 (diff)
Make more progress on implementing function extraction
Diffstat (limited to '')
-rw-r--r--src/Collections.ml23
-rw-r--r--src/ExtractToFStar.ml137
-rw-r--r--src/PureToExtract.ml116
-rw-r--r--src/Translate.ml2
4 files changed, 225 insertions, 53 deletions
diff --git a/src/Collections.ml b/src/Collections.ml
index 125cab1f..2bfdd18b 100644
--- a/src/Collections.ml
+++ b/src/Collections.ml
@@ -41,9 +41,9 @@ module List = struct
(** Iter and link the iterations.
- Iterate over a list, but call a function between every two elements
- (but not before the first element, and not after the last).
- *)
+ Iterate over a list, but call a function between every two elements
+ (but not before the first element, and not after the last).
+ *)
let iter_link (link : unit -> unit) (f : 'a -> unit) (ls : 'a list) : unit =
let rec iter ls =
match ls with
@@ -55,6 +55,23 @@ module List = struct
iter (y :: ls)
in
iter ls
+
+ (** Fold and link the iterations.
+
+ Similar to [iter_link] but for fold left operations.
+ *)
+ let fold_left_link (link : unit -> unit) (f : 'a -> 'b -> 'a) (init : 'a)
+ (ls : 'b list) : 'a =
+ let rec fold (acc : 'a) (ls : 'b list) : 'a =
+ match ls with
+ | [] -> acc
+ | [ x ] -> f acc x
+ | x :: y :: ls ->
+ let acc = f acc x in
+ link ();
+ fold acc (y :: ls)
+ in
+ fold init ls
end
module type OrderedType = sig
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 6e09dfa6..fb781939 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -39,6 +39,21 @@ let fstar_keywords =
"Type0";
]
+let fstar_assumed_adts : (assumed_ty * string) list = [ (Result, "result") ]
+
+let fstar_assumed_structs : (assumed_ty * string) list = []
+
+let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list =
+ [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail") ]
+
+let fstar_names_map_init =
+ {
+ keywords = fstar_keywords;
+ assumed_adts = fstar_assumed_adts;
+ assumed_structs = fstar_assumed_structs;
+ assumed_variants = fstar_assumed_variants;
+ }
+
(**
* [ctx]: we use the context to lookup type definitions, to retrieve type names.
* This is used to compute variable names, when they have no basenames: in this
@@ -121,6 +136,10 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
type_name_to_camel_case def_name ^ variant
else variant
in
+ let struct_constructor (basename : name) : string =
+ let tname = type_name basename in
+ "Mk" ^ tname
+ in
(* For now, we treat only the case where function names are of the
* form: `function` (no module prefix)
*)
@@ -182,6 +201,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
str_name = "string";
field_name;
variant_name;
+ struct_constructor;
type_name;
fun_name;
var_basename;
@@ -283,7 +303,7 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hvbox fmt 0;
(* Print the fields *)
let print_field (field_id : FieldId.id) (f : field) : unit =
- let field_name = ctx_get_field def.def_id field_id ctx in
+ let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in
F.pp_open_box fmt ctx.indent_incr;
F.pp_print_string fmt field_name;
F.pp_print_space fmt ();
@@ -339,7 +359,7 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
*)
(* Print the variants *)
let print_variant (variant_id : VariantId.id) (variant : variant) : unit =
- let variant_name = ctx_get_variant def.def_id variant_id ctx in
+ let variant_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in
F.pp_print_space fmt ();
F.pp_open_hvbox fmt ctx.indent_incr;
(* variant box *)
@@ -469,10 +489,74 @@ let extract_fun_def_register_names (ctx : extraction_ctx)
(* Return *)
ctx
-(** [inside]: see [extract_ty] *)
+(** The following function factorizes the extraction of ADT values.
+
+ Note that lvalues can introduce new variables: we thus return an extraction
+ context updated with new bindings.
+ *)
+let extract_adt_g_value
+ (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx)
+ (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool)
+ (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
+ extraction_ctx =
+ match ty with
+ | Adt (Tuple, _) ->
+ (* Tuple *)
+ F.pp_print_string fmt "(";
+ let ctx =
+ Collections.List.fold_left_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx false v)
+ ctx field_values
+ in
+ F.pp_print_string fmt ")";
+ ctx
+ | Adt (adt_id, _) ->
+ (* "Regular" ADT *)
+ (* We print something of the form: `Cons field0 ... fieldn`.
+ * We could update the code to print something of the form:
+ * `{ field0=...; ...; fieldn=...; }` in case of structures.
+ *)
+ let adt_ident =
+ match variant_id with
+ | Some vid -> ctx_get_variant adt_id vid ctx
+ | None -> ctx_get_struct adt_id ctx
+ in
+ if inside && field_values <> [] then F.pp_print_string fmt "(";
+ let ctx =
+ Collections.List.fold_left_link
+ (fun () -> F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx true v)
+ ctx field_values
+ in
+ if inside && field_values <> [] then F.pp_print_string fmt ")";
+ ctx
+ | _ -> failwith "Inconsistent typed value"
+
+(** [inside]: see [extract_ty].
+
+ As an lvalue can introduce new variables, we return an extraction context
+ updated with new bindings.
+ *)
let rec extract_typed_lvalue (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (v : typed_lvalue) : unit =
- raise Unimplemented
+ (inside : bool) (v : typed_lvalue) : extraction_ctx =
+ match v.value with
+ | LvVar (Var (v, _)) ->
+ let vname =
+ ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty
+ in
+ let ctx, vname = ctx_add_var vname v.id ctx in
+ F.pp_print_string fmt vname;
+ ctx
+ | LvVar Dummy ->
+ F.pp_print_string fmt "_";
+ ctx
+ | LvAdt av ->
+ let extract_value ctx inside v = extract_typed_lvalue ctx fmt inside v in
+ extract_adt_g_value extract_value fmt ctx inside av.variant_id
+ av.field_values v.ty
(** [inside]: see [extract_ty] *)
let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter)
@@ -493,17 +577,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(qualif : fun_def_qualif) (def : fun_def) : unit =
(* Retrieve the function name *)
let def_name = ctx_get_local_function def.def_id def.back_id ctx in
- (* Add the parameters - note that we need those bindings only for the
+ (* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx_body, type_params =
- ctx_add_type_params def.signature.type_params ctx
- in
- (* Note that some of the input parameters might not be used, in which case
- * they could be ignored (they will be printed as `_` and thus won't appear,
- * but if we don't ignore them they will still be used to check for name
- * clashes, and will have an influence on the computation of indices for
- * the local variables). This is mostly a detail, though. *)
- let ctx_body, input_params = ctx_add_vars def.inputs ctx in
+ let ctx, type_params = ctx_add_type_params def.signature.type_params ctx in
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -527,25 +603,28 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "(";
List.iter
(fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx_body in
+ let pname = ctx_get_type_var p.index ctx in
F.pp_print_string fmt pname;
F.pp_print_space fmt ())
def.signature.type_params;
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
F.pp_print_string fmt "Type0)");
- (* The input parameters *)
- List.iter
- (fun (lv : typed_lvalue) ->
- F.pp_print_string fmt "(";
- extract_typed_lvalue ctx_body fmt false lv;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_ty ctx_body fmt false lv.ty;
- F.pp_print_string fmt ")";
- F.pp_print_space fmt ())
- def.inputs_lvs;
+ (* The input parameters - note that doing this adds bindings in the context *)
+ let ctx =
+ List.fold_left
+ (fun ctx (lv : typed_lvalue) ->
+ F.pp_print_string fmt "(";
+ let ctx = extract_typed_lvalue ctx fmt false lv in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt false lv.ty;
+ F.pp_print_string fmt ")";
+ F.pp_print_space fmt ();
+ ctx)
+ ctx def.inputs_lvs
+ in
(* Print the "=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt "=";
@@ -555,7 +634,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the body *)
F.pp_open_hvbox fmt 0;
(* Extract the body *)
- extract_texpression ctx_body fmt false def.body;
+ extract_texpression ctx fmt false def.body;
F.pp_close_box fmt ();
(* Close the box for the body *)
(* Close the box for the definition *)
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index be489952..02c507ef 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -52,6 +52,18 @@ type name_formatter = {
- type name
- variant name
*)
+ struct_constructor : name -> string;
+ (** Structure constructors are used when constructing structure values.
+
+ For instance, in F*:
+ ```
+ type pair = { x : nat; y : nat }
+ let p : pair = Mkpair 0 1
+ ```
+
+ Inputs:
+ - type name
+ *)
type_name : name -> string; (** Provided a basename, compute a type name. *)
fun_name : A.fun_id -> name -> int -> region_group_info option -> string;
(** Inputs:
@@ -165,11 +177,21 @@ let compute_fun_def_name (ctx : trans_ctx) (fmt : name_formatter)
type id =
| FunId of A.fun_id * RegionGroupId.id option
| TypeId of type_id
- | VariantId of TypeDefId.id * VariantId.id
+ | StructId of type_id
+ (** We use this when we manipulate the names of the structure
+ constructors.
+
+ For instance, in F*:
+ ```
+ type pair = { x: nat; y : nat }
+ let p : pair = Mkpair 0 1
+ ```
+ *)
+ | VariantId of type_id * VariantId.id
(** If often happens that variant names must be unique (it is the case in
F* ) which is why we register them here.
*)
- | FieldId of TypeDefId.id * FieldId.id
+ | FieldId of type_id * FieldId.id
(** If often happens that in the case of structures, the field names
must be unique (it is the case in F* ) which is why we register
them here.
@@ -211,19 +233,6 @@ type names_map = {
We use it for lookups (during the translation) and to check for name clashes.
*)
-(** Initialize a names map with a proper set of keywords/names coming from the
- target language/prover. *)
-let initialize_names_map (keywords : string list) : names_map =
- let name_to_id =
- StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords)
- in
- let names_set = StringSet.of_list keywords in
- (* We initialize [id_to_name] as empty, because the id of a keyword is [UnknownId].
- * Also note that we don't need this mapping for keywords: we insert keywords only
- * to check collisions. *)
- let id_to_name = IdMap.empty in
- { id_to_name; name_to_id; names_set }
-
let names_map_add (id : id) (name : string) (nm : names_map) : names_map =
(* Sanity check: no clashes *)
assert (not (StringSet.mem name nm.names_set));
@@ -233,6 +242,18 @@ let names_map_add (id : id) (name : string) (nm : names_map) : names_map =
let names_set = StringSet.add name nm.names_set in
{ id_to_name; name_to_id; names_set }
+let names_map_add_assumed_type (id : assumed_ty) (name : string)
+ (nm : names_map) : names_map =
+ names_map_add (TypeId (Assumed id)) name nm
+
+let names_map_add_assumed_struct (id : assumed_ty) (name : string)
+ (nm : names_map) : names_map =
+ names_map_add (StructId (Assumed id)) name nm
+
+let names_map_add_assumed_variant (id : assumed_ty) (variant_id : VariantId.id)
+ (name : string) (nm : names_map) : names_map =
+ names_map_add (VariantId (Assumed id, variant_id)) name nm
+
(* TODO: remove those functions? We use the ones of extraction_ctx *)
let names_map_get (id : id) (nm : names_map) : string =
IdMap.find id nm.id_to_name
@@ -311,17 +332,23 @@ let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
let ctx_get_local_type (id : TypeDefId.id) (ctx : extraction_ctx) : string =
ctx_get_type (AdtId id) ctx
+let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
+ ctx_get_type (Assumed id) ctx
+
let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
ctx_get (VarId id) ctx
let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string =
ctx_get (TypeVarId id) ctx
-let ctx_get_field (def_id : TypeDefId.id) (field_id : FieldId.id)
+let ctx_get_field (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
- ctx_get (FieldId (def_id, field_id)) ctx
+ ctx_get (FieldId (type_id, field_id)) ctx
+
+let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string =
+ ctx_get (StructId def_id) ctx
-let ctx_get_variant (def_id : TypeDefId.id) (variant_id : VariantId.id)
+let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id)
(ctx : extraction_ctx) : string =
ctx_get (VariantId (def_id, variant_id)) ctx
@@ -366,6 +393,12 @@ let ctx_add_type_params (vars : type_var list) (ctx : extraction_ctx) :
(fun ctx (var : type_var) -> ctx_add_type_var var.name var.index ctx)
ctx vars
+let ctx_add_type_def_struct (def : type_def) (ctx : extraction_ctx) :
+ extraction_ctx * string =
+ let cons_name = ctx.fmt.struct_constructor def.name in
+ let ctx = ctx_add (StructId (AdtId def.def_id)) cons_name ctx in
+ (ctx, cons_name)
+
let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) :
extraction_ctx * string =
let def_name = ctx.fmt.type_name def.name in
@@ -375,7 +408,7 @@ let ctx_add_type_def (def : type_def) (ctx : extraction_ctx) :
let ctx_add_field (def : type_def) (field_id : FieldId.id) (field : field)
(ctx : extraction_ctx) : extraction_ctx * string =
let name = ctx.fmt.field_name def.name field_id field.field_name in
- let ctx = ctx_add (FieldId (def.def_id, field_id)) name ctx in
+ let ctx = ctx_add (FieldId (AdtId def.def_id, field_id)) name ctx in
(ctx, name)
let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list)
@@ -387,7 +420,7 @@ let ctx_add_fields (def : type_def) (fields : (FieldId.id * field) list)
let ctx_add_variant (def : type_def) (variant_id : VariantId.id)
(variant : variant) (ctx : extraction_ctx) : extraction_ctx * string =
let name = ctx.fmt.variant_name def.name variant.variant_name in
- let ctx = ctx_add (VariantId (def.def_id, variant_id)) name ctx in
+ let ctx = ctx_add (VariantId (AdtId def.def_id, variant_id)) name ctx in
(ctx, name)
let ctx_add_variants (def : type_def) (variants : (VariantId.id * variant) list)
@@ -421,3 +454,46 @@ let ctx_add_fun_def (def : fun_def) (ctx : extraction_ctx) : extraction_ctx =
let name = ctx.fmt.fun_name def_id def.basename num_rgs rg_info in
let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in
ctx
+
+type names_map_init = {
+ keywords : string list;
+ assumed_adts : (assumed_ty * string) list;
+ assumed_structs : (assumed_ty * string) list;
+ assumed_variants : (assumed_ty * VariantId.id * string) list;
+}
+
+(** Initialize a names map with a proper set of keywords/names coming from the
+ target language/prover. *)
+let initialize_names_map (init : names_map_init) : names_map =
+ let name_to_id =
+ StringMap.of_list (List.map (fun x -> (x, UnknownId)) init.keywords)
+ in
+ let names_set = StringSet.of_list init.keywords in
+ (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId].
+ * Also note that we don't need this mapping for keywords: we insert keywords only
+ * to check collisions. *)
+ let id_to_name = IdMap.empty in
+ let nm = { id_to_name; name_to_id; names_set } in
+ (* Then we add:
+ * - the assumed types
+ * - the assumed struct constructors
+ * - the assumed variants
+ *)
+ let nm =
+ List.fold_left
+ (fun nm (type_id, name) -> names_map_add_assumed_type type_id name nm)
+ nm init.assumed_adts
+ in
+ let nm =
+ List.fold_left
+ (fun nm (type_id, name) -> names_map_add_assumed_struct type_id name nm)
+ nm init.assumed_structs
+ in
+ let nm =
+ List.fold_left
+ (fun nm (type_id, variant_id, name) ->
+ names_map_add_assumed_variant type_id variant_id name nm)
+ nm init.assumed_variants
+ in
+ (* Return *)
+ nm
diff --git a/src/Translate.ml b/src/Translate.ml
index cff814f4..efaa43d6 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -267,7 +267,7 @@ let translate_module (filename : string) (config : C.partial_config)
(* Initialize the extraction context - for now we extract only to F* *)
let names_map =
- PureToExtract.initialize_names_map ExtractToFStar.fstar_keywords
+ PureToExtract.initialize_names_map ExtractToFStar.fstar_names_map_init
in
let variant_concatenate_type_name = true in
let fstar_fmt =