summaryrefslogtreecommitdiff
path: root/src/ExtractToFstar.ml
blob: 56a8c33821d1b791a351af544e77e0c962dedb76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
(** Extract to F* *)

open Errors
open Pure
open TranslateCore
open PureToExtract
module F = Format

(** Iter "between".

    Iterate over a list, but call a function between every two elements
    (but not before the first element, and not after the last).
 *)
let list_iterb (between : unit -> unit) (f : 'a -> unit) (ls : 'a list) : unit =
  let rec iter ls =
    match ls with
    | [] -> ()
    | [ x ] -> f x
    | x :: y :: ls ->
        f x;
        between ();
        iter (y :: ls)
  in
  iter ls

(** [inside] constrols whether we should add parentheses or not around type
    application (if `true` we add parentheses).
 *)
let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
    (ty : ty) : unit =
  match ty with
  | Adt (type_id, tys) -> (
      match type_id with
      | Tuple ->
          F.pp_print_string fmt "(";
          list_iterb (F.pp_print_space fmt) (extract_ty ctx fmt true) tys;
          F.pp_print_string fmt ")"
      | AdtId _ | Assumed _ ->
          if inside then F.pp_print_string fmt "(";
          F.pp_print_string fmt (ctx_find_type type_id ctx);
          if tys <> [] then F.pp_print_space fmt ();
          list_iterb (F.pp_print_space fmt) (extract_ty ctx fmt true) tys;
          if inside then F.pp_print_string fmt ")")
  | TypeVar vid -> F.pp_print_string fmt (ctx_find_type_var vid ctx)
  | Bool -> F.pp_print_string fmt ctx.fmt.bool_name
  | Char -> F.pp_print_string fmt ctx.fmt.char_name
  | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
  | Str -> F.pp_print_string fmt ctx.fmt.str_name
  | Array _ | Slice _ -> raise Unimplemented

let extract_type_def_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
    (type_name : string) (type_params : string list) (fields : field list) :
    unit =
  (* We want to generate a definition which looks like this:
   * ```
   * type s = { x : int; y : bool; }
   * ```
   *
   * Or if there isn't enough space on one line:
   * ```
   * type s = {
   *   x : int;
   *   y : bool;
   * }
   * ```
   * Note that we already printed: `type s =`
   *)
  F.pp_print_space fmt ();
  F.pp_print_string fmt "{";
  (* The body itself *)
  F.pp_open_hvbox fmt 0;
  F.pp_open_hvbox fmt ctx.indent_incr;
  F.pp_print_space fmt ();
  (* Print the fields *)
  let print_field (f : field) : unit =
    let field_name = ctx.fmt.field_name type_name f.field_name in
    F.pp_open_box fmt ctx.indent_incr;
    F.pp_print_string fmt field_name;
    F.pp_print_space fmt ();
    F.pp_print_string fmt ":";
    F.pp_print_space fmt ();
    extract_ty ctx fmt false f.field_ty;
    F.pp_close_box fmt ()
  in
  List.iter print_field fields;
  (* Close *)
  F.pp_close_box fmt ();
  F.pp_print_string fmt "}";
  F.pp_close_box fmt ()

let extract_type_def_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
    (type_name : string) (type_params : string list) (variants : variant list) :
    unit =
  raise Unimplemented

let rec extract_type_def (ctx : extraction_ctx) (fmt : F.formatter)
    (def : type_def) : extraction_ctx =
  (* Compute and register the type def name *)
  let ctx, def_name = ctx_add_type_def def ctx in
  (* Compute and register the variant names, if this is an enumeration *)
  let ctx, variant_names =
    match def.kind with
    | Struct _ -> (ctx, [])
    | Enum variants ->
        ctx_add_variants def (VariantId.mapi (fun id v -> (id, v)) variants) ctx
  in
  (* Add the type params - note that we remember those bindings only for the
   * body translation: the updated ctx we return at the end of the function
   * only contains the registered type def and variant names *)
  let ctx_body, type_params = ctx_add_type_params def.type_params ctx in
  (* > "type TYPE_NAME =" *)
  F.pp_print_string fmt "type";
  F.pp_print_space fmt ();
  F.pp_print_string fmt def_name;
  (match def.kind with
  | Struct fields ->
      extract_type_def_struct_body ctx_body fmt def_name type_params fields
  | Enum variants ->
      extract_type_def_enum_body ctx_body fmt def_name type_params variants);
  ctx