summaryrefslogtreecommitdiff
path: root/src/Pure.ml
blob: b1c8e2541d63b5478d5cb3505bfff4108fb7fd02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
open Identifiers
module T = Types
module V = Values
module E = Expressions
module A = CfimAst
module TypeDeclId = T.TypeDeclId
module TypeVarId = T.TypeVarId
module RegionGroupId = T.RegionGroupId
module VariantId = T.VariantId
module FieldId = T.FieldId
module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId

module SynthPhaseId = IdGen ()
(** We give an identifier to every phase of the synthesis (forward, backward
    for group of regions 0, etc.) *)

module VarId = IdGen ()
(** Pay attention to the fact that we also define a [VarId] module in Values *)

type integer_type = T.integer_type [@@deriving show, ord]

(** The assumed types for the pure AST.

    In comparison with CFIM:
    - we removed `Box` (because it is translated as the identity: `Box T == T`)
    - we added:
      - `Result`: the type used in the error monad. This allows us to have a
        unified treatment of expressions (especially when we have to unfold the
        monadic binds)
      - `State`: the type of the state, when using state-error monads. Note that
        this state is opaque to Aeneas (the user can define it, or leave it as
        assumed)
  *)
type assumed_ty = State | Result | Vec | Option [@@deriving show, ord]

(* TODO: we should never directly manipulate `Return` and `Fail`, but rather
 * the monadic functions `return` and `fail` (makes treatment of error and
 * state-error monads more uniform) *)
let result_return_id = VariantId.of_int 0

let result_fail_id = VariantId.of_int 1

let option_some_id = T.option_some_id

let option_none_id = T.option_none_id

type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
[@@deriving show, ord]

