diff options
author | Son Ho | 2022-01-28 11:06:13 +0100 |
---|---|---|
committer | Son Ho | 2022-01-28 11:06:13 +0100 |
commit | d00dd80b8b752a17c2027d6daccf74974ebf4292 (patch) | |
tree | 7d1b345a6d24dc6698c4040d8277f5eb5eea37fb | |
parent | 7deb7a2bde6d6bcdf14aac4f68f336bc498b964b (diff) |
Simplify the let-bindings in the pure AST
-rw-r--r-- | src/PrintPure.ml | 56 | ||||
-rw-r--r-- | src/Pure.ml | 155 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 41 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 28 |
4 files changed, 158 insertions, 122 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 3e68db90..9d28c959 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -353,12 +353,23 @@ let meta_to_string (fmt : ast_formatter) (meta : meta) : string = in "@meta[" ^ meta ^ "]" +let call_to_string (fmt : ast_formatter) (call : call) : string = + let val_fmt = ast_to_value_formatter fmt in + let ty_fmt = ast_to_type_formatter fmt in + let tys = List.map (ty_to_string ty_fmt) call.type_params in + let args = List.map (typed_rvalue_to_string fmt) call.args in + let all_args = List.append tys args in + let fun_id = fun_id_to_string fmt call.func in + if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args + let rec expression_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (e : expression) : string = match e with - | Return v -> indent ^ "return " ^ typed_rvalue_to_string fmt v - | Fail -> indent ^ "fail" - | Let (lb, e) -> let_to_string fmt indent indent_incr lb e + | Return v -> "return " ^ typed_rvalue_to_string fmt v + | Fail -> "fail" + | Value (v, _) -> typed_rvalue_to_string fmt v + | Call call -> call_to_string fmt call + | Let (lv, re, e) -> let_to_string fmt indent indent_incr lv re e | Switch (scrutinee, _, body) -> switch_to_string fmt indent indent_incr scrutinee body | Meta (meta, e) -> @@ -367,26 +378,13 @@ let rec expression_to_string (fmt : ast_formatter) (indent : string) indent ^ meta ^ "\n" ^ e and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) - (lb : let_bindings) (e : expression) : string = - let e = expression_to_string fmt indent indent_incr e in + (lv : typed_lvalue) (re : expression) (e : expression) : string = + let indent1 = indent ^ indent_incr in let val_fmt = ast_to_value_formatter fmt in - match lb with - | Call (lv, call) -> - let lv = typed_lvalue_to_string val_fmt lv in - let ty_fmt = ast_to_type_formatter fmt in - let tys = List.map (ty_to_string ty_fmt) call.type_params in - let args = List.map (typed_rvalue_to_string fmt) call.args in - let all_args = List.append tys args in - let fun_id = fun_id_to_string fmt call.func in - let call = - if all_args = [] then fun_id - else fun_id ^ " " ^ String.concat " " all_args - in - indent ^ "let " ^ lv ^ " = " ^ call ^ " in\n" ^ e - | Assign (lv, rv, _) -> - let lv = typed_lvalue_to_string val_fmt lv in - let rv = typed_rvalue_to_string fmt rv in - indent ^ "let " ^ lv ^ " = " ^ rv ^ " in\n" ^ e + let re = expression_to_string fmt indent1 indent_incr re in + let e = expression_to_string fmt indent indent_incr e in + let lv = typed_lvalue_to_string val_fmt lv in + "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e and switch_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (scrutinee : typed_rvalue) (body : switch_body) : @@ -397,13 +395,13 @@ and switch_to_string (fmt : ast_formatter) (indent : string) | If (e_true, e_false) -> let e_true = expression_to_string fmt indent1 indent_incr e_true in let e_false = expression_to_string fmt indent1 indent_incr e_false in - indent ^ "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ e_true ^ "\n" ^ indent - ^ "else\n" ^ e_false + "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ indent ^ e_true ^ "\n" ^ indent + ^ "else\n" ^ indent ^ e_false | SwitchInt (_, branches, otherwise) -> let branches = List.map (fun (v, be) -> - indent ^ "| " ^ scalar_value_to_string v ^ " ->\n" + indent ^ "| " ^ scalar_value_to_string v ^ " ->\n" ^ indent1 ^ expression_to_string fmt indent1 indent_incr be) branches in @@ -412,16 +410,16 @@ and switch_to_string (fmt : ast_formatter) (indent : string) ^ expression_to_string fmt indent1 indent_incr otherwise in let all_branches = List.append branches [ otherwise ] in - indent ^ "switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches + "switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches | Match branches -> let val_fmt = ast_to_value_formatter fmt in let branch_to_string (b : match_branch) : string = let pat = typed_lvalue_to_string val_fmt b.pat in - indent ^ "| " ^ pat ^ " ->\n" + indent ^ "| " ^ pat ^ " ->\n" ^ indent1 ^ expression_to_string fmt indent1 indent_incr b.branch in let branches = List.map branch_to_string branches in - indent ^ "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches + "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string = let type_fmt = ast_to_type_formatter fmt in @@ -432,4 +430,4 @@ let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string = if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n" in let body = expression_to_string fmt " " " " def.body in - "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ body + "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body diff --git a/src/Pure.ml b/src/Pure.ml index 61d2d130..64851449 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -18,6 +18,20 @@ module SynthPhaseId = IdGen () module VarId = IdGen () (** Pay attention to the fact that we also define a [VarId] module in Values *) +(* TODO + (** The assumed types for the pure AST. + + In comparison with CFIM: + - we removed `Box` (because it is translated as the identity: `Box T == T`) + - we added `Result`, which is the type used in the error monad. This allows + us to have a unified treatment of expressions. + *) + type assumed_ty = unit + + type type_id = AdtId of TypeDefId.id | Tuple | Assumed of assumed_ty + [@@deriving show, ord] +*) + type ty = | Adt of T.type_id * ty list (** [Adt] encodes ADTs and tuples and assumed types. @@ -238,55 +252,6 @@ type fun_id = | Unop of unop | Binop of E.binop * T.integer_type -type call = { - func : fun_id; - type_params : ty list; - args : typed_rvalue list; - (** Note that at this point, some functions have no arguments. For instance: - ``` - fn f(); - ``` - - In the extracted code, we add a unit argument. This is unit argument is - added later, when going from the "pure" AST to the "extracted" AST. - *) - args_mplaces : mplace option list; (** Meta data *) -} - -(* TODO: we might want to merge Call and Assign *) -type let_bindings = - | Call of typed_lvalue * call - (** The called function and the tuple of returned values. *) - | Assign of typed_lvalue * typed_rvalue * mplace option - (** Variable assignment: the introduced pattern and the place we read. - - We are quite general for the left-value on purpose; this is used - in several situations: - - 1. When deconstructing a tuple: - ``` - let (x, y) = p in ... - ``` - (not all languages have syntax like `p.0`, `p.1`... and it is more - readable anyway). - - 2. When expanding an enumeration with one variant. - - In this case, [Deconstruct] has to be understood as: - ``` - let Cons x tl = ls in - ... - ``` - - Note that later, depending on the language we extract to, we can - eventually update it to something like this (for F*, for instance): - ``` - let x = Cons?.v ls in - let tl = Cons?.tl ls in - ... - ``` - *) - (** Meta-information stored in the AST *) type meta = Assignment of mplace * typed_rvalue @@ -295,12 +260,12 @@ class ['self] iter_expression_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter + method visit_ty : 'env -> ty -> unit = fun _ _ -> () + method visit_typed_rvalue : 'env -> typed_rvalue -> unit = fun _ _ -> () method visit_typed_lvalue : 'env -> typed_lvalue -> unit = fun _ _ -> () - method visit_let_bindings : 'env -> let_bindings -> unit = fun _ _ -> () - method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () method visit_meta : 'env -> meta -> unit = fun _ _ -> () @@ -311,7 +276,7 @@ class ['self] iter_expression_base = method visit_id : 'env -> VariantId.id -> unit = fun _ _ -> () - method visit_var_or_dummy : 'env -> var_or_dummy -> unit = fun _ _ -> () + method visit_fun_id : 'env -> fun_id -> unit = fun _ _ -> () end (** Ancestor for [map_expression] map visitor *) @@ -319,15 +284,14 @@ class ['self] map_expression_base = object (_self : 'self) inherit [_] VisitorsRuntime.map + method visit_ty : 'env -> ty -> ty = fun _ x -> x + method visit_typed_rvalue : 'env -> typed_rvalue -> typed_rvalue = fun _ x -> x method visit_typed_lvalue : 'env -> typed_lvalue -> typed_lvalue = fun _ x -> x - method visit_let_bindings : 'env -> let_bindings -> let_bindings = - fun _ x -> x - method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x method visit_meta : 'env -> meta -> meta = fun _ x -> x @@ -340,10 +304,43 @@ class ['self] map_expression_base = method visit_id : 'env -> VariantId.id -> VariantId.id = fun _ x -> x - method visit_var_or_dummy : 'env -> var_or_dummy -> var_or_dummy = - fun _ x -> x + method visit_fun_id : 'env -> fun_id -> fun_id = fun _ x -> x end +type call = { + func : fun_id; + type_params : ty list; + args : typed_rvalue list; + (** Note that immediately after we converted the symbolic AST to a pure AST, + some functions may have no arguments. For instance: + ``` + fn f(); + ``` + We later add a unit argument. + + TODO: we should use expressions... + *) + args_mplaces : mplace option list; (** Meta data *) +} +[@@deriving + visitors + { + name = "iter_call"; + variety = "iter"; + ancestors = [ "iter_expression_base" ]; + nude = true (* Don't inherit [VisitorsRuntime.iter] *); + concrete = true; + }, + visitors + { + name = "map_call"; + variety = "map"; + ancestors = [ "map_expression_base" ]; + nude = true (* Don't inherit [VisitorsRuntime.iter] *); + concrete = true; + }] +(** "Regular" typed value (we map variables to typed values) *) + (** **Rk.:** here, [expression] is not at all equivalent to the expressions used in CFIM. They are lambda-calculus expressions, and are thus actually more general than the CFIM statements, in a sense. @@ -352,12 +349,48 @@ class ['self] map_expression_base = it is not a "textbook" lambda calculus expression (still quite constrained). As we want to do transformations on it, through micro-passes, it would be good to update it and make it more "regular". + + TODO: remove `Return` and `Fail` (they should be "normal" values, I think) *) type expression = | Return of typed_rvalue | Fail - | Let of let_bindings * expression - (** Let bindings include the let-bindings introduced because of function calls *) + | Value of typed_rvalue * mplace option + | Call of call + | Let of typed_lvalue * expression * expression + (** Let binding. + + TODO: add a boolean to control whether the let is monadic or not. + For instance, in F*: + - non-monadic: `let x = ... in ...` + - monadic: `x <-- ...; ...` + + Note that we are quite general for the left-value on purpose; this + is used in several situations: + + 1. When deconstructing a tuple: + ``` + let (x, y) = p in ... + ``` + (not all languages have syntax like `p.0`, `p.1`... and it is more + readable anyway). + + 2. When expanding an enumeration with one variant. + + In this case, [Deconstruct] has to be understood as: + ``` + let Cons x tl = ls in + ... + ``` + + Note that later, depending on the language we extract to, we can + eventually update it to something like this (for F*, for instance): + ``` + let x = Cons?.v ls in + let tl = Cons?.tl ls in + ... + ``` + *) | Switch of typed_rvalue * mplace option * switch_body | Meta of meta * expression (** Meta-information *) @@ -372,7 +405,7 @@ and match_branch = { pat : typed_lvalue; branch : expression } { name = "iter_expression"; variety = "iter"; - ancestors = [ "iter_expression_base" ]; + ancestors = [ "iter_call" ]; nude = true (* Don't inherit [VisitorsRuntime.iter] *); concrete = true; }, @@ -380,7 +413,7 @@ and match_branch = { pat : typed_lvalue; branch : expression } { name = "map_expression"; variety = "map"; - ancestors = [ "map_expression_base" ]; + ancestors = [ "map_call" ]; nude = true (* Don't inherit [VisitorsRuntime.iter] *); concrete = true; }] diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 80c35124..985d9ecc 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -159,28 +159,31 @@ let compute_pretty_names (def : fun_def) : fun_def = pn_ctx * expression = match e with | Return _ | Fail -> (ctx, e) - | Let (lb, e) -> update_let lb e ctx + | Value (v, mp) -> update_value v mp ctx + | Call call -> update_call call ctx + | Let (lb, re, e) -> update_let lb re e ctx | Switch (scrut, mp, body) -> update_switch_body scrut mp body ctx | Meta (meta, e) -> update_meta meta e ctx (* *) - and update_let (lb : let_bindings) (e : expression) (ctx : pn_ctx) : + and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) : pn_ctx * expression = - match lb with - | Call (lv, call) -> - let ctx = - add_opt_right_constraint_list ctx - (List.combine call.args_mplaces call.args) - in - let ctx = add_left_constraint lv ctx in - let ctx, e = update_expression e ctx in - let lv = update_typed_lvalue ctx lv in - (ctx, Let (Call (lv, call), e)) - | Assign (lv, rv, rmp) -> - let ctx = add_left_constraint lv ctx in - let ctx = add_opt_right_constraint rmp rv ctx in - let ctx, e = update_expression e ctx in - let lv = update_typed_lvalue ctx lv in - (ctx, Let (Assign (lv, rv, rmp), e)) + let ctx = add_opt_right_constraint mp v ctx in + (ctx, Value (v, mp)) + (* *) + and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression = + let ctx = + add_opt_right_constraint_list ctx + (List.combine call.args_mplaces call.args) + in + (ctx, Call call) + (* *) + and update_let (lv : typed_lvalue) (re : expression) (e : expression) + (ctx : pn_ctx) : pn_ctx * expression = + let ctx = add_left_constraint lv ctx in + let ctx, re = update_expression re ctx in + let ctx, e = update_expression e ctx in + let lv = update_typed_lvalue ctx lv in + (ctx, Let (lv, re, e)) (* *) and update_switch_body (scrut : typed_rvalue) (mp : mplace option) (body : switch_body) (ctx : pn_ctx) : pn_ctx * expression = @@ -304,6 +307,8 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_def) : fun_def = (* TODO: deconstruct the monadic bindings into matches *) + (* TODO: add unit arguments for functions with no arguments *) + (* We are done *) def diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 7fd72926..f4b92dff 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -955,10 +955,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | _ -> failwith "Unreachable") in let call = { func; type_params; args; args_mplaces } in + let call = Call call in (* Translate the next expression *) let e = translate_expression e ctx in (* Put together *) - Let (Call (mk_typed_lvalue_from_var dest dest_mplace, call), e) + Let (mk_typed_lvalue_from_var dest dest_mplace, call, e) and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : expression = @@ -1013,7 +1014,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Generate the assignemnts *) List.fold_right (fun (var, value) e -> - Let (Assign (mk_typed_lvalue_from_var var None, value, None), e)) + Let (mk_typed_lvalue_from_var var None, Value (value, None), e)) variables_values e | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in @@ -1069,7 +1070,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Put everything together *) let args_mplaces = List.map (fun _ -> None) inputs in let call = { func; type_params; args = inputs; args_mplaces } in - Let (Call (output, call), e) + Let (output, Call call, e) | V.SynthRet -> (* If we end the abstraction which consumed the return value of the function * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1122,7 +1123,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Generate the assignments *) List.fold_right (fun (given_back, input_var) e -> - Let (Assign (given_back, mk_typed_rvalue_from_var input_var, None), e)) + Let (given_back, Value (mk_typed_rvalue_from_var input_var, None), e)) given_back_inputs e and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) @@ -1145,8 +1146,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let e = translate_expression e ctx in Let - ( Assign - (mk_typed_lvalue_from_var var None, scrutinee, scrutinee_mplace), + ( mk_typed_lvalue_from_var var None, + Value (scrutinee, scrutinee_mplace), e ) | SeAdt _ -> (* Should be in the [ExpandAdt] case *) @@ -1171,7 +1172,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.map (fun v -> mk_typed_lvalue_from_var v None) vars in let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in - Let (Assign (lv, scrutinee, scrutinee_mplace), branch) + Let (lv, Value (scrutinee, scrutinee_mplace), branch) else (* This is not an enumeration: introduce let-bindings for every * field. @@ -1192,8 +1193,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (fun (fid, var) e -> let field_proj = gen_field_proj fid var in Let - ( Assign - (mk_typed_lvalue_from_var var None, field_proj, None), + ( mk_typed_lvalue_from_var var None, + Value (field_proj, None), e )) id_var_pairs branch | T.Tuple -> @@ -1201,7 +1202,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.map (fun x -> mk_typed_lvalue_from_var x None) vars in Let - ( Assign (mk_tuple_lvalue vars, scrutinee, scrutinee_mplace), + ( mk_tuple_lvalue vars, + Value (scrutinee, scrutinee_mplace), branch ) | T.Assumed T.Box -> (* There should be exactly one variable *) @@ -1211,10 +1213,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* We simply introduce an assignment - the box type is the * identity when extracted (`box a == a`) *) Let - ( Assign - ( mk_typed_lvalue_from_var var None, - scrutinee, - scrutinee_mplace ), + ( mk_typed_lvalue_from_var var None, + Value (scrutinee, scrutinee_mplace), branch )) | branches -> let translate_branch (variant_id : T.VariantId.id option) |