diff options
Diffstat (limited to '')
25 files changed, 407 insertions, 273 deletions
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index ccb9009e..dc2bb700 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -29,8 +29,8 @@ let compute_type_fun_global_contexts (m : A.crate) : let initialize_eval_context (type_context : C.type_context) (fun_context : C.fun_context) (global_context : C.global_context) - (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) : - C.eval_ctx = + (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) + (const_generic_vars : T.const_generic_var list) : C.eval_ctx = C.reset_global_counters (); { C.type_context; @@ -38,6 +38,7 @@ let initialize_eval_context (type_context : C.type_context) C.global_context; C.region_groups; C.type_vars; + C.const_generic_vars; C.env = [ C.Frame ]; C.ended_regions = T.RegionId.Set.empty; } @@ -76,11 +77,18 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) in let ctx = initialize_eval_context type_context fun_context global_context - region_groups sg.type_params + region_groups sg.type_params sg.const_generic_params in (* Instantiate the signature *) - let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in - let inst_sg = instantiate_fun_sig type_params sg in + let type_params = + List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + in + let cg_params = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + sg.const_generic_params + in + let inst_sg = instantiate_fun_sig type_params cg_params sg in (* Create fresh symbolic values for the inputs *) let input_svs = List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs @@ -155,8 +163,15 @@ let evaluate_function_symbolic_synthesize_backward_from_return * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let sg = fdef.signature in - let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in - let ret_inst_sg = instantiate_fun_sig type_params sg in + let type_params = + List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + in + let cg_params = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + sg.const_generic_params + in + let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in let ret_rty = ret_inst_sg.output in (* Move the return value out of the return variable *) let pop_return_value = is_regular_return in @@ -490,7 +505,7 @@ module Test = struct compute_type_fun_global_contexts crate in let ctx = - initialize_eval_context type_context fun_context global_context [] [] + initialize_eval_context type_context fun_context global_context [] [] [] in (* Insert the (uninitialized) local variables *) @@ -518,13 +533,11 @@ module Test = struct (** Small helper: return true if the function is a *transparent* unit function (no parameters, no arguments) - TODO: move *) let fun_decl_is_transparent_unit (def : A.fun_decl) : bool = - match def.body with - | None -> false - | Some body -> - body.arg_count = 0 - && List.length def.A.signature.region_params = 0 - && List.length def.A.signature.type_params = 0 - && List.length def.A.signature.inputs = 0 + Option.is_some def.body + && def.A.signature.region_params = [] + && def.A.signature.type_params = [] + && def.A.signature.const_generic_params = [] + && def.A.signature.inputs = [] (** Test all the unit functions in a list of function definitions *) let test_unit_functions (crate : A.crate) : unit = diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 38c6df3d..3d258b32 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -1733,7 +1733,7 @@ let destructure_abs (abs_kind : V.abs_kind) (can_end : bool) and list_values (v : V.typed_value) : V.typed_avalue list * V.typed_value = let ty = v.V.ty in match v.V.value with - | Primitive _ -> ([], v) + | Literal _ -> ([], v) | Adt adt -> let avll, field_values = List.split (List.map list_values adt.field_values) @@ -1841,7 +1841,7 @@ let convert_value_to_abstractions (abs_kind : V.abs_kind) (can_end : bool) let ty = v.V.ty in match v.V.value with - | V.Primitive _ -> ([], v) + | V.Literal _ -> ([], v) | V.Bottom -> (* Can happen: we *do* convert dummy values to abstractions, and dummy values can contain bottoms *) diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml index 55365043..bf083aa4 100644 --- a/compiler/InterpreterBorrowsCore.ml +++ b/compiler/InterpreterBorrowsCore.ml @@ -87,24 +87,28 @@ let add_borrow_or_abs_id_to_chain (msg : string) (id : borrow_or_abs_id) (** Helper function. - This function allows to define in a generic way a comparison of region types. + This function allows to define in a generic way a comparison of **region types**. See [projections_interesect] for instance. [default]: default boolean to return, when comparing types with no regions [combine]: how to combine booleans [compare_regions]: how to compare regions + + TODO: is there a way of deriving such a comparison? *) let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) (compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool) (ty1 : T.rty) (ty2 : T.rty) : bool = let compare = compare_rtys default combine compare_regions in match (ty1, ty2) with - | T.Bool, T.Bool | T.Char, T.Char | T.Str, T.Str -> default - | T.Integer int_ty1, T.Integer int_ty2 -> - assert (int_ty1 = int_ty2); + | T.Literal lit1, T.Literal lit2 -> + assert (lit1 = lit2); default - | T.Adt (id1, regions1, tys1), T.Adt (id2, regions2, tys2) -> + | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) -> assert (id1 = id2); + (* There are no regions in the const generics, so we ignore them, + but we still check they are the same, for sanity *) + assert (cgs1 = cgs2); (* The check for the ADTs is very crude: we simply compare the arguments * two by two. @@ -134,7 +138,6 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) in (* Combine *) combine params_b tys_b - | T.Array ty1, T.Array ty2 | T.Slice ty1, T.Slice ty2 -> compare ty1 ty2 | T.Ref (r1, ty1, kind1), T.Ref (r2, ty2, kind2) -> (* Sanity check *) assert (kind1 = kind2); diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index 64a90217..3b196571 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -216,7 +216,8 @@ let apply_symbolic_expansion_non_borrow (config : C.config) let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (regions : T.RegionId.id T.region list) (types : T.rty list) - (ctx : C.eval_ctx) : V.symbolic_expansion list = + (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list + = (* Lookup the definition and check if it is an enumeration with several * variants *) let def = C.ctx_lookup_type_decl ctx def_id in @@ -224,6 +225,7 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (* Retrieve, for every variant, the list of its instantiated field types *) let variants_fields_types = Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types + cgs in (* Check if there is strictly more than one variant *) if List.length variants_fields_types > 1 && not expand_enumerations then @@ -280,11 +282,12 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) : let compute_expanded_symbolic_adt_value (expand_enumerations : bool) (kind : V.sv_kind) (adt_id : T.type_id) (regions : T.RegionId.id T.region list) (types : T.rty list) - (ctx : C.eval_ctx) : V.symbolic_expansion list = + (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list + = match (adt_id, regions, types) with | T.AdtId def_id, _, _ -> compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind - def_id regions types ctx + def_id regions types cgs ctx | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ] | T.Assumed T.Option, [], [ ty ] -> compute_expanded_symbolic_option_value expand_enumerations kind ty @@ -513,10 +516,10 @@ let expand_symbolic_bool (config : C.config) (sv : V.symbolic_value) let original_sv = sv in let original_sv_place = sv_place in let rty = original_sv.V.sv_ty in - assert (rty = T.Bool); + assert (rty = T.Literal PV.Bool); (* Expand the symbolic value to true or false and continue execution *) - let see_true = V.SePrimitive (PV.Bool true) in - let see_false = V.SePrimitive (PV.Bool false) in + let see_true = V.SeLiteral (PV.Bool true) in + let see_false = V.SeLiteral (PV.Bool false) in let seel = [ (Some see_true, cf_true); (Some see_false, cf_false) ] in (* Apply the symbolic expansion (this also outputs the updated symbolic AST) *) apply_branching_symbolic_expansions_non_borrow config original_sv @@ -540,12 +543,12 @@ let expand_symbolic_value_no_branching (config : C.config) fun cf ctx -> match rty with (* ADTs *) - | T.Adt (adt_id, regions, types) -> + | T.Adt (adt_id, regions, types, cgs) -> (* Compute the expanded value *) let allow_branching = false in let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types ctx + regions types cgs ctx in (* There should be exacly one branch *) let see = Collections.List.to_cons_nil seel in @@ -597,12 +600,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value) (* Execute *) match rty with (* ADTs *) - | T.Adt (adt_id, regions, types) -> + | T.Adt (adt_id, regions, types, cgs) -> let allow_branching = true in (* Compute the expanded value *) let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types ctx + regions types cgs ctx in (* Apply *) let seel = List.map (fun see -> (Some see, cf_branches)) seel in @@ -617,7 +620,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value) (tgts : (V.scalar_value * st_cm_fun) list) (otherwise : st_cm_fun) (cf_after_join : st_m_fun) : m_fun = (* Sanity check *) - assert (sv.V.sv_ty = T.Integer int_type); + assert (sv.V.sv_ty = T.Literal (PV.Integer int_type)); (* For all the branches of the switch, we expand the symbolic value * to the value given by the branch and execute the branch statement. * For the otherwise branch, we leave the symbolic value as it is @@ -628,7 +631,7 @@ let expand_symbolic_int (config : C.config) (sv : V.symbolic_value) * (optional expansion, statement to execute) *) let seel = - List.map (fun (v, cf) -> (Some (V.SePrimitive (PV.Scalar v)), cf)) tgts + List.map (fun (v, cf) -> (Some (V.SeLiteral (PV.Scalar v)), cf)) tgts in let seel = List.append seel [ (None, otherwise) ] in (* Then expand and evaluate - this generates the proper symbolic AST *) @@ -676,7 +679,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = ^ symbolic_value_to_string ctx sv)); let cc : cm_fun = match sv.V.sv_ty with - | T.Adt (AdtId def_id, _, _) -> + | T.Adt (AdtId def_id, _, _, _) -> (* {!expand_symbolic_value_no_branching} checks if there are branchings, * but we prefer to also check it here - this leads to cleaner messages * and debugging *) @@ -701,16 +704,15 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = [config]): " ^ Print.name_to_string def.name)) else expand_symbolic_value_no_branching config sv None - | T.Adt ((Tuple | Assumed Box), _, _) | T.Ref (_, _, _) -> + | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) -> (* Ok *) expand_symbolic_value_no_branching config sv None - | T.Adt (Assumed (Vec | Option), _, _) -> + | T.Adt (Assumed (Vec | Option | Array | Slice | Str), _, _, _) -> (* We can't expand those *) - raise (Failure "Attempted to greedily expand a Vec or an Option ") - | T.Array _ -> raise Utils.Unimplemented - | T.Slice _ -> raise (Failure "Can't expand symbolic slices") - | T.TypeVar _ | Bool | Char | Never | Integer _ | Str -> - raise (Failure "Unreachable") + raise + (Failure + "Attempted to greedily expand an ADT which can't be expanded ") + | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable") in (* Compose and continue *) comp cc expand cf ctx diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index d75f5a26..bb159f05 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -94,24 +94,23 @@ let access_rplace_reorganize (config : C.config) (expand_prim_copy : bool) ctx (** Convert an operand constant operand value to a typed value *) -let primitive_to_typed_value (ty : T.ety) (cv : V.primitive_value) : +let literal_to_typed_value (ty : PV.literal_type) (cv : V.literal) : V.typed_value = (* Check the type while converting - we actually need some information * contained in the type *) log#ldebug (lazy - ("primitive_to_typed_value:" ^ "\n- cv: " - ^ Print.PrimitiveValues.primitive_value_to_string cv)); + ("literal_to_typed_value:" ^ "\n- cv: " + ^ Print.PrimitiveValues.literal_to_string cv)); match (ty, cv) with (* Scalar, boolean... *) - | T.Bool, Bool v -> { V.value = V.Primitive (Bool v); ty } - | T.Char, Char v -> { V.value = V.Primitive (Char v); ty } - | T.Str, String v -> { V.value = V.Primitive (String v); ty } - | T.Integer int_ty, PV.Scalar v -> + | PV.Bool, Bool v -> { V.value = V.Literal (Bool v); ty = T.Literal ty } + | Char, Char v -> { V.value = V.Literal (Char v); ty = T.Literal ty } + | Integer int_ty, PV.Scalar v -> (* Check the type and the ranges *) assert (int_ty = v.int_ty); assert (check_scalar_value_in_range v); - { V.value = V.Primitive (PV.Scalar v); ty } + { V.value = V.Literal (PV.Scalar v); ty = T.Literal ty } (* Remaining cases (invalid) *) | _, _ -> raise (Failure "Improperly typed constant value") @@ -138,14 +137,16 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config) * the fact that we have exhaustive matches below makes very obvious the cases * in which we need to fail *) match v.V.value with - | V.Primitive _ -> (ctx, v) + | V.Literal _ -> (ctx, v) | V.Adt av -> (* Sanity check *) (match v.V.ty with - | T.Adt (T.Assumed (T.Box | Vec), _, _) -> + | T.Adt (T.Assumed (T.Box | Vec), _, _, _) -> raise (Failure "Can't copy an assumed value other than Option") - | T.Adt (T.AdtId _, _, _) -> assert allow_adt_copy - | T.Adt ((T.Assumed Option | T.Tuple), _, _) -> () (* Ok *) + | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy + | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *) + | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) -> + assert (ty_is_primitively_copyable ty) | _ -> raise (Failure "Unreachable")); let ctx, fields = List.fold_left_map @@ -231,7 +232,7 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : match op with | Expressions.Constant (ty, cv) -> (* No need to reorganize the context *) - primitive_to_typed_value ty cv |> ignore; + literal_to_typed_value ty cv |> ignore; cf ctx | Expressions.Copy p -> (* Access the value *) @@ -259,7 +260,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> cf (primitive_to_typed_value ty cv) ctx + | Expressions.Constant (ty, cv) -> cf (literal_to_typed_value ty cv) ctx | Expressions.Copy p -> (* Access the value *) let access = Read in diff --git a/compiler/InterpreterLoopsCore.ml b/compiler/InterpreterLoopsCore.ml index 209fce1c..6e33c75b 100644 --- a/compiler/InterpreterLoopsCore.ml +++ b/compiler/InterpreterLoopsCore.ml @@ -60,8 +60,7 @@ module type PrimMatcher = sig val match_rtys : T.rty -> T.rty -> T.rty (** The input primitive values are not equal *) - val match_distinct_primitive_values : - T.ety -> V.primitive_value -> V.primitive_value -> V.typed_value + val match_distinct_literals : T.ety -> V.literal -> V.literal -> V.typed_value (** The input ADTs don't have the same variant *) val match_distinct_adts : T.ety -> V.adt_value -> V.adt_value -> V.typed_value diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml index aff8f3fe..a9ec9ecf 100644 --- a/compiler/InterpreterLoopsFixedPoint.ml +++ b/compiler/InterpreterLoopsFixedPoint.ml @@ -109,6 +109,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun = (fun r -> if T.RegionId.Set.mem r rids then nrid else r) (fun x -> x) (fun x -> x) + (fun x -> x) (fun id -> let nid = C.fresh_symbolic_value_id () in let sv = V.SymbolicValueId.Map.find id absl_id_maps.sids_to_values in diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index 6fb0449d..bf88e055 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -556,6 +556,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context; region_groups; type_vars; + const_generic_vars; env = _; ended_regions = ended_regions0; } = @@ -567,6 +568,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context = _; region_groups = _; type_vars = _; + const_generic_vars = _; env = _; ended_regions = ended_regions1; } = @@ -580,6 +582,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) global_context; region_groups; type_vars; + const_generic_vars; env; ended_regions; } @@ -635,6 +638,7 @@ let refresh_abs (old_abs : V.AbstractionId.Set.t) (ctx : C.eval_ctx) : (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) subst ctx.env in { ctx with C.env } diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 80cd93cf..9248e513 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -44,11 +44,11 @@ let compute_abs_borrows_loans_maps (no_duplicates : bool) (id0 : Id0.id) (id1 : Id1.id) : unit = (* Sanity check *) (if check_singleton_sets || check_not_already_registered then - match Id0.Map.find_opt id0 !map with - | None -> () - | Some set -> - assert ( - (not check_not_already_registered) || not (Id1.Set.mem id1 set))); + match Id0.Map.find_opt id0 !map with + | None -> () + | Some set -> + assert ( + (not check_not_already_registered) || not (Id1.Set.mem id1 set))); (* Update the mapping *) map := Id0.Map.update id0 @@ -149,9 +149,11 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) (match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty = let match_rec = match_types match_distinct_types match_regions in match (ty0, ty1) with - | Adt (id0, regions0, tys0), Adt (id1, regions1, tys1) -> + | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) -> assert (id0 = id1); + assert (cgs0 = cgs1); let id = id0 in + let cgs = cgs1 in let regions = List.map (fun (id0, id1) -> match_regions id0 id1) @@ -160,16 +162,15 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) let tys = List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1) in - Adt (id, regions, tys) + Adt (id, regions, tys, cgs) | TypeVar vid0, TypeVar vid1 -> assert (vid0 = vid1); let vid = vid0 in TypeVar vid - | Bool, Bool | Char, Char | Never, Never | Str, Str -> ty0 - | Integer int_ty0, Integer int_ty1 -> - assert (int_ty0 = int_ty1); + | Literal lty0, Literal lty1 -> + assert (lty0 = lty1); ty0 - | Array ty0, Array ty1 | Slice ty0, Slice ty1 -> match_rec ty0 ty1 + | Never, Never -> ty0 | Ref (r0, ty0, k0), Ref (r1, ty1, k1) -> let r = match_regions r0 r1 in let ty = match_rec ty0 ty1 in @@ -184,8 +185,8 @@ module MakeMatcher (M : PrimMatcher) : Matcher = struct let match_rec = match_typed_values ctx in let ty = M.match_etys v0.V.ty v1.V.ty in match (v0.V.value, v1.V.value) with - | V.Primitive pv0, V.Primitive pv1 -> - if pv0 = pv1 then v1 else M.match_distinct_primitive_values ty pv0 pv1 + | V.Literal lv0, V.Literal lv1 -> + if lv0 = lv1 then v1 else M.match_distinct_literals ty lv0 lv1 | V.Adt av0, V.Adt av1 -> if av0.variant_id = av1.variant_id then let fields = List.combine av0.field_values av1.field_values in @@ -385,8 +386,8 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct assert (ty0 = ty1); ty0 - let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value) - (_ : V.primitive_value) : V.typed_value = + let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) : + V.typed_value = mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty let match_distinct_adts (ty : T.ety) (adt0 : V.adt_value) (adt1 : V.adt_value) @@ -834,8 +835,8 @@ struct in match_types match_distinct_types match_regions ty0 ty1 - let match_distinct_primitive_values (ty : T.ety) (_ : V.primitive_value) - (_ : V.primitive_value) : V.typed_value = + let match_distinct_literals (ty : T.ety) (_ : V.literal) (_ : V.literal) : + V.typed_value = mk_fresh_symbolic_typed_value_from_ety V.LoopJoin ty let match_distinct_adts (_ty : T.ety) (_adt0 : V.adt_value) @@ -1616,7 +1617,7 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id) cc (cf (if is_loop_entry then EndEnterLoop (loop_id, input_values) - else EndContinue (loop_id, input_values))) + else EndContinue (loop_id, input_values))) tgt_ctx in diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 619815b3..4a439250 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -97,7 +97,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) match (pe, v.V.value, v.V.ty) with | ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id), V.Adt adt, - T.Adt (type_id, _, _) ) -> ( + T.Adt (type_id, _, _, _) ) -> ( (* Check consistency *) (match (proj_kind, type_id) with | ProjAdt (def_id, opt_variant_id), T.AdtId def_id' -> @@ -119,7 +119,8 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) let updated = { v with value = nadt } in Ok (ctx, { res with updated })) (* Tuples *) - | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _) -> ( + | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _) + -> ( assert (arity = List.length adt.field_values); let fv = T.FieldId.nth adt.field_values field_id in (* Project *) @@ -144,7 +145,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) (* Box dereferencement *) | ( DerefBox, Adt { variant_id = None; field_values = [ bv ] }, - T.Adt (T.Assumed T.Box, _, _) ) -> ( + T.Adt (T.Assumed T.Box, _, _, _) ) -> ( (* We allow moving inside of boxes. In practice, this kind of * manipulations should happen only inside unsage code, so * it shouldn't happen due to user code, and we leverage it @@ -249,7 +250,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) in Ok (ctx, { res with updated = nv }) else Error (FailSharedLoan bids)) - | (_, (V.Primitive _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r -> + | (_, (V.Literal _ | V.Adt _ | V.Bottom | V.Borrow _), _) as r -> let pe, v, ty = r in let pe = "- pe: " ^ E.show_projection_elem pe in let v = "- v:\n" ^ V.show_value v in diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli index 6e9286cd..4a9f3b41 100644 --- a/compiler/InterpreterPaths.mli +++ b/compiler/InterpreterPaths.mli @@ -61,6 +61,7 @@ val compute_expanded_bottom_adt_value : T.VariantId.id option -> T.erased_region list -> T.ety list -> + T.const_generic list -> V.typed_value (** Compute an expanded [Option] ⊥ value *) diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml index 9487df84..faed066b 100644 --- a/compiler/InterpreterProjectors.ml +++ b/compiler/InterpreterProjectors.ml @@ -23,12 +23,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx) if not (ty_has_regions_in_set regions ty) then [] else match (v.V.value, ty) with - | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> [] - | V.Adt adt, T.Adt (id, region_params, tys) -> + | V.Literal _, T.Literal _ -> [] + | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> (* Retrieve the types of the fields *) let field_types = Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys + region_params tys cgs in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -102,12 +102,12 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx) else let value : V.avalue = match (v.V.value, ty) with - | V.Primitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> V.AIgnored - | V.Adt adt, T.Adt (id, region_params, tys) -> + | V.Literal _, T.Literal _ -> V.AIgnored + | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> (* Retrieve the types of the fields *) let field_types = Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys + region_params tys cgs in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -231,7 +231,7 @@ let symbolic_expansion_non_borrow_to_value (sv : V.symbolic_value) let ty = Subst.erase_regions sv.V.sv_ty in let value = match see with - | SePrimitive cv -> V.Primitive cv + | SeLiteral cv -> V.Literal cv | SeAdt (variant_id, field_values) -> let field_values = List.map mk_typed_value_from_symbolic_value field_values @@ -267,9 +267,9 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t) (* Match *) let (value, ty) : V.avalue * T.rty = match (see, original_sv_ty) with - | SePrimitive _, (T.Bool | T.Char | T.Integer _ | T.Str) -> - (V.AIgnored, original_sv_ty) - | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys) -> + | SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty) + | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs) + -> (* Project over the field values *) let field_values = List.map diff --git a/compiler/InterpreterProjectors.mli b/compiler/InterpreterProjectors.mli index 1afb9d53..bcc3dee2 100644 --- a/compiler/InterpreterProjectors.mli +++ b/compiler/InterpreterProjectors.mli @@ -55,7 +55,16 @@ val prepare_reborrows : bool -> (V.BorrowId.id -> V.BorrowId.id) * (C.eval_ctx -> C.eval_ctx) -(** Apply (and reduce) a projector over borrows to a value. +(** Apply (and reduce) a projector over borrows to an avalue. + We use this for instance to spread the borrows present in the inputs + of a function into the regions introduced for this function. For instance: + {[ + fn f<'a, 'b, T>(x: &'a T, y: &'b T) + ]} + If we call `f` with `x -> shared_borrow l0` and `y -> shared_borrow l1`, then + for the region introduced for `'a` we need to project the value for `x` to + a shared aborrow, and we need to ignore the borrow in `y`, because it belongs + to the other region. Parameters: - [check_symbolic_no_ended]: controls whether we check or not whether diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index f5b1111e..d181ca4b 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -149,7 +149,7 @@ let eval_assertion_concrete (config : C.config) (assertion : A.assertion) : let eval_assert cf (v : V.typed_value) : m_fun = fun ctx -> match v.value with - | Primitive (Bool b) -> + | Literal (Bool b) -> (* Branch *) if b = assertion.expected then cf Unit ctx else cf Panic ctx | _ -> @@ -172,26 +172,26 @@ let eval_assertion (config : C.config) (assertion : A.assertion) : st_cm_fun = (* Evaluate the assertion *) let eval_assert cf (v : V.typed_value) : m_fun = fun ctx -> - assert (v.ty = T.Bool); + assert (v.ty = T.Literal PV.Bool); (* We make a choice here: we could completely decouple the concrete and * symbolic executions here but choose not to. In the case where we * know the concrete value of the boolean we test, we use this value * even if we are in symbolic mode. Note that this case should be * extremely rare... *) match v.value with - | Primitive (Bool _) -> + | Literal (Bool _) -> (* Delegate to the concrete evaluation function *) eval_assertion_concrete config assertion cf ctx | Symbolic sv -> assert (config.mode = C.SymbolicMode); - assert (sv.V.sv_ty = T.Bool); + assert (sv.V.sv_ty = T.Literal PV.Bool); (* We continue the execution as if the test had succeeded, and thus * perform the symbolic expansion: sv ~~> true. * We will of course synthesize an assertion in the generated code * (see below). *) let ctx = apply_symbolic_expansion_non_borrow config sv - (V.SePrimitive (PV.Bool true)) ctx + (V.SeLiteral (PV.Bool true)) ctx in (* Continue *) let expr = cf Unit ctx in @@ -232,7 +232,8 @@ let set_discriminant (config : C.config) (p : E.place) let update_value cf (v : V.typed_value) : m_fun = fun ctx -> match (v.V.ty, v.V.value) with - | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types), + | ( T.Adt + (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), V.Adt av ) -> ( (* There are two situations: - either the discriminant is already the proper one (in which case we @@ -252,7 +253,7 @@ let set_discriminant (config : C.config) (p : E.place) | T.AdtId def_id -> compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id (Some variant_id) - regions types + regions types cgs | T.Assumed T.Option -> assert (regions = []); compute_expanded_bottom_option_value variant_id @@ -260,13 +261,14 @@ let set_discriminant (config : C.config) (p : E.place) | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx) - | ( T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types), + | ( T.Adt + (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), V.Bottom ) -> let bottom_v = match type_id with | T.AdtId def_id -> compute_expanded_bottom_adt_value ctx.type_context.type_decls - def_id (Some variant_id) regions types + def_id (Some variant_id) regions types cgs | T.Assumed T.Option -> assert (regions = []); compute_expanded_bottom_option_value variant_id @@ -285,7 +287,7 @@ let set_discriminant (config : C.config) (p : E.place) * or reset an already initialized value, really. *) raise (Failure "Unexpected value") | _, (V.Adt _ | V.Bottom) -> raise (Failure "Inconsistent state") - | _, (V.Primitive _ | V.Borrow _ | V.Loan _) -> + | _, (V.Literal _ | V.Borrow _ | V.Loan _) -> raise (Failure "Unexpected value") in (* Compose and apply *) @@ -302,20 +304,21 @@ let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx) instantiation of a non-local function. *) let get_non_local_function_return_type (fid : A.assumed_fun_id) - (region_params : T.erased_region list) (type_params : T.ety list) : T.ety = + (region_params : T.erased_region list) (type_params : T.ety list) + (const_generic_params : T.const_generic list) : T.ety = (* [Box::free] has a special treatment *) - match (fid, region_params, type_params) with - | A.BoxFree, [], [ _ ] -> mk_unit_ty + match (fid, region_params, type_params, const_generic_params) with + | A.BoxFree, [], [ _ ], [] -> mk_unit_ty | _ -> (* Retrieve the function's signature *) let sg = Assumed.get_assumed_sig fid in (* Instantiate the return type *) - let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) sg.type_params) - type_params + let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.const_generic_params + const_generic_params in - Subst.erase_regions_substitute_types tsubst sg.output + Subst.erase_regions_substitute_types tsubst cgsubst sg.output let move_return_value (config : C.config) (pop_return_value : bool) (cf : V.typed_value option -> m_fun) : m_fun = @@ -443,7 +446,7 @@ let eval_box_new_concrete (config : C.config) (* Create the new box *) let cf_create cf (moved_input_value : V.typed_value) : m_fun = (* Create the box value *) - let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ]) in + let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in let box_v = V.Adt { variant_id = None; field_values = [ moved_input_value ] } in diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli index f28bf2ea..814bc964 100644 --- a/compiler/InterpreterStatements.mli +++ b/compiler/InterpreterStatements.mli @@ -31,7 +31,8 @@ val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun Note: there are no region parameters, because they should be erased. *) -val instantiate_fun_sig : T.ety list -> LA.fun_sig -> LA.inst_fun_sig +val instantiate_fun_sig : + T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig (** Helper. diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index 981c2c46..a726eda0 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -377,10 +377,10 @@ let check_borrowed_values_invariant (ctx : C.eval_ctx) : unit = let info = { outer_borrow = false; outer_shared = false } in visitor#visit_eval_ctx info ctx -let check_primitive_value_type (cv : V.primitive_value) (ty : T.ety) : unit = +let check_literal_type (cv : V.literal) (ty : PV.literal_type) : unit = match (cv, ty) with - | PV.Scalar sv, T.Integer int_ty -> assert (sv.int_ty = int_ty) - | PV.Bool _, T.Bool | PV.Char _, T.Char | PV.String _, T.Str -> () + | PV.Scalar sv, PV.Integer int_ty -> assert (sv.int_ty = int_ty) + | PV.Bool _, PV.Bool | PV.Char _, PV.Char -> () | _ -> raise (Failure "Erroneous typing") let check_typing_invariant (ctx : C.eval_ctx) : unit = @@ -404,9 +404,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = method! visit_typed_value info tv = (* Check the current pair (value, type) *) (match (tv.V.value, tv.V.ty) with - | V.Primitive cv, ty -> check_primitive_value_type cv ty + | V.Literal cv, T.Literal ty -> check_literal_type cv ty (* ADT case *) - | V.Adt av, T.Adt (T.AdtId def_id, regions, tys) -> + | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in @@ -422,7 +422,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check that the field types are correct *) let field_types = Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id - tys + tys cgs in let fields_with_types = List.combine av.V.field_values field_types @@ -431,8 +431,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.Adt av, T.Adt (T.Tuple, regions, tys) -> + | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) -> assert (regions = []); + assert (cgs = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) @@ -441,20 +442,22 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys) -> ( + | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( assert (av.V.variant_id = None || aty_id = T.Option); - match (aty_id, av.V.field_values, regions, tys) with + match (aty_id, av.V.field_values, regions, tys, cgs) with (* Box *) - | T.Box, [ inner_value ], [], [ inner_ty ] - | T.Option, [ inner_value ], [], [ inner_ty ] -> + | T.Box, [ inner_value ], [], [ inner_ty ], [] + | T.Option, [ inner_value ], [], [ inner_ty ], [] -> assert (inner_value.V.ty = inner_ty) - | T.Option, _, [], [ _ ] -> + | T.Option, _, [], [ _ ], [] -> (* Option::None: nothing to check *) () - | T.Vec, fvs, [], [ vec_ty ] -> + | T.Vec, fvs, [], [ vec_ty ], [] -> List.iter (fun (v : V.typed_value) -> assert (v.ty = vec_ty)) fvs + | (T.Array | T.Slice | T.Str), _, _, _, _ -> + raise (Failure "Unexpected") | _ -> raise (Failure "Erroneous type")) | V.Bottom, _ -> (* Nothing to check *) () | V.Borrow bc, T.Ref (_, ref_ty, rkind) -> ( @@ -502,13 +505,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check the current pair (value, type) *) (match (atv.V.value, atv.V.ty) with (* ADT case *) - | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys) -> + | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) assert (List.length regions = List.length def.region_params); assert (List.length tys = List.length def.type_params); + assert (List.length cgs = List.length def.const_generic_params); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -518,7 +522,7 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check that the field types are correct *) let field_types = Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id - regions tys + regions tys cgs in let fields_with_types = List.combine av.V.field_values field_types @@ -527,8 +531,9 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.AAdt av, T.Adt (T.Tuple, regions, tys) -> + | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) -> assert (regions = []); + assert (cgs = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) @@ -537,11 +542,11 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys) -> ( + | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( assert (av.V.variant_id = None); - match (aty_id, av.V.field_values, regions, tys) with + match (aty_id, av.V.field_values, regions, tys, cgs) with (* Box *) - | T.Box, [ boxed_value ], [], [ boxed_ty ] -> + | T.Box, [ boxed_value ], [], [ boxed_ty ], [] -> assert (boxed_value.V.ty = boxed_ty) | _ -> raise (Failure "Erroneous type")) | V.ABottom, _ -> (* Nothing to check *) () diff --git a/compiler/Print.ml b/compiler/Print.ml index 23cebd4c..410b45e6 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -80,7 +80,7 @@ module Values = struct string = let ty_fmt : PT.etype_formatter = value_to_etype_formatter fmt in match v.value with - | Primitive cv -> PPV.literal_to_string cv + | Literal cv -> PPV.literal_to_string cv | Adt av -> ( let field_values = List.map (typed_value_to_string fmt) av.field_values diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 03252200..6f857b4f 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -6,11 +6,15 @@ open PureUtils type type_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; + global_decl_id_to_string : GlobalDeclId.id -> string; } type value_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; + global_decl_id_to_string : GlobalDeclId.id -> string; adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; @@ -20,6 +24,8 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = { type_var_id_to_string = fmt.type_var_id_to_string; type_decl_id_to_string = fmt.type_decl_id_to_string; + const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + global_decl_id_to_string = fmt.global_decl_id_to_string; } (* TODO: we need to store which variables we have encountered so far, and @@ -28,6 +34,7 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = type ast_formatter = { type_var_id_to_string : TypeVarId.id -> string; type_decl_id_to_string : TypeDeclId.id -> string; + const_generic_var_id_to_string : ConstGenericVarId.id -> string; adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_to_string : @@ -41,6 +48,8 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = { type_var_id_to_string = fmt.type_var_id_to_string; type_decl_id_to_string = fmt.type_decl_id_to_string; + const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + global_decl_id_to_string = fmt.global_decl_id_to_string; adt_variant_to_string = fmt.adt_variant_to_string; var_id_to_string = fmt.var_id_to_string; adt_field_names = fmt.adt_field_names; @@ -55,22 +64,38 @@ let fun_name_to_string = Print.fun_name_to_string let global_name_to_string = Print.global_name_to_string let option_to_string = Print.option_to_string let type_var_to_string = Print.Types.type_var_to_string +let const_generic_var_to_string = Print.Types.const_generic_var_to_string let integer_type_to_string = Print.PrimitiveValues.integer_type_to_string let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string let literal_to_string = Print.PrimitiveValues.literal_to_string let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) - (type_params : type_var list) : type_formatter = + (global_decls : A.global_decl GlobalDeclId.Map.t) + (type_params : type_var list) + (const_generic_params : const_generic_var list) : type_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in type_var_to_string var in + let const_generic_var_id_to_string vid = + let var = T.ConstGenericVarId.nth const_generic_params vid in + const_generic_var_to_string var + in let type_decl_id_to_string def_id = let def = T.TypeDeclId.Map.find def_id type_decls in name_to_string def.name in - { type_var_id_to_string; type_decl_id_to_string } + let global_decl_id_to_string def_id = + let def = T.GlobalDeclId.Map.find def_id global_decls in + name_to_string def.name + in + { + type_var_id_to_string; + type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; + } (* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter. @@ -81,11 +106,16 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (fun_decls : A.fun_decl FunDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) : ast_formatter = + (type_params : type_var list) + (const_generic_params : const_generic_var list) : ast_formatter = let type_var_id_to_string vid = let var = T.TypeVarId.nth type_params vid in type_var_to_string var in + let const_generic_var_id_to_string vid = + let var = T.ConstGenericVarId.nth const_generic_params vid in + const_generic_var_to_string var + in let type_decl_id_to_string def_id = let def = T.TypeDeclId.Map.find def_id type_decls in name_to_string def.name @@ -113,6 +143,7 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) in { type_var_id_to_string; + const_generic_var_id_to_string; type_decl_id_to_string; adt_variant_to_string; var_id_to_string; @@ -122,36 +153,50 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) global_decl_id_to_string; } +let assumed_ty_to_string (aty : assumed_ty) : string = + match aty with + | State -> "State" + | Result -> "Result" + | Error -> "Error" + | Fuel -> "Fuel" + | Option -> "Option" + | Vec -> "Vec" + | Array -> "Array" + | Slice -> "Slice" + | Str -> "Str" + let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match id with | AdtId id -> fmt.type_decl_id_to_string id | Tuple -> "" - | Assumed aty -> ( - match aty with - | State -> "State" - | Result -> "Result" - | Error -> "Error" - | Fuel -> "Fuel" - | Option -> "Option" - | Vec -> "Vec") + | Assumed aty -> assumed_ty_to_string aty + +(* TODO: duplicates Charon.PrintTypes.const_generic_to_string *) +let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) : + string = + match cg with + | ConstGenericGlobal id -> fmt.global_decl_id_to_string id + | ConstGenericVar id -> fmt.const_generic_var_id_to_string id + | ConstGenericValue lit -> literal_to_string lit let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = match ty with - | Adt (id, tys) -> ( + | Adt (id, tys, cgs) -> ( let tys = List.map (ty_to_string fmt false) tys in + let cgs = List.map (const_generic_to_string fmt) cgs in + let params = List.append tys cgs in match id with - | Tuple -> "(" ^ String.concat " * " tys ^ ")" + | Tuple -> + assert (cgs = []); + "(" ^ String.concat " * " tys ^ ")" | AdtId _ | Assumed _ -> - let tys_s = if tys = [] then "" else " " ^ String.concat " " tys in - let ty_s = type_id_to_string fmt id ^ tys_s in - if tys <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) + let params_s = + if params = [] then "" else " " ^ String.concat " " params + in + let ty_s = type_id_to_string fmt id ^ params_s in + if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) | TypeVar tv -> fmt.type_var_id_to_string tv - | Bool -> "bool" - | Char -> "char" - | Integer int_ty -> integer_type_to_string int_ty - | Str -> "str" - | Array aty -> "[" ^ ty_to_string fmt false aty ^ "; ?]" - | Slice sty -> "[" ^ ty_to_string fmt false sty ^ "]" + | Literal lty -> literal_type_to_string lty | Arrow (arg_ty, ret_ty) -> let ty = ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty @@ -248,8 +293,8 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State -> - (* This type is opaque: we can't get there *) + | State | Vec | Array | Slice | Str -> + (* Those types are opaque: we can't get there *) raise (Failure "Unreachable") | Result -> let variant_id = Option.get variant_id in @@ -272,10 +317,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) if variant_id = option_some_id then "@Option::Some " else if variant_id = option_none_id then "@Option::None" else - raise (Failure "Unreachable: improper variant id for result type") - | Vec -> - assert (variant_id = None); - "Vec") + raise (Failure "Unreachable: improper variant id for result type")) let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) (field_id : FieldId.id) : string = @@ -292,7 +334,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State | Fuel | Vec -> + | State | Fuel | Vec | Array | Slice | Str -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") | Result | Error | Option -> @@ -300,17 +342,17 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) raise (Failure "Unreachable")) (** TODO: we don't need a general function anymore (it is now only used for - patterns (i.e., patterns) + patterns) *) let adt_g_value_to_string (fmt : value_formatter) (value_to_string : 'v -> string) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : string = let field_values = List.map value_to_string field_values in match ty with - | Adt (Tuple, _) -> + | Adt (Tuple, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | Adt (AdtId def_id, _) -> + | Adt (AdtId def_id, _, _) -> (* "Regular" ADT *) let adt_ident = match variant_id with @@ -332,7 +374,7 @@ let adt_g_value_to_string (fmt : value_formatter) let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | Adt (Assumed aty, _) -> ( + | Adt (Assumed aty, _, _) -> ( (* Assumed type *) match aty with | State -> @@ -377,12 +419,13 @@ let adt_g_value_to_string (fmt : value_formatter) "@Option::None") else raise (Failure "Unreachable: improper variant id for result type") - | Vec -> + | Vec | Array | Slice | Str -> assert (variant_id = None); let field_values = List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values in - "Vec [" ^ String.concat "; " field_values ^ "]") + let id = assumed_ty_to_string aty in + id ^ " [" ^ String.concat "; " field_values ^ "]") | _ -> let fmt = value_to_type_formatter fmt in raise diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 5af28efd..9b5d9236 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -32,7 +32,11 @@ IdGen () module VarId = IdGen () +module ConstGenericVarId = T.ConstGenericVarId + type integer_type = T.integer_type [@@deriving show, ord] +type const_generic_var = T.const_generic_var [@@deriving show, ord] +type const_generic = T.const_generic [@@deriving show, ord] (** The assumed types for the pure AST. @@ -50,7 +54,16 @@ type integer_type = T.integer_type [@@deriving show, ord] this state is opaque to Aeneas (the user can define it, or leave it as assumed) *) -type assumed_ty = State | Result | Error | Fuel | Vec | Option +type assumed_ty = + | State + | Result + | Error + | Fuel + | Vec + | Option + | Array + | Slice + | Str [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather @@ -114,26 +127,28 @@ type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty polymorphic = false; }] +type literal_type = T.literal_type [@@deriving show, ord] + (** Ancestor for iter visitor for [ty] *) class ['self] iter_ty_base = object (_self : 'self) inherit [_] iter_type_id + inherit! [_] T.iter_const_generic + inherit! [_] PV.iter_literal_type method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () end (** Ancestor for map visitor for [ty] *) class ['self] map_ty_base = object (_self : 'self) inherit [_] map_type_id + inherit! [_] T.map_const_generic + inherit! [_] PV.map_literal_type method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x - - method visit_integer_type : 'env -> integer_type -> integer_type = - fun _ x -> x end type ty = - | Adt of type_id * ty list + | Adt of type_id * ty list * const_generic list (** {!Adt} encodes ADTs and tuples and assumed types. TODO: what about the ended regions? (ADTs may be parameterized @@ -142,12 +157,7 @@ type ty = such "partial" ADTs. *) | TypeVar of type_var_id - | Bool - | Char - | Integer of integer_type - | Str - | Array of ty (* TODO: this should be an assumed type?... *) - | Slice of ty (* TODO: this should be an assumed type?... *) + | Literal of literal_type | Arrow of ty * ty [@@deriving show, @@ -182,6 +192,7 @@ type type_decl = { def_id : TypeDeclId.id; name : name; type_params : type_var list; + const_generic_params : const_generic_var list; kind : type_decl_kind; } [@@deriving show] @@ -393,7 +404,12 @@ type qualif_id = which explains why we have the [type_params] field: a function or ADT constructor is always fully instantiated. *) -type qualif = { id : qualif_id; type_args : ty list } [@@deriving show] +type qualif = { + id : qualif_id; + type_args : ty list; + const_generic_args : const_generic list; +} +[@@deriving show] type field_id = FieldId.id [@@deriving show, ord] type var_id = VarId.id [@@deriving show, ord] @@ -716,6 +732,7 @@ type fun_sig_info = { *) type fun_sig = { type_params : type_var list; + const_generic_params : const_generic_var list; (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) inputs : ty list; (** The input types. diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 72084dfc..ef8bac37 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -5,8 +5,8 @@ open PureUtils (** Utility function, used for type checking *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) - (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) : - ty list = + (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) + (cgs : const_generic list) : ty list = match type_id with | Tuple -> (* Tuple *) @@ -15,7 +15,7 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) | AdtId def_id -> (* "Regular" ADT *) let def = TypeDeclId.Map.find def_id type_decls in - type_decl_get_instantiated_fields_types def variant_id tys + type_decl_get_instantiated_fields_types def variant_id tys cgs | Assumed aty -> ( (* Assumed type *) match aty with @@ -47,7 +47,10 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) else if variant_id = option_none_id then [] else raise (Failure "Unreachable: improper variant id for result type") - | Vec -> raise (Failure "Unreachable: `Vector` values are opaque")) + | Vec | Array | Slice | Str -> + raise + (Failure + "Unreachable: trying to access the fields of an opaque type")) type tc_ctx = { type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *) @@ -56,7 +59,7 @@ type tc_ctx = { env : ty VarId.Map.t; (** Environment from variables to types *) } -let check_literal (v : literal) (ty : ty) : unit = +let check_literal (v : literal) (ty : literal_type) : unit = match (ty, v) with | Integer int_ty, PV.Scalar sv -> assert (int_ty = sv.PV.int_ty) | Bool, Bool _ | Char, Char _ -> () @@ -66,7 +69,7 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = log#ldebug (lazy ("check_typed_pattern: " ^ show_typed_pattern v)); match v.value with | PatConstant cv -> - check_literal cv v.ty; + check_literal cv (ty_as_literal v.ty); ctx | PatDummy -> ctx | PatVar (var, _) -> @@ -75,13 +78,9 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = { ctx with env } | PatAdt av -> (* Compute the field types *) - let type_id, tys = - match v.ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Inconsistently typed value") - in + let type_id, tys, cgs = ty_as_adt v.ty in let field_tys = - get_adt_field_types ctx.type_decls type_id av.variant_id tys + get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs in let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx = if ty <> v.ty then ( @@ -108,7 +107,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match VarId.Map.find_opt var_id ctx.env with | None -> () | Some ty -> assert (ty = e.ty)) - | Const cv -> check_literal cv e.ty + | Const cv -> check_literal cv (ty_as_literal e.ty) | App (app, arg) -> let input_ty, output_ty = destruct_arrow app.ty in assert (input_ty = arg.ty); @@ -130,33 +129,31 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) let adt_ty, field_ty = destruct_arrow e.ty in - let adt_id, adt_type_args = - match adt_ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in + let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in (* Check the ADT type *) assert (adt_id = proj_adt_id); assert (adt_type_args = qualif.type_args); + assert (adt_cg_args = qualif.const_generic_args); (* Retrieve and check the expected field type *) let variant_id = None in let expected_field_tys = get_adt_field_types ctx.type_decls proj_adt_id variant_id - qualif.type_args + qualif.type_args qualif.const_generic_args in let expected_field_ty = FieldId.nth expected_field_tys field_id in assert (expected_field_ty = field_ty) | AdtCons id -> ( let expected_field_tys = get_adt_field_types ctx.type_decls id.adt_id id.variant_id - qualif.type_args + qualif.type_args qualif.const_generic_args in let field_tys, adt_ty = destruct_arrows e.ty in assert (expected_field_tys = field_tys); match adt_ty with - | Adt (type_id, tys) -> + | Adt (type_id, tys, cgs) -> assert (type_id = id.adt_id); - assert (tys = qualif.type_args) + assert (tys = qualif.type_args); + assert (cgs = qualif.const_generic_args) | _ -> raise (Failure "Unreachable"))) | Let (monadic, pat, re, e_next) -> let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in @@ -172,7 +169,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx scrut; match switch_body with | If (e_then, e_else) -> - assert (scrut.ty = Bool); + assert (scrut.ty = Literal Bool); assert (e_then.ty = e.ty); assert (e_else.ty = e.ty); check_texpression ctx e_then; @@ -202,15 +199,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = | Some ty -> assert (ty = e.ty)); (* Check the fields *) (* Retrieve and check the expected field type *) - let adt_id, adt_type_args = - match e.ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in + let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in assert (adt_id = AdtId supd.struct_id); let variant_id = None in let expected_field_tys = get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args + adt_cg_args in List.iter (fun (fid, fe) -> diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 88b18e89..1c8d8921 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -62,7 +62,7 @@ let dest_arrow_ty (ty : ty) : ty * ty = | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") -let compute_literal_ty (cv : literal) : ty = +let compute_literal_type (cv : literal) : literal_type = match cv with | PV.Scalar sv -> Integer sv.PV.int_ty | Bool _ -> Bool @@ -71,7 +71,7 @@ let compute_literal_ty (cv : literal) : ty = let var_get_id (v : var) : VarId.id = v.id let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = - let ty = compute_literal_ty cv in + let ty = Literal (compute_literal_type cv) in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) @@ -90,11 +90,13 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) { var_id; name; projection } (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = +let ty_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = let obj = object inherit [_] map_ty method! visit_TypeVar _ var_id = tsubst var_id + method! visit_ConstGenericVar _ var_id = cgsubst var_id end in obj#visit_ty () ty @@ -109,6 +111,10 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp +let make_const_generic_subst (vars : const_generic_var list) + (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = + Substitute.make_const_generic_subst_from_vars vars cgs + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -132,14 +138,17 @@ let type_decl_get_fields (def : type_decl) (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) : ty list = + (opt_variant_id : VariantId.id option) (types : ty list) + (cgs : const_generic list) : ty list = let ty_subst = make_type_subst def.type_params types in + let cg_subst = make_const_generic_subst def.const_generic_params cgs in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst f.field_ty) fields + List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : +let fun_sig_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : inst_fun_sig = - let subst = ty_substitute tsubst in + let subst = ty_substitute tsubst cgsubst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -181,9 +190,9 @@ let is_global (e : texpression) : bool = let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list = +let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = match ty with - | Adt (id, tys) -> (id, tys) + | Adt (id, tys, cgs) -> (id, tys, cgs) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -291,13 +300,19 @@ let opt_destruct_function_call (e : texpression) : let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, tys, cgs) -> + assert (cgs = []); + Some (Collections.List.to_cons_nil tys) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = - match ty with Adt (Tuple, tys) -> Some tys | _ -> None + match ty with + | Adt (Tuple, tys, cgs) -> + assert (cgs = []); + Some tys + | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = let ty = Arrow (x.ty, e.ty) in @@ -351,7 +366,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : let mk_switch (scrut : texpression) (sb : switch_body) : texpression = (* Sanity check: the scrutinee has the proper type *) (match sb with - | If (_, _) -> assert (scrut.ty = Bool) + | If (_, _) -> assert (scrut.ty = Literal Bool) | Match branches -> List.iter (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty)) @@ -368,14 +383,14 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) + match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) -let mk_bool_ty : ty = Bool -let mk_unit_ty : ty = Adt (Tuple, []) +let mk_bool_ty : ty = Literal Bool +let mk_unit_ty : ty = Adt (Tuple, [], []) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -415,7 +430,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -426,11 +441,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys } in + let qualif = { id; type_args = tys; const_generic_args = [] } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -441,36 +456,39 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option) { value; ty = adt_ty } let ty_as_integer (t : ty) : T.integer_type = - match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") + match t with + | Literal (Integer int_ty) -> int_ty + | _ -> raise (Failure "Unreachable") -(* TODO: move *) -let type_decl_is_enum (def : T.type_decl) : bool = - match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false +let ty_as_literal (t : ty) : T.literal_type = + match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) -let mk_error_ty : ty = Adt (Assumed Error, []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, []) +let mk_state_ty : ty = Adt (Assumed State, [], []) +let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) +let mk_error_ty : ty = Adt (Assumed Error, [], []) +let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ]) -> ty + | Adt (Assumed Result, [ ty ], cgs) -> + assert (cgs = []); + ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -483,11 +501,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -496,7 +514,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ]) in + let ty = Adt (Assumed Result, [ ty ], []) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -508,7 +526,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ]) in + let ty = Adt (Assumed Result, [ v.ty ], []) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -543,11 +561,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args = ty_as_adt pat.ty in + let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args } in + let qualif = { id = qualif_id; type_args; const_generic_args } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 5dc8664a..ba2a6525 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -240,6 +240,8 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = r_to_string; type_var_id_to_string; type_decl_id_to_string = ast_fmt.type_decl_id_to_string; + const_generic_var_id_to_string = ast_fmt.const_generic_var_id_to_string; + global_decl_id_to_string = ast_fmt.global_decl_id_to_string; adt_variant_to_string = ast_fmt.adt_variant_to_string; var_id_to_string; adt_field_names = ast_fmt.adt_field_names; @@ -247,10 +249,12 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = let type_params = ctx.fun_decl.signature.type_params in + let cg_params = ctx.fun_decl.signature.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string = let fmt = bs_ctx_to_ctx_formatter ctx in @@ -273,8 +277,12 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = let type_params = def.type_params in + let cg_params = def.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in + let global_decls = ctx.global_context.llbc_global_decls in + let fmt = + PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + in PrintPure.type_decl_to_string fmt def let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = @@ -283,21 +291,25 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = let type_params = sg.type_params in + let cg_params = sg.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in + let cg_params = def.signature.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_decl_to_string fmt def @@ -315,16 +327,17 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = Print.Values.abs_to_string fmt verbose indent indent_incr abs let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) : - inst_fun_sig = + (back_id : T.RegionGroupId.id option) (tys : ty list) + (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig = (* Lookup the non-instantiated function signature *) let sg = (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg in (* Create the substitution *) let tsubst = make_type_subst sg.type_params tys in + let cgsubst = make_const_generic_subst sg.const_generic_params cgs in (* Apply *) - fun_sig_substitute tsubst sg + fun_sig_substitute tsubst cgsubst sg let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = @@ -380,17 +393,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys) -> ( + | T.Adt (type_id, regions, tys, cgs) -> ( (* Can't translate types with regions for now *) assert (regions = []); let tys = List.map translate tys in match type_id with - | T.AdtId adt_id -> Adt (AdtId adt_id, tys) + | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs) | T.Tuple -> mk_simpl_tuple_ty tys | T.Assumed aty -> ( match aty with - | T.Vec -> Adt (Assumed Vec, tys) - | T.Option -> Adt (Assumed Option, tys) + | T.Vec -> Adt (Assumed Vec, tys, cgs) + | T.Option -> Adt (Assumed Option, tys, cgs) | T.Box -> ( (* Eliminate the boxes *) match tys with @@ -399,15 +412,13 @@ let rec translate_sty (ty : T.sty) : ty = raise (Failure "Box/vec/option type with incorrect number of arguments") - ))) + ) + | T.Array -> Adt (Assumed Array, tys, cgs) + | T.Slice -> Adt (Assumed Slice, tys, cgs) + | T.Str -> Adt (Assumed Str, tys, cgs))) | TypeVar vid -> TypeVar vid - | Bool -> Bool - | Char -> Char + | Literal ty -> Literal ty | Never -> raise (Failure "Unreachable") - | Integer int_ty -> Integer int_ty - | Str -> Str - | Array ty -> Array (translate ty) - | Slice ty -> Slice (translate ty) | Ref (_, rty, _) -> translate rty let translate_field (f : T.field) : field = @@ -445,8 +456,9 @@ let translate_type_decl (def : T.type_decl) : type_decl = (* Can't translate types with regions for now *) assert (def.region_params = []); let type_params = def.type_params in + let const_generic_params = def.const_generic_params in let kind = translate_type_decl_kind def.T.kind in - { def_id; name; type_params; kind } + { def_id; name; type_params; const_generic_params; kind } (** Translate a type, seen as an input/output of a forward function (preserve all borrows, etc.) @@ -455,7 +467,7 @@ let translate_type_decl (def : T.type_decl) : type_decl = let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty type_infos in match ty with - | T.Adt (type_id, regions, tys) -> ( + | T.Adt (type_id, regions, tys, cgs) -> ( (* Can't translate types with regions for now *) assert (regions = []); (* Translate the type parameters *) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index a6e11363..e2cdc726 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -36,8 +36,8 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) (* Boolean expansion: there should be two branches *) match ls with | [ - (Some (V.SePrimitive (PV.Bool true)), true_exp); - (Some (V.SePrimitive (PV.Bool false)), false_exp); + (Some (V.SeLiteral (PV.Bool true)), true_exp); + (Some (V.SeLiteral (PV.Bool false)), false_exp); ] -> ExpandBool (true_exp, false_exp) | _ -> raise (Failure "Ill-formed boolean expansion")) @@ -50,7 +50,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) let get_scalar (see : V.symbolic_expansion option) : V.scalar_value = match see with - | Some (V.SePrimitive (PV.Scalar cv)) -> + | Some (V.SeLiteral (PV.Scalar cv)) -> assert (cv.PV.int_ty = int_ty); cv | _ -> raise (Failure "Unreachable") diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index 9ba73c7e..ba5e237b 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -32,33 +32,39 @@ type pure_fun_translation = fun_and_loops * fun_and_loops list let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = let type_params = def.type_params in + let cg_params = def.const_generic_params in let type_decls = ctx.type_context.type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in + let global_decls = ctx.global_context.global_decls in + let fmt = + PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + in PrintPure.type_decl_to_string fmt def -let type_id_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = - let type_params = def.type_params in - let type_decls = ctx.type_context.type_decls in - let fmt = PrintPure.mk_type_formatter type_decls type_params in - PrintPure.type_decl_to_string fmt def +let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string = + Print.fun_name_to_string + (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = let type_params = sg.type_params in + let cg_params = sg.const_generic_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in let global_decls = ctx.global_context.global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = let type_params = def.signature.type_params in + let cg_params = def.signature.const_generic_params in let type_decls = ctx.type_context.type_decls in let fun_decls = ctx.fun_context.fun_decls in let global_decls = ctx.global_context.global_decls in let fmt = PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params + cg_params in PrintPure.fun_decl_to_string fmt def diff --git a/compiler/Values.ml b/compiler/Values.ml index 3d6bc9c1..f70b9b4b 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -147,7 +147,7 @@ class ['self] map_typed_value_base = (** An untyped value, used in the environments *) type value = - | Primitive of literal (** Non-symbolic primitive value *) + | Literal of literal (** Non-symbolic primitive value *) | Adt of adt_value (** Enumerations and structures *) | Bottom (** No value (uninitialized or moved value) *) | Borrow of borrow_content (** A borrowed value *) @@ -1014,7 +1014,7 @@ type abs = { TODO: this should rather be name "expanded_symbolic" *) type symbolic_expansion = - | SePrimitive of literal + | SeLiteral of literal | SeAdt of (VariantId.id option * symbolic_value list) | SeMutRef of BorrowId.id * symbolic_value | SeSharedRef of BorrowId.Set.t * symbolic_value |