summaryrefslogtreecommitdiff
path: root/src/ExtractToFStar.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/ExtractToFStar.ml')
-rw-r--r--src/ExtractToFStar.ml137
1 files changed, 108 insertions, 29 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 6e09dfa6..fb781939 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -39,6 +39,21 @@ let fstar_keywords =
"Type0";
]
+let fstar_assumed_adts : (assumed_ty * string) list = [ (Result, "result") ]
+
+let fstar_assumed_structs : (assumed_ty * string) list = []
+
+let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list =
+ [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail") ]
+
+let fstar_names_map_init =
+ {
+ keywords = fstar_keywords;
+ assumed_adts = fstar_assumed_adts;
+ assumed_structs = fstar_assumed_structs;
+ assumed_variants = fstar_assumed_variants;
+ }
+
(**
* [ctx]: we use the context to lookup type definitions, to retrieve type names.
* This is used to compute variable names, when they have no basenames: in this
@@ -121,6 +136,10 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
type_name_to_camel_case def_name ^ variant
else variant
in
+ let struct_constructor (basename : name) : string =
+ let tname = type_name basename in
+ "Mk" ^ tname
+ in
(* For now, we treat only the case where function names are of the
* form: `function` (no module prefix)
*)
@@ -182,6 +201,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) =
str_name = "string";
field_name;
variant_name;
+ struct_constructor;
type_name;
fun_name;
var_basename;
@@ -283,7 +303,7 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hvbox fmt 0;
(* Print the fields *)
let print_field (field_id : FieldId.id) (f : field) : unit =
- let field_name = ctx_get_field def.def_id field_id ctx in
+ let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in
F.pp_open_box fmt ctx.indent_incr;
F.pp_print_string fmt field_name;
F.pp_print_space fmt ();
@@ -339,7 +359,7 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
*)
(* Print the variants *)
let print_variant (variant_id : VariantId.id) (variant : variant) : unit =
- let variant_name = ctx_get_variant def.def_id variant_id ctx in
+ let variant_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in
F.pp_print_space fmt ();
F.pp_open_hvbox fmt ctx.indent_incr;
(* variant box *)
@@ -469,10 +489,74 @@ let extract_fun_def_register_names (ctx : extraction_ctx)
(* Return *)
ctx
-(** [inside]: see [extract_ty] *)
+(** The following function factorizes the extraction of ADT values.
+
+ Note that lvalues can introduce new variables: we thus return an extraction
+ context updated with new bindings.
+ *)
+let extract_adt_g_value
+ (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx)
+ (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool)
+ (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
+ extraction_ctx =
+ match ty with
+ | Adt (Tuple, _) ->
+ (* Tuple *)
+ F.pp_print_string fmt "(";
+ let ctx =
+ Collections.List.fold_left_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx false v)
+ ctx field_values
+ in
+ F.pp_print_string fmt ")";
+ ctx
+ | Adt (adt_id, _) ->
+ (* "Regular" ADT *)
+ (* We print something of the form: `Cons field0 ... fieldn`.
+ * We could update the code to print something of the form:
+ * `{ field0=...; ...; fieldn=...; }` in case of structures.
+ *)
+ let adt_ident =
+ match variant_id with
+ | Some vid -> ctx_get_variant adt_id vid ctx
+ | None -> ctx_get_struct adt_id ctx
+ in
+ if inside && field_values <> [] then F.pp_print_string fmt "(";
+ let ctx =
+ Collections.List.fold_left_link
+ (fun () -> F.pp_print_space fmt ())
+ (fun ctx v -> extract_value ctx true v)
+ ctx field_values
+ in
+ if inside && field_values <> [] then F.pp_print_string fmt ")";
+ ctx
+ | _ -> failwith "Inconsistent typed value"
+
+(** [inside]: see [extract_ty].
+
+ As an lvalue can introduce new variables, we return an extraction context
+ updated with new bindings.
+ *)
let rec extract_typed_lvalue (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (v : typed_lvalue) : unit =
- raise Unimplemented
+ (inside : bool) (v : typed_lvalue) : extraction_ctx =
+ match v.value with
+ | LvVar (Var (v, _)) ->
+ let vname =
+ ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty
+ in
+ let ctx, vname = ctx_add_var vname v.id ctx in
+ F.pp_print_string fmt vname;
+ ctx
+ | LvVar Dummy ->
+ F.pp_print_string fmt "_";
+ ctx
+ | LvAdt av ->
+ let extract_value ctx inside v = extract_typed_lvalue ctx fmt inside v in
+ extract_adt_g_value extract_value fmt ctx inside av.variant_id
+ av.field_values v.ty
(** [inside]: see [extract_ty] *)
let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter)
@@ -493,17 +577,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(qualif : fun_def_qualif) (def : fun_def) : unit =
(* Retrieve the function name *)
let def_name = ctx_get_local_function def.def_id def.back_id ctx in
- (* Add the parameters - note that we need those bindings only for the
+ (* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx_body, type_params =
- ctx_add_type_params def.signature.type_params ctx
- in
- (* Note that some of the input parameters might not be used, in which case
- * they could be ignored (they will be printed as `_` and thus won't appear,
- * but if we don't ignore them they will still be used to check for name
- * clashes, and will have an influence on the computation of indices for
- * the local variables). This is mostly a detail, though. *)
- let ctx_body, input_params = ctx_add_vars def.inputs ctx in
+ let ctx, type_params = ctx_add_type_params def.signature.type_params ctx in
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -527,25 +603,28 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "(";
List.iter
(fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx_body in
+ let pname = ctx_get_type_var p.index ctx in
F.pp_print_string fmt pname;
F.pp_print_space fmt ())
def.signature.type_params;
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
F.pp_print_string fmt "Type0)");
- (* The input parameters *)
- List.iter
- (fun (lv : typed_lvalue) ->
- F.pp_print_string fmt "(";
- extract_typed_lvalue ctx_body fmt false lv;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_ty ctx_body fmt false lv.ty;
- F.pp_print_string fmt ")";
- F.pp_print_space fmt ())
- def.inputs_lvs;
+ (* The input parameters - note that doing this adds bindings in the context *)
+ let ctx =
+ List.fold_left
+ (fun ctx (lv : typed_lvalue) ->
+ F.pp_print_string fmt "(";
+ let ctx = extract_typed_lvalue ctx fmt false lv in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt false lv.ty;
+ F.pp_print_string fmt ")";
+ F.pp_print_space fmt ();
+ ctx)
+ ctx def.inputs_lvs
+ in
(* Print the "=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt "=";
@@ -555,7 +634,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the body *)
F.pp_open_hvbox fmt 0;
(* Extract the body *)
- extract_texpression ctx_body fmt false def.body;
+ extract_texpression ctx fmt false def.body;
F.pp_close_box fmt ();
(* Close the box for the body *)
(* Close the box for the definition *)