From a1e24c2a13713b015abc9a93e6915b6d4a6f22fe Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Aug 2023 14:10:15 +0200 Subject: Make progress --- compiler/ExtractBase.ml | 13 +++--- compiler/InterpreterExpansion.ml | 3 +- compiler/InterpreterExpressions.ml | 86 ++++++++++++++++++++++++++------------ compiler/InterpreterStatements.ml | 49 ++++++++++++---------- compiler/PrimitiveValuesUtils.ml | 1 + compiler/PrintPure.ml | 15 ++++++- compiler/Pure.ml | 1 + compiler/PureMicroPasses.ml | 46 ++++++++++++++++---- compiler/PureTypeCheck.ml | 6 ++- compiler/Substitute.ml | 5 +++ compiler/SymbolicToPure.ml | 17 +++++--- compiler/TypesAnalysis.ml | 6 ++- compiler/ValuesUtils.ml | 1 + compiler/dune | 5 ++- 14 files changed, 178 insertions(+), 76 deletions(-) create mode 100644 compiler/PrimitiveValuesUtils.ml (limited to 'compiler') diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 655bb033..bff6a360 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -292,7 +292,7 @@ type formatter = { indices to names, the responsability of finding a proper index is delegated to helper functions. *) - extract_primitive_value : F.formatter -> bool -> primitive_value -> unit; + extract_literal : F.formatter -> bool -> literal -> unit; (** Format a constant value. Inputs: @@ -674,7 +674,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" else raise (Failure "Unreachable") - | Assumed (State | Vec | Fuel) -> raise (Failure "Unreachable") + | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) -> + raise (Failure "Unreachable") | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in match def.kind with @@ -688,10 +689,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let field_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Error | Fuel | Option) -> - raise (Failure "Unreachable") - | Assumed Vec -> - (* We can't directly have access to the fields of a vector *) + | Assumed + ( State | Result | Error | Fuel | Option | Vec | Array | Slice | Str + | Range ) -> + (* We can't directly have access to the fields of those types *) raise (Failure "Unreachable") | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index 3b196571..81e73e3e 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -707,7 +707,8 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) -> (* Ok *) expand_symbolic_value_no_branching config sv None - | T.Adt (Assumed (Vec | Option | Array | Slice | Str), _, _, _) -> + | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _) + -> (* We can't expand those *) raise (Failure diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index bb159f05..c3ff8d4f 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -232,7 +232,7 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : match op with | Expressions.Constant (ty, cv) -> (* No need to reorganize the context *) - literal_to_typed_value ty cv |> ignore; + literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore; cf ctx | Expressions.Copy p -> (* Access the value *) @@ -260,7 +260,8 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> cf (literal_to_typed_value ty cv) ctx + | Expressions.Constant (ty, cv) -> + cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx | Expressions.Copy p -> (* Access the value *) let access = Read in @@ -350,21 +351,21 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) (* Apply the unop *) let apply cf (v : V.typed_value) : m_fun = match (unop, v.V.value) with - | E.Not, V.Primitive (Bool b) -> - cf (Ok { v with V.value = V.Primitive (Bool (not b)) }) - | E.Neg, V.Primitive (PV.Scalar sv) -> ( + | E.Not, V.Literal (Bool b) -> + cf (Ok { v with V.value = V.Literal (Bool (not b)) }) + | E.Neg, V.Literal (PV.Scalar sv) -> ( let i = Z.neg sv.PV.value in match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) - | Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) })) - | E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> ( + | Ok sv -> cf (Ok { v with V.value = V.Literal (PV.Scalar sv) })) + | E.Cast (src_ty, tgt_ty), V.Literal (PV.Scalar sv) -> ( assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> - let ty = T.Integer tgt_ty in - let value = V.Primitive (PV.Scalar sv) in + let ty = T.Literal (Integer tgt_ty) in + let value = V.Literal (PV.Scalar sv) in cf (Ok { V.ty; value })) | _ -> raise (Failure "Invalid input for unop") in @@ -381,9 +382,9 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand) let res_sv_id = C.fresh_symbolic_value_id () in let res_sv_ty = match (unop, v.V.ty) with - | E.Not, T.Bool -> T.Bool - | E.Neg, T.Integer int_ty -> T.Integer int_ty - | E.Cast (_, tgt_ty), _ -> T.Integer tgt_ty + | E.Not, (T.Literal Bool as lty) -> lty + | E.Neg, (T.Literal (Integer _) as lty) -> lty + | E.Cast (_, tgt_ty), _ -> T.Literal (Integer tgt_ty) | _ -> raise (Failure "Invalid input for unop") in let res_sv = @@ -418,11 +419,11 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) (* Equality/inequality check is primitive only for a subset of types *) assert (ty_is_primitively_copyable v1.ty); let b = v1 = v2 in - Ok { V.value = V.Primitive (Bool b); ty = T.Bool }) + Ok { V.value = V.Literal (Bool b); ty = T.Literal Bool }) else (* For the non-equality operations, the input values are necessarily scalars *) match (v1.V.value, v2.V.value) with - | V.Primitive (PV.Scalar sv1), V.Primitive (PV.Scalar sv2) -> ( + | V.Literal (PV.Scalar sv1), V.Literal (PV.Scalar sv2) -> ( (* There are binops which require the two operands to have the same type, and binops for which it is not the case. There are also binops which return booleans, and binops which @@ -442,7 +443,9 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) | E.BitOr | E.Shl | E.Shr | E.Ne | E.Eq -> raise (Failure "Unreachable") in - Ok ({ V.value = V.Primitive (Bool b); ty = T.Bool } : V.typed_value) + Ok + ({ V.value = V.Literal (Bool b); ty = T.Literal Bool } + : V.typed_value) | E.Div | E.Rem | E.Add | E.Sub | E.Mul | E.BitXor | E.BitAnd | E.BitOr -> ( (* The two operands must have the same type and the result is an integer *) @@ -470,8 +473,8 @@ let eval_binary_op_concrete_compute (binop : E.binop) (v1 : V.typed_value) | Ok sv -> Ok { - V.value = V.Primitive (PV.Scalar sv); - ty = Integer sv1.int_ty; + V.value = V.Literal (PV.Scalar sv); + ty = T.Literal (Integer sv1.int_ty); }) | E.Shl | E.Shr -> raise Unimplemented | E.Ne | E.Eq -> raise (Failure "Unreachable")) @@ -507,19 +510,19 @@ let eval_binary_op_symbolic (config : C.config) (binop : E.binop) assert (v1.ty = v2.ty); (* Equality/inequality check is primitive only for a subset of types *) assert (ty_is_primitively_copyable v1.ty); - T.Bool) + T.Literal Bool) else (* Other operations: input types are integers *) match (v1.V.ty, v2.V.ty) with - | T.Integer int_ty1, T.Integer int_ty2 -> ( + | T.Literal (Integer int_ty1), T.Literal (Integer int_ty2) -> ( match binop with | E.Lt | E.Le | E.Ge | E.Gt -> assert (int_ty1 = int_ty2); - T.Bool + T.Literal Bool | E.Div | E.Rem | E.Add | E.Sub | E.Mul | E.BitXor | E.BitAnd | E.BitOr -> assert (int_ty1 = int_ty2); - T.Integer int_ty1 + T.Literal (Integer int_ty1) | E.Shl | E.Shr -> raise Unimplemented | E.Ne | E.Eq -> raise (Failure "Unreachable")) | _ -> raise (Failure "Invalid inputs for binop") @@ -653,7 +656,7 @@ let eval_rvalue_aggregate (config : C.config) | E.AggregatedTuple -> let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in let v = V.Adt { variant_id = None; field_values = values } in - let ty = T.Adt (T.Tuple, [], tys) in + let ty = T.Adt (T.Tuple, [], tys, []) in let aggregated : V.typed_value = { V.value = v; ty } in (* Call the continuation *) cf aggregated ctx @@ -664,20 +667,20 @@ let eval_rvalue_aggregate (config : C.config) assert (List.length values = 1) else raise (Failure "Unreachable"); (* Construt the value *) - let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in + let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in let av : V.adt_value = { V.variant_id = Some variant_id; V.field_values = values } in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx - | E.AggregatedAdt (def_id, opt_variant_id, regions, types) -> + | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) -> (* Sanity checks *) let type_decl = C.ctx_lookup_type_decl ctx def_id in assert (List.length type_decl.region_params = List.length regions); let expected_field_types = Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id - types + types cgs in assert ( expected_field_types @@ -686,10 +689,41 @@ let eval_rvalue_aggregate (config : C.config) let av : V.adt_value = { V.variant_id = opt_variant_id; V.field_values = values } in - let aty = T.Adt (T.AdtId def_id, regions, types) in + let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in (* Call the continuation *) cf aggregated ctx + | E.AggregatedRange ety -> + (* There should be two fields exactly *) + let v0, v1 = + match values with + | [ v0; v1 ] -> (v0, v1) + | _ -> raise (Failure "Unreachable") + in + (* Ranges are parametric over the type of indices. For now we only + support scalars, which can be of any type *) + assert (literal_type_is_integer (ty_as_literal ety)); + assert (v0.ty = ety); + assert (v1.ty = ety); + (* Construct the value *) + let av : V.adt_value = + { V.variant_id = None; V.field_values = values } + in + let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in + let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in + (* Call the continuation *) + cf aggregated ctx + | E.AggregatedArray (ety, cg) -> + (* Sanity check: all the values have the proper type *) + assert (List.for_all (fun (v : V.typed_value) -> v.V.ty = ety) values); + (* Sanity check: the number of values is consistent with the length *) + let len = (literal_as_scalar (const_generic_as_literal cg)).value in + assert (Z.to_int len = List.length values); + let v = V.Adt { variant_id = None; field_values = values } in + let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in + let aggregated : V.typed_value = { V.value = v; ty } in + (* Call the continuation *) + cf aggregated ctx in (* Compose and apply *) comp eval_ops compute cf diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index cd5f8c3e..79fe79e7 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -893,7 +893,7 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) match config.mode with | ConcreteMode -> (* Treat the evaluation of the global as a call to the global body (without arguments) *) - (eval_local_function_call_concrete config global.body_id [] [] [] dest) + (eval_local_function_call_concrete config global.body_id [] [] [] [] dest) cf ctx | SymbolicMode -> (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be @@ -1044,10 +1044,10 @@ and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = match call.func with | A.Regular fid -> eval_local_function_call config fid call.region_args call.type_args - call.args call.dest + call.const_generic_args call.args call.dest | A.Assumed fid -> eval_non_local_function_call config fid call.region_args call.type_args - call.args call.dest + call.const_generic_args call.args call.dest (** Evaluate a local (i.e., non-assumed) function call in concrete mode *) and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) @@ -1135,19 +1135,20 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) (** Evaluate a local (i.e., non-assumed) function call in symbolic mode *) and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> (* Retrieve the (correctly instantiated) signature *) let def = C.ctx_lookup_fun_decl ctx fid in let sg = def.A.signature in (* Instantiate the signature and introduce fresh abstraction and region ids * while doing so *) - let inst_sg = instantiate_fun_sig type_args sg in + let inst_sg = instantiate_fun_sig type_args cg_args sg in (* Sanity check *) assert (List.length args = List.length def.A.signature.inputs); (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg - region_args type_args args dest cf ctx + region_args type_args cg_args args dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. @@ -1157,7 +1158,8 @@ and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) and eval_function_call_symbolic_from_inst_sig (config : C.config) (fid : A.fun_id) (inst_sg : A.inst_fun_sig) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> assert (region_args = []); (* Generate a fresh symbolic value for the return value *) @@ -1225,8 +1227,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids type_args args - args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args + args args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1289,8 +1291,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) (** Evaluate a non-local function call in symbolic mode *) and eval_non_local_function_call_symbolic (config : C.config) (fid : A.assumed_fun_id) (region_args : T.erased_region list) - (type_args : T.ety list) (args : E.operand list) (dest : E.place) : - st_cm_fun = + (type_args : T.ety list) (cg_args : T.const_generic list) + (args : E.operand list) (dest : E.place) : st_cm_fun = fun cf ctx -> (* Sanity check: make sure the type parameters don't contain regions - * this is a current limitation of our synthesis *) @@ -1308,7 +1310,7 @@ and eval_non_local_function_call_symbolic (config : C.config) | A.BoxFree -> (* Degenerate case: box_free - note that this is not really a function * call: no need to call a "synthesize_..." function *) - eval_box_free config region_args type_args args dest (cf Unit) ctx + eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx | _ -> (* "Normal" case: not box_free *) (* In symbolic mode, the behaviour of a function call is completely defined @@ -1319,18 +1321,20 @@ and eval_non_local_function_call_symbolic (config : C.config) | A.BoxFree -> (* should have been treated above *) raise (Failure "Unreachable") - | _ -> instantiate_fun_sig type_args (Assumed.get_assumed_sig fid) + | _ -> + instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid) in (* Evaluate the function call *) eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig - region_args type_args args dest cf ctx + region_args type_args cg_args args dest cf ctx (** Evaluate a non-local (i.e, assumed) function call such as [Box::deref] (auxiliary helper for [eval_statement]) *) and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> (* Debug *) log#ldebug @@ -1349,23 +1353,24 @@ and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) match config.mode with | C.ConcreteMode -> eval_non_local_function_call_concrete config fid region_args type_args - args dest (cf Unit) ctx + cg_args args dest (cf Unit) ctx | C.SymbolicMode -> eval_non_local_function_call_symbolic config fid region_args type_args - args dest cf ctx + cg_args args dest cf ctx (** Evaluate a local (i.e, not assumed) function call (auxiliary helper for [eval_statement]) *) and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = match config.mode with | ConcreteMode -> - eval_local_function_call_concrete config fid region_args type_args args - dest + eval_local_function_call_concrete config fid region_args type_args cg_args + args dest | SymbolicMode -> - eval_local_function_call_symbolic config fid region_args type_args args - dest + eval_local_function_call_symbolic config fid region_args type_args cg_args + args dest (** Evaluate a statement seen as a function body *) and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun = diff --git a/compiler/PrimitiveValuesUtils.ml b/compiler/PrimitiveValuesUtils.ml new file mode 100644 index 00000000..0000916d --- /dev/null +++ b/compiler/PrimitiveValuesUtils.ml @@ -0,0 +1 @@ +include Charon.PrimitiveValuesUtils diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 33a86df5..8fb6d644 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -164,6 +164,7 @@ let assumed_ty_to_string (aty : assumed_ty) : string = | Array -> "Array" | Slice -> "Slice" | Str -> "Str" + | Range -> "Range" let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match id with @@ -293,9 +294,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State | Vec | Array | Slice | Str -> + | State | Array | Slice | Str -> (* Those types are opaque: we can't get there *) raise (Failure "Unreachable") + | Vec -> "@Vec" + | Range -> "@Range" | Result -> let variant_id = Option.get variant_id in if variant_id = result_return_id then "@Result::Return" @@ -334,6 +337,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with + | Range -> FieldId.to_string field_id | State | Fuel | Vec | Array | Slice | Str -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") @@ -425,7 +429,14 @@ let adt_g_value_to_string (fmt : value_formatter) List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values in let id = assumed_ty_to_string aty in - id ^ " [" ^ String.concat "; " field_values ^ "]") + id ^ " [" ^ String.concat "; " field_values ^ "]" + | Range -> + assert (variant_id = None); + let field_values = + List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values + in + let id = assumed_ty_to_string aty in + id ^ " {" ^ String.concat "; " field_values ^ "}") | _ -> let fmt = value_to_type_formatter fmt in raise diff --git a/compiler/Pure.ml b/compiler/Pure.ml index b90ef60a..551ebf7b 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -64,6 +64,7 @@ type assumed_ty = | Array | Slice | Str + | Range [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 74f3c576..00620c58 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -585,6 +585,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; type_args = _; + const_generic_args = _; } -> (* Lookup the def *) let decl = @@ -1086,6 +1087,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; type_args; + const_generic_args; } -> (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) @@ -1106,7 +1108,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = * [x.field] for some variable [x], and where the projection * is for the proper ADT *) let to_var_proj (i : int) (arg : texpression) : - (ty list * var_id) option = + (ty list * const_generic list * var_id) option = match arg.e with | App (proj, x) -> ( match (proj.e, x.e) with @@ -1115,13 +1117,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = id = Proj { adt_id = AdtId proj_adt_id; field_id }; type_args = proj_type_args; + const_generic_args = proj_const_generic_args; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if proj_adt_id = adt_id && FieldId.to_int field_id = i - then Some (proj_type_args, v) + then + Some (proj_type_args, proj_const_generic_args, v) else None | _ -> None) | _ -> None @@ -1132,12 +1136,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = if List.length args = num_fields then (* Check that this is the same variable we project from - * note that we checked above that there is at least one field *) - let (_, x), end_args = Collections.List.pop args in - if List.for_all (fun (_, y) -> y = x) end_args then ( + let (_, _, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, _, y) -> y = x) end_args then ( (* We can substitute *) (* Sanity check: all types correct *) assert ( - List.for_all (fun (tys, _) -> tys = type_args) args); + List.for_all + (fun (tys, cgs, _) -> + tys = type_args && cgs = const_generic_args) + args); { e with e = Var x }) else super#visit_texpression env e else super#visit_texpression env e @@ -1156,6 +1163,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = Proj { adt_id = AdtId proj_adt_id; field_id }; type_args = _; + const_generic_args = _; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) @@ -1354,6 +1362,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig = { type_params = fun_sig.type_params; + const_generic_params = fun_sig.const_generic_params; inputs = inputs_tys; output; doutputs; @@ -1554,8 +1563,11 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | A.BoxFree, _ -> assert (args = []); mk_unit_rvalue - | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen - | A.VecIndex | A.VecIndexMut ), + | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen + | VecIndex | VecIndexMut | ArraySharedSubslice + | ArrayMutSubslice | SliceSharedIndex | SliceMutIndex + | SliceSharedSubslice | SliceMutSubslice | ArraySharedIndex + | ArrayMutIndex | ArrayToSharedSlice | ArrayToMutSlice ), _ ) -> super#visit_texpression env e) | _ -> super#visit_texpression env e) @@ -2130,7 +2142,14 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { type_params; inputs; output; doutputs; info } = + let { + type_params; + const_generic_params; + inputs; + output; + doutputs; + info; + } = decl.signature in let { @@ -2158,7 +2177,16 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : effect_info; } in - let signature = { type_params; inputs; output; doutputs; info } in + let signature = + { + type_params; + const_generic_params; + inputs; + output; + doutputs; + info; + } + in { decl with signature } in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index ef8bac37..0f64720e 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -46,7 +46,11 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) if variant_id = option_some_id then [ ty ] else if variant_id = option_none_id then [] else - raise (Failure "Unreachable: improper variant id for result type") + raise (Failure "Unreachable: improper variant id for option type") + | Range -> + let ty = Collections.List.to_cons_nil tys in + assert (variant_id = None); + [ ty; ty ] | Vec | Array | Slice | Str -> raise (Failure diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index a1b1572e..38850243 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -230,6 +230,11 @@ let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) if adt.V.variant_id = Some T.option_some_id then type_params else if adt.V.variant_id = Some T.option_none_id then [] else raise (Failure "Unreachable") + | T.Range -> + assert (List.length region_params = 0); + assert (List.length type_params = 1); + assert (List.length cg_params = 0); + type_params | T.Array | T.Slice | T.Str -> (* Those types don't have fields *) raise (Failure "Unreachable")) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index a6d2784b..958c1bc8 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -416,7 +416,8 @@ let rec translate_sty (ty : T.sty) : ty = ) | T.Array -> Adt (Assumed Array, tys, cgs) | T.Slice -> Adt (Assumed Slice, tys, cgs) - | T.Str -> Adt (Assumed Str, tys, cgs))) + | T.Str -> Adt (Assumed Str, tys, cgs) + | T.Range -> Adt (Assumed Range, tys, cgs))) | TypeVar vid -> TypeVar vid | Literal ty -> Literal ty | Never -> raise (Failure "Unreachable") @@ -472,6 +473,7 @@ let translate_type_id (id : T.type_id) : type_id = | T.Array -> Array | T.Slice -> Slice | T.Str -> Str + | T.Range -> Range | T.Box -> (* Boxes have to be eliminated: this type id shouldn't be translated *) @@ -493,7 +495,8 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let t_tys = List.map translate tys in (* Eliminate boxes and simplify tuples *) match type_id with - | AdtId _ | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str) -> + | AdtId _ + | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> (* No general parametricity for now *) assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); let type_id = translate_type_id type_id in @@ -537,7 +540,8 @@ let rec translate_back_ty (type_infos : TA.type_infos) match ty with | T.Adt (type_id, _, tys, cgs) -> ( match type_id with - | T.AdtId _ | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str) -> + | T.AdtId _ + | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows type_infos ty)); let type_id = translate_type_id type_id in @@ -1037,7 +1041,8 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with | T.AdtId _ - | T.Assumed (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str) -> + | T.Assumed + (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> assert (field_values = []); None | T.Tuple -> @@ -1183,7 +1188,8 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with | T.AdtId _ - | T.Assumed (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str) -> + | T.Assumed + (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> assert (field_values = []); (ctx, None) | T.Tuple -> @@ -2272,6 +2278,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) * know how to expand values like vectors or arrays, because they have a variable number * of fields!) *) raise (Failure "Attempt to expand a non-expandable value") + | T.Assumed Range -> raise (Failure "Unimplemented") | T.Assumed T.Option -> (* We shouldn't get there in the "one-branch" case: options have * two variants *) diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index b4bb0386..925f6d39 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -170,8 +170,10 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) (* Continue exploring *) analyze expl_info ty_info rty | Adt - ((Tuple | Assumed (Box | Vec | Option | Slice | Array | Str)), _, tys, _) - -> + ( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)), + _, + tys, + _ ) -> (* Nothing to update: just explore the type parameters *) List.fold_left (fun ty_info ty -> analyze expl_info ty_info ty) diff --git a/compiler/ValuesUtils.ml b/compiler/ValuesUtils.ml index abbfad31..d748cc2e 100644 --- a/compiler/ValuesUtils.ml +++ b/compiler/ValuesUtils.ml @@ -3,6 +3,7 @@ open TypesUtils open Types open Values module TA = TypesAnalysis +include PrimitiveValuesUtils (** Utility exception *) exception FoundSymbolicValue of symbolic_value diff --git a/compiler/dune b/compiler/dune index b74b65fa..6785cad4 100644 --- a/compiler/dune +++ b/compiler/dune @@ -48,6 +48,8 @@ PrePasses Print PrintPure + PrimitiveValues + PrimitiveValuesUtils PureMicroPasses Pure PureTypeCheck @@ -67,8 +69,7 @@ TypesUtils Utils Values - ValuesUtils - PrimitiveValues)) + ValuesUtils)) (documentation (package aeneas)) -- cgit v1.2.3