summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
authorSon HO2023-11-10 18:21:06 +0100
committerGitHub2023-11-10 18:21:06 +0100
commit587f1ebc0178acb19029d3fc9a729c197082aba7 (patch)
treef29805e5426f9f3fabe12d3fdadda96a1e987880 /compiler/SymbolicToPure.ml
parent7fc7c82aa61d782b335e7cf37231fd9998cd0d89 (diff)
parentd300be95c28ff3147bb6f6a65992df5b9b571bdf (diff)
Merge pull request #44 from AeneasVerif/son_traits_types
Add support for traits
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml830
1 files changed, 559 insertions, 271 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3512270a..2ce8c706 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -4,6 +4,7 @@ open Pure
open PureUtils
module Id = Identifiers
module C = Contexts
+module A = LlbcAst
module S = SymbolicAst
module TA = TypesAnalysis
module L = Logging
@@ -52,6 +53,9 @@ type fun_context = {
type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
children abstractions).
@@ -106,8 +110,7 @@ type loop_info = {
loop_id : LoopId.id;
input_vars : var list;
input_svl : V.symbolic_value list;
- type_args : ty list;
- const_generic_args : const_generic list;
+ generics : generic_args;
forward_inputs : texpression list option;
(** The forward inputs are initialized at [None] *)
forward_output_no_state_no_result : var option;
@@ -120,6 +123,8 @@ type bs_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
bid : T.RegionGroupId.id option; (** TODO: rename *)
sg : fun_sig;
@@ -201,34 +206,11 @@ type bs_ctx = {
}
[@@deriving show]
-let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- let _ = PureTypeCheck.check_typed_pattern ctx v in
- ()
-
-let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- PureTypeCheck.check_texpression ctx e
-
(* TODO: move *)
let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter =
Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls
ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls
- ctx.fun_decl
+ ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl
let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
let rvar_to_string = Print.Types.region_var_id_to_string in
@@ -246,16 +228,25 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
adt_variant_to_string = ast_fmt.adt_variant_to_string;
var_id_to_string;
adt_field_names = ast_fmt.adt_field_names;
+ trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string;
}
let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter =
- let type_params = ctx.fun_decl.signature.type_params in
- let cg_params = ctx.fun_decl.signature.const_generic_params in
+ let generics = ctx.fun_decl.signature.generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types
+ generics.const_generics
+
+let ctx_egeneric_args_to_string (ctx : bs_ctx) (args : T.egeneric_args) : string
+ =
+ let fmt = bs_ctx_to_ctx_formatter ctx in
+ let fmt = Print.PC.ctx_to_etype_formatter fmt in
+ Print.PT.egeneric_args_to_string fmt args
let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string =
let fmt = bs_ctx_to_ctx_formatter ctx in
@@ -277,12 +268,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
Print.PT.rty_to_string fmt ty
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx
+ ctx.trait_impls_ctx def.generics.types def.generics.const_generics
in
PrintPure.type_decl_to_string fmt def
@@ -291,26 +281,27 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string =
PrintPure.texpression_to_string fmt false "" " " e
let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string =
- let type_params = sg.type_params in
- let cg_params = sg.const_generic_params in
+ let type_params = sg.generics.types in
+ let cg_params = sg.generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string =
- let type_params = def.signature.type_params in
- let cg_params = def.signature.const_generic_params in
+ let generics = def.signature.generics in
+ let type_params = generics.types in
+ let cg_params = generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_decl_to_string fmt def
@@ -328,17 +319,18 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
Print.Values.abs_to_string fmt verbose indent indent_incr abs
let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (tys : ty list)
- (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig =
+ (back_id : T.RegionGroupId.id option) (generics : generic_args)
+ (ctx : bs_ctx) : inst_fun_sig =
(* Lookup the non-instantiated function signature *)
let sg =
(RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
in
(* Create the substitution *)
- let tsubst = make_type_subst sg.type_params tys in
- let cgsubst = make_const_generic_subst sg.const_generic_params cgs in
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics sg.generics generics tr_self in
(* Apply *)
- fun_sig_substitute tsubst cgsubst sg
+ fun_sig_substitute subst sg
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
@@ -351,77 +343,128 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) :
(* TODO: move *)
let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id)
(back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig =
- let id = (A.Regular def_id, back_id) in
+ let id = (E.Regular def_id, back_id) in
(RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg
-let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
- let calls = ctx.calls in
- assert (not (V.FunCallId.Map.mem call_id calls));
- let info =
- { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
- in
- let calls = V.FunCallId.Map.add call_id info calls in
- { ctx with calls }
-
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
- (* Insert the abstraction in the call informations *)
- let info = V.FunCallId.Map.find call_id ctx.calls in
- assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards =
- T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
- in
- let info = { info with backwards } in
- let calls = V.FunCallId.Map.add call_id info ctx.calls in
- (* Insert the abstraction in the abstractions map *)
- let abstractions = ctx.abstractions in
- assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
- let abstractions =
- V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
- in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- (* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+(* Some generic translation functions (we need to translate different "flavours"
+ of types: sty, forward types, backward types, etc.) *)
+let rec translate_generic_args (translate_ty : 'r T.ty -> ty)
+ (generics : 'r T.generic_args) : generic_args =
+ (* We ignore the regions: if they didn't cause trouble for the symbolic execution,
+ then everything's fine *)
+ let types = List.map translate_ty generics.types in
+ let const_generics = generics.const_generics in
+ let trait_refs =
+ List.map (translate_trait_ref translate_ty) generics.trait_refs
+ in
+ { types; const_generics; trait_refs }
+
+and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) :
+ trait_ref =
+ let trait_id = translate_trait_instance_id translate_ty tr.trait_id in
+ let generics = translate_generic_args translate_ty tr.generics in
+ let trait_decl_ref =
+ translate_trait_decl_ref translate_ty tr.trait_decl_ref
+ in
+ { trait_id; generics; trait_decl_ref }
+
+and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty)
+ (tr : 'r T.trait_decl_ref) : trait_decl_ref =
+ let decl_generics = translate_generic_args translate_ty tr.decl_generics in
+ { trait_decl_id = tr.trait_decl_id; decl_generics }
+
+and translate_trait_instance_id (translate_ty : 'r T.ty -> ty)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ let translate_trait_instance_id = translate_trait_instance_id translate_ty in
+ match id with
+ | T.Self -> Self
+ | TraitImpl id -> TraitImpl id
+ | BuiltinOrAuto _ ->
+ (* We should have eliminated those in the prepasses *)
+ raise (Failure "Unreachable")
+ | Clause id -> Clause id
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ParentClause (inst_id, decl_id, clause_id)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ItemClause (inst_id, decl_id, item_name, clause_id)
+ | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr)
+ | FnPointer _ -> raise (Failure "TODO")
+ | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s))
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- let tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let generics = translate_sgeneric_args generics in
match type_id with
- | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs)
- | T.Tuple -> mk_simpl_tuple_ty tys
+ | T.AdtId adt_id -> Adt (AdtId adt_id, generics)
+ | T.Tuple ->
+ assert (generics.const_generics = []);
+ mk_simpl_tuple_ty generics.types
| T.Assumed aty -> (
match aty with
- | T.Vec -> Adt (Assumed Vec, tys, cgs)
- | T.Option -> Adt (Assumed Option, tys, cgs)
| T.Box -> (
(* Eliminate the boxes *)
- match tys with
+ match generics.types with
| [ ty ] -> ty
| _ ->
raise
(Failure
"Box/vec/option type with incorrect number of arguments")
)
- | T.Array -> Adt (Assumed Array, tys, cgs)
- | T.Slice -> Adt (Assumed Slice, tys, cgs)
- | T.Str -> Adt (Assumed Str, tys, cgs)
- | T.Range -> Adt (Assumed Range, tys, cgs)))
+ | T.Array -> Adt (Assumed Array, generics)
+ | T.Slice -> Adt (Assumed Slice, generics)
+ | T.Str -> Adt (Assumed Str, generics)))
| TypeVar vid -> TypeVar vid
| Literal ty -> Literal ty
| Never -> raise (Failure "Unreachable")
| Ref (_, rty, _) -> translate rty
+ | RawPtr (ty, rkind) ->
+ let mut = match rkind with Mut -> Mut | Shared -> Const in
+ let ty = translate ty in
+ let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in
+ Adt (Assumed (RawPtr mut), generics)
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ TraitType (trait_ref, generics, type_name)
+ | Arrow _ -> raise (Failure "TODO")
+
+and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args =
+ translate_generic_args translate_sty generics
+
+and translate_strait_ref (tr : T.strait_ref) : trait_ref =
+ translate_trait_ref translate_sty tr
+
+and translate_strait_instance_id (id : T.strait_instance_id) : trait_instance_id
+ =
+ translate_trait_instance_id translate_sty id
+
+let translate_trait_clause (clause : T.trait_clause) : trait_clause =
+ let { T.clause_id; meta = _; trait_id; generics } = clause in
+ let generics = translate_sgeneric_args generics in
+ { clause_id; trait_id; generics }
+
+let translate_strait_type_constraint (ttc : T.strait_type_constraint) :
+ trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ let ty = translate_sty ty in
+ { trait_ref; generics; type_name; ty }
+
+let translate_predicates (preds : T.predicates) : predicates =
+ let trait_type_constraints =
+ List.map translate_strait_type_constraint preds.trait_type_constraints
+ in
+ { trait_type_constraints }
+
+let translate_generic_params (generics : T.generic_params) : generic_params =
+ let { T.regions = _; types; const_generics; trait_clauses } = generics in
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ { types; const_generics; trait_clauses }
let translate_field (f : T.field) : field =
let field_name = f.field_name in
@@ -452,15 +495,16 @@ let translate_type_decl_kind (kind : T.type_decl_kind) : type_decl_kind =
point of moving this definition for now.
*)
let translate_type_decl (def : T.type_decl) : type_decl =
- (* Translate *)
let def_id = def.T.def_id in
let name = def.name in
+ let { T.regions; types; const_generics; trait_clauses } = def.generics in
(* Can't translate types with regions for now *)
- assert (def.region_params = []);
- let type_params = def.type_params in
- let const_generic_params = def.const_generic_params in
+ assert (regions = []);
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ let generics = { types; const_generics; trait_clauses } in
let kind = translate_type_decl_kind def.T.kind in
- { def_id; name; type_params; const_generic_params; kind }
+ let preds = translate_predicates def.preds in
+ { def_id; name; generics; kind; preds }
let translate_type_id (id : T.type_id) : type_id =
match id with
@@ -468,12 +512,9 @@ let translate_type_id (id : T.type_id) : type_id =
| T.Assumed aty ->
let aty =
match aty with
- | T.Vec -> Vec
- | T.Option -> Option
| T.Array -> Array
| T.Slice -> Slice
| T.Str -> Str
- | T.Range -> Range
| T.Box ->
(* Boxes have to be eliminated: this type id shouldn't
be translated *)
@@ -488,28 +529,26 @@ let translate_type_id (id : T.type_id) : type_id =
let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
let translate = translate_fwd_ty type_infos in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- (* Translate the type parameters *)
- let t_tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let t_generics = translate_fwd_generic_args type_infos generics in
(* Eliminate boxes and simplify tuples *)
match type_id with
- | AdtId _
- | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
- (* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
+ | AdtId _ | T.Assumed (T.Array | T.Slice | T.Str) ->
let type_id = translate_type_id type_id in
- Adt (type_id, t_tys, cgs)
+ Adt (type_id, t_generics)
| Tuple ->
(* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the
identity *)
- mk_simpl_tuple_ty t_tys
+ mk_simpl_tuple_ty t_generics.types
| T.Assumed T.Box -> (
(* We eliminate boxes *)
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
- match t_tys with
+ assert (
+ not
+ (List.exists
+ (TypesUtils.ty_has_borrows type_infos)
+ generics.types));
+ match t_generics.types with
| [ bty ] -> bty
| _ ->
raise
@@ -520,12 +559,40 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| Never -> raise (Failure "Unreachable")
| Literal lty -> Literal lty
| Ref (_, rty, _) -> translate rty
+ | RawPtr (ty, rkind) ->
+ let mut = match rkind with Mut -> Mut | Shared -> Const in
+ let ty = translate ty in
+ let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in
+ Adt (Assumed (RawPtr mut), generics)
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ TraitType (trait_ref, generics, type_name)
+ | Arrow _ -> raise (Failure "TODO")
+
+and translate_fwd_generic_args (type_infos : TA.type_infos)
+ (generics : 'r T.generic_args) : generic_args =
+ translate_generic_args (translate_fwd_ty type_infos) generics
+
+and translate_fwd_trait_ref (type_infos : TA.type_infos) (tr : 'r T.trait_ref) :
+ trait_ref =
+ translate_trait_ref (translate_fwd_ty type_infos) tr
+
+and translate_fwd_trait_instance_id (type_infos : TA.type_infos)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ translate_trait_instance_id (translate_fwd_ty type_infos) id
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
let type_infos = ctx.type_context.type_infos in
translate_fwd_ty type_infos ty
+(** Simply calls [translate_fwd_generic_args] *)
+let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : 'r T.generic_args)
+ : generic_args =
+ let type_infos = ctx.type_context.type_infos in
+ translate_fwd_generic_args type_infos generics
+
(** Translate a type, when some regions may have ended.
We return an option, because the translated type may be empty.
@@ -538,30 +605,40 @@ let rec translate_back_ty (type_infos : TA.type_infos)
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
- | T.Adt (type_id, _, tys, cgs) -> (
+ | T.Adt (type_id, generics) -> (
match type_id with
- | T.AdtId _
- | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
- (* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows type_infos ty));
+ | T.AdtId _ | Assumed (T.Array | T.Slice | T.Str) ->
let type_id = translate_type_id type_id in
if inside_mut then
- let tys_t = List.filter_map translate tys in
- Some (Adt (type_id, tys_t, cgs))
- else None
+ (* We do not want to filter anything, so we translate the generics
+ as "forward" types *)
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (Adt (type_id, generics))
+ else
+ (* If not inside a mutable reference: check if at least one
+ of the generics contains a mutable reference (i.e., is not
+ translated to `None`. If yes, keep the whole type, and
+ translate all the generics as "forward" types (the backward
+ function will extract the proper information from the ADT value)
+ *)
+ let types = List.filter_map translate generics.types in
+ if types <> [] then
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (Adt (type_id, generics))
+ else None
| Assumed T.Box -> (
(* Don't accept ADTs (which are not tuples) with borrows for now *)
assert (not (TypesUtils.ty_has_borrows type_infos ty));
(* Eliminate the box *)
- match tys with
+ match generics.types with
| [ bty ] -> translate bty
| _ ->
raise
(Failure "Unreachable: boxes receive exactly one type parameter")
)
| T.Tuple -> (
- (* Tuples can contain borrows (which we eliminated) *)
- let tys_t = List.filter_map translate tys in
+ (* Tuples can contain borrows (which we eliminate) *)
+ let tys_t = List.filter_map translate generics.types in
match tys_t with
| [] -> None
| _ ->
@@ -582,6 +659,17 @@ let rec translate_back_ty (type_infos : TA.type_infos)
if keep_region r then
translate_back_ty type_infos keep_region inside_mut rty
else None)
+ | RawPtr _ ->
+ (* TODO: not sure what to do here *)
+ None
+ | TraitType (trait_ref, generics, type_name) ->
+ assert (generics.regions = []);
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TraitType (trait_ref, generics, type_name))
+ | Arrow _ -> raise (Failure "TODO")
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
@@ -589,6 +677,80 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
let type_infos = ctx.type_context.type_infos in
translate_back_ty type_infos keep_region inside_mut ty
+let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
+ let const_generics =
+ T.ConstGenericVarId.Map.of_list
+ (List.map
+ (fun (cg : T.const_generic_var) ->
+ (cg.index, ctx_translate_fwd_ty ctx (T.Literal cg.ty)))
+ ctx.sg.generics.const_generics)
+ in
+ let env = VarId.Map.empty in
+ {
+ PureTypeCheck.type_decls = ctx.type_context.type_decls;
+ global_decls = ctx.global_context.llbc_global_decls;
+ env;
+ const_generics;
+ }
+
+let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
+ let ctx = mk_type_check_ctx ctx in
+ let _ = PureTypeCheck.check_typed_pattern ctx v in
+ ()
+
+let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
+ if !Config.type_check_pure_code then
+ let ctx = mk_type_check_ctx ctx in
+ PureTypeCheck.check_texpression ctx e
+
+let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
+ (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
+ match id with
+ | FunId fun_id -> FunId fun_id
+ | TraitMethod (trait_ref, method_name, fun_decl_id) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ TraitMethod (trait_ref, method_name, fun_decl_id)
+
+let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ let calls = ctx.calls in
+ assert (not (V.FunCallId.Map.mem call_id calls));
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
+ let calls = V.FunCallId.Map.add call_id info calls in
+ { ctx with calls }
+
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
+ (* Insert the abstraction in the call informations *)
+ let info = V.FunCallId.Map.find call_id ctx.calls in
+ assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
+ let info = { info with backwards } in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
+ (* Insert the abstraction in the abstractions map *)
+ let abstractions = ctx.abstractions in
+ assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ (* Update the context and return *)
+ ({ ctx with calls; abstractions }, fun_id)
+
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
(call_id : V.FunCallId.id) : V.AbstractionId.id list =
@@ -642,10 +804,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular fid ->
+ | TraitMethod (_, _, fid) | FunId (Regular fid) ->
let info = A.FunDeclId.Map.find fid fun_infos in
let stateful_group = info.stateful in
let stateful =
@@ -658,10 +820,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
can_diverge = info.can_diverge;
is_rec = info.is_rec || Option.is_some lid;
}
- | A.Assumed aid ->
+ | FunId (Assumed aid) ->
assert (lid = None);
{
- can_fail = Assumed.assumed_can_fail aid;
+ can_fail = Assumed.assumed_fun_can_fail aid;
stateful_group = false;
stateful = false;
can_diverge = false;
@@ -673,12 +835,14 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
- of the forward function) which we use as hints to generate pretty names.
+ of the forward function) which we use as hints to generate pretty names
+ in the extracted code.
*)
-let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig)
- (input_names : string option list) (bid : T.RegionGroupId.id option) :
- fun_sig_named_outputs =
+let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
+ (sg : A.fun_sig) (input_names : string option list)
+ (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+ let fun_infos = decls_ctx.fun_ctx.fun_infos in
+ let type_infos = decls_ctx.type_ctx.type_infos in
(* Retrieve the list of parent backward functions *)
let gid, parents =
match bid with
@@ -689,7 +853,34 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
in
(* Is the function stateful, and can it fail? *)
let lid = None in
- let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
+ let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in
+ (* We need an evaluation context to normalize the types (to normalize the
+ associated types, etc. - for instance it may happen that the types
+ refer to the types associated to a trait ref, but where the trait ref
+ is a known impl). *)
+ (* Create the context *)
+ let ctx =
+ let region_groups =
+ List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
+ in
+ let ctx =
+ InterpreterUtils.initialize_eval_context decls_ctx region_groups
+ sg.generics.types sg.generics.const_generics
+ in
+ (* Compute the normalization map for the *sty* types and add it to the context *)
+ AssociatedTypes.ctx_add_norm_trait_stypes_from_preds ctx
+ sg.preds.trait_type_constraints
+ in
+
+ (* Normalize the signature *)
+ let sg =
+ let ({ A.inputs; output; _ } : A.fun_sig) = sg in
+ let norm = AssociatedTypes.ctx_normalize_sty ctx in
+ let inputs = List.map norm inputs in
+ let output = norm output in
+ { sg with A.inputs; output }
+ in
+
(* List the inputs for:
* - the fuel
* - the forward function
@@ -806,9 +997,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(* Wrap in a result type *)
if effect_info.can_fail then mk_result_ty output else output
in
- (* Type/const generic parameters *)
- let type_params = sg.type_params in
- let const_generic_params = sg.const_generic_params in
+ (* Generic parameters *)
+ let generics = translate_generic_params sg.generics in
(* Return *)
let has_fuel = fuel <> [] in
let num_fwd_inputs_no_state = List.length fwd_inputs in
@@ -836,9 +1026,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
effect_info;
}
in
- let sg =
- { type_params; const_generic_params; inputs; output; doutputs; info }
- in
+ let preds = translate_predicates sg.A.preds in
+ let sg = { generics; preds; inputs; output; doutputs; info } in
{ sg; output_names }
let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
@@ -917,7 +1106,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
(** Peel boxes as long as the value is of the form [Box<T>] *)
let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
match (v.value, v.ty) with
- | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> (
+ | V.Adt av, T.Adt (T.Assumed T.Box, _) -> (
match av.field_values with
| [ bv ] -> unbox_typed_value bv
| _ -> raise (Failure "Unreachable"))
@@ -962,16 +1151,16 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx)
let field_values = List.map translate av.field_values in
(* Eliminate the tuple wrapper if it is a tuple with exactly one field *)
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
assert (variant_id = None);
mk_simpl_tuple_texpression field_values
| _ ->
- (* Retrieve the type, the translated type arguments and the
- * const generic arguments from the translated type (simpler this way) *)
- let adt_id, type_args, const_generic_args = ty_as_adt ty in
+ (* Retrieve the type and the translated generics from the translated
+ type (simpler this way) *)
+ let adt_id, generics = ty_as_adt ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) field_values
@@ -1038,11 +1227,9 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx)
(* Translate the field values *)
let field_values = List.filter_map translate adt_v.field_values in
(* For now, only tuples can contain borrows *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
- | T.AdtId _
- | T.Assumed
- (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
+ | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) ->
assert (field_values = []);
None
| T.Tuple ->
@@ -1185,11 +1372,9 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue)
(* For now, only tuples can contain borrows - note that if we gave
* something like a [&mut Vec] to a function, we give back the
* vector value upon visiting the "abstraction borrow" node *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
- | T.AdtId _
- | T.Assumed
- (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
+ | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) ->
assert (field_values = []);
(ctx, None)
| T.Tuple ->
@@ -1457,9 +1642,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
+ log#ldebug
+ (lazy
+ ("translate_function_call:\n"
+ ^ ctx_egeneric_args_to_string ctx call.generics));
(* Translate the function call *)
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
let args =
let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in
let args_mplaces = List.map translate_opt_mplace call.args_places in
@@ -1475,7 +1663,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None, None)) in
+ let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
+ let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -1525,18 +1714,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
- | S.Unop (E.Cast (src_ty, tgt_ty)) ->
- (* Note that cast can fail *)
- let effect_info =
- {
- can_fail = true;
- stateful_group = false;
- stateful = false;
- can_diverge = false;
- is_rec = false;
- }
- in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ | S.Unop (E.Cast cast_kind) -> (
+ match cast_kind with
+ | CastInteger (src_ty, tgt_ty) ->
+ (* Note that cast can fail *)
+ let effect_info =
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ is_rec = false;
+ }
+ in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None))
| S.Binop binop -> (
match args with
| [ arg0; arg1 ] ->
@@ -1561,7 +1752,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
- let func = { id = FunOrOp fun_id; type_args; const_generic_args } in
+ let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty
@@ -1665,9 +1856,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
+ (* TODO: normalize the types *)
+ if !Config.type_check_pure_code then
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
@@ -1692,8 +1885,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let effect_info =
get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
in
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
(* Retrieve the values consumed when we called the forward function and
@@ -1741,34 +1933,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- let _ =
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args
- ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
+ (if (* TODO: normalize the types *) !Config.type_check_pure_code then
+ match fun_id with
+ | FunId fun_id ->
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
+ in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func =
@@ -1788,7 +1981,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -1905,14 +2098,13 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
| V.LoopCall ->
- let fun_id = A.Regular ctx.fun_decl.A.def_id in
+ let fun_id = E.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id)
+ (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let type_args = loop_info.type_args in
- let const_generic_args = loop_info.const_generic_args in
+ let generics = loop_info.generics in
let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
@@ -1960,8 +2152,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -2021,9 +2213,7 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
let ctx, var = fresh_var_for_symbolic_value sval ctx in
let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in
- let global_expr =
- { id = Global gid; type_args = []; const_generic_args = [] }
- in
+ let global_expr = { id = Global gid; generics = empty_generic_args } in
(* We use translate_fwd_ty to translate the global type *)
let ty = ctx_translate_fwd_ty ctx decl.ty in
let gval = { e = Qualif global_expr; ty } in
@@ -2037,11 +2227,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
let v = typed_value_to_texpression ctx ectx v in
let args = [ v ] in
let func =
- {
- id = FunOrOp (Fun (Pure Assert));
- type_args = [];
- const_generic_args = [];
- }
+ { id = FunOrOp (Fun (Pure Assert)); generics = empty_generic_args }
in
let func_ty = mk_arrow (Literal Bool) mk_unit_ty in
let func = { e = Qualif func; ty = func_ty } in
@@ -2189,7 +2375,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(branch : S.expression) (ctx : bs_ctx) : texpression =
(* TODO: always introduce a match, and use micro-passes to turn the
the match into a let? *)
- let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let type_id, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
let branch = translate_expression branch ctx in
match type_id with
@@ -2224,10 +2410,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
* field.
* We use the [dest] variable in order not to have to recompute
* the type of the result of the projection... *)
- let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in
+ let adt_id, generics = ty_as_adt scrutinee.ty in
let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression =
let proj_kind = { adt_id; field_id } in
- let qualif = { id = Proj proj_kind; type_args; const_generic_args } in
+ let qualif = { id = Proj proj_kind; generics } in
let proj_e = Qualif qualif in
let proj_ty = mk_arrow scrutinee.ty dest.ty in
let proj = { e = proj_e; ty = proj_ty } in
@@ -2259,17 +2445,12 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(mk_typed_pattern_from_var var None)
(mk_opt_mplace_texpression scrutinee_mplace scrutinee)
branch
- | T.Assumed (T.Vec | T.Array | T.Slice | T.Str) ->
+ | T.Assumed (T.Array | T.Slice | T.Str) ->
(* We can't expand those values: we can access the fields only
* through the functions provided by the API (note that we don't
* know how to expand values like vectors or arrays, because they have a variable number
* of fields!) *)
raise (Failure "Attempt to expand a non-expandable value")
- | T.Assumed Range -> raise (Failure "Unimplemented")
- | T.Assumed T.Option ->
- (* We shouldn't get there in the "one-branch" case: options have
- * two variants *)
- raise (Failure "Unreachable")
and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression)
@@ -2282,8 +2463,9 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(* Translate the next expression *)
let next_e = translate_expression e ctx in
- (* Translate the value: there are two cases, depending on whether this
- is a "regular" let-binding or an array aggregate.
+ (* Translate the value: there are several cases, depending on whether this
+ is a "regular" let-binding, an array aggregate, a const generic or
+ a trait associated constant.
*)
let v =
match v with
@@ -2298,6 +2480,14 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
{ struct_id = Assumed Array; init = None; updates = values }
in
{ e = StructUpdate su; ty = var.ty }
+ | ConstGenericValue cg_id -> { e = CVar cg_id; ty = var.ty }
+ | TraitConstValue (trait_ref, generics, const_name) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ let qualif_id = TraitConst (trait_ref, generics, const_name) in
+ let qualif = { id = qualif_id; generics = empty_generic_args } in
+ { e = Qualif qualif; ty = var.ty }
in
(* Make the let-binding *)
@@ -2368,9 +2558,9 @@ and translate_forward_end (ectx : C.eval_ctx)
let org_args = args in
(* Lookup the effect info for the loop function *)
- let fid = A.Regular ctx.fun_decl.A.def_id in
+ let fid = E.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid
in
(* Introduce a fresh output value for the forward function *)
@@ -2415,14 +2605,8 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
- let func =
- {
- id = FunOrOp fun_id;
- type_args = loop_info.type_args;
- const_generic_args = loop_info.const_generic_args;
- }
- in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty
@@ -2541,14 +2725,31 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Note that we will retrieve the input values later in the [ForwardEnd]
(and will introduce the outputs at that moment, together with the actual
- call to the loop forward function *)
- let type_args =
- List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params
- in
- let const_generic_args =
- List.map
- (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
- ctx.sg.const_generic_params
+ call to the loop forward function) *)
+ let generics =
+ let { types; const_generics; trait_clauses } = ctx.sg.generics in
+ let types =
+ List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) types
+ in
+ let const_generics =
+ List.map
+ (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
+ const_generics
+ in
+ let trait_refs =
+ List.map
+ (fun (c : trait_clause) ->
+ let trait_decl_ref =
+ { trait_decl_id = c.trait_id; decl_generics = empty_generic_args }
+ in
+ {
+ trait_id = Clause c.clause_id;
+ generics = empty_generic_args;
+ trait_decl_ref;
+ })
+ trait_clauses
+ in
+ { types; const_generics; trait_refs }
in
let loop_info =
@@ -2556,8 +2757,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop_id;
input_vars = inputs;
input_svl = loop.input_svalues;
- type_args;
- const_generic_args;
+ generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
}
@@ -2648,8 +2848,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelEqZero));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in
@@ -2661,8 +2860,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelDecrease));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in
@@ -2727,8 +2925,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
- bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id))
+ None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -2803,10 +3001,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- signature.inputs: "
^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
));
- assert (
- List.for_all
- (fun (var, ty) -> (var : var).ty = ty)
- (List.combine inputs signature.inputs));
+ (* TODO: we need to normalize the types *)
+ if !Config.type_check_pure_code then
+ assert (
+ List.for_all
+ (fun (var, ty) -> (var : var).ty = ty)
+ (List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
@@ -2821,6 +3021,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def =
{
def_id;
+ kind = def.kind;
num_loops;
loop_id;
back_id = bid;
@@ -2853,8 +3054,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (type_infos : TA.type_infos)
+let translate_fun_signatures (decls_ctx : C.decls_ctx)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdNotLoopMap.t =
(* For every function, translate the signatures of:
@@ -2865,17 +3065,14 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
=
(* The forward function *)
- let fwd_sg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names None
- in
+ let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names
- (Some rg.id)
+ translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
in
let id = (fun_id, Some rg.id) in
(id, tsg))
@@ -2891,3 +3088,94 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.fold_left
(fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
RegularFunIdNotLoopMap.empty translated
+
+let translate_trait_decl (type_infos : TA.type_infos)
+ (trait_decl : A.trait_decl) : trait_decl =
+ let {
+ def_id;
+ name;
+ generics;
+ preds;
+ parent_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } : A.trait_decl =
+ trait_decl
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let parent_clauses = List.map translate_trait_clause parent_clauses in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_clauses, ty)) ->
+ ( name,
+ ( List.map translate_trait_clause trait_clauses,
+ Option.map (translate_fwd_ty type_infos) ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ generics;
+ preds;
+ parent_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
+
+let translate_trait_impl (type_infos : TA.type_infos)
+ (trait_impl : A.trait_impl) : trait_impl =
+ let {
+ A.def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ parent_trait_refs;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_impl
+ in
+ let impl_trait =
+ translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let parent_trait_refs = List.map translate_strait_ref parent_trait_refs in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_refs, ty)) ->
+ ( name,
+ ( List.map (translate_fwd_trait_ref type_infos) trait_refs,
+ translate_fwd_ty type_infos ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ parent_trait_refs;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }