diff options
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r-- | src/SymbolicToPure.ml | 74 |
1 files changed, 50 insertions, 24 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 4e15d921..49bf3559 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -948,17 +948,26 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = (** Small utility. - Return true if a function return type is monadic. - Always true, at the exception of some assumed functions. + Return: (function is monadic, function uses state monad) + + Note that all functions are monadic except some assumed functions. *) -let fun_is_monadic (fun_id : A.fun_id) : bool = +let fun_is_monadic (fun_id : A.fun_id) : bool * bool = match fun_id with - | A.Regular _ -> true - | A.Assumed aid -> Assumed.assumed_is_monadic aid + | A.Regular _ -> (true, true) + | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false) + +(** Utility for function return types. -let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty = + A function return type can have the shape: + - ty + - result ty (* error-monad *) + - state -> result (state & ty) (* state-error monad *) + *) +let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool) + (out_ty : ty) : ty = if monadic then - if config.use_state_monad then + if config.use_state_monad && state_monad then let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in let ret = mk_arrow_ty mk_state_ty ret in ret @@ -969,19 +978,34 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return opt_v -> translate_return config opt_v ctx - | Panic -> - (* Here we use the function return type - note that it is ok because - * we don't match on panics which happen inside the function body - - * but it won't be true anymore once we translate individual blocks *) - let v = mk_result_fail_rvalue ctx.ret_ty in - let e = Value (v, None) in - let ty = v.ty in - { e; ty } + | Panic -> translate_panic config ctx | FunCall (call, e) -> translate_function_call config call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx +and translate_panic (config : config) (ctx : bs_ctx) : texpression = + (* Here we use the function return type - note that it is ok because + * we don't match on panics which happen inside the function body - + * but it won't be true anymore once we translate individual blocks *) + (* If we use a state monad, we need to add a lambda for the state variable *) + if config.use_state_monad then + (* Create the `Fail` value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in + let v = mk_result_fail_rvalue ret_ty in + let e = Value (v, None) in + let ty = v.ty in + let e = { e; ty } in + (* Add the lambda *) + let _, state_var = fresh_var (Some "st") mk_state_ty ctx in + let state_lvalue = mk_typed_lvalue_from_var state_var None in + mk_abs state_lvalue e + else + let v = mk_result_fail_rvalue ctx.ret_ty in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } + and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1058,21 +1082,21 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, func, monadic = + let ctx, func, monadic, state_monad = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in - let monadic = fun_is_monadic fid in - (ctx, func, monadic) - | S.Unop E.Not -> (ctx, Unop Not, false) + let monadic, state_monad = fun_is_monadic fid in + (ctx, func, monadic, state_monad) + | S.Unop E.Not -> (ctx, Unop Not, false, false) | S.Unop E.Neg -> ( match args with | [ arg ] -> let int_ty = ty_as_integer arg.ty in (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) - (ctx, Unop (Neg int_ty), true) + (ctx, Unop (Neg int_ty), true, false) | _ -> failwith "Unreachable") | S.Binop binop -> ( match args with @@ -1081,7 +1105,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let int_ty1 = ty_as_integer arg1.ty in assert (int_ty0 = int_ty1); let monadic = binop_can_fail binop in - (ctx, Binop (binop, int_ty0), monadic) + (ctx, Binop (binop, int_ty0), monadic, false) | _ -> failwith "Unreachable") in let args = @@ -1092,7 +1116,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let dest_v = mk_typed_lvalue_from_var dest dest_mplace in let func = { func; type_params } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic dest_v.ty in + let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { e = Func func; ty = func_ty } in let call = mk_apps func args in @@ -1223,9 +1247,9 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (arg, mp) -> mk_value_expression arg mp) (List.combine inputs args_mplaces) in - let monadic = fun_is_monadic fun_id in + let monadic, state_monad = fun_is_monadic fun_id in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic output.ty in + let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { func; type_params } in let func = { e = Func func; ty = func_ty } in @@ -1444,7 +1468,9 @@ and translate_expansion (config : config) (p : S.mplace option) (* There should be at least one branch *) let branch = List.hd branches in let ty = branch.branch.ty in + (* Sanity check *) assert (List.for_all (fun br -> br.branch.ty = ty) branches); + (* Return *) { e; ty }) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any |