summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-11-09 11:47:21 +0100
committerSon HO2022-11-10 11:35:30 +0100
commitb970183881379ff676b232e47e353e924de8cfdd (patch)
tree60709cf395439d1b53d03fc5bfbcfd4f05552716
parenta68926f574b23e75fe13ef3a500df7648a3c23d8 (diff)
Update the way function names are handled in Pure
-rw-r--r--compiler/ExtractToFStar.ml34
-rw-r--r--compiler/PrintPure.ml57
-rw-r--r--compiler/Pure.ml34
-rw-r--r--compiler/PureMicroPasses.ml33
-rw-r--r--compiler/PureToExtract.ml97
-rw-r--r--compiler/PureTypeCheck.ml2
-rw-r--r--compiler/PureUtils.ml35
-rw-r--r--compiler/SymbolicToPure.ml12
8 files changed, 162 insertions, 142 deletions
diff --git a/compiler/ExtractToFStar.ml b/compiler/ExtractToFStar.ml
index 6d680984..2a7d6a6c 100644
--- a/compiler/ExtractToFStar.ml
+++ b/compiler/ExtractToFStar.ml
@@ -128,7 +128,7 @@ let fstar_assumed_variants : (assumed_ty * VariantId.id * string) list =
(Option, option_none_id, "None");
]
-let fstar_assumed_functions :
+let fstar_assumed_llbc_functions :
(A.assumed_fun_id * T.RegionGroupId.id option * string) list =
let rg0 = Some T.RegionGroupId.zero in
[
@@ -146,13 +146,17 @@ let fstar_assumed_functions :
(VecIndexMut, rg0, "vec_index_mut_back");
]
+let fstar_assumed_pure_functions : (pure_assumed_fun_id * string) list =
+ [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ]
+
let fstar_names_map_init : names_map_init =
{
keywords = fstar_keywords;
assumed_adts = fstar_assumed_adts;
assumed_structs = fstar_assumed_structs;
assumed_variants = fstar_assumed_variants;
- assumed_functions = fstar_assumed_functions;
+ assumed_llbc_functions = fstar_assumed_llbc_functions;
+ assumed_pure_functions = fstar_assumed_pure_functions;
}
let fstar_extract_unop (extract_expr : bool -> texpression -> unit)
@@ -321,7 +325,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let parts = List.map to_snake_case (get_name name) in
String.concat "_" parts
in
- let fun_name (_fid : A.fun_id) (fname : fun_name) (num_rgs : int)
+ let fun_name (fname : fun_name) (num_rgs : int)
(rg : region_group_info option) (filter_info : bool * int) : string =
let fname = fun_name_to_snake_case fname in
(* Compute the suffix *)
@@ -416,7 +420,6 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
char_name = "char";
int_name;
str_name = "string";
- assert_name = "massert";
field_name;
variant_name;
struct_constructor;
@@ -433,7 +436,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
}
(** [inside] constrols whether we should add parentheses or not around type
- application (if [true] we add parentheses).
+ applications (if [true] we add parentheses).
*)
let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(ty : ty) : unit =
@@ -928,7 +931,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| Qualif qualif -> (
(* Top-level qualifier *)
match qualif.id with
- | Func fun_id ->
+ | FunOrOp fun_id ->
extract_function_call ctx fmt inside fun_id qualif.type_args args
| Global global_id -> extract_global ctx fmt global_id
| AdtCons adt_cons_id ->
@@ -957,7 +960,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: function call *)
and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (fid : fun_id) (type_args : ty list)
+ (inside : bool) (fid : fun_or_op_id) (type_args : ty list)
(args : texpression list) : unit =
match (fid, args) with
| Unop unop, [ arg ] ->
@@ -971,17 +974,12 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
ctx.fmt.extract_binop
(extract_texpression ctx fmt)
fmt inside binop int_ty arg0 arg1
- | Regular (_, _), _ | Assert, [ _ ] ->
+ | Fun fun_id, _ ->
if inside then F.pp_print_string fmt "(";
(* Open a box for the function call *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the function name *)
- let fun_name =
- match fid with
- | Regular (fun_id, rg_id) -> ctx_get_function fun_id rg_id ctx
- | Assert -> ctx.fmt.assert_name
- | _ -> raise (Failure "Unreachable")
- in
+ let fun_name = ctx_get_function fun_id ctx in
F.pp_print_string fmt fun_name;
(* Print the type parameters *)
List.iter
@@ -999,10 +997,10 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ();
(* Return *)
if inside then F.pp_print_string fmt ")"
- | (Unop _ | Binop _ | Assert), _ ->
+ | (Unop _ | Binop _), _ ->
raise
(Failure
- ("Unreachable:\n" ^ "Function: " ^ show_fun_id fid
+ ("Unreachable:\n" ^ "Function: " ^ show_fun_or_op_id fid
^ ",\nNumber of arguments: "
^ string_of_int (List.length args)
^ ",\nArguments: "
@@ -1576,7 +1574,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
let decl_name = ctx_get_global global.def_id ctx in
- let body_name = ctx_get_function (Regular global.body_id) None ctx in
+ let body_name =
+ ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx
+ in
let decl_ty, body_ty =
let ty = body.signature.output in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 0d1288d7..b4ab26b8 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -386,30 +386,39 @@ let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
-let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : A.fun_id) : string
- =
- match fun_id with
- | A.Regular fid -> fmt.fun_decl_id_to_string fid
- | A.Assumed fid -> (
- match fid with
- | A.Replace -> "core::mem::replace"
- | A.BoxNew -> "alloc::boxed::Box::new"
- | A.BoxDeref -> "core::ops::deref::Deref::deref"
- | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut"
- | A.BoxFree -> "alloc::alloc::box_free"
- | A.VecNew -> "alloc::vec::Vec::new"
- | A.VecPush -> "alloc::vec::Vec::push"
- | A.VecInsert -> "alloc::vec::Vec::insert"
- | A.VecLen -> "alloc::vec::Vec::len"
- | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index"
- | A.VecIndexMut ->
- "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut")
-
let fun_suffix (rg_id : T.RegionGroupId.id option) : string =
match rg_id with
| None -> ""
| Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
+let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
+ match fid with
+ | A.Replace -> "core::mem::replace"
+ | A.BoxNew -> "alloc::boxed::Box::new"
+ | A.BoxDeref -> "core::ops::deref::Deref::deref"
+ | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut"
+ | A.BoxFree -> "alloc::alloc::box_free"
+ | A.VecNew -> "alloc::vec::Vec::new"
+ | A.VecPush -> "alloc::vec::Vec::push"
+ | A.VecInsert -> "alloc::vec::Vec::insert"
+ | A.VecLen -> "alloc::vec::Vec::len"
+ | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index"
+ | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut"
+
+let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
+ match fid with Return -> "return" | Fail -> "fail" | Assert -> "assert"
+
+let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
+ match fun_id with
+ | FromLlbc (fid, rg_id) ->
+ let f =
+ match fid with
+ | Regular fid -> fmt.fun_decl_id_to_string fid
+ | Assumed fid -> llbc_assumed_fun_id_to_string fid
+ in
+ f ^ fun_suffix rg_id
+ | Pure fid -> pure_assumed_fun_id_to_string fid
+
let unop_to_string (unop : unop) : string =
match unop with
| Not -> "¬"
@@ -420,15 +429,13 @@ let unop_to_string (unop : unop) : string =
let binop_to_string = Print.Expressions.binop_to_string
-let fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
+let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) :
+ string =
match fun_id with
- | Regular (fun_id, rg_id) ->
- let f = regular_fun_id_to_string fmt fun_id in
- f ^ fun_suffix rg_id
+ | Fun fun_id -> regular_fun_id_to_string fmt fun_id
| Unop unop -> unop_to_string unop
| Binop (binop, int_ty) ->
binop_to_string binop ^ "<" ^ integer_type_to_string int_ty ^ ">"
- | Assert -> "assert"
(** [inside]: controls the introduction of parentheses *)
let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
@@ -478,7 +485,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
(* Convert the qualifier identifier *)
let qualif_s =
match qualif.id with
- | Func fun_id -> fun_id_to_string fmt fun_id
+ | FunOrOp fun_id -> fun_or_op_id_to_string fmt fun_id
| Global global_id -> fmt.global_decl_id_to_string global_id
| AdtCons adt_cons_id ->
let variant_s =
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index f11397e9..a50dd5f9 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -281,22 +281,30 @@ and typed_pattern = { value : pattern; ty : ty }
type unop = Not | Neg of integer_type | Cast of integer_type * integer_type
[@@deriving show, ord]
+(** Identifiers of assumed functions that we use only in the pure translation *)
+type pure_assumed_fun_id =
+ | Return (** The monadic return *)
+ | Fail (** The monadic fail *)
+ | Assert (** Assertion *)
+[@@deriving show, ord]
+
+(** A function identifier *)
type fun_id =
- | Regular of A.fun_id * T.RegionGroupId.id option
- (** Backward id: [Some] if the function is a backward function, [None]
- if it is a forward function.
-
- TODO: we need to redefine A.fun_id here, to add [fail] and
- [return] (important to get a unified treatment of the state-error
- monad). For now, when using the state-error monad: extraction
- works only if we unfold all the monadic let-bindings, and we
- then replace the content of the occurrences of [Return] to also
- return the state (which is really super ugly).
- TODO: also add Assert...
+ | FromLlbc of A.fun_id * T.RegionGroupId.id option
+ (** A function coming from LLBC.
+
+ The region group id is the backward id:: [Some] if the function is a
+ backward function, [None] if it is a forward function.
*)
+ | Pure of pure_assumed_fun_id
+ (** A function only used in the pure translation *)
+[@@deriving show, ord]
+
+(** A function or an operation id *)
+type fun_or_op_id =
+ | Fun of fun_id
| Unop of unop
| Binop of E.binop * integer_type
- | Assert
[@@deriving show, ord]
(** An identifier for an ADT constructor *)
@@ -309,7 +317,7 @@ type adt_cons_id = { adt_id : type_id; variant_id : variant_id option }
type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show]
type qualif_id =
- | Func of fun_id
+ | FunOrOp of fun_or_op_id (** A function or an operation *)
| Global of GlobalDeclId.id
| AdtCons of adt_cons_id (** A function or ADT constructor identifier *)
| Proj of projection (** Field projector *)
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 239b0b4f..9d604626 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -610,9 +610,9 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
match qualif.id with
| AdtCons _ -> true (* ADT constructor *)
| Proj _ -> true (* Projector *)
- | Func (Unop _ | Binop _) ->
+ | FunOrOp (Unop _ | Binop _) ->
true (* primitive function call *)
- | Func (Regular _) -> false (* non-primitive function call *)
+ | FunOrOp (Fun _) -> false (* non-primitive function call *)
| _ -> false)
| _ -> false
in
@@ -667,24 +667,25 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
fn f<'a>(x : &'a mut T);
]}
- We often have things like this in the synthesized code:
+ We often have things like this in the synthesized code:
{[
- _ <-- f x;
+ _ <-- f@fwd x;
...
nx <-- f@back'a x y;
...
]}
- In this situation, we can remove the call [f x].
+ If [f@back'a x y] fails, then necessarily [f@fwd x] also fails.
+ In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (fun_id0 : fun_id) (tys0 : ty list) (args0 : texpression list)
- (e : texpression) : bool =
- let check_call (fun_id1 : fun_id) (tys1 : ty list) (args1 : texpression list)
- : bool =
+ (id0 : A.fun_id) (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (args0 : texpression list) (e : texpression) : bool =
+ let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
+ (args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
- match (fun_id0, fun_id1) with
- | Regular (id0, rg_id0), Regular (id1, rg_id1) ->
+ match fun_id1 with
+ | Fun (FromLlbc (id1, rg_id1)) ->
(* Both are "regular" calls: check if they come from the same rust function *)
if id0 = id1 then
(* Same rust functions: check the regions hierarchy *)
@@ -858,20 +859,20 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
* We can filter if the right-expression is a function call,
* under some conditions. *)
match (filter_monadic_calls, opt_destruct_function_call re) with
- | true, Some (func, tys, args) ->
+ | true, Some (Fun (FromLlbc (fid, rg_id)), tys, args) ->
(* We need to check if there is a child call - see
* the comments for:
* [expression_contains_child_call_in_all_paths] *)
let has_child_call =
- expression_contains_child_call_in_all_paths ctx func tys
- args e
+ expression_contains_child_call_in_all_paths ctx fid rg_id
+ tys args e
in
if has_child_call then (* Filter *)
(e.e, fun _ -> used)
else (* No child call: don't filter *)
dont_filter ()
| _ ->
- (* Not a call or not allowed to filter: we can't filter *)
+ (* Not an LLBC function call or not allowed to filter: we can't filter *)
dont_filter ()
else (* There are used variables: don't filter *)
dont_filter ()
@@ -1088,7 +1089,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
match fun_id with
- | Regular (A.Assumed aid, rg_id) -> (
+ | Fun (FromLlbc (A.Assumed aid, rg_id)) -> (
(* Below, when dealing with the arguments: we consider the very
* general case, where functions could be boxed (meaning we
* could have: [box_new f x])
diff --git a/compiler/PureToExtract.ml b/compiler/PureToExtract.ml
index 860949a7..25ad6713 100644
--- a/compiler/PureToExtract.ml
+++ b/compiler/PureToExtract.ml
@@ -51,7 +51,6 @@ type formatter = {
char_name : string;
int_name : integer_type -> string;
str_name : string;
- assert_name : string;
field_name : name -> FieldId.id -> string option -> string;
(** Inputs:
- type name
@@ -86,16 +85,12 @@ type formatter = {
global_name : global_name -> string;
(** Provided a basename, compute a global name. *)
fun_name :
- A.fun_id ->
- fun_name ->
- int ->
- region_group_info option ->
- bool * int ->
- string;
- (** Inputs:
- - function id: this is especially useful to identify whether the
- function is an assumed function or a local function
- - function basename
+ fun_name -> int -> region_group_info option -> bool * int -> string;
+ (** Compute the name of a regular (non-assumed) function.
+
+ Inputs:
+ - function id
+ - function basename (TODO: shouldn't appear for assumed functions?...)
- number of region groups
- region group information in case of a backward function
([None] if forward function)
@@ -191,7 +186,7 @@ type formatter = {
(** We use identifiers to look for name clashes *)
type id =
| GlobalId of A.GlobalDeclId.id
- | FunId of A.fun_id * RegionGroupId.id option
+ | FunId of fun_id
| DecreasesClauseId of A.fun_id
(** The definition which provides the decreases/termination clause.
We insert calls to this clause to prove/reason about termination:
@@ -290,10 +285,9 @@ let names_map_add_assumed_variant (id_to_string : id -> string)
(nm : names_map) : names_map =
names_map_add id_to_string (VariantId (Assumed id, variant_id)) name nm
-let names_map_add_assumed_function (id_to_string : id -> string)
- (fid : A.assumed_fun_id) (rg_id : RegionGroupId.id option) (name : string)
- (nm : names_map) : names_map =
- names_map_add id_to_string (FunId (A.Assumed fid, rg_id)) name nm
+let names_map_add_function (id_to_string : id -> string) (fid : fun_id)
+ (name : string) (nm : names_map) : names_map =
+ names_map_add id_to_string (FunId fid) name nm
(** Make a (variable) basename unique (by adding an index).
@@ -360,25 +354,29 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| GlobalId gid ->
let name = (A.GlobalDeclId.Map.find gid global_decls).name in
"global name: " ^ Print.global_name_to_string name
- | FunId (fid, rg_id) ->
- let fun_name =
- match fid with
- | A.Regular fid ->
- Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name
- | A.Assumed aid -> A.show_assumed_fun_id aid
- in
- let fun_kind =
- match rg_id with
- | None -> "forward"
- | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
- in
- "fun name (" ^ fun_kind ^ "): " ^ fun_name
+ | FunId fid -> (
+ match fid with
+ | FromLlbc (fid, rg_id) ->
+ let fun_name =
+ match fid with
+ | Regular fid ->
+ Print.fun_name_to_string
+ (A.FunDeclId.Map.find fid fun_decls).name
+ | Assumed aid -> A.show_assumed_fun_id aid
+ in
+ let fun_kind =
+ match rg_id with
+ | None -> "forward"
+ | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
+ in
+ "fun name (" ^ fun_kind ^ "): " ^ fun_name
+ | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid)
| DecreasesClauseId fid ->
let fun_name =
match fid with
- | A.Regular fid ->
+ | Regular fid ->
Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name
- | A.Assumed aid -> A.show_assumed_fun_id aid
+ | Assumed aid -> A.show_assumed_fun_id aid
in
"decreases clause for function: " ^ fun_name
| TypeId id -> "type name: " ^ get_type_name id
@@ -451,13 +449,12 @@ let ctx_get (id : id) (ctx : extraction_ctx) : string =
let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string =
ctx_get (GlobalId id) ctx
-let ctx_get_function (id : A.fun_id) (rg : RegionGroupId.id option)
- (ctx : extraction_ctx) : string =
- ctx_get (FunId (id, rg)) ctx
+let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
+ ctx_get (FunId id) ctx
let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option)
(ctx : extraction_ctx) : string =
- ctx_get_function (A.Regular id) rg ctx
+ ctx_get_function (FromLlbc (Regular id, rg)) ctx
let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
assert (id <> Tuple);
@@ -488,7 +485,7 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id)
let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) :
string =
- ctx_get (DecreasesClauseId (A.Regular def_id)) ctx
+ ctx_get (DecreasesClauseId (Regular def_id)) ctx
(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
@@ -577,13 +574,13 @@ let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) :
let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) :
extraction_ctx =
let name = ctx.fmt.decreases_clause_name def.def_id def.basename in
- ctx_add (DecreasesClauseId (A.Regular def.def_id)) name ctx
+ ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx
let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
let name = ctx.fmt.global_name def.name in
let decl = GlobalId def.def_id in
- let body = FunId (Regular def.body_id, None) in
+ let body = FunId (FromLlbc (Regular def.body_id, None)) in
let ctx = ctx_add decl (name ^ "_c") ctx in
let ctx = ctx_add body (name ^ "_body") ctx in
ctx
@@ -617,18 +614,19 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
in
Some { id = rg_id; region_names }
in
- let def_id = A.Regular def_id in
let name =
- ctx.fmt.fun_name def_id def.basename num_rgs rg_info (keep_fwd, num_backs)
+ ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs)
in
- ctx_add (FunId (def_id, def.back_id)) name ctx
+ ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx
type names_map_init = {
keywords : string list;
assumed_adts : (assumed_ty * string) list;
assumed_structs : (assumed_ty * string) list;
assumed_variants : (assumed_ty * VariantId.id * string) list;
- assumed_functions : (A.assumed_fun_id * RegionGroupId.id option * string) list;
+ assumed_llbc_functions :
+ (A.assumed_fun_id * RegionGroupId.id option * string) list;
+ assumed_pure_functions : (pure_assumed_fun_id * string) list;
}
(** Initialize a names map with a proper set of keywords/names coming from the
@@ -638,9 +636,7 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
let keywords =
List.concat
[
- [ fmt.bool_name; fmt.char_name; fmt.str_name; fmt.assert_name ];
- int_names;
- init.keywords;
+ [ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords;
]
in
let names_set = StringSet.of_list keywords in
@@ -680,11 +676,16 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
names_map_add_assumed_variant id_to_string type_id variant_id name nm)
nm init.assumed_variants
in
+ let assumed_functions =
+ List.map
+ (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name))
+ init.assumed_llbc_functions
+ @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
+ in
let nm =
List.fold_left
- (fun nm (fun_id, rg_id, name) ->
- names_map_add_assumed_function id_to_string fun_id rg_id name nm)
- nm init.assumed_functions
+ (fun nm (fid, name) -> names_map_add_function id_to_string fid name nm)
+ nm assumed_functions
in
(* Return *)
nm
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 8c19a53b..6b6a82ad 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -113,7 +113,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx body
| Qualif qualif -> (
match qualif.id with
- | Func _ -> () (* TODO *)
+ | FunOrOp _ -> () (* TODO *)
| Global _ -> () (* TODO *)
| Proj { adt_id = proj_adt_id; field_id } ->
(* Note we can only project fields of structures (not enumerations) *)
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index e292576c..ff379bf5 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -18,17 +18,17 @@ end
module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType)
-module FunIdOrderedType = struct
- type t = fun_id
+module FunOrOpIdOrderedType = struct
+ type t = fun_or_op_id
- let compare = compare_fun_id
- let to_string = show_fun_id
- let pp_t = pp_fun_id
- let show_t = show_fun_id
+ let compare = compare_fun_or_op_id
+ let to_string = show_fun_or_op_id
+ let pp_t = pp_fun_or_op_id
+ let show_t = show_fun_or_op_id
end
-module FunIdMap = Collections.MakeMap (FunIdOrderedType)
-module FunIdSet = Collections.MakeSet (FunIdOrderedType)
+module FunOrOpIdMap = Collections.MakeMap (FunOrOpIdOrderedType)
+module FunOrOpIdSet = Collections.MakeSet (FunOrOpIdOrderedType)
let dest_arrow_ty (ty : ty) : ty * ty =
match ty with
@@ -114,23 +114,26 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
This function is meant to be applied on a set of (forward, backwards) functions
generated for one recursive function.
The way we do the test is very simple:
- - we explore the functions one by one, in the order
+ - we explore the functions one by one, in the order in which they are provided
- if all functions only call functions we already explored, they are not
mutually recursive
*)
let functions_not_mutually_recursive (funs : fun_decl list) : bool =
(* Compute the set of function identifiers in the group *)
let ids =
- FunIdSet.of_list
+ FunOrOpIdSet.of_list
(List.map
- (fun (f : fun_decl) -> Regular (A.Regular f.def_id, f.back_id))
+ (fun (f : fun_decl) -> Fun (FromLlbc (A.Regular f.def_id, f.back_id)))
funs)
in
let ids = ref ids in
(* Explore every body *)
let body_only_calls_itself (fdef : fun_decl) : bool =
(* Remove the current id from the id set *)
- ids := FunIdSet.remove (Regular (A.Regular fdef.def_id, fdef.back_id)) !ids;
+ ids :=
+ FunOrOpIdSet.remove
+ (Fun (FromLlbc (A.Regular fdef.def_id, fdef.back_id)))
+ !ids;
(* Check if we call functions from the updated id set *)
let obj =
@@ -139,8 +142,8 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool =
method! visit_qualif env qualif =
match qualif.id with
- | Func fun_id ->
- if FunIdSet.mem fun_id !ids then raise Utils.Found
+ | FunOrOp fun_id ->
+ if FunOrOpIdSet.mem fun_id !ids then raise Utils.Found
else super#visit_qualif env qualif
| _ -> super#visit_qualif env qualif
end
@@ -242,12 +245,12 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list =
(** Destruct an expression into a function call, if possible *)
let opt_destruct_function_call (e : texpression) :
- (fun_id * ty list * texpression list) option =
+ (fun_or_op_id * ty list * texpression list) option =
match opt_destruct_qualif_app e with
| None -> None
| Some (qualif, args) -> (
match qualif.id with
- | Func fun_id -> Some (fun_id, qualif.type_args, args)
+ | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args)
| _ -> None)
let opt_destruct_result (ty : ty) : ty option =
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 8329d80e..6d01614d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -240,7 +240,7 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
(** [back_args]: the *additional* list of inputs received by the backward function *)
let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list)
- (ctx : bs_ctx) : bs_ctx * fun_id =
+ (ctx : bs_ctx) : bs_ctx * fun_or_op_id =
(* Insert the abstraction in the call informations *)
let back_id = abs.back_id in
let info = V.FunCallId.Map.find abs.call_id ctx.calls in
@@ -259,7 +259,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list)
(* Retrieve the fun_id *)
let fun_id =
match info.forward.call_id with
- | S.Fun (fid, _) -> Regular (fid, Some abs.back_id)
+ | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some abs.back_id))
| S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
in
(* Update the context and return *)
@@ -1167,7 +1167,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Regular (fid, None) in
+ let func = Fun (FromLlbc (fid, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -1235,7 +1235,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
- let func = { id = Func fun_id; type_args } in
+ let func = { id = FunOrOp fun_id; type_args } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty
@@ -1390,7 +1390,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = Func func; type_args } in
+ let func = { id = FunOrOp func; type_args } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -1487,7 +1487,7 @@ and translate_assertion (config : config) (v : V.typed_value) (e : S.expression)
let monadic = true in
let v = typed_value_to_texpression ctx v in
let args = [ v ] in
- let func = { id = Func Assert; type_args = [] } in
+ let func = { id = FunOrOp (Fun (Pure Assert)); type_args = [] } in
let func_ty = mk_arrow Bool mk_unit_ty in
let func = { e = Qualif func; ty = func_ty } in
let assertion = mk_apps func args in