diff options
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r-- | src/SymbolicToPure.ml | 102 |
1 files changed, 93 insertions, 9 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index afca9398..a3aad6df 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -109,19 +109,67 @@ let mk_typed_lvalue_from_var (v : var) : typed_lvalue = let ty = v.ty in { value; ty } +let ty_as_integer (t : ty) : T.integer_type = + match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable" + (* TODO: move *) let type_def_is_enum (def : T.type_def) : bool = match def.kind with T.Struct _ -> false | Enum _ -> true +(** Type substitution *) +let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = + let obj = + object + inherit [_] map_ty + + method! visit_TypeVar _ var_id = tsubst var_id + end + in + obj#visit_ty () ty + +let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty + = + let ls = List.combine vars tys in + let mp = + List.fold_left + (fun mp (k, v) -> TypeVarId.Map.add (k : type_var).index v mp) + TypeVarId.Map.empty ls + in + fun id -> TypeVarId.Map.find id mp + +let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : + inst_fun_sig = + let subst = ty_substitute tsubst in + let inputs = List.map subst sg.inputs in + let outputs = List.map subst sg.outputs in + { inputs; outputs } + +type regular_fun_id = { fun_id : A.fun_id; back_id : T.RegionGroupId.id option } +[@@deriving show, ord] +(** We use this type as a key for lookups *) + +module RegularFunIdOrderedType = struct + type t = regular_fun_id + + let compare = compare_regular_fun_id + + let to_string = show_regular_fun_id + + let pp_t = pp_regular_fun_id + + let show_t = show_regular_fun_id +end + +module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) + type type_context = { types_infos : TA.type_infos; cfim_type_defs : T.type_def TypeDefId.Map.t; - type_defs : type_def TypeDefId.Map.t; } type fun_context = { cfim_fun_defs : A.fun_def FunDefId.Map.t; - fun_defs : fun_def FunDefId.Map.t; + fun_sigs : fun_sig RegularFunIdMap.t; } type call_info = { @@ -199,8 +247,16 @@ let fs_ctx_to_bs_ctx (fs_ctx : fs_ctx) : bs_ctx = abstractions; } -(*let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def = - TypeDefId.Map.find id ctx.type_context.type_defs*) +let get_instantiated_fun_sig (fun_id : A.fun_id) + (back_id : T.RegionGroupId.id option) (tys : ty list) (ctx : bs_ctx) : + inst_fun_sig = + (* Lookup the non-instantiated function signature *) + let sg = RegularFunIdMap.find { fun_id; back_id } ctx.fun_context.fun_sigs in + (* Create the substitution *) + let tsubst = make_type_subst sg.type_params tys in + (* Apply *) + fun_sig_substitute tsubst sg + let bs_ctx_lookup_cfim_type_def (id : TypeDefId.id) (ctx : bs_ctx) : T.type_def = TypeDefId.Map.find id ctx.type_context.cfim_type_defs @@ -815,15 +871,28 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let args = List.map (typed_value_to_rvalue ctx) call.args in let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context - * if necessary *) + * if necessary. *) let ctx, func = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in (ctx, func) - | S.Unop unop -> (ctx, Unop unop) - | S.Binop binop -> (ctx, Binop binop) + | S.Unop E.Not -> (ctx, Unop Not) + | S.Unop E.Neg -> ( + match args with + | [ arg ] -> + let int_ty = ty_as_integer arg.ty in + (ctx, Unop (Neg int_ty)) + | _ -> failwith "Unreachable") + | S.Binop binop -> ( + match args with + | [ arg0; arg1 ] -> + let int_ty0 = ty_as_integer arg0.ty in + let int_ty1 = ty_as_integer arg1.ty in + assert (int_ty0 = int_ty1); + (ctx, Binop (binop, int_ty0)) + | _ -> failwith "Unreachable") in let call = { func; type_params; args } in (* Translate the next expression *) @@ -856,11 +925,26 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : let inputs = List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ] in - (* TODO: check that the inputs have the proper number and the proper type *) (* Retrieve the values given back by this function: those are the output * values *) let ctx, outputs = abs_to_given_back abs ctx in - (* TODO: check that the outputs have the proper number and the proper type *) + (* Sanity check: the inputs and outputs have the proper number and the proper type *) + let fun_id = + match call.call_id with + | S.Fun (fun_id, _) -> fun_id + | Unop _ | Binop _ -> + (* Those don't have backward functions *) failwith "Unreachable" + in + + let inst_sg = + get_instantiated_fun_sig fun_id (Some abs.back_id) type_params ctx + in + List.iter + (fun (x, ty) -> assert ((x : typed_rvalue).ty = ty)) + (List.combine inputs inst_sg.inputs); + List.iter + (fun (x, ty) -> assert ((x : typed_lvalue).ty = ty)) + (List.combine outputs inst_sg.outputs); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = bs_ctx_register_backward_call abs ctx in |