diff options
Diffstat (limited to '')
-rw-r--r-- | src/ExtractToFStar.ml | 137 |
1 files changed, 108 insertions, 29 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 6e09dfa6..fb781939 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -39,6 +39,21 @@ let fstar_keywords = "Type0"; ] +let fstar_assumed_adts : (assumed_ty * string) list = [ (Result, "result") ] + +let fstar_assumed_structs : (assumed_ty * string) list = [] + +let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list = + [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail") ] + +let fstar_names_map_init = + { + keywords = fstar_keywords; + assumed_adts = fstar_assumed_adts; + assumed_structs = fstar_assumed_structs; + assumed_variants = fstar_assumed_variants; + } + (** * [ctx]: we use the context to lookup type definitions, to retrieve type names. * This is used to compute variable names, when they have no basenames: in this @@ -121,6 +136,10 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = type_name_to_camel_case def_name ^ variant else variant in + let struct_constructor (basename : name) : string = + let tname = type_name basename in + "Mk" ^ tname + in (* For now, we treat only the case where function names are of the * form: `function` (no module prefix) *) @@ -182,6 +201,7 @@ let mk_name_formatter (ctx : trans_ctx) (variant_concatenate_type_name : bool) = str_name = "string"; field_name; variant_name; + struct_constructor; type_name; fun_name; var_basename; @@ -283,7 +303,7 @@ let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hvbox fmt 0; (* Print the fields *) let print_field (field_id : FieldId.id) (f : field) : unit = - let field_name = ctx_get_field def.def_id field_id ctx in + let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in F.pp_open_box fmt ctx.indent_incr; F.pp_print_string fmt field_name; F.pp_print_space fmt (); @@ -339,7 +359,7 @@ let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter) *) (* Print the variants *) let print_variant (variant_id : VariantId.id) (variant : variant) : unit = - let variant_name = ctx_get_variant def.def_id variant_id ctx in + let variant_name = ctx_get_variant (AdtId def.def_id) variant_id ctx in F.pp_print_space fmt (); F.pp_open_hvbox fmt ctx.indent_incr; (* variant box *) @@ -469,10 +489,74 @@ let extract_fun_def_register_names (ctx : extraction_ctx) (* Return *) ctx -(** [inside]: see [extract_ty] *) +(** The following function factorizes the extraction of ADT values. + + Note that lvalues can introduce new variables: we thus return an extraction + context updated with new bindings. + *) +let extract_adt_g_value + (extract_value : extraction_ctx -> bool -> 'v -> extraction_ctx) + (fmt : F.formatter) (ctx : extraction_ctx) (inside : bool) + (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : + extraction_ctx = + match ty with + | Adt (Tuple, _) -> + (* Tuple *) + F.pp_print_string fmt "("; + let ctx = + Collections.List.fold_left_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx false v) + ctx field_values + in + F.pp_print_string fmt ")"; + ctx + | Adt (adt_id, _) -> + (* "Regular" ADT *) + (* We print something of the form: `Cons field0 ... fieldn`. + * We could update the code to print something of the form: + * `{ field0=...; ...; fieldn=...; }` in case of structures. + *) + let adt_ident = + match variant_id with + | Some vid -> ctx_get_variant adt_id vid ctx + | None -> ctx_get_struct adt_id ctx + in + if inside && field_values <> [] then F.pp_print_string fmt "("; + let ctx = + Collections.List.fold_left_link + (fun () -> F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx true v) + ctx field_values + in + if inside && field_values <> [] then F.pp_print_string fmt ")"; + ctx + | _ -> failwith "Inconsistent typed value" + +(** [inside]: see [extract_ty]. + + As an lvalue can introduce new variables, we return an extraction context + updated with new bindings. + *) let rec extract_typed_lvalue (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (v : typed_lvalue) : unit = - raise Unimplemented + (inside : bool) (v : typed_lvalue) : extraction_ctx = + match v.value with + | LvVar (Var (v, _)) -> + let vname = + ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty + in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | LvVar Dummy -> + F.pp_print_string fmt "_"; + ctx + | LvAdt av -> + let extract_value ctx inside v = extract_typed_lvalue ctx fmt inside v in + extract_adt_g_value extract_value fmt ctx inside av.variant_id + av.field_values v.ty (** [inside]: see [extract_ty] *) let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter) @@ -493,17 +577,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (qualif : fun_def_qualif) (def : fun_def) : unit = (* Retrieve the function name *) let def_name = ctx_get_local_function def.def_id def.back_id ctx in - (* Add the parameters - note that we need those bindings only for the + (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx_body, type_params = - ctx_add_type_params def.signature.type_params ctx - in - (* Note that some of the input parameters might not be used, in which case - * they could be ignored (they will be printed as `_` and thus won't appear, - * but if we don't ignore them they will still be used to check for name - * clashes, and will have an influence on the computation of indices for - * the local variables). This is mostly a detail, though. *) - let ctx_body, input_params = ctx_add_vars def.inputs ctx in + let ctx, type_params = ctx_add_type_params def.signature.type_params ctx in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -527,25 +603,28 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "("; List.iter (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx_body in + let pname = ctx_get_type_var p.index ctx in F.pp_print_string fmt pname; F.pp_print_space fmt ()) def.signature.type_params; F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "Type0)"); - (* The input parameters *) - List.iter - (fun (lv : typed_lvalue) -> - F.pp_print_string fmt "("; - extract_typed_lvalue ctx_body fmt false lv; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx_body fmt false lv.ty; - F.pp_print_string fmt ")"; - F.pp_print_space fmt ()) - def.inputs_lvs; + (* The input parameters - note that doing this adds bindings in the context *) + let ctx = + List.fold_left + (fun ctx (lv : typed_lvalue) -> + F.pp_print_string fmt "("; + let ctx = extract_typed_lvalue ctx fmt false lv in + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false lv.ty; + F.pp_print_string fmt ")"; + F.pp_print_space fmt (); + ctx) + ctx def.inputs_lvs + in (* Print the "=" *) F.pp_print_space fmt (); F.pp_print_string fmt "="; @@ -555,7 +634,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the body *) F.pp_open_hvbox fmt 0; (* Extract the body *) - extract_texpression ctx_body fmt false def.body; + extract_texpression ctx fmt false def.body; F.pp_close_box fmt (); (* Close the box for the body *) (* Close the box for the definition *) |