From 00104884e101d3125e62dde9757b9c1cacb3feec Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 3 Mar 2022 17:36:33 +0100 Subject: Make good progress on adding support for external and opaque declarations --- src/ExtractToFStar.ml | 50 +++-- src/Interpreter.ml | 42 ++-- src/InterpreterStatements.ml | 23 ++- src/LlbcAst.ml | 19 -- src/LlbcAstUtils.ml | 11 +- src/LlbcOfJson.ml | 18 +- src/PrePasses.ml | 6 +- src/Print.ml | 73 ++++--- src/PrintPure.ml | 16 +- src/Pure.ml | 14 +- src/PureMicroPasses.ml | 474 ++++++++++++++++++++++++------------------- src/PureUtils.ml | 7 +- src/Substitute.ml | 8 +- src/SymbolicToPure.ml | 60 +++--- src/Translate.ml | 8 +- 15 files changed, 475 insertions(+), 354 deletions(-) diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 6bbc21d7..2e0568c8 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -1025,22 +1025,25 @@ let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt ()); (* The input parameters - note that doing this adds bindings to the context *) let ctx_body = - List.fold_left - (fun ctx (lv : typed_lvalue) -> - (* Open a box for the input parameter *) - F.pp_open_hovbox fmt 0; - F.pp_print_string fmt "("; - let ctx = extract_typed_lvalue ctx fmt false lv in - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx fmt false lv.ty; - F.pp_print_string fmt ")"; - (* Close the box for the input parameters *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - ctx) - ctx def.inputs_lvs + match def.body with + | None -> ctx + | Some body -> + List.fold_left + (fun ctx (lv : typed_lvalue) -> + (* Open a box for the input parameter *) + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "("; + let ctx = extract_typed_lvalue ctx fmt false lv in + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false lv.ty; + F.pp_print_string fmt ")"; + (* Close the box for the input parameters *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + ctx) + ctx body.inputs_lvs in (ctx, ctx_body) @@ -1169,7 +1172,8 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (Collections.List.to_cons_nil def.signature.outputs); (* Close the box for the return type *) F.pp_close_box fmt (); - (* Print the decrease clause *) + (* Print the decrease clause - rk.: a function with a decreases clause + * is necessarily a transparent function *) if has_decreases_clause then ( F.pp_print_space fmt (); (* Open a box for the decrease clause *) @@ -1193,9 +1197,13 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) * function (the additional input values "given back" to the * backward functions have no influence on termination: we thus * share the decrease clauses between the forward and the backward - * functions) *) + * functions). + * Rk.: if a function has a decreases clause, it is necessarily + * a transparent function *) let inputs_lvs = - Collections.List.prefix (List.length fwd_def.inputs_lvs) def.inputs_lvs + Collections.List.prefix + (List.length (Option.get fwd_def.body).inputs_lvs) + (Option.get def.body).inputs_lvs in let _ = List.fold_left @@ -1224,7 +1232,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the body *) F.pp_open_hvbox fmt 0; (* Extract the body *) - let _ = extract_texpression ctx_body fmt false false def.body in + let _ = + extract_texpression ctx_body fmt false false (Option.get def.body).body + in (* Close the box for the body *) F.pp_close_box fmt ()); (* Close the box for the definition *) diff --git a/src/Interpreter.ml b/src/Interpreter.ml index 82eb4b35..f6ae268d 100644 --- a/src/Interpreter.ml +++ b/src/Interpreter.ml @@ -91,9 +91,10 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) inst_sg.A.regions_hierarchy compute_abs_avalues ctx in (* Split the variables between return var, inputs and remaining locals *) - let ret_var = List.hd fdef.locals in + let body = Option.get fdef.body in + let ret_var = List.hd body.locals in let input_vars, local_vars = - Collections.List.split_at (List.tl fdef.locals) fdef.arg_count + Collections.List.split_at (List.tl body.locals) body.arg_count in (* Push the return variable (initialized with ⊥) *) let ctx = C.ctx_push_uninitialized_var ctx ret_var in @@ -242,7 +243,9 @@ let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool) in (* Evaluate the function *) - let symbolic = eval_function_body config fdef.A.body cf_finish ctx in + let symbolic = + eval_function_body config (Option.get fdef.A.body).body cf_finish ctx + in (* Return *) (input_svs, symbolic) @@ -255,6 +258,7 @@ module Test = struct (fid : A.FunDeclId.id) : unit = (* Retrieve the function declaration *) let fdef = A.FunDeclId.nth m.functions fid in + let body = Option.get fdef.body in (* Debug *) log#ldebug @@ -263,14 +267,14 @@ module Test = struct (* Sanity check - *) assert (List.length fdef.A.signature.region_params = 0); assert (List.length fdef.A.signature.type_params = 0); - assert (fdef.A.arg_count = 0); + assert (body.A.arg_count = 0); (* Create the evaluation context *) let type_context, fun_context = compute_type_fun_contexts m in let ctx = initialize_eval_context type_context fun_context [] in (* Insert the (uninitialized) local variables *) - let ctx = C.ctx_push_uninitialized_vars ctx fdef.A.locals in + let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in (* Create the continuation to check the function's result *) let config = C.config_of_partial C.ConcreteMode config in @@ -286,21 +290,24 @@ module Test = struct in (* Evaluate the function *) - let _ = eval_function_body config fdef.A.body cf_check ctx in + let _ = eval_function_body config body.body cf_check ctx in () - (** Small helper: return true if the function is a unit function (no parameters, - no arguments) - TODO: move *) - let fun_decl_is_unit (def : A.fun_decl) : bool = - def.A.arg_count = 0 - && List.length def.A.signature.region_params = 0 - && List.length def.A.signature.type_params = 0 - && List.length def.A.signature.inputs = 0 + (** Small helper: return true if the function is a *transparent* unit function + (no parameters, no arguments) - TODO: move *) + let fun_decl_is_transparent_unit (def : A.fun_decl) : bool = + match def.body with + | None -> false + | Some body -> + body.arg_count = 0 + && List.length def.A.signature.region_params = 0 + && List.length def.A.signature.type_params = 0 + && List.length def.A.signature.inputs = 0 (** Test all the unit functions in a list of function definitions *) let test_unit_functions (config : C.partial_config) (m : M.llbc_module) : unit = - let unit_funs = List.filter fun_decl_is_unit m.functions in + let unit_funs = List.filter fun_decl_is_transparent_unit m.functions in let test_unit_fun (def : A.fun_decl) : unit = test_unit_function config m def.A.def_id in @@ -329,6 +336,10 @@ module Test = struct () + (** Small helper *) + let fun_decl_is_transparent (def : A.fun_decl) : bool = + Option.is_some def.body + (** Execute the symbolic interpreter on a list of functions. TODO: for now we ignore the functions which contain loops, because @@ -336,9 +347,12 @@ module Test = struct *) let test_functions_symbolic (config : C.partial_config) (synthesize : bool) (m : M.llbc_module) : unit = + (* Filter the functions which contain loops *) let no_loop_funs = List.filter (fun f -> not (LlbcAstUtils.fun_decl_has_loops f)) m.functions in + (* Filter the opaque functions *) + let no_loop_funs = List.filter fun_decl_is_transparent no_loop_funs in let type_context, fun_context = compute_type_fun_contexts m in let test_fun (def : A.fun_decl) : unit = (* Execute the function - note that as the symbolic interpreter explores diff --git a/src/InterpreterStatements.ml b/src/InterpreterStatements.ml index 547f5ee3..3ea0a6fa 100644 --- a/src/InterpreterStatements.ml +++ b/src/InterpreterStatements.ml @@ -233,8 +233,9 @@ let set_discriminant (config : C.config) (p : E.place) let bottom_v = match type_id with | T.AdtId def_id -> - compute_expanded_bottom_adt_value ctx.type_context.type_decls - def_id (Some variant_id) regions types + compute_expanded_bottom_adt_value + ctx.type_context.type_decls def_id (Some variant_id) + regions types | T.Assumed T.Option -> assert (regions = []); compute_expanded_bottom_option_value variant_id @@ -1005,15 +1006,21 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) (* Retrieve the (correctly instantiated) body *) let def = C.ctx_lookup_fun_decl ctx fid in + (* We can evaluate the function call only if it is not opaque *) + let body = + match def.body with + | None -> raise (Failure "Can't evaluate a call to an opaque function") + | Some body -> body + in let tsubst = Subst.make_type_subst (List.map (fun v -> v.T.index) def.A.signature.type_params) type_params in - let locals, body = Subst.fun_decl_substitute_in_body tsubst def in + let locals, body_st = Subst.fun_body_substitute_in_body tsubst body in (* Evaluate the input operands *) - assert (List.length args = def.A.arg_count); + assert (List.length args = body.A.arg_count); let cc = eval_operands config args in (* Push a frame delimiter - we use [comp_transmit] to transmit the result @@ -1028,7 +1035,9 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) | ret_ty :: locals -> (ret_ty, locals) | _ -> raise (Failure "Unreachable") in - let input_locals, locals = Collections.List.split_at locals def.A.arg_count in + let input_locals, locals = + Collections.List.split_at locals body.A.arg_count + in let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in @@ -1045,7 +1054,7 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) let cc = comp cc (push_uninitialized_vars locals) in (* Execute the function body *) - let cc = comp cc (eval_function_body config body) in + let cc = comp cc (eval_function_body config body_st) in (* Pop the stack frame and move the return value to its destination *) let cf_finish cf res = @@ -1074,7 +1083,7 @@ and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) * while doing so *) let inst_sg = instantiate_fun_sig type_params sg in (* Sanity check *) - assert (List.length args = def.A.arg_count); + assert (List.length args = List.length def.A.signature.inputs); (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (A.Local fid) inst_sg region_params type_params args dest cf ctx diff --git a/src/LlbcAst.ml b/src/LlbcAst.ml index 149fb23d..f5ffc956 100644 --- a/src/LlbcAst.ml +++ b/src/LlbcAst.ml @@ -170,24 +170,6 @@ and switch_targets = concrete = true; }] -type fun_decl = { - def_id : FunDeclId.id; - name : fun_name; - signature : fun_sig; - arg_count : int; - locals : var list; - body : statement; -} -[@@deriving show] -(** TODO: function definitions (and maybe type definitions in the future) - * contain information like `divergent`. I wonder if this information should - * be stored directly inside the definitions or inside separate maps/sets. - * Of course, if everything is stored in separate maps/sets, nothing - * prevents us from computing this info in Charon (and thus exporting directly - * it with the type/function defs), in which case we just have to implement special - * treatment when deserializing, to move the info to a separate map. *) - -(* type fun_body = { arg_count : int; locals : var list; body : statement } [@@deriving show] @@ -198,4 +180,3 @@ type fun_decl = { body : fun_body option; } [@@deriving show] -*) diff --git a/src/LlbcAstUtils.ml b/src/LlbcAstUtils.ml index 93ca4448..41d17bf4 100644 --- a/src/LlbcAstUtils.ml +++ b/src/LlbcAstUtils.ml @@ -17,7 +17,10 @@ let statement_has_loops (st : statement) : bool = with Found -> true (** Check if a [fun_decl] contains loops *) -let fun_decl_has_loops (fd : fun_decl) : bool = statement_has_loops fd.body +let fun_decl_has_loops (fd : fun_decl) : bool = + match fd.body with + | Some body -> statement_has_loops body.body + | None -> false let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) : fun_sig = @@ -64,6 +67,6 @@ let list_ordered_parent_region_groups (sg : fun_sig) (gid : T.RegionGroupId.id) let parents = List.map (fun (rg : T.region_var_group) -> rg.id) parents in parents -let fun_decl_get_input_vars (fdef : fun_decl) : var list = - let locals = List.tl fdef.locals in - Collections.List.prefix fdef.arg_count locals +let fun_body_get_input_vars (fbody : fun_body) : var list = + let locals = List.tl fbody.locals in + Collections.List.prefix fbody.arg_count locals diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml index e293b030..6e0adfb6 100644 --- a/src/LlbcOfJson.ml +++ b/src/LlbcOfJson.ml @@ -607,6 +607,16 @@ and switch_targets_of_json (js : json) : (A.switch_targets, string) result = Ok (A.SwitchInt (int_ty, tgts, otherwise)) | _ -> Error "") +let fun_body_of_json (js : json) : (A.fun_body, string) result = + combine_error_msgs js "fun_body_of_json" + (match js with + | `Assoc [ ("arg_count", arg_count); ("locals", locals); ("body", body) ] -> + let* arg_count = int_of_json arg_count in + let* locals = list_of_json var_of_json locals in + let* body = statement_of_json body in + Ok { A.arg_count; locals; body } + | _ -> Error "") + let fun_decl_of_json (js : json) : (A.fun_decl, string) result = combine_error_msgs js "fun_decl_of_json" (match js with @@ -615,17 +625,13 @@ let fun_decl_of_json (js : json) : (A.fun_decl, string) result = ("def_id", def_id); ("name", name); ("signature", signature); - ("arg_count", arg_count); - ("locals", locals); ("body", body); ] -> let* def_id = A.FunDeclId.id_of_json def_id in let* name = fun_name_of_json name in let* signature = fun_sig_of_json signature in - let* arg_count = int_of_json arg_count in - let* locals = list_of_json var_of_json locals in - let* body = statement_of_json body in - Ok { A.def_id; name; signature; arg_count; locals; body } + let* body = option_of_json fun_body_of_json body in + Ok { A.def_id; name; signature; body } | _ -> Error "") let g_declaration_group_of_json (id_of_json : json -> ('id, string) result) diff --git a/src/PrePasses.ml b/src/PrePasses.ml index 9b1a6990..dda3c867 100644 --- a/src/PrePasses.ml +++ b/src/PrePasses.ml @@ -42,7 +42,11 @@ let filter_drop_assigns (f : A.fun_decl) : A.fun_decl = end in (* Map *) - let body = obj#visit_statement () f.body in + let body = + match f.body with + | Some body -> Some { body with body = obj#visit_statement () body.body } + | None -> None + in { f with body } let apply_passes (m : M.llbc_module) : M.llbc_module = diff --git a/src/Print.ml b/src/Print.ml index 18227c61..d2101dae 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -784,7 +784,7 @@ module LlbcAst = struct PC.type_ctx_to_adt_variant_to_string_fun type_decls in let var_id_to_string vid = - let var = V.VarId.nth fdef.locals vid in + let var = V.VarId.nth (Option.get fdef.body).locals vid in var_to_string var in let adt_field_names = PC.type_ctx_to_adt_field_names_fun type_decls in @@ -1062,41 +1062,56 @@ module LlbcAst = struct "<" ^ String.concat "," (List.append regions types) ^ ">" in - (* Arguments *) - let inputs = List.tl def.locals in - let inputs, _aux_locals = Collections.List.split_at inputs def.arg_count in - let args = List.combine inputs sg.inputs in - let args = - List.map - (fun (var, rty) -> var_to_string var ^ " : " ^ sty_to_string rty) - args - in - let args = String.concat ", " args in - (* Return type *) let ret_ty = sg.output in let ret_ty = if TU.ty_is_unit ret_ty then "" else " -> " ^ sty_to_string ret_ty in - (* All the locals (with erased regions) *) - let locals = - List.map - (fun var -> - indent ^ indent_incr ^ var_to_string var ^ " : " - ^ ety_to_string var.var_ty ^ ";") - def.locals - in - let locals = String.concat "\n" locals in + (* We print the declaration differently if it is opaque (no body) or transparent + * (we have access to a body) *) + match def.body with + | None -> + (* Arguments - we need to ignore the first input type which is actually + * the return type... TODO: fix that *) + let input_tys = List.tl sg.inputs in + let args = List.map sty_to_string input_tys in + let args = String.concat ", " args in + + (* Put everything together *) + indent ^ "opaque fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty + | Some body -> + (* Arguments *) + let inputs = List.tl body.locals in + let inputs, _aux_locals = + Collections.List.split_at inputs body.arg_count + in + let args = List.combine inputs sg.inputs in + let args = + List.map + (fun (var, rty) -> var_to_string var ^ " : " ^ sty_to_string rty) + args + in + let args = String.concat ", " args in + + (* All the locals (with erased regions) *) + let locals = + List.map + (fun var -> + indent ^ indent_incr ^ var_to_string var ^ " : " + ^ ety_to_string var.var_ty ^ ";") + body.locals + in + let locals = String.concat "\n" locals in - (* Body *) - let body = - statement_to_string fmt (indent ^ indent_incr) indent_incr def.body - in + (* Body *) + let body = + statement_to_string fmt (indent ^ indent_incr) indent_incr body.body + in - (* Put everything together *) - indent ^ "fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty ^ " {\n" ^ locals - ^ "\n\n" ^ body ^ "\n" ^ indent ^ "}" + (* Put everything together *) + indent ^ "fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty ^ " {\n" + ^ locals ^ "\n\n" ^ body ^ "\n" ^ indent ^ "}" end module PA = LlbcAst (* local module *) @@ -1139,7 +1154,7 @@ module Module = struct fun_name_to_string def.name in let var_id_to_string vid = - let var = V.VarId.nth def.locals vid in + let var = V.VarId.nth (Option.get def.body).locals vid in PA.var_to_string var in let adt_variant_to_string = diff --git a/src/PrintPure.ml b/src/PrintPure.ml index f47a1f06..52215019 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -479,9 +479,13 @@ let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string = let type_fmt = ast_to_type_formatter fmt in let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in let signature = fun_sig_to_string fmt def.signature in - let inputs = List.map (var_to_string type_fmt) def.inputs in - let inputs = - if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n" - in - let body = texpression_to_string fmt " " " " def.body in - "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body + match def.body with + | None -> "val " ^ name ^ " :\n " ^ signature + | Some body -> + let inputs = List.map (var_to_string type_fmt) body.inputs in + let inputs = + if inputs = [] then "" + else " fun " ^ String.concat " " inputs ^ " ->\n" + in + let body = texpression_to_string fmt " " " " body.body in + "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body diff --git a/src/Pure.ml b/src/Pure.ml index a4b193f7..79801440 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -596,6 +596,14 @@ type fun_sig = { type inst_fun_sig = { inputs : ty list; outputs : ty list } +type fun_body = { + inputs : var list; + inputs_lvs : typed_lvalue list; + (** The inputs seen as lvalues. Allows to make transformations, for example + to replace unused variables by `_` *) + body : texpression; +} + type fun_decl = { def_id : FunDeclId.id; back_id : T.RegionGroupId.id option; @@ -606,9 +614,5 @@ type fun_decl = { (to identify the forward/backward functions) later. *) signature : fun_sig; - inputs : var list; - inputs_lvs : typed_lvalue list; - (** The inputs seen as lvalues. Allows to make transformations, for example - to replace unused variables by `_` *) - body : texpression; + body : fun_body option; } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index b110f829..5227b2ad 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -311,14 +311,22 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (ctx, e.e) in - let input_names = - List.filter_map - (fun (v : var) -> - match v.basename with None -> None | Some name -> Some (v.id, name)) - def.inputs + let body = + match def.body with + | None -> None + | Some body -> + let input_names = + List.filter_map + (fun (v : var) -> + match v.basename with + | None -> None + | Some name -> Some (v.id, name)) + body.inputs + in + let ctx = VarId.Map.of_list input_names in + let _, body_exp = update_texpression body.body ctx in + Some { body with body = body_exp } in - let ctx = VarId.Map.of_list input_names in - let _, body = update_texpression def.body ctx in { def with body } (** Remove the meta-information *) @@ -330,8 +338,11 @@ let remove_meta (def : fun_decl) : fun_decl = method! visit_Meta env _ e = super#visit_expression env e.e end in - let body = obj#visit_texpression () def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } (** Inline the useless variable (re-)assignments: @@ -452,8 +463,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) (** Visit the places used as rvalues, to substitute them if possible *) end in - let body = obj#visit_texpression VarId.Map.empty def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = + { body with body = obj#visit_texpression VarId.Map.empty body.body } + in + { def with body = Some body } (** Given a forward or backward function call, is there, for every execution path, a child backward function called later with exactly the same input @@ -683,22 +699,29 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) dont_filter () end in - (* Visit the body *) - let body, used_vars = expr_visitor#visit_texpression () def.body in - (* Visit the parameters - TODO: update: we can filter only if the definition - * is not recursive (otherwise it might mess up with the decrease clauses: - * the decrease clauses uses all the inputs given to the function, if some - * inputs are replaced by '_' we can't give it to the function used in the - * decreases clause). - * For now we deactivate the filtering. *) - let used_vars = used_vars () in - let inputs_lvs = - if false then - List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs - else def.inputs_lvs - in - (* Return *) - { def with body; inputs_lvs } + (* We filter only inside of transparent (i.e., non-opaque) definitions *) + match def.body with + | None -> def + | Some body -> + (* Visit the body *) + let body_exp, used_vars = expr_visitor#visit_texpression () body.body in + (* Visit the parameters - TODO: update: we can filter only if the definition + * is not recursive (otherwise it might mess up with the decrease clauses: + * the decrease clauses uses all the inputs given to the function, if some + * inputs are replaced by '_' we can't give it to the function used in the + * decreases clause). + * For now we deactivate the filtering. *) + let used_vars = used_vars () in + let inputs_lvs = + if false then + List.map + (fun lv -> fst (filter_typed_lvalue used_vars lv)) + body.inputs_lvs + else body.inputs_lvs + in + (* Return *) + let body = { body with body = body_exp; inputs_lvs } in + { def with body = Some body } (** Return `None` if the function is a backward function with no outputs (so that we eliminate the definition which is useless). @@ -763,21 +786,31 @@ let to_monadic (config : config) (def : fun_decl) : fun_decl = super#visit_call env call end in - let body = obj#visit_texpression () def.body in - let def = { def with body } in + let def = + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } + in (* Update the signature: first the input types *) let def = - if def.inputs = [] && config.add_unit_args then ( - assert (def.signature.inputs = []); + if def.signature.inputs = [] && config.add_unit_args then let signature = { def.signature with inputs = [ unit_ty ] } in - let var_cnt = get_expression_min_var_counter def.body.e in - let id, _ = VarId.fresh var_cnt in - let var = { id; basename = None; ty = unit_ty } in - let inputs = [ var ] in - let input_lv = mk_typed_lvalue_from_var var None in - let inputs_lvs = [ input_lv ] in - { def with signature; inputs; inputs_lvs }) + let body = + match def.body with + | None -> None + | Some body -> + let var_cnt = get_expression_min_var_counter body.body.e in + let id, _ = VarId.fresh var_cnt in + let var = { id; basename = None; ty = unit_ty } in + let inputs = [ var ] in + let input_lv = mk_typed_lvalue_from_var var None in + let inputs_lvs = [ input_lv ] in + Some { body with inputs; inputs_lvs } + in + { def with signature; body } else def in (* Then the output type *) @@ -830,11 +863,15 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = end in (* Update the body *) - let body = obj#visit_texpression () def.body in - (* Update the input parameters *) - let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in - (* Return *) - { def with body; inputs_lvs } + match def.body with + | None -> def + | Some body -> + let body_exp = obj#visit_texpression () body.body in + (* Update the input parameters *) + let inputs_lvs = List.map (obj#visit_typed_lvalue ()) body.inputs_lvs in + (* Return *) + let body = Some { body with body = body_exp; inputs_lvs } in + { def with body } (** Eliminate the box functions like `Box::new`, `Box::deref`, etc. Most of them are translated to identity, and `Box::free` is translated to `()`. @@ -887,178 +924,201 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = end in (* Update the body *) - let body = obj#visit_texpression () def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = Some { body with body = obj#visit_texpression () body.body } in + { def with body } (** Decompose the monadic let-bindings. See the explanations in [config]. *) -let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl - = - (* Set up the var id generator *) - let cnt = get_expression_min_var_counter def.body.e in - let _, fresh_id = VarId.mk_stateful_generator cnt in - (* It is a very simple map *) - let obj = - object (self) - inherit [_] map_expression as super - - method! visit_Let env monadic lv re next_e = - if not monadic then super#visit_Let env monadic lv re next_e - else - (* If monadic, we need to check if the left-value is a variable: - * - if yes, don't decompose - * - if not, make the decomposition in two steps - *) - match lv.value with - | LvVar _ -> - (* Variable: nothing to do *) - super#visit_Let env monadic lv re next_e - | _ -> - (* Not a variable: decompose *) - (* Introduce a temporary variable to receive the value of the - * monadic binding *) - let vid = fresh_id () in - let tmp : var = { id = vid; basename = None; ty = lv.ty } in - let ltmp = mk_typed_lvalue_from_var tmp None in - let rtmp = mk_typed_rvalue_from_var tmp in - let rtmp = mk_value_expression rtmp None in - (* Visit the next expression *) - let next_e = self#visit_texpression env next_e in - (* Create the let-bindings *) - (mk_let true ltmp re (mk_let false lv rtmp next_e)).e - end - in - (* Update the body *) - let body = obj#visit_texpression () def.body in - (* Return *) - { def with body } +let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : + fun_decl = + match def.body with + | None -> def + | Some body -> + (* Set up the var id generator *) + let cnt = get_expression_min_var_counter body.body.e in + let _, fresh_id = VarId.mk_stateful_generator cnt in + (* It is a very simple map *) + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let env monadic lv re next_e = + if not monadic then super#visit_Let env monadic lv re next_e + else + (* If monadic, we need to check if the left-value is a variable: + * - if yes, don't decompose + * - if not, make the decomposition in two steps + *) + match lv.value with + | LvVar _ -> + (* Variable: nothing to do *) + super#visit_Let env monadic lv re next_e + | _ -> + (* Not a variable: decompose *) + (* Introduce a temporary variable to receive the value of the + * monadic binding *) + let vid = fresh_id () in + let tmp : var = { id = vid; basename = None; ty = lv.ty } in + let ltmp = mk_typed_lvalue_from_var tmp None in + let rtmp = mk_typed_rvalue_from_var tmp in + let rtmp = mk_value_expression rtmp None in + (* Visit the next expression *) + let next_e = self#visit_texpression env next_e in + (* Create the let-bindings *) + (mk_let true ltmp re (mk_let false lv rtmp next_e)).e + end + in + (* Update the body *) + let body = Some { body with body = obj#visit_texpression () body.body } in + (* Return *) + { def with body } (** Unfold the monadic let-bindings to explicit matches. *) let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) (def : fun_decl) : fun_decl = - (* We may need to introduce fresh variables for the state *) - let var_cnt = get_expression_min_var_counter def.body.e in - let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in - let fresh_state_var () = - let id = fresh_var_id () in - { id; basename = Some "st"; ty = mk_state_ty } - in - (* It is a very simple map *) - let obj = - object (self) - inherit [_] map_expression as super - - method! visit_Let state_var monadic lv re e = - if not monadic then super#visit_Let state_var monadic lv re e - else - (* We don't do the same thing if we use a state-error monad or simply - * an error monad. - * Note that some functions always live in the error monad (arithmetic - * operations, for instance). - *) - let re_call = - match re.e with - | Call call -> call - | _ -> raise (Failure "Unreachable: expected a function call") - in - (* TODO: this information should be computed in SymbolicToPure and - * store in an enum ("monadic" should be an enum, not a bool). - * Also: everything will be cleaner once we update the AST to make - * it more idiomatic lambda calculus... *) - let re_call_can_use_state = - match re_call.func with - | Regular (A.Local _, _) -> true - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false - in - if config.use_state_monad && re_call_can_use_state then - let re_call = - let call = re_call in - let state_value = mk_typed_rvalue_from_var state_var in - let args = call.args @ [ mk_value_expression state_value None ] in - Call { call with args } - in - let re = { re with e = re_call } in - (* Create the match *) - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } - in - (* The `Success` branch introduces a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_typed_lvalue_from_var state_var None in - let success_pat = - mk_result_return_lvalue - (mk_simpl_tuple_lvalue [ state_value; lv ]) - in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression state_var e - else - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } - in - let success_pat = mk_result_return_lvalue lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression state_var e - - method! visit_Value state_var rv mp = - if config.use_state_monad then - match rv.ty with - | Adt (Assumed Result, _) -> ( - match rv.value with - | RvAdt av -> - (* We only need to replace the content of `Return ...` *) - (* TODO: type checking is completely broken at this point... *) - let variant_id = Option.get av.variant_id in - if variant_id = result_return_id then - let res_v = Collections.List.to_cons_nil av.field_values in - let state_value = mk_typed_rvalue_from_var state_var in - let res = mk_simpl_tuple_rvalue [ state_value; res_v ] in - let res = mk_result_return_rvalue res in - (mk_value_expression res None).e - else super#visit_Value state_var rv mp - | _ -> raise (Failure "Unrechable")) - | _ -> super#visit_Value state_var rv mp - else super#visit_Value state_var rv mp - (** We also need to update values, in case this value is `Return ...`. - + match def.body with + | None -> def + | Some body -> + (* We may need to introduce fresh variables for the state *) + let var_cnt = get_expression_min_var_counter body.body.e in + let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in + let fresh_state_var () = + let id = fresh_var_id () in + { id; basename = Some "st"; ty = mk_state_ty } + in + (* It is a very simple map *) + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let state_var monadic lv re e = + if not monadic then super#visit_Let state_var monadic lv re e + else + (* We don't do the same thing if we use a state-error monad or simply + * an error monad. + * Note that some functions always live in the error monad (arithmetic + * operations, for instance). + *) + let re_call = + match re.e with + | Call call -> call + | _ -> raise (Failure "Unreachable: expected a function call") + in + (* TODO: this information should be computed in SymbolicToPure and + * store in an enum ("monadic" should be an enum, not a bool). + * Also: everything will be cleaner once we update the AST to make + * it more idiomatic lambda calculus... *) + let re_call_can_use_state = + match re_call.func with + | Regular (A.Local _, _) -> true + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false + in + if config.use_state_monad && re_call_can_use_state then + let re_call = + let call = re_call in + let state_value = mk_typed_rvalue_from_var state_var in + let args = + call.args @ [ mk_value_expression state_value None ] + in + Call { call with args } + in + let re = { re with e = re_call } in + (* Create the match *) + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { + pat = fail_pat; + branch = mk_value_expression fail_value None; + } + in + (* The `Success` branch introduces a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_lvalue_from_var state_var None in + let success_pat = + mk_result_return_lvalue + (mk_simpl_tuple_lvalue [ state_value; lv ]) + in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + else + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { + pat = fail_pat; + branch = mk_value_expression fail_value None; + } + in + let success_pat = mk_result_return_lvalue lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + + method! visit_Value state_var rv mp = + if config.use_state_monad then + match rv.ty with + | Adt (Assumed Result, _) -> ( + match rv.value with + | RvAdt av -> + (* We only need to replace the content of `Return ...` *) + (* TODO: type checking is completely broken at this point... *) + let variant_id = Option.get av.variant_id in + if variant_id = result_return_id then + let res_v = + Collections.List.to_cons_nil av.field_values + in + let state_value = mk_typed_rvalue_from_var state_var in + let res = + mk_simpl_tuple_rvalue [ state_value; res_v ] + in + let res = mk_result_return_rvalue res in + (mk_value_expression res None).e + else super#visit_Value state_var rv mp + | _ -> raise (Failure "Unrechable")) + | _ -> super#visit_Value state_var rv mp + else super#visit_Value state_var rv mp + (** We also need to update values, in case this value is `Return ...`. + TODO: this is super ugly... We need to use the monadic functions - `fail` and `return` instead. - *) - end - in - (* Update the body *) - let input_state_var = fresh_state_var () in - let body = obj#visit_texpression input_state_var def.body in - let def = { def with body } in - (* We need to update the type if we revealed the state monad *) - let def = - if config.use_state_monad then - (* Update the signature *) - let sg = def.signature in - let sg_inputs = sg.inputs @ [ mk_state_ty ] in - let sg_outputs = Collections.List.to_cons_nil sg.outputs in - let _, sg_outputs = dest_arrow_ty sg_outputs in - let sg_outputs = [ sg_outputs ] in - let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in - (* Update the inputs list *) - let inputs = def.inputs @ [ input_state_var ] in - let input_lv = mk_typed_lvalue_from_var input_state_var None in - let inputs_lvs = def.inputs_lvs @ [ input_lv ] in - (* Update the definition *) - { def with signature = sg; inputs; inputs_lvs } - else def - in - (* Return *) - { def with body } + `fail` and `return` instead. + *) + end + in + (* Update the body *) + let input_state_var = fresh_state_var () in + let body = + { body with body = obj#visit_texpression input_state_var body.body } + in + (* We need to update the type if we revealed the state monad *) + let body, signature = + if config.use_state_monad then + (* Update the signature *) + let sg = def.signature in + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Update the inputs list *) + let inputs = body.inputs @ [ input_state_var ] in + let input_lv = mk_typed_lvalue_from_var input_state_var None in + let inputs_lvs = body.inputs_lvs @ [ input_lv ] in + (* Update the body *) + let body = { body with inputs; inputs_lvs } in + (body, sg) + else (body, def.signature) + in + (* Return *) + { def with body = Some body; signature } (** Apply all the micro-passes to a function. @@ -1149,8 +1209,8 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy - ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); + ("decompose_monadic_let_bindings:\n\n" + ^ fun_decl_to_string ctx def ^ "\n")); def) else ( log#ldebug diff --git a/src/PureUtils.ml b/src/PureUtils.ml index cfc8a270..2a56b6e0 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -251,8 +251,11 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = in try - obj#visit_texpression () fdef.body; - true + match fdef.body with + | None -> true + | Some body -> + obj#visit_texpression () body.body; + true with Utils.Found -> false in List.for_all body_only_calls_itself funs diff --git a/src/Substitute.ml b/src/Substitute.ml index 10d6d419..81f6985b 100644 --- a/src/Substitute.ml +++ b/src/Substitute.ml @@ -322,15 +322,15 @@ and switch_targets_substitute (tsubst : T.TypeVarId.id -> T.ety) (** Apply a type substitution to a function body. Return the local variables and the body. *) -let fun_decl_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) - (def : A.fun_decl) : A.var list * A.statement = +let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) + (body : A.fun_body) : A.var list * A.statement = let rsubst r = r in let locals = List.map (fun v -> { v with A.var_ty = ty_substitute rsubst tsubst v.A.var_ty }) - def.A.locals + body.A.locals in - let body = statement_substitute tsubst def.body in + let body = statement_substitute tsubst body.body in (locals, body) (** Substitute a function signature *) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index fd41f094..6c35e541 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -1443,8 +1443,8 @@ and translate_meta (config : config) (meta : S.meta) (e : S.expression) let ty = next_e.ty in { e; ty } -let translate_fun_decl (config : config) (ctx : bs_ctx) (body : S.expression) : - fun_decl = +let translate_fun_decl (config : config) (ctx : bs_ctx) + (body : S.expression option) : fun_decl = let def = ctx.fun_decl in let bid = ctx.bid in log#ldebug @@ -1455,35 +1455,43 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (body : S.expression) : ^ Print.option_to_string T.RegionGroupId.to_string bid ^ ")")); - (* Translate the function *) + (* Translate the declaration *) let def_id = def.A.def_id in let basename = def.name in let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in - let body = translate_expression config body ctx in - (* Compute the list of (properly ordered) input variables *) - let backward_inputs : var list = - match bid with - | None -> [] - | Some back_id -> - let parents_ids = - list_ordered_parent_region_groups def.signature back_id + (* Translate the body, if there is *) + let body = + match body with + | None -> None + | Some body -> + let body = translate_expression config body ctx in + (* Compute the list of (properly ordered) input variables *) + let backward_inputs : var list = + match bid with + | None -> [] + | Some back_id -> + let parents_ids = + list_ordered_parent_region_groups def.signature back_id + in + let backward_ids = List.append parents_ids [ back_id ] in + List.concat + (List.map + (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) + backward_ids) in - let backward_ids = List.append parents_ids [ back_id ] in - List.concat - (List.map - (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) - backward_ids) - in - let inputs = List.append ctx.forward_inputs backward_inputs in - let inputs_lvs = List.map (fun v -> mk_typed_lvalue_from_var v None) inputs in - (* Sanity check *) - assert ( - List.for_all - (fun (var, ty) -> (var : var).ty = ty) - (List.combine inputs signature.inputs)); - let def = - { def_id; back_id = bid; basename; signature; inputs; inputs_lvs; body } + let inputs = List.append ctx.forward_inputs backward_inputs in + let inputs_lvs = + List.map (fun v -> mk_typed_lvalue_from_var v None) inputs + in + (* Sanity check *) + assert ( + List.for_all + (fun (var, ty) -> (var : var).ty = ty) + (List.combine inputs signature.inputs)); + Some { inputs; inputs_lvs; body } in + (* Assemble the declaration *) + let def = { def_id; back_id = bid; basename; signature; body } in (* Debugging *) log#ldebug (lazy diff --git a/src/Translate.ml b/src/Translate.ml index ce669525..b522aeb7 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -152,11 +152,11 @@ let translate_function_to_pure (config : C.partial_config) in (* We need to initialize the input/output variables *) - let forward_input_vars = LlbcAstUtils.fun_decl_get_input_vars fdef in + let forward_input_vars = LlbcAstUtils.fun_body_get_input_vars body in let forward_input_varnames = List.map (fun (v : A.var) -> v.name) forward_input_vars in - let num_forward_inputs = fdef.arg_count in + let num_forward_inputs = body.arg_count in let add_forward_inputs input_svs ctx = let input_svs = List.combine forward_input_varnames input_svs in let ctx, forward_inputs = @@ -276,7 +276,7 @@ let translate_module_to_pure (config : C.partial_config) ( A.Local fdef.def_id, List.map (fun (v : A.var) -> v.name) - (LlbcAstUtils.fun_decl_get_input_vars fdef), + (LlbcAstUtils.fun_body_get_input_vars fdef), fdef.signature )) m.functions in @@ -285,7 +285,7 @@ let translate_module_to_pure (config : C.partial_config) SymbolicToPure.translate_fun_signatures type_context.type_infos sigs in - (* Translate all the functions *) + (* Translate all the *transparent* functions *) let pure_translations = List.map (translate_function_to_pure config mp_config trans_ctx fun_sigs -- cgit v1.2.3