summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-01-28 11:06:13 +0100
committerSon Ho2022-01-28 11:06:13 +0100
commitd00dd80b8b752a17c2027d6daccf74974ebf4292 (patch)
tree7d1b345a6d24dc6698c4040d8277f5eb5eea37fb
parent7deb7a2bde6d6bcdf14aac4f68f336bc498b964b (diff)
Simplify the let-bindings in the pure AST
Diffstat (limited to '')
-rw-r--r--src/PrintPure.ml56
-rw-r--r--src/Pure.ml155
-rw-r--r--src/PureMicroPasses.ml41
-rw-r--r--src/SymbolicToPure.ml28
4 files changed, 158 insertions, 122 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 3e68db90..9d28c959 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -353,12 +353,23 @@ let meta_to_string (fmt : ast_formatter) (meta : meta) : string =
in
"@meta[" ^ meta ^ "]"
+let call_to_string (fmt : ast_formatter) (call : call) : string =
+ let val_fmt = ast_to_value_formatter fmt in
+ let ty_fmt = ast_to_type_formatter fmt in
+ let tys = List.map (ty_to_string ty_fmt) call.type_params in
+ let args = List.map (typed_rvalue_to_string fmt) call.args in
+ let all_args = List.append tys args in
+ let fun_id = fun_id_to_string fmt call.func in
+ if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args
+
let rec expression_to_string (fmt : ast_formatter) (indent : string)
(indent_incr : string) (e : expression) : string =
match e with
- | Return v -> indent ^ "return " ^ typed_rvalue_to_string fmt v
- | Fail -> indent ^ "fail"
- | Let (lb, e) -> let_to_string fmt indent indent_incr lb e
+ | Return v -> "return " ^ typed_rvalue_to_string fmt v
+ | Fail -> "fail"
+ | Value (v, _) -> typed_rvalue_to_string fmt v
+ | Call call -> call_to_string fmt call
+ | Let (lv, re, e) -> let_to_string fmt indent indent_incr lv re e
| Switch (scrutinee, _, body) ->
switch_to_string fmt indent indent_incr scrutinee body
| Meta (meta, e) ->
@@ -367,26 +378,13 @@ let rec expression_to_string (fmt : ast_formatter) (indent : string)
indent ^ meta ^ "\n" ^ e
and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
- (lb : let_bindings) (e : expression) : string =
- let e = expression_to_string fmt indent indent_incr e in
+ (lv : typed_lvalue) (re : expression) (e : expression) : string =
+ let indent1 = indent ^ indent_incr in
let val_fmt = ast_to_value_formatter fmt in
- match lb with
- | Call (lv, call) ->
- let lv = typed_lvalue_to_string val_fmt lv in
- let ty_fmt = ast_to_type_formatter fmt in
- let tys = List.map (ty_to_string ty_fmt) call.type_params in
- let args = List.map (typed_rvalue_to_string fmt) call.args in
- let all_args = List.append tys args in
- let fun_id = fun_id_to_string fmt call.func in
- let call =
- if all_args = [] then fun_id
- else fun_id ^ " " ^ String.concat " " all_args
- in
- indent ^ "let " ^ lv ^ " = " ^ call ^ " in\n" ^ e
- | Assign (lv, rv, _) ->
- let lv = typed_lvalue_to_string val_fmt lv in
- let rv = typed_rvalue_to_string fmt rv in
- indent ^ "let " ^ lv ^ " = " ^ rv ^ " in\n" ^ e
+ let re = expression_to_string fmt indent1 indent_incr re in
+ let e = expression_to_string fmt indent indent_incr e in
+ let lv = typed_lvalue_to_string val_fmt lv in
+ "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e
and switch_to_string (fmt : ast_formatter) (indent : string)
(indent_incr : string) (scrutinee : typed_rvalue) (body : switch_body) :
@@ -397,13 +395,13 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
| If (e_true, e_false) ->
let e_true = expression_to_string fmt indent1 indent_incr e_true in
let e_false = expression_to_string fmt indent1 indent_incr e_false in
- indent ^ "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ e_true ^ "\n" ^ indent
- ^ "else\n" ^ e_false
+ "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ indent ^ e_true ^ "\n" ^ indent
+ ^ "else\n" ^ indent ^ e_false
| SwitchInt (_, branches, otherwise) ->
let branches =
List.map
(fun (v, be) ->
- indent ^ "| " ^ scalar_value_to_string v ^ " ->\n"
+ indent ^ "| " ^ scalar_value_to_string v ^ " ->\n" ^ indent1
^ expression_to_string fmt indent1 indent_incr be)
branches
in
@@ -412,16 +410,16 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
^ expression_to_string fmt indent1 indent_incr otherwise
in
let all_branches = List.append branches [ otherwise ] in
- indent ^ "switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches
+ "switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches
| Match branches ->
let val_fmt = ast_to_value_formatter fmt in
let branch_to_string (b : match_branch) : string =
let pat = typed_lvalue_to_string val_fmt b.pat in
- indent ^ "| " ^ pat ^ " ->\n"
+ indent ^ "| " ^ pat ^ " ->\n" ^ indent1
^ expression_to_string fmt indent1 indent_incr b.branch
in
let branches = List.map branch_to_string branches in
- indent ^ "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches
+ "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches
let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string =
let type_fmt = ast_to_type_formatter fmt in
@@ -432,4 +430,4 @@ let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string =
if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n"
in
let body = expression_to_string fmt " " " " def.body in
- "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ body
+ "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body
diff --git a/src/Pure.ml b/src/Pure.ml
index 61d2d130..64851449 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -18,6 +18,20 @@ module SynthPhaseId = IdGen ()
module VarId = IdGen ()
(** Pay attention to the fact that we also define a [VarId] module in Values *)
+(* TODO
+ (** The assumed types for the pure AST.
+
+ In comparison with CFIM:
+ - we removed `Box` (because it is translated as the identity: `Box T == T`)
+ - we added `Result`, which is the type used in the error monad. This allows
+ us to have a unified treatment of expressions.
+ *)
+ type assumed_ty = unit
+
+ type type_id = AdtId of TypeDefId.id | Tuple | Assumed of assumed_ty
+ [@@deriving show, ord]
+*)
+
type ty =
| Adt of T.type_id * ty list
(** [Adt] encodes ADTs and tuples and assumed types.
@@ -238,55 +252,6 @@ type fun_id =
| Unop of unop
| Binop of E.binop * T.integer_type
-type call = {
- func : fun_id;
- type_params : ty list;
- args : typed_rvalue list;
- (** Note that at this point, some functions have no arguments. For instance:
- ```
- fn f();
- ```
-
- In the extracted code, we add a unit argument. This is unit argument is
- added later, when going from the "pure" AST to the "extracted" AST.
- *)
- args_mplaces : mplace option list; (** Meta data *)
-}
-
-(* TODO: we might want to merge Call and Assign *)
-type let_bindings =
- | Call of typed_lvalue * call
- (** The called function and the tuple of returned values. *)
- | Assign of typed_lvalue * typed_rvalue * mplace option
- (** Variable assignment: the introduced pattern and the place we read.
-
- We are quite general for the left-value on purpose; this is used
- in several situations:
-
- 1. When deconstructing a tuple:
- ```
- let (x, y) = p in ...
- ```
- (not all languages have syntax like `p.0`, `p.1`... and it is more
- readable anyway).
-
- 2. When expanding an enumeration with one variant.
-
- In this case, [Deconstruct] has to be understood as:
- ```
- let Cons x tl = ls in
- ...
- ```
-
- Note that later, depending on the language we extract to, we can
- eventually update it to something like this (for F*, for instance):
- ```
- let x = Cons?.v ls in
- let tl = Cons?.tl ls in
- ...
- ```
- *)
-
(** Meta-information stored in the AST *)
type meta = Assignment of mplace * typed_rvalue
@@ -295,12 +260,12 @@ class ['self] iter_expression_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.iter
+ method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
+
method visit_typed_rvalue : 'env -> typed_rvalue -> unit = fun _ _ -> ()
method visit_typed_lvalue : 'env -> typed_lvalue -> unit = fun _ _ -> ()
- method visit_let_bindings : 'env -> let_bindings -> unit = fun _ _ -> ()
-
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
method visit_meta : 'env -> meta -> unit = fun _ _ -> ()
@@ -311,7 +276,7 @@ class ['self] iter_expression_base =
method visit_id : 'env -> VariantId.id -> unit = fun _ _ -> ()
- method visit_var_or_dummy : 'env -> var_or_dummy -> unit = fun _ _ -> ()
+ method visit_fun_id : 'env -> fun_id -> unit = fun _ _ -> ()
end
(** Ancestor for [map_expression] map visitor *)
@@ -319,15 +284,14 @@ class ['self] map_expression_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.map
+ method visit_ty : 'env -> ty -> ty = fun _ x -> x
+
method visit_typed_rvalue : 'env -> typed_rvalue -> typed_rvalue =
fun _ x -> x
method visit_typed_lvalue : 'env -> typed_lvalue -> typed_lvalue =
fun _ x -> x
- method visit_let_bindings : 'env -> let_bindings -> let_bindings =
- fun _ x -> x
-
method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
method visit_meta : 'env -> meta -> meta = fun _ x -> x
@@ -340,10 +304,43 @@ class ['self] map_expression_base =
method visit_id : 'env -> VariantId.id -> VariantId.id = fun _ x -> x
- method visit_var_or_dummy : 'env -> var_or_dummy -> var_or_dummy =
- fun _ x -> x
+ method visit_fun_id : 'env -> fun_id -> fun_id = fun _ x -> x
end
+type call = {
+ func : fun_id;
+ type_params : ty list;
+ args : typed_rvalue list;
+ (** Note that immediately after we converted the symbolic AST to a pure AST,
+ some functions may have no arguments. For instance:
+ ```
+ fn f();
+ ```
+ We later add a unit argument.
+
+ TODO: we should use expressions...
+ *)
+ args_mplaces : mplace option list; (** Meta data *)
+}
+[@@deriving
+ visitors
+ {
+ name = "iter_call";
+ variety = "iter";
+ ancestors = [ "iter_expression_base" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ },
+ visitors
+ {
+ name = "map_call";
+ variety = "map";
+ ancestors = [ "map_expression_base" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ }]
+(** "Regular" typed value (we map variables to typed values) *)
+
(** **Rk.:** here, [expression] is not at all equivalent to the expressions
used in CFIM. They are lambda-calculus expressions, and are thus actually
more general than the CFIM statements, in a sense.
@@ -352,12 +349,48 @@ class ['self] map_expression_base =
it is not a "textbook" lambda calculus expression (still quite constrained).
As we want to do transformations on it, through micro-passes, it would be
good to update it and make it more "regular".
+
+ TODO: remove `Return` and `Fail` (they should be "normal" values, I think)
*)
type expression =
| Return of typed_rvalue
| Fail
- | Let of let_bindings * expression
- (** Let bindings include the let-bindings introduced because of function calls *)
+ | Value of typed_rvalue * mplace option
+ | Call of call
+ | Let of typed_lvalue * expression * expression
+ (** Let binding.
+
+ TODO: add a boolean to control whether the let is monadic or not.
+ For instance, in F*:
+ - non-monadic: `let x = ... in ...`
+ - monadic: `x <-- ...; ...`
+
+ Note that we are quite general for the left-value on purpose; this
+ is used in several situations:
+
+ 1. When deconstructing a tuple:
+ ```
+ let (x, y) = p in ...
+ ```
+ (not all languages have syntax like `p.0`, `p.1`... and it is more
+ readable anyway).
+
+ 2. When expanding an enumeration with one variant.
+
+ In this case, [Deconstruct] has to be understood as:
+ ```
+ let Cons x tl = ls in
+ ...
+ ```
+
+ Note that later, depending on the language we extract to, we can
+ eventually update it to something like this (for F*, for instance):
+ ```
+ let x = Cons?.v ls in
+ let tl = Cons?.tl ls in
+ ...
+ ```
+ *)
| Switch of typed_rvalue * mplace option * switch_body
| Meta of meta * expression (** Meta-information *)
@@ -372,7 +405,7 @@ and match_branch = { pat : typed_lvalue; branch : expression }
{
name = "iter_expression";
variety = "iter";
- ancestors = [ "iter_expression_base" ];
+ ancestors = [ "iter_call" ];
nude = true (* Don't inherit [VisitorsRuntime.iter] *);
concrete = true;
},
@@ -380,7 +413,7 @@ and match_branch = { pat : typed_lvalue; branch : expression }
{
name = "map_expression";
variety = "map";
- ancestors = [ "map_expression_base" ];
+ ancestors = [ "map_call" ];
nude = true (* Don't inherit [VisitorsRuntime.iter] *);
concrete = true;
}]
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 80c35124..985d9ecc 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -159,28 +159,31 @@ let compute_pretty_names (def : fun_def) : fun_def =
pn_ctx * expression =
match e with
| Return _ | Fail -> (ctx, e)
- | Let (lb, e) -> update_let lb e ctx
+ | Value (v, mp) -> update_value v mp ctx
+ | Call call -> update_call call ctx
+ | Let (lb, re, e) -> update_let lb re e ctx
| Switch (scrut, mp, body) -> update_switch_body scrut mp body ctx
| Meta (meta, e) -> update_meta meta e ctx
(* *)
- and update_let (lb : let_bindings) (e : expression) (ctx : pn_ctx) :
+ and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) :
pn_ctx * expression =
- match lb with
- | Call (lv, call) ->
- let ctx =
- add_opt_right_constraint_list ctx
- (List.combine call.args_mplaces call.args)
- in
- let ctx = add_left_constraint lv ctx in
- let ctx, e = update_expression e ctx in
- let lv = update_typed_lvalue ctx lv in
- (ctx, Let (Call (lv, call), e))
- | Assign (lv, rv, rmp) ->
- let ctx = add_left_constraint lv ctx in
- let ctx = add_opt_right_constraint rmp rv ctx in
- let ctx, e = update_expression e ctx in
- let lv = update_typed_lvalue ctx lv in
- (ctx, Let (Assign (lv, rv, rmp), e))
+ let ctx = add_opt_right_constraint mp v ctx in
+ (ctx, Value (v, mp))
+ (* *)
+ and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression =
+ let ctx =
+ add_opt_right_constraint_list ctx
+ (List.combine call.args_mplaces call.args)
+ in
+ (ctx, Call call)
+ (* *)
+ and update_let (lv : typed_lvalue) (re : expression) (e : expression)
+ (ctx : pn_ctx) : pn_ctx * expression =
+ let ctx = add_left_constraint lv ctx in
+ let ctx, re = update_expression re ctx in
+ let ctx, e = update_expression e ctx in
+ let lv = update_typed_lvalue ctx lv in
+ (ctx, Let (lv, re, e))
(* *)
and update_switch_body (scrut : typed_rvalue) (mp : mplace option)
(body : switch_body) (ctx : pn_ctx) : pn_ctx * expression =
@@ -304,6 +307,8 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_def) : fun_def =
(* TODO: deconstruct the monadic bindings into matches *)
+ (* TODO: add unit arguments for functions with no arguments *)
+
(* We are done *)
def
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 7fd72926..f4b92dff 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -955,10 +955,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| _ -> failwith "Unreachable")
in
let call = { func; type_params; args; args_mplaces } in
+ let call = Call call in
(* Translate the next expression *)
let e = translate_expression e ctx in
(* Put together *)
- Let (Call (mk_typed_lvalue_from_var dest dest_mplace, call), e)
+ Let (mk_typed_lvalue_from_var dest dest_mplace, call, e)
and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
expression =
@@ -1013,7 +1014,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Generate the assignemnts *)
List.fold_right
(fun (var, value) e ->
- Let (Assign (mk_typed_lvalue_from_var var None, value, None), e))
+ Let (mk_typed_lvalue_from_var var None, Value (value, None), e))
variables_values e
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
@@ -1069,7 +1070,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let call = { func; type_params; args = inputs; args_mplaces } in
- Let (Call (output, call), e)
+ Let (output, Call call, e)
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1122,7 +1123,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Generate the assignments *)
List.fold_right
(fun (given_back, input_var) e ->
- Let (Assign (given_back, mk_typed_rvalue_from_var input_var, None), e))
+ Let (given_back, Value (mk_typed_rvalue_from_var input_var, None), e))
given_back_inputs e
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
@@ -1145,8 +1146,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
let e = translate_expression e ctx in
Let
- ( Assign
- (mk_typed_lvalue_from_var var None, scrutinee, scrutinee_mplace),
+ ( mk_typed_lvalue_from_var var None,
+ Value (scrutinee, scrutinee_mplace),
e )
| SeAdt _ ->
(* Should be in the [ExpandAdt] case *)
@@ -1171,7 +1172,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.map (fun v -> mk_typed_lvalue_from_var v None) vars
in
let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in
- Let (Assign (lv, scrutinee, scrutinee_mplace), branch)
+ Let (lv, Value (scrutinee, scrutinee_mplace), branch)
else
(* This is not an enumeration: introduce let-bindings for every
* field.
@@ -1192,8 +1193,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(fun (fid, var) e ->
let field_proj = gen_field_proj fid var in
Let
- ( Assign
- (mk_typed_lvalue_from_var var None, field_proj, None),
+ ( mk_typed_lvalue_from_var var None,
+ Value (field_proj, None),
e ))
id_var_pairs branch
| T.Tuple ->
@@ -1201,7 +1202,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.map (fun x -> mk_typed_lvalue_from_var x None) vars
in
Let
- ( Assign (mk_tuple_lvalue vars, scrutinee, scrutinee_mplace),
+ ( mk_tuple_lvalue vars,
+ Value (scrutinee, scrutinee_mplace),
branch )
| T.Assumed T.Box ->
(* There should be exactly one variable *)
@@ -1211,10 +1213,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* We simply introduce an assignment - the box type is the
* identity when extracted (`box a == a`) *)
Let
- ( Assign
- ( mk_typed_lvalue_from_var var None,
- scrutinee,
- scrutinee_mplace ),
+ ( mk_typed_lvalue_from_var var None,
+ Value (scrutinee, scrutinee_mplace),
branch ))
| branches ->
let translate_branch (variant_id : T.VariantId.id option)