(** Ancestor for iter visitor for [ty] *)
class ['self] iter_ty_base =
  object (_self : 'self)
    inherit [_] VisitorsRuntime.iter

    method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> ()

    method visit_type_id : 'env -> type_id -> unit = fun _ _ -> ()

    method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
  end

(** Ancestor for map visitor for [ty] *)
class ['self] map_ty_base =
  object (_self : 'self)
    inherit [_] VisitorsRuntime.map

    method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id

    method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id

    method visit_integer_type : 'env -> integer_type -> integer_type =
      fun _ ity -> ity
  end

type ty =
  | Adt of type_id * ty list
      (** [Adt] encodes ADTs and tuples and assumed types.
       
          TODO: what about the ended regions? (ADTs may be parameterized
          with several region variables. When giving back an ADT value, we may
          be able to only give back part of the ADT. We need a way to encode
          such "partial" ADTs.
       *)
  | TypeVar of TypeVarId.id
  | Bool
  | Char
  | Integer of integer_type
  | Str
  | Array of ty (* TODO: this should be an assumed type?... *)
  | Slice of ty (* TODO: this should be an assumed type?... *)
  | Arrow of ty * ty
[@@deriving
  show,
    visitors
      {
        name = "iter_ty";
        variety = "iter";
        ancestors = [ "iter_ty_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "map_ty";
        variety = "map";
        ancestors = [ "map_ty_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      }]

type field = { field_name : string option; field_ty : ty } [@@deriving show]

type variant = { variant_name : string; fields : field list } [@@deriving show]

type type_decl_kind = Struct of field list | Enum of variant list
[@@deriving show]

type type_var = T.type_var [@@deriving show]

type type_decl = {
  def_id : TypeDeclId.id;
  name : name;
  type_params : type_var list;
  kind : type_decl_kind;
}
[@@deriving show]

type scalar_value = V.scalar_value [@@deriving show]

type constant_value = V.constant_value [@@deriving show]

type var = {
  id : VarId.id;
  basename : string option;
      (** The "basename" is used to generate a meaningful name for the variable
          (by potentially adding an index to uniquely identify it).
       *)
  ty : ty;
}
[@@deriving show]
(** Because we introduce a lot of temporary variables, the list of variables
    is not fixed: we thus must carry all its information with the variable
    itself.
 *)

(* TODO: we might want to redefine field_proj_kind here, to prevent field accesses
 * on enumerations.
 * Also: tuples... *)
type projection_elem = { pkind : E.field_proj_kind; field_id : FieldId.id }
[@@deriving show]

type projection = projection_elem list [@@deriving show]

type mplace = { name : string option; projection : projection }
[@@deriving show]
(** "Meta" place.

    Meta-data retrieved from the symbolic execution, which gives provenance
    information about the values. We use this to generate names for the variables
    we introduce.
 *)

type place = { var : VarId.id; projection : projection } [@@deriving show]

(** Ancestor for [iter_var_or_dummy] visitor *)
class ['self] iter_value_base =
  object (_self : 'self)
    inherit [_] VisitorsRuntime.iter

    method visit_constant_value : 'env -> constant_value -> unit = fun _ _ -> ()

    method visit_var : 'env -> var -> unit = fun _ _ -> ()

    method visit_place : 'env -> place -> unit = fun _ _ -> ()

    method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()

    method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
  end

(** Ancestor for [map_var_or_dummy] visitor *)
class ['self] map_value_base =
  object (_self : 'self)
    inherit [_] VisitorsRuntime.map

    method visit_constant_value : 'env -> constant_value -> constant_value =
      fun _ x -> x

    method visit_var : 'env -> var -> var = fun _ x -> x

    method visit_place : 'env -> place -> place = fun _ x -> x

    method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x

    method visit_ty : 'env -> ty -> ty = fun _ x -> x
  end

(** Ancestor for [reduce_var_or_dummy] visitor *)
class virtual ['self] reduce_value_base =
  object (self : 'self)
    inherit [_] VisitorsRuntime.reduce

    method visit_constant_value : 'env -> constant_value -> 'a =
      fun _ _ -> self#zero

    method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero

    method visit_place : 'env -> place -> 'a = fun _ _ -> self#zero

    method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero

    method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
  end

(** Ancestor for [mapreduce_var_or_dummy] visitor *)
class virtual ['self] mapreduce_value_base =
  object (self : 'self)
    inherit [_] VisitorsRuntime.mapreduce

    method visit_constant_value : 'env -> constant_value -> constant_value * 'a
        =
      fun _ x -> (x, self#zero)

    method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero)

    method visit_place : 'env -> place -> place * 'a = fun _ x -> (x, self#zero)

    method visit_mplace : 'env -> mplace -> mplace * 'a =
      fun _ x -> (x, self#zero)

    method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero)
  end

type var_or_dummy =
  | Var of var * mplace option
  | Dummy  (** Ignored value: `_`. *)
[@@deriving
  show,
    visitors
      {
        name = "iter_var_or_dummy";
        variety = "iter";
        ancestors = [ "iter_value_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "map_var_or_dummy";
        variety = "map";
        ancestors = [ "map_value_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.map] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "reduce_var_or_dummy";
        variety = "reduce";
        ancestors = [ "reduce_value_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.reduce] *);
        polymorphic = false;
      },
    visitors
      {
        name = "mapreduce_var_or_dummy";
        variety = "mapreduce";
        ancestors = [ "mapreduce_value_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.reduce] *);
        polymorphic = false;
      }]

(** A left value (which appears on the left of assignments *)
type lvalue =
  | LvConcrete of constant_value
      (** [LvConcrete] is necessary because we merge the switches over integer
        values and the matches over enumerations *)
  | LvVar of var_or_dummy
  | LvAdt of adt_lvalue

and adt_lvalue = {
  variant_id : (VariantId.id option[@opaque]);
  field_values : typed_lvalue list;
}

and typed_lvalue = { value : lvalue; ty : ty }
[@@deriving
  show,
    visitors
      {
        name = "iter_typed_lvalue";
        variety = "iter";
        ancestors = [ "iter_var_or_dummy" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "map_typed_lvalue";
        variety = "map";
        ancestors = [ "map_var_or_dummy" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "reduce_typed_lvalue";
        variety = "reduce";
        ancestors = [ "reduce_var_or_dummy" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        polymorphic = false;
      },
    visitors
      {
        name = "mapreduce_typed_lvalue";
        variety = "mapreduce";
        ancestors = [ "mapreduce_var_or_dummy" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        polymorphic = false;
      }]

type rvalue =
  | RvConcrete of constant_value
  | RvPlace of place
  | RvAdt of adt_rvalue

and adt_rvalue = {
  variant_id : (VariantId.id option[@opaque]);
  field_values : typed_rvalue list;
}

and typed_rvalue = { value : rvalue; ty : ty }
[@@deriving
  show,
    visitors
      {
        name = "iter_typed_rvalue";
        variety = "iter";
        ancestors = [ "iter_typed_lvalue" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "map_typed_rvalue";
        variety = "map";
        ancestors = [ "map_typed_lvalue" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
        polymorphic = false;
      },
    visitors
      {
        name = "reduce_typed_rvalue";
        variety = "reduce";
        ancestors = [ "reduce_typed_lvalue" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        polymorphic = false;
      },
    visitors
      {
        name = "mapreduce_typed_rvalue";
        variety = "mapreduce";
        ancestors = [ "mapreduce_typed_lvalue" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        polymorphic = false;
      }]

type unop = Not | Neg of integer_type [@@deriving show, ord]

type fun_id =
  | Regular of A.fun_id * T.RegionGroupId.id option
      (** Backward id: `Some` if the function is a backward function, `None`
          if it is a forward function.

          TODO: we need to redefine A.fun_id here, to add `fail` and
          `return` (important to get a unified treatment of the state-error
          monad). For now, when using the state-error monad: extraction
          works only if we unfold all the monadic let-bindings, and we
          then replace the content of the occurrences of `Return` to also
          return the state (which is really super ugly).
       *)
  | Unop of unop
  | Binop of E.binop * integer_type
[@@deriving show, ord]

(** Meta-information stored in the AST *)
type meta = Assignment of mplace * typed_rvalue [@@deriving show]

(** Ancestor for [iter_expression] visitor *)
class ['self] iter_expression_base =
  object (_self : 'self)
    inherit [_] iter_typed_rvalue

    method visit_meta : 'env -> meta -> unit = fun _ _ -> ()

    method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()

    method visit_scalar_value : 'env -> scalar_value -> unit = fun _ _ -> ()

    method visit_id : 'env -> VariantId.id -> unit = fun _ _ -> ()

    method visit_fun_id : 'env -> fun_id -> unit = fun _ _ -> ()
  end

(** Ancestor for [map_expression] visitor *)
class ['self] map_expression_base =
  object (_self : 'self)
    inherit [_] map_typed_rvalue

    method visit_meta : 'env -> meta -> meta = fun _ x -> x

    method visit_integer_type : 'env -> integer_type -> integer_type =
      fun _ x -> x

    method visit_scalar_value : 'env -> scalar_value -> scalar_value =
      fun _ x -> x

    method visit_id : 'env -> VariantId.id -> VariantId.id = fun _ x -> x

    method visit_fun_id : 'env -> fun_id -> fun_id = fun _ x -> x
  end

(** Ancestor for [reduce_expression] visitor *)
class virtual ['self] reduce_expression_base =
  object (self : 'self)
    inherit [_] reduce_typed_rvalue

    method visit_meta : 'env -> meta -> 'a = fun _ _ -> self#zero

    method visit_integer_type : 'env -> integer_type -> 'a =
      fun _ _ -> self#zero

    method visit_scalar_value : 'env -> scalar_value -> 'a =
      fun _ _ -> self#zero

    method visit_id : 'env -> VariantId.id -> 'a = fun _ _ -> self#zero

    method visit_fun_id : 'env -> fun_id -> 'a = fun _ _ -> self#zero
  end

(** Ancestor for [mapreduce_expression] visitor *)
class virtual ['self] mapreduce_expression_base =
  object (self : 'self)
    inherit [_] mapreduce_typed_rvalue

    method visit_meta : 'env -> meta -> meta * 'a = fun _ x -> (x, self#zero)

    method visit_integer_type : 'env -> integer_type -> integer_type * 'a =
      fun _ x -> (x, self#zero)

    method visit_scalar_value : 'env -> scalar_value -> scalar_value * 'a =
      fun _ x -> (x, self#zero)

    method visit_id : 'env -> VariantId.id -> VariantId.id * 'a =
      fun _ x -> (x, self#zero)

    method visit_fun_id : 'env -> fun_id -> fun_id * 'a =
      fun _ x -> (x, self#zero)
  end

(** **Rk.:** here, [expression] is not at all equivalent to the expressions
    used in CFIM. They are lambda-calculus expressions, and are thus actually
    more general than the CFIM statements, in a sense.
 *)
type expression =
  | Value of typed_rvalue * mplace option
  | Call of call
      (** The function calls are still quite structured.
          Change that?... We might want to have a "normal" lambda calculus
          app (with head and argument): this would allow us to replace some
          field accesses with calls to projectors over fields (when there
          are clashes of field names, some provers like F* get pretty bad...)
       *)
  | Let of bool * typed_lvalue * texpression * texpression
      (** Let binding.
      
          TODO: the boolean should be replaced by an enum: sometimes we use
          the error-monad, sometimes we use the state-error monad (and we
          do this an a per-function basis! For instance, arithmetic functions
          are always in the error monad).

          The boolean controls whether the let is monadic or not.
          For instance, in F*:
          - non-monadic: `let x = ... in ...`
          - monadic:     `x <-- ...; ...`

          Note that we are quite general for the left-value on purpose; this
          is used in several situations:

          1. When deconstructing a tuple:
          ```
          let (x, y) = p in ...
          ```
          (not all languages have syntax like `p.0`, `p.1`... and it is more
          readable anyway).
          
          2. When expanding an enumeration with one variant.
          
          In this case, [Deconstruct] has to be understood as:
          ```
          let Cons x tl = ls in
          ...
          ```
          
          Note that later, depending on the language we extract to, we can
          eventually update it to something like this (for F*, for instance):
          ```
          let x = Cons?.v ls in
          let tl = Cons?.tl ls in
          ...
          ```
       *)
  | Switch of texpression * switch_body
  | Meta of meta * texpression  (** Meta-information *)

and call = {
  func : fun_id;
  type_params : ty list;
  args : texpression list;
      (** Note that immediately after we converted the symbolic AST to a pure AST,
          some functions may have no arguments. For instance:
          ```
          fn f();
          ```
          We later add a unit argument.
       *)
}

and switch_body = If of texpression * texpression | Match of match_branch list

and match_branch = { pat : typed_lvalue; branch : texpression }

and texpression = { e : expression; ty : ty }
[@@deriving
  show,
    visitors
      {
        name = "iter_expression";
        variety = "iter";
        ancestors = [ "iter_expression_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
      },
    visitors
      {
        name = "map_expression";
        variety = "map";
        ancestors = [ "map_expression_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
        concrete = true;
      },
    visitors
      {
        name = "reduce_expression";
        variety = "reduce";
        ancestors = [ "reduce_expression_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
      },
    visitors
      {
        name = "mapreduce_expression";
        variety = "mapreduce";
        ancestors = [ "mapreduce_expression_base" ];
        nude = true (* Don't inherit [VisitorsRuntime.iter] *);
      }]

type fun_sig = {
  type_params : type_var list;
  inputs : ty list;
  outputs : ty list;
      (** The list of outputs.

          Immediately after the translation from symbolic to pure we have this
          the following:
          In case of a forward function, the list will have length = 1.
          However, in case of backward function, the list may have length > 1.
          If the length is > 1, it gets extracted to a tuple type. Followingly,
          we could not use a list because we can encode tuples, but here we
          want to account for the fact that we immediately deconstruct the tuple
          upon calling the backward function (because the backward function is
          called to update a set of values in the environment).
          
          After the "to monadic" pass, the list has size exactly one (and we
          use the `Result` type).
       *)
}

type inst_fun_sig = { inputs : ty list; outputs : ty list }

type fun_decl = {
  def_id : FunDeclId.id;
  back_id : T.RegionGroupId.id option;
  basename : fun_name;
      (** The "base" name of the function.
  
          The base name is the original name of the Rust function. We add suffixes
          (to identify the forward/backward functions) later.
       *)
  signature : fun_sig;
  inputs : var list;
  inputs_lvs : typed_lvalue list;
      (** The inputs seen as lvalues. Allows to make transformations, for example
          to replace unused variables by `_` *)
  body : texpression;
}