diff options
author | Son Ho | 2022-01-28 12:07:53 +0100 |
---|---|---|
committer | Son Ho | 2022-01-28 12:07:53 +0100 |
commit | bb9d21e658630315a7e83bfbdfb7a1b53e3bcc1a (patch) | |
tree | 0a13b80013d64b7df469d7d5ef3528cfeb00cfec /src | |
parent | a96c9e10cec6b8af30dd1c70214ec9b6db66645f (diff) |
Remove the Return and Fail variants from Pure.expression and add a
`monadic` boolean field to `Let`
Diffstat (limited to 'src')
-rw-r--r-- | src/PrintPure.ml | 48 | ||||
-rw-r--r-- | src/Pure.ml | 69 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 20 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 100 | ||||
-rw-r--r-- | src/Translate.ml | 22 |
5 files changed, 178 insertions, 81 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 98c832a1..77e01c65 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -127,23 +127,19 @@ let mk_ast_formatter (type_defs : T.type_def TypeDefId.Map.t) fun_def_id_to_string; } -let type_id_to_string (fmt : type_formatter) (id : T.type_id) : string = +let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match id with - | T.AdtId id -> fmt.type_def_id_to_string id - | T.Tuple -> "" - | T.Assumed aty -> ( - match aty with - | Box -> - (* Boxes should have been eliminated *) - failwith "Unreachable: boxes should have been eliminated") + | AdtId id -> fmt.type_def_id_to_string id + | Tuple -> "" + | Assumed aty -> ( match aty with Result -> "Result") let rec ty_to_string (fmt : type_formatter) (ty : ty) : string = match ty with | Adt (id, tys) -> ( let tys = List.map (ty_to_string fmt) tys in match id with - | T.Tuple -> "(" ^ String.concat " * " tys ^ ")" - | T.AdtId _ | T.Assumed _ -> + | Tuple -> "(" ^ String.concat " * " tys ^ ")" + | AdtId _ | Assumed _ -> let tys = if tys = [] then "" else " " ^ String.concat " " tys in type_id_to_string fmt id ^ tys) | TypeVar tv -> fmt.type_var_id_to_string tv @@ -226,10 +222,10 @@ let adt_g_value_to_string (fmt : value_formatter) (field_values : 'v list) (ty : ty) : string = let field_values = List.map value_to_string field_values in match ty with - | Adt (T.Tuple, _) -> + | Adt (Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | Adt (T.AdtId def_id, _) -> + | Adt (AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match variant_id with @@ -251,12 +247,19 @@ let adt_g_value_to_string (fmt : value_formatter) let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | Adt (T.Assumed aty, _) -> ( + | Adt (Assumed aty, _) -> ( (* Assumed type *) match aty with - | Box -> - (* Box values should have been eliminated *) - failwith "Unreachable") + | Result -> + let variant_id = Option.get variant_id in + if variant_id = result_return_id then + match field_values with + | [ v ] -> "@Result::Return " ^ v + | _ -> failwith "Result::Return takes exactly one value" + else if variant_id = result_fail_id then ( + assert (field_values = []); + "@Result::Fail") + else failwith "Unreachable: improper variant id for result type") | _ -> failwith "Inconsistent typed value" let rec typed_lvalue_to_string (fmt : value_formatter) (v : typed_lvalue) : @@ -356,11 +359,10 @@ let meta_to_string (fmt : ast_formatter) (meta : meta) : string = let rec expression_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (e : expression) : string = match e with - | Return v -> "return " ^ typed_rvalue_to_string fmt v - | Fail -> "fail" | Value (v, _) -> typed_rvalue_to_string fmt v | Call call -> call_to_string fmt indent indent_incr call - | Let (lv, re, e) -> let_to_string fmt indent indent_incr lv re e + | Let (monadic, lv, re, e) -> + let_to_string fmt indent indent_incr monadic lv re e | Switch (scrutinee, _, body) -> switch_to_string fmt indent indent_incr scrutinee body | Meta (meta, e) -> @@ -382,13 +384,15 @@ and call_to_string (fmt : ast_formatter) (indent : string) if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) - (lv : typed_lvalue) (re : expression) (e : expression) : string = + (monadic : bool) (lv : typed_lvalue) (re : expression) (e : expression) : + string = let indent1 = indent ^ indent_incr in let val_fmt = ast_to_value_formatter fmt in let re = expression_to_string fmt indent1 indent_incr re in let e = expression_to_string fmt indent indent_incr e in let lv = typed_lvalue_to_string val_fmt lv in - "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e + if monadic then lv ^ " <-- " ^ re ^ ";\n" ^ indent ^ e + else "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e and switch_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (scrutinee : typed_rvalue) (body : switch_body) : @@ -410,7 +414,7 @@ and switch_to_string (fmt : ast_formatter) (indent : string) branches in let otherwise = - indent ^ "| _ ->\n" + indent ^ "| _ ->\n" ^ indent1 ^ expression_to_string fmt indent1 indent_incr otherwise in let all_branches = List.append branches [ otherwise ] in diff --git a/src/Pure.ml b/src/Pure.ml index cba0a1f4..375cdb0f 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -18,32 +18,57 @@ module SynthPhaseId = IdGen () module VarId = IdGen () (** Pay attention to the fact that we also define a [VarId] module in Values *) -(* TODO - (** The assumed types for the pure AST. +type assumed_ty = + | Result + (** The assumed types for the pure AST. - In comparison with CFIM: - - we removed `Box` (because it is translated as the identity: `Box T == T`) - - we added `Result`, which is the type used in the error monad. This allows - us to have a unified treatment of expressions. - *) - type assumed_ty = unit + In comparison with CFIM: + - we removed `Box` (because it is translated as the identity: `Box T == T`) + - we added `Result`, which is the type used in the error monad. This allows + us to have a unified treatment of expressions. + *) +[@@deriving show, ord] - type type_id = AdtId of TypeDefId.id | Tuple | Assumed of assumed_ty - [@@deriving show, ord] -*) +let result_return_id = VariantId.of_int 0 + +let result_fail_id = VariantId.of_int 1 + +type type_id = AdtId of TypeDefId.id | Tuple | Assumed of assumed_ty +[@@deriving show, ord] + +(** Ancestor for iter visitor for [ty] *) +class ['self] iter_ty_base = + object (_self : 'self) + inherit [_] VisitorsRuntime.iter + + method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> () + + method visit_type_id : 'env -> type_id -> unit = fun _ _ -> () + + method visit_integer_type : 'env -> T.integer_type -> unit = fun _ _ -> () + end + +(** Ancestor for map visitor for [ty] *) +class ['self] map_ty_base = + object (_self : 'self) + inherit [_] VisitorsRuntime.map + + method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id + + method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id + + method visit_integer_type : 'env -> T.integer_type -> T.integer_type = + fun _ ity -> ity + end type ty = - | Adt of T.type_id * ty list + | Adt of type_id * ty list (** [Adt] encodes ADTs and tuples and assumed types. TODO: what about the ended regions? (ADTs may be parameterized with several region variables. When giving back an ADT value, we may be able to only give back part of the ADT. We need a way to encode such "partial" ADTs. - - TODO: we may want to redefine type_id here, to remove some types like - boxe. But on the other hand, it might introduce a lot of administrative - manipulations just to remove boxe... *) | TypeVar of TypeVarId.id | Bool @@ -58,8 +83,7 @@ type ty = { name = "iter_ty"; variety = "iter"; - ancestors = [ "T.iter_ty_base" ]; - (* Reusing the visitor from Types.ml *) + ancestors = [ "iter_ty_base" ]; nude = true (* Don't inherit [VisitorsRuntime.iter] *); concrete = true; polymorphic = false; @@ -68,8 +92,7 @@ type ty = { name = "map_ty"; variety = "map"; - ancestors = [ "T.map_ty_base" ]; - (* Reusing the visitor from Types.ml *) + ancestors = [ "map_ty_base" ]; nude = true (* Don't inherit [VisitorsRuntime.iter] *); concrete = true; polymorphic = false; @@ -319,14 +342,12 @@ class ['self] map_expression_base = TODO: remove `Return` and `Fail` (they should be "normal" values, I think) *) type expression = - | Return of typed_rvalue - | Fail | Value of typed_rvalue * mplace option | Call of call - | Let of typed_lvalue * expression * expression + | Let of bool * typed_lvalue * expression * expression (** Let binding. - TODO: add a boolean to control whether the let is monadic or not. + The boolean controls whether the let is monadic or not. For instance, in F*: - non-monadic: `let x = ... in ...` - monadic: `x <-- ...; ...` diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 80d4e8bf..9f261386 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -153,10 +153,9 @@ let compute_pretty_names (def : fun_def) : fun_def = let rec update_expression (e : expression) (ctx : pn_ctx) : pn_ctx * expression = match e with - | Return _ | Fail -> (ctx, e) | Value (v, mp) -> update_value v mp ctx | Call call -> update_call call ctx - | Let (lb, re, e) -> update_let lb re e ctx + | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, mp, body) -> update_switch_body scrut mp body ctx | Meta (meta, e) -> update_meta meta e ctx (* *) @@ -174,13 +173,13 @@ let compute_pretty_names (def : fun_def) : fun_def = let call = { call with args } in (ctx, Call call) (* *) - and update_let (lv : typed_lvalue) (re : expression) (e : expression) - (ctx : pn_ctx) : pn_ctx * expression = + and update_let (monadic : bool) (lv : typed_lvalue) (re : expression) + (e : expression) (ctx : pn_ctx) : pn_ctx * expression = let ctx = add_left_constraint lv ctx in let ctx, re = update_expression re ctx in let ctx, e = update_expression e ctx in let lv = update_typed_lvalue ctx lv in - (ctx, Let (lv, re, e)) + (ctx, Let (monadic, lv, re, e)) (* *) and update_switch_body (scrut : typed_rvalue) (mp : mplace option) (body : switch_body) (ctx : pn_ctx) : pn_ctx * expression = @@ -265,8 +264,8 @@ let filter_unused_assignments (def : fun_def) : fun_def = (* TODO *) def -(** Add unit arguments for functions with no arguments *) -let add_unit_arguments (def : fun_def) : fun_def = +(** Add unit arguments for functions with no arguments, and change their return type. *) +let to_monadic (def : fun_def) : fun_def = (* TODO *) def @@ -309,9 +308,10 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_def) : fun_def = (* TODO: deconstruct the monadic bindings into matches *) - (* Add unit arguments for functions with no arguments *) - let def = add_unit_arguments def in - log#ldebug (lazy ("add_unit_arguments:\n" ^ fun_def_to_string ctx def)); + (* Add unit arguments for functions with no arguments, and change their return type. + * TODO: move that at the beginning? *) + let def = to_monadic def in + log#ldebug (lazy ("to_monadic:\n" ^ fun_def_to_string ctx def)); (* We are done *) def diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 5c0250f7..d65e929f 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -15,9 +15,18 @@ module PP = PrintPure (** The local logger *) let log = L.symbolic_to_pure_log +(* TODO : move *) +let binop_can_fail (binop : E.binop) : bool = + match binop with + | BitXor | BitAnd | BitOr | Eq | Lt | Le | Ne | Ge | Gt -> false + | Div | Rem | Add | Sub | Mul -> true + | Shl | Shr -> raise Unimplemented + (* TODO: move *) let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } +let mk_tuple_ty (tys : ty list) : ty = Adt (Tuple, tys) + let mk_typed_rvalue_from_var (v : var) : typed_rvalue = let value = RvPlace (mk_place_from_var v) in let ty = v.ty in @@ -31,7 +40,7 @@ let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue = let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in - let ty = Adt (T.Tuple, tys) in + let ty = Adt (Tuple, tys) in let value = LvAdt { variant_id = None; field_values = vl } in { value; ty } @@ -47,6 +56,18 @@ let ty_as_integer (t : ty) : T.integer_type = let type_def_is_enum (def : T.type_def) : bool = match def.kind with T.Struct _ -> false | Enum _ -> true +let mk_result_fail_rvalue (ty : ty) : typed_rvalue = + let ty = Adt (Assumed Result, [ ty ]) in + let value = RvAdt { variant_id = Some result_fail_id; field_values = [] } in + { value; ty } + +let mk_result_return_rvalue (v : typed_rvalue) : typed_rvalue = + let ty = Adt (Assumed Result, [ v.ty ]) in + let value = + RvAdt { variant_id = Some result_return_id; field_values = [ v ] } + in + { value; ty } + (** Type substitution *) let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = let obj = @@ -132,6 +153,7 @@ type bs_ctx = { fun_context : fun_context; fun_def : A.fun_def; bid : T.RegionGroupId.id option; (** TODO: rename *) + ret_ty : ty; (** The return type - we use it to translate `Panic` *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -253,11 +275,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx * fun_id let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys) -> + | T.Adt (type_id, regions, tys) -> ( (* Can't translate types with regions for now *) assert (regions = []); let tys = List.map translate tys in - Adt (type_id, tys) + match type_id with + | T.AdtId adt_id -> Adt (AdtId adt_id, tys) + | T.Tuple -> Adt (Tuple, tys) + | T.Assumed T.Box -> ( + match tys with + | [ ty ] -> ty + | _ -> failwith "Box type with incorrect number of arguments")) | TypeVar vid -> TypeVar vid | Bool -> Bool | Char -> Char @@ -321,7 +349,8 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = let tys = List.map translate tys in (* Eliminate boxes *) match type_id with - | T.AdtId _ | Tuple -> Adt (type_id, tys) + | AdtId adt_id -> Adt (AdtId adt_id, tys) + | Tuple -> Adt (Tuple, tys) | T.Assumed T.Box -> ( match tys with | [ bty ] -> bty @@ -363,6 +392,11 @@ let rec translate_back_ty (types_infos : TA.type_infos) | T.AdtId _ -> (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows types_infos ty)); + let type_id = + match type_id with + | T.AdtId id -> AdtId id + | T.Tuple | T.Assumed T.Box -> failwith "Unreachable" + in if inside_mut then let tys_t = List.filter_map translate tys in Some (Adt (type_id, tys_t)) @@ -378,7 +412,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) | T.Tuple -> ( (* Tuples can contain borrows (which we eliminated) *) let tys_t = List.filter_map translate tys in - match tys_t with [] -> None | _ -> Some (Adt (T.Tuple, tys_t)))) + match tys_t with [] -> None | _ -> Some (Adt (Tuple, tys_t)))) | TypeVar vid -> wrap (TypeVar vid) | Bool -> wrap Bool | Char -> wrap Char @@ -886,7 +920,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = match e with | S.Return opt_v -> translate_return opt_v ctx - | Panic -> Fail + | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None) | FunCall (call, e) -> translate_function_call call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx @@ -904,7 +938,7 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_rvalue ctx v in - Return v + Value (mk_result_return_rvalue v, None) | Some bid -> (* Backward function *) (* Sanity check *) @@ -918,9 +952,9 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression let field_values = List.map mk_typed_rvalue_from_var backward_outputs in let ret_value = RvAdt { variant_id = None; field_values } in let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in - let ret_ty = Adt (T.Tuple, ret_tys) in + let ret_ty = Adt (Tuple, ret_tys) in let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in - Return ret_value + Value (mk_result_return_rvalue ret_value, None) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : expression = @@ -932,18 +966,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, func = + let ctx, func, monadic = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in - (ctx, func) - | S.Unop E.Not -> (ctx, Unop Not) + (ctx, func, true) + | S.Unop E.Not -> (ctx, Unop Not, false) | S.Unop E.Neg -> ( match args with | [ arg ] -> let int_ty = ty_as_integer arg.ty in - (ctx, Unop (Neg int_ty)) + (* Note that negation can lead to an overflow and thus fail (it + * is thus monadic) *) + (ctx, Unop (Neg int_ty), true) | _ -> failwith "Unreachable") | S.Binop binop -> ( match args with @@ -951,7 +987,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let int_ty0 = ty_as_integer arg0.ty in let int_ty1 = ty_as_integer arg1.ty in assert (int_ty0 = int_ty1); - (ctx, Binop (binop, int_ty0)) + let monadic = binop_can_fail binop in + (ctx, Binop (binop, int_ty0), monadic) | _ -> failwith "Unreachable") in let args = @@ -962,7 +999,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Put together *) - Let (mk_typed_lvalue_from_var dest dest_mplace, call, e) + Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e) and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : expression = @@ -1015,9 +1052,11 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Generate the assignemnts *) + let monadic = false in List.fold_right (fun (var, value) e -> - Let (mk_typed_lvalue_from_var var None, Value (value, None), e)) + Let + (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e)) variables_values e | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in @@ -1078,7 +1117,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (List.combine inputs args_mplaces) in let call = { func; type_params; args } in - Let (output, Call call, e) + let monadic = true in + Let (monadic, output, Call call, e) | V.SynthRet -> (* If we end the abstraction which consumed the return value of the function * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1129,9 +1169,14 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Generate the assignments *) + let monadic = false in List.fold_right (fun (given_back, input_var) e -> - Let (given_back, Value (mk_typed_rvalue_from_var input_var, None), e)) + Let + ( monadic, + given_back, + Value (mk_typed_rvalue_from_var input_var, None), + e )) given_back_inputs e and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) @@ -1153,8 +1198,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let e = translate_expression e ctx in + let monadic = false in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (scrutinee, scrutinee_mplace), e ) | SeAdt _ -> @@ -1180,7 +1227,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.map (fun v -> mk_typed_lvalue_from_var v None) vars in let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in - Let (lv, Value (scrutinee, scrutinee_mplace), branch) + let monadic = false in + Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch) else (* This is not an enumeration: introduce let-bindings for every * field. @@ -1197,11 +1245,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) { value; ty } in let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in + let monadic = false in List.fold_right (fun (fid, var) e -> let field_proj = gen_field_proj fid var in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (field_proj, None), e )) id_var_pairs branch @@ -1209,8 +1259,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let vars = List.map (fun x -> mk_typed_lvalue_from_var x None) vars in + let monadic = false in Let - ( mk_tuple_lvalue vars, + ( monadic, + mk_tuple_lvalue vars, Value (scrutinee, scrutinee_mplace), branch ) | T.Assumed T.Box -> @@ -1220,8 +1272,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in (* We simply introduce an assignment - the box type is the * identity when extracted (`box a == a`) *) + let monadic = false in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (scrutinee, scrutinee_mplace), branch )) | branches -> diff --git a/src/Translate.ml b/src/Translate.ml index d70c1486..63b6027e 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -77,6 +77,13 @@ let translate_function_to_pure (config : C.partial_config) (* Convert the symbolic ASTs to pure ASTs: *) (* Initialize the context *) + let module RegularFunIdMap = SymbolicToPure.RegularFunIdMap in + let forward_sig = RegularFunIdMap.find (A.Local def_id, None) fun_sigs in + let forward_ret_ty = + match forward_sig.sg.outputs with + | [ ty ] -> ty + | _ -> failwith "Unreachable" + in let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let calls = V.FunCallId.Map.empty in @@ -94,6 +101,8 @@ let translate_function_to_pure (config : C.partial_config) { SymbolicToPure.bid = None; (* Dummy for now *) + ret_ty = forward_ret_ty; + (* Will need to be updated for the backward functions *) sv_to_var; var_counter; type_context; @@ -111,7 +120,6 @@ let translate_function_to_pure (config : C.partial_config) in (* We need to initialize the input/output variables *) - let module RegularFunIdMap = SymbolicToPure.RegularFunIdMap in let forward_input_vars = CfimAstUtils.fun_def_get_input_vars fdef in let forward_input_varnames = List.map (fun (v : A.var) -> v.name) forward_input_vars @@ -164,6 +172,10 @@ let translate_function_to_pure (config : C.partial_config) let ctx, backward_outputs = SymbolicToPure.fresh_vars backward_outputs ctx in + let backward_output_tys = + List.map (fun (v : Pure.var) -> v.ty) backward_outputs + in + let backward_ret_ty = SymbolicToPure.mk_tuple_ty backward_output_tys in let backward_inputs = T.RegionGroupId.Map.singleton back_id backward_inputs in @@ -173,7 +185,13 @@ let translate_function_to_pure (config : C.partial_config) (* Put everything in the context *) let ctx = - { ctx with bid = Some back_id; backward_inputs; backward_outputs } + { + ctx with + bid = Some back_id; + ret_ty = backward_ret_ty; + backward_inputs; + backward_outputs; + } in (* Translate *) |