From ac2d7f421b8511c67614eb238865961e239486c2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 26 Jan 2022 15:56:47 +0100 Subject: Implement sanity checks to check the types of the input/output arguments given to backward functions --- src/Pure.ml | 31 +++++++++++++-- src/Substitute.ml | 5 ++- src/SymbolicToPure.ml | 102 +++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 125 insertions(+), 13 deletions(-) (limited to 'src') diff --git a/src/Pure.ml b/src/Pure.ml index fd1e7763..9a16264d 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -34,7 +34,28 @@ type ty = | Str | Array of ty (* TODO: there should be a constant with the array *) | Slice of ty -[@@deriving show] +[@@deriving + show, + visitors + { + name = "iter_ty"; + variety = "iter"; + ancestors = [ "T.iter_ty_base" ]; + (* Reusing the visitor from Types.ml *) + nude = true (* Don't inherit [VisitorsRuntime.iter] *); + concrete = true; + polymorphic = false; + }, + visitors + { + name = "map_ty"; + variety = "map"; + ancestors = [ "T.map_ty_base" ]; + (* Reusing the visitor from Types.ml *) + nude = true (* Don't inherit [VisitorsRuntime.iter] *); + concrete = true; + polymorphic = false; + }] type field = { field_name : string; field_ty : ty } [@@deriving show] @@ -104,12 +125,14 @@ and typed_rvalue = { value : rvalue; ty : ty } about ADTs, though. *) +type unop = Not | Neg of T.integer_type + type fun_id = | Regular of A.fun_id * T.RegionGroupId.id option (** Backward id: `Some` if the function is a backward function, `None` if it is a forward function *) - | Unop of E.unop - | Binop of E.binop + | Unop of unop + | Binop of E.binop * T.integer_type type call = { func : fun_id; type_params : ty list; args : typed_rvalue list } @@ -189,6 +212,8 @@ type fun_sig = { *) } +type inst_fun_sig = { inputs : ty list; outputs : ty list } + type fun_def = { def_id : FunDefId.id; name : name; diff --git a/src/Substitute.ml b/src/Substitute.ml index 9db58812..01ce3a4e 100644 --- a/src/Substitute.ml +++ b/src/Substitute.ml @@ -9,7 +9,10 @@ module E = Expressions module A = CfimAst module C = Contexts -(** Substitute types variables and regions in a type *) +(** Substitute types variables and regions in a type. + + TODO: we can reimplement that with visitors. + *) let rec ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) (ty : 'r1 T.ty) : 'r2 T.ty = let open T in diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index afca9398..a3aad6df 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -109,19 +109,67 @@ let mk_typed_lvalue_from_var (v : var) : typed_lvalue = let ty = v.ty in { value; ty } +let ty_as_integer (t : ty) : T.integer_type = + match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable" + (* TODO: move *) let type_def_is_enum (def : T.type_def) : bool = match def.kind with T.Struct _ -> false | Enum _ -> true +(** Type substitution *) +let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = + let obj = + object + inherit [_] map_ty + + method! visit_TypeVar _ var_id = tsubst var_id + end + in + obj#visit_ty () ty + +let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty + = + let ls = List.combine vars tys in + let mp = + List.fold_left + (fun mp (k, v) -> TypeVarId.Map.add (k : type_var).index v mp) + TypeVarId.Map.empty ls + in + fun id -> TypeVarId.Map.find id mp + +let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : + inst_fun_sig = + let subst = ty_substitute tsubst in + let inputs = List.map subst sg.inputs in + let outputs = List.map subst sg.outputs in + { inputs; outputs } + +type regular_fun_id = { fun_id : A.fun_id; back_id : T.RegionGroupId.id option } +[@@deriving show, ord] +(** We use this type as a key for lookups *) + +module RegularFunIdOrderedType = struct + type t = regular_fun_id + + let compare = compare_regular_fun_id + + let to_string = show_regular_fun_id + + let pp_t = pp_regular_fun_id + + let show_t = show_regular_fun_id +end + +module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) + type type_context = { types_infos : TA.type_infos; cfim_type_defs : T.type_def TypeDefId.Map.t; - type_defs : type_def TypeDefId.Map.t; } type fun_context = { cfim_fun_defs : A.fun_def FunDefId.Map.t; - fun_defs : fun_def FunDefId.Map.t; + fun_sigs : fun_sig RegularFunIdMap.t; } type call_info = { @@ -199,8 +247,16 @@ let fs_ctx_to_bs_ctx (fs_ctx : fs_ctx) : bs_ctx = abstractions; } -(*let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def = - TypeDefId.Map.find id ctx.type_context.type_defs*) +let get_instantiated_fun_sig (fun_id : A.fun_id) + (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) : + inst_fun_sig = + (* Lookup the non-instantiated function signature *) + let sg = RegularFunIdMap.find { fun_id; back_id } ctx.fun_context.fun_sigs in + (* Create the substitution *) + let tsubst = make_type_subst sg.type_params tys in + (* Apply *) + fun_sig_substitute tsubst sg + let bs_ctx_lookup_cfim_type_def (id : TypeDefId.id) (ctx : bs_ctx) : T.type_def = TypeDefId.Map.find id ctx.type_context.cfim_type_defs @@ -815,15 +871,28 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let args = List.map (typed_value_to_rvalue ctx) call.args in let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context - * if necessary *) + * if necessary. *) let ctx, func = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in (ctx, func) - | S.Unop unop -> (ctx, Unop unop) - | S.Binop binop -> (ctx, Binop binop) + | S.Unop E.Not -> (ctx, Unop Not) + | S.Unop E.Neg -> ( + match args with + | [ arg ] -> + let int_ty = ty_as_integer arg.ty in + (ctx, Unop (Neg int_ty)) + | _ -> failwith "Unreachable") + | S.Binop binop -> ( + match args with + | [ arg0; arg1 ] -> + let int_ty0 = ty_as_integer arg0.ty in + let int_ty1 = ty_as_integer arg1.ty in + assert (int_ty0 = int_ty1); + (ctx, Binop (binop, int_ty0)) + | _ -> failwith "Unreachable") in let call = { func; type_params; args } in (* Translate the next expression *) @@ -856,11 +925,26 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : let inputs = List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ] in - (* TODO: check that the inputs have the proper number and the proper type *) (* Retrieve the values given back by this function: those are the output * values *) let ctx, outputs = abs_to_given_back abs ctx in - (* TODO: check that the outputs have the proper number and the proper type *) + (* Sanity check: the inputs and outputs have the proper number and the proper type *) + let fun_id = + match call.call_id with + | S.Fun (fun_id, _) -> fun_id + | Unop _ | Binop _ -> + (* Those don't have backward functions *) failwith "Unreachable" + in + + let inst_sg = + get_instantiated_fun_sig fun_id (Some abs.back_id) type_params ctx + in + List.iter + (fun (x, ty) -> assert ((x : typed_rvalue).ty = ty)) + (List.combine inputs inst_sg.inputs); + List.iter + (fun (x, ty) -> assert ((x : typed_lvalue).ty = ty)) + (List.combine outputs inst_sg.outputs); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = bs_ctx_register_backward_call abs ctx in -- cgit v1.2.3