summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-01-26 15:56:47 +0100
committerSon Ho2022-01-26 15:56:47 +0100
commitac2d7f421b8511c67614eb238865961e239486c2 (patch)
tree75e8141cb01a4175a387a2b1d0055f83815a514a
parentf10dff1e13c00eaa49d78f5d7ba79366fa028a73 (diff)
Implement sanity checks to check the types of the input/output arguments
given to backward functions
-rw-r--r--src/Pure.ml31
-rw-r--r--src/Substitute.ml5
-rw-r--r--src/SymbolicToPure.ml102
3 files changed, 125 insertions, 13 deletions
diff --git a/src/Pure.ml b/src/Pure.ml
index fd1e7763..9a16264d 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -34,7 +34,28 @@ type ty =
| Str
| Array of ty (* TODO: there should be a constant with the array *)
| Slice of ty
-[@@deriving show]
+[@@deriving
+ show,
+ visitors
+ {
+ name = "iter_ty";
+ variety = "iter";
+ ancestors = [ "T.iter_ty_base" ];
+ (* Reusing the visitor from Types.ml *)
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "map_ty";
+ variety = "map";
+ ancestors = [ "T.map_ty_base" ];
+ (* Reusing the visitor from Types.ml *)
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ polymorphic = false;
+ }]
type field = { field_name : string; field_ty : ty } [@@deriving show]
@@ -104,12 +125,14 @@ and typed_rvalue = { value : rvalue; ty : ty }
about ADTs, though.
*)
+type unop = Not | Neg of T.integer_type
+
type fun_id =
| Regular of A.fun_id * T.RegionGroupId.id option
(** Backward id: `Some` if the function is a backward function, `None`
if it is a forward function *)
- | Unop of E.unop
- | Binop of E.binop
+ | Unop of unop
+ | Binop of E.binop * T.integer_type
type call = { func : fun_id; type_params : ty list; args : typed_rvalue list }
@@ -189,6 +212,8 @@ type fun_sig = {
*)
}
+type inst_fun_sig = { inputs : ty list; outputs : ty list }
+
type fun_def = {
def_id : FunDefId.id;
name : name;
diff --git a/src/Substitute.ml b/src/Substitute.ml
index 9db58812..01ce3a4e 100644
--- a/src/Substitute.ml
+++ b/src/Substitute.ml
@@ -9,7 +9,10 @@ module E = Expressions
module A = CfimAst
module C = Contexts
-(** Substitute types variables and regions in a type *)
+(** Substitute types variables and regions in a type.
+
+ TODO: we can reimplement that with visitors.
+ *)
let rec ty_substitute (rsubst : 'r1 -> 'r2)
(tsubst : T.TypeVarId.id -> 'r2 T.ty) (ty : 'r1 T.ty) : 'r2 T.ty =
let open T in
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index afca9398..a3aad6df 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -109,19 +109,67 @@ let mk_typed_lvalue_from_var (v : var) : typed_lvalue =
let ty = v.ty in
{ value; ty }
+let ty_as_integer (t : ty) : T.integer_type =
+ match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable"
+
(* TODO: move *)
let type_def_is_enum (def : T.type_def) : bool =
match def.kind with T.Struct _ -> false | Enum _ -> true
+(** Type substitution *)
+let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
+ let obj =
+ object
+ inherit [_] map_ty
+
+ method! visit_TypeVar _ var_id = tsubst var_id
+ end
+ in
+ obj#visit_ty () ty
+
+let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
+ =
+ let ls = List.combine vars tys in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> TypeVarId.Map.add (k : type_var).index v mp)
+ TypeVarId.Map.empty ls
+ in
+ fun id -> TypeVarId.Map.find id mp
+
+let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
+ inst_fun_sig =
+ let subst = ty_substitute tsubst in
+ let inputs = List.map subst sg.inputs in
+ let outputs = List.map subst sg.outputs in
+ { inputs; outputs }
+
+type regular_fun_id = { fun_id : A.fun_id; back_id : T.RegionGroupId.id option }
+[@@deriving show, ord]
+(** We use this type as a key for lookups *)
+
+module RegularFunIdOrderedType = struct
+ type t = regular_fun_id
+
+ let compare = compare_regular_fun_id
+
+ let to_string = show_regular_fun_id
+
+ let pp_t = pp_regular_fun_id
+
+ let show_t = show_regular_fun_id
+end
+
+module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType)
+
type type_context = {
types_infos : TA.type_infos;
cfim_type_defs : T.type_def TypeDefId.Map.t;
- type_defs : type_def TypeDefId.Map.t;
}
type fun_context = {
cfim_fun_defs : A.fun_def FunDefId.Map.t;
- fun_defs : fun_def FunDefId.Map.t;
+ fun_sigs : fun_sig RegularFunIdMap.t;
}
type call_info = {
@@ -199,8 +247,16 @@ let fs_ctx_to_bs_ctx (fs_ctx : fs_ctx) : bs_ctx =
abstractions;
}
-(*let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def =
- TypeDefId.Map.find id ctx.type_context.type_defs*)
+let get_instantiated_fun_sig (fun_id : A.fun_id)
+ (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) :
+ inst_fun_sig =
+ (* Lookup the non-instantiated function signature *)
+ let sg = RegularFunIdMap.find { fun_id; back_id } ctx.fun_context.fun_sigs in
+ (* Create the substitution *)
+ let tsubst = make_type_subst sg.type_params tys in
+ (* Apply *)
+ fun_sig_substitute tsubst sg
+
let bs_ctx_lookup_cfim_type_def (id : TypeDefId.id) (ctx : bs_ctx) : T.type_def
=
TypeDefId.Map.find id ctx.type_context.cfim_type_defs
@@ -815,15 +871,28 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let args = List.map (typed_value_to_rvalue ctx) call.args in
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ * if necessary. *)
let ctx, func =
match call.call_id with
| S.Fun (fid, call_id) ->
let ctx = bs_ctx_register_forward_call call_id call ctx in
let func = Regular (fid, None) in
(ctx, func)
- | S.Unop unop -> (ctx, Unop unop)
- | S.Binop binop -> (ctx, Binop binop)
+ | S.Unop E.Not -> (ctx, Unop Not)
+ | S.Unop E.Neg -> (
+ match args with
+ | [ arg ] ->
+ let int_ty = ty_as_integer arg.ty in
+ (ctx, Unop (Neg int_ty))
+ | _ -> failwith "Unreachable")
+ | S.Binop binop -> (
+ match args with
+ | [ arg0; arg1 ] ->
+ let int_ty0 = ty_as_integer arg0.ty in
+ let int_ty1 = ty_as_integer arg1.ty in
+ assert (int_ty0 = int_ty1);
+ (ctx, Binop (binop, int_ty0))
+ | _ -> failwith "Unreachable")
in
let call = { func; type_params; args } in
(* Translate the next expression *)
@@ -856,11 +925,26 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
let inputs =
List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ]
in
- (* TODO: check that the inputs have the proper number and the proper type *)
(* Retrieve the values given back by this function: those are the output
* values *)
let ctx, outputs = abs_to_given_back abs ctx in
- (* TODO: check that the outputs have the proper number and the proper type *)
+ (* Sanity check: the inputs and outputs have the proper number and the proper type *)
+ let fun_id =
+ match call.call_id with
+ | S.Fun (fun_id, _) -> fun_id
+ | Unop _ | Binop _ ->
+ (* Those don't have backward functions *) failwith "Unreachable"
+ in
+
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some abs.back_id) type_params ctx
+ in
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_rvalue).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_lvalue).ty = ty))
+ (List.combine outputs inst_sg.outputs);
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in