diff options
author | Son HO | 2022-09-22 18:52:15 +0200 |
---|---|---|
committer | GitHub | 2022-09-22 18:52:15 +0200 |
commit | dd75894c85bbaa5dc6aa54d39980e160e5b7777f (patch) | |
tree | ece56b01bcadea24a3c373236f0254f47e32a98f /src | |
parent | d8f92140abd7e65b6f1c5dd7e511c0c0aa69e73f (diff) | |
parent | 0d5fb87166cc4eb4ddc783d871ad459479fc9fdc (diff) |
Merge pull request #1 from AeneasVerif/constants-v2
Implement support for globals
Diffstat (limited to 'src')
40 files changed, 741 insertions, 611 deletions
diff --git a/src/Assumed.ml b/src/Assumed.ml index b3128b9b..1e8bb669 100644 --- a/src/Assumed.ml +++ b/src/Assumed.ml @@ -38,13 +38,9 @@ module Sig = struct (** A few utilities *) let rvar_id_0 = T.RegionVarId.of_int 0 - let rvar_0 : T.RegionVarId.id T.region = T.Var rvar_id_0 - let rg_id_0 = T.RegionGroupId.of_int 0 - let tvar_id_0 = T.TypeVarId.of_int 0 - let tvar_0 : T.sty = T.TypeVar tvar_id_0 (** Region 'a of id 0 *) @@ -218,8 +214,7 @@ module Sig = struct let inputs = [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *); - mk_usize_ty; - (* usize *) + mk_usize_ty (* usize *); ] in let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in diff --git a/src/Collections.ml b/src/Collections.ml index 614857e6..2cb298a7 100644 --- a/src/Collections.ml +++ b/src/Collections.ml @@ -88,9 +88,7 @@ module type OrderedType = sig include Map.OrderedType val to_string : t -> string - val pp_t : Format.formatter -> t -> unit - val show_t : t -> string end @@ -99,9 +97,7 @@ module OrderedString : OrderedType with type t = string = struct include String let to_string s = s - let pp_t fmt s = Format.pp_print_string fmt s - let show_t s = s end @@ -109,7 +105,6 @@ module type Map = sig include Map.S val add_list : (key * 'a) list -> 'a t -> 'a t - val of_list : (key * 'a) list -> 'a t val to_string : string option -> ('a -> string) -> 'a t -> string @@ -123,7 +118,6 @@ module type Map = sig *) val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit - val show : ('a -> string) -> 'a t -> string end @@ -132,7 +126,6 @@ module MakeMap (Ord : OrderedType) : Map with type key = Ord.t = struct include Map let add_list bl m = List.fold_left (fun s (key, e) -> add key e s) m bl - let of_list bl = add_list bl empty let to_string indent_opt a_to_string m = @@ -177,7 +170,6 @@ module type Set = sig include Set.S val add_list : elt list -> t -> t - val of_list : elt list -> t val to_string : string option -> t -> string @@ -191,7 +183,6 @@ module type Set = sig *) val pp : Format.formatter -> t -> unit - val show : t -> string end @@ -200,7 +191,6 @@ module MakeSet (Ord : OrderedType) : Set with type elt = Ord.t = struct include Set let add_list bl s = List.fold_left (fun s e -> add e s) s bl - let of_list bl = add_list bl empty let to_string indent_opt m = @@ -239,79 +229,43 @@ end *) module type InjMap = sig type key - type elem - type t val empty : t - val is_empty : t -> bool - val mem : key -> t -> bool - val add : key -> elem -> t -> t - val singleton : key -> elem -> t - val remove : key -> t -> t - val compare : (elem -> elem -> int) -> t -> t -> int - val equal : (elem -> elem -> bool) -> t -> t -> bool - val iter : (key -> elem -> unit) -> t -> unit - val fold : (key -> elem -> 'b -> 'b) -> t -> 'b -> 'b - val for_all : (key -> elem -> bool) -> t -> bool - val exists : (key -> elem -> bool) -> t -> bool - val filter : (key -> elem -> bool) -> t -> t - val partition : (key -> elem -> bool) -> t -> t * t - val cardinal : t -> int - val bindings : t -> (key * elem) list - val min_binding : t -> key * elem - val min_binding_opt : t -> (key * elem) option - val max_binding : t -> key * elem - val max_binding_opt : t -> (key * elem) option - val choose : t -> key * elem - val choose_opt : t -> (key * elem) option - val split : key -> t -> t * elem option * t - val find : key -> t -> elem - val find_opt : key -> t -> elem option - val find_first : (key -> bool) -> t -> key * elem - val find_first_opt : (key -> bool) -> t -> (key * elem) option - val find_last : (key -> bool) -> t -> key * elem - val find_last_opt : (key -> bool) -> t -> (key * elem) option - val to_seq : t -> (key * elem) Seq.t - val to_seq_from : key -> t -> (key * elem) Seq.t - val add_seq : (key * elem) Seq.t -> t -> t - val of_seq : (key * elem) Seq.t -> t - val add_list : (key * elem) list -> t -> t - val of_list : (key * elem) list -> t end @@ -322,15 +276,11 @@ module MakeInjMap (Key : OrderedType) (Elem : OrderedType) : module Set = MakeSet (Elem) type key = Key.t - type elem = Elem.t - type t = { map : elem Map.t; elems : Set.t } let empty = { map = Map.empty; elems = Set.empty } - let is_empty m = Map.is_empty m.map - let mem k m = Map.mem k m.map let add k e m = @@ -345,15 +295,10 @@ module MakeInjMap (Key : OrderedType) (Elem : OrderedType) : | Some x -> { map = Map.remove k m.map; elems = Set.remove x m.elems } let compare f m1 m2 = Map.compare f m1.map m2.map - let equal f m1 m2 = Map.equal f m1.map m2.map - let iter f m = Map.iter f m.map - let fold f m x = Map.fold f m.map x - let for_all f m = Map.for_all f m.map - let exists f m = Map.exists f m.map (** Small helper *) @@ -381,19 +326,12 @@ module MakeInjMap (Key : OrderedType) (Elem : OrderedType) : (map_to_t map1, map_to_t map2) let cardinal m = Map.cardinal m.map - let bindings m = Map.bindings m.map - let min_binding m = Map.min_binding m.map - let min_binding_opt m = Map.min_binding_opt m.map - let max_binding m = Map.max_binding m.map - let max_binding_opt m = Map.max_binding_opt m.map - let choose m = Map.choose m.map - let choose_opt m = Map.choose_opt m.map let split k m = @@ -403,19 +341,12 @@ module MakeInjMap (Key : OrderedType) (Elem : OrderedType) : (l, data, r) let find k m = Map.find k m.map - let find_opt k m = Map.find_opt k m.map - let find_first k m = Map.find_first k m.map - let find_first_opt k m = Map.find_first_opt k m.map - let find_last k m = Map.find_last k m.map - let find_last_opt k m = Map.find_last_opt k m.map - let to_seq m = Map.to_seq m.map - let to_seq_from k m = Map.to_seq_from k m.map let rec add_seq s m = @@ -428,8 +359,6 @@ module MakeInjMap (Key : OrderedType) (Elem : OrderedType) : add_seq s m let of_seq s = add_seq s empty - let add_list ls m = List.fold_left (fun m (key, elem) -> add key elem m) m ls - let of_list ls = add_list ls empty end diff --git a/src/Contexts.ml b/src/Contexts.ml index a4551420..716326cf 100644 --- a/src/Contexts.ml +++ b/src/Contexts.ml @@ -62,7 +62,6 @@ let symbolic_value_id_counter, fresh_symbolic_value_id = SymbolicValueId.fresh_stateful_generator () let borrow_id_counter, fresh_borrow_id = BorrowId.fresh_stateful_generator () - let region_id_counter, fresh_region_id = RegionId.fresh_stateful_generator () let abstraction_id_counter, fresh_abstraction_id = @@ -219,9 +218,13 @@ type type_context = { type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show] +type global_context = { global_decls : global_decl GlobalDeclId.Map.t } +[@@deriving show] + type eval_ctx = { type_context : type_context; fun_context : fun_context; + global_context : global_context; type_vars : type_var list; env : env; ended_regions : RegionId.Set.t; @@ -255,6 +258,11 @@ let ctx_lookup_type_decl (ctx : eval_ctx) (tid : TypeDeclId.id) : type_decl = let ctx_lookup_fun_decl (ctx : eval_ctx) (fid : FunDeclId.id) : fun_decl = FunDeclId.Map.find fid ctx.fun_context.fun_decls +(** TODO: make this more efficient with maps *) +let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : + global_decl = + GlobalDeclId.Map.find gid ctx.global_context.global_decls + (** Retrieve a variable's value in an environment *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = (* We take care to stop at the end of current frame: different variables diff --git a/src/Errors.ml b/src/Errors.ml index 69a030b1..31a53cf4 100644 --- a/src/Errors.ml +++ b/src/Errors.ml @@ -1,3 +1,2 @@ exception IntegerOverflow of unit - exception Unimplemented diff --git a/src/Expressions.ml b/src/Expressions.ml index 6bf14c66..bf06dd1e 100644 --- a/src/Expressions.ml +++ b/src/Expressions.ml @@ -72,30 +72,10 @@ let all_binops = Shr; ] -(** Constant value for an operand - - It is a bit annoying, but rustc treats some ADT and tuple instances as - constants when generating MIR: - - an enumeration with one variant and no fields is a constant. - - a structure with no field is a constant. - - sometimes, Rust stores the initialization of an ADT as a constant - (if all the fields are constant) rather than as an aggregated value - - For our translation, we use the following enumeration to encode those - special cases in assignments. They are converted to "normal" values - when evaluating the assignment (which is why we don't put them in the - [ConstantValue] enumeration). - *) -type operand_constant_value = - | ConstantValue of constant_value - | ConstantAdt of VariantId.id option * operand_constant_value list -[@@deriving show] - -(* TODO: symplify the operand constant values *) type operand = | Copy of place | Move of place - | Constant of ety * operand_constant_value + | Constant of ety * constant_value [@@deriving show] (** An aggregated ADT. diff --git a/src/ExtractAst.ml b/src/ExtractAst.ml deleted file mode 100644 index dd793291..00000000 --- a/src/ExtractAst.ml +++ /dev/null @@ -1,57 +0,0 @@ -(** This module defines the AST which is to be extracted to generate code. - This AST is voluntarily as simple as possible, so that the extraction - can focus on pretty-printing and on the syntax specific to the different - provers. - - TODO: we don't use this... - *) - -type constant_value = Pure.constant_value - -type pattern = - | PatVar of string - | PatDummy - | PatEnum of string * pattern list - (** Enum: the constructor name (tuple if `None`) and the fields. - Note that we never use structures as patters: we access the fields one - by one. - *) - | PatTuple of pattern list - -(** We want to keep terms a little bit structured, for pretty printing. - See the `FieldProj` and the `Record` cases, for instance. - *) -type term = - | Constant of constant_value - | Var of string - | FieldProj of term * term - (** `x.y` - - Of course, we can always use projectors like `record_get_y x`: - this variant is for pretty-printing. - - Note that `FieldProj` are generated when translating `place` from - the "pure" AST. - *) - | App of term * term - | Let of bool * pattern * term * term - | If of term * term * term - | Switch of term * (pattern * term) list - | Ascribed of term * term (** `x <: ty` *) - | Tuple of term list - | Record of (string * term) list - (** In case a record has named fields, we try to use them, to generate - code like: `{ x = 3; y = true; }` - Otherwise, we can use `App` (with the record constructor). - *) - -type fun_decl = { - name : string; - inputs : pattern list; - (** We can match over the inputs, hence the use of [pattern]. In practice, - we use [PatVar] and [PatDummy]. - *) - input_tys : term list; - output_ty : term; - body : term; -} diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 0bbe591e..b537e181 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -26,6 +26,14 @@ type type_decl_qualif = *) type fun_decl_qualif = Let | LetRec | And | Val | AssumeVal +let fun_decl_qualif_keyword (qualif : fun_decl_qualif) : string = + match qualif with + | Let -> "let" + | LetRec -> "let rec" + | And -> "and" + | Val -> "val" + | AssumeVal -> "assume val" + (** Small helper to compute the name of an int type *) let fstar_int_name (int_ty : integer_type) = match int_ty with @@ -305,6 +313,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* Concatenate the elements *) String.concat "_" fname in + let global_name (name : global_name) : string = + (* Converting to snake case also lowercases the letters (in Rust, global + * names are written in capital letters). *) + let parts = List.map to_snake_case (get_name name) in + String.concat "_" parts + in let fun_name (_fid : A.fun_id) (fname : fun_name) (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) : string = let fname = fun_name_to_snake_case fname in @@ -314,7 +328,8 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ suffix in - let decreases_clause_name (_fid : FunDeclId.id) (fname : fun_name) : string = + let decreases_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) : string + = let fname = fun_name_to_snake_case fname in (* Compute the suffix *) let suffix = "_decreases" in @@ -403,6 +418,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) variant_name; struct_constructor; type_name; + global_name; fun_name; decreases_clause_name; var_basename; @@ -781,6 +797,11 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) (* Return *) ctx +(** Simply add the global name to the context. *) +let extract_global_decl_register_names (ctx : extraction_ctx) + (def : A.global_decl) : extraction_ctx = + ctx_add_global_decl_and_body def ctx + (** The following function factorizes the extraction of ADT values. Note that patterns can introduce new variables: we thus return an extraction @@ -831,9 +852,14 @@ let extract_adt_g_value ctx | _ -> raise (Failure "Inconsistent typed value") +(* Extract globals in the same way as variables *) +let extract_global (ctx : extraction_ctx) (fmt : F.formatter) + (id : A.GlobalDeclId.id) : unit = + F.pp_print_string fmt (ctx_get_global id ctx) + (** [inside]: see [extract_ty]. - As an pattern can introduce new variables, we return an extraction context + As a pattern can introduce new variables, we return an extraction context updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) @@ -888,6 +914,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body | Meta (_, e) -> extract_texpression ctx fmt inside e +(* Extract an application *or* a top-level qualif (function extraction has + * to handle top-level qualifiers, so it seemed more natural to merge the + * two cases) *) and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (app : texpression) (args : texpression list) : unit = (* We don't do the same thing if the app is a top-level identifier (function, @@ -898,6 +927,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) match qualif.id with | Func fun_id -> extract_function_call ctx fmt inside fun_id qualif.type_args args + | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args | Proj proj -> @@ -1337,6 +1367,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (qualif : fun_decl_qualif) (has_decreases_clause : bool) (def : fun_decl) : unit = + assert (not def.is_global_decl_body); (* Retrieve the function name *) let def_name = ctx_get_local_function def.def_id def.back_id ctx in (* (* Add the type parameters - note that we need those bindings only for the @@ -1355,14 +1386,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hovbox fmt ctx.indent_incr; (* > "let FUN_NAME" *) let is_opaque = Option.is_none def.body in - let qualif = - match qualif with - | Let -> "let" - | LetRec -> "let rec" - | And -> "and" - | Val -> "val" - | AssumeVal -> "assume val" - in + let qualif = fun_decl_qualif_keyword qualif in F.pp_print_string fmt (qualif ^ " " ^ def_name); F.pp_print_space fmt (); (* Open a box for "(PARAMS) : EFFECT =" *) @@ -1471,6 +1495,98 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 +(** Extract a global declaration body of the shape "QUALIF NAME : TYPE = BODY" with a custom body extractor *) +let extract_global_decl_body (ctx : extraction_ctx) (fmt : F.formatter) + (qualif : fun_decl_qualif) (name : string) (ty : ty) + (extract_body : (F.formatter -> unit) Option.t) : unit = + let is_opaque = Option.is_none extract_body in + + (* Open the definition box (depth=0) *) + F.pp_open_hvbox fmt ctx.indent_incr; + + (* Open "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "QUALIF NAME " *) + F.pp_print_string fmt (fun_decl_qualif_keyword qualif ^ " " ^ name); + F.pp_print_space fmt (); + + (* Open ": TYPE =" box (depth=2) *) + F.pp_open_hvbox fmt 0; + (* Print ": " *) + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + + (* Open "TYPE" box (depth=3) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "TYPE" *) + extract_ty ctx fmt false ty; + (* Close "TYPE" box (depth=3) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + (* Print " =" *) + F.pp_print_space fmt (); + F.pp_print_string fmt "="); + (* Close ": TYPE =" box (depth=2) *) + F.pp_close_box fmt (); + (* Close "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + F.pp_print_space fmt (); + (* Open "BODY" box (depth=1) *) + F.pp_open_hvbox fmt 0; + (* Print "BODY" *) + (Option.get extract_body) fmt; + (* Close "BODY" box (depth=1) *) + F.pp_close_box fmt ()); + (* Close the definition box (depth=0) *) + F.pp_close_box fmt () + +(** Extract a global declaration. + We generate the body which computes the global value separately from the value declaration itself. + + For example in Rust, + `static X: u32 = 3;` + + will be translated to: + `let x_body : result u32 = Return 3` + `let x_c : u32 = eval_global x_body` + *) +let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) + (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = + assert body.is_global_decl_body; + assert (Option.is_none body.back_id); + assert (List.length body.signature.inputs = 0); + assert (List.length body.signature.doutputs = 1); + assert (List.length body.signature.type_params = 0); + + (* Add a break then the name of the corresponding LLBC declaration *) + F.pp_print_break fmt 0 0; + F.pp_print_string fmt + ("(** [" ^ Print.global_name_to_string global.name ^ "] *)"); + F.pp_print_space fmt (); + + let decl_name = ctx_get_global global.def_id ctx in + let body_name = ctx_get_function (Regular global.body_id) None ctx in + + let decl_ty, body_ty = + let ty = body.signature.output in + if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + else (ty, mk_result_ty ty) + in + match body.body with + | None -> + let qualif = if interface then Val else AssumeVal in + extract_global_decl_body ctx fmt qualif decl_name decl_ty None + | Some body -> + extract_global_decl_body ctx fmt Let body_name body_ty + (Some (fun fmt -> extract_texpression ctx fmt false body.body)); + F.pp_print_break fmt 0 0; + extract_global_decl_body ctx fmt Let decl_name decl_ty + (Some (fun fmt -> F.pp_print_string fmt ("eval_global " ^ body_name))); + F.pp_print_break fmt 0 0 + (** Extract a unit test, if the function is a unit function (takes no parameters, returns unit). diff --git a/src/FunsAnalysis.ml b/src/FunsAnalysis.ml index b8dd17d8..615f45b3 100644 --- a/src/FunsAnalysis.ml +++ b/src/FunsAnalysis.ml @@ -1,7 +1,7 @@ (** Compute various information, including: - can a function fail (by having `Fail` in its body, or transitively - calling a function which can fail) - - can a function diverge (bu being recursive, containing a loop or + calling a function which can fail - this is false for globals) + - can a function diverge (by being recursive, containing a loop or transitively calling a function which can diverge) - does a function perform stateful operations (i.e., do we need a state to translate it) @@ -27,7 +27,8 @@ type modules_funs_info = fun_info FunDeclId.Map.t (** Various information about a module's functions *) let analyze_module (m : llbc_module) (funs_map : fun_decl FunDeclId.Map.t) - (use_state : bool) : modules_funs_info = + (globals_map : global_decl GlobalDeclId.Map.t) (use_state : bool) : + modules_funs_info = let infos = ref FunDeclId.Map.empty in let register_info (id : FunDeclId.id) (info : fun_info) : unit = @@ -50,54 +51,66 @@ let analyze_module (m : llbc_module) (funs_map : fun_decl FunDeclId.Map.t) let stateful = ref false in let divergent = ref false in - let obj = - object - inherit [_] iter_statement as super - - method! visit_Assert env a = - can_fail := true; - super#visit_Assert env a - - method! visit_rvalue _env rv = - match rv with - | Use _ | Ref _ | Discriminant _ | Aggregate _ -> () - | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail - | BinaryOp (bop, _, _) -> - can_fail := EU.binop_can_fail bop || !can_fail - - method! visit_Call env call = - (match call.func with - | Regular id -> - if FunDeclId.Set.mem id fun_ids then divergent := true - else - let info = FunDeclId.Map.find id !infos in - can_fail := !can_fail || info.can_fail; - stateful := !stateful || info.stateful; - divergent := !divergent || info.divergent - | Assumed id -> - (* None of the assumed functions is stateful for now *) - can_fail := !can_fail || Assumed.assumed_can_fail id); - super#visit_Call env call - - method! visit_Panic env = - can_fail := true; - super#visit_Panic env - - method! visit_Loop env loop = - divergent := true; - super#visit_Loop env loop - end - in - let visit_fun (f : fun_decl) : unit = + let obj = + object (self) + inherit [_] iter_statement as super + method may_fail b = can_fail := !can_fail || b + + method! visit_Assert env a = + self#may_fail true; + super#visit_Assert env a + + method! visit_rvalue _env rv = + match rv with + | Use _ | Ref _ | Discriminant _ | Aggregate _ -> () + | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail + | BinaryOp (bop, _, _) -> + can_fail := EU.binop_can_fail bop || !can_fail + + method! visit_Call env call = + (match call.func with + | Regular id -> + if FunDeclId.Set.mem id fun_ids then divergent := true + else + let info = FunDeclId.Map.find id !infos in + self#may_fail info.can_fail; + stateful := !stateful || info.stateful; + divergent := !divergent || info.divergent + | Assumed id -> + (* None of the assumed functions is stateful for now *) + can_fail := !can_fail || Assumed.assumed_can_fail id); + super#visit_Call env call + + method! visit_Panic env = + self#may_fail true; + super#visit_Panic env + + method! visit_Loop env loop = + divergent := true; + super#visit_Loop env loop + end + in + (* Sanity check: global bodies don't contain stateful calls *) + assert ((not f.is_global_decl_body) || not !stateful); match f.body with | None -> - (* Opaque function *) - can_fail := true; - stateful := use_state + (* Opaque function: we consider they fail by default *) + obj#may_fail true; + stateful := (not f.is_global_decl_body) && use_state | Some body -> obj#visit_statement () body.body in List.iter visit_fun d; + (* We need to know if the declaration group contains a global - note that + * groups containing globals contain exactly one declaration *) + let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in + assert ((not is_global_decl_body) || List.length d == 1); + (* We ignore on purpose functions that cannot fail and consider they *can* + * fail: the result of the analysis is not used yet to adjust the translation + * so that the functions which syntactically can't fail don't use an error monad. + * However, we do keep the result of the analysis for global bodies. + * *) + can_fail := (not is_global_decl_body) || !can_fail; { can_fail = !can_fail; stateful = !stateful; divergent = !divergent } in @@ -118,6 +131,11 @@ let analyze_module (m : llbc_module) (funs_map : fun_decl FunDeclId.Map.t) | Fun decl :: decls' -> analyze_fun_decl_group decl; analyze_decl_groups decls' + | Global id :: decls' -> + (* Analyze a global by analyzing its body function *) + let global = GlobalDeclId.Map.find id globals_map in + analyze_fun_decl_group (NonRec global.body_id); + analyze_decl_groups decls' in analyze_decl_groups m.declarations; diff --git a/src/Identifiers.ml b/src/Identifiers.ml index 61238aac..9f6a863d 100644 --- a/src/Identifiers.ml +++ b/src/Identifiers.ml @@ -13,15 +13,10 @@ module type Id = sig (** Id generator - simply a counter *) val zero : id - val generator_zero : generator - val generator_from_incr_id : id -> generator - val fresh_stateful_generator : unit -> generator ref * (unit -> id) - val mk_stateful_generator : generator -> generator ref * (unit -> id) - val incr : id -> id (* TODO: this should be stateful! - but we may want to be able to duplicate @@ -30,29 +25,17 @@ module type Id = sig TODO: change the order of the returned types *) val fresh : generator -> id * generator - val to_string : id -> string - val pp_id : Format.formatter -> id -> unit - val show_id : id -> string - val id_of_json : Yojson.Basic.t -> (id, string) result - val compare_id : id -> id -> int - val max : id -> id -> id - val min : id -> id -> id - val pp_generator : Format.formatter -> generator -> unit - val show_generator : generator -> string - val to_int : id -> int - val of_int : int -> id - val nth : 'a list -> id -> 'a (* TODO: change the signature (invert the index and the list *) @@ -75,9 +58,7 @@ module type Id = sig val iteri : (id -> 'a -> unit) -> 'a list -> unit module Ord : C.OrderedType with type t = id - module Set : C.Set with type elt = id - module Map : C.Map with type key = id end @@ -88,11 +69,9 @@ end module IdGen () : Id = struct (* TODO: use Z.t *) type id = int [@@deriving show] - type generator = id [@@deriving show] let zero = 0 - let generator_zero = 0 let incr x = @@ -113,13 +92,9 @@ module IdGen () : Id = struct (g, fresh) let fresh_stateful_generator () = mk_stateful_generator 0 - let fresh gen = (gen, incr gen) - let to_string = string_of_int - let to_int x = x - let of_int x = x let id_of_json js = @@ -129,13 +104,9 @@ module IdGen () : Id = struct | _ -> Error ("id_of_json: failed on " ^ Yojson.Basic.show js) let compare_id = compare - let max id0 id1 = if id0 > id1 then id0 else id1 - let min id0 id1 = if id0 < id1 then id0 else id1 - let nth v id = List.nth v id - let nth_opt v id = List.nth_opt v id let rec update_nth vec id v = @@ -158,11 +129,8 @@ module IdGen () : Id = struct type t = id let compare = compare - let to_string = to_string - let pp_t = pp_id - let show_t = show_id end diff --git a/src/Interpreter.ml b/src/Interpreter.ml index cbbf2b2e..3a2939ef 100644 --- a/src/Interpreter.ml +++ b/src/Interpreter.ml @@ -13,11 +13,11 @@ module SA = SymbolicAst (** The local logger *) let log = L.interpreter_log -let compute_type_fun_contexts (m : M.llbc_module) : - C.type_context * C.fun_context = - let type_decls_list, _ = M.split_declarations m.declarations in - let type_decls, fun_decls = M.compute_defs_maps m in - let type_decls_groups, _funs_defs_groups = +let compute_type_fun_global_contexts (m : M.llbc_module) : + C.type_context * C.fun_context * C.global_context = + let type_decls_list, _, _ = M.split_declarations m.declarations in + let type_decls, fun_decls, global_decls = M.compute_defs_maps m in + let type_decls_groups, _funs_defs_groups, _globals_defs_groups = M.split_declarations_to_group_maps m.declarations in let type_infos = @@ -25,14 +25,17 @@ let compute_type_fun_contexts (m : M.llbc_module) : in let type_context = { C.type_decls_groups; type_decls; type_infos } in let fun_context = { C.fun_decls } in - (type_context, fun_context) + let global_context = { C.global_decls } in + (type_context, fun_context, global_context) let initialize_eval_context (type_context : C.type_context) - (fun_context : C.fun_context) (type_vars : T.type_var list) : C.eval_ctx = + (fun_context : C.fun_context) (global_context : C.global_context) + (type_vars : T.type_var list) : C.eval_ctx = C.reset_global_counters (); { C.type_context; C.fun_context; + C.global_context; C.type_vars; C.env = [ C.Frame ]; C.ended_regions = T.RegionId.Set.empty; @@ -52,8 +55,8 @@ let initialize_eval_context (type_context : C.type_context) - the instantiated function signature *) let initialize_symbolic_context_for_fun (type_context : C.type_context) - (fun_context : C.fun_context) (fdef : A.fun_decl) : - C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = + (fun_context : C.fun_context) (global_context : C.global_context) + (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = (* The abstractions are not initialized the same way as for function * calls: they contain *loan* projectors, because they "provide" us * with the input values (which behave as if they had been returned @@ -67,7 +70,10 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) * *) let sg = fdef.signature in (* Create the context *) - let ctx = initialize_eval_context type_context fun_context sg.type_params in + let ctx = + initialize_eval_context type_context fun_context global_context + sg.type_params + in (* Instantiate the signature *) let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in let inst_sg = instantiate_fun_sig type_params sg in @@ -205,7 +211,8 @@ let evaluate_function_symbolic_synthesize_backward_from_return *) let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool) (type_context : C.type_context) (fun_context : C.fun_context) - (fdef : A.fun_decl) (back_id : T.RegionGroupId.id option) : + (global_context : C.global_context) (fdef : A.fun_decl) + (back_id : T.RegionGroupId.id option) : V.symbolic_value list * SA.expression option = (* Debug *) let name_to_string () = @@ -218,7 +225,8 @@ let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool) (* Create the evaluation context *) let ctx, input_svs, inst_sg = - initialize_symbolic_context_for_fun type_context fun_context fdef + initialize_symbolic_context_for_fun type_context fun_context global_context + fdef in (* Create the continuation to finish the evaluation *) @@ -285,8 +293,12 @@ module Test = struct assert (body.A.arg_count = 0); (* Create the evaluation context *) - let type_context, fun_context = compute_type_fun_contexts m in - let ctx = initialize_eval_context type_context fun_context [] in + let type_context, fun_context, global_context = + compute_type_fun_global_contexts m + in + let ctx = + initialize_eval_context type_context fun_context global_context [] + in (* Insert the (uninitialized) local variables *) let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in @@ -331,14 +343,15 @@ module Test = struct (** Execute the symbolic interpreter on a function. *) let test_function_symbolic (config : C.partial_config) (synthesize : bool) (type_context : C.type_context) (fun_context : C.fun_context) - (fdef : A.fun_decl) : unit = + (global_context : C.global_context) (fdef : A.fun_decl) : unit = (* Debug *) log#ldebug (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name)); (* Evaluate *) let evaluate = - evaluate_function_symbolic config synthesize type_context fun_context fdef + evaluate_function_symbolic config synthesize type_context fun_context + global_context fdef in (* Execute the forward function *) let _ = evaluate None in @@ -368,12 +381,15 @@ module Test = struct in (* Filter the opaque functions *) let no_loop_funs = List.filter fun_decl_is_transparent no_loop_funs in - let type_context, fun_context = compute_type_fun_contexts m in + let type_context, fun_context, global_context = + compute_type_fun_global_contexts m + in let test_fun (def : A.fun_decl) : unit = (* Execute the function - note that as the symbolic interpreter explores * all the path, some executions are expected to "panic": we thus don't * check the return value *) - test_function_symbolic config synthesize type_context fun_context def + test_function_symbolic config synthesize type_context fun_context + global_context def in List.iter test_fun no_loop_funs end diff --git a/src/InterpreterBorrows.ml b/src/InterpreterBorrows.ml index a13ac786..6b920a51 100644 --- a/src/InterpreterBorrows.ml +++ b/src/InterpreterBorrows.ml @@ -436,7 +436,7 @@ let give_back_symbolic_value (_config : C.config) assert (sv.sv_id <> nsv.sv_id); (match nsv.sv_kind with | V.SynthInputGivenBack | V.SynthRetGivenBack | V.FunCallGivenBack -> () - | V.FunCallRet | V.SynthInput -> failwith "Unrechable"); + | V.FunCallRet | V.SynthInput | V.Global -> failwith "Unrechable"); (* Store the given-back value as a meta-value for synthesis purposes *) let mv = nsv in (* Substitution function, to replace the borrow projectors over symbolic values *) diff --git a/src/InterpreterBorrowsCore.ml b/src/InterpreterBorrowsCore.ml index d47989c3..f2f10944 100644 --- a/src/InterpreterBorrowsCore.ml +++ b/src/InterpreterBorrowsCore.ml @@ -582,7 +582,6 @@ let get_first_loan_in_value (v : V.typed_value) : V.loan_content option = let obj = object inherit [_] V.iter_typed_value - method! visit_loan_content _ lc = raise (FoundLoanContent lc) end in @@ -597,7 +596,6 @@ let get_first_borrow_in_value (v : V.typed_value) : V.borrow_content option = let obj = object inherit [_] V.iter_typed_value - method! visit_borrow_content _ bc = raise (FoundBorrowContent bc) end in @@ -700,7 +698,6 @@ let lookup_intersecting_aproj_borrows_opt (lookup_shared : bool) let obj = object inherit [_] C.iter_eval_ctx as super - method! visit_abs _ abs = super#visit_abs (Some abs) abs method! visit_abstract_shared_borrows abs asb = @@ -791,7 +788,6 @@ let update_intersecting_aproj_borrows (can_update_shared : bool) let obj = object inherit [_] C.map_eval_ctx as super - method! visit_abs _ abs = super#visit_abs (Some abs) abs method! visit_abstract_shared_borrows abs asb = @@ -920,7 +916,6 @@ let update_intersecting_aproj_loans (proj_regions : T.RegionId.Set.t) let obj = object inherit [_] C.map_eval_ctx as super - method! visit_abs _ abs = super#visit_abs (Some abs) abs method! visit_aproj abs sproj = diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml index 6bb2baf0..4a4f3353 100644 --- a/src/InterpreterExpressions.ml +++ b/src/InterpreterExpressions.ml @@ -1,11 +1,13 @@ module T = Types module V = Values +module LA = LlbcAst open Scalars module E = Expressions open Errors module C = Contexts module Subst = Substitute module L = Logging +module PV = Print.Values open TypesUtils open ValuesUtils module Inv = Invariants @@ -108,53 +110,25 @@ let access_rplace_reorganize (config : C.config) (expand_prim_copy : bool) ctx (** Convert an operand constant operand value to a typed value *) -let rec operand_constant_value_to_typed_value (ctx : C.eval_ctx) (ty : T.ety) - (cv : E.operand_constant_value) : V.typed_value = +let constant_to_typed_value (ty : T.ety) (cv : V.constant_value) : V.typed_value + = (* Check the type while converting - we actually need some information - * contained in the type *) + * contained in the type *) log#ldebug (lazy - ("operand_constant_value_to_typed_value:" ^ "\n- ty: " - ^ ety_to_string ctx ty ^ "\n- cv: " - ^ operand_constant_value_to_string ctx cv)); + ("constant_to_typed_value:" ^ "\n- cv: " ^ PV.constant_value_to_string cv)); match (ty, cv) with - (* Adt *) - | ( T.Adt (adt_id, region_params, type_params), - ConstantAdt (variant_id, field_values) ) -> - assert (region_params = []); - (* Compute the types of the fields *) - let field_tys = - match adt_id with - | T.AdtId def_id -> - let def = C.ctx_lookup_type_decl ctx def_id in - assert (def.region_params = []); - Subst.type_decl_get_instantiated_field_etypes def variant_id - type_params - | T.Tuple -> type_params - | T.Assumed _ -> failwith "Unreachable" - in - (* Compute the field values *) - let field_values = - List.map - (fun (ty, v) -> operand_constant_value_to_typed_value ctx ty v) - (List.combine field_tys field_values) - in - (* Put together *) - let value = V.Adt { variant_id; field_values } in - { value; ty } (* Scalar, boolean... *) - | T.Bool, ConstantValue (Bool v) -> { V.value = V.Concrete (Bool v); ty } - | T.Char, ConstantValue (Char v) -> { V.value = V.Concrete (Char v); ty } - | T.Str, ConstantValue (String v) -> { V.value = V.Concrete (String v); ty } - | T.Integer int_ty, ConstantValue (V.Scalar v) -> + | T.Bool, Bool v -> { V.value = V.Concrete (Bool v); ty } + | T.Char, Char v -> { V.value = V.Concrete (Char v); ty } + | T.Str, String v -> { V.value = V.Concrete (String v); ty } + | T.Integer int_ty, V.Scalar v -> (* Check the type and the ranges *) assert (int_ty = v.int_ty); assert (check_scalar_value_in_range v); { V.value = V.Concrete (V.Scalar v); ty } - (* Remaining cases (invalid) - listing as much as we can on purpose - (allows to catch errors at compilation if the definitions change) *) - | _, ConstantAdt _ | _, ConstantValue _ -> - failwith "Improperly typed constant value" + (* Remaining cases (invalid) *) + | _, _ -> failwith "Improperly typed constant value" (** Reorganize the environment in preparation for the evaluation of an operand. @@ -197,8 +171,9 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : let prepare : cm_fun = fun cf ctx -> match op with - | Expressions.Constant _ -> + | Expressions.Constant (ty, cv) -> (* No need to reorganize the context *) + constant_to_typed_value ty cv |> ignore; cf ctx | Expressions.Copy p -> (* Access the value *) @@ -226,9 +201,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> - let v = operand_constant_value_to_typed_value ctx ty cv in - cf v ctx + | Expressions.Constant (ty, cv) -> cf (constant_to_typed_value ty cv) ctx | Expressions.Copy p -> (* Access the value *) let access = Read in diff --git a/src/InterpreterStatements.ml b/src/InterpreterStatements.ml index 585fa828..34310ea1 100644 --- a/src/InterpreterStatements.ml +++ b/src/InterpreterStatements.ml @@ -831,6 +831,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = (* Compose and apply *) comp cf_eval_rvalue cf_assign cf ctx + | A.AssignGlobal { dst; global } -> eval_global config dst global cf ctx | A.FakeRead p -> let expand_prim_copy = false in let cf_prepare cf = @@ -908,6 +909,28 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = (* Compose and apply *) comp cc cf_eval_st cf ctx +and eval_global (config : C.config) (dest : V.VarId.id) + (gid : LA.GlobalDeclId.id) : st_cm_fun = + fun cf ctx -> + let global = C.ctx_lookup_global_decl ctx gid in + let place = { E.var_id = dest; projection = [] } in + match config.mode with + | ConcreteMode -> + (* Treat the evaluation of the global as a call to the global body (without arguments) *) + (eval_local_function_call_concrete config global.body_id [] [] [] place) + cf ctx + | SymbolicMode -> + (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be + * defined as equal to the value of the global (see `S.synthesize_global_eval`). *) + let sval = + mk_fresh_symbolic_value V.Global (ety_no_regions_to_rty global.ty) + in + let cc = + assign_to_place config (mk_typed_value_from_symbolic_value sval) place + in + let e = cc (cf Unit) ctx in + S.synthesize_global_eval gid sval e + (** Evaluate a switch *) and eval_switch (config : C.config) (op : E.operand) (tgts : A.switch_targets) : st_cm_fun = diff --git a/src/InterpreterUtils.ml b/src/InterpreterUtils.ml index 7a2e22f7..fed5ff9b 100644 --- a/src/InterpreterUtils.ml +++ b/src/InterpreterUtils.ml @@ -12,35 +12,19 @@ module PA = Print.EvalCtxLlbcAst (** Some utilities *) let eval_ctx_to_string = Print.Contexts.eval_ctx_to_string - let ety_to_string = PA.ety_to_string - let rty_to_string = PA.rty_to_string - let symbolic_value_to_string = PA.symbolic_value_to_string - let borrow_content_to_string = PA.borrow_content_to_string - let loan_content_to_string = PA.loan_content_to_string - let aborrow_content_to_string = PA.aborrow_content_to_string - let aloan_content_to_string = PA.aloan_content_to_string - let aproj_to_string = PA.aproj_to_string - let typed_value_to_string = PA.typed_value_to_string - let typed_avalue_to_string = PA.typed_avalue_to_string - -let operand_constant_value_to_string = PA.operand_constant_value_to_string - let place_to_string = PA.place_to_string - let operand_to_string = PA.operand_to_string - let statement_to_string ctx = PA.statement_to_string ctx "" " " - let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " " let same_symbolic_id (sv0 : V.symbolic_value) (sv1 : V.symbolic_value) : bool = @@ -213,7 +197,6 @@ let bottom_in_value (ended_regions : T.RegionId.Set.t) (v : V.typed_value) : let obj = object inherit [_] V.iter_typed_value - method! visit_Bottom _ = raise Found method! visit_symbolic_value _ s = @@ -242,6 +225,7 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx) | V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack | V.SynthRetGivenBack -> () + | V.Global -> () end in (* We use exceptions *) diff --git a/src/Invariants.ml b/src/Invariants.ml index 81e35de3..ef255010 100644 --- a/src/Invariants.ml +++ b/src/Invariants.ml @@ -399,7 +399,6 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = let visitor = object inherit [_] C.iter_eval_ctx as super - method! visit_abs _ abs = super#visit_abs (Some abs) abs method! visit_typed_value info tv = @@ -705,9 +704,7 @@ let check_symbolic_values (_config : C.config) (ctx : C.eval_ctx) : unit = let obj = object inherit [_] C.iter_eval_ctx as super - method! visit_abs _ abs = super#visit_abs (Some abs) abs - method! visit_Symbolic _ sv = add_env_sv sv method! visit_abstract_shared_borrows abs asb = diff --git a/src/LlbcAst.ml b/src/LlbcAst.ml index d35cd5d8..ccc870dc 100644 --- a/src/LlbcAst.ml +++ b/src/LlbcAst.ml @@ -1,10 +1,10 @@ -open Identifiers open Names open Types open Values open Expressions - +open Identifiers module FunDeclId = IdGen () +module GlobalDeclId = IdGen () type var = { index : VarId.id; (** Unique variable identifier *) @@ -36,6 +36,9 @@ type assumed_fun_id = type fun_id = Regular of FunDeclId.id | Assumed of assumed_fun_id [@@deriving show, ord] +type global_assignment = { dst : VarId.id; global : GlobalDeclId.id } +[@@deriving show] + type assertion = { cond : operand; expected : bool } [@@deriving show] type abs_region_group = (AbstractionId.id, RegionId.id) g_region_group @@ -77,20 +80,16 @@ class ['self] iter_statement_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - method visit_place : 'env -> place -> unit = fun _ _ -> () + method visit_global_assignment : 'env -> global_assignment -> unit = + fun _ _ -> () + method visit_place : 'env -> place -> unit = fun _ _ -> () method visit_rvalue : 'env -> rvalue -> unit = fun _ _ -> () - method visit_id : 'env -> VariantId.id -> unit = fun _ _ -> () - method visit_assertion : 'env -> assertion -> unit = fun _ _ -> () - method visit_operand : 'env -> operand -> unit = fun _ _ -> () - method visit_call : 'env -> call -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () - method visit_scalar_value : 'env -> scalar_value -> unit = fun _ _ -> () end @@ -99,16 +98,15 @@ class ['self] map_statement_base = object (_self : 'self) inherit [_] VisitorsRuntime.map - method visit_place : 'env -> place -> place = fun _ x -> x + method visit_global_assignment + : 'env -> global_assignment -> global_assignment = + fun _ x -> x + method visit_place : 'env -> place -> place = fun _ x -> x method visit_rvalue : 'env -> rvalue -> rvalue = fun _ x -> x - method visit_id : 'env -> VariantId.id -> VariantId.id = fun _ x -> x - method visit_assertion : 'env -> assertion -> assertion = fun _ x -> x - method visit_operand : 'env -> operand -> operand = fun _ x -> x - method visit_call : 'env -> call -> call = fun _ x -> x method visit_integer_type : 'env -> integer_type -> integer_type = @@ -120,6 +118,7 @@ class ['self] map_statement_base = type statement = | Assign of place * rvalue + | AssignGlobal of global_assignment | FakeRead of place | SetDiscriminant of place * VariantId.id | Drop of place @@ -178,5 +177,14 @@ type fun_decl = { name : fun_name; signature : fun_sig; body : fun_body option; + is_global_decl_body : bool; +} +[@@deriving show] + +type global_decl = { + def_id : GlobalDeclId.id; + body_id : FunDeclId.id; + name : global_name; + ty : ety; } [@@deriving show] diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml index 99d652ec..4e10c642 100644 --- a/src/LlbcOfJson.ml +++ b/src/LlbcOfJson.ml @@ -17,6 +17,8 @@ module S = Scalars module M = Modules module E = Expressions module A = LlbcAst +module TU = TypesUtils +module AU = LlbcAstUtils (* The default logger *) let log = Logging.llbc_of_json_logger @@ -298,23 +300,6 @@ let scalar_value_of_json (js : json) : (V.scalar_value, string) result = raise (Failure ("Scalar value not in range: " ^ V.show_scalar_value sv))); res -let constant_value_of_json (js : json) : (V.constant_value, string) result = - combine_error_msgs js "constant_value_of_json" - (match js with - | `Assoc [ ("Scalar", scalar_value) ] -> - let* scalar_value = scalar_value_of_json scalar_value in - Ok (V.Scalar scalar_value) - | `Assoc [ ("Bool", v) ] -> - let* v = bool_of_json v in - Ok (V.Bool v) - | `Assoc [ ("Char", v) ] -> - let* v = char_of_json v in - Ok (V.Char v) - | `Assoc [ ("String", v) ] -> - let* v = string_of_json v in - Ok (V.String v) - | _ -> Error "") - let field_proj_kind_of_json (js : json) : (E.field_proj_kind, string) result = combine_error_msgs js "field_proj_kind_of_json" (match js with @@ -393,19 +378,21 @@ let binop_of_json (js : json) : (E.binop, string) result = | `String "Shr" -> Ok E.Shr | _ -> Error ("binop_of_json failed on:" ^ show js) -let rec operand_constant_value_of_json (js : json) : - (E.operand_constant_value, string) result = - combine_error_msgs js "operand_constant_value_of_json" +let constant_value_of_json (js : json) : (V.constant_value, string) result = + combine_error_msgs js "constant_value_of_json" (match js with - | `Assoc [ ("ConstantValue", `List [ cv ]) ] -> - let* cv = constant_value_of_json cv in - Ok (E.ConstantValue cv) - | `Assoc [ ("ConstantAdt", `List [ variant_id; field_values ]) ] -> - let* variant_id = option_of_json T.VariantId.id_of_json variant_id in - let* field_values = - list_of_json operand_constant_value_of_json field_values - in - Ok (E.ConstantAdt (variant_id, field_values)) + | `Assoc [ ("Scalar", scalar_value) ] -> + let* scalar_value = scalar_value_of_json scalar_value in + Ok (V.Scalar scalar_value) + | `Assoc [ ("Bool", v) ] -> + let* v = bool_of_json v in + Ok (V.Bool v) + | `Assoc [ ("Char", v) ] -> + let* v = char_of_json v in + Ok (V.Char v) + | `Assoc [ ("String", v) ] -> + let* v = string_of_json v in + Ok (V.String v) | _ -> Error "") let operand_of_json (js : json) : (E.operand, string) result = @@ -417,9 +404,9 @@ let operand_of_json (js : json) : (E.operand, string) result = | `Assoc [ ("Move", place) ] -> let* place = place_of_json place in Ok (E.Move place) - | `Assoc [ ("Constant", `List [ ty; cv ]) ] -> + | `Assoc [ ("Const", `List [ ty; cv ]) ] -> let* ty = ety_of_json ty in - let* cv = operand_constant_value_of_json cv in + let* cv = constant_value_of_json cv in Ok (E.Constant (ty, cv)) | _ -> Error "") @@ -560,6 +547,10 @@ let rec statement_of_json (js : json) : (A.statement, string) result = let* place = place_of_json place in let* rvalue = rvalue_of_json rvalue in Ok (A.Assign (place, rvalue)) + | `Assoc [ ("AssignGlobal", `List [ dst; global ]) ] -> + let* dst = V.VarId.id_of_json dst in + let* global = A.GlobalDeclId.id_of_json global in + Ok (A.AssignGlobal { dst; global }) | `Assoc [ ("FakeRead", place) ] -> let* place = place_of_json place in Ok (A.FakeRead place) @@ -640,7 +631,52 @@ let fun_decl_of_json (js : json) : (A.fun_decl, string) result = let* name = fun_name_of_json name in let* signature = fun_sig_of_json signature in let* body = option_of_json fun_body_of_json body in - Ok { A.def_id; name; signature; body } + Ok { A.def_id; name; signature; body; is_global_decl_body = false } + | _ -> Error "") + +(* Strict type for the number of function declarations (see [global_to_fun_id] below) *) +type global_id_converter = { fun_count : int } [@@deriving show] + +(** Converts a global id to its corresponding function id. + To do so, it adds the global id to the number of function declarations : + We have the bijection `global_fun_id <=> global_id + fun_id_count`. +*) +let global_to_fun_id (conv : global_id_converter) (gid : A.GlobalDeclId.id) : + A.FunDeclId.id = + A.FunDeclId.of_int (A.GlobalDeclId.to_int gid + conv.fun_count) + +(* Converts a global declaration to a function declaration. + *) +let global_decl_of_json (js : json) (gid_conv : global_id_converter) : + (A.global_decl * A.fun_decl, string) result = + combine_error_msgs js "global_decl_of_json" + (match js with + | `Assoc [ ("def_id", def_id); ("name", name); ("ty", ty); ("body", body) ] + -> + let* global_id = A.GlobalDeclId.id_of_json def_id in + let fun_id = global_to_fun_id gid_conv global_id in + let* name = fun_name_of_json name in + let* ty = ety_of_json ty in + let* body = option_of_json fun_body_of_json body in + let signature : A.fun_sig = + { + region_params = []; + num_early_bound_regions = 0; + regions_hierarchy = []; + type_params = []; + inputs = []; + output = TU.ety_no_regions_to_sty ty; + } + in + Ok + ( { A.def_id = global_id; body_id = fun_id; name; ty }, + { + A.def_id = fun_id; + name; + signature; + body; + is_global_decl_body = true; + } ) | _ -> Error "") let g_declaration_group_of_json (id_of_json : json -> ('id, string) result) @@ -665,6 +701,16 @@ let fun_declaration_group_of_json (js : json) : combine_error_msgs js "fun_declaration_group_of_json" (g_declaration_group_of_json A.FunDeclId.id_of_json js) +let global_declaration_group_of_json (js : json) : + (A.GlobalDeclId.id, string) result = + combine_error_msgs js "global_declaration_group_of_json" + (match js with + | `Assoc [ ("NonRec", `List [ id ]) ] -> + let* id = A.GlobalDeclId.id_of_json id in + Ok id + | `Assoc [ ("Rec", `List [ _ ]) ] -> Error "got mutually dependent globals" + | _ -> Error "") + let declaration_group_of_json (js : json) : (M.declaration_group, string) result = combine_error_msgs js "declaration_of_json" @@ -675,8 +721,17 @@ let declaration_group_of_json (js : json) : (M.declaration_group, string) result | `Assoc [ ("Fun", `List [ decl ]) ] -> let* decl = fun_declaration_group_of_json decl in Ok (M.Fun decl) + | `Assoc [ ("Global", `List [ decl ]) ] -> + let* id = global_declaration_group_of_json decl in + Ok (M.Global id) | _ -> Error "") +let length_of_json_list (js : json) : (int, string) result = + combine_error_msgs js "get_json_list_len" + (match js with + | `List jsl -> Ok (List.length jsl) + | _ -> Error ("not a list: " ^ show js)) + let llbc_module_of_json (js : json) : (M.llbc_module, string) result = combine_error_msgs js "llbc_module_of_json" (match js with @@ -686,12 +741,32 @@ let llbc_module_of_json (js : json) : (M.llbc_module, string) result = ("declarations", declarations); ("types", types); ("functions", functions); + ("globals", globals); ] -> + (* We first deserialize the declaration groups (which simply contain ids) + * and all the declarations *butù* the globals *) let* name = string_of_json name in let* declarations = list_of_json declaration_group_of_json declarations in let* types = list_of_json type_decl_of_json types in let* functions = list_of_json fun_decl_of_json functions in - Ok { M.name; declarations; types; functions } + (* When deserializing the globals, we split the global declarations + * between the globals themselves and their bodies, which are simply + * functions with no arguments. We add the global bodies to the list + * of function declarations: the (fresh) ids we use for those bodies + * are simply given by: `num_functions + global_id` *) + let gid_conv = { fun_count = List.length functions } in + let* globals = + list_of_json (fun js -> global_decl_of_json js gid_conv) globals + in + let globals, global_bodies = List.split globals in + Ok + { + M.name; + declarations; + types; + functions = functions @ global_bodies; + globals; + } | _ -> Error "") diff --git a/src/Modules.ml b/src/Modules.ml index f52983c6..7f372d09 100644 --- a/src/Modules.ml +++ b/src/Modules.ml @@ -9,10 +9,11 @@ type type_declaration_group = TypeDeclId.id g_declaration_group type fun_declaration_group = FunDeclId.id g_declaration_group [@@deriving show] -(** Module declaration *) +(** Module declaration. Globals cannot be mutually recursive. *) type declaration_group = | Type of type_declaration_group | Fun of fun_declaration_group + | Global of GlobalDeclId.id [@@deriving show] type llbc_module = { @@ -20,11 +21,14 @@ type llbc_module = { declarations : declaration_group list; types : type_decl list; functions : fun_decl list; + globals : global_decl list; } (** LLBC module - TODO: rename to crate *) let compute_defs_maps (m : llbc_module) : - type_decl TypeDeclId.Map.t * fun_decl FunDeclId.Map.t = + type_decl TypeDeclId.Map.t + * fun_decl FunDeclId.Map.t + * global_decl GlobalDeclId.Map.t = let types_map = List.fold_left (fun m (def : type_decl) -> TypeDeclId.Map.add def.def_id def m) @@ -35,28 +39,37 @@ let compute_defs_maps (m : llbc_module) : (fun m (def : fun_decl) -> FunDeclId.Map.add def.def_id def m) FunDeclId.Map.empty m.functions in - (types_map, funs_map) + let globals_map = + List.fold_left + (fun m (def : global_decl) -> GlobalDeclId.Map.add def.def_id def m) + GlobalDeclId.Map.empty m.globals + in + (types_map, funs_map, globals_map) -(** Split a module's declarations between types and functions *) +(** Split a module's declarations between types, functions and globals *) let split_declarations (decls : declaration_group list) : - type_declaration_group list * fun_declaration_group list = + type_declaration_group list + * fun_declaration_group list + * GlobalDeclId.id list = let rec split decls = match decls with - | [] -> ([], []) + | [] -> ([], [], []) | d :: decls' -> ( - let types, funs = split decls' in + let types, funs, globals = split decls' in match d with - | Type decl -> (decl :: types, funs) - | Fun decl -> (types, decl :: funs)) + | Type decl -> (decl :: types, funs, globals) + | Fun decl -> (types, decl :: funs, globals) + | Global decl -> (types, funs, decl :: globals)) in split decls -(** Split a module's declarations into two maps from type/fun ids to +(** Split a module's declarations into three maps from type/fun/global ids to declaration groups. *) let split_declarations_to_group_maps (decls : declaration_group list) : type_declaration_group TypeDeclId.Map.t - * fun_declaration_group FunDeclId.Map.t = + * fun_declaration_group FunDeclId.Map.t + * GlobalDeclId.Set.t = let module G (M : Map.S) = struct let add_group (map : M.key g_declaration_group M.t) (group : M.key g_declaration_group) : M.key g_declaration_group M.t = @@ -68,9 +81,10 @@ let split_declarations_to_group_maps (decls : declaration_group list) : M.key g_declaration_group M.t = List.fold_left add_group M.empty groups end in - let types, funs = split_declarations decls in + let types, funs, globals = split_declarations decls in let module TG = G (TypeDeclId.Map) in let types = TG.create_map types in let module FG = G (FunDeclId.Map) in let funs = FG.create_map funs in - (types, funs) + let globals = GlobalDeclId.Set.of_list globals in + (types, funs, globals) diff --git a/src/Names.ml b/src/Names.ml index 1308eccc..209f8547 100644 --- a/src/Names.ml +++ b/src/Names.ml @@ -1,5 +1,4 @@ open Identifiers - module Disambiguator = IdGen () (** See the comments for [Name] *) @@ -49,10 +48,9 @@ type name = path_elem list [@@deriving show, ord] let to_name (ls : string list) : name = List.map (fun s -> Ident s) ls type module_name = name [@@deriving show, ord] - type type_name = name [@@deriving show, ord] - type fun_name = name [@@deriving show, ord] +type global_name = name [@@deriving show, ord] (** Filter the disambiguators equal to 0 in a name *) let filter_disambiguators_zero (n : name) : name = diff --git a/src/Print.ml b/src/Print.ml index af6fc982..c10c5989 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -13,6 +13,7 @@ let option_to_string (to_string : 'a -> string) (x : 'a option) : string = let name_to_string (name : name) : string = Names.name_to_string name let fun_name_to_string (name : fun_name) : string = name_to_string name +let global_name_to_string (name : global_name) : string = name_to_string name (** Pretty-printing for types *) module Types = struct @@ -686,6 +687,7 @@ module LlbcAst = struct adt_field_names : T.TypeDeclId.id -> T.VariantId.id option -> string list option; fun_decl_id_to_string : A.FunDeclId.id -> string; + global_decl_id_to_string : A.GlobalDeclId.id -> string; } let ast_to_ctx_formatter (fmt : ast_formatter) : PC.ctx_formatter = @@ -742,6 +744,10 @@ module LlbcAst = struct let def = C.ctx_lookup_fun_decl ctx def_id in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + let def = C.ctx_lookup_global_decl ctx def_id in + global_name_to_string def.name + in { rvar_to_string = ctx_fmt.PV.rvar_to_string; r_to_string = ctx_fmt.PV.r_to_string; @@ -752,10 +758,12 @@ module LlbcAst = struct adt_field_names = ctx_fmt.PV.adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } let fun_decl_to_ast_formatter (type_decls : T.type_decl T.TypeDeclId.Map.t) - (fun_decls : A.fun_decl A.FunDeclId.Map.t) (fdef : A.fun_decl) : + (fun_decls : A.fun_decl A.FunDeclId.Map.t) + (global_decls : A.global_decl A.GlobalDeclId.Map.t) (fdef : A.fun_decl) : ast_formatter = let rvar_to_string r = let rvar = T.RegionVarId.nth fdef.signature.region_params r in @@ -784,6 +792,10 @@ module LlbcAst = struct let def = A.FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + let def = A.GlobalDeclId.Map.find def_id global_decls in + global_name_to_string def.name + in { rvar_to_string; r_to_string; @@ -794,6 +806,7 @@ module LlbcAst = struct adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } let rec projection_to_string (fmt : ast_formatter) (inside : string) @@ -859,35 +872,13 @@ module LlbcAst = struct | E.Shl -> "<<" | E.Shr -> ">>" - let rec operand_constant_value_to_string (fmt : ast_formatter) - (cv : E.operand_constant_value) : string = - match cv with - | E.ConstantValue cv -> PV.constant_value_to_string cv - | E.ConstantAdt (variant_id, field_values) -> - (* This is a bit annoying, because we don't have context information - * to convert the ADT to a value, so we do the best we can in the - * simplest manner. Anyway, those printing utilitites are only used - * for debugging, and complex constant values are not common. - * We might want to store type information in the operand constant values - * in the future. - *) - let variant_id = option_to_string T.VariantId.to_string variant_id in - let field_values = - List.map (operand_constant_value_to_string fmt) field_values - in - "ConstantAdt " ^ variant_id ^ " {" - ^ String.concat ", " field_values - ^ "}" - let operand_to_string (fmt : ast_formatter) (op : E.operand) : string = match op with | E.Copy p -> "copy " ^ place_to_string fmt p | E.Move p -> "move " ^ place_to_string fmt p | E.Constant (ty, cv) -> - (* For clarity, we also print the typing information: see the comment in - * [operand_constant_value_to_string] *) "(" - ^ operand_constant_value_to_string fmt cv + ^ PV.constant_value_to_string cv ^ " : " ^ PT.ety_to_string (ast_to_etype_formatter fmt) ty ^ ")" @@ -948,6 +939,9 @@ module LlbcAst = struct match st with | A.Assign (p, rv) -> indent ^ place_to_string fmt p ^ " := " ^ rvalue_to_string fmt rv + | A.AssignGlobal { dst; global } -> + indent ^ fmt.var_id_to_string dst ^ " := global " + ^ fmt.global_decl_id_to_string global | A.FakeRead p -> indent ^ "fake_read " ^ place_to_string fmt p | A.SetDiscriminant (p, variant_id) -> (* TODO: improve this to lookup the variant name by using the def id *) @@ -1138,7 +1132,8 @@ module Module = struct (** Generate an [ast_formatter] by using a definition context in combination with the variables local to a function's definition *) let def_ctx_to_ast_formatter (type_context : T.type_decl T.TypeDeclId.Map.t) - (fun_context : A.fun_decl A.FunDeclId.Map.t) (def : A.fun_decl) : + (fun_context : A.fun_decl A.FunDeclId.Map.t) + (global_context : A.global_decl A.GlobalDeclId.Map.t) (def : A.fun_decl) : PA.ast_formatter = let rvar_to_string vid = let var = T.RegionVarId.nth def.signature.region_params vid in @@ -1160,6 +1155,10 @@ module Module = struct let def = A.FunDeclId.Map.find def_id fun_context in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + let def = A.GlobalDeclId.Map.find def_id global_context in + global_name_to_string def.name + in let var_id_to_string vid = let var = V.VarId.nth (Option.get def.body).locals vid in PA.var_to_string var @@ -1181,24 +1180,33 @@ module Module = struct var_id_to_string; adt_field_names; fun_decl_id_to_string; + global_decl_id_to_string; } (** This function pretty-prints a function definition by using a definition context *) let fun_decl_to_string (type_context : T.type_decl T.TypeDeclId.Map.t) - (fun_context : A.fun_decl A.FunDeclId.Map.t) (def : A.fun_decl) : string = - let fmt = def_ctx_to_ast_formatter type_context fun_context def in + (fun_context : A.fun_decl A.FunDeclId.Map.t) + (global_context : A.global_decl A.GlobalDeclId.Map.t) (def : A.fun_decl) : + string = + let fmt = + def_ctx_to_ast_formatter type_context fun_context global_context def + in PA.fun_decl_to_string fmt "" " " def let module_to_string (m : M.llbc_module) : string = - let types_defs_map, funs_defs_map = M.compute_defs_maps m in + let types_defs_map, funs_defs_map, globals_defs_map = + M.compute_defs_maps m + in (* The types *) let type_decls = List.map (type_decl_to_string types_defs_map) m.M.types in (* The functions *) let fun_decls = - List.map (fun_decl_to_string types_defs_map funs_defs_map) m.M.functions + List.map + (fun_decl_to_string types_defs_map funs_defs_map globals_defs_map) + m.M.functions in (* Put everything together *) @@ -1255,11 +1263,6 @@ module EvalCtxLlbcAst = struct let fmt = PC.eval_ctx_to_ctx_formatter ctx in PV.typed_avalue_to_string fmt v - let operand_constant_value_to_string (ctx : C.eval_ctx) - (cv : E.operand_constant_value) : string = - let fmt = PA.eval_ctx_to_ast_formatter ctx in - PA.operand_constant_value_to_string fmt cv - let place_to_string (ctx : C.eval_ctx) (op : E.place) : string = let fmt = PA.eval_ctx_to_ast_formatter ctx in PA.place_to_string fmt op diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 5e817dde..0a7091f0 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -2,17 +2,6 @@ open Pure open PureUtils -module T = Types -module V = Values -module E = Expressions -module A = LlbcAst -module TypeDeclId = T.TypeDeclId -module TypeVarId = T.TypeVarId -module RegionId = T.RegionId -module VariantId = T.VariantId -module FieldId = T.FieldId -module SymbolicValueId = V.SymbolicValueId -module FunDeclId = A.FunDeclId type type_formatter = { type_var_id_to_string : TypeVarId.id -> string; @@ -44,7 +33,8 @@ type ast_formatter = { adt_field_to_string : TypeDeclId.id -> VariantId.id option -> FieldId.id -> string option; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; - fun_decl_id_to_string : A.FunDeclId.id -> string; + fun_decl_id_to_string : FunDeclId.id -> string; + global_decl_id_to_string : GlobalDeclId.id -> string; } let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = @@ -62,6 +52,7 @@ let ast_to_type_formatter (fmt : ast_formatter) : type_formatter = let name_to_string = Print.name_to_string let fun_name_to_string = Print.fun_name_to_string +let global_name_to_string = Print.global_name_to_string let option_to_string = Print.option_to_string let type_var_to_string = Print.Types.type_var_to_string let integer_type_to_string = Print.Types.integer_type_to_string @@ -86,8 +77,9 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) while we only need those definitions to lookup proper names for the def ids. *) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) - (fun_decls : A.fun_decl FunDeclId.Map.t) (type_params : type_var list) : - ast_formatter = + (fun_decls : A.fun_decl FunDeclId.Map.t) + (global_decls : A.global_decl GlobalDeclId.Map.t) + (type_params : type_var list) : ast_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in type_var_to_string var @@ -110,9 +102,13 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) Print.LlbcAst.type_ctx_to_adt_field_to_string_fun type_decls in let fun_decl_id_to_string def_id = - let def = A.FunDeclId.Map.find def_id fun_decls in + let def = FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + let def = GlobalDeclId.Map.find def_id global_decls in + global_name_to_string def.name + in { type_var_id_to_string; type_decl_id_to_string; @@ -121,6 +117,7 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } let type_id_to_string (fmt : type_formatter) (id : type_id) : string = @@ -481,6 +478,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) let qualif_s = match qualif.id with | Func fun_id -> fun_id_to_string fmt fun_id + | Global global_id -> fmt.global_decl_id_to_string global_id | AdtCons adt_cons_id -> let variant_s = adt_variant_to_string diff --git a/src/PrintSymbolicAst.ml b/src/PrintSymbolicAst.ml index 0ab68efc..e44b422a 100644 --- a/src/PrintSymbolicAst.ml +++ b/src/PrintSymbolicAst.ml @@ -7,6 +7,7 @@ open Errors open Identifiers +open FunIdentifier module T = Types module TU = TypesUtils module V = Values @@ -20,7 +21,7 @@ module PT = Print.Types type formatting_ctx = { type_context : C.type_context; - fun_context : A.fun_decl A.FunDeclId.Map.t; + fun_context : A.fun_decl FunDeclId.Map.t; type_vars : T.type_var list; } diff --git a/src/Pure.ml b/src/Pure.ml index 5834b87f..afda2caa 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -11,6 +11,7 @@ module VariantId = T.VariantId module FieldId = T.FieldId module SymbolicValueId = V.SymbolicValueId module FunDeclId = A.FunDeclId +module GlobalDeclId = A.GlobalDeclId module SynthPhaseId = IdGen () (** We give an identifier to every phase of the synthesis (forward, backward @@ -303,6 +304,7 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] type qualif_id = | Func of fun_id + | Global of GlobalDeclId.id | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) | Proj of projection (** Field projector *) [@@deriving show] @@ -575,5 +577,6 @@ type fun_decl = { (to identify the forward/backward functions) later. *) signature : fun_sig; + is_global_decl_body : bool; body : fun_body option; } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 826283ae..c8ebfa6b 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -586,45 +586,47 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) match (monadic, lv.value) with | false, PatVar (lv_var, _) -> (* We can filter if: *) - let filter = false in - (* 1. Either: - * - the left variable is unnamed or [inline_named] is true - * - the right-expression is a variable - *) - let filter = + (* 1. the left variable is unnamed or [inline_named] is true *) + let filter_left = match (inline_named, lv_var.basename) with - | true, _ | _, None -> is_var re - | _ -> filter + | true, _ | _, None -> true + | _ -> false + in + (* And either: + * 2.1 the right-expression is a variable or a global *) + let var_or_global = is_var re || is_global re in + (* Or: + * 2.2 the right-expression is a constant value, an ADT value, + * a projection or a primitive function call *and* the flag + * `inline_pure` is set *) + let pure_re = + is_const re + || + let app, _ = destruct_apps re in + match app.e with + | Qualif qualif -> ( + match qualif.id with + | AdtCons _ -> true (* ADT constructor *) + | Proj _ -> true (* Projector *) + | Func (Unop _ | Binop _) -> + true (* primitive function call *) + | Func (Regular _) -> false (* non-primitive function call *) + | _ -> false) + | _ -> false in - (* 2. Or: - * - the left variable is an unnamed variable - * - the right-expression is a value or a primitive function call - *) let filter = - if inline_pure then - let app, _ = destruct_apps re in - match app.e with - | Const _ | Var _ -> true (* constant or variable *) - | Qualif qualif -> ( - match qualif.id with - | AdtCons _ | Proj _ -> true (* ADT constructor *) - | Func (Unop _ | Binop _) -> - true (* primitive function call *) - | Func (Regular _) -> - false (* non-primitive function call *)) - | _ -> filter - else false + filter_left && (var_or_global || (inline_pure && pure_re)) in - (* Update the environment and continue the exploration *) + (* Update the rhs (we may perform substitutions inside, and it is + * better to do them *before* we inline it *) let re = self#visit_texpression env re in - (* TODO: once rvalues and expressions are merged, filter the - * let-binding (note that for now we leave it, expect it to - * become useless, and wait for a subsequent pass to filter it) *) - (* let env = add_subst lv_var.id re env in *) + (* Update the substitution environment *) let env = if filter then VarId.Map.add lv_var.id re env else env in + (* Update the next expression *) let e = self#visit_texpression env e in - Let (monadic, lv, re, e) + (* Reconstruct the `let`, only if the binding is not filtered *) + if filter then e.e else Let (monadic, lv, re, e) | _ -> super#visit_Let env monadic lv re e (** Visit the let-bindings to filter the useless ones (and update the substitution map while doing so *) diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index 1c530011..07a1732c 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -29,9 +29,8 @@ module StringSet = Collections.MakeSet (Collections.OrderedString) module StringMap = Collections.MakeMap (Collections.OrderedString) type name = Names.name - type type_name = Names.type_name - +type global_name = Names.global_name type fun_name = Names.fun_name (* TODO: this should a module we give to a functor! *) @@ -71,6 +70,8 @@ type formatter = { *) type_name : type_name -> string; (** Provided a basename, compute a type name. *) + global_name : global_name -> string; + (** Provided a basename, compute a global name. *) fun_name : A.fun_id -> fun_name -> @@ -83,16 +84,16 @@ type formatter = { function is an assumed function or a local function - function basename - number of region groups + - region group information in case of a backward function + (`None` if forward function) - pair: - do we generate the forward function (it may have been filtered)? - the number of extracted backward functions (not necessarily equal to the number of region groups, because we may have filtered some of them) - - region group information in case of a backward function - (`None` if forward function) TODO: use the fun id for the assumed functions. *) - decreases_clause_name : FunDeclId.id -> fun_name -> string; + decreases_clause_name : A.FunDeclId.id -> fun_name -> string; (** Generates the name of the definition used to prove/reason about termination. The generated code uses this clause where needed, but its body must be defined by the user. @@ -184,6 +185,7 @@ type formatter = { (** We use identifiers to look for name clashes *) type id = + | GlobalId of A.GlobalDeclId.id | FunId of A.fun_id * RegionGroupId.id option | DecreasesClauseId of A.fun_id (** The definition which provides the decreases/termination clause. @@ -224,11 +226,8 @@ module IdOrderedType = struct type t = id let compare = compare_id - let to_string = show_id - let pp_t = pp_id - let show_t = show_id end @@ -340,6 +339,7 @@ type extraction_ctx = { (** Debugging function *) let id_to_string (id : id) (ctx : extraction_ctx) : string = + let global_decls = ctx.trans_ctx.global_context.global_decls in let fun_decls = ctx.trans_ctx.fun_context.fun_decls in let type_decls = ctx.trans_ctx.type_context.type_decls in (* TODO: factorize the pretty-printing with what is in PrintPure *) @@ -352,11 +352,14 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | Tuple -> failwith "Unreachable" in match id with + | GlobalId gid -> + let name = (A.GlobalDeclId.Map.find gid global_decls).name in + "global name: " ^ Print.global_name_to_string name | FunId (fid, rg_id) -> let fun_name = match fid with | A.Regular fid -> - Print.fun_name_to_string (FunDeclId.Map.find fid fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name | A.Assumed aid -> A.show_assumed_fun_id aid in let fun_kind = @@ -369,7 +372,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let fun_name = match fid with | A.Regular fid -> - Print.fun_name_to_string (FunDeclId.Map.find fid fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name | A.Assumed aid -> A.show_assumed_fun_id aid in "decreases clause for function: " ^ fun_name @@ -440,11 +443,14 @@ let ctx_get (id : id) (ctx : extraction_ctx) : string = log#serror ("Could not find: " ^ id_to_string id ctx); raise Not_found +let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = + ctx_get (GlobalId id) ctx + let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = ctx_get (FunId (id, rg)) ctx -let ctx_get_local_function (id : FunDeclId.id) (rg : RegionGroupId.id option) +let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = ctx_get_function (A.Regular id) rg ctx @@ -475,7 +481,7 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = ctx_get (VariantId (def_id, variant_id)) ctx -let ctx_get_decreases_clause (def_id : FunDeclId.id) (ctx : extraction_ctx) : +let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : string = ctx_get (DecreasesClauseId (A.Regular def_id)) ctx @@ -568,12 +574,24 @@ let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : let name = ctx.fmt.decreases_clause_name def.def_id def.basename in ctx_add (DecreasesClauseId (A.Regular def.def_id)) name ctx +let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : + extraction_ctx = + let name = ctx.fmt.global_name def.name in + let decl = GlobalId def.def_id in + let body = FunId (Regular def.body_id, None) in + let ctx = ctx_add decl (name ^ "_c") ctx in + let ctx = ctx_add body (name ^ "_body") ctx in + ctx + let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = + (* Sanity check: the function should not be a global body - those are handled + * separately *) + assert (not def.is_global_decl_body); (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in let llbc_def = - FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls + A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls in let sg = llbc_def.signature in let num_rgs = List.length sg.regions_hierarchy in @@ -598,9 +616,7 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) let name = ctx.fmt.fun_name def_id def.basename num_rgs rg_info (keep_fwd, num_backs) in - (* Add the function name *) - let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in - ctx + ctx_add (FunId (def_id, def.back_id)) name ctx type names_map_init = { keywords : string list; diff --git a/src/PureTypeCheck.ml b/src/PureTypeCheck.ml index 8848ff20..5aefb0be 100644 --- a/src/PureTypeCheck.ml +++ b/src/PureTypeCheck.ml @@ -40,6 +40,8 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) type tc_ctx = { type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *) + global_decls : A.global_decl A.GlobalDeclId.Map.t; + (** The global declarations *) env : ty VarId.Map.t; (** Environment from variables to types *) } @@ -112,6 +114,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = | Qualif qualif -> ( match qualif.id with | Func _ -> () (* TODO *) + | Global _ -> () (* TODO *) | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 8d3b5258..c3d4c983 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -173,6 +173,12 @@ let is_var (e : texpression) : bool = let as_var (e : texpression) : VarId.id = match e.e with Var v -> v | _ -> raise (Failure "Unreachable") +let is_global (e : texpression) : bool = + match e.e with Qualif { id = Global _; _ } -> true | _ -> false + +let is_const (e : texpression) : bool = + match e.e with Const _ -> true | _ -> false + (** Remove the external occurrences of [Meta] *) let rec unmeta (e : texpression) : texpression = match e.e with Meta (_, e) -> unmeta e | _ -> e @@ -399,6 +405,11 @@ let type_decl_is_enum (def : T.type_decl) : bool = let mk_state_ty : ty = Adt (Assumed State, []) let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) +let unwrap_result_ty (ty : ty) : ty = + match ty with + | Adt (Assumed Result, [ ty ]) -> ty + | _ -> failwith "not a result type" + let mk_result_fail_texpression (ty : ty) : texpression = let type_args = [ ty ] in let ty = Adt (Assumed Result, type_args) in diff --git a/src/Scalars.ml b/src/Scalars.ml index 3324c24b..03ca506c 100644 --- a/src/Scalars.ml +++ b/src/Scalars.ml @@ -4,43 +4,24 @@ open Values (** The minimum/maximum values an integer type can have depending on its type *) let i8_min = Z.of_string "-128" - let i8_max = Z.of_string "127" - let i16_min = Z.of_string "-32768" - let i16_max = Z.of_string "32767" - let i32_min = Z.of_string "-2147483648" - let i32_max = Z.of_string "2147483647" - let i64_min = Z.of_string "-9223372036854775808" - let i64_max = Z.of_string "9223372036854775807" - let i128_min = Z.of_string "-170141183460469231731687303715884105728" - let i128_max = Z.of_string "170141183460469231731687303715884105727" - let u8_min = Z.of_string "0" - let u8_max = Z.of_string "255" - let u16_min = Z.of_string "0" - let u16_max = Z.of_string "65535" - let u32_min = Z.of_string "0" - let u32_max = Z.of_string "4294967295" - let u64_min = Z.of_string "0" - let u64_max = Z.of_string "18446744073709551615" - let u128_min = Z.of_string "0" - let u128_max = Z.of_string "340282366920938463463374607431768211455" (** Being a bit conservative about isize/usize: depending on the system, @@ -48,11 +29,8 @@ let u128_max = Z.of_string "340282366920938463463374607431768211455" want to take that into account in the future *) let isize_min = i32_min - let isize_max = i32_max - let usize_min = u32_min - let usize_max = u32_max (** Check that an integer value is in range *) diff --git a/src/Substitute.ml b/src/Substitute.ml index 711e438b..5a21e637 100644 --- a/src/Substitute.ml +++ b/src/Substitute.ml @@ -210,12 +210,6 @@ let place_substitute (_tsubst : T.TypeVarId.id -> T.ety) (p : E.place) : E.place (* There is nothing to do *) p -(** Apply a type substitution to an operand constant value *) -let operand_constant_value_substitute (_tsubst : T.TypeVarId.id -> T.ety) - (op : E.operand_constant_value) : E.operand_constant_value = - (* There is nothing to do *) - op - (** Apply a type substitution to an operand *) let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : E.operand = @@ -225,9 +219,7 @@ let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : | E.Move p -> E.Move (p_subst p) | E.Constant (ety, cv) -> let rsubst x = x in - E.Constant - ( ty_substitute rsubst tsubst ety, - operand_constant_value_substitute tsubst cv ) + E.Constant (ty_substitute rsubst tsubst ety, cv) (** Apply a type substitution to an rvalue *) let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) (rv : E.rvalue) : @@ -289,6 +281,9 @@ let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) let p = place_substitute tsubst p in let rvalue = rvalue_substitute tsubst rvalue in A.Assign (p, rvalue) + | A.AssignGlobal g -> + (* Globals don't have type parameters *) + A.AssignGlobal g | A.FakeRead p -> let p = place_substitute tsubst p in A.FakeRead p diff --git a/src/SymbolicAst.ml b/src/SymbolicAst.ml index 9cab092d..ec2a80ca 100644 --- a/src/SymbolicAst.ml +++ b/src/SymbolicAst.ml @@ -65,6 +65,8 @@ type expression = | Panic | FunCall of call * expression | EndAbstraction of V.abs * expression + | EvalGlobal of A.GlobalDeclId.id * V.symbolic_value * expression + (** Evaluate a global to a fresh symbolic value *) | Expansion of mplace option * V.symbolic_value * expansion (** Expansion of a symbolic value. diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 4c2ba4c8..f321ce8c 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -67,11 +67,13 @@ type fun_sig_named_outputs = { } type fun_context = { - llbc_fun_decls : A.fun_decl FunDeclId.Map.t; + llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_sigs : fun_sig_named_outputs RegularFunIdMap.t; (** *) - fun_infos : FA.fun_info FunDeclId.Map.t; + fun_infos : FA.fun_info A.FunDeclId.Map.t; } +type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } + type call_info = { forward : S.call; forward_inputs : texpression list; @@ -95,6 +97,7 @@ type call_info = { type bs_ctx = { type_context : type_context; fun_context : fun_context; + global_context : global_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) sg : fun_sig; @@ -122,25 +125,39 @@ type bs_ctx = { let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = let env = VarId.Map.empty in - let ctx = { PureTypeCheck.type_decls = ctx.type_context.type_decls; env } in + let ctx = + { + PureTypeCheck.type_decls = ctx.type_context.type_decls; + global_decls = ctx.global_context.llbc_global_decls; + env; + } + in let _ = PureTypeCheck.check_typed_pattern ctx v in () let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = let env = VarId.Map.empty in - let ctx = { PureTypeCheck.type_decls = ctx.type_context.type_decls; env } in + let ctx = + { + PureTypeCheck.type_decls = ctx.type_context.type_decls; + global_decls = ctx.global_context.llbc_global_decls; + env; + } + in PureTypeCheck.check_texpression ctx e (* TODO: move *) let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.LlbcAst.ast_formatter = Print.LlbcAst.fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls - ctx.fun_context.llbc_fun_decls ctx.fun_decl + ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls + ctx.fun_decl let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = let type_params = ctx.fun_decl.signature.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - PrintPure.mk_ast_formatter type_decls fun_decls type_params + let global_decls = ctx.global_context.llbc_global_decls in + PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params let ty_to_string (ctx : bs_ctx) (ty : ty) : string = let fmt = bs_ctx_to_pp_ast_formatter ctx in @@ -161,14 +178,20 @@ let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = let type_params = sg.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let global_decls = ctx.global_context.llbc_global_decls in + let fmt = + PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let global_decls = ctx.global_context.llbc_global_decls in + let fmt = + PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + in PrintPure.fun_decl_to_string fmt def (* TODO: move *) @@ -195,12 +218,12 @@ let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = TypeDeclId.Map.find id ctx.type_context.llbc_type_decls -let bs_ctx_lookup_llbc_fun_decl (id : FunDeclId.id) (ctx : bs_ctx) : A.fun_decl - = - FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls +let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : + A.fun_decl = + A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls (* TODO: move *) -let bs_ctx_lookup_local_function_sig (def_id : FunDeclId.id) +let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = let id = (A.Regular def_id, back_id) in (RegularFunIdMap.find id ctx.fun_context.fun_sigs).sg @@ -471,17 +494,14 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids (** Small utility. *) -let get_fun_effect_info (fun_infos : FA.fun_info FunDeclId.Map.t) +let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> - let info = FunDeclId.Map.find fid fun_infos in + let info = A.FunDeclId.Map.find fid fun_infos in let input_state = info.stateful in let output_state = input_state && gid = None in - (* We ignore on purpose info.can_fail: the result of the analysis is not - * used yet to adjust the translation so that the functions which syntactically - * can't fail don't use an error monad. *) - { can_fail = true; input_state; output_state } + { can_fail = info.can_fail; input_state; output_state } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; @@ -496,7 +516,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function). *) -let translate_fun_sig (fun_infos : FA.fun_info FunDeclId.Map.t) +let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = @@ -1058,6 +1078,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | Panic -> translate_panic ctx | FunCall (call, e) -> translate_function_call config call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx + | EvalGlobal (gid, sv, e) -> translate_global_eval config gid sv e ctx | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx @@ -1444,6 +1465,17 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) mk_let monadic given_back (mk_texpression_from_var input_var) e) given_back_inputs next_e +and translate_global_eval (config : config) (gid : A.GlobalDeclId.id) + (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = + let ctx, var = fresh_var_for_symbolic_value sval ctx in + let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in + let global_expr = { id = Global gid; type_args = [] } in + (* We use translate_fwd_ty to translate the global type *) + let ty = ctx_translate_fwd_ty ctx decl.ty in + let gval = { e = Qualif global_expr; ty } in + let e = translate_expression config e ctx in + mk_let false (mk_typed_pattern_from_var var None) gval e + and translate_expansion (config : config) (p : S.mplace option) (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) @@ -1722,7 +1754,16 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) Some { inputs; inputs_lvs; body } in (* Assemble the declaration *) - let def = { def_id; back_id = bid; basename; signature; body } in + let def = + { + def_id; + back_id = bid; + basename; + signature; + is_global_decl_body = def.is_global_decl_body; + body; + } + in (* Debugging *) log#ldebug (lazy @@ -1746,7 +1787,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info FunDeclId.Map.t) +let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = diff --git a/src/SynthesizeSymbolic.ml b/src/SynthesizeSymbolic.ml index 95da38e6..a2256bdd 100644 --- a/src/SynthesizeSymbolic.ml +++ b/src/SynthesizeSymbolic.ml @@ -114,6 +114,10 @@ let synthesize_function_call (call_id : call_id) in Some (FunCall (call, expr)) +let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) + (expr : expression option) : expression option = + match expr with None -> None | Some e -> Some (EvalGlobal (gid, dest, e)) + let synthesize_regular_function_call (fun_id : A.fun_id) (call_id : V.FunCallId.id) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) (args : V.typed_value list) diff --git a/src/Translate.ml b/src/Translate.ml index 57b92e44..61300ed8 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -63,7 +63,7 @@ let translate_function_to_symbolics (config : C.partial_config) ("translate_function_to_symbolics: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context } = trans_ctx in + let { type_context; fun_context; global_context } = trans_ctx in let fun_context = { C.fun_decls = fun_context.fun_decls } in match fdef.body with @@ -74,7 +74,7 @@ let translate_function_to_symbolics (config : C.partial_config) let evaluate gid = let inputs, symb = evaluate_function_symbolic config synthesize type_context fun_context - fdef gid + global_context fdef gid in (inputs, Option.get symb) in @@ -106,7 +106,7 @@ let translate_function_to_pure (config : C.partial_config) (lazy ("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context } = trans_ctx in + let { type_context; fun_context; global_context } = trans_ctx in let def_id = fdef.def_id in (* Compute the symbolic ASTs, if the function is transparent *) @@ -140,6 +140,9 @@ let translate_function_to_pure (config : C.partial_config) fun_infos = fun_context.fun_infos; } in + let global_context = + { SymbolicToPure.llbc_global_decls = global_context.global_decls } + in let ctx = { SymbolicToPure.bid = None; @@ -151,6 +154,7 @@ let translate_function_to_pure (config : C.partial_config) state_var; type_context; fun_context; + global_context; fun_decl = fdef; forward_inputs = []; (* Empty for now *) @@ -288,10 +292,15 @@ let translate_module_to_pure (config : C.partial_config) log#ldebug (lazy "translate_module_to_pure"); (* Compute the type and function contexts *) - let type_context, fun_context = compute_type_fun_contexts m in - let fun_infos = FA.analyze_module m fun_context.C.fun_decls use_state in + let type_context, fun_context, global_context = + compute_type_fun_global_contexts m + in + let fun_infos = + FA.analyze_module m fun_context.C.fun_decls global_context.C.global_decls + use_state + in let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in - let trans_ctx = { type_context; fun_context } in + let trans_ctx = { type_context; fun_context; global_context } in (* Translate all the type definitions *) let type_decls = SymbolicToPure.translate_type_decls m.types in @@ -351,8 +360,8 @@ type gen_ctx = { m : M.llbc_module; extract_ctx : PureToExtract.extraction_ctx; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) Pure.FunDeclId.Map.t; - functions_with_decreases_clause : Pure.FunDeclId.Set.t; + trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; + functions_with_decreases_clause : A.FunDeclId.Set.t; } (** Extraction context *) @@ -388,7 +397,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = ctx.trans_types in let has_opaque_funs = - Pure.FunDeclId.Map.exists + A.FunDeclId.Map.exists (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) -> Option.is_none t_fwd.body) ctx.trans_funs @@ -427,7 +436,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - Pure.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause in (* In case of (non-mutually) recursive functions, we use a simple procedure to @@ -499,6 +508,24 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) pure_ls in + (* TODO: Check correct behaviour with opaque globals *) + let export_global (id : A.GlobalDeclId.id) : unit = + let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in + let global = A.GlobalDeclId.Map.find id global_decls in + let _, (body, body_backs) = + A.FunDeclId.Map.find global.body_id ctx.trans_funs + in + assert (List.length body_backs = 0); + + let is_opaque = Option.is_none body.Pure.body in + if + ((not is_opaque) && config.extract_transparent) + || (is_opaque && config.extract_opaque) + then + ExtractToFStar.extract_global_decl ctx.extract_ctx fmt global body + config.interface + in + let export_state_type () : unit = let qualif = if config.interface then ExtractToFStar.TypeVal @@ -523,17 +550,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) ids | Fun (NonRec id) -> (* Lookup *) - let pure_fun = Pure.FunDeclId.Map.find id ctx.trans_funs in + let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in (* Translate *) export_functions false [ pure_fun ] | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) let pure_funs = - List.map (fun id -> Pure.FunDeclId.Map.find id ctx.trans_funs) ids + List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids in (* Translate *) export_functions true pure_funs + | Global id -> export_global id in (* If we need to export the state type: we try to export it after we defined @@ -622,14 +650,14 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = - Pure.FunDeclId.Set.of_list + A.FunDeclId.Set.of_list (List.concat (List.map (fun decl -> match decl with M.Fun (Rec ids) -> ids | _ -> []) m.declarations)) in - (* Register unique names for all the top-level types and functions. + (* Register unique names for all the top-level types, globals and functions. * Note that the order in which we generate the names doesn't matter: * we just need to generate a mapping from identifier to name, and make * sure there are no name clashes. *) @@ -642,15 +670,25 @@ let translate_module (filename : string) (dest_dir : string) (config : config) let ctx = List.fold_left (fun ctx (keep_fwd, def) -> - (* Note that we generate a decrease clause for all the recursive functions *) + (* We generate a decrease clause for all the recursive functions *) let gen_decr_clause = - Pure.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions + A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions in - ExtractToFStar.extract_fun_decl_register_names ctx keep_fwd - gen_decr_clause def) + (* Register the names, only if the function is not a global body - + * those are handled later *) + let is_global = (fst def).Pure.is_global_decl_body in + if is_global then ctx + else + ExtractToFStar.extract_fun_decl_register_names ctx keep_fwd + gen_decr_clause def) ctx trans_funs in + let ctx = + List.fold_left ExtractToFStar.extract_global_decl_register_names ctx + m.globals + in + (* Open the output file *) (* First compute the filename by replacing the extension and converting the * case (rust module names are snake case) *) @@ -674,7 +712,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) in let trans_funs = - Pure.FunDeclId.Map.of_list + A.FunDeclId.Map.of_list (List.map (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> (fd.def_id, (keep_fwd, (fd, bdl)))) @@ -761,7 +799,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (* Extract the template clauses *) let needs_clauses_module = config.extract_decreases_clauses - && not (Pure.FunDeclId.Set.is_empty rec_functions) + && not (A.FunDeclId.Set.is_empty rec_functions) in (if needs_clauses_module && config.extract_template_decreases_clauses then let clauses_filename = extract_filebasename ^ ".Clauses.Template.fst" in diff --git a/src/TranslateCore.ml b/src/TranslateCore.ml index 17c35cbf..326bb05f 100644 --- a/src/TranslateCore.ml +++ b/src/TranslateCore.ml @@ -19,7 +19,13 @@ type fun_context = { } [@@deriving show] -type trans_ctx = { type_context : type_context; fun_context : fun_context } +type global_context = C.global_context [@@deriving show] + +type trans_ctx = { + type_context : type_context; + fun_context : fun_context; + global_context : global_context; +} type pure_fun_translation = Pure.fun_decl * Pure.fun_decl list @@ -39,16 +45,22 @@ let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = let type_params = sg.type_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let global_decls = ctx.global_context.global_decls in + let fmt = + PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let global_decls = ctx.global_context.global_decls in + let fmt = + PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + in PrintPure.fun_decl_to_string fmt def -let fun_decl_id_to_string (ctx : trans_ctx) (id : Pure.FunDeclId.id) : string = +let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string = Print.fun_name_to_string - (Pure.FunDeclId.Map.find id ctx.fun_context.fun_decls).name + (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name diff --git a/src/Types.ml b/src/Types.ml index 5ff407c9..5bd172cb 100644 --- a/src/Types.ml +++ b/src/Types.ml @@ -1,12 +1,8 @@ open Identifiers open Names - module TypeVarId = IdGen () - module TypeDeclId = IdGen () - module VariantId = IdGen () - module FieldId = IdGen () module RegionVarId = IdGen () @@ -24,7 +20,6 @@ type ('id, 'name) indexed_var = { [@@deriving show] type type_var = (TypeVarId.id, string) indexed_var [@@deriving show] - type region_var = (RegionVarId.id, string option) indexed_var [@@deriving show] (** A region. @@ -82,13 +77,10 @@ type integer_type = [@@deriving show, ord] let all_signed_int_types = [ Isize; I8; I16; I32; I64; I128 ] - let all_unsigned_int_types = [ Usize; U8; U16; U32; U64; U128 ] - let all_int_types = List.append all_signed_int_types all_unsigned_int_types type ref_kind = Mut | Shared [@@deriving show, ord] - type assumed_ty = Box | Vec | Option [@@deriving show, ord] (** The variant id for `Option::None` *) @@ -109,15 +101,10 @@ type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty class ['self] iter_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - method visit_'r : 'env -> 'r -> unit = fun _ _ -> () - method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> () - method visit_type_id : 'env -> type_id -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () - method visit_ref_kind : 'env -> ref_kind -> unit = fun _ _ -> () end @@ -125,11 +112,8 @@ class ['self] iter_ty_base = class ['self] map_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.map - method visit_'r : 'env -> 'r -> 'r = fun _ r -> r - method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id - method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id method visit_integer_type : 'env -> integer_type -> integer_type = @@ -196,7 +180,6 @@ type ety = erased_region ty [@@deriving show, ord] *) type field = { field_name : string option; field_ty : sty } [@@deriving show] - type variant = { variant_name : string; fields : field list } [@@deriving show] type type_decl_kind = diff --git a/src/TypesUtils.ml b/src/TypesUtils.ml index bee7956e..b5ea6fca 100644 --- a/src/TypesUtils.ml +++ b/src/TypesUtils.ml @@ -87,7 +87,6 @@ let rty_regions (ty : rty) : RegionId.Set.t = let obj = object inherit [_] iter_ty - method! visit_'r _env r = add_region r end in @@ -100,28 +99,31 @@ let rty_regions_intersect (ty : rty) (regions : RegionId.Set.t) : bool = let ty_regions = rty_regions ty in not (RegionId.Set.disjoint ty_regions regions) -(** Convert an [ety], containing no region variables, to an [rty]. +(** Convert an [ety], containing no region variables, to an [rty] or an [sty]. In practice, it is the identity. *) -let rec ety_no_regions_to_rty (ty : ety) : rty = +let rec ety_no_regions_to_gr_ty (ty : ety) : 'a gr_ty = match ty with | Adt (type_id, regions, tys) -> assert (regions = []); - Adt (type_id, [], List.map ety_no_regions_to_rty tys) + Adt (type_id, [], List.map ety_no_regions_to_gr_ty tys) | TypeVar v -> TypeVar v | Bool -> Bool | Char -> Char | Never -> Never | Integer int_ty -> Integer int_ty | Str -> Str - | Array ty -> Array (ety_no_regions_to_rty ty) - | Slice ty -> Slice (ety_no_regions_to_rty ty) + | Array ty -> Array (ety_no_regions_to_gr_ty ty) + | Slice ty -> Slice (ety_no_regions_to_gr_ty ty) | Ref (_, _, _) -> failwith "Can't convert a ref with erased regions to a ref with non-erased \ regions" +let ety_no_regions_to_rty (ty : ety) : rty = ety_no_regions_to_gr_ty ty +let ety_no_regions_to_sty (ty : ety) : sty = ety_no_regions_to_gr_ty ty + (** Retuns true if the type contains borrows. Note that we can't simply explore the type and look for regions: sometimes diff --git a/src/Values.ml b/src/Values.ml index 4585b443..fb927fb5 100644 --- a/src/Values.ml +++ b/src/Values.ml @@ -65,6 +65,7 @@ type sv_kind = *) | SynthInputGivenBack (** The value was given back upon ending one of the input abstractions *) + | Global (** The value is a global *) [@@deriving show] type symbolic_value = { diff --git a/src/ValuesUtils.ml b/src/ValuesUtils.ml index 2814615c..bc205622 100644 --- a/src/ValuesUtils.ml +++ b/src/ValuesUtils.ml @@ -11,7 +11,6 @@ let mk_unit_value : typed_value = { value = Adt { variant_id = None; field_values = [] }; ty = mk_unit_ty } let mk_typed_value (ty : ety) (value : value) : typed_value = { value; ty } - let mk_bottom (ty : ety) : typed_value = { value = Bottom; ty } (** Box a value *) @@ -38,7 +37,6 @@ let borrows_in_value (v : typed_value) : bool = let obj = object inherit [_] iter_typed_value - method! visit_borrow_content _env _ = raise Found end in @@ -53,7 +51,6 @@ let inactivated_in_value (v : typed_value) : bool = let obj = object inherit [_] iter_typed_value - method! visit_InactivatedMutBorrow _env _ = raise Found end in @@ -68,7 +65,6 @@ let loans_in_value (v : typed_value) : bool = let obj = object inherit [_] iter_typed_value - method! visit_loan_content _env _ = raise Found end in @@ -84,9 +80,7 @@ let outer_loans_in_value (v : typed_value) : bool = let obj = object inherit [_] iter_typed_value - method! visit_loan_content _env _ = raise Found - method! visit_borrow_content _ _ = () end in @@ -1,21 +1,25 @@ ;; core: for Core.Unix.mkdir_p + (executable (name main) - (preprocess (pps ppx_deriving.show ppx_deriving.ord visitors.ppx)) + (preprocess + (pps ppx_deriving.show ppx_deriving.ord visitors.ppx)) (libraries ppx_deriving yojson zarith easy_logging core_unix)) (env - (dev (flags - :standard - -safe-string - -g - ;-dsource - -warn-error -5-8-9-11-14-33-20-21-26-27-39 - )) - (release (flags - :standard - -safe-string - -g - ;-dsource - -warn-error -5-8-9-11-14-33-20-21-26-27-39 - ))) + (dev + (flags + :standard + -safe-string + -g + ;-dsource + -warn-error + -5-8-9-11-14-33-20-21-26-27-39)) + (release + (flags + :standard + -safe-string + -g + ;-dsource + -warn-error + -5-8-9-11-14-33-20-21-26-27-39))) |