summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-08-02 11:03:59 +0200
committerSon Ho2023-08-02 11:03:59 +0200
commit9d27e2e27db06eaad7565b55366ca8734b364fca (patch)
tree7cb450a93c538d671486e1d9f40aa1258401a31e
parent50af296306bfee9f0b127dde8abe5fb0ec1b0acb (diff)
Make progress proapagating the changes
-rw-r--r--compiler/Interpreter.ml43
-rw-r--r--compiler/InterpreterBorrows.ml4
-rw-r--r--compiler/InterpreterBorrowsCore.ml15
-rw-r--r--compiler/InterpreterExpansion.ml42
-rw-r--r--compiler/InterpreterExpressions.ml29
-rw-r--r--compiler/InterpreterLoopsCore.ml3
-rw-r--r--compiler/InterpreterLoopsFixedPoint.ml1
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml4
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.ml37
-rw-r--r--compiler/InterpreterPaths.ml9
-rw-r--r--compiler/InterpreterPaths.mli1
-rw-r--r--compiler/InterpreterProjectors.ml20
-rw-r--r--compiler/InterpreterProjectors.mli11
-rw-r--r--compiler/InterpreterStatements.ml41
-rw-r--r--compiler/InterpreterStatements.mli3
-rw-r--r--compiler/Invariants.ml43
-rw-r--r--compiler/Print.ml2
-rw-r--r--compiler/PrintPure.ml113
-rw-r--r--compiler/Pure.ml43
-rw-r--r--compiler/PureTypeCheck.ml50
-rw-r--r--compiler/PureUtils.ml92
-rw-r--r--compiler/SymbolicToPure.ml46
-rw-r--r--compiler/SynthesizeSymbolic.ml6
-rw-r--r--compiler/TranslateCore.ml18
-rw-r--r--compiler/Values.ml4
25 files changed, 407 insertions, 273 deletions
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index ccb9009e..dc2bb700 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -29,8 +29,8 @@ let compute_type_fun_global_contexts (m : A.crate) :
let initialize_eval_context (type_context : C.type_context)
(fun_context : C.fun_context) (global_context : C.global_context)
- (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) :
- C.eval_ctx =
+ (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
+ (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
C.reset_global_counters ();
{
C.type_context;
@@ -38,6 +38,7 @@ let initialize_eval_context (type_context : C.type_context)
C.global_context;
C.region_groups;
C.type_vars;
+ C.const_generic_vars;
C.env = [ C.Frame ];
C.ended_regions = T.RegionId.Set.empty;
}
@@ -76,11 +77,18 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
in
let ctx =
initialize_eval_context type_context fun_context global_context
- region_groups sg.type_params
+ region_groups sg.type_params sg.const_generic_params
in
(* Instantiate the signature *)
- let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in
- let inst_sg = instantiate_fun_sig type_params sg in
+ let type_params =
+ List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
+ in
+ let cg_params =
+ List.map
+ (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
+ sg.const_generic_params
+ in
+ let inst_sg = instantiate_fun_sig type_params cg_params sg in
(* Create fresh symbolic values for the inputs *)
let input_svs =
List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs
@@ -155,8 +163,15 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let sg = fdef.signature in
- let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in
- let ret_inst_sg = instantiate_fun_sig type_params sg in
+ let type_params =
+ List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
+ in
+ let cg_params =
+ List.map
+ (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
+ sg.const_generic_params
+ in
+ let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in
let ret_rty = ret_inst_sg.output in
(* Move the return value out of the return variable *)
let pop_return_value = is_regular_return in
@@ -490,7 +505,7 @@ module Test = struct
compute_type_fun_global_contexts crate
in
let ctx =
- initialize_eval_context type_context fun_context global_context [] []
+ initialize_eval_context type_context fun_context global_context [] [] []
in
(* Insert the (uninitialized) local variables *)
@@ -518,13 +533,11 @@ module Test = struct
(** Small helper: return true if the function is a *transparent* unit function
(no parameters, no arguments) - TODO: move *)
let fun_decl_is_transparent_unit (def : A.fun_decl) : bool =
- match def.body with
- | None -> false
- | Some body ->
- body.arg_count = 0
- && List.length def.A.signature.region_params = 0
- && List.length def.A.signature.type_params = 0
- && List.length def.A.signature.inputs = 0
+ Option.is_some def.body
+ && def.A.signature.region_params = []
+ && def.A.signature.type_params = []
+ && def.A.signature.const_generic_params = []
+ && def.A.signature.inputs = []
(** Test all the unit functions in a list of function definitions *)
let test_unit_functions (crate : A.crate) : unit =
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 38c6df3d..3d258b32 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -1733,7 +1733,7 @@ let destructure_abs (abs_kind : V.abs_kind) (can_end : bool)
and list_values (v : V.typed_value) : V.typed_avalue list * V.typed_value =
let ty = v.V.ty in
match v.V.value with
- | Primitive _ -> ([], v)
+ | Literal _ -> ([], v)
| Adt adt ->
let avll, field_values =
List.split (List.map list_values adt.field_values)
@@ -1841,7 +1841,7 @@ let convert_value_to_abstractions (abs_kind : V.abs_kind) (can_end : bool)
let ty = v.V.ty in
match v.V.value with
- | V.Primitive _ -> ([], v)
+ | V.Literal _ -> ([], v)
| V.Bottom ->
(* Can happen: we *do* convert dummy values to abstractions, and dummy
values can contain bottoms *)
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index 55365043..bf083aa4 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -87,24 +87,28 @@ let add_borrow_or_abs_id_to_chain (msg : string) (id : borrow_or_abs_id)
(** Helper function.
- This function allows to define in a generic way a comparison of region types.
+ This function allows to define in a generic way a comparison of **region types**.
See [projections_interesect] for instance.
[default]: default boolean to return, when comparing types with no regions
[combine]: how to combine booleans
[compare_regions]: how to compare regions
+
+ TODO: is there a way of deriving such a comparison?
*)
let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
(compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool)
(ty1 : T.rty) (ty2 : T.rty) : bool =
let compare = compare_rtys default combine compare_regions in
match (ty1, ty2) with
- | T.Bool, T.Bool | T.Char, T.Char | T.Str, T.Str -> default
- | T.Integer int_ty1, T.Integer int_ty2 ->
- assert (int_ty1 = int_ty2);
+ | T.Literal lit1, T.Literal lit2 ->
+ assert (lit1 = lit2);
default
- | T.Adt (id1, regions1, tys1), T.Adt (id2, regions2, tys2) ->
+ | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) ->
assert (id1 = id2);
+ (* There are no regions in the const generics, so we ignore them,
+ but we still check they are the same, for sanity *)
+ assert (cgs1 = cgs2);
(* The check for the ADTs is very crude: we simply compare the arguments
* two by two.
@@ -134,7 +138,6 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
in
(* Combine *)
combine params_b tys_b
- | T.Array ty1, T.Array ty2 | T.Slice ty1, T.Slice ty2 -> compare ty1 ty2
| T.Ref (r1, ty1, kind1), T.Ref (r2, ty2, kind2) ->
(* Sanity check *)
assert (kind1 = kind2);
diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml
index 64a90217..3b196571 100644
--- a/compiler/InterpreterExpansion.ml
+++ b/compiler/InterpreterExpansion.ml
@@ -216,7 +216,8 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
(kind : V.sv_kind) (def_id : T.TypeDeclId.id)
(regions : T.RegionId.id T.region list) (types : T.rty list)
- (ctx : C.eval_ctx) : V.symbolic_expansion list =
+ (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
+ =
(* Lookup the definition and check if it is an enumeration with several
* variants *)
let def = C.ctx_lookup_type_decl ctx def_id in
@@ -224,6 +225,7 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
(* Retrieve, for every variant, the list of its instantiated field types *)
let variants_fields_types =
Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types
+ cgs
in
(* Check if there is strictly more than one variant *)
if List.length variants_fields_types > 1 && not expand_enumerations then
@@ -280,11 +282,12 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) :
let compute_expanded_symbolic_adt_value (expand_enumerations : bool)
(kind : V.sv_kind) (adt_id : T.type_id)
(regions : T.RegionId.id T.region list) (types : T.rty list)
- (ctx : C.eval_ctx) : V.symbolic_expansion list =
+ (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
+ =
match (adt_id, regions, types) with
| T.AdtId def_id, _, _ ->
compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind
- def_id regions types ctx
+ def_id regions types cgs ctx
| T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ]
| T.Assumed T.Option, [], [ ty ] ->
compute_expanded_symbolic_option_value expand_enumerations kind ty
@@ -513,10 +516,10 @@ let expand_symbolic_bool (config : C.config) (sv : V.symbolic_value)
let original_sv = sv in
let original_sv_place = sv_place in
let rty = original_sv.V.sv_ty in
- assert (rty = T.Bool);
+ assert (rty = T.Literal PV.Bool);
(* Expand the symbolic value to true or false and continue execution *)
- let see_true = V.SePrimitive (PV.Bool true) in
- let see_false = V.SePrimitive (PV.Bool false) in
+ let see_true = V.SeLiteral (PV.Bool true) in
+ let see_false = V.SeLiteral (PV.Bool false) in
let seel = [ (Some see_true, cf_true); (Some see_false, cf_false) ] in
(* Apply the symbolic expansion (this also outputs the updated symbolic AST) *)
apply_branching_symbolic_expansions_non_borrow config original_sv
@@ -540,12 +543,12 @@ let expand_symbolic_value_no_branching (config : C.config)
fun cf ctx ->
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types) ->
+ | T.Adt (adt_id, regions, types, cgs) ->
(* Compute the expanded value *)
let allow_branching = false in
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types ctx
+ regions types cgs ctx
in
(* There should be exacly one branch *)
let see = Collections.List.to_cons_nil seel in
@@ -597,12 +600,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value)
(* Execute *)
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types) ->
+ | T.Adt (adt_id, regions, types, cgs) ->
let allow_branching = true in
(* Compute the expanded value *)
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types ctx
+ regions types cgs ctx
in
(* Apply *)
let seel = List.map (fun see -> (Some see, cf_branches)) seel in
@@ -617,7 +620,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value)
(tgts : (V.scalar_value * st_cm_fun) list) (otherwise : st_cm_fun)
(cf_after_join : st_m_fun) : m_fun =
(* Sanity check *)
- assert (sv.V.sv_ty = T.Integer int_type);
+ assert (sv.V.sv_ty = T.Literal (PV.Integer int_type));
(* For all the branches of the switch, we expand the symbolic value
* to the value given by the branch and execute the branch statement.
* For the otherwise branch, we leave the symbolic value as it is
@@ -628,7 +631,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value)
* (optional expansion, statement to execute)
*)
let seel =
- List.map (fun (v, cf) -> (Some (V.SePrimitive (PV.Scalar v)), cf)) tgts
+ List.map (fun (v, cf) -> (Some (V.SeLiteral (PV.Scalar v)), cf)) tgts
in
let seel = List.append seel [ (None, otherwise) ] in
(* Then expand and evaluate - this generates the proper symbolic AST *)
@@ -676,7 +679,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
^ symbolic_value_to_string ctx sv));
let cc : cm_fun =
match sv.V.sv_ty with
- | T.Adt (AdtId def_id, _, _) ->
+ | T.Adt (AdtId def_id, _, _, _) ->
(* {!expand_symbolic_value_no_branching} checks if there are branchings,
* but we prefer to also check it here - this leads to cleaner messages
* and debugging *)
@@ -701,16 +704,15 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
[config]): "
^ Print.name_to_string def.name))
else expand_symbolic_value_no_branching config sv None
- | T.Adt ((Tuple | Assumed Box), _, _) | T.Ref (_, _, _) ->
+ | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) ->
(* Ok *)
expand_symbolic_value_no_branching config sv None
- | T.Adt (Assumed (Vec | Option), _, _) ->
+ | T.Adt (Assumed (Vec | Option | Array | Slice | Str), _, _, _) ->
(* We can't expand those *)
- raise (Failure "Attempted to greedily expand a Vec or an Option ")
- | T.Array _ -> raise Utils.Unimplemented
- | T.Slice _ -> raise (Failure "Can't expand symbolic slices")
- | T.TypeVar _ | Bool | Char | Never | Integer _ | Str ->
- raise (Failure "Unreachable")
+ raise
+ (Failure
+ "Attempted to greedily expand an ADT which can't be expanded ")
+ | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable")
in
(* Compose and continue *)
comp cc expand cf ctx
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index d75f5a26..bb159f05 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -94,24 +94,23 @@ let access_rplace_reorganize (config : C.config) (expand_prim_copy : bool)
ctx
(** Convert an operand constant operand value to a typed value *)
-let primitive_to_typed_value (ty : T.ety) (cv : V.primitive_value) :
+let literal_to_typed_value (ty : PV.literal_type) (cv : V.literal) :
V.typed_value =
(* Check the type while converting - we actually need some information
* contained in the type *)
log#ldebug
(lazy
- ("primitive_to_typed_value:" ^ "\n- cv: "
- ^ Print.PrimitiveValues.primitive_value_to_string cv));
+ ("literal_to_typed_value:" ^ "\n- cv: "
+ ^ Print.PrimitiveValues.literal_to_string cv));
match (ty, cv) with
(* Scalar, boolean... *)
- | T.Bool, Bool v -> { V.value = V.Primitive (Bool v); ty }
- | T.Char, Char v -> { V.value = V.Primitive (Char v); ty }
- | T.Str, String v -> { V.value = V.Primitive (String v); ty }
- | T.Integer int_ty, PV.Scalar v ->
+ | PV.Bool, Bool v -> { V.value = V.Literal (Bool v); ty = T.Literal ty }
+ | Char, Char v -> { V.value = V.Literal (Char v); ty = T.Literal ty }
+ | Integer int_ty, PV.Scalar v ->
(* Check the type and the ranges *)
assert (int_ty = v.int_ty);
assert (check_scalar_value_in_range v);
- { V.value = V.Primitive (PV.Scalar v); ty }
+ { V.value = V.Literal (PV.Scalar v); ty = T.Literal ty }
(* Remaining cases (invalid) *)
| _, _ -> raise (Failure "Improperly typed constant value")
@@ -138,14 +137,16 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config)
* the fact that we have exhaustive matches below makes very obvious the cases
* in which we need to fail *)
match v.V.value with
- | V.Primitive _ -> (ctx, v)
+ | V.Literal _ -> (ctx, v)
| V.Adt av ->
(* Sanity check *)
(match v.V.ty with
- | T.Adt (T.Assumed (T.Box | Vec), _, _) ->
+ | T.Adt (T.Assumed (T.Box | Vec), _, _, _) ->
raise (Failure "Can't copy an assumed value other than Option")
- | T.Adt (T.AdtId _, _, _) -> assert allow_adt_copy
- | T.Adt ((T.Assumed Option | T.Tuple), _, _) -> () (* Ok *)
+ | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy
+ | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *)
+ | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) ->
+ assert (ty_is_primitively_copyable ty)
| _ -> raise (Failure "Unreachable"));
let ctx, fields =
List.fold_left_map
@@ -231,7 +232,7 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) :
match op with
| Expressions.Constant (ty, cv) ->
(* No need to reorganize the context *)
- primitive_to_typed_value ty cv |> ignore;
+ literal_to_typed_value ty cv |> ignore;
cf ctx
| Expressions.Copy p ->
(* Access the value *)
@@ -259,7 +260,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n"));
(* Evaluate *)
match op with
- | Expressions.Constant (ty, cv) -> cf (primitive_to_typed_value ty cv) ctx
+ | Expressions.Constant (ty, cv) -> cf (literal_to_typed_value ty cv) ctx
| Expressions.Copy p ->
(* Access the value *)
let access = Read in
diff --git a/compiler/InterpreterLoopsCore.ml b/compiler/InterpreterLoopsCore.ml
index 209fce1c..6e33c75b 100644
--- a/compiler/InterpreterLoopsCore.ml
+++ b/compiler/InterpreterLoopsCore.ml
@@ -60,8 +60,7 @@ module type PrimMatcher = sig
val match_rtys : T.rty -> T.rty -> T.rty
(** The input primitive values are not equal *)
- val match_distinct_primitive_values :
- T.ety -> V.primitive_value -> V.primitive_value -> V.typed_value
+ val match_distinct_literals : T.ety -> V.literal -> V.literal -> V.typed_value
(** The input ADTs don't have the same variant *)
val match_distinct_adts : T.ety -> V.adt_value -> V.adt_value -> V.typed_value
diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml
index aff8f3fe..a9ec9ecf 100644
--- a/compiler/InterpreterLoopsFixedPoint.ml
+++ b/compiler/InterpreterLoopsFixedPoint.ml
@@ -109,6 +109,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun =
(fun r -> if T.RegionId.Set.mem r rids then nrid else r)
(fun x -> x)
(fun x -> x)
+ (fun x -> x)
(fun id ->
let nid = C.fresh_symbolic_value_id () in
let sv = V.SymbolicValueId.Map.find id absl_id_maps.sids_to_values in
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index 6fb0449d..bf88e055 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -556,6 +556,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
global_context;
region_groups;
type_vars;
+ const_generic_vars;
env = _;
ended_regions = ended_regions0;
} =
@@ -567,6 +568,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
global_context = _;
region_groups = _;
type_vars = _;
+ const_generic_vars = _;
env = _;
ended_regions = ended_regions1;
} =
@@ -580,6 +582,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
global_context;
region_groups;
type_vars;
+ const_generic_vars;
env;
ended_regions;
}
@@ -635,6 +638,7 @@ let refresh_abs (old_abs : V.AbstractionId.Set.t) (ctx : C.eval_ctx) :
(fun x -> x)
(fun x -> x)
(fun x -> x)
+ (fun x -> x)
subst ctx.env
in
{ ctx with C.env }
diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml
index 80cd93cf..9248e513 100644
--- a/compiler/InterpreterLoopsMatchCtxs.ml
+++ b/compiler/InterpreterLoopsMatchCtxs.ml
@@ -44,11 +44,11 @@ let compute_abs_borrows_loans_maps (no_duplicates : bool)
(id0 : Id0.id) (id1 : Id1.id) : unit =
(* Sanity check *)
(if check_singleton_sets || check_not_already_registered then
- match Id0.Map.find_opt id0 !map with
- | None -> ()
- | Some set ->
- assert (
- (not check_not_already_registered) || not (Id1.Set.mem id1 set)));
+ match Id0.Map.find_opt id0 !map with
+ | None -> ()
+ | Some set ->
+ assert (
+ (not check_not_already_registered) || not (Id1.Set.mem id1 set)));
(* Update the mapping *)
map :=
Id0.Map.update id0
@@ -149,9 +149,11 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty)
(match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty =
let match_rec = match_types match_distinct_types match_regions in
match (ty0, ty1) with
- | Adt (id0, regions0, tys0), Adt (id1, regions1, tys1) ->
+ | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) ->
assert (id0 = id1);
+ assert (cgs0 = cgs1);
let id = id0 in
+ let cgs = cgs1 in
let regions =
List.map
(fun (id0, id1) -> match_regions id0 id1)
@@ -160,16 +162,15 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty)
let tys =
List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1)
in
- Adt (id, regions, tys)
+ Adt (id, regions, tys, cgs)
| TypeVar vid0, TypeVar vid1 ->
assert (vid0 = vid1);
let vid = vid0 in
TypeVar vid
- | Bool, Bool | Char, Char | Never, Never | Str, Str -> ty0
- | Integer int_ty0, Integer int_ty1 ->
- assert (int_ty0 = int_ty1);
+ | Literal lty0, Literal lty1 ->
+ assert (lty0 = lty1);
ty0
- | Array ty0, Array ty1 | Slice ty0, Slice ty1 -> match_rec ty0 ty1
+ | Never, Never -> ty0
| Ref (r0, ty0, k0), Ref (r1, ty1, k1) ->
let r = match_regions r0 r1 in
let ty = match_rec ty0 ty1 in
@@ -184,8 +185,8 @@ module MakeMatcher (M : PrimMatcher) : Matcher = struct
let match_rec = match_typed_values ctx in
let ty = M.match_etys v0.V.ty v1.V.ty in
match (v0.V.value, v1.V.value) with
- | V.Primitive pv0, V.Primitive pv1 ->
- if pv0 = pv1 then v1 else M.match_distinct_primitive_values ty pv0 pv1
+ | V.Literal lv0, V.Literal lv1 ->
+ if lv0 = lv1 then v1 else M.match_distinct_literals ty lv0 lv1
| V.Adt av0, V.Adt av1 ->
if av0.variant_id = av1.variant_id then
let fields = List.combine av0.field_values av1.field_values in
@@ -385,8 +386,8 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct
assert (ty0 = ty1);
ty0
- let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value)
- (_ : V.primitive_value) : V.typed_value =
+ let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) :
+ V.typed_value =
mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty
let match_distinct_adts (ty : T.ety) (adt0 : V.adt_value) (adt1 : V.adt_value)
@@ -834,8 +835,8 @@ struct
in
match_types match_distinct_types match_regions ty0 ty1
- let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value)
- (_ : V.primitive_value) : V.typed_value =
+ let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) :
+ V.typed_value =
mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty
let match_distinct_adts (_ty : T.ety) (_adt0 : V.adt_value)
@@ -1616,7 +1617,7 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id)
cc
(cf
(if is_loop_entry then EndEnterLoop (loop_id, input_values)
- else EndContinue (loop_id, input_values)))
+ else EndContinue (loop_id, input_values)))
tgt_ctx
in
diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml
index 619815b3..4a439250 100644
--- a/compiler/InterpreterPaths.ml
+++ b/compiler/InterpreterPaths.ml
@@ -97,7 +97,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
match (pe, v.V.value, v.V.ty) with
| ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id),
V.Adt adt,
- T.Adt (type_id, _, _) ) -> (
+ T.Adt (type_id, _, _, _) ) -> (
(* Check consistency *)
(match (proj_kind, type_id) with
| ProjAdt (def_id, opt_variant_id), T.AdtId def_id' ->
@@ -119,7 +119,8 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
let updated = { v with value = nadt } in
Ok (ctx, { res with updated }))
(* Tuples *)
- | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _) -> (
+ | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _)
+ -> (
assert (arity = List.length adt.field_values);
let fv = T.FieldId.nth adt.field_values field_id in
(* Project *)
@@ -144,7 +145,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
(* Box dereferencement *)
| ( DerefBox,
Adt { variant_id = None; field_values = [ bv ] },
- T.Adt (T.Assumed T.Box, _, _) ) -> (
+ T.Adt (T.Assumed T.Box, _, _, _) ) -> (
(* We allow moving inside of boxes. In practice, this kind of
* manipulations should happen only inside unsage code, so
* it shouldn't happen due to user code, and we leverage it
@@ -249,7 +250,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
in
Ok (ctx, { res with updated = nv })
else Error (FailSharedLoan bids))
- | (_, (V.Primitive _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r ->
+ | (_, (V.Literal _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r ->
let pe, v, ty = r in
let pe = "- pe: " ^ E.show_projection_elem pe in
let v = "- v:\n" ^ V.show_value v in
diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli
index 6e9286cd..4a9f3b41 100644
--- a/compiler/InterpreterPaths.mli
+++ b/compiler/InterpreterPaths.mli
@@ -61,6 +61,7 @@ val compute_expanded_bottom_adt_value :
T.VariantId.id option ->
T.erased_region list ->
T.ety list ->
+ T.const_generic list ->
V.typed_value
(** Compute an expanded [Option] ⊥ value *)
diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml
index 9487df84..faed066b 100644
--- a/compiler/InterpreterProjectors.ml
+++ b/compiler/InterpreterProjectors.ml
@@ -23,12 +23,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx)
if not (ty_has_regions_in_set regions ty) then []
else
match (v.V.value, ty) with
- | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> []
- | V.Adt adt, T.Adt (id, region_params, tys) ->
+ | V.Literal _, T.Literal _ -> []
+ | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
(* Retrieve the types of the fields *)
let field_types =
Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys
+ region_params tys cgs
in
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
@@ -102,12 +102,12 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx)
else
let value : V.avalue =
match (v.V.value, ty) with
- | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> V.AIgnored
- | V.Adt adt, T.Adt (id, region_params, tys) ->
+ | V.Literal _, T.Literal _ -> V.AIgnored
+ | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
(* Retrieve the types of the fields *)
let field_types =
Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys
+ region_params tys cgs
in
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
@@ -231,7 +231,7 @@ let symbolic_expansion_non_borrow_to_value (sv : V.symbolic_value)
let ty = Subst.erase_regions sv.V.sv_ty in
let value =
match see with
- | SePrimitive cv -> V.Primitive cv
+ | SeLiteral cv -> V.Literal cv
| SeAdt (variant_id, field_values) ->
let field_values =
List.map mk_typed_value_from_symbolic_value field_values
@@ -267,9 +267,9 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t)
(* Match *)
let (value, ty) : V.avalue * T.rty =
match (see, original_sv_ty) with
- | SePrimitive _, (T.Bool | T.Char | T.Integer _ | T.Str) ->
- (V.AIgnored, original_sv_ty)
- | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys) ->
+ | SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty)
+ | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs)
+ ->
(* Project over the field values *)
let field_values =
List.map
diff --git a/compiler/InterpreterProjectors.mli b/compiler/InterpreterProjectors.mli
index 1afb9d53..bcc3dee2 100644
--- a/compiler/InterpreterProjectors.mli
+++ b/compiler/InterpreterProjectors.mli
@@ -55,7 +55,16 @@ val prepare_reborrows :
bool ->
(V.BorrowId.id -> V.BorrowId.id) * (C.eval_ctx -> C.eval_ctx)
-(** Apply (and reduce) a projector over borrows to a value.
+(** Apply (and reduce) a projector over borrows to an avalue.
+ We use this for instance to spread the borrows present in the inputs
+ of a function into the regions introduced for this function. For instance:
+ {[
+ fn f<'a, 'b, T>(x: &'a T, y: &'b T)
+ ]}
+ If we call `f` with `x -> shared_borrow l0` and `y -> shared_borrow l1`, then
+ for the region introduced for `'a` we need to project the value for `x` to
+ a shared aborrow, and we need to ignore the borrow in `y`, because it belongs
+ to the other region.
Parameters:
- [check_symbolic_no_ended]: controls whether we check or not whether
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index f5b1111e..d181ca4b 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -149,7 +149,7 @@ let eval_assertion_concrete (config : C.config) (assertion : A.assertion) :
let eval_assert cf (v : V.typed_value) : m_fun =
fun ctx ->
match v.value with
- | Primitive (Bool b) ->
+ | Literal (Bool b) ->
(* Branch *)
if b = assertion.expected then cf Unit ctx else cf Panic ctx
| _ ->
@@ -172,26 +172,26 @@ let eval_assertion (config : C.config) (assertion : A.assertion) : st_cm_fun =
(* Evaluate the assertion *)
let eval_assert cf (v : V.typed_value) : m_fun =
fun ctx ->
- assert (v.ty = T.Bool);
+ assert (v.ty = T.Literal PV.Bool);
(* We make a choice here: we could completely decouple the concrete and
* symbolic executions here but choose not to. In the case where we
* know the concrete value of the boolean we test, we use this value
* even if we are in symbolic mode. Note that this case should be
* extremely rare... *)
match v.value with
- | Primitive (Bool _) ->
+ | Literal (Bool _) ->
(* Delegate to the concrete evaluation function *)
eval_assertion_concrete config assertion cf ctx
| Symbolic sv ->
assert (config.mode = C.SymbolicMode);
- assert (sv.V.sv_ty = T.Bool);
+ assert (sv.V.sv_ty = T.Literal PV.Bool);
(* We continue the execution as if the test had succeeded, and thus
* perform the symbolic expansion: sv ~~> true.
* We will of course synthesize an assertion in the generated code
* (see below). *)
let ctx =
apply_symbolic_expansion_non_borrow config sv
- (V.SePrimitive (PV.Bool true)) ctx
+ (V.SeLiteral (PV.Bool true)) ctx
in
(* Continue *)
let expr = cf Unit ctx in
@@ -232,7 +232,8 @@ let set_discriminant (config : C.config) (p : E.place)
let update_value cf (v : V.typed_value) : m_fun =
fun ctx ->
match (v.V.ty, v.V.value) with
- | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types),
+ | ( T.Adt
+ (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
V.Adt av ) -> (
(* There are two situations:
- either the discriminant is already the proper one (in which case we
@@ -252,7 +253,7 @@ let set_discriminant (config : C.config) (p : E.place)
| T.AdtId def_id ->
compute_expanded_bottom_adt_value
ctx.type_context.type_decls def_id (Some variant_id)
- regions types
+ regions types cgs
| T.Assumed T.Option ->
assert (regions = []);
compute_expanded_bottom_option_value variant_id
@@ -260,13 +261,14 @@ let set_discriminant (config : C.config) (p : E.place)
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx)
- | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types),
+ | ( T.Adt
+ (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
V.Bottom ) ->
let bottom_v =
match type_id with
| T.AdtId def_id ->
compute_expanded_bottom_adt_value ctx.type_context.type_decls
- def_id (Some variant_id) regions types
+ def_id (Some variant_id) regions types cgs
| T.Assumed T.Option ->
assert (regions = []);
compute_expanded_bottom_option_value variant_id
@@ -285,7 +287,7 @@ let set_discriminant (config : C.config) (p : E.place)
* or reset an already initialized value, really. *)
raise (Failure "Unexpected value")
| _, (V.Adt _ | V.Bottom) -> raise (Failure "Inconsistent state")
- | _, (V.Primitive _ | V.Borrow _ | V.Loan _) ->
+ | _, (V.Literal _ | V.Borrow _ | V.Loan _) ->
raise (Failure "Unexpected value")
in
(* Compose and apply *)
@@ -302,20 +304,21 @@ let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx)
instantiation of a non-local function.
*)
let get_non_local_function_return_type (fid : A.assumed_fun_id)
- (region_params : T.erased_region list) (type_params : T.ety list) : T.ety =
+ (region_params : T.erased_region list) (type_params : T.ety list)
+ (const_generic_params : T.const_generic list) : T.ety =
(* [Box::free] has a special treatment *)
- match (fid, region_params, type_params) with
- | A.BoxFree, [], [ _ ] -> mk_unit_ty
+ match (fid, region_params, type_params, const_generic_params) with
+ | A.BoxFree, [], [ _ ], [] -> mk_unit_ty
| _ ->
(* Retrieve the function's signature *)
let sg = Assumed.get_assumed_sig fid in
(* Instantiate the return type *)
- let tsubst =
- Subst.make_type_subst
- (List.map (fun v -> v.T.index) sg.type_params)
- type_params
+ let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in
+ let cgsubst =
+ Subst.make_const_generic_subst_from_vars sg.const_generic_params
+ const_generic_params
in
- Subst.erase_regions_substitute_types tsubst sg.output
+ Subst.erase_regions_substitute_types tsubst cgsubst sg.output
let move_return_value (config : C.config) (pop_return_value : bool)
(cf : V.typed_value option -> m_fun) : m_fun =
@@ -443,7 +446,7 @@ let eval_box_new_concrete (config : C.config)
(* Create the new box *)
let cf_create cf (moved_input_value : V.typed_value) : m_fun =
(* Create the box value *)
- let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ]) in
+ let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in
let box_v =
V.Adt { variant_id = None; field_values = [ moved_input_value ] }
in
diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli
index f28bf2ea..814bc964 100644
--- a/compiler/InterpreterStatements.mli
+++ b/compiler/InterpreterStatements.mli
@@ -31,7 +31,8 @@ val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun
Note: there are no region parameters, because they should be erased.
*)
-val instantiate_fun_sig : T.ety list -> LA.fun_sig -> LA.inst_fun_sig
+val instantiate_fun_sig :
+ T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig
(** Helper.
diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml
index 981c2c46..a726eda0 100644
--- a/compiler/Invariants.ml
+++ b/compiler/Invariants.ml
@@ -377,10 +377,10 @@ let check_borrowed_values_invariant (ctx : C.eval_ctx) : unit =
let info = { outer_borrow = false; outer_shared = false } in
visitor#visit_eval_ctx info ctx
-let check_primitive_value_type (cv : V.primitive_value) (ty : T.ety) : unit =
+let check_literal_type (cv : V.literal) (ty : PV.literal_type) : unit =
match (cv, ty) with
- | PV.Scalar sv, T.Integer int_ty -> assert (sv.int_ty = int_ty)
- | PV.Bool _, T.Bool | PV.Char _, T.Char | PV.String _, T.Str -> ()
+ | PV.Scalar sv, PV.Integer int_ty -> assert (sv.int_ty = int_ty)
+ | PV.Bool _, PV.Bool | PV.Char _, PV.Char -> ()
| _ -> raise (Failure "Erroneous typing")
let check_typing_invariant (ctx : C.eval_ctx) : unit =
@@ -404,9 +404,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
method! visit_typed_value info tv =
(* Check the current pair (value, type) *)
(match (tv.V.value, tv.V.ty) with
- | V.Primitive cv, ty -> check_primitive_value_type cv ty
+ | V.Literal cv, T.Literal ty -> check_literal_type cv ty
(* ADT case *)
- | V.Adt av, T.Adt (T.AdtId def_id, regions, tys) ->
+ | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
@@ -422,7 +422,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check that the field types are correct *)
let field_types =
Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id
- tys
+ tys cgs
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -431,8 +431,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.Adt av, T.Adt (T.Tuple, regions, tys) ->
+ | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) ->
assert (regions = []);
+ assert (cgs = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
@@ -441,20 +442,22 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys) -> (
+ | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
assert (av.V.variant_id = None || aty_id = T.Option);
- match (aty_id, av.V.field_values, regions, tys) with
+ match (aty_id, av.V.field_values, regions, tys, cgs) with
(* Box *)
- | T.Box, [ inner_value ], [], [ inner_ty ]
- | T.Option, [ inner_value ], [], [ inner_ty ] ->
+ | T.Box, [ inner_value ], [], [ inner_ty ], []
+ | T.Option, [ inner_value ], [], [ inner_ty ], [] ->
assert (inner_value.V.ty = inner_ty)
- | T.Option, _, [], [ _ ] ->
+ | T.Option, _, [], [ _ ], [] ->
(* Option::None: nothing to check *)
()
- | T.Vec, fvs, [], [ vec_ty ] ->
+ | T.Vec, fvs, [], [ vec_ty ], [] ->
List.iter
(fun (v : V.typed_value) -> assert (v.ty = vec_ty))
fvs
+ | (T.Array | T.Slice | T.Str), _, _, _, _ ->
+ raise (Failure "Unexpected")
| _ -> raise (Failure "Erroneous type"))
| V.Bottom, _ -> (* Nothing to check *) ()
| V.Borrow bc, T.Ref (_, ref_ty, rkind) -> (
@@ -502,13 +505,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check the current pair (value, type) *)
(match (atv.V.value, atv.V.ty) with
(* ADT case *)
- | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys) ->
+ | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
assert (List.length regions = List.length def.region_params);
assert (List.length tys = List.length def.type_params);
+ assert (List.length cgs = List.length def.const_generic_params);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -518,7 +522,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check that the field types are correct *)
let field_types =
Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id
- regions tys
+ regions tys cgs
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -527,8 +531,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.AAdt av, T.Adt (T.Tuple, regions, tys) ->
+ | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) ->
assert (regions = []);
+ assert (cgs = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
@@ -537,11 +542,11 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys) -> (
+ | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
assert (av.V.variant_id = None);
- match (aty_id, av.V.field_values, regions, tys) with
+ match (aty_id, av.V.field_values, regions, tys, cgs) with
(* Box *)
- | T.Box, [ boxed_value ], [], [ boxed_ty ] ->
+ | T.Box, [ boxed_value ], [], [ boxed_ty ], [] ->
assert (boxed_value.V.ty = boxed_ty)
| _ -> raise (Failure "Erroneous type"))
| V.ABottom, _ -> (* Nothing to check *) ()
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 23cebd4c..410b45e6 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -80,7 +80,7 @@ module Values = struct
string =
let ty_fmt : PT.etype_formatter = value_to_etype_formatter fmt in
match v.value with
- | Primitive cv -> PPV.literal_to_string cv
+ | Literal cv -> PPV.literal_to_string cv
| Adt av -> (
let field_values =
List.map (typed_value_to_string fmt) av.field_values
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 03252200..6f857b4f 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -6,11 +6,15 @@ open PureUtils
type type_formatter = {
type_var_id_to_string : TypeVarId.id -> string;
type_decl_id_to_string : TypeDeclId.id -> string;
+ const_generic_var_id_to_string : ConstGenericVarId.id -> string;
+ global_decl_id_to_string : GlobalDeclId.id -> string;
}
type value_formatter = {
type_var_id_to_string : TypeVarId.id -> string;
type_decl_id_to_string : TypeDeclId.id -> string;
+ const_generic_var_id_to_string : ConstGenericVarId.id -> string;
+ global_decl_id_to_string : GlobalDeclId.id -> string;
adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string;
var_id_to_string : VarId.id -> string;
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
@@ -20,6 +24,8 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter =
{
type_var_id_to_string = fmt.type_var_id_to_string;
type_decl_id_to_string = fmt.type_decl_id_to_string;
+ const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
+ global_decl_id_to_string = fmt.global_decl_id_to_string;
}
(* TODO: we need to store which variables we have encountered so far, and
@@ -28,6 +34,7 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter =
type ast_formatter = {
type_var_id_to_string : TypeVarId.id -> string;
type_decl_id_to_string : TypeDeclId.id -> string;
+ const_generic_var_id_to_string : ConstGenericVarId.id -> string;
adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string;
var_id_to_string : VarId.id -> string;
adt_field_to_string :
@@ -41,6 +48,8 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
{
type_var_id_to_string = fmt.type_var_id_to_string;
type_decl_id_to_string = fmt.type_decl_id_to_string;
+ const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
+ global_decl_id_to_string = fmt.global_decl_id_to_string;
adt_variant_to_string = fmt.adt_variant_to_string;
var_id_to_string = fmt.var_id_to_string;
adt_field_names = fmt.adt_field_names;
@@ -55,22 +64,38 @@ let fun_name_to_string = Print.fun_name_to_string
let global_name_to_string = Print.global_name_to_string
let option_to_string = Print.option_to_string
let type_var_to_string = Print.Types.type_var_to_string
+let const_generic_var_to_string = Print.Types.const_generic_var_to_string
let integer_type_to_string = Print.PrimitiveValues.integer_type_to_string
let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string
let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string
let literal_to_string = Print.PrimitiveValues.literal_to_string
let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
- (type_params : type_var list) : type_formatter =
+ (global_decls : A.global_decl GlobalDeclId.Map.t)
+ (type_params : type_var list)
+ (const_generic_params : const_generic_var list) : type_formatter =
let type_var_id_to_string vid =
let var = T.TypeVarId.nth type_params vid in
type_var_to_string var
in
+ let const_generic_var_id_to_string vid =
+ let var = T.ConstGenericVarId.nth const_generic_params vid in
+ const_generic_var_to_string var
+ in
let type_decl_id_to_string def_id =
let def = T.TypeDeclId.Map.find def_id type_decls in
name_to_string def.name
in
- { type_var_id_to_string; type_decl_id_to_string }
+ let global_decl_id_to_string def_id =
+ let def = T.GlobalDeclId.Map.find def_id global_decls in
+ name_to_string def.name
+ in
+ {
+ type_var_id_to_string;
+ type_decl_id_to_string;
+ const_generic_var_id_to_string;
+ global_decl_id_to_string;
+ }
(* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter.
@@ -81,11 +106,16 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(fun_decls : A.fun_decl FunDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list) : ast_formatter =
+ (type_params : type_var list)
+ (const_generic_params : const_generic_var list) : ast_formatter =
let type_var_id_to_string vid =
let var = T.TypeVarId.nth type_params vid in
type_var_to_string var
in
+ let const_generic_var_id_to_string vid =
+ let var = T.ConstGenericVarId.nth const_generic_params vid in
+ const_generic_var_to_string var
+ in
let type_decl_id_to_string def_id =
let def = T.TypeDeclId.Map.find def_id type_decls in
name_to_string def.name
@@ -113,6 +143,7 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
in
{
type_var_id_to_string;
+ const_generic_var_id_to_string;
type_decl_id_to_string;
adt_variant_to_string;
var_id_to_string;
@@ -122,36 +153,50 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
global_decl_id_to_string;
}
+let assumed_ty_to_string (aty : assumed_ty) : string =
+ match aty with
+ | State -> "State"
+ | Result -> "Result"
+ | Error -> "Error"
+ | Fuel -> "Fuel"
+ | Option -> "Option"
+ | Vec -> "Vec"
+ | Array -> "Array"
+ | Slice -> "Slice"
+ | Str -> "Str"
+
let type_id_to_string (fmt : type_formatter) (id : type_id) : string =
match id with
| AdtId id -> fmt.type_decl_id_to_string id
| Tuple -> ""
- | Assumed aty -> (
- match aty with
- | State -> "State"
- | Result -> "Result"
- | Error -> "Error"
- | Fuel -> "Fuel"
- | Option -> "Option"
- | Vec -> "Vec")
+ | Assumed aty -> assumed_ty_to_string aty
+
+(* TODO: duplicates Charon.PrintTypes.const_generic_to_string *)
+let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) :
+ string =
+ match cg with
+ | ConstGenericGlobal id -> fmt.global_decl_id_to_string id
+ | ConstGenericVar id -> fmt.const_generic_var_id_to_string id
+ | ConstGenericValue lit -> literal_to_string lit
let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
match ty with
- | Adt (id, tys) -> (
+ | Adt (id, tys, cgs) -> (
let tys = List.map (ty_to_string fmt false) tys in
+ let cgs = List.map (const_generic_to_string fmt) cgs in
+ let params = List.append tys cgs in
match id with
- | Tuple -> "(" ^ String.concat " * " tys ^ ")"
+ | Tuple ->
+ assert (cgs = []);
+ "(" ^ String.concat " * " tys ^ ")"
| AdtId _ | Assumed _ ->
- let tys_s = if tys = [] then "" else " " ^ String.concat " " tys in
- let ty_s = type_id_to_string fmt id ^ tys_s in
- if tys <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
+ let params_s =
+ if params = [] then "" else " " ^ String.concat " " params
+ in
+ let ty_s = type_id_to_string fmt id ^ params_s in
+ if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
| TypeVar tv -> fmt.type_var_id_to_string tv
- | Bool -> "bool"
- | Char -> "char"
- | Integer int_ty -> integer_type_to_string int_ty
- | Str -> "str"
- | Array aty -> "[" ^ ty_to_string fmt false aty ^ "; ?]"
- | Slice sty -> "[" ^ ty_to_string fmt false sty ^ "]"
+ | Literal lty -> literal_type_to_string lty
| Arrow (arg_ty, ret_ty) ->
let ty =
ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty
@@ -248,8 +293,8 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
| Assumed aty -> (
(* Assumed type *)
match aty with
- | State ->
- (* This type is opaque: we can't get there *)
+ | State | Vec | Array | Slice | Str ->
+ (* Those types are opaque: we can't get there *)
raise (Failure "Unreachable")
| Result ->
let variant_id = Option.get variant_id in
@@ -272,10 +317,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
if variant_id = option_some_id then "@Option::Some "
else if variant_id = option_none_id then "@Option::None"
else
- raise (Failure "Unreachable: improper variant id for result type")
- | Vec ->
- assert (variant_id = None);
- "Vec")
+ raise (Failure "Unreachable: improper variant id for result type"))
let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
(field_id : FieldId.id) : string =
@@ -292,7 +334,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
| Assumed aty -> (
(* Assumed type *)
match aty with
- | State | Fuel | Vec ->
+ | State | Fuel | Vec | Array | Slice | Str ->
(* Opaque types: we can't get there *)
raise (Failure "Unreachable")
| Result | Error | Option ->
@@ -300,17 +342,17 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
raise (Failure "Unreachable"))
(** TODO: we don't need a general function anymore (it is now only used for
- patterns (i.e., patterns)
+ patterns)
*)
let adt_g_value_to_string (fmt : value_formatter)
(value_to_string : 'v -> string) (variant_id : VariantId.id option)
(field_values : 'v list) (ty : ty) : string =
let field_values = List.map value_to_string field_values in
match ty with
- | Adt (Tuple, _) ->
+ | Adt (Tuple, _, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | Adt (AdtId def_id, _) ->
+ | Adt (AdtId def_id, _, _) ->
(* "Regular" ADT *)
let adt_ident =
match variant_id with
@@ -332,7 +374,7 @@ let adt_g_value_to_string (fmt : value_formatter)
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | Adt (Assumed aty, _) -> (
+ | Adt (Assumed aty, _, _) -> (
(* Assumed type *)
match aty with
| State ->
@@ -377,12 +419,13 @@ let adt_g_value_to_string (fmt : value_formatter)
"@Option::None")
else
raise (Failure "Unreachable: improper variant id for result type")
- | Vec ->
+ | Vec | Array | Slice | Str ->
assert (variant_id = None);
let field_values =
List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values
in
- "Vec [" ^ String.concat "; " field_values ^ "]")
+ let id = assumed_ty_to_string aty in
+ id ^ " [" ^ String.concat "; " field_values ^ "]")
| _ ->
let fmt = value_to_type_formatter fmt in
raise
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 5af28efd..9b5d9236 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -32,7 +32,11 @@ IdGen ()
module VarId =
IdGen ()
+module ConstGenericVarId = T.ConstGenericVarId
+
type integer_type = T.integer_type [@@deriving show, ord]
+type const_generic_var = T.const_generic_var [@@deriving show, ord]
+type const_generic = T.const_generic [@@deriving show, ord]
(** The assumed types for the pure AST.
@@ -50,7 +54,16 @@ type integer_type = T.integer_type [@@deriving show, ord]
this state is opaque to Aeneas (the user can define it, or leave it as
assumed)
*)
-type assumed_ty = State | Result | Error | Fuel | Vec | Option
+type assumed_ty =
+ | State
+ | Result
+ | Error
+ | Fuel
+ | Vec
+ | Option
+ | Array
+ | Slice
+ | Str
[@@deriving show, ord]
(* TODO: we should never directly manipulate [Return] and [Fail], but rather
@@ -114,26 +127,28 @@ type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty
polymorphic = false;
}]
+type literal_type = T.literal_type [@@deriving show, ord]
+
(** Ancestor for iter visitor for [ty] *)
class ['self] iter_ty_base =
object (_self : 'self)
inherit [_] iter_type_id
+ inherit! [_] T.iter_const_generic
+ inherit! [_] PV.iter_literal_type
method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> ()
- method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
end
(** Ancestor for map visitor for [ty] *)
class ['self] map_ty_base =
object (_self : 'self)
inherit [_] map_type_id
+ inherit! [_] T.map_const_generic
+ inherit! [_] PV.map_literal_type
method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
-
- method visit_integer_type : 'env -> integer_type -> integer_type =
- fun _ x -> x
end
type ty =
- | Adt of type_id * ty list
+ | Adt of type_id * ty list * const_generic list
(** {!Adt} encodes ADTs and tuples and assumed types.
TODO: what about the ended regions? (ADTs may be parameterized
@@ -142,12 +157,7 @@ type ty =
such "partial" ADTs.
*)
| TypeVar of type_var_id
- | Bool
- | Char
- | Integer of integer_type
- | Str
- | Array of ty (* TODO: this should be an assumed type?... *)
- | Slice of ty (* TODO: this should be an assumed type?... *)
+ | Literal of literal_type
| Arrow of ty * ty
[@@deriving
show,
@@ -182,6 +192,7 @@ type type_decl = {
def_id : TypeDeclId.id;
name : name;
type_params : type_var list;
+ const_generic_params : const_generic_var list;
kind : type_decl_kind;
}
[@@deriving show]
@@ -393,7 +404,12 @@ type qualif_id =
which explains why we have the [type_params] field: a function or ADT
constructor is always fully instantiated.
*)
-type qualif = { id : qualif_id; type_args : ty list } [@@deriving show]
+type qualif = {
+ id : qualif_id;
+ type_args : ty list;
+ const_generic_args : const_generic list;
+}
+[@@deriving show]
type field_id = FieldId.id [@@deriving show, ord]
type var_id = VarId.id [@@deriving show, ord]
@@ -716,6 +732,7 @@ type fun_sig_info = {
*)
type fun_sig = {
type_params : type_var list;
+ const_generic_params : const_generic_var list;
(** TODO: we should analyse the signature to make the type parameters implicit whenever possible *)
inputs : ty list;
(** The input types.
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 72084dfc..ef8bac37 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -5,8 +5,8 @@ open PureUtils
(** Utility function, used for type checking *)
let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
- (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) :
- ty list =
+ (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list)
+ (cgs : const_generic list) : ty list =
match type_id with
| Tuple ->
(* Tuple *)
@@ -15,7 +15,7 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
| AdtId def_id ->
(* "Regular" ADT *)
let def = TypeDeclId.Map.find def_id type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys
+ type_decl_get_instantiated_fields_types def variant_id tys cgs
| Assumed aty -> (
(* Assumed type *)
match aty with
@@ -47,7 +47,10 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
else if variant_id = option_none_id then []
else
raise (Failure "Unreachable: improper variant id for result type")
- | Vec -> raise (Failure "Unreachable: `Vector` values are opaque"))
+ | Vec | Array | Slice | Str ->
+ raise
+ (Failure
+ "Unreachable: trying to access the fields of an opaque type"))
type tc_ctx = {
type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *)
@@ -56,7 +59,7 @@ type tc_ctx = {
env : ty VarId.Map.t; (** Environment from variables to types *)
}
-let check_literal (v : literal) (ty : ty) : unit =
+let check_literal (v : literal) (ty : literal_type) : unit =
match (ty, v) with
| Integer int_ty, PV.Scalar sv -> assert (int_ty = sv.PV.int_ty)
| Bool, Bool _ | Char, Char _ -> ()
@@ -66,7 +69,7 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx =
log#ldebug (lazy ("check_typed_pattern: " ^ show_typed_pattern v));
match v.value with
| PatConstant cv ->
- check_literal cv v.ty;
+ check_literal cv (ty_as_literal v.ty);
ctx
| PatDummy -> ctx
| PatVar (var, _) ->
@@ -75,13 +78,9 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx =
{ ctx with env }
| PatAdt av ->
(* Compute the field types *)
- let type_id, tys =
- match v.ty with
- | Adt (type_id, tys) -> (type_id, tys)
- | _ -> raise (Failure "Inconsistently typed value")
- in
+ let type_id, tys, cgs = ty_as_adt v.ty in
let field_tys =
- get_adt_field_types ctx.type_decls type_id av.variant_id tys
+ get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs
in
let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx =
if ty <> v.ty then (
@@ -108,7 +107,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
match VarId.Map.find_opt var_id ctx.env with
| None -> ()
| Some ty -> assert (ty = e.ty))
- | Const cv -> check_literal cv e.ty
+ | Const cv -> check_literal cv (ty_as_literal e.ty)
| App (app, arg) ->
let input_ty, output_ty = destruct_arrow app.ty in
assert (input_ty = arg.ty);
@@ -130,33 +129,31 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
(* Note we can only project fields of structures (not enumerations) *)
(* Deconstruct the projector type *)
let adt_ty, field_ty = destruct_arrow e.ty in
- let adt_id, adt_type_args =
- match adt_ty with
- | Adt (type_id, tys) -> (type_id, tys)
- | _ -> raise (Failure "Unreachable")
- in
+ let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in
(* Check the ADT type *)
assert (adt_id = proj_adt_id);
assert (adt_type_args = qualif.type_args);
+ assert (adt_cg_args = qualif.const_generic_args);
(* Retrieve and check the expected field type *)
let variant_id = None in
let expected_field_tys =
get_adt_field_types ctx.type_decls proj_adt_id variant_id
- qualif.type_args
+ qualif.type_args qualif.const_generic_args
in
let expected_field_ty = FieldId.nth expected_field_tys field_id in
assert (expected_field_ty = field_ty)
| AdtCons id -> (
let expected_field_tys =
get_adt_field_types ctx.type_decls id.adt_id id.variant_id
- qualif.type_args
+ qualif.type_args qualif.const_generic_args
in
let field_tys, adt_ty = destruct_arrows e.ty in
assert (expected_field_tys = field_tys);
match adt_ty with
- | Adt (type_id, tys) ->
+ | Adt (type_id, tys, cgs) ->
assert (type_id = id.adt_id);
- assert (tys = qualif.type_args)
+ assert (tys = qualif.type_args);
+ assert (cgs = qualif.const_generic_args)
| _ -> raise (Failure "Unreachable")))
| Let (monadic, pat, re, e_next) ->
let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in
@@ -172,7 +169,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx scrut;
match switch_body with
| If (e_then, e_else) ->
- assert (scrut.ty = Bool);
+ assert (scrut.ty = Literal Bool);
assert (e_then.ty = e.ty);
assert (e_else.ty = e.ty);
check_texpression ctx e_then;
@@ -202,15 +199,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
| Some ty -> assert (ty = e.ty));
(* Check the fields *)
(* Retrieve and check the expected field type *)
- let adt_id, adt_type_args =
- match e.ty with
- | Adt (type_id, tys) -> (type_id, tys)
- | _ -> raise (Failure "Unreachable")
- in
+ let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in
assert (adt_id = AdtId supd.struct_id);
let variant_id = None in
let expected_field_tys =
get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
+ adt_cg_args
in
List.iter
(fun (fid, fe) ->
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 88b18e89..1c8d8921 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -62,7 +62,7 @@ let dest_arrow_ty (ty : ty) : ty * ty =
| Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
| _ -> raise (Failure "Unreachable")
-let compute_literal_ty (cv : literal) : ty =
+let compute_literal_type (cv : literal) : literal_type =
match cv with
| PV.Scalar sv -> Integer sv.PV.int_ty
| Bool _ -> Bool
@@ -71,7 +71,7 @@ let compute_literal_ty (cv : literal) : ty =
let var_get_id (v : var) : VarId.id = v.id
let mk_typed_pattern_from_literal (cv : literal) : typed_pattern =
- let ty = compute_literal_ty cv in
+ let ty = Literal (compute_literal_type cv) in
{ value = PatConstant cv; ty }
let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression)
@@ -90,11 +90,13 @@ let mk_mplace (var_id : E.VarId.id) (name : string option)
{ var_id; name; projection }
(** Type substitution *)
-let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
+let ty_substitute (tsubst : TypeVarId.id -> ty)
+ (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty =
let obj =
object
inherit [_] map_ty
method! visit_TypeVar _ var_id = tsubst var_id
+ method! visit_ConstGenericVar _ var_id = cgsubst var_id
end
in
obj#visit_ty () ty
@@ -109,6 +111,10 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
in
fun id -> TypeVarId.Map.find id mp
+let make_const_generic_subst (vars : const_generic_var list)
+ (cgs : const_generic list) : ConstGenericVarId.id -> const_generic =
+ Substitute.make_const_generic_subst_from_vars vars cgs
+
(** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
@@ -132,14 +138,17 @@ let type_decl_get_fields (def : type_decl)
(** Instantiate the type variables for the chosen variant in an ADT definition,
and return the list of the types of its fields *)
let type_decl_get_instantiated_fields_types (def : type_decl)
- (opt_variant_id : VariantId.id option) (types : ty list) : ty list =
+ (opt_variant_id : VariantId.id option) (types : ty list)
+ (cgs : const_generic list) : ty list =
let ty_subst = make_type_subst def.type_params types in
+ let cg_subst = make_const_generic_subst def.const_generic_params cgs in
let fields = type_decl_get_fields def opt_variant_id in
- List.map (fun f -> ty_substitute ty_subst f.field_ty) fields
+ List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields
-let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
+let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
+ (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) :
inst_fun_sig =
- let subst = ty_substitute tsubst in
+ let subst = ty_substitute tsubst cgsubst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
let doutputs = List.map subst sg.doutputs in
@@ -181,9 +190,9 @@ let is_global (e : texpression) : bool =
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
-let ty_as_adt (ty : ty) : type_id * ty list =
+let ty_as_adt (ty : ty) : type_id * ty list * const_generic list =
match ty with
- | Adt (id, tys) -> (id, tys)
+ | Adt (id, tys, cgs) -> (id, tys, cgs)
| _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of {!Meta} *)
@@ -291,13 +300,19 @@ let opt_destruct_function_call (e : texpression) :
let opt_destruct_result (ty : ty) : ty option =
match ty with
- | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys)
+ | Adt (Assumed Result, tys, cgs) ->
+ assert (cgs = []);
+ Some (Collections.List.to_cons_nil tys)
| _ -> None
let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
- match ty with Adt (Tuple, tys) -> Some tys | _ -> None
+ match ty with
+ | Adt (Tuple, tys, cgs) ->
+ assert (cgs = []);
+ Some tys
+ | _ -> None
let mk_abs (x : typed_pattern) (e : texpression) : texpression =
let ty = Arrow (x.ty, e.ty) in
@@ -351,7 +366,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) :
let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
(* Sanity check: the scrutinee has the proper type *)
(match sb with
- | If (_, _) -> assert (scrut.ty = Bool)
+ | If (_, _) -> assert (scrut.ty = Literal Bool)
| Match branches ->
List.iter
(fun (b : match_branch) -> assert (b.pat.ty = scrut.ty))
@@ -368,14 +383,14 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- if there is > one type: wrap them in a tuple
*)
let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
+ match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, [])
-let mk_bool_ty : ty = Bool
-let mk_unit_ty : ty = Adt (Tuple, [])
+let mk_bool_ty : ty = Literal Bool
+let mk_unit_ty : ty = Adt (Tuple, [], [])
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = [] } in
+ let qualif = { id; type_args = []; const_generic_args = [] } in
let e = Qualif qualif in
let ty = mk_unit_ty in
{ e; ty }
@@ -415,7 +430,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern =
| [ v ] -> v
| _ ->
let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
+ let ty = Adt (Tuple, tys, []) in
let value = PatAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -426,11 +441,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
| _ ->
(* Compute the types of the fields, and the type of the tuple constructor *)
let tys = List.map (fun (v : texpression) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
+ let ty = Adt (Tuple, tys, []) in
let ty = mk_arrows tys ty in
(* Construct the tuple constructor qualifier *)
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = tys } in
+ let qualif = { id; type_args = tys; const_generic_args = [] } in
(* Put everything together *)
let cons = { e = Qualif qualif; ty } in
mk_apps cons vl
@@ -441,36 +456,39 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option)
{ value; ty = adt_ty }
let ty_as_integer (t : ty) : T.integer_type =
- match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable")
+ match t with
+ | Literal (Integer int_ty) -> int_ty
+ | _ -> raise (Failure "Unreachable")
-(* TODO: move *)
-let type_decl_is_enum (def : T.type_decl) : bool =
- match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false
+let ty_as_literal (t : ty) : T.literal_type =
+ match t with Literal ty -> ty | _ -> raise (Failure "Unreachable")
-let mk_state_ty : ty = Adt (Assumed State, [])
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
-let mk_error_ty : ty = Adt (Assumed Error, [])
-let mk_fuel_ty : ty = Adt (Assumed Fuel, [])
+let mk_state_ty : ty = Adt (Assumed State, [], [])
+let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], [])
+let mk_error_ty : ty = Adt (Assumed Error, [], [])
+let mk_fuel_ty : ty = Adt (Assumed Fuel, [], [])
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
- let qualif = { id; type_args = [] } in
+ let qualif = { id; type_args = []; const_generic_args = [] } in
let e = Qualif qualif in
{ e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
- | Adt (Assumed Result, [ ty ]) -> ty
+ | Adt (Assumed Result, [ ty ], cgs) ->
+ assert (cgs = []);
+ ty
| _ -> raise (Failure "not a result type")
let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
- let ty = Adt (Assumed Result, type_args) in
+ let ty = Adt (Assumed Result, type_args, []) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
in
- let qualif = { id; type_args } in
+ let qualif = { id; type_args; const_generic_args = [] } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -483,11 +501,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
- let ty = Adt (Assumed Result, type_args) in
+ let ty = Adt (Assumed Result, type_args, []) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
in
- let qualif = { id; type_args } in
+ let qualif = { id; type_args; const_generic_args = [] } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow v.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -496,7 +514,7 @@ let mk_result_return_texpression (v : texpression) : texpression =
(** Create a [Fail err] pattern which captures the error *)
let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
- let ty = Adt (Assumed Result, [ ty ]) in
+ let ty = Adt (Assumed Result, [ ty ], []) in
let value =
PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
in
@@ -508,7 +526,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
mk_result_fail_pattern error_pat ty
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
- let ty = Adt (Assumed Result, [ v.ty ]) in
+ let ty = Adt (Assumed Result, [ v.ty ], []) in
let value =
PatAdt { variant_id = Some result_return_id; field_values = [ v ] }
in
@@ -543,11 +561,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
let fields_values = List.map (fun e -> Option.get e) fields in
(* Retrieve the type id and the type args from the pat type (simpler this way *)
- let adt_id, type_args = ty_as_adt pat.ty in
+ let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args } in
+ let qualif = { id = qualif_id; type_args; const_generic_args } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) fields_values
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 5dc8664a..ba2a6525 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -240,6 +240,8 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
r_to_string;
type_var_id_to_string;
type_decl_id_to_string = ast_fmt.type_decl_id_to_string;
+ const_generic_var_id_to_string = ast_fmt.const_generic_var_id_to_string;
+ global_decl_id_to_string = ast_fmt.global_decl_id_to_string;
adt_variant_to_string = ast_fmt.adt_variant_to_string;
var_id_to_string;
adt_field_names = ast_fmt.adt_field_names;
@@ -247,10 +249,12 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter =
let type_params = ctx.fun_decl.signature.type_params in
+ let cg_params = ctx.fun_decl.signature.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
+ cg_params
let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string =
let fmt = bs_ctx_to_ctx_formatter ctx in
@@ -273,8 +277,12 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
let type_params = def.type_params in
+ let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
- let fmt = PrintPure.mk_type_formatter type_decls type_params in
+ let global_decls = ctx.global_context.llbc_global_decls in
+ let fmt =
+ PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ in
PrintPure.type_decl_to_string fmt def
let texpression_to_string (ctx : bs_ctx) (e : texpression) : string =
@@ -283,21 +291,25 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string =
let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string =
let type_params = sg.type_params in
+ let cg_params = sg.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
+ cg_params
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string =
let type_params = def.signature.type_params in
+ let cg_params = def.signature.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
+ cg_params
in
PrintPure.fun_decl_to_string fmt def
@@ -315,16 +327,17 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
Print.Values.abs_to_string fmt verbose indent indent_incr abs
let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) :
- inst_fun_sig =
+ (back_id : T.RegionGroupId.id option) (tys : ty list)
+ (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig =
(* Lookup the non-instantiated function signature *)
let sg =
(RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
in
(* Create the substitution *)
let tsubst = make_type_subst sg.type_params tys in
+ let cgsubst = make_const_generic_subst sg.const_generic_params cgs in
(* Apply *)
- fun_sig_substitute tsubst sg
+ fun_sig_substitute tsubst cgsubst sg
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
@@ -380,17 +393,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
- | T.Adt (type_id, regions, tys) -> (
+ | T.Adt (type_id, regions, tys, cgs) -> (
(* Can't translate types with regions for now *)
assert (regions = []);
let tys = List.map translate tys in
match type_id with
- | T.AdtId adt_id -> Adt (AdtId adt_id, tys)
+ | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs)
| T.Tuple -> mk_simpl_tuple_ty tys
| T.Assumed aty -> (
match aty with
- | T.Vec -> Adt (Assumed Vec, tys)
- | T.Option -> Adt (Assumed Option, tys)
+ | T.Vec -> Adt (Assumed Vec, tys, cgs)
+ | T.Option -> Adt (Assumed Option, tys, cgs)
| T.Box -> (
(* Eliminate the boxes *)
match tys with
@@ -399,15 +412,13 @@ let rec translate_sty (ty : T.sty) : ty =
raise
(Failure
"Box/vec/option type with incorrect number of arguments")
- )))
+ )
+ | T.Array -> Adt (Assumed Array, tys, cgs)
+ | T.Slice -> Adt (Assumed Slice, tys, cgs)
+ | T.Str -> Adt (Assumed Str, tys, cgs)))
| TypeVar vid -> TypeVar vid
- | Bool -> Bool
- | Char -> Char
+ | Literal ty -> Literal ty
| Never -> raise (Failure "Unreachable")
- | Integer int_ty -> Integer int_ty
- | Str -> Str
- | Array ty -> Array (translate ty)
- | Slice ty -> Slice (translate ty)
| Ref (_, rty, _) -> translate rty
let translate_field (f : T.field) : field =
@@ -445,8 +456,9 @@ let translate_type_decl (def : T.type_decl) : type_decl =
(* Can't translate types with regions for now *)
assert (def.region_params = []);
let type_params = def.type_params in
+ let const_generic_params = def.const_generic_params in
let kind = translate_type_decl_kind def.T.kind in
- { def_id; name; type_params; kind }
+ { def_id; name; type_params; const_generic_params; kind }
(** Translate a type, seen as an input/output of a forward function
(preserve all borrows, etc.)
@@ -455,7 +467,7 @@ let translate_type_decl (def : T.type_decl) : type_decl =
let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
let translate = translate_fwd_ty type_infos in
match ty with
- | T.Adt (type_id, regions, tys) -> (
+ | T.Adt (type_id, regions, tys, cgs) -> (
(* Can't translate types with regions for now *)
assert (regions = []);
(* Translate the type parameters *)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index a6e11363..e2cdc726 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -36,8 +36,8 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
(* Boolean expansion: there should be two branches *)
match ls with
| [
- (Some (V.SePrimitive (PV.Bool true)), true_exp);
- (Some (V.SePrimitive (PV.Bool false)), false_exp);
+ (Some (V.SeLiteral (PV.Bool true)), true_exp);
+ (Some (V.SeLiteral (PV.Bool false)), false_exp);
] ->
ExpandBool (true_exp, false_exp)
| _ -> raise (Failure "Ill-formed boolean expansion"))
@@ -50,7 +50,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
let get_scalar (see : V.symbolic_expansion option) : V.scalar_value
=
match see with
- | Some (V.SePrimitive (PV.Scalar cv)) ->
+ | Some (V.SeLiteral (PV.Scalar cv)) ->
assert (cv.PV.int_ty = int_ty);
cv
| _ -> raise (Failure "Unreachable")
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index 9ba73c7e..ba5e237b 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -32,33 +32,39 @@ type pure_fun_translation = fun_and_loops * fun_and_loops list
let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string =
let type_params = def.type_params in
+ let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.type_decls in
- let fmt = PrintPure.mk_type_formatter type_decls type_params in
+ let global_decls = ctx.global_context.global_decls in
+ let fmt =
+ PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ in
PrintPure.type_decl_to_string fmt def
-let type_id_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string =
- let type_params = def.type_params in
- let type_decls = ctx.type_context.type_decls in
- let fmt = PrintPure.mk_type_formatter type_decls type_params in
- PrintPure.type_decl_to_string fmt def
+let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string =
+ Print.fun_name_to_string
+ (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name
let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string =
let type_params = sg.type_params in
+ let cg_params = sg.const_generic_params in
let type_decls = ctx.type_context.type_decls in
let fun_decls = ctx.fun_context.fun_decls in
let global_decls = ctx.global_context.global_decls in
let fmt =
PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
+ cg_params
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string =
let type_params = def.signature.type_params in
+ let cg_params = def.signature.const_generic_params in
let type_decls = ctx.type_context.type_decls in
let fun_decls = ctx.fun_context.fun_decls in
let global_decls = ctx.global_context.global_decls in
let fmt =
PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
+ cg_params
in
PrintPure.fun_decl_to_string fmt def
diff --git a/compiler/Values.ml b/compiler/Values.ml
index 3d6bc9c1..f70b9b4b 100644
--- a/compiler/Values.ml
+++ b/compiler/Values.ml
@@ -147,7 +147,7 @@ class ['self] map_typed_value_base =
(** An untyped value, used in the environments *)
type value =
- | Primitive of literal (** Non-symbolic primitive value *)
+ | Literal of literal (** Non-symbolic primitive value *)
| Adt of adt_value (** Enumerations and structures *)
| Bottom (** No value (uninitialized or moved value) *)
| Borrow of borrow_content (** A borrowed value *)
@@ -1014,7 +1014,7 @@ type abs = {
TODO: this should rather be name "expanded_symbolic"
*)
type symbolic_expansion =
- | SePrimitive of literal
+ | SeLiteral of literal
| SeAdt of (VariantId.id option * symbolic_value list)
| SeMutRef of BorrowId.id * symbolic_value
| SeSharedRef of BorrowId.Set.t * symbolic_value