diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Contexts.ml | 10 | ||||
-rw-r--r-- | src/Expressions.ml | 21 | ||||
-rw-r--r-- | src/ExtractToFStar.ml | 132 | ||||
-rw-r--r-- | src/FunsAnalysis.ml | 99 | ||||
-rw-r--r-- | src/Interpreter.ml | 2 | ||||
-rw-r--r-- | src/InterpreterExpressions.ml | 57 | ||||
-rw-r--r-- | src/InterpreterStatements.ml | 9 | ||||
-rw-r--r-- | src/InterpreterUtils.ml | 2 | ||||
-rw-r--r-- | src/LlbcAst.ml | 26 | ||||
-rw-r--r-- | src/LlbcOfJson.ml | 151 | ||||
-rw-r--r-- | src/Modules.ml | 4 | ||||
-rw-r--r-- | src/Print.ml | 70 | ||||
-rw-r--r-- | src/PrintPure.ml | 11 | ||||
-rw-r--r-- | src/PrintSymbolicAst.ml | 3 | ||||
-rw-r--r-- | src/Pure.ml | 5 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 5 | ||||
-rw-r--r-- | src/PureToExtract.ml | 28 | ||||
-rw-r--r-- | src/PureTypeCheck.ml | 2 | ||||
-rw-r--r-- | src/Substitute.ml | 11 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 43 | ||||
-rw-r--r-- | src/Translate.ml | 42 | ||||
-rw-r--r-- | src/TranslateCore.ml | 11 | ||||
-rw-r--r-- | src/TypesUtils.ml | 13 |
23 files changed, 477 insertions, 280 deletions
diff --git a/src/Contexts.ml b/src/Contexts.ml index a4551420..1fbc916b 100644 --- a/src/Contexts.ml +++ b/src/Contexts.ml @@ -217,7 +217,11 @@ type type_context = { } [@@deriving show] -type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show] +type fun_context = { + fun_decls : fun_decl FunDeclId.Map.t; + gid_conv : global_id_converter; +} +[@@deriving show] type eval_ctx = { type_context : type_context; @@ -255,6 +259,10 @@ let ctx_lookup_type_decl (ctx : eval_ctx) (tid : TypeDeclId.id) : type_decl = let ctx_lookup_fun_decl (ctx : eval_ctx) (fid : FunDeclId.id) : fun_decl = FunDeclId.Map.find fid ctx.fun_context.fun_decls +(** TODO: make this more efficient with maps *) +let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : fun_decl = + ctx_lookup_fun_decl ctx (global_to_fun_id ctx.fun_context.gid_conv gid) + (** Retrieve a variable's value in an environment *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = (* We take care to stop at the end of current frame: different variables diff --git a/src/Expressions.ml b/src/Expressions.ml index 6bf14c66..6645a77f 100644 --- a/src/Expressions.ml +++ b/src/Expressions.ml @@ -72,30 +72,11 @@ let all_binops = Shr; ] -(** Constant value for an operand - - It is a bit annoying, but rustc treats some ADT and tuple instances as - constants when generating MIR: - - an enumeration with one variant and no fields is a constant. - - a structure with no field is a constant. - - sometimes, Rust stores the initialization of an ADT as a constant - (if all the fields are constant) rather than as an aggregated value - - For our translation, we use the following enumeration to encode those - special cases in assignments. They are converted to "normal" values - when evaluating the assignment (which is why we don't put them in the - [ConstantValue] enumeration). - *) -type operand_constant_value = - | ConstantValue of constant_value - | ConstantAdt of VariantId.id option * operand_constant_value list -[@@deriving show] - (* TODO: symplify the operand constant values *) type operand = | Copy of place | Move of place - | Constant of ety * operand_constant_value + | Constant of ety * constant_value [@@deriving show] (** An aggregated ADT. diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 0bbe591e..20b06bfa 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -26,6 +26,14 @@ type type_decl_qualif = *) type fun_decl_qualif = Let | LetRec | And | Val | AssumeVal +let fun_decl_qualif_name (qualif : fun_decl_qualif) : string = + match qualif with + | Let -> "let" + | LetRec -> "let rec" + | And -> "and" + | Val -> "val" + | AssumeVal -> "assume val" + (** Small helper to compute the name of an int type *) let fstar_int_name (int_ty : integer_type) = match int_ty with @@ -305,16 +313,16 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* Concatenate the elements *) String.concat "_" fname in - let fun_name (_fid : A.fun_id) (fname : fun_name) (num_rgs : int) + let fun_name (_fid : A.fun_id) (fname : fun_name) (is_global : bool) (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) : string = let fname = fun_name_to_snake_case fname in (* Compute the suffix *) - let suffix = default_fun_suffix num_rgs rg filter_info in + let suffix = default_fun_suffix is_global num_rgs rg filter_info in (* Concatenate *) fname ^ suffix in - let decreases_clause_name (_fid : FunDeclId.id) (fname : fun_name) : string = + let decreases_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) : string = let fname = fun_name_to_snake_case fname in (* Compute the suffix *) let suffix = "_decreases" in @@ -898,10 +906,15 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) match qualif.id with | Func fun_id -> extract_function_call ctx fmt inside fun_id qualif.type_args args + | Global global_id -> + let fid = A.global_to_fun_id ctx.trans_ctx.fun_context.gid_conv global_id in + let fun_id = Regular (A.Regular fid, None) in + extract_function_call ctx fmt inside fun_id qualif.type_args args | AdtCons adt_cons_id -> extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args args | Proj proj -> - extract_field_projector ctx fmt inside app proj qualif.type_args args) + extract_field_projector ctx fmt inside app proj qualif.type_args args + ) | _ -> (* "Regular" expression *) (* Open parentheses *) @@ -1355,14 +1368,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hovbox fmt ctx.indent_incr; (* > "let FUN_NAME" *) let is_opaque = Option.is_none def.body in - let qualif = - match qualif with - | Let -> "let" - | LetRec -> "let rec" - | And -> "and" - | Val -> "val" - | AssumeVal -> "assume val" - in + let qualif = fun_decl_qualif_name qualif in F.pp_print_string fmt (qualif ^ " " ^ def_name); F.pp_print_space fmt (); (* Open a box for "(PARAMS) : EFFECT =" *) @@ -1471,6 +1477,108 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 +(* Change the suffix from "_c" to "_body" *) +let global_decl_to_body_name (decl : string) : string = + (* The declaration length without the global suffix *) + let base_len = String.length decl - 2 in + (* TODO: Use String.ends_with instead when a more recent version of OCaml is used *) + assert (String.sub decl base_len 2 = "_c"); + (String.sub decl 0 base_len) ^ "_body" + +(** Print a definition of the shape "QUALIF NAME : TYPE = BODY" with a custom body extractor *) +let extract_global_definition (ctx : extraction_ctx) (fmt : F.formatter) + (qualif : fun_decl_qualif) (name : string) (ty : ty) + (extract_body : (F.formatter -> unit) Option.t) + : unit = + let is_opaque = Option.is_none extract_body in + + (* Open the definition box (depth=0) *) + F.pp_open_hvbox fmt ctx.indent_incr; + + (* Open "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "QUALIF NAME " *) + F.pp_print_string fmt (fun_decl_qualif_name qualif ^ " " ^ name); + F.pp_print_space fmt (); + + (* Open ": TYPE =" box (depth=2) *) + F.pp_open_hvbox fmt 0; + (* Print ": " *) + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + + (* Open "TYPE" box (depth=3) *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print "TYPE" *) + extract_ty ctx fmt false ty; + (* Close "TYPE" box (depth=3) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + (* Print " =" *) + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + ); + (* Close ": TYPE =" box (depth=2) *) + F.pp_close_box fmt (); + (* Close "QUALIF NAME : TYPE =" box (depth=1) *) + F.pp_close_box fmt (); + + if not is_opaque then ( + F.pp_print_space fmt (); + (* Open "BODY" box (depth=1) *) + F.pp_open_hvbox fmt 0; + (* Print "BODY" *) + (Option.get extract_body) fmt; + (* Close "BODY" box (depth=1) *) + F.pp_close_box fmt () + ); + (* Close the definition box (depth=0) *) + F.pp_close_box fmt () + +(** Extract a global declaration. + This has similarity with the function extraction above (without parameters). + However, generate its body separately from its declaration to extract the result value. + + For example, + `let x = 3` + + will be translated to + `let x_body : result int = Return 3` + `let x_c : int = eval_global x_body` + *) +let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) + (qualif : fun_decl_qualif) (def : fun_decl) + : unit = + (* Sanity checks for globals *) + assert (def.is_global); + assert (Option.is_none def.back_id); + assert (List.length def.signature.inputs = 0); + assert (List.length def.signature.doutputs = 1); + assert (List.length def.signature.type_params = 0); + assert (not def.signature.info.effect_info.can_fail); + + (* Add a break then the corresponding Rust definition *) + F.pp_print_break fmt 0 0; + F.pp_print_string fmt ("(** [" ^ Print.fun_name_to_string def.basename ^ "] *)"); + F.pp_print_space fmt (); + + let def_name = ctx_get_local_function def.def_id def.back_id ctx in + match def.body with + | None -> + extract_global_definition ctx fmt qualif def_name def.signature.output None + | Some body -> + let body_name = global_decl_to_body_name def_name in + let body_ty = mk_result_ty def.signature.output in + extract_global_definition ctx fmt qualif body_name body_ty (Some (fun fmt -> + extract_texpression ctx fmt false body.body + )); + F.pp_print_break fmt 0 0; + extract_global_definition ctx fmt qualif def_name def.signature.output (Some (fun fmt -> + F.pp_print_string fmt ("eval_global " ^ body_name) + )); + F.pp_print_break fmt 0 0 + (** Extract a unit test, if the function is a unit function (takes no parameters, returns unit). diff --git a/src/FunsAnalysis.ml b/src/FunsAnalysis.ml index b8dd17d8..ee4f71c1 100644 --- a/src/FunsAnalysis.ml +++ b/src/FunsAnalysis.ml @@ -1,6 +1,6 @@ (** Compute various information, including: - can a function fail (by having `Fail` in its body, or transitively - calling a function which can fail) + calling a function which can fail), false for globals - can a function diverge (bu being recursive, containing a loop or transitively calling a function which can diverge) - does a function perform stateful operations (i.e., do we need a state @@ -50,52 +50,65 @@ let analyze_module (m : llbc_module) (funs_map : fun_decl FunDeclId.Map.t) let stateful = ref false in let divergent = ref false in - let obj = - object - inherit [_] iter_statement as super - - method! visit_Assert env a = - can_fail := true; - super#visit_Assert env a - - method! visit_rvalue _env rv = - match rv with - | Use _ | Ref _ | Discriminant _ | Aggregate _ -> () - | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail - | BinaryOp (bop, _, _) -> - can_fail := EU.binop_can_fail bop || !can_fail - - method! visit_Call env call = - (match call.func with - | Regular id -> - if FunDeclId.Set.mem id fun_ids then divergent := true - else - let info = FunDeclId.Map.find id !infos in - can_fail := !can_fail || info.can_fail; - stateful := !stateful || info.stateful; - divergent := !divergent || info.divergent - | Assumed id -> - (* None of the assumed functions is stateful for now *) - can_fail := !can_fail || Assumed.assumed_can_fail id); - super#visit_Call env call - - method! visit_Panic env = - can_fail := true; - super#visit_Panic env - - method! visit_Loop env loop = - divergent := true; - super#visit_Loop env loop - end - in - let visit_fun (f : fun_decl) : unit = - match f.body with + let obj = + object (self) + inherit [_] iter_statement as super + + method may_fail b = + (* The fail flag is disabled for globals : the global body is + * normalised into its declaration, which is always successful. + *) + if f.is_global then () else + can_fail := !can_fail || b + + method! visit_Assert env a = + self#may_fail true; + super#visit_Assert env a + + method! visit_rvalue _env rv = + match rv with + | Use _ | Ref _ | Discriminant _ | Aggregate _ -> () + | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail + | BinaryOp (bop, _, _) -> + can_fail := EU.binop_can_fail bop || !can_fail + + method! visit_Call env call = + pp_fun_id Format.std_formatter call.func; + print_newline (); + + (match call.func with + | Regular id -> + if FunDeclId.Set.mem id fun_ids then divergent := true + else + let info = FunDeclId.Map.find id !infos in + self#may_fail info.can_fail; + stateful := !stateful || info.stateful; + divergent := !divergent || info.divergent + | Assumed id -> + (* None of the assumed functions is stateful for now *) + can_fail := !can_fail || Assumed.assumed_can_fail id); + super#visit_Call env call + + method! visit_Panic env = + self#may_fail true; + super#visit_Panic env + + method! visit_Loop env loop = + divergent := true; + super#visit_Loop env loop + end + in + (match f.body with | None -> (* Opaque function *) - can_fail := true; + obj#may_fail true; stateful := use_state - | Some body -> obj#visit_statement () body.body + | Some body -> obj#visit_statement () body.body); + (* We ignore on purpose functions that cannot fail: the result of the analysis + * is not used yet to adjust the translation so that the functions which + * syntactically can't fail don't use an error monad. *) + can_fail := not f.is_global in List.iter visit_fun d; { can_fail = !can_fail; stateful = !stateful; divergent = !divergent } diff --git a/src/Interpreter.ml b/src/Interpreter.ml index cbbf2b2e..f4f01ff8 100644 --- a/src/Interpreter.ml +++ b/src/Interpreter.ml @@ -24,7 +24,7 @@ let compute_type_fun_contexts (m : M.llbc_module) : TypesAnalysis.analyze_type_declarations type_decls type_decls_list in let type_context = { C.type_decls_groups; type_decls; type_infos } in - let fun_context = { C.fun_decls } in + let fun_context = { C.fun_decls; gid_conv = m.gid_conv } in (type_context, fun_context) let initialize_eval_context (type_context : C.type_context) diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml index 6bb2baf0..04ad1b3c 100644 --- a/src/InterpreterExpressions.ml +++ b/src/InterpreterExpressions.ml @@ -1,11 +1,13 @@ module T = Types module V = Values +module LA = LlbcAst open Scalars module E = Expressions open Errors module C = Contexts module Subst = Substitute module L = Logging +module PV = Print.Values open TypesUtils open ValuesUtils module Inv = Invariants @@ -108,52 +110,27 @@ let access_rplace_reorganize (config : C.config) (expand_prim_copy : bool) ctx (** Convert an operand constant operand value to a typed value *) -let rec operand_constant_value_to_typed_value (ctx : C.eval_ctx) (ty : T.ety) - (cv : E.operand_constant_value) : V.typed_value = +let typecheck_constant_value (ty : T.ety) + (cv : V.constant_value) : V.typed_value = (* Check the type while converting - we actually need some information - * contained in the type *) + * contained in the type *) log#ldebug (lazy - ("operand_constant_value_to_typed_value:" ^ "\n- ty: " - ^ ety_to_string ctx ty ^ "\n- cv: " - ^ operand_constant_value_to_string ctx cv)); + ("typecheck_constant_value:" ^ "\n- cv: " + ^ PV.constant_value_to_string cv)); match (ty, cv) with - (* Adt *) - | ( T.Adt (adt_id, region_params, type_params), - ConstantAdt (variant_id, field_values) ) -> - assert (region_params = []); - (* Compute the types of the fields *) - let field_tys = - match adt_id with - | T.AdtId def_id -> - let def = C.ctx_lookup_type_decl ctx def_id in - assert (def.region_params = []); - Subst.type_decl_get_instantiated_field_etypes def variant_id - type_params - | T.Tuple -> type_params - | T.Assumed _ -> failwith "Unreachable" - in - (* Compute the field values *) - let field_values = - List.map - (fun (ty, v) -> operand_constant_value_to_typed_value ctx ty v) - (List.combine field_tys field_values) - in - (* Put together *) - let value = V.Adt { variant_id; field_values } in - { value; ty } (* Scalar, boolean... *) - | T.Bool, ConstantValue (Bool v) -> { V.value = V.Concrete (Bool v); ty } - | T.Char, ConstantValue (Char v) -> { V.value = V.Concrete (Char v); ty } - | T.Str, ConstantValue (String v) -> { V.value = V.Concrete (String v); ty } - | T.Integer int_ty, ConstantValue (V.Scalar v) -> + | T.Bool, Bool v -> { V.value = V.Concrete (Bool v); ty } + | T.Char, Char v -> { V.value = V.Concrete (Char v); ty } + | T.Str, String v -> { V.value = V.Concrete (String v); ty } + | T.Integer int_ty, V.Scalar v -> (* Check the type and the ranges *) assert (int_ty = v.int_ty); assert (check_scalar_value_in_range v); { V.value = V.Concrete (V.Scalar v); ty } (* Remaining cases (invalid) - listing as much as we can on purpose - (allows to catch errors at compilation if the definitions change) *) - | _, ConstantAdt _ | _, ConstantValue _ -> + (allows to catch errors at compilation if the definitions change) *) + | _, _ -> failwith "Improperly typed constant value" (** Reorganize the environment in preparation for the evaluation of an operand. @@ -197,8 +174,8 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : let prepare : cm_fun = fun cf ctx -> match op with - | Expressions.Constant _ -> - (* No need to reorganize the context *) + | Expressions.Constant (ty, cv) -> + typecheck_constant_value ty cv |> ignore; cf ctx | Expressions.Copy p -> (* Access the value *) @@ -226,9 +203,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> - let v = operand_constant_value_to_typed_value ctx ty cv in - cf v ctx + | Expressions.Constant (ty, cv) -> cf (typecheck_constant_value ty cv) ctx | Expressions.Copy p -> (* Access the value *) let access = Read in diff --git a/src/InterpreterStatements.ml b/src/InterpreterStatements.ml index 585fa828..8f981174 100644 --- a/src/InterpreterStatements.ml +++ b/src/InterpreterStatements.ml @@ -831,6 +831,15 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = (* Compose and apply *) comp cf_eval_rvalue cf_assign cf ctx + | A.AssignGlobal { dst; global } -> + let call : A.call = { + func = A.Regular (A.global_to_fun_id ctx.fun_context.gid_conv global); + region_args = []; + type_args = []; + args = []; + dest = { var_id = dst; projection = [] }; + } in + eval_function_call config call cf ctx | A.FakeRead p -> let expand_prim_copy = false in let cf_prepare cf = diff --git a/src/InterpreterUtils.ml b/src/InterpreterUtils.ml index 7a2e22f7..47323cc2 100644 --- a/src/InterpreterUtils.ml +++ b/src/InterpreterUtils.ml @@ -33,8 +33,6 @@ let typed_value_to_string = PA.typed_value_to_string let typed_avalue_to_string = PA.typed_avalue_to_string -let operand_constant_value_to_string = PA.operand_constant_value_to_string - let place_to_string = PA.place_to_string let operand_to_string = PA.operand_to_string diff --git a/src/LlbcAst.ml b/src/LlbcAst.ml index d35cd5d8..16733e20 100644 --- a/src/LlbcAst.ml +++ b/src/LlbcAst.ml @@ -1,10 +1,22 @@ -open Identifiers open Names open Types open Values open Expressions +open Identifiers module FunDeclId = IdGen () +module GlobalDeclId = IdGen () + +(* Strict type for the number of function declarations (see [global_to_fun_id] below) *) +type global_id_converter = { fun_count : int } +[@@deriving show] + +(** Converts a global id to its corresponding function id. + To do so, it adds the global id to the number of function declarations : + We have the bijection `global_id <=> fun_id + fun_id_count`. +*) +let global_to_fun_id (conv : global_id_converter) (gid : GlobalDeclId.id) : FunDeclId.id = + FunDeclId.of_int ((GlobalDeclId.to_int gid) + conv.fun_count) type var = { index : VarId.id; (** Unique variable identifier *) @@ -36,6 +48,12 @@ type assumed_fun_id = type fun_id = Regular of FunDeclId.id | Assumed of assumed_fun_id [@@deriving show, ord] +type assign_global = { + dst : VarId.id; + global : GlobalDeclId.id; +} +[@@deriving show] + type assertion = { cond : operand; expected : bool } [@@deriving show] type abs_region_group = (AbstractionId.id, RegionId.id) g_region_group @@ -77,6 +95,8 @@ class ['self] iter_statement_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter + method visit_assign_global : 'env -> assign_global -> unit = fun _ _ -> () + method visit_place : 'env -> place -> unit = fun _ _ -> () method visit_rvalue : 'env -> rvalue -> unit = fun _ _ -> () @@ -99,6 +119,8 @@ class ['self] map_statement_base = object (_self : 'self) inherit [_] VisitorsRuntime.map + method visit_assign_global : 'env -> assign_global -> assign_global = fun _ x -> x + method visit_place : 'env -> place -> place = fun _ x -> x method visit_rvalue : 'env -> rvalue -> rvalue = fun _ x -> x @@ -120,6 +142,7 @@ class ['self] map_statement_base = type statement = | Assign of place * rvalue + | AssignGlobal of assign_global | FakeRead of place | SetDiscriminant of place * VariantId.id | Drop of place @@ -178,5 +201,6 @@ type fun_decl = { name : fun_name; signature : fun_sig; body : fun_body option; + is_global : bool; } [@@deriving show] diff --git a/src/LlbcOfJson.ml b/src/LlbcOfJson.ml index 99d652ec..3ff45077 100644 --- a/src/LlbcOfJson.ml +++ b/src/LlbcOfJson.ml @@ -17,6 +17,8 @@ module S = Scalars module M = Modules module E = Expressions module A = LlbcAst +module TU = TypesUtils +module AU = LlbcAstUtils (* The default logger *) let log = Logging.llbc_of_json_logger @@ -298,23 +300,6 @@ let scalar_value_of_json (js : json) : (V.scalar_value, string) result = raise (Failure ("Scalar value not in range: " ^ V.show_scalar_value sv))); res -let constant_value_of_json (js : json) : (V.constant_value, string) result = - combine_error_msgs js "constant_value_of_json" - (match js with - | `Assoc [ ("Scalar", scalar_value) ] -> - let* scalar_value = scalar_value_of_json scalar_value in - Ok (V.Scalar scalar_value) - | `Assoc [ ("Bool", v) ] -> - let* v = bool_of_json v in - Ok (V.Bool v) - | `Assoc [ ("Char", v) ] -> - let* v = char_of_json v in - Ok (V.Char v) - | `Assoc [ ("String", v) ] -> - let* v = string_of_json v in - Ok (V.String v) - | _ -> Error "") - let field_proj_kind_of_json (js : json) : (E.field_proj_kind, string) result = combine_error_msgs js "field_proj_kind_of_json" (match js with @@ -393,21 +378,23 @@ let binop_of_json (js : json) : (E.binop, string) result = | `String "Shr" -> Ok E.Shr | _ -> Error ("binop_of_json failed on:" ^ show js) -let rec operand_constant_value_of_json (js : json) : - (E.operand_constant_value, string) result = - combine_error_msgs js "operand_constant_value_of_json" +let constant_value_of_json (js : json) : (V.constant_value, string) result = + combine_error_msgs js "constant_value_of_json" (match js with - | `Assoc [ ("ConstantValue", `List [ cv ]) ] -> - let* cv = constant_value_of_json cv in - Ok (E.ConstantValue cv) - | `Assoc [ ("ConstantAdt", `List [ variant_id; field_values ]) ] -> - let* variant_id = option_of_json T.VariantId.id_of_json variant_id in - let* field_values = - list_of_json operand_constant_value_of_json field_values - in - Ok (E.ConstantAdt (variant_id, field_values)) + | `Assoc [ ("Scalar", scalar_value) ] -> + let* scalar_value = scalar_value_of_json scalar_value in + Ok (V.Scalar scalar_value) + | `Assoc [ ("Bool", v) ] -> + let* v = bool_of_json v in + Ok (V.Bool v) + | `Assoc [ ("Char", v) ] -> + let* v = char_of_json v in + Ok (V.Char v) + | `Assoc [ ("String", v) ] -> + let* v = string_of_json v in + Ok (V.String v) | _ -> Error "") - + let operand_of_json (js : json) : (E.operand, string) result = combine_error_msgs js "operand_of_json" (match js with @@ -417,9 +404,9 @@ let operand_of_json (js : json) : (E.operand, string) result = | `Assoc [ ("Move", place) ] -> let* place = place_of_json place in Ok (E.Move place) - | `Assoc [ ("Constant", `List [ ty; cv ]) ] -> + | `Assoc [ ("Const", `List [ ty; cv ]) ] -> let* ty = ety_of_json ty in - let* cv = operand_constant_value_of_json cv in + let* cv = constant_value_of_json cv in Ok (E.Constant (ty, cv)) | _ -> Error "") @@ -553,13 +540,17 @@ let call_of_json (js : json) : (A.call, string) result = Ok { A.func; region_args; type_args; args; dest } | _ -> Error "") -let rec statement_of_json (js : json) : (A.statement, string) result = +let rec statement_of_json (js : json) (gid_conv : A.global_id_converter) : (A.statement, string) result = combine_error_msgs js "statement_of_json" (match js with | `Assoc [ ("Assign", `List [ place; rvalue ]) ] -> let* place = place_of_json place in let* rvalue = rvalue_of_json rvalue in Ok (A.Assign (place, rvalue)) + | `Assoc [ ("AssignGlobal", `List [ dst; global ]) ] -> + let* dst = V.VarId.id_of_json dst in + let* global = A.GlobalDeclId.id_of_json global in + Ok (A.AssignGlobal { dst; global }) | `Assoc [ ("FakeRead", place) ] -> let* place = place_of_json place in Ok (A.FakeRead place) @@ -586,47 +577,48 @@ let rec statement_of_json (js : json) : (A.statement, string) result = Ok (A.Continue i) | `String "Nop" -> Ok A.Nop | `Assoc [ ("Sequence", `List [ st1; st2 ]) ] -> - let* st1 = statement_of_json st1 in - let* st2 = statement_of_json st2 in + let* st1 = statement_of_json st1 gid_conv in + let* st2 = statement_of_json st2 gid_conv in Ok (A.Sequence (st1, st2)) | `Assoc [ ("Switch", `List [ op; tgt ]) ] -> let* op = operand_of_json op in - let* tgt = switch_targets_of_json tgt in + let* tgt = switch_targets_of_json tgt gid_conv in Ok (A.Switch (op, tgt)) | `Assoc [ ("Loop", st) ] -> - let* st = statement_of_json st in + let* st = statement_of_json st gid_conv in Ok (A.Loop st) | _ -> Error "") -and switch_targets_of_json (js : json) : (A.switch_targets, string) result = +and switch_targets_of_json (js : json) (gid_conv : A.global_id_converter) : (A.switch_targets, string) result = combine_error_msgs js "switch_targets_of_json" (match js with | `Assoc [ ("If", `List [ st1; st2 ]) ] -> - let* st1 = statement_of_json st1 in - let* st2 = statement_of_json st2 in + let* st1 = statement_of_json st1 gid_conv in + let* st2 = statement_of_json st2 gid_conv in Ok (A.If (st1, st2)) | `Assoc [ ("SwitchInt", `List [ int_ty; tgts; otherwise ]) ] -> let* int_ty = integer_type_of_json int_ty in let* tgts = - list_of_json - (pair_of_json (list_of_json scalar_value_of_json) statement_of_json) + list_of_json (pair_of_json + (list_of_json scalar_value_of_json) + (fun js -> statement_of_json js gid_conv)) tgts in - let* otherwise = statement_of_json otherwise in + let* otherwise = statement_of_json otherwise gid_conv in Ok (A.SwitchInt (int_ty, tgts, otherwise)) | _ -> Error "") -let fun_body_of_json (js : json) : (A.fun_body, string) result = +let fun_body_of_json (js : json) (gid_conv : A.global_id_converter) : (A.fun_body, string) result = combine_error_msgs js "fun_body_of_json" (match js with | `Assoc [ ("arg_count", arg_count); ("locals", locals); ("body", body) ] -> let* arg_count = int_of_json arg_count in let* locals = list_of_json var_of_json locals in - let* body = statement_of_json body in + let* body = statement_of_json body gid_conv in Ok { A.arg_count; locals; body } | _ -> Error "") -let fun_decl_of_json (js : json) : (A.fun_decl, string) result = +let fun_decl_of_json (js : json) (gid_conv : A.global_id_converter) : (A.fun_decl, string) result = combine_error_msgs js "fun_decl_of_json" (match js with | `Assoc @@ -639,8 +631,36 @@ let fun_decl_of_json (js : json) : (A.fun_decl, string) result = let* def_id = A.FunDeclId.id_of_json def_id in let* name = fun_name_of_json name in let* signature = fun_sig_of_json signature in - let* body = option_of_json fun_body_of_json body in - Ok { A.def_id; name; signature; body } + let* body = option_of_json (fun js -> fun_body_of_json js gid_conv) body in + Ok { A.def_id; name; signature; body; is_global = false; } + | _ -> Error "") + +(* Converts a global declaration to a function declaration. + *) +let global_decl_of_json (js : json) (gid_conv : A.global_id_converter) : (A.fun_decl, string) result = + combine_error_msgs js "global_decl_of_json" + (match js with + | `Assoc + [ + ("def_id", def_id); + ("name", name); + ("type_", type_); + ("body", body); + ] -> + let* global_id = A.GlobalDeclId.id_of_json def_id in + let def_id = A.global_to_fun_id gid_conv global_id in + let* name = fun_name_of_json name in + let* type_ = ety_of_json type_ in + let* body = option_of_json (fun js -> fun_body_of_json js gid_conv) body in + let signature : A.fun_sig = { + region_params = []; + num_early_bound_regions = 0; + regions_hierarchy = []; + type_params = []; + inputs = []; + output = TU.ety_no_regions_to_sty type_; + } in + Ok { A.def_id; name; signature; body; is_global = true; } | _ -> Error "") let g_declaration_group_of_json (id_of_json : json -> ('id, string) result) @@ -665,7 +685,15 @@ let fun_declaration_group_of_json (js : json) : combine_error_msgs js "fun_declaration_group_of_json" (g_declaration_group_of_json A.FunDeclId.id_of_json js) -let declaration_group_of_json (js : json) : (M.declaration_group, string) result +let global_declaration_group_of_json (js : json) (gid_conv : A.global_id_converter) : + (M.fun_declaration_group, string) result = + combine_error_msgs js "global_declaration_group_of_json" + (g_declaration_group_of_json (fun js -> + let* id = A.GlobalDeclId.id_of_json js in + Ok (A.global_to_fun_id gid_conv id) + ) js) + +let declaration_group_of_json (js : json) (gid_conv : A.global_id_converter) : (M.declaration_group, string) result = combine_error_msgs js "declaration_of_json" (match js with @@ -675,8 +703,17 @@ let declaration_group_of_json (js : json) : (M.declaration_group, string) result | `Assoc [ ("Fun", `List [ decl ]) ] -> let* decl = fun_declaration_group_of_json decl in Ok (M.Fun decl) + | `Assoc [ ("Global", `List [ decl ]) ] -> + let* decl = global_declaration_group_of_json decl gid_conv in + Ok (M.Fun decl) | _ -> Error "") +let length_of_json_list (js: json) : (int, string) result = + combine_error_msgs js "get_json_list_len" + (match js with + | `List jsl -> Ok (List.length jsl) + | _ -> Error ("not a list: " ^ show js)) + let llbc_module_of_json (js : json) : (M.llbc_module, string) result = combine_error_msgs js "llbc_module_of_json" (match js with @@ -686,12 +723,22 @@ let llbc_module_of_json (js : json) : (M.llbc_module, string) result = ("declarations", declarations); ("types", types); ("functions", functions); + ("globals", globals); ] -> + let* fun_count = length_of_json_list functions in + let gid_conv = { A.fun_count = fun_count } in let* name = string_of_json name in let* declarations = - list_of_json declaration_group_of_json declarations + list_of_json (fun js -> declaration_group_of_json js gid_conv) declarations in let* types = list_of_json type_decl_of_json types in - let* functions = list_of_json fun_decl_of_json functions in - Ok { M.name; declarations; types; functions } + let* functions = list_of_json (fun js -> fun_decl_of_json js gid_conv) functions in + let* globals = list_of_json (fun js -> global_decl_of_json js gid_conv) globals in + Ok { + M.name; + declarations; + types; + functions = functions @ globals; + gid_conv; + } | _ -> Error "") diff --git a/src/Modules.ml b/src/Modules.ml index f52983c6..149de020 100644 --- a/src/Modules.ml +++ b/src/Modules.ml @@ -7,7 +7,8 @@ type 'id g_declaration_group = NonRec of 'id | Rec of 'id list type type_declaration_group = TypeDeclId.id g_declaration_group [@@deriving show] -type fun_declaration_group = FunDeclId.id g_declaration_group [@@deriving show] +type fun_declaration_group = FunDeclId.id g_declaration_group +[@@deriving show] (** Module declaration *) type declaration_group = @@ -20,6 +21,7 @@ type llbc_module = { declarations : declaration_group list; types : type_decl list; functions : fun_decl list; + gid_conv : global_id_converter; } (** LLBC module - TODO: rename to crate *) diff --git a/src/Print.ml b/src/Print.ml index af6fc982..337116ec 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -686,6 +686,7 @@ module LlbcAst = struct adt_field_names : T.TypeDeclId.id -> T.VariantId.id option -> string list option; fun_decl_id_to_string : A.FunDeclId.id -> string; + global_decl_id_to_string : A.GlobalDeclId.id -> string; } let ast_to_ctx_formatter (fmt : ast_formatter) : PC.ctx_formatter = @@ -742,6 +743,9 @@ module LlbcAst = struct let def = C.ctx_lookup_fun_decl ctx def_id in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + fun_decl_id_to_string (A.global_to_fun_id ctx.fun_context.gid_conv def_id) + in { rvar_to_string = ctx_fmt.PV.rvar_to_string; r_to_string = ctx_fmt.PV.r_to_string; @@ -752,10 +756,14 @@ module LlbcAst = struct adt_field_names = ctx_fmt.PV.adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } - let fun_decl_to_ast_formatter (type_decls : T.type_decl T.TypeDeclId.Map.t) - (fun_decls : A.fun_decl A.FunDeclId.Map.t) (fdef : A.fun_decl) : + let fun_decl_to_ast_formatter + (type_decls : T.type_decl T.TypeDeclId.Map.t) + (fun_decls : A.fun_decl A.FunDeclId.Map.t) + (global_to_fun_id : A.GlobalDeclId.id -> A.FunDeclId.id) + (fdef : A.fun_decl) : ast_formatter = let rvar_to_string r = let rvar = T.RegionVarId.nth fdef.signature.region_params r in @@ -784,6 +792,9 @@ module LlbcAst = struct let def = A.FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + fun_decl_id_to_string (global_to_fun_id def_id) + in { rvar_to_string; r_to_string; @@ -794,6 +805,7 @@ module LlbcAst = struct adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } let rec projection_to_string (fmt : ast_formatter) (inside : string) @@ -859,35 +871,13 @@ module LlbcAst = struct | E.Shl -> "<<" | E.Shr -> ">>" - let rec operand_constant_value_to_string (fmt : ast_formatter) - (cv : E.operand_constant_value) : string = - match cv with - | E.ConstantValue cv -> PV.constant_value_to_string cv - | E.ConstantAdt (variant_id, field_values) -> - (* This is a bit annoying, because we don't have context information - * to convert the ADT to a value, so we do the best we can in the - * simplest manner. Anyway, those printing utilitites are only used - * for debugging, and complex constant values are not common. - * We might want to store type information in the operand constant values - * in the future. - *) - let variant_id = option_to_string T.VariantId.to_string variant_id in - let field_values = - List.map (operand_constant_value_to_string fmt) field_values - in - "ConstantAdt " ^ variant_id ^ " {" - ^ String.concat ", " field_values - ^ "}" - let operand_to_string (fmt : ast_formatter) (op : E.operand) : string = match op with | E.Copy p -> "copy " ^ place_to_string fmt p | E.Move p -> "move " ^ place_to_string fmt p | E.Constant (ty, cv) -> - (* For clarity, we also print the typing information: see the comment in - * [operand_constant_value_to_string] *) "(" - ^ operand_constant_value_to_string fmt cv + ^ PV.constant_value_to_string cv ^ " : " ^ PT.ety_to_string (ast_to_etype_formatter fmt) ty ^ ")" @@ -948,6 +938,8 @@ module LlbcAst = struct match st with | A.Assign (p, rv) -> indent ^ place_to_string fmt p ^ " := " ^ rvalue_to_string fmt rv + | A.AssignGlobal { dst; global } -> + indent ^ fmt.var_id_to_string dst ^ " := global " ^ fmt.global_decl_id_to_string global | A.FakeRead p -> indent ^ "fake_read " ^ place_to_string fmt p | A.SetDiscriminant (p, variant_id) -> (* TODO: improve this to lookup the variant name by using the def id *) @@ -1137,8 +1129,11 @@ module Module = struct (** Generate an [ast_formatter] by using a definition context in combination with the variables local to a function's definition *) - let def_ctx_to_ast_formatter (type_context : T.type_decl T.TypeDeclId.Map.t) - (fun_context : A.fun_decl A.FunDeclId.Map.t) (def : A.fun_decl) : + let def_ctx_to_ast_formatter + (type_context : T.type_decl T.TypeDeclId.Map.t) + (fun_context : A.fun_decl A.FunDeclId.Map.t) + (global_to_fun_id : A.GlobalDeclId.id -> A.FunDeclId.id) + (def : A.fun_decl) : PA.ast_formatter = let rvar_to_string vid = let var = T.RegionVarId.nth def.signature.region_params vid in @@ -1160,6 +1155,9 @@ module Module = struct let def = A.FunDeclId.Map.find def_id fun_context in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + fun_decl_id_to_string (global_to_fun_id def_id) + in let var_id_to_string vid = let var = V.VarId.nth (Option.get def.body).locals vid in PA.var_to_string var @@ -1181,13 +1179,17 @@ module Module = struct var_id_to_string; adt_field_names; fun_decl_id_to_string; + global_decl_id_to_string; } (** This function pretty-prints a function definition by using a definition context *) - let fun_decl_to_string (type_context : T.type_decl T.TypeDeclId.Map.t) - (fun_context : A.fun_decl A.FunDeclId.Map.t) (def : A.fun_decl) : string = - let fmt = def_ctx_to_ast_formatter type_context fun_context def in + let fun_decl_to_string + (type_context : T.type_decl T.TypeDeclId.Map.t) + (fun_context : A.fun_decl A.FunDeclId.Map.t) + (global_to_fun_id : A.GlobalDeclId.id -> A.FunDeclId.id) + (def : A.fun_decl) : string = + let fmt = def_ctx_to_ast_formatter type_context fun_context global_to_fun_id def in PA.fun_decl_to_string fmt "" " " def let module_to_string (m : M.llbc_module) : string = @@ -1198,7 +1200,8 @@ module Module = struct (* The functions *) let fun_decls = - List.map (fun_decl_to_string types_defs_map funs_defs_map) m.M.functions + let gid_to_fid = fun gid -> A.global_to_fun_id m.gid_conv gid in + List.map (fun_decl_to_string types_defs_map funs_defs_map gid_to_fid) m.M.functions in (* Put everything together *) @@ -1255,11 +1258,6 @@ module EvalCtxLlbcAst = struct let fmt = PC.eval_ctx_to_ctx_formatter ctx in PV.typed_avalue_to_string fmt v - let operand_constant_value_to_string (ctx : C.eval_ctx) - (cv : E.operand_constant_value) : string = - let fmt = PA.eval_ctx_to_ast_formatter ctx in - PA.operand_constant_value_to_string fmt cv - let place_to_string (ctx : C.eval_ctx) (op : E.place) : string = let fmt = PA.eval_ctx_to_ast_formatter ctx in PA.place_to_string fmt op diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 5e817dde..c13f967f 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -12,7 +12,6 @@ module RegionId = T.RegionId module VariantId = T.VariantId module FieldId = T.FieldId module SymbolicValueId = V.SymbolicValueId -module FunDeclId = A.FunDeclId type type_formatter = { type_var_id_to_string : TypeVarId.id -> string; @@ -45,6 +44,7 @@ type ast_formatter = { TypeDeclId.id -> VariantId.id option -> FieldId.id -> string option; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; fun_decl_id_to_string : A.FunDeclId.id -> string; + global_decl_id_to_string : A.GlobalDeclId.id -> string; } let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = @@ -86,7 +86,9 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) while we only need those definitions to lookup proper names for the def ids. *) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) - (fun_decls : A.fun_decl FunDeclId.Map.t) (type_params : type_var list) : + (fun_decls : A.fun_decl A.FunDeclId.Map.t) + (gid_conv : A.global_id_converter) + (type_params : type_var list) : ast_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in @@ -113,6 +115,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let def = A.FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in + let global_decl_id_to_string def_id = + fun_decl_id_to_string (A.global_to_fun_id gid_conv def_id) + in { type_var_id_to_string; type_decl_id_to_string; @@ -121,6 +126,7 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) adt_field_names; adt_field_to_string; fun_decl_id_to_string; + global_decl_id_to_string; } let type_id_to_string (fmt : type_formatter) (id : type_id) : string = @@ -481,6 +487,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) let qualif_s = match qualif.id with | Func fun_id -> fun_id_to_string fmt fun_id + | Global global_id -> fmt.global_decl_id_to_string global_id | AdtCons adt_cons_id -> let variant_s = adt_variant_to_string diff --git a/src/PrintSymbolicAst.ml b/src/PrintSymbolicAst.ml index 0ab68efc..e44b422a 100644 --- a/src/PrintSymbolicAst.ml +++ b/src/PrintSymbolicAst.ml @@ -7,6 +7,7 @@ open Errors open Identifiers +open FunIdentifier module T = Types module TU = TypesUtils module V = Values @@ -20,7 +21,7 @@ module PT = Print.Types type formatting_ctx = { type_context : C.type_context; - fun_context : A.fun_decl A.FunDeclId.Map.t; + fun_context : A.fun_decl FunDeclId.Map.t; type_vars : T.type_var list; } diff --git a/src/Pure.ml b/src/Pure.ml index 5834b87f..b3be2040 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -10,7 +10,6 @@ module RegionGroupId = T.RegionGroupId module VariantId = T.VariantId module FieldId = T.FieldId module SymbolicValueId = V.SymbolicValueId -module FunDeclId = A.FunDeclId module SynthPhaseId = IdGen () (** We give an identifier to every phase of the synthesis (forward, backward @@ -303,6 +302,7 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] type qualif_id = | Func of fun_id + | Global of A.GlobalDeclId.id | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) | Proj of projection (** Field projector *) [@@deriving show] @@ -566,7 +566,7 @@ type fun_body = { } type fun_decl = { - def_id : FunDeclId.id; + def_id : A.FunDeclId.id; back_id : T.RegionGroupId.id option; basename : fun_name; (** The "base" name of the function. @@ -575,5 +575,6 @@ type fun_decl = { (to identify the forward/backward functions) later. *) signature : fun_sig; + is_global : bool; body : fun_body option; } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 826283ae..7927a068 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -611,7 +611,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) | Func (Unop _ | Binop _) -> true (* primitive function call *) | Func (Regular _) -> - false (* non-primitive function call *)) + false (* non-primitive function call *) + | Global _ -> + true (* Global constant or static *) + ) | _ -> filter else false in diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index 1c530011..e58fec2a 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -74,6 +74,7 @@ type formatter = { fun_name : A.fun_id -> fun_name -> + bool -> int -> region_group_info option -> bool * int -> @@ -92,7 +93,7 @@ type formatter = { (`None` if forward function) TODO: use the fun id for the assumed functions. *) - decreases_clause_name : FunDeclId.id -> fun_name -> string; + decreases_clause_name : A.FunDeclId.id -> fun_name -> string; (** Generates the name of the definition used to prove/reason about termination. The generated code uses this clause where needed, but its body must be defined by the user. @@ -356,7 +357,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let fun_name = match fid with | A.Regular fid -> - Print.fun_name_to_string (FunDeclId.Map.find fid fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name | A.Assumed aid -> A.show_assumed_fun_id aid in let fun_kind = @@ -369,7 +370,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let fun_name = match fid with | A.Regular fid -> - Print.fun_name_to_string (FunDeclId.Map.find fid fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name | A.Assumed aid -> A.show_assumed_fun_id aid in "decreases clause for function: " ^ fun_name @@ -440,11 +441,13 @@ let ctx_get (id : id) (ctx : extraction_ctx) : string = log#serror ("Could not find: " ^ id_to_string id ctx); raise Not_found -let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option) +let ctx_get_function (id : A.fun_id) + (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = ctx_get (FunId (id, rg)) ctx -let ctx_get_local_function (id : FunDeclId.id) (rg : RegionGroupId.id option) +let ctx_get_local_function (id : A.FunDeclId.id) + (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = ctx_get_function (A.Regular id) rg ctx @@ -475,7 +478,7 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = ctx_get (VariantId (def_id, variant_id)) ctx -let ctx_get_decreases_clause (def_id : FunDeclId.id) (ctx : extraction_ctx) : +let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : string = ctx_get (DecreasesClauseId (A.Regular def_id)) ctx @@ -573,7 +576,7 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in let llbc_def = - FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls + A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls in let sg = llbc_def.signature in let num_rgs = List.length sg.regions_hierarchy in @@ -596,7 +599,7 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) in let def_id = A.Regular def_id in let name = - ctx.fmt.fun_name def_id def.basename num_rgs rg_info (keep_fwd, num_backs) + ctx.fmt.fun_name def_id def.basename def.is_global num_rgs rg_info (keep_fwd, num_backs) in (* Add the function name *) let ctx = ctx_add (FunId (def_id, def.back_id)) name ctx in @@ -666,8 +669,12 @@ let compute_type_decl_name (fmt : formatter) (def : type_decl) : string = information. TODO: move all those helpers. *) -let default_fun_suffix (num_region_groups : int) (rg : region_group_info option) - ((keep_fwd, num_backs) : bool * int) : string = +let default_fun_suffix + (is_global : bool) + (num_region_groups : int) + (rg : region_group_info option) + ((keep_fwd, num_backs) : bool * int) + : string = (* There are several cases: - [rg] is `Some`: this is a forward function: - we add "_fwd" @@ -683,6 +690,7 @@ let default_fun_suffix (num_region_groups : int) (rg : region_group_info option) we could not add the "_fwd" suffix) to prevent name clashes between definitions (in particular between type and function definitions). *) + if is_global then "_c" else match rg with | None -> "_fwd" | Some rg -> diff --git a/src/PureTypeCheck.ml b/src/PureTypeCheck.ml index 8848ff20..90b9ab09 100644 --- a/src/PureTypeCheck.ml +++ b/src/PureTypeCheck.ml @@ -111,7 +111,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx body | Qualif qualif -> ( match qualif.id with - | Func _ -> () (* TODO *) + | Func _ | Global _ -> () (* TODO *) | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) diff --git a/src/Substitute.ml b/src/Substitute.ml index 711e438b..4b0a04ca 100644 --- a/src/Substitute.ml +++ b/src/Substitute.ml @@ -210,12 +210,6 @@ let place_substitute (_tsubst : T.TypeVarId.id -> T.ety) (p : E.place) : E.place (* There is nothing to do *) p -(** Apply a type substitution to an operand constant value *) -let operand_constant_value_substitute (_tsubst : T.TypeVarId.id -> T.ety) - (op : E.operand_constant_value) : E.operand_constant_value = - (* There is nothing to do *) - op - (** Apply a type substitution to an operand *) let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : E.operand = @@ -225,9 +219,7 @@ let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : | E.Move p -> E.Move (p_subst p) | E.Constant (ety, cv) -> let rsubst x = x in - E.Constant - ( ty_substitute rsubst tsubst ety, - operand_constant_value_substitute tsubst cv ) + E.Constant ( ty_substitute rsubst tsubst ety, cv ) (** Apply a type substitution to an rvalue *) let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) (rv : E.rvalue) : @@ -289,6 +281,7 @@ let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) let p = place_substitute tsubst p in let rvalue = rvalue_substitute tsubst rvalue in A.Assign (p, rvalue) + | A.AssignGlobal g -> A.AssignGlobal g | A.FakeRead p -> let p = place_substitute tsubst p in A.FakeRead p diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 4c2ba4c8..a057b015 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -67,9 +67,10 @@ type fun_sig_named_outputs = { } type fun_context = { - llbc_fun_decls : A.fun_decl FunDeclId.Map.t; + llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_sigs : fun_sig_named_outputs RegularFunIdMap.t; (** *) - fun_infos : FA.fun_info FunDeclId.Map.t; + fun_infos : FA.fun_info A.FunDeclId.Map.t; + gid_conv : A.global_id_converter; } type call_info = { @@ -133,14 +134,18 @@ let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = (* TODO: move *) let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.LlbcAst.ast_formatter = - Print.LlbcAst.fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls - ctx.fun_context.llbc_fun_decls ctx.fun_decl + Print.LlbcAst.fun_decl_to_ast_formatter + ctx.type_context.llbc_type_decls + ctx.fun_context.llbc_fun_decls + (A.global_to_fun_id ctx.fun_context.gid_conv) + ctx.fun_decl let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = let type_params = ctx.fun_decl.signature.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - PrintPure.mk_ast_formatter type_decls fun_decls type_params + let gid_conv = ctx.fun_context.gid_conv in + PrintPure.mk_ast_formatter type_decls fun_decls gid_conv type_params let ty_to_string (ctx : bs_ctx) (ty : ty) : string = let fmt = bs_ctx_to_pp_ast_formatter ctx in @@ -161,14 +166,16 @@ let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = let type_params = sg.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let gid_conv = ctx.fun_context.gid_conv in + let fmt = PrintPure.mk_ast_formatter type_decls fun_decls gid_conv type_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let gid_conv = ctx.fun_context.gid_conv in + let fmt = PrintPure.mk_ast_formatter type_decls fun_decls gid_conv type_params in PrintPure.fun_decl_to_string fmt def (* TODO: move *) @@ -195,12 +202,12 @@ let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = TypeDeclId.Map.find id ctx.type_context.llbc_type_decls -let bs_ctx_lookup_llbc_fun_decl (id : FunDeclId.id) (ctx : bs_ctx) : A.fun_decl +let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = - FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls + A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls (* TODO: move *) -let bs_ctx_lookup_local_function_sig (def_id : FunDeclId.id) +let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = let id = (A.Regular def_id, back_id) in (RegularFunIdMap.find id ctx.fun_context.fun_sigs).sg @@ -471,17 +478,14 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids (** Small utility. *) -let get_fun_effect_info (fun_infos : FA.fun_info FunDeclId.Map.t) +let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> - let info = FunDeclId.Map.find fid fun_infos in + let info = A.FunDeclId.Map.find fid fun_infos in let input_state = info.stateful in let output_state = input_state && gid = None in - (* We ignore on purpose info.can_fail: the result of the analysis is not - * used yet to adjust the translation so that the functions which syntactically - * can't fail don't use an error monad. *) - { can_fail = true; input_state; output_state } + { can_fail = info.can_fail; input_state; output_state } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; @@ -496,7 +500,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function). *) -let translate_fun_sig (fun_infos : FA.fun_info FunDeclId.Map.t) +let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = @@ -1662,6 +1666,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Lookup the signature *) let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in (* Translate the body, if there is *) + let is_global = def.A.is_global in let body = match body with | None -> None @@ -1722,7 +1727,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) Some { inputs; inputs_lvs; body } in (* Assemble the declaration *) - let def = { def_id; back_id = bid; basename; signature; body } in + let def = { def_id; back_id = bid; basename; signature; is_global; body } in (* Debugging *) log#ldebug (lazy @@ -1746,7 +1751,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info FunDeclId.Map.t) +let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = diff --git a/src/Translate.ml b/src/Translate.ml index 57b92e44..9412b8b8 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -64,7 +64,10 @@ let translate_function_to_symbolics (config : C.partial_config) ^ Print.fun_name_to_string fdef.A.name)); let { type_context; fun_context } = trans_ctx in - let fun_context = { C.fun_decls = fun_context.fun_decls } in + let fun_context = { + C.fun_decls = fun_context.fun_decls; + C.gid_conv = fun_context.gid_conv; + } in match fdef.body with | None -> None @@ -99,7 +102,8 @@ let translate_function_to_symbolics (config : C.partial_config) let translate_function_to_pure (config : C.partial_config) (mp_config : Micro.config) (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) - (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) + (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) + (fdef : A.fun_decl) : pure_fun_translation = (* Debug *) log#ldebug @@ -138,6 +142,7 @@ let translate_function_to_pure (config : C.partial_config) SymbolicToPure.llbc_fun_decls = fun_context.fun_decls; fun_sigs; fun_infos = fun_context.fun_infos; + gid_conv = fun_context.gid_conv; } in let ctx = @@ -290,7 +295,11 @@ let translate_module_to_pure (config : C.partial_config) (* Compute the type and function contexts *) let type_context, fun_context = compute_type_fun_contexts m in let fun_infos = FA.analyze_module m fun_context.C.fun_decls use_state in - let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in + let fun_context = { + fun_decls = fun_context.fun_decls; + fun_infos; + gid_conv = m.gid_conv; + } in let trans_ctx = { type_context; fun_context } in (* Translate all the type definitions *) @@ -351,8 +360,8 @@ type gen_ctx = { m : M.llbc_module; extract_ctx : PureToExtract.extraction_ctx; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) Pure.FunDeclId.Map.t; - functions_with_decreases_clause : Pure.FunDeclId.Set.t; + trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; + functions_with_decreases_clause : A.FunDeclId.Set.t; } (** Extraction context *) @@ -388,7 +397,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = ctx.trans_types in let has_opaque_funs = - Pure.FunDeclId.Map.exists + A.FunDeclId.Map.exists (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) -> Option.is_none t_fwd.body) ctx.trans_funs @@ -427,7 +436,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - Pure.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause in (* In case of (non-mutually) recursive functions, we use a simple procedure to @@ -486,9 +495,10 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if ((not is_opaque) && config.extract_transparent) || (is_opaque && config.extract_opaque) - then - ExtractToFStar.extract_fun_decl ctx.extract_ctx fmt qualif - has_decr_clause def) + then if def.is_global + then ExtractToFStar.extract_global_decl ctx.extract_ctx fmt qualif def + else ExtractToFStar.extract_fun_decl ctx.extract_ctx fmt qualif has_decr_clause def + ) fls); (* Insert unit tests if necessary *) if config.test_unit_functions then @@ -523,14 +533,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) ids | Fun (NonRec id) -> (* Lookup *) - let pure_fun = Pure.FunDeclId.Map.find id ctx.trans_funs in + let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in (* Translate *) export_functions false [ pure_fun ] | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) let pure_funs = - List.map (fun id -> Pure.FunDeclId.Map.find id ctx.trans_funs) ids + List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids in (* Translate *) export_functions true pure_funs @@ -622,7 +632,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = - Pure.FunDeclId.Set.of_list + A.FunDeclId.Set.of_list (List.concat (List.map (fun decl -> match decl with M.Fun (Rec ids) -> ids | _ -> []) @@ -644,7 +654,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (fun ctx (keep_fwd, def) -> (* Note that we generate a decrease clause for all the recursive functions *) let gen_decr_clause = - Pure.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions + A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions in ExtractToFStar.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause def) @@ -674,7 +684,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) in let trans_funs = - Pure.FunDeclId.Map.of_list + A.FunDeclId.Map.of_list (List.map (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> (fd.def_id, (keep_fwd, (fd, bdl)))) @@ -761,7 +771,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (* Extract the template clauses *) let needs_clauses_module = config.extract_decreases_clauses - && not (Pure.FunDeclId.Set.is_empty rec_functions) + && not (A.FunDeclId.Set.is_empty rec_functions) in (if needs_clauses_module && config.extract_template_decreases_clauses then let clauses_filename = extract_filebasename ^ ".Clauses.Template.fst" in diff --git a/src/TranslateCore.ml b/src/TranslateCore.ml index 17c35cbf..047219ad 100644 --- a/src/TranslateCore.ml +++ b/src/TranslateCore.ml @@ -16,6 +16,7 @@ type type_context = C.type_context [@@deriving show] type fun_context = { fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_infos : FA.fun_info A.FunDeclId.Map.t; + gid_conv : A.global_id_converter; } [@@deriving show] @@ -39,16 +40,18 @@ let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = let type_params = sg.type_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let gid_conv = ctx.fun_context.gid_conv in + let fmt = PrintPure.mk_ast_formatter type_decls fun_decls gid_conv type_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in - let fmt = PrintPure.mk_ast_formatter type_decls fun_decls type_params in + let gid_conv = ctx.fun_context.gid_conv in + let fmt = PrintPure.mk_ast_formatter type_decls fun_decls gid_conv type_params in PrintPure.fun_decl_to_string fmt def -let fun_decl_id_to_string (ctx : trans_ctx) (id : Pure.FunDeclId.id) : string = +let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string = Print.fun_name_to_string - (Pure.FunDeclId.Map.find id ctx.fun_context.fun_decls).name + (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name diff --git a/src/TypesUtils.ml b/src/TypesUtils.ml index bee7956e..8d0624ee 100644 --- a/src/TypesUtils.ml +++ b/src/TypesUtils.ml @@ -100,28 +100,31 @@ let rty_regions_intersect (ty : rty) (regions : RegionId.Set.t) : bool = let ty_regions = rty_regions ty in not (RegionId.Set.disjoint ty_regions regions) -(** Convert an [ety], containing no region variables, to an [rty]. +(** Convert an [ety], containing no region variables, to an [rty] or [sty]. In practice, it is the identity. *) -let rec ety_no_regions_to_rty (ty : ety) : rty = +let rec ety_no_regions_to_gr_ty (ty : ety) : 'a gr_ty = match ty with | Adt (type_id, regions, tys) -> assert (regions = []); - Adt (type_id, [], List.map ety_no_regions_to_rty tys) + Adt (type_id, [], List.map ety_no_regions_to_gr_ty tys) | TypeVar v -> TypeVar v | Bool -> Bool | Char -> Char | Never -> Never | Integer int_ty -> Integer int_ty | Str -> Str - | Array ty -> Array (ety_no_regions_to_rty ty) - | Slice ty -> Slice (ety_no_regions_to_rty ty) + | Array ty -> Array (ety_no_regions_to_gr_ty ty) + | Slice ty -> Slice (ety_no_regions_to_gr_ty ty) | Ref (_, _, _) -> failwith "Can't convert a ref with erased regions to a ref with non-erased \ regions" +let ety_no_regions_to_rty (ty : ety) : rty = ety_no_regions_to_gr_ty ty +let ety_no_regions_to_sty (ty : ety) : sty = ety_no_regions_to_gr_ty ty + (** Retuns true if the type contains borrows. Note that we can't simply explore the type and look for regions: sometimes |