summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml50
-rw-r--r--src/Interpreter.ml42
-rw-r--r--src/InterpreterStatements.ml23
-rw-r--r--src/LlbcAst.ml19
-rw-r--r--src/LlbcAstUtils.ml11
-rw-r--r--src/LlbcOfJson.ml18
-rw-r--r--src/PrePasses.ml6
-rw-r--r--src/Print.ml73
-rw-r--r--src/PrintPure.ml16
-rw-r--r--src/Pure.ml14
-rw-r--r--src/PureMicroPasses.ml474
-rw-r--r--src/PureUtils.ml7
-rw-r--r--src/Substitute.ml8
-rw-r--r--src/SymbolicToPure.ml60
-rw-r--r--src/Translate.ml8
15 files changed, 475 insertions, 354 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 6bbc21d7..2e0568c8 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -1025,22 +1025,25 @@ let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ());
(* The input parameters - note that doing this adds bindings to the context *)
let ctx_body =
- List.fold_left
- (fun ctx (lv : typed_lvalue) ->
- (* Open a box for the input parameter *)
- F.pp_open_hovbox fmt 0;
- F.pp_print_string fmt "(";
- let ctx = extract_typed_lvalue ctx fmt false lv in
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_ty ctx fmt false lv.ty;
- F.pp_print_string fmt ")";
- (* Close the box for the input parameters *)
- F.pp_close_box fmt ();
- F.pp_print_space fmt ();
- ctx)
- ctx def.inputs_lvs
+ match def.body with
+ | None -> ctx
+ | Some body ->
+ List.fold_left
+ (fun ctx (lv : typed_lvalue) ->
+ (* Open a box for the input parameter *)
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt "(";
+ let ctx = extract_typed_lvalue ctx fmt false lv in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt false lv.ty;
+ F.pp_print_string fmt ")";
+ (* Close the box for the input parameters *)
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ ctx)
+ ctx body.inputs_lvs
in
(ctx, ctx_body)
@@ -1169,7 +1172,8 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(Collections.List.to_cons_nil def.signature.outputs);
(* Close the box for the return type *)
F.pp_close_box fmt ();
- (* Print the decrease clause *)
+ (* Print the decrease clause - rk.: a function with a decreases clause
+ * is necessarily a transparent function *)
if has_decreases_clause then (
F.pp_print_space fmt ();
(* Open a box for the decrease clause *)
@@ -1193,9 +1197,13 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
* function (the additional input values "given back" to the
* backward functions have no influence on termination: we thus
* share the decrease clauses between the forward and the backward
- * functions) *)
+ * functions).
+ * Rk.: if a function has a decreases clause, it is necessarily
+ * a transparent function *)
let inputs_lvs =
- Collections.List.prefix (List.length fwd_def.inputs_lvs) def.inputs_lvs
+ Collections.List.prefix
+ (List.length (Option.get fwd_def.body).inputs_lvs)
+ (Option.get def.body).inputs_lvs
in
let _ =
List.fold_left
@@ -1224,7 +1232,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the body *)
F.pp_open_hvbox fmt 0;
(* Extract the body *)
- let _ = extract_texpression ctx_body fmt false false def.body in
+ let _ =
+ extract_texpression ctx_body fmt false false (Option.get def.body).body
+ in
(* Close the box for the body *)
F.pp_close_box fmt ());
(* Close the box for the definition *)
diff --git a/src/Interpreter.ml b/src/Interpreter.ml
index 82eb4b35..f6ae268d 100644
--- a/src/Interpreter.ml
+++ b/src/Interpreter.ml
@@ -91,9 +91,10 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
inst_sg.A.regions_hierarchy compute_abs_avalues ctx
in
(* Split the variables between return var, inputs and remaining locals *)
- let ret_var = List.hd fdef.locals in
+ let body = Option.get fdef.body in
+ let ret_var = List.hd body.locals in
let input_vars, local_vars =
- Collections.List.split_at (List.tl fdef.locals) fdef.arg_count
+ Collections.List.split_at (List.tl body.locals) body.arg_count
in
(* Push the return variable (initialized with ⊥) *)
let ctx = C.ctx_push_uninitialized_var ctx ret_var in
@@ -242,7 +243,9 @@ let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool)
in
(* Evaluate the function *)
- let symbolic = eval_function_body config fdef.A.body cf_finish ctx in
+ let symbolic =
+ eval_function_body config (Option.get fdef.A.body).body cf_finish ctx
+ in
(* Return *)
(input_svs, symbolic)
@@ -255,6 +258,7 @@ module Test = struct
(fid : A.FunDeclId.id) : unit =
(* Retrieve the function declaration *)
let fdef = A.FunDeclId.nth m.functions fid in
+ let body = Option.get fdef.body in
(* Debug *)
log#ldebug
@@ -263,14 +267,14 @@ module Test = struct
(* Sanity check - *)
assert (List.length fdef.A.signature.region_params = 0);
assert (List.length fdef.A.signature.type_params = 0);
- assert (fdef.A.arg_count = 0);
+ assert (body.A.arg_count = 0);
(* Create the evaluation context *)
let type_context, fun_context = compute_type_fun_contexts m in
let ctx = initialize_eval_context type_context fun_context [] in
(* Insert the (uninitialized) local variables *)
- let ctx = C.ctx_push_uninitialized_vars ctx fdef.A.locals in
+ let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in
(* Create the continuation to check the function's result *)
let config = C.config_of_partial C.ConcreteMode config in
@@ -286,21 +290,24 @@ module Test = struct
in
(* Evaluate the function *)
- let _ = eval_function_body config fdef.A.body cf_check ctx in
+ let _ = eval_function_body config body.body cf_check ctx in
()
- (** Small helper: return true if the function is a unit function (no parameters,
- no arguments) - TODO: move *)
- let fun_decl_is_unit (def : A.fun_decl) : bool =
- def.A.arg_count = 0
- && List.length def.A.signature.region_params = 0
- && List.length def.A.signature.type_params = 0
- && List.length def.A.signature.inputs = 0
+ (** Small helper: return true if the function is a *transparent* unit function
+ (no parameters, no arguments) - TODO: move *)
+ let fun_decl_is_transparent_unit (def : A.fun_decl) : bool =
+ match def.body with
+ | None -> false
+ | Some body ->
+ body.arg_count = 0
+ && List.length def.A.signature.region_params = 0
+ && List.length def.A.signature.type_params = 0
+ && List.length def.A.signature.inputs = 0
(** Test all the unit functions in a list of function definitions *)
let test_unit_functions (config : C.partial_config) (m : M.llbc_module) : unit
=
- let unit_funs = List.filter fun_decl_is_unit m.functions in
+ let unit_funs = List.filter fun_decl_is_transparent_unit m.functions in
let test_unit_fun (def : A.fun_decl) : unit =
test_unit_function config m def.A.def_id
in
@@ -329,6 +336,10 @@ module Test = struct
()
+ (** Small helper *)
+ let fun_decl_is_transparent (def : A.fun_decl) : bool =
+ Option.is_some def.body
+
(** Execute the symbolic interpreter on a list of functions.
TODO: for now we ignore the functions which contain loops, because
@@ -336,9 +347,12 @@ module Test = struct
*)
let test_functions_symbolic (config : C.partial_config) (synthesize : bool)
(m : M.llbc_module) : unit =
+ (* Filter the functions which contain loops *)
let no_loop_funs =
List.filter (fun f -> not (LlbcAstUtils.fun_decl_has_loops f)) m.functions
in
+ (* Filter the opaque functions *)
+ let no_loop_funs = List.filter fun_decl_is_transparent no_loop_funs in
let type_context, fun_context = compute_type_fun_contexts m in
let test_fun (def : A.fun_decl) : unit =
(* Execute the function - note that as the symbolic interpreter explores
diff --git a/src/InterpreterStatements.ml b/src/InterpreterStatements.ml
index 547f5ee3..3ea0a6fa 100644
--- a/src/InterpreterStatements.ml
+++ b/src/InterpreterStatements.ml
@@ -233,8 +233,9 @@ let set_discriminant (config : C.config) (p : E.place)
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value ctx.type_context.type_decls
- def_id (Some variant_id) regions types
+ compute_expanded_bottom_adt_value
+ ctx.type_context.type_decls def_id (Some variant_id)
+ regions types
| T.Assumed T.Option ->
assert (regions = []);
compute_expanded_bottom_option_value variant_id
@@ -1005,15 +1006,21 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
(* Retrieve the (correctly instantiated) body *)
let def = C.ctx_lookup_fun_decl ctx fid in
+ (* We can evaluate the function call only if it is not opaque *)
+ let body =
+ match def.body with
+ | None -> raise (Failure "Can't evaluate a call to an opaque function")
+ | Some body -> body
+ in
let tsubst =
Subst.make_type_subst
(List.map (fun v -> v.T.index) def.A.signature.type_params)
type_params
in
- let locals, body = Subst.fun_decl_substitute_in_body tsubst def in
+ let locals, body_st = Subst.fun_body_substitute_in_body tsubst body in
(* Evaluate the input operands *)
- assert (List.length args = def.A.arg_count);
+ assert (List.length args = body.A.arg_count);
let cc = eval_operands config args in
(* Push a frame delimiter - we use [comp_transmit] to transmit the result
@@ -1028,7 +1035,9 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
| ret_ty :: locals -> (ret_ty, locals)
| _ -> raise (Failure "Unreachable")
in
- let input_locals, locals = Collections.List.split_at locals def.A.arg_count in
+ let input_locals, locals =
+ Collections.List.split_at locals body.A.arg_count
+ in
let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in
@@ -1045,7 +1054,7 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
let cc = comp cc (push_uninitialized_vars locals) in
(* Execute the function body *)
- let cc = comp cc (eval_function_body config body) in
+ let cc = comp cc (eval_function_body config body_st) in
(* Pop the stack frame and move the return value to its destination *)
let cf_finish cf res =
@@ -1074,7 +1083,7 @@ and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id)
* while doing so *)
let inst_sg = instantiate_fun_sig type_params sg in
(* Sanity check *)
- assert (List.length args = def.A.arg_count);
+ assert (List.length args = List.length def.A.signature.inputs);
(* Evaluate the function call *)
eval_function_call_symbolic_from_inst_sig config (A.Local fid) inst_sg
region_params type_params args dest cf ctx
diff --git a/src/LlbcAst.ml b/src/LlbcAst.ml
index 149fb23d..f5ffc956 100644
--- a/src/LlbcAst.ml
+++ b/src/LlbcAst.ml
@@ -170,24 +170,6 @@ and switch_targets =
concrete = true;
}]
-type fun_decl = {
- def_id : FunDeclId.id;
- name : fun_name;
- signature : fun_sig;
- arg_count : int;
- locals : var list;
- body : statement;
-}
-[@@deriving show]
-(** TODO: function definitions (and maybe type definitions in the future)
- * contain information like `divergent`. I wonder if this information should
- * be stored directly inside the definitions or inside separate maps/sets.
- * Of course, if everything is stored in separate maps/sets, nothing
- * prevents us from computing this info in Charon (and thus exporting directly
- * it with the type/function defs), in which case we just have to implement special
- * treatment when deserializing, to move the info to a separate map. *)
-
-(*
type fun_body = { arg_count : int; locals : var list; body : statement }
[@@deriving show]
@@ -198,4 +180,3 @@ type fun_decl = {
body : fun_body option;
}
[@@deriving show]
-*)
diff --git a/src/LlbcAstUtils.ml b/src/LlbcAstUtils.ml
index 93ca4448..41d17bf4 100644
--- a/src/LlbcAstUtils.ml
+++ b/src/LlbcAstUtils.ml
@@ -17,7 +17,10 @@ let statement_has_loops (st : statement) : bool =
with Found -> true
(** Check if a [fun_decl] contains loops *)
-let fun_decl_has_loops (fd : fun_decl) : bool = statement_has_loops fd.body
+let fun_decl_has_loops (fd : fun_decl) : bool =
+ match fd.body with
+ | Some body -> statement_has_loops body.body
+ | None -> false
let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) :
fun_sig =
@@ -64,6 +67,6 @@ let list_ordered_parent_region_groups (sg : fun_sig) (gid : T.RegionGroupId.id)
let parents = List.map (fun (rg : T.region_var_group) -> rg.id) parents in
parents
-let fun_decl_get_input_vars (fdef : fun_decl) : var list =
- let locals = List.tl fdef.locals in
- Collections.List.prefix fdef.arg_count locals
+let fun_body_get_input_vars (fbody : fun_body) : var list =
+ let locals = List.tl fbody.locals in
+ Collections.List.prefix fbody.arg_count locals
diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml
index e293b030..6e0adfb6 100644
--- a/src/LlbcOfJson.ml
+++ b/src/LlbcOfJson.ml
@@ -607,6 +607,16 @@ and switch_targets_of_json (js : json) : (A.switch_targets, string) result =
Ok (A.SwitchInt (int_ty, tgts, otherwise))
| _ -> Error "")
+let fun_body_of_json (js : json) : (A.fun_body, string) result =
+ combine_error_msgs js "fun_body_of_json"
+ (match js with
+ | `Assoc [ ("arg_count", arg_count); ("locals", locals); ("body", body) ] ->
+ let* arg_count = int_of_json arg_count in
+ let* locals = list_of_json var_of_json locals in
+ let* body = statement_of_json body in
+ Ok { A.arg_count; locals; body }
+ | _ -> Error "")
+
let fun_decl_of_json (js : json) : (A.fun_decl, string) result =
combine_error_msgs js "fun_decl_of_json"
(match js with
@@ -615,17 +625,13 @@ let fun_decl_of_json (js : json) : (A.fun_decl, string) result =
("def_id", def_id);
("name", name);
("signature", signature);
- ("arg_count", arg_count);
- ("locals", locals);
("body", body);
] ->
let* def_id = A.FunDeclId.id_of_json def_id in
let* name = fun_name_of_json name in
let* signature = fun_sig_of_json signature in
- let* arg_count = int_of_json arg_count in
- let* locals = list_of_json var_of_json locals in
- let* body = statement_of_json body in
- Ok { A.def_id; name; signature; arg_count; locals; body }
+ let* body = option_of_json fun_body_of_json body in
+ Ok { A.def_id; name; signature; body }
| _ -> Error "")
let g_declaration_group_of_json (id_of_json : json -> ('id, string) result)
diff --git a/src/PrePasses.ml b/src/PrePasses.ml
index 9b1a6990..dda3c867 100644
--- a/src/PrePasses.ml
+++ b/src/PrePasses.ml
@@ -42,7 +42,11 @@ let filter_drop_assigns (f : A.fun_decl) : A.fun_decl =
end
in
(* Map *)
- let body = obj#visit_statement () f.body in
+ let body =
+ match f.body with
+ | Some body -> Some { body with body = obj#visit_statement () body.body }
+ | None -> None
+ in
{ f with body }
let apply_passes (m : M.llbc_module) : M.llbc_module =
diff --git a/src/Print.ml b/src/Print.ml
index 18227c61..d2101dae 100644
--- a/src/Print.ml
+++ b/src/Print.ml
@@ -784,7 +784,7 @@ module LlbcAst = struct
PC.type_ctx_to_adt_variant_to_string_fun type_decls
in
let var_id_to_string vid =
- let var = V.VarId.nth fdef.locals vid in
+ let var = V.VarId.nth (Option.get fdef.body).locals vid in
var_to_string var
in
let adt_field_names = PC.type_ctx_to_adt_field_names_fun type_decls in
@@ -1062,41 +1062,56 @@ module LlbcAst = struct
"<" ^ String.concat "," (List.append regions types) ^ ">"
in
- (* Arguments *)
- let inputs = List.tl def.locals in
- let inputs, _aux_locals = Collections.List.split_at inputs def.arg_count in
- let args = List.combine inputs sg.inputs in
- let args =
- List.map
- (fun (var, rty) -> var_to_string var ^ " : " ^ sty_to_string rty)
- args
- in
- let args = String.concat ", " args in
-
(* Return type *)
let ret_ty = sg.output in
let ret_ty =
if TU.ty_is_unit ret_ty then "" else " -> " ^ sty_to_string ret_ty
in
- (* All the locals (with erased regions) *)
- let locals =
- List.map
- (fun var ->
- indent ^ indent_incr ^ var_to_string var ^ " : "
- ^ ety_to_string var.var_ty ^ ";")
- def.locals
- in
- let locals = String.concat "\n" locals in
+ (* We print the declaration differently if it is opaque (no body) or transparent
+ * (we have access to a body) *)
+ match def.body with
+ | None ->
+ (* Arguments - we need to ignore the first input type which is actually
+ * the return type... TODO: fix that *)
+ let input_tys = List.tl sg.inputs in
+ let args = List.map sty_to_string input_tys in
+ let args = String.concat ", " args in
+
+ (* Put everything together *)
+ indent ^ "opaque fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty
+ | Some body ->
+ (* Arguments *)
+ let inputs = List.tl body.locals in
+ let inputs, _aux_locals =
+ Collections.List.split_at inputs body.arg_count
+ in
+ let args = List.combine inputs sg.inputs in
+ let args =
+ List.map
+ (fun (var, rty) -> var_to_string var ^ " : " ^ sty_to_string rty)
+ args
+ in
+ let args = String.concat ", " args in
+
+ (* All the locals (with erased regions) *)
+ let locals =
+ List.map
+ (fun var ->
+ indent ^ indent_incr ^ var_to_string var ^ " : "
+ ^ ety_to_string var.var_ty ^ ";")
+ body.locals
+ in
+ let locals = String.concat "\n" locals in
- (* Body *)
- let body =
- statement_to_string fmt (indent ^ indent_incr) indent_incr def.body
- in
+ (* Body *)
+ let body =
+ statement_to_string fmt (indent ^ indent_incr) indent_incr body.body
+ in
- (* Put everything together *)
- indent ^ "fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty ^ " {\n" ^ locals
- ^ "\n\n" ^ body ^ "\n" ^ indent ^ "}"
+ (* Put everything together *)
+ indent ^ "fn " ^ name ^ params ^ "(" ^ args ^ ")" ^ ret_ty ^ " {\n"
+ ^ locals ^ "\n\n" ^ body ^ "\n" ^ indent ^ "}"
end
module PA = LlbcAst (* local module *)
@@ -1139,7 +1154,7 @@ module Module = struct
fun_name_to_string def.name
in
let var_id_to_string vid =
- let var = V.VarId.nth def.locals vid in
+ let var = V.VarId.nth (Option.get def.body).locals vid in
PA.var_to_string var
in
let adt_variant_to_string =
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index f47a1f06..52215019 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -479,9 +479,13 @@ let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string =
let type_fmt = ast_to_type_formatter fmt in
let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in
let signature = fun_sig_to_string fmt def.signature in
- let inputs = List.map (var_to_string type_fmt) def.inputs in
- let inputs =
- if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n"
- in
- let body = texpression_to_string fmt " " " " def.body in
- "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body
+ match def.body with
+ | None -> "val " ^ name ^ " :\n " ^ signature
+ | Some body ->
+ let inputs = List.map (var_to_string type_fmt) body.inputs in
+ let inputs =
+ if inputs = [] then ""
+ else " fun " ^ String.concat " " inputs ^ " ->\n"
+ in
+ let body = texpression_to_string fmt " " " " body.body in
+ "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body
diff --git a/src/Pure.ml b/src/Pure.ml
index a4b193f7..79801440 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -596,6 +596,14 @@ type fun_sig = {
type inst_fun_sig = { inputs : ty list; outputs : ty list }
+type fun_body = {
+ inputs : var list;
+ inputs_lvs : typed_lvalue list;
+ (** The inputs seen as lvalues. Allows to make transformations, for example
+ to replace unused variables by `_` *)
+ body : texpression;
+}
+
type fun_decl = {
def_id : FunDeclId.id;
back_id : T.RegionGroupId.id option;
@@ -606,9 +614,5 @@ type fun_decl = {
(to identify the forward/backward functions) later.
*)
signature : fun_sig;
- inputs : var list;
- inputs_lvs : typed_lvalue list;
- (** The inputs seen as lvalues. Allows to make transformations, for example
- to replace unused variables by `_` *)
- body : texpression;
+ body : fun_body option;
}
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index b110f829..5227b2ad 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -311,14 +311,22 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(ctx, e.e)
in
- let input_names =
- List.filter_map
- (fun (v : var) ->
- match v.basename with None -> None | Some name -> Some (v.id, name))
- def.inputs
+ let body =
+ match def.body with
+ | None -> None
+ | Some body ->
+ let input_names =
+ List.filter_map
+ (fun (v : var) ->
+ match v.basename with
+ | None -> None
+ | Some name -> Some (v.id, name))
+ body.inputs
+ in
+ let ctx = VarId.Map.of_list input_names in
+ let _, body_exp = update_texpression body.body ctx in
+ Some { body with body = body_exp }
in
- let ctx = VarId.Map.of_list input_names in
- let _, body = update_texpression def.body ctx in
{ def with body }
(** Remove the meta-information *)
@@ -330,8 +338,11 @@ let remove_meta (def : fun_decl) : fun_decl =
method! visit_Meta env _ e = super#visit_expression env e.e
end
in
- let body = obj#visit_texpression () def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
(** Inline the useless variable (re-)assignments:
@@ -452,8 +463,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
(** Visit the places used as rvalues, to substitute them if possible *)
end
in
- let body = obj#visit_texpression VarId.Map.empty def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body =
+ { body with body = obj#visit_texpression VarId.Map.empty body.body }
+ in
+ { def with body = Some body }
(** Given a forward or backward function call, is there, for every execution
path, a child backward function called later with exactly the same input
@@ -683,22 +699,29 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
dont_filter ()
end
in
- (* Visit the body *)
- let body, used_vars = expr_visitor#visit_texpression () def.body in
- (* Visit the parameters - TODO: update: we can filter only if the definition
- * is not recursive (otherwise it might mess up with the decrease clauses:
- * the decrease clauses uses all the inputs given to the function, if some
- * inputs are replaced by '_' we can't give it to the function used in the
- * decreases clause).
- * For now we deactivate the filtering. *)
- let used_vars = used_vars () in
- let inputs_lvs =
- if false then
- List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs
- else def.inputs_lvs
- in
- (* Return *)
- { def with body; inputs_lvs }
+ (* We filter only inside of transparent (i.e., non-opaque) definitions *)
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Visit the body *)
+ let body_exp, used_vars = expr_visitor#visit_texpression () body.body in
+ (* Visit the parameters - TODO: update: we can filter only if the definition
+ * is not recursive (otherwise it might mess up with the decrease clauses:
+ * the decrease clauses uses all the inputs given to the function, if some
+ * inputs are replaced by '_' we can't give it to the function used in the
+ * decreases clause).
+ * For now we deactivate the filtering. *)
+ let used_vars = used_vars () in
+ let inputs_lvs =
+ if false then
+ List.map
+ (fun lv -> fst (filter_typed_lvalue used_vars lv))
+ body.inputs_lvs
+ else body.inputs_lvs
+ in
+ (* Return *)
+ let body = { body with body = body_exp; inputs_lvs } in
+ { def with body = Some body }
(** Return `None` if the function is a backward function with no outputs (so
that we eliminate the definition which is useless).
@@ -763,21 +786,31 @@ let to_monadic (config : config) (def : fun_decl) : fun_decl =
super#visit_call env call
end
in
- let body = obj#visit_texpression () def.body in
- let def = { def with body } in
+ let def =
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
+ in
(* Update the signature: first the input types *)
let def =
- if def.inputs = [] && config.add_unit_args then (
- assert (def.signature.inputs = []);
+ if def.signature.inputs = [] && config.add_unit_args then
let signature = { def.signature with inputs = [ unit_ty ] } in
- let var_cnt = get_expression_min_var_counter def.body.e in
- let id, _ = VarId.fresh var_cnt in
- let var = { id; basename = None; ty = unit_ty } in
- let inputs = [ var ] in
- let input_lv = mk_typed_lvalue_from_var var None in
- let inputs_lvs = [ input_lv ] in
- { def with signature; inputs; inputs_lvs })
+ let body =
+ match def.body with
+ | None -> None
+ | Some body ->
+ let var_cnt = get_expression_min_var_counter body.body.e in
+ let id, _ = VarId.fresh var_cnt in
+ let var = { id; basename = None; ty = unit_ty } in
+ let inputs = [ var ] in
+ let input_lv = mk_typed_lvalue_from_var var None in
+ let inputs_lvs = [ input_lv ] in
+ Some { body with inputs; inputs_lvs }
+ in
+ { def with signature; body }
else def
in
(* Then the output type *)
@@ -830,11 +863,15 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
- (* Update the input parameters *)
- let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in
- (* Return *)
- { def with body; inputs_lvs }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body_exp = obj#visit_texpression () body.body in
+ (* Update the input parameters *)
+ let inputs_lvs = List.map (obj#visit_typed_lvalue ()) body.inputs_lvs in
+ (* Return *)
+ let body = Some { body with body = body_exp; inputs_lvs } in
+ { def with body }
(** Eliminate the box functions like `Box::new`, `Box::deref`, etc. Most of them
are translated to identity, and `Box::free` is translated to `()`.
@@ -887,178 +924,201 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
- { def with body }
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = Some { body with body = obj#visit_texpression () body.body } in
+ { def with body }
(** Decompose the monadic let-bindings.
See the explanations in [config].
*)
-let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl
- =
- (* Set up the var id generator *)
- let cnt = get_expression_min_var_counter def.body.e in
- let _, fresh_id = VarId.mk_stateful_generator cnt in
- (* It is a very simple map *)
- let obj =
- object (self)
- inherit [_] map_expression as super
-
- method! visit_Let env monadic lv re next_e =
- if not monadic then super#visit_Let env monadic lv re next_e
- else
- (* If monadic, we need to check if the left-value is a variable:
- * - if yes, don't decompose
- * - if not, make the decomposition in two steps
- *)
- match lv.value with
- | LvVar _ ->
- (* Variable: nothing to do *)
- super#visit_Let env monadic lv re next_e
- | _ ->
- (* Not a variable: decompose *)
- (* Introduce a temporary variable to receive the value of the
- * monadic binding *)
- let vid = fresh_id () in
- let tmp : var = { id = vid; basename = None; ty = lv.ty } in
- let ltmp = mk_typed_lvalue_from_var tmp None in
- let rtmp = mk_typed_rvalue_from_var tmp in
- let rtmp = mk_value_expression rtmp None in
- (* Visit the next expression *)
- let next_e = self#visit_texpression env next_e in
- (* Create the let-bindings *)
- (mk_let true ltmp re (mk_let false lv rtmp next_e)).e
- end
- in
- (* Update the body *)
- let body = obj#visit_texpression () def.body in
- (* Return *)
- { def with body }
+let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
+ fun_decl =
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Set up the var id generator *)
+ let cnt = get_expression_min_var_counter body.body.e in
+ let _, fresh_id = VarId.mk_stateful_generator cnt in
+ (* It is a very simple map *)
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_Let env monadic lv re next_e =
+ if not monadic then super#visit_Let env monadic lv re next_e
+ else
+ (* If monadic, we need to check if the left-value is a variable:
+ * - if yes, don't decompose
+ * - if not, make the decomposition in two steps
+ *)
+ match lv.value with
+ | LvVar _ ->
+ (* Variable: nothing to do *)
+ super#visit_Let env monadic lv re next_e
+ | _ ->
+ (* Not a variable: decompose *)
+ (* Introduce a temporary variable to receive the value of the
+ * monadic binding *)
+ let vid = fresh_id () in
+ let tmp : var = { id = vid; basename = None; ty = lv.ty } in
+ let ltmp = mk_typed_lvalue_from_var tmp None in
+ let rtmp = mk_typed_rvalue_from_var tmp in
+ let rtmp = mk_value_expression rtmp None in
+ (* Visit the next expression *)
+ let next_e = self#visit_texpression env next_e in
+ (* Create the let-bindings *)
+ (mk_let true ltmp re (mk_let false lv rtmp next_e)).e
+ end
+ in
+ (* Update the body *)
+ let body = Some { body with body = obj#visit_texpression () body.body } in
+ (* Return *)
+ { def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
(def : fun_decl) : fun_decl =
- (* We may need to introduce fresh variables for the state *)
- let var_cnt = get_expression_min_var_counter def.body.e in
- let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
- let fresh_state_var () =
- let id = fresh_var_id () in
- { id; basename = Some "st"; ty = mk_state_ty }
- in
- (* It is a very simple map *)
- let obj =
- object (self)
- inherit [_] map_expression as super
-
- method! visit_Let state_var monadic lv re e =
- if not monadic then super#visit_Let state_var monadic lv re e
- else
- (* We don't do the same thing if we use a state-error monad or simply
- * an error monad.
- * Note that some functions always live in the error monad (arithmetic
- * operations, for instance).
- *)
- let re_call =
- match re.e with
- | Call call -> call
- | _ -> raise (Failure "Unreachable: expected a function call")
- in
- (* TODO: this information should be computed in SymbolicToPure and
- * store in an enum ("monadic" should be an enum, not a bool).
- * Also: everything will be cleaner once we update the AST to make
- * it more idiomatic lambda calculus... *)
- let re_call_can_use_state =
- match re_call.func with
- | Regular (A.Local _, _) -> true
- | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
- in
- if config.use_state_monad && re_call_can_use_state then
- let re_call =
- let call = re_call in
- let state_value = mk_typed_rvalue_from_var state_var in
- let args = call.args @ [ mk_value_expression state_value None ] in
- Call { call with args }
- in
- let re = { re with e = re_call } in
- (* Create the match *)
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
- in
- (* The `Success` branch introduces a fresh state variable *)
- let state_var = fresh_state_var () in
- let state_value = mk_typed_lvalue_from_var state_var None in
- let success_pat =
- mk_result_return_lvalue
- (mk_simpl_tuple_lvalue [ state_value; lv ])
- in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression state_var e
- else
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
- in
- let success_pat = mk_result_return_lvalue lv in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression state_var e
-
- method! visit_Value state_var rv mp =
- if config.use_state_monad then
- match rv.ty with
- | Adt (Assumed Result, _) -> (
- match rv.value with
- | RvAdt av ->
- (* We only need to replace the content of `Return ...` *)
- (* TODO: type checking is completely broken at this point... *)
- let variant_id = Option.get av.variant_id in
- if variant_id = result_return_id then
- let res_v = Collections.List.to_cons_nil av.field_values in
- let state_value = mk_typed_rvalue_from_var state_var in
- let res = mk_simpl_tuple_rvalue [ state_value; res_v ] in
- let res = mk_result_return_rvalue res in
- (mk_value_expression res None).e
- else super#visit_Value state_var rv mp
- | _ -> raise (Failure "Unrechable"))
- | _ -> super#visit_Value state_var rv mp
- else super#visit_Value state_var rv mp
- (** We also need to update values, in case this value is `Return ...`.
-
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* We may need to introduce fresh variables for the state *)
+ let var_cnt = get_expression_min_var_counter body.body.e in
+ let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
+ let fresh_state_var () =
+ let id = fresh_var_id () in
+ { id; basename = Some "st"; ty = mk_state_ty }
+ in
+ (* It is a very simple map *)
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_Let state_var monadic lv re e =
+ if not monadic then super#visit_Let state_var monadic lv re e
+ else
+ (* We don't do the same thing if we use a state-error monad or simply
+ * an error monad.
+ * Note that some functions always live in the error monad (arithmetic
+ * operations, for instance).
+ *)
+ let re_call =
+ match re.e with
+ | Call call -> call
+ | _ -> raise (Failure "Unreachable: expected a function call")
+ in
+ (* TODO: this information should be computed in SymbolicToPure and
+ * store in an enum ("monadic" should be an enum, not a bool).
+ * Also: everything will be cleaner once we update the AST to make
+ * it more idiomatic lambda calculus... *)
+ let re_call_can_use_state =
+ match re_call.func with
+ | Regular (A.Local _, _) -> true
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
+ in
+ if config.use_state_monad && re_call_can_use_state then
+ let re_call =
+ let call = re_call in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let args =
+ call.args @ [ mk_value_expression state_value None ]
+ in
+ Call { call with args }
+ in
+ let re = { re with e = re_call } in
+ (* Create the match *)
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ {
+ pat = fail_pat;
+ branch = mk_value_expression fail_value None;
+ }
+ in
+ (* The `Success` branch introduces a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_lvalue_from_var state_var None in
+ let success_pat =
+ mk_result_return_lvalue
+ (mk_simpl_tuple_lvalue [ state_value; lv ])
+ in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+ else
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ {
+ pat = fail_pat;
+ branch = mk_value_expression fail_value None;
+ }
+ in
+ let success_pat = mk_result_return_lvalue lv in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+
+ method! visit_Value state_var rv mp =
+ if config.use_state_monad then
+ match rv.ty with
+ | Adt (Assumed Result, _) -> (
+ match rv.value with
+ | RvAdt av ->
+ (* We only need to replace the content of `Return ...` *)
+ (* TODO: type checking is completely broken at this point... *)
+ let variant_id = Option.get av.variant_id in
+ if variant_id = result_return_id then
+ let res_v =
+ Collections.List.to_cons_nil av.field_values
+ in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let res =
+ mk_simpl_tuple_rvalue [ state_value; res_v ]
+ in
+ let res = mk_result_return_rvalue res in
+ (mk_value_expression res None).e
+ else super#visit_Value state_var rv mp
+ | _ -> raise (Failure "Unrechable"))
+ | _ -> super#visit_Value state_var rv mp
+ else super#visit_Value state_var rv mp
+ (** We also need to update values, in case this value is `Return ...`.
+
TODO: this is super ugly... We need to use the monadic functions
- `fail` and `return` instead.
- *)
- end
- in
- (* Update the body *)
- let input_state_var = fresh_state_var () in
- let body = obj#visit_texpression input_state_var def.body in
- let def = { def with body } in
- (* We need to update the type if we revealed the state monad *)
- let def =
- if config.use_state_monad then
- (* Update the signature *)
- let sg = def.signature in
- let sg_inputs = sg.inputs @ [ mk_state_ty ] in
- let sg_outputs = Collections.List.to_cons_nil sg.outputs in
- let _, sg_outputs = dest_arrow_ty sg_outputs in
- let sg_outputs = [ sg_outputs ] in
- let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
- (* Update the inputs list *)
- let inputs = def.inputs @ [ input_state_var ] in
- let input_lv = mk_typed_lvalue_from_var input_state_var None in
- let inputs_lvs = def.inputs_lvs @ [ input_lv ] in
- (* Update the definition *)
- { def with signature = sg; inputs; inputs_lvs }
- else def
- in
- (* Return *)
- { def with body }
+ `fail` and `return` instead.
+ *)
+ end
+ in
+ (* Update the body *)
+ let input_state_var = fresh_state_var () in
+ let body =
+ { body with body = obj#visit_texpression input_state_var body.body }
+ in
+ (* We need to update the type if we revealed the state monad *)
+ let body, signature =
+ if config.use_state_monad then
+ (* Update the signature *)
+ let sg = def.signature in
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Update the inputs list *)
+ let inputs = body.inputs @ [ input_state_var ] in
+ let input_lv = mk_typed_lvalue_from_var input_state_var None in
+ let inputs_lvs = body.inputs_lvs @ [ input_lv ] in
+ (* Update the body *)
+ let body = { body with inputs; inputs_lvs } in
+ (body, sg)
+ else (body, def.signature)
+ in
+ (* Return *)
+ { def with body = Some body; signature }
(** Apply all the micro-passes to a function.
@@ -1149,8 +1209,8 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
let def = decompose_monadic_let_bindings ctx def in
log#ldebug
(lazy
- ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
+ ("decompose_monadic_let_bindings:\n\n"
+ ^ fun_decl_to_string ctx def ^ "\n"));
def)
else (
log#ldebug
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index cfc8a270..2a56b6e0 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -251,8 +251,11 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool =
in
try
- obj#visit_texpression () fdef.body;
- true
+ match fdef.body with
+ | None -> true
+ | Some body ->
+ obj#visit_texpression () body.body;
+ true
with Utils.Found -> false
in
List.for_all body_only_calls_itself funs
diff --git a/src/Substitute.ml b/src/Substitute.ml
index 10d6d419..81f6985b 100644
--- a/src/Substitute.ml
+++ b/src/Substitute.ml
@@ -322,15 +322,15 @@ and switch_targets_substitute (tsubst : T.TypeVarId.id -> T.ety)
(** Apply a type substitution to a function body. Return the local variables
and the body. *)
-let fun_decl_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety)
- (def : A.fun_decl) : A.var list * A.statement =
+let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety)
+ (body : A.fun_body) : A.var list * A.statement =
let rsubst r = r in
let locals =
List.map
(fun v -> { v with A.var_ty = ty_substitute rsubst tsubst v.A.var_ty })
- def.A.locals
+ body.A.locals
in
- let body = statement_substitute tsubst def.body in
+ let body = statement_substitute tsubst body.body in
(locals, body)
(** Substitute a function signature *)
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index fd41f094..6c35e541 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -1443,8 +1443,8 @@ and translate_meta (config : config) (meta : S.meta) (e : S.expression)
let ty = next_e.ty in
{ e; ty }
-let translate_fun_decl (config : config) (ctx : bs_ctx) (body : S.expression) :
- fun_decl =
+let translate_fun_decl (config : config) (ctx : bs_ctx)
+ (body : S.expression option) : fun_decl =
let def = ctx.fun_decl in
let bid = ctx.bid in
log#ldebug
@@ -1455,35 +1455,43 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (body : S.expression) :
^ Print.option_to_string T.RegionGroupId.to_string bid
^ ")"));
- (* Translate the function *)
+ (* Translate the declaration *)
let def_id = def.A.def_id in
let basename = def.name in
let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
- let body = translate_expression config body ctx in
- (* Compute the list of (properly ordered) input variables *)
- let backward_inputs : var list =
- match bid with
- | None -> []
- | Some back_id ->
- let parents_ids =
- list_ordered_parent_region_groups def.signature back_id
+ (* Translate the body, if there is *)
+ let body =
+ match body with
+ | None -> None
+ | Some body ->
+ let body = translate_expression config body ctx in
+ (* Compute the list of (properly ordered) input variables *)
+ let backward_inputs : var list =
+ match bid with
+ | None -> []
+ | Some back_id ->
+ let parents_ids =
+ list_ordered_parent_region_groups def.signature back_id
+ in
+ let backward_ids = List.append parents_ids [ back_id ] in
+ List.concat
+ (List.map
+ (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
+ backward_ids)
in
- let backward_ids = List.append parents_ids [ back_id ] in
- List.concat
- (List.map
- (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
- backward_ids)
- in
- let inputs = List.append ctx.forward_inputs backward_inputs in
- let inputs_lvs = List.map (fun v -> mk_typed_lvalue_from_var v None) inputs in
- (* Sanity check *)
- assert (
- List.for_all
- (fun (var, ty) -> (var : var).ty = ty)
- (List.combine inputs signature.inputs));
- let def =
- { def_id; back_id = bid; basename; signature; inputs; inputs_lvs; body }
+ let inputs = List.append ctx.forward_inputs backward_inputs in
+ let inputs_lvs =
+ List.map (fun v -> mk_typed_lvalue_from_var v None) inputs
+ in
+ (* Sanity check *)
+ assert (
+ List.for_all
+ (fun (var, ty) -> (var : var).ty = ty)
+ (List.combine inputs signature.inputs));
+ Some { inputs; inputs_lvs; body }
in
+ (* Assemble the declaration *)
+ let def = { def_id; back_id = bid; basename; signature; body } in
(* Debugging *)
log#ldebug
(lazy
diff --git a/src/Translate.ml b/src/Translate.ml
index ce669525..b522aeb7 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -152,11 +152,11 @@ let translate_function_to_pure (config : C.partial_config)
in
(* We need to initialize the input/output variables *)
- let forward_input_vars = LlbcAstUtils.fun_decl_get_input_vars fdef in
+ let forward_input_vars = LlbcAstUtils.fun_body_get_input_vars body in
let forward_input_varnames =
List.map (fun (v : A.var) -> v.name) forward_input_vars
in
- let num_forward_inputs = fdef.arg_count in
+ let num_forward_inputs = body.arg_count in
let add_forward_inputs input_svs ctx =
let input_svs = List.combine forward_input_varnames input_svs in
let ctx, forward_inputs =
@@ -276,7 +276,7 @@ let translate_module_to_pure (config : C.partial_config)
( A.Local fdef.def_id,
List.map
(fun (v : A.var) -> v.name)
- (LlbcAstUtils.fun_decl_get_input_vars fdef),
+ (LlbcAstUtils.fun_body_get_input_vars fdef),
fdef.signature ))
m.functions
in
@@ -285,7 +285,7 @@ let translate_module_to_pure (config : C.partial_config)
SymbolicToPure.translate_fun_signatures type_context.type_infos sigs
in
- (* Translate all the functions *)
+ (* Translate all the *transparent* functions *)
let pure_translations =
List.map
(translate_function_to_pure config mp_config trans_ctx fun_sigs