From 59ec03d37d2ad51cf77e456622703c4c84780f48 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Aug 2023 11:46:09 +0200 Subject: Make progress --- compiler/InterpreterPaths.ml | 21 ++-- compiler/InterpreterStatements.ml | 92 ++++++++------ compiler/PrintPure.ml | 1 + compiler/Pure.ml | 10 +- compiler/SymbolicAst.ml | 3 + compiler/SymbolicToPure.ml | 249 ++++++++++++++++++++++---------------- compiler/SynthesizeSymbolic.ml | 20 +-- 7 files changed, 232 insertions(+), 164 deletions(-) (limited to 'compiler') diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 4a439250..04dc8892 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -359,7 +359,8 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value) let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.erased_region list) (types : T.ety list) : V.typed_value = + (regions : T.erased_region list) (types : T.ety list) + (cgs : T.const_generic list) : V.typed_value = (* Lookup the definition and check if it is an enumeration - it should be an enumeration if and only if the projection element is a field projection with *some* variant id. Retrieve the list @@ -368,12 +369,12 @@ let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) assert (List.length regions = List.length def.T.region_params); (* Compute the field types *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types + Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs in (* Initialize the expanded value *) let fields = List.map mk_bottom field_types in let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in - let ty = T.Adt (T.AdtId def_id, regions, types) in + let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in { V.value = av; V.ty } let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) @@ -386,7 +387,7 @@ let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) else raise (Failure "Unreachable") in let av = V.Adt { variant_id = Some variant_id; field_values } in - let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ]) in + let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in { V.value = av; ty } let compute_expanded_bottom_tuple_value (field_types : T.ety list) : @@ -394,7 +395,7 @@ let compute_expanded_bottom_tuple_value (field_types : T.ety list) : (* Generate the field values *) let fields = List.map mk_bottom field_types in let v = V.Adt { variant_id = None; field_values = fields } in - let ty = T.Adt (T.Tuple, [], field_types) in + let ty = T.Adt (T.Tuple, [], field_types, []) in { V.value = v; V.ty } (** Auxiliary helper to expand {!V.Bottom} values. @@ -446,16 +447,16 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place) match (pe, ty) with (* "Regular" ADTs *) | ( Field (ProjAdt (def_id, opt_variant_id), _), - T.Adt (T.AdtId def_id', regions, types) ) -> + T.Adt (T.AdtId def_id', regions, types, cgs) ) -> assert (def_id = def_id'); compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id - opt_variant_id regions types + opt_variant_id regions types cgs (* Option *) - | Field (ProjOption variant_id, _), T.Adt (T.Assumed T.Option, [], [ ty ]) - -> + | ( Field (ProjOption variant_id, _), + T.Adt (T.Assumed T.Option, [], [ ty ], []) ) -> compute_expanded_bottom_option_value variant_id ty (* Tuples *) - | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys) -> + | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) -> assert (arity = List.length tys); (* Generate the field values *) compute_expanded_bottom_tuple_value tys diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index d181ca4b..cd5f8c3e 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -420,18 +420,20 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun = (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_replace_concrete (_config : C.config) - (_region_params : T.erased_region list) (_type_params : T.ety list) : cm_fun - = + (_region_params : T.erased_region list) (_type_params : T.ety list) + (_cg_params : T.const_generic list) : cm_fun = fun _cf _ctx -> raise Unimplemented (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_new_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = fun cf ctx -> (* Check and retrieve the arguments *) - match (region_params, type_params, ctx.env) with + match (region_params, type_params, cg_params, ctx.env) with | ( [], [ boxed_ty ], + [], Var (VarBinder input_var, input_value) :: Var (_ret_var, _) :: C.Frame :: _ ) -> @@ -468,12 +470,13 @@ let eval_box_new_concrete (config : C.config) and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *) let eval_box_deref_mut_or_shared_concrete (config : C.config) (region_params : T.erased_region list) (type_params : T.ety list) - (is_mut : bool) : cm_fun = + (cg_params : T.const_generic list) (is_mut : bool) : cm_fun = fun cf ctx -> (* Check the arguments *) - match (region_params, type_params, ctx.env) with + match (region_params, type_params, cg_params, ctx.env) with | ( [], [ boxed_ty ], + [], Var (VarBinder input_var, input_value) :: Var (_ret_var, _) :: C.Frame :: _ ) -> @@ -513,15 +516,19 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_deref_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = let is_mut = false in - eval_box_deref_mut_or_shared_concrete config region_params type_params is_mut + eval_box_deref_mut_or_shared_concrete config region_params type_params + cg_params is_mut (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_box_deref_mut_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) : cm_fun = + (region_params : T.erased_region list) (type_params : T.ety list) + (cg_params : T.const_generic list) : cm_fun = let is_mut = true in - eval_box_deref_mut_or_shared_concrete config region_params type_params is_mut + eval_box_deref_mut_or_shared_concrete config region_params type_params + cg_params is_mut (** Auxiliary function - see {!eval_non_local_function_call}. @@ -543,11 +550,11 @@ let eval_box_deref_mut_concrete (config : C.config) the destination (by setting it to [()]). *) let eval_box_free (config : C.config) (region_params : T.erased_region list) - (type_params : T.ety list) (args : E.operand list) (dest : E.place) : cm_fun - = + (type_params : T.ety list) (cg_params : T.const_generic list) + (args : E.operand list) (dest : E.place) : cm_fun = fun cf ctx -> - match (region_params, type_params, args) with - | [], [ boxed_ty ], [ E.Move input_box_place ] -> + match (region_params, type_params, cg_params, args) with + | [], [ boxed_ty ], [], [ E.Move input_box_place ] -> (* Required type checking *) let input_box = InterpreterPaths.read_place Write input_box_place ctx in (let input_ty = ty_get_box input_box.V.ty in @@ -565,15 +572,15 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list) (** Auxiliary function - see {!eval_non_local_function_call} *) let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id) - (_region_params : T.erased_region list) (_type_params : T.ety list) : cm_fun - = + (_region_params : T.erased_region list) (_type_params : T.ety list) + (_cg_params : T.const_generic list) : cm_fun = fun _cf _ctx -> raise Unimplemented (** Evaluate a non-local function call in concrete mode *) let eval_non_local_function_call_concrete (config : C.config) (fid : A.assumed_fun_id) (region_params : T.erased_region list) - (type_params : T.ety list) (args : E.operand list) (dest : E.place) : cm_fun - = + (type_params : T.ety list) (cg_params : T.const_generic list) + (args : E.operand list) (dest : E.place) : cm_fun = (* There are two cases (and this is extremely annoying): - the function is not box_free - the function is box_free @@ -582,7 +589,7 @@ let eval_non_local_function_call_concrete (config : C.config) match fid with | A.BoxFree -> (* Degenerate case: box_free *) - eval_box_free config region_params type_params args dest + eval_box_free config region_params type_params cg_params args dest | _ -> (* "Normal" case: not box_free *) (* Evaluate the operands *) @@ -605,6 +612,7 @@ let eval_non_local_function_call_concrete (config : C.config) let ret_vid = E.VarId.zero in let ret_ty = get_non_local_function_return_type fid region_params type_params + cg_params in let ret_var = mk_var ret_vid (Some "@return") ret_ty in let cc = comp cc (push_uninitialized_var ret_var) in @@ -622,15 +630,25 @@ let eval_non_local_function_call_concrete (config : C.config) * access to a body. *) let cf_eval_body : cm_fun = match fid with - | A.Replace -> eval_replace_concrete config region_params type_params - | BoxNew -> eval_box_new_concrete config region_params type_params - | BoxDeref -> eval_box_deref_concrete config region_params type_params + | A.Replace -> + eval_replace_concrete config region_params type_params cg_params + | BoxNew -> + eval_box_new_concrete config region_params type_params cg_params + | BoxDeref -> + eval_box_deref_concrete config region_params type_params cg_params | BoxDerefMut -> eval_box_deref_mut_concrete config region_params type_params + cg_params | BoxFree -> (* Should have been treated above *) raise (Failure "Unreachable") | VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut -> eval_vec_function_concrete config fid region_params type_params + cg_params + | ArraySharedIndex | ArrayMutIndex | ArrayToSharedSlice + | ArrayToMutSlice | ArraySharedSubslice | ArrayMutSubslice + | SliceSharedIndex | SliceMutIndex | SliceSharedSubslice + | SliceMutSubslice -> + raise (Failure "Unimplemented") in let cc = comp cc cf_eval_body in @@ -644,8 +662,8 @@ let eval_non_local_function_call_concrete (config : C.config) (* Compose and apply *) comp cf_eval_ops cf_eval_call -let instantiate_fun_sig (type_params : T.ety list) (sg : A.fun_sig) : - A.inst_fun_sig = +let instantiate_fun_sig (type_params : T.ety list) + (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig = (* Generate fresh abstraction ids and create a substitution from region * group ids to abstraction ids *) let rg_abs_ids_bindings = @@ -674,13 +692,12 @@ let instantiate_fun_sig (type_params : T.ety list) (sg : A.fun_sig) : * work to do to properly handle full type parametrization. * *) let rtype_params = List.map ety_no_regions_to_rty type_params in - let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) sg.type_params) - rtype_params + let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params in (* Substitute the signature *) - let inst_sig = Subst.substitute_signature asubst rsubst tsubst sg in + let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in (* Return *) inst_sig @@ -912,7 +929,7 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = let cf_if (cf : st_m_fun) (op_v : V.typed_value) : m_fun = fun ctx -> match op_v.value with - | V.Primitive (PV.Bool b) -> + | V.Literal (PV.Bool b) -> (* Evaluate the if and the branch body *) let cf_branch cf : m_fun = (* Branch *) @@ -940,7 +957,7 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = let cf_switch (cf : st_m_fun) (op_v : V.typed_value) : m_fun = fun ctx -> match op_v.value with - | V.Primitive (PV.Scalar sv) -> + | V.Literal (PV.Scalar sv) -> (* Evaluate the branch *) let cf_eval_branch cf = (* Sanity check *) @@ -1035,7 +1052,8 @@ and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = (** Evaluate a local (i.e., non-assumed) function call in concrete mode *) and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) (region_args : T.erased_region list) (type_args : T.ety list) - (args : E.operand list) (dest : E.place) : st_cm_fun = + (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + st_cm_fun = fun cf ctx -> assert (region_args = []); @@ -1052,11 +1070,13 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) | Some body -> body in let tsubst = - Subst.make_type_subst - (List.map (fun v -> v.T.index) def.A.signature.type_params) - type_args + Subst.make_type_subst_from_vars def.A.signature.type_params type_args + in + let cgsubst = + Subst.make_const_generic_subst_from_vars + def.A.signature.const_generic_params cg_args in - let locals, body_st = Subst.fun_body_substitute_in_body tsubst body in + let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in (* Evaluate the input operands *) assert (List.length args = body.A.arg_count); diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 6f857b4f..33a86df5 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -532,6 +532,7 @@ let unop_to_string (unop : unop) : string = | Cast (src, tgt) -> "cast<" ^ integer_type_to_string src ^ "," ^ integer_type_to_string tgt ^ ">" + | SliceNew tgt_len -> "array_to_slice<" ^ scalar_value_to_string tgt_len ^ ">" let binop_to_string = Print.Expressions.binop_to_string diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 9b5d9236..b90ef60a 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -197,8 +197,8 @@ type type_decl = { } [@@deriving show] -type scalar_value = V.scalar_value [@@deriving show] -type literal = V.literal [@@deriving show] +type scalar_value = V.scalar_value [@@deriving show, ord] +type literal = V.literal [@@deriving show, ord] (** Because we introduce a lot of temporary variables, the list of variables is not fixed: we thus must carry all its information with the variable @@ -343,7 +343,11 @@ and typed_pattern = { value : pattern; ty : ty } polymorphic = false; }] -type unop = Not | Neg of integer_type | Cast of integer_type * integer_type +type unop = + | Not + | Neg of integer_type + | Cast of integer_type * integer_type + | SliceNew of scalar_value [@@deriving show, ord] (** Identifiers of assumed functions that we use only in the pure translation *) diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 0e68d2fd..787fefc7 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -43,7 +43,10 @@ type call = { borrows (we need to perform lookups). *) abstractions : V.AbstractionId.id list; + (* TODO: rename to "...args" *) type_params : T.ety list; + (* TODO: rename to "...args" *) + const_generic_params : T.const_generic list; args : V.typed_value list; args_places : mplace option list; (** Meta information *) dest : V.symbolic_value; diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ba2a6525..a6d2784b 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -107,6 +107,7 @@ type loop_info = { input_vars : var list; input_svl : V.symbolic_value list; type_args : ty list; + const_generic_args : const_generic list; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; @@ -460,10 +461,28 @@ let translate_type_decl (def : T.type_decl) : type_decl = let kind = translate_type_decl_kind def.T.kind in { def_id; name; type_params; const_generic_params; kind } +let translate_type_id (id : T.type_id) : type_id = + match id with + | AdtId adt_id -> AdtId adt_id + | T.Assumed aty -> + let aty = + match aty with + | T.Vec -> Vec + | T.Option -> Option + | T.Array -> Array + | T.Slice -> Slice + | T.Str -> Str + | T.Box -> + (* Boxes have to be eliminated: this type id shouldn't + be translated *) + raise (Failure "Unreachable") + in + Assumed aty + | T.Tuple -> Tuple + (** Translate a type, seen as an input/output of a forward function (preserve all borrows, etc.) *) - let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty type_infos in match ty with @@ -474,17 +493,11 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let t_tys = List.map translate tys in (* Eliminate boxes and simplify tuples *) match type_id with - | AdtId _ | T.Assumed (T.Vec | T.Option) -> + | AdtId _ | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str) -> (* No general parametricity for now *) assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); - let type_id = - match type_id with - | AdtId adt_id -> AdtId adt_id - | T.Assumed T.Vec -> Assumed Vec - | T.Assumed T.Option -> Assumed Option - | _ -> raise (Failure "Unreachable") - in - Adt (type_id, t_tys) + let type_id = translate_type_id type_id in + Adt (type_id, t_tys, cgs) | Tuple -> (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the identity *) @@ -501,17 +514,8 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = "Unreachable: box/vec/option receives exactly one type \ parameter"))) | TypeVar vid -> TypeVar vid - | Bool -> Bool - | Char -> Char | Never -> raise (Failure "Unreachable") - | Integer int_ty -> Integer int_ty - | Str -> Str - | Array ty -> - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - Array (translate ty) - | Slice ty -> - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - Slice (translate ty) + | Literal lty -> Literal lty | Ref (_, rty, _) -> translate rty (** Simply calls [translate_fwd_ty] *) @@ -531,21 +535,15 @@ let rec translate_back_ty (type_infos : TA.type_infos) (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with - | T.Adt (type_id, _, tys) -> ( + | T.Adt (type_id, _, tys, cgs) -> ( match type_id with - | T.AdtId _ | Assumed (T.Vec | T.Option) -> + | T.AdtId _ | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str) -> (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows type_infos ty)); - let type_id = - match type_id with - | T.AdtId id -> AdtId id - | T.Assumed T.Vec -> Assumed Vec - | T.Assumed T.Option -> Assumed Option - | T.Tuple | T.Assumed T.Box -> raise (Failure "Unreachable") - in + let type_id = translate_type_id type_id in if inside_mut then let tys_t = List.filter_map translate tys in - Some (Adt (type_id, tys_t)) + Some (Adt (type_id, tys_t, cgs)) else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) @@ -567,17 +565,8 @@ let rec translate_back_ty (type_infos : TA.type_infos) * is the identity *) Some (mk_simpl_tuple_ty tys_t))) | TypeVar vid -> wrap (TypeVar vid) - | Bool -> wrap Bool - | Char -> wrap Char | Never -> raise (Failure "Unreachable") - | Integer int_ty -> wrap (Integer int_ty) - | Str -> wrap Str - | Array ty -> ( - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - match translate ty with None -> None | Some ty -> Some (Array ty)) - | Slice ty -> ( - assert (not (TypesUtils.ty_has_borrows type_infos ty)); - match translate ty with None -> None | Some ty -> Some (Slice ty)) + | Literal lty -> wrap (Literal lty) | Ref (r, rty, rkind) -> ( match rkind with | T.Shared -> @@ -813,8 +802,9 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Wrap in a result type *) if effect_info.can_fail then mk_result_ty output else output in - (* Type parameters *) + (* Type/const generic parameters *) let type_params = sg.type_params in + let const_generic_params = sg.const_generic_params in (* Return *) let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in @@ -842,7 +832,9 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) effect_info; } in - let sg = { type_params; inputs; output; doutputs; info } in + let sg = + { type_params; const_generic_params; inputs; output; doutputs; info } + in { sg; output_names } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = @@ -921,7 +913,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = (** Peel boxes as long as the value is of the form [Box] *) let rec unbox_typed_value (v : V.typed_value) : V.typed_value = match (v.value, v.ty) with - | V.Adt av, T.Adt (T.Assumed T.Box, _, _) -> ( + | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> ( match av.field_values with | [ bv ] -> unbox_typed_value bv | _ -> raise (Failure "Unreachable")) @@ -960,26 +952,22 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the value *) let value = match v.value with - | V.Primitive cv -> { e = Const cv; ty } + | V.Literal cv -> { e = Const cv; ty } | Adt av -> ( let variant_id = av.variant_id in let field_values = List.map translate av.field_values in (* Eliminate the tuple wrapper if it is a tuple with exactly one field *) match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> assert (variant_id = None); mk_simpl_tuple_texpression field_values | _ -> - (* Retrieve the type and the translated type arguments from the - * translated type (simpler this way) *) - let adt_id, type_args = - match ty with - | Adt (type_id, tys) -> (type_id, tys) - | _ -> raise (Failure "Unreachable") - in + (* Retrieve the type, the translated type arguments and the + * const generic arguments from the translated type (simpler this way) *) + let adt_id, type_args, const_generic_args = ty_as_adt ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args } in + let qualif = { id = qualif_id; type_args; const_generic_args } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) field_values @@ -1046,9 +1034,10 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the field values *) let field_values = List.filter_map translate adt_v.field_values in (* For now, only tuples can contain borrows *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + | T.AdtId _ + | T.Assumed (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str) -> assert (field_values = []); None | T.Tuple -> @@ -1189,11 +1178,12 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) in let field_values = List.filter_map (fun x -> x) field_values in (* For now, only tuples can contain borrows - note that if we gave - * something like a [&mut Vec] to a function, we give give back the + * something like a [&mut Vec] to a function, we give back the * vector value upon visiting the "abstraction borrow" node *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + | T.AdtId _ + | T.Assumed (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str) -> assert (field_values = []); (ctx, None) | T.Tuple -> @@ -1463,6 +1453,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = (* Translate the function call *) let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in + let const_generic_args = call.const_generic_params in let args = let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in let args_mplaces = List.map translate_opt_mplace call.args_places in @@ -1540,6 +1531,19 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + | S.Unop (E.SliceNew tgt_len) -> + (* The cast can fail if the length of the source array is not + big enough *) + let effect_info = + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + is_rec = false; + } + in + (ctx, Unop (SliceNew tgt_len), effect_info, args, None) | S.Binop binop -> ( match args with | [ arg0; arg1 ] -> @@ -1564,7 +1568,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let func = { id = FunOrOp fun_id; type_args } in + let func = { id = FunOrOp fun_id; type_args; const_generic_args } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty @@ -1625,13 +1629,13 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) * to the backward function, and which consumed the values [consumed_i], * we introduce: * {[ - * let v_i = consumed_i in - * ... - * ]} + * let v_i = consumed_i in + * ... + * ]} * Then, when we reach the [Return] node, we introduce: * {[ - * (v_i) - * ]} + * (v_i) + * ]} * *) (* First, get the given back variables. @@ -1696,6 +1700,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in + let const_generic_args = call.const_generic_params in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in (* Retrieve the values consumed when we called the forward function and @@ -1744,7 +1749,10 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) in (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) let _ = - let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx in + let inst_sg = + get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args + ctx + in log#ldebug (lazy ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" @@ -1787,7 +1795,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; type_args } in + let func = { id = FunOrOp func; type_args; const_generic_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -1850,7 +1858,7 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) {[ let id_back x nx = let s = nx in // the name [s] is not important (only collision matters) - ... + ... ]} This let-binding later gets inlined, during a micro-pass. @@ -1911,6 +1919,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) in let loop_info = LoopId.Map.find loop_id ctx.loops in let type_args = loop_info.type_args in + let const_generic_args = loop_info.const_generic_args in let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we @@ -1959,7 +1968,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) in let func_ty = mk_arrows input_tys ret_ty in let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; type_args } in + let func = { id = FunOrOp func; type_args; const_generic_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -2019,7 +2028,9 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in - let global_expr = { id = Global gid; type_args = [] } in + let global_expr = + { id = Global gid; type_args = []; const_generic_args = [] } + in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in let gval = { e = Qualif global_expr; ty } in @@ -2032,8 +2043,14 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) let monadic = true in let v = typed_value_to_texpression ctx ectx v in let args = [ v ] in - let func = { id = FunOrOp (Fun (Pure Assert)); type_args = [] } in - let func_ty = mk_arrow Bool mk_unit_ty in + let func = + { + id = FunOrOp (Fun (Pure Assert)); + type_args = []; + const_generic_args = []; + } + in + let func_ty = mk_arrow (Literal Bool) mk_unit_ty in let func = { e = Qualif func; ty = func_ty } in let assertion = mk_apps func args in mk_let monadic (mk_dummy_pattern mk_unit_ty) assertion next_e @@ -2048,13 +2065,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match exp with | ExpandNoBranch (sexp, e) -> ( match sexp with - | V.SePrimitive _ -> - (* Actually, we don't *register* symbolic expansions to constant - * values in the symbolic ADT *) + | V.SeLiteral _ -> + (* We do not *register* symbolic expansions to literal + * values in the symbolic ADT *) raise (Failure "Unreachable") | SeMutRef (_, nsv) | SeSharedRef (_, nsv) -> (* The (mut/shared) borrow type is extracted to identity: we thus simply - * introduce an reassignment *) + * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let next_e = translate_expression e ctx in let monadic = false in @@ -2075,10 +2092,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. - We can decompose the ADT value with a let-binding, unless - the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): - we *ignore* this branch (and go to the next one) if the ADT is a custom - adt, and [always_deconstruct_adts_with_matches] is true. + We can decompose the ADT value with a let-binding, unless + the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): + we *ignore* this branch (and go to the next one) if the ADT is a custom + adt, and [always_deconstruct_adts_with_matches] is true. *) translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace variant_id svl branch ctx @@ -2127,14 +2144,14 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : match_branch = (* We don't need to update the context: we don't introduce any - * new values/variables *) + * new values/variables *) let branch = translate_expression branch_e ctx in - let pat = mk_typed_pattern_from_primitive_value (PV.Scalar v) in + let pat = mk_typed_pattern_from_literal (PV.Scalar v) in { pat; branch } in let branches = List.map translate_branch branches in let otherwise = translate_expression otherwise ctx in - let pat_ty = Integer int_ty in + let pat_ty = Literal (Integer int_ty) in let otherwise_pat : typed_pattern = { value = PatDummy; ty = pat_ty } in let otherwise : match_branch = { pat = otherwise_pat; branch = otherwise } @@ -2154,18 +2171,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) There are several possibilities: - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding: - {[ - let Cons x0 ... xn = y in - ... - ]} + {[ + let Cons x0 ... xn = y in + ... + ]} - if the ADT is a structure, we attempt to introduce one let-binding per field: - {[ - let x0 = y.f0 in - ... + {[ + let x0 = y.f0 in + ... let xn = y.fn in ... - ]} + ]} Of course, this is not always possible depending on the backend. Also, recursive structures, and more specifically structures mutually recursive @@ -2179,14 +2196,14 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (branch : S.expression) (ctx : bs_ctx) : texpression = (* TODO: always introduce a match, and use micro-passes to turn the the match into a let? *) - let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let branch = translate_expression branch ctx in match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in - let is_enum = type_decl_is_enum tdef in + let is_enum = TypesUtils.type_decl_is_enum tdef in (* We deconstruct the ADT with a let-binding in two situations: - if the ADT is an enumeration (which must have exactly one branch) - if we forbid using field projectors. @@ -2214,14 +2231,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) * field. * We use the [dest] variable in order not to have to recompute * the type of the result of the projection... *) - let adt_id, type_args = - match scrutinee.ty with - | Adt (adt_id, tys) -> (adt_id, tys) - | _ -> raise (Failure "Unreachable") - in + let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args } in + let qualif = { id = Proj proj_kind; type_args; const_generic_args } in let proj_e = Qualif qualif in let proj_ty = mk_arrow scrutinee.ty dest.ty in let proj = { e = proj_e; ty = proj_ty } in @@ -2253,12 +2266,12 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (mk_typed_pattern_from_var var None) (mk_opt_mplace_texpression scrutinee_mplace scrutinee) branch - | T.Assumed T.Vec -> - (* We can't expand vector values: we can access the fields only + | T.Assumed (T.Vec | T.Array | T.Slice | T.Str) -> + (* We can't expand those values: we can access the fields only * through the functions provided by the API (note that we don't - * know how to expand a vector, because it has a variable number + * know how to expand values like vectors or arrays, because they have a variable number * of fields!) *) - raise (Failure "Can't expand a vector value") + raise (Failure "Attempt to expand a non-expandable value") | T.Assumed T.Option -> (* We shouldn't get there in the "one-branch" case: options have * two variants *) @@ -2394,7 +2407,13 @@ and translate_forward_end (ectx : C.eval_ctx) let loop_call = let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in - let func = { id = FunOrOp fun_id; type_args = loop_info.type_args } in + let func = + { + id = FunOrOp fun_id; + type_args = loop_info.type_args; + const_generic_args = loop_info.const_generic_args; + } + in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty @@ -2515,7 +2534,12 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (and will introduce the outputs at that moment, together with the actual call to the loop forward function *) let type_args = - List.map (fun ty -> TypeVar ty.T.index) ctx.sg.type_params + List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params + in + let const_generic_args = + List.map + (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) + ctx.sg.const_generic_params in let loop_info = @@ -2524,6 +2548,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_vars = inputs; input_svl = loop.input_svalues; type_args; + const_generic_args; forward_inputs = None; forward_output_no_state_no_result = None; } @@ -2611,14 +2636,26 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) *) (* Create the expression: [fuel0 = 0] *) let check_fuel = - let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in + let func = + { + id = FunOrOp (Fun (Pure FuelEqZero)); + type_args = []; + const_generic_args = []; + } + in let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in let func = { e = Qualif func; ty = func_ty } in mk_app func fuel0 in (* Create the expression: [decrease fuel0] *) let decrease_fuel = - let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in + let func = + { + id = FunOrOp (Fun (Pure FuelDecrease)); + type_args = []; + const_generic_args = []; + } + in let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in let func = { e = Qualif func; ty = func_ty } in mk_app func fuel0 diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index e2cdc726..857fea97 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -98,9 +98,9 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value) let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (args : V.typed_value list) (args_places : mplace option list) - (dest : V.symbolic_value) (dest_place : mplace option) - (e : expression option) : expression option = + (const_generic_params : T.const_generic list) (args : V.typed_value list) + (args_places : mplace option list) (dest : V.symbolic_value) + (dest_place : mplace option) (e : expression option) : expression option = Option.map (fun e -> let call = @@ -109,6 +109,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) ctx; abstractions; type_params; + const_generic_params; args; dest; args_places; @@ -125,24 +126,25 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) let synthesize_regular_function_call (fun_id : A.fun_id) (call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (args : V.typed_value list) (args_places : mplace option list) - (dest : V.symbolic_value) (dest_place : mplace option) - (e : expression option) : expression option = + (const_generic_params : T.const_generic list) (args : V.typed_value list) + (args_places : mplace option list) (dest : V.symbolic_value) + (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions type_params args args_places dest dest_place e + ctx abstractions type_params const_generic_params args args_places dest + dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop) (arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Unop unop) ctx [] [] [ arg ] [ arg_place ] dest + synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop) (arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value) (arg1_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Binop binop) ctx [] [] [ arg0; arg1 ] + synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs) -- cgit v1.2.3