diff options
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r-- | src/SymbolicToPure.ml | 108 |
1 files changed, 93 insertions, 15 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 18e2b873..b25b7309 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -38,6 +38,12 @@ type config = { Note that we later filter the useless *forward* calls in the micro-passes, where it is more natural to do. *) + use_state_monad : bool; + (** If `true`, use a state-error monad. + If `false`, only use an error monad. + + Using a state-error monad is necessary when modelling I/O, for instance. + *) } type type_context = { @@ -950,6 +956,15 @@ let fun_is_monadic (fun_id : A.fun_id) : bool = | A.Regular _ -> true | A.Assumed aid -> Assumed.assumed_is_monadic aid +let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty = + if monadic then + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty out_ty + else out_ty + let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = match e with @@ -967,7 +982,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx -and translate_return (_config : config) (opt_v : V.typed_value option) +and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: - either we are translating a forward function, in which case the optional @@ -979,13 +994,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option) (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_rvalue ctx v in - (* TODO: we need to use a `return` function (otherwise we have problems - * with the state-error monad). We also need to update the type when using - * a state-error monad. *) - let v = mk_result_return_rvalue v in - let e = Value (v, None) in - let ty = v.ty in - { e; ty } + (* We don't synthesize the same expression depending on the monad we use: + * - error-monad: Return x + * - state-error monad: fun state -> Return (state, x) + * *) + (* TODO: we should use a `return` function, it would be cleaner *) + if config.use_state_monad then + let _, state_var = fresh_var (Some "state") mk_state_ty ctx in + let state_rvalue = mk_typed_rvalue_from_var state_var in + let v = + mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ]) + in + let e = Value (v, None) in + let ty = v.ty in + let e = { e; ty } in + let state_var = mk_typed_lvalue_from_var state_var None in + mk_abs state_var e + else + let v = mk_result_return_rvalue v in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } | Some bid -> (* Backward function *) (* Sanity check *) @@ -997,11 +1026,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option) T.RegionGroupId.Map.find bid ctx.backward_outputs in let field_values = List.map mk_typed_rvalue_from_var backward_outputs in - let ret_value = mk_simpl_tuple_rvalue field_values in - let ret_value = mk_result_return_rvalue ret_value in - let e = Value (ret_value, None) in - let ty = ret_value.ty in - { e; ty } + (* See the comment about the monads, for the forward function case *) + (* TODO: we should use a `fail` function, it would be cleaner *) + if config.use_state_monad then + let _, state_var = fresh_var (Some "state") mk_state_ty ctx in + let state_rvalue = mk_typed_rvalue_from_var state_var in + let ret_value = mk_simpl_tuple_rvalue field_values in + let ret_value = + mk_result_return_rvalue + (mk_simpl_tuple_rvalue [ state_rvalue; ret_value ]) + in + let e = Value (ret_value, None) in + let ty = ret_value.ty in + let e = { e; ty } in + let state_var = mk_typed_lvalue_from_var state_var None in + mk_abs state_var e + else + let ret_value = mk_simpl_tuple_rvalue field_values in + let ret_value = mk_result_return_rvalue ret_value in + let e = Value (ret_value, None) in + let ty = ret_value.ty in + { e; ty } and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1047,7 +1092,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let dest_v = mk_typed_lvalue_from_var dest dest_mplace in let func = { func; type_params } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in + let ret_ty = mk_function_ret_ty config monadic dest_v.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { e = Func func; ty = func_ty } in let call = mk_apps func args in @@ -1180,7 +1225,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) in let monadic = fun_is_monadic fun_id in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = if monadic then mk_result_ty output.ty else output.ty in + let ret_ty = mk_function_ret_ty config monadic output.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { func; type_params } in let func = { e = Func func; ty = func_ty } in @@ -1470,6 +1515,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Translate the declaration *) let def_id = def.A.def_id in let basename = def.name in + (* Lookup the signature *) let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in (* Translate the body, if there is *) let body = @@ -1502,6 +1548,38 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in + (* Make the signature monadic *) + let output_ty = + match (bid, signature.outputs) with + | None, [ out_ty ] -> + (* Forward function: there is always exactly one output *) + (* We don't do the same thing if we use a state error monad or not: + * - error-monad: `result out_ty` + * - state-error: `state -> result (state & out_ty) + *) + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else (* Simply wrap the type in `result` *) + mk_result_ty out_ty + | Some _, outputs -> + (* Backward function: we have to group the list of outputs into a tuple + * (and similarly to the forward function, we don't do the same thing + * if we use a state error monad or not): + * - error-monad: `result (out_ty1 & .. out_tyn)` + * - state-error: `state -> result (out_ty1 & .. out_tyn)` + *) + if config.use_state_monad then + let ret = mk_simpl_tuple_ty outputs in + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty (mk_simpl_tuple_ty outputs) + | _ -> failwith "Unreachable" + in + let outputs = [ output_ty ] in + let signature = { signature with outputs } in (* Assemble the declaration *) let def = { def_id; back_id = bid; basename; signature; body } in (* Debugging *) |