summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r--src/SymbolicToPure.ml102
1 files changed, 93 insertions, 9 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index afca9398..a3aad6df 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -109,19 +109,67 @@ let mk_typed_lvalue_from_var (v : var) : typed_lvalue =
let ty = v.ty in
{ value; ty }
+let ty_as_integer (t : ty) : T.integer_type =
+ match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable"
+
(* TODO: move *)
let type_def_is_enum (def : T.type_def) : bool =
match def.kind with T.Struct _ -> false | Enum _ -> true
+(** Type substitution *)
+let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
+ let obj =
+ object
+ inherit [_] map_ty
+
+ method! visit_TypeVar _ var_id = tsubst var_id
+ end
+ in
+ obj#visit_ty () ty
+
+let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
+ =
+ let ls = List.combine vars tys in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> TypeVarId.Map.add (k : type_var).index v mp)
+ TypeVarId.Map.empty ls
+ in
+ fun id -> TypeVarId.Map.find id mp
+
+let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
+ inst_fun_sig =
+ let subst = ty_substitute tsubst in
+ let inputs = List.map subst sg.inputs in
+ let outputs = List.map subst sg.outputs in
+ { inputs; outputs }
+
+type regular_fun_id = { fun_id : A.fun_id; back_id : T.RegionGroupId.id option }
+[@@deriving show, ord]
+(** We use this type as a key for lookups *)
+
+module RegularFunIdOrderedType = struct
+ type t = regular_fun_id
+
+ let compare = compare_regular_fun_id
+
+ let to_string = show_regular_fun_id
+
+ let pp_t = pp_regular_fun_id
+
+ let show_t = show_regular_fun_id
+end
+
+module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType)
+
type type_context = {
types_infos : TA.type_infos;
cfim_type_defs : T.type_def TypeDefId.Map.t;
- type_defs : type_def TypeDefId.Map.t;
}
type fun_context = {
cfim_fun_defs : A.fun_def FunDefId.Map.t;
- fun_defs : fun_def FunDefId.Map.t;
+ fun_sigs : fun_sig RegularFunIdMap.t;
}
type call_info = {
@@ -199,8 +247,16 @@ let fs_ctx_to_bs_ctx (fs_ctx : fs_ctx) : bs_ctx =
abstractions;
}
-(*let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def =
- TypeDefId.Map.find id ctx.type_context.type_defs*)
+let get_instantiated_fun_sig (fun_id : A.fun_id)
+ (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) :
+ inst_fun_sig =
+ (* Lookup the non-instantiated function signature *)
+ let sg = RegularFunIdMap.find { fun_id; back_id } ctx.fun_context.fun_sigs in
+ (* Create the substitution *)
+ let tsubst = make_type_subst sg.type_params tys in
+ (* Apply *)
+ fun_sig_substitute tsubst sg
+
let bs_ctx_lookup_cfim_type_def (id : TypeDefId.id) (ctx : bs_ctx) : T.type_def
=
TypeDefId.Map.find id ctx.type_context.cfim_type_defs
@@ -815,15 +871,28 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let args = List.map (typed_value_to_rvalue ctx) call.args in
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ * if necessary. *)
let ctx, func =
match call.call_id with
| S.Fun (fid, call_id) ->
let ctx = bs_ctx_register_forward_call call_id call ctx in
let func = Regular (fid, None) in
(ctx, func)
- | S.Unop unop -> (ctx, Unop unop)
- | S.Binop binop -> (ctx, Binop binop)
+ | S.Unop E.Not -> (ctx, Unop Not)
+ | S.Unop E.Neg -> (
+ match args with
+ | [ arg ] ->
+ let int_ty = ty_as_integer arg.ty in
+ (ctx, Unop (Neg int_ty))
+ | _ -> failwith "Unreachable")
+ | S.Binop binop -> (
+ match args with
+ | [ arg0; arg1 ] ->
+ let int_ty0 = ty_as_integer arg0.ty in
+ let int_ty1 = ty_as_integer arg1.ty in
+ assert (int_ty0 = int_ty1);
+ (ctx, Binop (binop, int_ty0))
+ | _ -> failwith "Unreachable")
in
let call = { func; type_params; args } in
(* Translate the next expression *)
@@ -856,11 +925,26 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
let inputs =
List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ]
in
- (* TODO: check that the inputs have the proper number and the proper type *)
(* Retrieve the values given back by this function: those are the output
* values *)
let ctx, outputs = abs_to_given_back abs ctx in
- (* TODO: check that the outputs have the proper number and the proper type *)
+ (* Sanity check: the inputs and outputs have the proper number and the proper type *)
+ let fun_id =
+ match call.call_id with
+ | S.Fun (fun_id, _) -> fun_id
+ | Unop _ | Binop _ ->
+ (* Those don't have backward functions *) failwith "Unreachable"
+ in
+
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some abs.back_id) type_params ctx
+ in
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_rvalue).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_lvalue).ty = ty))
+ (List.combine outputs inst_sg.outputs);
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in