From b970183881379ff676b232e47e353e924de8cfdd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 9 Nov 2022 11:47:21 +0100 Subject: Update the way function names are handled in Pure --- compiler/ExtractToFStar.ml | 34 ++++++++-------- compiler/PrintPure.ml | 57 ++++++++++++++------------ compiler/Pure.ml | 34 ++++++++++------ compiler/PureMicroPasses.ml | 33 +++++++-------- compiler/PureToExtract.ml | 97 +++++++++++++++++++++++---------------------- compiler/PureTypeCheck.ml | 2 +- compiler/PureUtils.ml | 35 ++++++++-------- compiler/SymbolicToPure.ml | 12 +++--- 8 files changed, 162 insertions(+), 142 deletions(-) (limited to 'compiler') diff --git a/compiler/ExtractToFStar.ml b/compiler/ExtractToFStar.ml index 6d680984..2a7d6a6c 100644 --- a/compiler/ExtractToFStar.ml +++ b/compiler/ExtractToFStar.ml @@ -128,7 +128,7 @@ let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list = (Option, option_none_id, "None"); ] -let fstar_assumed_functions : +let fstar_assumed_llbc_functions : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = let rg0 = Some T.RegionGroupId.zero in [ @@ -146,13 +146,17 @@ let fstar_assumed_functions : (VecIndexMut, rg0, "vec_index_mut_back"); ] +let fstar_assumed_pure_functions : (pure_assumed_fun_id * string) list = + [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] + let fstar_names_map_init : names_map_init = { keywords = fstar_keywords; assumed_adts = fstar_assumed_adts; assumed_structs = fstar_assumed_structs; assumed_variants = fstar_assumed_variants; - assumed_functions = fstar_assumed_functions; + assumed_llbc_functions = fstar_assumed_llbc_functions; + assumed_pure_functions = fstar_assumed_pure_functions; } let fstar_extract_unop (extract_expr : bool -> texpression -> unit) @@ -321,7 +325,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) let parts = List.map to_snake_case (get_name name) in String.concat "_" parts in - let fun_name (_fid : A.fun_id) (fname : fun_name) (num_rgs : int) + let fun_name (fname : fun_name) (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) : string = let fname = fun_name_to_snake_case fname in (* Compute the suffix *) @@ -416,7 +420,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) char_name = "char"; int_name; str_name = "string"; - assert_name = "massert"; field_name; variant_name; struct_constructor; @@ -433,7 +436,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) } (** [inside] constrols whether we should add parentheses or not around type - application (if [true] we add parentheses). + applications (if [true] we add parentheses). *) let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (ty : ty) : unit = @@ -928,7 +931,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Qualif qualif -> ( (* Top-level qualifier *) match qualif.id with - | Func fun_id -> + | FunOrOp fun_id -> extract_function_call ctx fmt inside fun_id qualif.type_args args | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> @@ -957,7 +960,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: function call *) and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (fid : fun_id) (type_args : ty list) + (inside : bool) (fid : fun_or_op_id) (type_args : ty list) (args : texpression list) : unit = match (fid, args) with | Unop unop, [ arg ] -> @@ -971,17 +974,12 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) ctx.fmt.extract_binop (extract_texpression ctx fmt) fmt inside binop int_ty arg0 arg1 - | Regular (_, _), _ | Assert, [ _ ] -> + | Fun fun_id, _ -> if inside then F.pp_print_string fmt "("; (* Open a box for the function call *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the function name *) - let fun_name = - match fid with - | Regular (fun_id, rg_id) -> ctx_get_function fun_id rg_id ctx - | Assert -> ctx.fmt.assert_name - | _ -> raise (Failure "Unreachable") - in + let fun_name = ctx_get_function fun_id ctx in F.pp_print_string fmt fun_name; (* Print the type parameters *) List.iter @@ -999,10 +997,10 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) F.pp_close_box fmt (); (* Return *) if inside then F.pp_print_string fmt ")" - | (Unop _ | Binop _ | Assert), _ -> + | (Unop _ | Binop _), _ -> raise (Failure - ("Unreachable:\n" ^ "Function: " ^ show_fun_id fid + ("Unreachable:\n" ^ "Function: " ^ show_fun_or_op_id fid ^ ",\nNumber of arguments: " ^ string_of_int (List.length args) ^ ",\nArguments: " @@ -1576,7 +1574,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); let decl_name = ctx_get_global global.def_id ctx in - let body_name = ctx_get_function (Regular global.body_id) None ctx in + let body_name = + ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx + in let decl_ty, body_ty = let ty = body.signature.output in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 0d1288d7..b4ab26b8 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -386,30 +386,39 @@ let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string = let all_types = List.append inputs [ output ] in String.concat " -> " all_types -let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : A.fun_id) : string - = - match fun_id with - | A.Regular fid -> fmt.fun_decl_id_to_string fid - | A.Assumed fid -> ( - match fid with - | A.Replace -> "core::mem::replace" - | A.BoxNew -> "alloc::boxed::Box::new" - | A.BoxDeref -> "core::ops::deref::Deref::deref" - | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut" - | A.BoxFree -> "alloc::alloc::box_free" - | A.VecNew -> "alloc::vec::Vec::new" - | A.VecPush -> "alloc::vec::Vec::push" - | A.VecInsert -> "alloc::vec::Vec::insert" - | A.VecLen -> "alloc::vec::Vec::len" - | A.VecIndex -> "core::ops::index::Index::index" - | A.VecIndexMut -> - "core::ops::index::IndexMut::index_mut") - let fun_suffix (rg_id : T.RegionGroupId.id option) : string = match rg_id with | None -> "" | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id +let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = + match fid with + | A.Replace -> "core::mem::replace" + | A.BoxNew -> "alloc::boxed::Box::new" + | A.BoxDeref -> "core::ops::deref::Deref::deref" + | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut" + | A.BoxFree -> "alloc::alloc::box_free" + | A.VecNew -> "alloc::vec::Vec::new" + | A.VecPush -> "alloc::vec::Vec::push" + | A.VecInsert -> "alloc::vec::Vec::insert" + | A.VecLen -> "alloc::vec::Vec::len" + | A.VecIndex -> "core::ops::index::Index::index" + | A.VecIndexMut -> "core::ops::index::IndexMut::index_mut" + +let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = + match fid with Return -> "return" | Fail -> "fail" | Assert -> "assert" + +let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = + match fun_id with + | FromLlbc (fid, rg_id) -> + let f = + match fid with + | Regular fid -> fmt.fun_decl_id_to_string fid + | Assumed fid -> llbc_assumed_fun_id_to_string fid + in + f ^ fun_suffix rg_id + | Pure fid -> pure_assumed_fun_id_to_string fid + let unop_to_string (unop : unop) : string = match unop with | Not -> "¬" @@ -420,15 +429,13 @@ let unop_to_string (unop : unop) : string = let binop_to_string = Print.Expressions.binop_to_string -let fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = +let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) : + string = match fun_id with - | Regular (fun_id, rg_id) -> - let f = regular_fun_id_to_string fmt fun_id in - f ^ fun_suffix rg_id + | Fun fun_id -> regular_fun_id_to_string fmt fun_id | Unop unop -> unop_to_string unop | Binop (binop, int_ty) -> binop_to_string binop ^ "<" ^ integer_type_to_string int_ty ^ ">" - | Assert -> "assert" (** [inside]: controls the introduction of parentheses *) let rec texpression_to_string (fmt : ast_formatter) (inside : bool) @@ -478,7 +485,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (* Convert the qualifier identifier *) let qualif_s = match qualif.id with - | Func fun_id -> fun_id_to_string fmt fun_id + | FunOrOp fun_id -> fun_or_op_id_to_string fmt fun_id | Global global_id -> fmt.global_decl_id_to_string global_id | AdtCons adt_cons_id -> let variant_s = diff --git a/compiler/Pure.ml b/compiler/Pure.ml index f11397e9..a50dd5f9 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -281,22 +281,30 @@ and typed_pattern = { value : pattern; ty : ty } type unop = Not | Neg of integer_type | Cast of integer_type * integer_type [@@deriving show, ord] +(** Identifiers of assumed functions that we use only in the pure translation *) +type pure_assumed_fun_id = + | Return (** The monadic return *) + | Fail (** The monadic fail *) + | Assert (** Assertion *) +[@@deriving show, ord] + +(** A function identifier *) type fun_id = - | Regular of A.fun_id * T.RegionGroupId.id option - (** Backward id: [Some] if the function is a backward function, [None] - if it is a forward function. - - TODO: we need to redefine A.fun_id here, to add [fail] and - [return] (important to get a unified treatment of the state-error - monad). For now, when using the state-error monad: extraction - works only if we unfold all the monadic let-bindings, and we - then replace the content of the occurrences of [Return] to also - return the state (which is really super ugly). - TODO: also add Assert... + | FromLlbc of A.fun_id * T.RegionGroupId.id option + (** A function coming from LLBC. + + The region group id is the backward id:: [Some] if the function is a + backward function, [None] if it is a forward function. *) + | Pure of pure_assumed_fun_id + (** A function only used in the pure translation *) +[@@deriving show, ord] + +(** A function or an operation id *) +type fun_or_op_id = + | Fun of fun_id | Unop of unop | Binop of E.binop * integer_type - | Assert [@@deriving show, ord] (** An identifier for an ADT constructor *) @@ -309,7 +317,7 @@ type adt_cons_id = { adt_id : type_id; variant_id : variant_id option } type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] type qualif_id = - | Func of fun_id + | FunOrOp of fun_or_op_id (** A function or an operation *) | Global of GlobalDeclId.id | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) | Proj of projection (** Field projector *) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 239b0b4f..9d604626 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -610,9 +610,9 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) match qualif.id with | AdtCons _ -> true (* ADT constructor *) | Proj _ -> true (* Projector *) - | Func (Unop _ | Binop _) -> + | FunOrOp (Unop _ | Binop _) -> true (* primitive function call *) - | Func (Regular _) -> false (* non-primitive function call *) + | FunOrOp (Fun _) -> false (* non-primitive function call *) | _ -> false) | _ -> false in @@ -667,24 +667,25 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) fn f<'a>(x : &'a mut T); ]} - We often have things like this in the synthesized code: + We often have things like this in the synthesized code: {[ - _ <-- f x; + _ <-- f@fwd x; ... nx <-- f@back'a x y; ... ]} - In this situation, we can remove the call [f x]. + If [f@back'a x y] fails, then necessarily [f@fwd x] also fails. + In this situation, we can remove the call [f@fwd x]. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (fun_id0 : fun_id) (tys0 : ty list) (args0 : texpression list) - (e : texpression) : bool = - let check_call (fun_id1 : fun_id) (tys1 : ty list) (args1 : texpression list) - : bool = + (id0 : A.fun_id) (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) + (args0 : texpression list) (e : texpression) : bool = + let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list) + (args1 : texpression list) : bool = (* Check the fun_ids, to see if call1's function is a child of call0's function *) - match (fun_id0, fun_id1) with - | Regular (id0, rg_id0), Regular (id1, rg_id1) -> + match fun_id1 with + | Fun (FromLlbc (id1, rg_id1)) -> (* Both are "regular" calls: check if they come from the same rust function *) if id0 = id1 then (* Same rust functions: check the regions hierarchy *) @@ -858,20 +859,20 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * We can filter if the right-expression is a function call, * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (func, tys, args) -> + | true, Some (Fun (FromLlbc (fid, rg_id)), tys, args) -> (* We need to check if there is a child call - see * the comments for: * [expression_contains_child_call_in_all_paths] *) let has_child_call = - expression_contains_child_call_in_all_paths ctx func tys - args e + expression_contains_child_call_in_all_paths ctx fid rg_id + tys args e in if has_child_call then (* Filter *) (e.e, fun _ -> used) else (* No child call: don't filter *) dont_filter () | _ -> - (* Not a call or not allowed to filter: we can't filter *) + (* Not an LLBC function call or not allowed to filter: we can't filter *) dont_filter () else (* There are used variables: don't filter *) dont_filter () @@ -1088,7 +1089,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( match fun_id with - | Regular (A.Assumed aid, rg_id) -> ( + | Fun (FromLlbc (A.Assumed aid, rg_id)) -> ( (* Below, when dealing with the arguments: we consider the very * general case, where functions could be boxed (meaning we * could have: [box_new f x]) diff --git a/compiler/PureToExtract.ml b/compiler/PureToExtract.ml index 860949a7..25ad6713 100644 --- a/compiler/PureToExtract.ml +++ b/compiler/PureToExtract.ml @@ -51,7 +51,6 @@ type formatter = { char_name : string; int_name : integer_type -> string; str_name : string; - assert_name : string; field_name : name -> FieldId.id -> string option -> string; (** Inputs: - type name @@ -86,16 +85,12 @@ type formatter = { global_name : global_name -> string; (** Provided a basename, compute a global name. *) fun_name : - A.fun_id -> - fun_name -> - int -> - region_group_info option -> - bool * int -> - string; - (** Inputs: - - function id: this is especially useful to identify whether the - function is an assumed function or a local function - - function basename + fun_name -> int -> region_group_info option -> bool * int -> string; + (** Compute the name of a regular (non-assumed) function. + + Inputs: + - function id + - function basename (TODO: shouldn't appear for assumed functions?...) - number of region groups - region group information in case of a backward function ([None] if forward function) @@ -191,7 +186,7 @@ type formatter = { (** We use identifiers to look for name clashes *) type id = | GlobalId of A.GlobalDeclId.id - | FunId of A.fun_id * RegionGroupId.id option + | FunId of fun_id | DecreasesClauseId of A.fun_id (** The definition which provides the decreases/termination clause. We insert calls to this clause to prove/reason about termination: @@ -290,10 +285,9 @@ let names_map_add_assumed_variant (id_to_string : id -> string) (nm : names_map) : names_map = names_map_add id_to_string (VariantId (Assumed id, variant_id)) name nm -let names_map_add_assumed_function (id_to_string : id -> string) - (fid : A.assumed_fun_id) (rg_id : RegionGroupId.id option) (name : string) - (nm : names_map) : names_map = - names_map_add id_to_string (FunId (A.Assumed fid, rg_id)) name nm +let names_map_add_function (id_to_string : id -> string) (fid : fun_id) + (name : string) (nm : names_map) : names_map = + names_map_add id_to_string (FunId fid) name nm (** Make a (variable) basename unique (by adding an index). @@ -360,25 +354,29 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | GlobalId gid -> let name = (A.GlobalDeclId.Map.find gid global_decls).name in "global name: " ^ Print.global_name_to_string name - | FunId (fid, rg_id) -> - let fun_name = - match fid with - | A.Regular fid -> - Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name - | A.Assumed aid -> A.show_assumed_fun_id aid - in - let fun_kind = - match rg_id with - | None -> "forward" - | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id - in - "fun name (" ^ fun_kind ^ "): " ^ fun_name + | FunId fid -> ( + match fid with + | FromLlbc (fid, rg_id) -> + let fun_name = + match fid with + | Regular fid -> + Print.fun_name_to_string + (A.FunDeclId.Map.find fid fun_decls).name + | Assumed aid -> A.show_assumed_fun_id aid + in + let fun_kind = + match rg_id with + | None -> "forward" + | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id + in + "fun name (" ^ fun_kind ^ "): " ^ fun_name + | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid) | DecreasesClauseId fid -> let fun_name = match fid with - | A.Regular fid -> + | Regular fid -> Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name - | A.Assumed aid -> A.show_assumed_fun_id aid + | Assumed aid -> A.show_assumed_fun_id aid in "decreases clause for function: " ^ fun_name | TypeId id -> "type name: " ^ get_type_name id @@ -451,13 +449,12 @@ let ctx_get (id : id) (ctx : extraction_ctx) : string = let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = ctx_get (GlobalId id) ctx -let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option) - (ctx : extraction_ctx) : string = - ctx_get (FunId (id, rg)) ctx +let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = + ctx_get (FunId id) ctx let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get_function (A.Regular id) rg ctx + ctx_get_function (FromLlbc (Regular id, rg)) ctx let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = assert (id <> Tuple); @@ -488,7 +485,7 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) : string = - ctx_get (DecreasesClauseId (A.Regular def_id)) ctx + ctx_get (DecreasesClauseId (Regular def_id)) ctx (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) @@ -577,13 +574,13 @@ let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.decreases_clause_name def.def_id def.basename in - ctx_add (DecreasesClauseId (A.Regular def.def_id)) name ctx + ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.global_name def.name in let decl = GlobalId def.def_id in - let body = FunId (Regular def.body_id, None) in + let body = FunId (FromLlbc (Regular def.body_id, None)) in let ctx = ctx_add decl (name ^ "_c") ctx in let ctx = ctx_add body (name ^ "_body") ctx in ctx @@ -617,18 +614,19 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) in Some { id = rg_id; region_names } in - let def_id = A.Regular def_id in let name = - ctx.fmt.fun_name def_id def.basename num_rgs rg_info (keep_fwd, num_backs) + ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs) in - ctx_add (FunId (def_id, def.back_id)) name ctx + ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx type names_map_init = { keywords : string list; assumed_adts : (assumed_ty * string) list; assumed_structs : (assumed_ty * string) list; assumed_variants : (assumed_ty * VariantId.id * string) list; - assumed_functions : (A.assumed_fun_id * RegionGroupId.id option * string) list; + assumed_llbc_functions : + (A.assumed_fun_id * RegionGroupId.id option * string) list; + assumed_pure_functions : (pure_assumed_fun_id * string) list; } (** Initialize a names map with a proper set of keywords/names coming from the @@ -638,9 +636,7 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = let keywords = List.concat [ - [ fmt.bool_name; fmt.char_name; fmt.str_name; fmt.assert_name ]; - int_names; - init.keywords; + [ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords; ] in let names_set = StringSet.of_list keywords in @@ -680,11 +676,16 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = names_map_add_assumed_variant id_to_string type_id variant_id name nm) nm init.assumed_variants in + let assumed_functions = + List.map + (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name)) + init.assumed_llbc_functions + @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions + in let nm = List.fold_left - (fun nm (fun_id, rg_id, name) -> - names_map_add_assumed_function id_to_string fun_id rg_id name nm) - nm init.assumed_functions + (fun nm (fid, name) -> names_map_add_function id_to_string fid name nm) + nm assumed_functions in (* Return *) nm diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 8c19a53b..6b6a82ad 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -113,7 +113,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx body | Qualif qualif -> ( match qualif.id with - | Func _ -> () (* TODO *) + | FunOrOp _ -> () (* TODO *) | Global _ -> () (* TODO *) | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index e292576c..ff379bf5 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -18,17 +18,17 @@ end module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) -module FunIdOrderedType = struct - type t = fun_id +module FunOrOpIdOrderedType = struct + type t = fun_or_op_id - let compare = compare_fun_id - let to_string = show_fun_id - let pp_t = pp_fun_id - let show_t = show_fun_id + let compare = compare_fun_or_op_id + let to_string = show_fun_or_op_id + let pp_t = pp_fun_or_op_id + let show_t = show_fun_or_op_id end -module FunIdMap = Collections.MakeMap (FunIdOrderedType) -module FunIdSet = Collections.MakeSet (FunIdOrderedType) +module FunOrOpIdMap = Collections.MakeMap (FunOrOpIdOrderedType) +module FunOrOpIdSet = Collections.MakeSet (FunOrOpIdOrderedType) let dest_arrow_ty (ty : ty) : ty * ty = match ty with @@ -114,23 +114,26 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : This function is meant to be applied on a set of (forward, backwards) functions generated for one recursive function. The way we do the test is very simple: - - we explore the functions one by one, in the order + - we explore the functions one by one, in the order in which they are provided - if all functions only call functions we already explored, they are not mutually recursive *) let functions_not_mutually_recursive (funs : fun_decl list) : bool = (* Compute the set of function identifiers in the group *) let ids = - FunIdSet.of_list + FunOrOpIdSet.of_list (List.map - (fun (f : fun_decl) -> Regular (A.Regular f.def_id, f.back_id)) + (fun (f : fun_decl) -> Fun (FromLlbc (A.Regular f.def_id, f.back_id))) funs) in let ids = ref ids in (* Explore every body *) let body_only_calls_itself (fdef : fun_decl) : bool = (* Remove the current id from the id set *) - ids := FunIdSet.remove (Regular (A.Regular fdef.def_id, fdef.back_id)) !ids; + ids := + FunOrOpIdSet.remove + (Fun (FromLlbc (A.Regular fdef.def_id, fdef.back_id))) + !ids; (* Check if we call functions from the updated id set *) let obj = @@ -139,8 +142,8 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = method! visit_qualif env qualif = match qualif.id with - | Func fun_id -> - if FunIdSet.mem fun_id !ids then raise Utils.Found + | FunOrOp fun_id -> + if FunOrOpIdSet.mem fun_id !ids then raise Utils.Found else super#visit_qualif env qualif | _ -> super#visit_qualif env qualif end @@ -242,12 +245,12 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list = (** Destruct an expression into a function call, if possible *) let opt_destruct_function_call (e : texpression) : - (fun_id * ty list * texpression list) option = + (fun_or_op_id * ty list * texpression list) option = match opt_destruct_qualif_app e with | None -> None | Some (qualif, args) -> ( match qualif.id with - | Func fun_id -> Some (fun_id, qualif.type_args, args) + | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args) | _ -> None) let opt_destruct_result (ty : ty) : ty option = diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8329d80e..6d01614d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -240,7 +240,7 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (** [back_args]: the *additional* list of inputs received by the backward function *) let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list) - (ctx : bs_ctx) : bs_ctx * fun_id = + (ctx : bs_ctx) : bs_ctx * fun_or_op_id = (* Insert the abstraction in the call informations *) let back_id = abs.back_id in let info = V.FunCallId.Map.find abs.call_id ctx.calls in @@ -259,7 +259,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list) (* Retrieve the fun_id *) let fun_id = match info.forward.call_id with - | S.Fun (fid, _) -> Regular (fid, Some abs.back_id) + | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some abs.back_id)) | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") in (* Update the context and return *) @@ -1167,7 +1167,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) - let func = Regular (fid, None) in + let func = Fun (FromLlbc (fid, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = @@ -1235,7 +1235,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let func = { id = Func fun_id; type_args } in + let func = { id = FunOrOp fun_id; type_args } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty @@ -1390,7 +1390,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = { id = Func func; type_args } in + let func = { id = FunOrOp func; type_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -1487,7 +1487,7 @@ and translate_assertion (config : config) (v : V.typed_value) (e : S.expression) let monadic = true in let v = typed_value_to_texpression ctx v in let args = [ v ] in - let func = { id = Func Assert; type_args = [] } in + let func = { id = FunOrOp (Fun (Pure Assert)); type_args = [] } in let func_ty = mk_arrow Bool mk_unit_ty in let func = { e = Qualif func; ty = func_ty } in let assertion = mk_apps func args in -- cgit v1.2.3