diff options
-rw-r--r-- | src/ExtractToFStar.ml | 2 | ||||
-rw-r--r-- | src/LlbcOfJson.ml | 2 | ||||
-rw-r--r-- | src/Print.ml | 2 | ||||
-rw-r--r-- | src/PrintPure.ml | 96 | ||||
-rw-r--r-- | src/Pure.ml | 54 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 101 | ||||
-rw-r--r-- | src/PureToExtract.ml | 4 | ||||
-rw-r--r-- | src/PureUtils.ml | 60 | ||||
-rw-r--r-- | src/SymbolicAst.ml | 2 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 19 |
10 files changed, 218 insertions, 124 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 9d96d058..4ef40d8b 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -899,7 +899,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) let _ = extract_typed_rvalue ctx fmt inside rv in if not inner then F.pp_close_box fmt (); () - | Call call -> ( + | App call -> ( match (call.func, call.args) with | Unop unop, [ arg ] -> ctx.fmt.extract_unop diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml index 19ffc279..7604ec2b 100644 --- a/src/LlbcOfJson.ml +++ b/src/LlbcOfJson.ml @@ -1,4 +1,4 @@ -(** Functions to load CFIM ASTs from json. +(** Functions to load LLBC ASTs from json. Initially, we used `ppx_derive_yojson` to automate this. However, `ppx_derive_yojson` expects formatting to be slightly diff --git a/src/Print.ml b/src/Print.ml index 98876acb..841fa9b2 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -1198,7 +1198,7 @@ module Module = struct String.concat "\n\n" all_defs end -(** Pretty-printing for CFIM ASTs (functions based on an evaluation context) *) +(** Pretty-printing for LLBC ASTs (functions based on an evaluation context) *) module EvalCtxLlbcAst = struct let ety_to_string (ctx : C.eval_ctx) (t : T.ety) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 8344ee41..158d4c3c 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -1,6 +1,7 @@ (** This module defines printing functions for the types defined in Pure.ml *) open Pure +open PureUtils module T = Types module V = Values module E = Expressions @@ -425,50 +426,86 @@ let meta_to_string (fmt : ast_formatter) (meta : meta) : string = in "@meta[" ^ meta ^ "]" -let rec expression_to_string (fmt : ast_formatter) (indent : string) - (indent_incr : string) (e : expression) : string = - match e with +let rec texpression_to_string (fmt : ast_formatter) (inner : bool) + (indent : string) (indent_incr : string) (e : texpression) : string = + match e.e with | Value (v, mp) -> let mp = match mp with | None -> "" | Some mp -> " [@mplace=" ^ mplace_to_string fmt mp ^ "]" in - "(" ^ typed_rvalue_to_string fmt v ^ mp ^ ")" - | Call call -> call_to_string fmt indent indent_incr call + let e = typed_rvalue_to_string fmt v ^ mp in + if inner then "(" ^ e ^ ")" else e + | App _ -> + (* Recursively destruct the app, to have a pair (app, arguments list) *) + let app, args = destruct_apps e in + (* Convert to string *) + app_to_string fmt inner indent indent_incr app args + | Func _ -> + (* Func without arguments *) + app_to_string fmt inner indent indent_incr e [] | Let (monadic, lv, re, e) -> - let_to_string fmt indent indent_incr monadic lv re e + let e = let_to_string fmt indent indent_incr monadic lv re e in + if inner then "(" ^ e ^ ")" else e | Switch (scrutinee, body) -> - switch_to_string fmt indent indent_incr scrutinee body + let e = switch_to_string fmt indent indent_incr scrutinee body in + if inner then "(" ^ e ^ ")" else e | Meta (meta, e) -> let meta = meta_to_string fmt meta in - let e = texpression_to_string fmt indent indent_incr e in - meta ^ "\n" ^ indent ^ e + let e = texpression_to_string fmt inner indent indent_incr e in + let e = meta ^ "\n" ^ indent ^ e in + if inner then "(" ^ e ^ ")" else e -and texpression_to_string (fmt : ast_formatter) (indent : string) +(*and texpression_to_string (fmt : ast_formatter) (inner : bool) (indent : string) (indent_incr : string) (e : texpression) : string = - expression_to_string fmt indent indent_incr e.e + expression_to_string fmt inner indent indent_incr inner e.e*) -and call_to_string (fmt : ast_formatter) (indent : string) - (indent_incr : string) (call : call) : string = - let ty_fmt = ast_to_type_formatter fmt in - let tys = List.map (ty_to_string ty_fmt) call.type_params in - (* The arguments are expressions, so indentation might get weird... (though +and app_to_string (fmt : ast_formatter) (inner : bool) (indent : string) + (indent_incr : string) (app : texpression) (args : texpression list) : + string = + (* There are two possibilities: either the `app` is an instantiated, + * top-level function, or it is a "regular" expression *) + let app, tys = + match app.e with + | Func func -> + (* Function case *) + (* Convert the function identifier *) + let fun_id = fun_id_to_string fmt func.func in + (* Convert the type instantiation *) + let ty_fmt = ast_to_type_formatter fmt in + let tys = List.map (ty_to_string ty_fmt) func.type_params in + (* *) + (fun_id, tys) + | _ -> + (* "Regular" expression case *) + let inner = args <> [] || (args = [] && inner) in + (texpression_to_string fmt inner indent indent_incr app, []) + in + (* Convert the arguments. + * The arguments are expressions, so indentation might get weird... (though * those expressions will in most cases just be values) *) - let indent1 = indent ^ indent_incr in - let args = - List.map (texpression_to_string fmt indent1 indent_incr) call.args + let arg_to_string = + let inner = true in + let indent1 = indent ^ indent_incr in + texpression_to_string fmt inner indent1 indent_incr in + let args = List.map arg_to_string args in let all_args = List.append tys args in - let fun_id = fun_id_to_string fmt call.func in - if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args + (* Put together *) + let e = + if all_args = [] then app else app ^ " " ^ String.concat " " all_args + in + (* Add parentheses *) + if all_args <> [] && inner then "(" ^ e ^ ")" else e and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) : string = let indent1 = indent ^ indent_incr in - let re = texpression_to_string fmt indent1 indent_incr re in - let e = texpression_to_string fmt indent indent_incr e in + let inner = false in + let re = texpression_to_string fmt inner indent1 indent_incr re in + let e = texpression_to_string fmt inner indent indent_incr e in let lv = typed_lvalue_to_string fmt lv in if monadic then lv ^ " <-- " ^ re ^ ";\n" ^ indent ^ e else "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e @@ -480,18 +517,18 @@ and switch_to_string (fmt : ast_formatter) (indent : string) (* Printing can mess up on the scrutinee, because it is an expression - but * in most situations it will be a value or a function call, so it should be * ok*) - let scrut = texpression_to_string fmt indent1 indent_incr scrutinee in + let scrut = texpression_to_string fmt true indent1 indent_incr scrutinee in + let e_to_string = texpression_to_string fmt false indent1 indent_incr in match body with | If (e_true, e_false) -> - let e_true = texpression_to_string fmt indent1 indent_incr e_true in - let e_false = texpression_to_string fmt indent1 indent_incr e_false in + let e_true = e_to_string e_true in + let e_false = e_to_string e_false in "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ indent1 ^ e_true ^ "\n" ^ indent ^ "else\n" ^ indent1 ^ e_false | Match branches -> let branch_to_string (b : match_branch) : string = let pat = typed_lvalue_to_string fmt b.pat in - indent ^ "| " ^ pat ^ " ->\n" ^ indent1 - ^ texpression_to_string fmt indent1 indent_incr b.branch + indent ^ "| " ^ pat ^ " ->\n" ^ indent1 ^ e_to_string b.branch in let branches = List.map branch_to_string branches in "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches @@ -503,11 +540,12 @@ let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string = match def.body with | None -> "val " ^ name ^ " :\n " ^ signature | Some body -> + let inner = false in let indent = " " in let inputs = List.map (var_to_string type_fmt) body.inputs in let inputs = if inputs = [] then indent else " fun " ^ String.concat " " inputs ^ " ->\n" ^ indent in - let body = texpression_to_string fmt indent indent body.body in + let body = texpression_to_string fmt inner indent indent body.body in "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ body diff --git a/src/Pure.ml b/src/Pure.ml index e7721b41..ebc92258 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -23,7 +23,7 @@ type integer_type = T.integer_type [@@deriving show, ord] (** The assumed types for the pure AST. - In comparison with CFIM: + In comparison with LLBC: - we removed `Box` (because it is translated as the identity: `Box T == T`) - we added: - `Result`: the type used in the error monad. This allows us to have a @@ -168,6 +168,7 @@ type mplace = { we introduce. *) +(* TODO: there shouldn't be places *) type place = { var : VarId.id; projection : projection } [@@deriving show] (** Ancestor for [iter_var_or_dummy] visitor *) @@ -239,13 +240,16 @@ class virtual ['self] mapreduce_value_base = method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero) end +(* TODO: merge with expressions *) type rvalue = | RvConcrete of constant_value - | RvPlace of place + | RvPlace of place (* TODO: field projectors should be expressions *) | RvAdt of adt_rvalue and adt_rvalue = { variant_id : (VariantId.id option[@opaque]); + (* TODO: variant constructors should be expressions, treated in a manner + * similar to functions *) field_values : typed_rvalue list; } @@ -332,7 +336,10 @@ type var_or_dummy = polymorphic = false; }] -(** A left value (which appears on the left of assignments *) +(** A left value (which appears on the left of assignments. + + TODO: rename to "pattern" + *) type lvalue = | LvConcrete of constant_value (** [LvConcrete] is necessary because we merge the switches over integer @@ -405,6 +412,14 @@ type fun_id = type meta = Assignment of mplace * typed_rvalue * mplace option [@@deriving show] +type func = { func : fun_id; type_params : ty list } [@@deriving show] +(** A function. + + Note that for now we have a clear separation between types and expressions, + which explains why we have the `type_params` field: a function is always + fully instantiated. + *) + (** Ancestor for [iter_expression] visitor *) class ['self] iter_expression_base = object (_self : 'self) @@ -416,7 +431,7 @@ class ['self] iter_expression_base = method visit_scalar_value : 'env -> scalar_value -> unit = fun _ _ -> () - method visit_fun_id : 'env -> fun_id -> unit = fun _ _ -> () + method visit_func : 'env -> func -> unit = fun _ _ -> () end (** Ancestor for [map_expression] visitor *) @@ -432,7 +447,7 @@ class ['self] map_expression_base = method visit_scalar_value : 'env -> scalar_value -> scalar_value = fun _ x -> x - method visit_fun_id : 'env -> fun_id -> fun_id = fun _ x -> x + method visit_func : 'env -> func -> func = fun _ x -> x end (** Ancestor for [reduce_expression] visitor *) @@ -448,7 +463,7 @@ class virtual ['self] reduce_expression_base = method visit_scalar_value : 'env -> scalar_value -> 'a = fun _ _ -> self#zero - method visit_fun_id : 'env -> fun_id -> 'a = fun _ _ -> self#zero + method visit_func : 'env -> func -> 'a = fun _ _ -> self#zero end (** Ancestor for [mapreduce_expression] visitor *) @@ -464,23 +479,25 @@ class virtual ['self] mapreduce_expression_base = method visit_scalar_value : 'env -> scalar_value -> scalar_value * 'a = fun _ x -> (x, self#zero) - method visit_fun_id : 'env -> fun_id -> fun_id * 'a = - fun _ x -> (x, self#zero) + method visit_func : 'env -> func -> func * 'a = fun _ x -> (x, self#zero) end (** **Rk.:** here, [expression] is not at all equivalent to the expressions - used in CFIM. They are lambda-calculus expressions, and are thus actually - more general than the CFIM statements, in a sense. + used in LLBC. They are lambda-calculus expressions, and are thus actually + more general than the LLBC statements, in a sense. *) type expression = | Value of typed_rvalue * mplace option - | Call of call - (** The function calls are still quite structured. + | App of texpression * texpression + (** Application of a function to an argument. + + The function calls are still quite structured. Change that?... We might want to have a "normal" lambda calculus app (with head and argument): this would allow us to replace some field accesses with calls to projectors over fields (when there are clashes of field names, some provers like F* get pretty bad...) *) + | Func of func (** A function - TODO: change to Qualifier *) | Let of bool * typed_lvalue * texpression * texpression (** Let binding. @@ -523,19 +540,6 @@ type expression = | Switch of texpression * switch_body | Meta of meta * texpression (** Meta-information *) -and call = { - func : fun_id; - type_params : ty list; - args : texpression list; - (** Note that immediately after we converted the symbolic AST to a pure AST, - some functions may have no arguments. For instance: - ``` - fn f(); - ``` - We later add a unit argument. - *) -} - and switch_body = If of texpression * texpression | Match of match_branch list and match_branch = { pat : typed_lvalue; branch : texpression } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index bc4cdc3c..198a4d89 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -69,8 +69,6 @@ type config = { borrows as inputs, it can't return mutable borrows; we actually dynamically check for that). *) - add_unit_args : bool; - (** Add unit input arguments to functions with no arguments. *) } (** A configuration to control the application of the passes *) @@ -444,7 +442,12 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, e = match e.e with | Value (v, mp) -> update_value v mp ctx - | Call call -> update_call call ctx + | App (app, arg) -> + let ctx, app = update_texpression app ctx in + let ctx, arg = update_texpression app ctx in + let e = App (app, arg) in + (ctx, e) + | Func _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Meta (meta, e) -> update_meta meta e ctx @@ -456,15 +459,6 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx = add_opt_right_constraint mp v ctx in (ctx, Value (v, mp)) (* *) - and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression = - let ctx, args = - List.fold_left_map - (fun ctx arg -> update_texpression arg ctx) - ctx call.args - in - let call = { call with args } in - (ctx, Call call) - (* *) and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) @@ -597,13 +591,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) inherit [_] map_expression as super method! visit_Let env monadic lv re e = - (* Check that: + (* In order to filter, we need to check first that: * - the let-binding is not monadic * - the left-value is a variable *) match (monadic, lv.value) with | false, LvVar (Var (lv_var, _)) -> - (* Check that: *) + (* We can filter if: *) let filter = false in (* 1. Either: * - the left variable is unnamed or [inline_named] is true @@ -622,10 +616,14 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) if inline_pure then match re.e with | Value _ -> true - | Call call -> ( - match call.func with - | Regular _ -> false - | Unop _ | Binop _ -> true) + | App _ -> ( + (* Application: decompose, and check that function call *) + match opt_destruct_function_call re with + | Some (func, _) -> ( + match func.func with + | Regular _ -> false + | Unop _ | Binop _ -> true) + | _ -> false) | _ -> filter else false in @@ -716,11 +714,11 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call `f x`. *) -let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) - (e : texpression) : bool = - let check_call call1 : bool = +let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) + (args0 : texpression list) (e : texpression) : bool = + let check_call (func1 : func) (args1 : texpression list) : bool = (* Check the func_ids, to see if call1's function is a child of call0's function *) - match (call0.func, call1.func) with + match (func0.func, func1.func) with | Regular (id0, rg_id0), Regular (id1, rg_id1) -> (* Both are "regular" calls: check if they come from the same rust function *) if id0 = id1 then @@ -738,7 +736,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) if rg_id0 = rg_id1 then true else (* We need to use the regions hierarchy *) - (* First, lookup the signature of the CFIM function *) + (* First, lookup the signature of the LLBC function *) let sg = LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls in @@ -756,9 +754,9 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) * *) if call1_is_child then let call1_args = - Collections.List.prefix (List.length call0.args) call1.args + Collections.List.prefix (List.length args0) args1 in - let args = List.combine call0.args call1_args in + let args = List.combine args0 call1_args in (* Note that the input values are expressions, *which may contain * meta-values* (which we need to ignore). We only consider the * case where both expressions are actually values. *) @@ -767,7 +765,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) | Value (v0, _), Value (v1, _) -> v0 = v1 | _ -> false in - call0.type_params = call1.type_params && List.for_all input_eq args + (* Compare the input types and the prefix of the input arguments *) + func0.type_params = func1.type_params && List.for_all input_eq args else (* Not a child *) false else (* Not the same function *) @@ -783,25 +782,24 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) method plus b0 b1 _ = b0 () && b1 () - method! visit_expression env e = - match e with + method! visit_texpression env e = + match e.e with | Value (_, _) -> fun _ -> false - | Let (_, _, { e = Call call1; ty = _ }, e) -> - let call_is_child = check_call call1 in - if call_is_child then fun () -> true - else self#visit_texpression env e - | Let (_, _, re, e) -> + | Let (_, _, re, e) -> ( + match opt_destruct_function_call re with + | None -> fun () -> self#visit_texpression env e () + | Some (func1, args1) -> + let call_is_child = check_call func1 args1 in + if call_is_child then fun () -> true + else fun () -> self#visit_texpression env e ()) + | App _ -> ( fun () -> - self#visit_texpression env re () - && self#visit_texpression env e () - | Call call1 -> fun () -> check_call call1 + match opt_destruct_function_call e with + | Some (func1, args1) -> check_call func1 args1 + | None -> false) + | Func _ -> fun () -> false | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body - (** We need to reimplement the way we compose the booleans *) - - method! visit_texpression env e = - (* We take care not to visit the type *) - self#visit_expression env e.e method! visit_switch_body env body = match body with @@ -877,7 +875,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Value (_, _) | Call _ | Switch (_, _) | Meta (_, _) -> + | Value (_, _) | App _ | Switch (_, _) | Meta (_, _) -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -900,13 +898,14 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) (* Monadic let-binding: trickier. * We can filter if the right-expression is a function call, * under some conditions. *) - match (filter_monadic_calls, re.e) with - | true, Call call -> + match (filter_monadic_calls, opt_destruct_function_call re) with + | true, Some (func, args) -> (* We need to check if there is a child call - see * the comments for: * [expression_contains_child_call_in_all_paths] *) let has_child_call = - expression_contains_child_call_in_all_paths ctx call e + expression_contains_child_call_in_all_paths ctx func args + e in if has_child_call then (* Filter *) (e.e, fun _ -> used) @@ -987,6 +986,8 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = (** Add unit arguments (optionally) to functions with no arguments, and change their output type to use `result` + + TODO: remove this *) let to_monadic (config : config) (def : fun_decl) : fun_decl = (* Update the body *) @@ -1109,7 +1110,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = object inherit [_] map_expression as super - method! visit_Call env call = + method! visit_App env call = match call.func with | Regular (A.Assumed aid, rg_id) -> ( match (aid, rg_id) with @@ -1139,8 +1140,8 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen | A.VecIndex | A.VecIndexMut ), _ ) -> - super#visit_Call env call) - | _ -> super#visit_Call env call + super#visit_App env call) + | _ -> super#visit_App env call end in (* Update the body *) @@ -1229,7 +1230,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) *) let re_call = match re.e with - | Call call -> call + | App call -> call | _ -> raise (Failure "Unreachable: expected a function call") in (* TODO: this information should be computed in SymbolicToPure and @@ -1248,7 +1249,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let args = call.args @ [ mk_value_expression state_value None ] in - Call { call with args } + App { call with args } in let re = { re with e = re_call } in (* Create the match *) diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index bbcf2cec..1c530011 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -333,7 +333,7 @@ type extraction_ctx = { (** Extraction context. Note that the extraction context contains information coming from the - CFIM AST (not only the pure AST). This is useful for naming, for instance: + LLBC AST (not only the pure AST). This is useful for naming, for instance: we use the region information to generate the names of the backward functions, etc. *) @@ -570,7 +570,7 @@ let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - (* Lookup the CFIM def to compute the region group information *) + (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in let llbc_def = FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 23bffa1d..0045cc1d 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -244,9 +244,9 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = object inherit [_] iter_expression as super - method! visit_call env call = - if FunIdSet.mem call.func !ids then raise Utils.Found - else super#visit_call env call + method! visit_func env func = + if FunIdSet.mem func.func !ids then raise Utils.Found + else super#visit_func env func end in @@ -264,13 +264,13 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = We only look for outer monadic let-bindings. This is used when printing the branches of `if ... then ... else ...`. *) -let rec expression_requires_parentheses (e : texpression) : bool = +let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Value _ | Call _ -> false + | Value _ | App _ | Func _ -> false | Let (monadic, _, _, next_e) -> - if monadic then true else expression_requires_parentheses next_e + if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false - | Meta (_, next_e) -> expression_requires_parentheses next_e + | Meta (_, next_e) -> let_group_requires_parentheses next_e (** Module to perform type checking - we use this for sanity checks only *) module TypeCheck = struct @@ -385,3 +385,49 @@ let as_var (e : texpression) : VarId.id = (** Remove the external occurrences of [Meta] *) let rec unmeta (e : texpression) : texpression = match e.e with Meta (_, e) -> unmeta e | _ -> e + +(** Construct a type as a list of arrows: ty1 -> ... tyn *) +let mk_arrows (inputs : ty list) (output : ty) = + let rec aux (tys : ty list) : ty = + match tys with [] -> output | ty :: tys' -> Arrow (ty, aux tys') + in + aux inputs + +(** Destruct an `App` expression into an expression and a list of arguments. + + We simply destruct the expression as long as it is of the form `App (f, x)`. + *) +let destruct_apps (e : texpression) : texpression * texpression list = + let rec aux (args : texpression list) (e : texpression) : + texpression * texpression list = + match e.e with App (f, x) -> aux (x :: args) f | _ -> (e, args) + in + aux [] e + +(** The reverse of [destruct_app] *) +let mk_apps (e : texpression) (args : texpression list) : texpression = + (* Reverse the arguments *) + let args = List.rev args in + (* Apply *) + let rec aux (e : texpression) (args : texpression list) : texpression = + match args with + | [] -> e + | arg :: args' -> ( + let e' = aux e args' in + match e'.ty with + | Arrow (ty0, ty1) -> + (* Sanity check *) + assert (ty0 == arg.ty); + let e'' = App (e', arg) in + let ty'' = ty1 in + { e = e''; ty = ty'' } + | _ -> raise (Failure "Expected an arrow type")) + in + aux e args + +(* Destruct an expression into a function identifier and a list of arguments, + * if possible *) +let opt_destruct_function_call (e : texpression) : + (func * texpression list) option = + let app, args = destruct_apps e in + match app.e with Func func -> Some (func, args) | _ -> None diff --git a/src/SymbolicAst.ml b/src/SymbolicAst.ml index 5fa7d754..9cab092d 100644 --- a/src/SymbolicAst.ml +++ b/src/SymbolicAst.ml @@ -53,7 +53,7 @@ type meta = (** We generated an assignment (destination, assigned value, src) *) (** **Rk.:** here, [expression] is not at all equivalent to the expressions - used in CFIM: they are a first step towards lambda-calculus expressions. + used in LLBC: they are a first step towards lambda-calculus expressions. *) type expression = | Return of V.typed_value option diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index e32e28d6..18e2b873 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -1045,10 +1045,12 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (List.combine args args_mplaces) in let dest_v = mk_typed_lvalue_from_var dest dest_mplace in - let call = { func; type_params; args } in - let call = Call call in - let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in - let call = { e = call; ty = call_ty } in + let func = { func; type_params } in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in + let func_ty = mk_arrows input_tys ret_ty in + let func = { e = Func func; ty = func_ty } in + let call = mk_apps func args in (* Translate the next expression *) let next_e = translate_expression config e ctx in (* Put together *) @@ -1177,9 +1179,12 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (List.combine inputs args_mplaces) in let monadic = fun_is_monadic fun_id in - let call = { func; type_params; args } in - let call_ty = mk_result_ty output.ty in - let call = { e = Call call; ty = call_ty } in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = if monadic then mk_result_ty output.ty else output.ty in + let func_ty = mk_arrows input_tys ret_ty in + let func = { func; type_params } in + let func = { e = Func func; ty = func_ty } in + let call = mk_apps func args in (* **Optimization**: * ================= * We do a small optimization here: if the backward function doesn't |