diff options
author | Son Ho | 2023-03-07 23:31:57 +0100 |
---|---|---|
committer | Son HO | 2023-06-04 21:44:33 +0200 |
commit | fa76f1b94e1f68d520b02c0dc1072cb73fa9d8be (patch) | |
tree | 6d301b14dc1909beff34691796c4abae88490408 /compiler | |
parent | a946df8b716695f4d387d852b7e74cf288ddb03e (diff) |
Add a special expression for structure creation/update
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 1 | ||||
-rw-r--r-- | compiler/Extract.ml | 105 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 19 | ||||
-rw-r--r-- | compiler/Pure.ml | 101 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 110 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 24 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 2 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 261 | ||||
-rw-r--r-- | compiler/Translate.ml | 10 |
9 files changed, 451 insertions, 182 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 1baed7fa..15818938 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -118,7 +118,6 @@ let dont_use_field_projectors = ref false (** Deconstructing ADTs which have only one variant with let-bindings is not always supported: this parameter controls whether we use let-bindings in such situations or not. *) - let always_deconstruct_adts_with_matches = ref false (** Controls whether we need to use a state to model the external world diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 0b8d8bdf..8dd5910f 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1617,6 +1617,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Let (_, _, _, _) -> extract_lets ctx fmt inside e | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body | Meta (_, e) -> extract_texpression ctx fmt inside e + | StructUpdate supd -> extract_StructUpdate ctx fmt inside supd | Loop _ -> (* The loop nodes should have been eliminated in {!PureMicroPasses} *) raise (Failure "Unreachable") @@ -1748,57 +1749,15 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) else ctx_get_variant adt_cons.adt_id vid ctx | None -> ctx_get_struct with_opaque_pre adt_cons.adt_id ctx in - let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in - if is_lean_struct then ( - (* TODO: when only one or two fields differ, considering using the with - syntax (peephole optimization) *) - let decl_id = - match adt_cons.adt_id with AdtId id -> id | _ -> assert false - in - let def_kind = - (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls) - .kind - in - let fields = - match def_kind with T.Struct fields -> fields | _ -> assert false - in - let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - F.pp_open_hvbox fmt 0; - F.pp_open_hvbox fmt ctx.indent_incr; - F.pp_print_string fmt "{"; - F.pp_print_space fmt (); - F.pp_open_hvbox fmt ctx.indent_incr; - F.pp_open_hvbox fmt 0; - Collections.List.iter_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (fun ((fid, _), e) -> - F.pp_open_hvbox fmt ctx.indent_incr; - let f = ctx_get_field adt_cons.adt_id fid ctx in - F.pp_print_string fmt f; - F.pp_print_string fmt " := "; - F.pp_open_hvbox fmt ctx.indent_incr; - extract_texpression ctx fmt true e; - F.pp_close_box fmt (); - F.pp_close_box fmt ()) - (List.combine fields args); - F.pp_close_box fmt (); - F.pp_close_box fmt (); - F.pp_close_box fmt (); - F.pp_print_space fmt (); - F.pp_print_string fmt "}"; - F.pp_close_box fmt ()) - else - let use_parentheses = inside && args <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - Collections.List.iter - (fun v -> - F.pp_print_space fmt (); - extract_texpression ctx fmt true v) - args; - if use_parentheses then F.pp_print_string fmt ")" + let use_parentheses = inside && args <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + Collections.List.iter + (fun v -> + F.pp_print_space fmt (); + extract_texpression ctx fmt true v) + args; + if use_parentheses then F.pp_print_string fmt ")" (** Subcase of the app case: ADT field projector. *) and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) @@ -2078,6 +2037,50 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Close the box for the whole expression *) F.pp_close_box fmt () +and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) + (_inside : bool) (supd : struct_update) : unit = + (* We can't update a subset of the fields in Coq (i.e., we can do + [{| x:= 3; y := 4 |}], but there is no syntax for [{| s with x := 3 |}]) *) + assert (!backend <> Coq || supd.init = None); + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr; + let lb, rb = + match !backend with Lean | FStar -> ("{", "}") | Coq -> ("{|", "|}") + in + F.pp_print_string fmt lb; + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + if supd.init <> None then ( + let var_name = ctx_get_var (Option.get supd.init) ctx in + F.pp_print_string fmt var_name; + F.pp_print_space fmt (); + F.pp_print_string fmt "where"; + F.pp_print_space fmt ()); + F.pp_open_hvbox fmt 0; + let delimiter = match !backend with Lean -> "," | Coq | FStar -> ";" in + let assign = match !backend with Coq | Lean -> ":=" | FStar -> "=" in + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt delimiter; + F.pp_print_space fmt ()) + (fun (fid, fe) -> + F.pp_open_hvbox fmt ctx.indent_incr; + let f = ctx_get_field (AdtId supd.struct_id) fid ctx in + F.pp_print_string fmt f; + F.pp_print_string fmt (" " ^ assign); + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + extract_texpression ctx fmt true fe; + F.pp_close_box fmt (); + F.pp_close_box fmt ()) + supd.updates; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_print_space fmt (); + F.pp_print_string fmt rb; + F.pp_close_box fmt () + (** Insert a space, if necessary *) let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = if !space then space := false else F.pp_print_space fmt () diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 2c8d5081..3f35a023 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -517,6 +517,25 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Loop loop -> let e = loop_to_string fmt indent indent_incr loop in if inside then "(" ^ e ^ ")" else e + | StructUpdate supd -> + let s = + match supd.init with + | None -> "" + | Some vid -> " " ^ fmt.var_id_to_string vid ^ " with" + in + let field_names = Option.get (fmt.adt_field_names supd.struct_id None) in + let indent1 = indent ^ indent_incr in + let indent2 = indent1 ^ indent_incr in + let fields = + List.map + (fun (fid, fe) -> + let field = FieldId.nth field_names fid in + let fe = texpression_to_string fmt false indent2 indent_incr fe in + "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";") + supd.updates + in + let bl = if fields = [] then "" else "\n" ^ indent in + "{" ^ s ^ String.concat "" fields ^ bl ^ "}" | Meta (meta, e) -> ( let meta_s = meta_to_string fmt meta in let e = texpression_to_string fmt inside indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 5b2fca7d..4a00dfb2 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -66,27 +66,67 @@ let error_out_of_fuel_id = VariantId.of_int 1 let fuel_zero_id = VariantId.of_int 0 let fuel_succ_id = VariantId.of_int 1 -type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty -[@@deriving show, ord] +type type_decl_id = TypeDeclId.id [@@deriving show, ord] +type type_var_id = TypeVarId.id [@@deriving show, ord] (** Ancestor for iter visitor for [ty] *) -class ['self] iter_ty_base = +class ['self] iter_type_id_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> () - method visit_type_id : 'env -> type_id -> unit = fun _ _ -> () + method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> () + method visit_assumed_ty : 'env -> assumed_ty -> unit = fun _ _ -> () + end + +(** Ancestor for map visitor for [ty] *) +class ['self] map_type_id_base = + object (_self : 'self) + inherit [_] VisitorsRuntime.map + + method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id = + fun _ x -> x + + method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty = fun _ x -> x + end + +type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty +[@@deriving + show, + ord, + visitors + { + name = "iter_type_id"; + variety = "iter"; + ancestors = [ "iter_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + concrete = true; + polymorphic = false; + }, + visitors + { + name = "map_type_id"; + variety = "map"; + ancestors = [ "map_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + concrete = true; + polymorphic = false; + }] + +(** Ancestor for iter visitor for [ty] *) +class ['self] iter_ty_base = + object (_self : 'self) + inherit [_] iter_type_id + method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> () method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () end (** Ancestor for map visitor for [ty] *) class ['self] map_ty_base = object (_self : 'self) - inherit [_] VisitorsRuntime.map - method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id - method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id + inherit [_] map_type_id + method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x method visit_integer_type : 'env -> integer_type -> integer_type = - fun _ ity -> ity + fun _ x -> x end type ty = @@ -98,7 +138,7 @@ type ty = be able to only give back part of the ADT. We need a way to encode such "partial" ADTs. *) - | TypeVar of TypeVarId.id + | TypeVar of type_var_id | Bool | Char | Integer of integer_type @@ -362,6 +402,7 @@ type qualif_id = *) type qualif = { id : qualif_id; type_args : ty list } [@@deriving show] +type field_id = FieldId.id [@@deriving show, ord] type var_id = VarId.id [@@deriving show, ord] (** Ancestor for {!iter_expression} visitor *) @@ -372,6 +413,8 @@ class ['self] iter_expression_base = method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> () + method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> () + method visit_field_id : 'env -> field_id -> unit = fun _ _ -> () end (** Ancestor for {!map_expression} visitor *) @@ -385,6 +428,11 @@ class ['self] map_expression_base = method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x + + method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id = + fun _ x -> x + + method visit_field_id : 'env -> field_id -> field_id = fun _ x -> x end (** Ancestor for {!reduce_expression} visitor *) @@ -398,6 +446,11 @@ class virtual ['self] reduce_expression_base = method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero + + method visit_type_decl_id : 'env -> type_decl_id -> 'a = + fun _ _ -> self#zero + + method visit_field_id : 'env -> field_id -> 'a = fun _ _ -> self#zero end (** Ancestor for {!mapreduce_expression} visitor *) @@ -416,6 +469,12 @@ class virtual ['self] mapreduce_expression_base = method visit_loop_id : 'env -> loop_id -> loop_id * 'a = fun _ x -> (x, self#zero) + + method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_field_id : 'env -> field_id -> field_id * 'a = + fun _ x -> (x, self#zero) end (** **Rk.:** here, {!expression} is not at all equivalent to the expressions @@ -477,6 +536,7 @@ type expression = *) | Switch of texpression * switch_body | Loop of loop (** See the comments for {!loop} *) + | StructUpdate of struct_update (** See the comments for {!struct_update} *) | Meta of (meta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list @@ -508,6 +568,27 @@ and loop = { loop_body : texpression; } +(** Structure creation/update. + + This expression is not strictly necessary, but allows for nice syntax, which + is important to work easily with the generated code. + + If {!init} is [None], it defines a structure creation: + {[ + { x := 3; y := true; } + ]} + + If {!init} is [Some], it defines a structure update: + {[ + { s with x := 3 } + ]} + *) +and struct_update = { + struct_id : type_decl_id; + init : var_id option; + updates : (field_id * texpression) list; +} + and texpression = { e : expression; ty : ty } (** Meta-value (converted to an expression). It is important that the content diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 3614487e..7e6ca822 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -388,6 +388,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Loop loop -> update_loop loop ctx + | StructUpdate supd -> update_struct_update supd ctx | Meta (meta, e) -> update_meta meta e ctx in (ctx, { e; ty }) @@ -474,6 +475,18 @@ let compute_pretty_names (def : fun_decl) : fun_decl = } in (ctx, Loop loop) + and update_struct_update (supd : struct_update) (ctx : pn_ctx) : + pn_ctx * expression = + let { struct_id; init; updates } = supd in + let ctx, updates = + List.fold_left_map + (fun ctx (fid, fe) -> + let ctx, fe = update_texpression fe ctx in + (ctx, (fid, fe))) + ctx updates + in + let supd = { struct_id; init; updates } in + (ctx, StructUpdate supd) (* *) and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = @@ -536,6 +549,89 @@ let remove_meta (def : fun_decl) : fun_decl = let body = { body with body = PureUtils.remove_meta body.body } in { def with body = Some body } +(** Introduce the special structure create/update expressions. + + Upon generating the pure code, we introduce structure values by using + the structure constructors: + {[ + Cons x0 ... xn + ]} + + This micro-pass turns those into expressions which use structure syntax: + {[ + { + f0 := x0; + ... + fn := xn; + } + ]} + *) +let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_texpression env (e : texpression) = + match e.e with + | App _ -> ( + let app, args = destruct_apps e in + let ignore () = + mk_apps + (self#visit_texpression env app) + (List.map (self#visit_texpression env) args) + in + match app.e with + | Qualif + { + id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; + type_args = _; + } -> + (* Lookup the def *) + let decl = + TypeDeclId.Map.find adt_id ctx.type_context.type_decls + in + (* Check that there are as many arguments as there are fields - note + that the def should have a body (otherwise we couldn't use the + constructor) *) + let fields = TypesUtils.type_decl_get_fields decl None in + if List.length fields = List.length args then + (* Check if the definition is recursive *) + let is_rec = + match + TypeDeclId.Map.find adt_id + ctx.type_context.type_decls_groups + with + | NonRec _ -> false + | Rec _ -> true + in + (* Convert, if possible - note that for now for Lean and Coq + we don't support the structure syntax on recursive structures *) + if + (!Config.backend <> Lean && !Config.backend <> Coq) + || not is_rec + then + let struct_id = adt_id in + let init = None in + let updates = + FieldId.mapi + (fun fid fe -> (fid, self#visit_texpression env fe)) + args + in + let ne = { struct_id; init; updates } in + let nty = e.ty in + { e = StructUpdate ne; ty = nty } + else ignore () + else ignore () + | _ -> ignore ()) + | _ -> super#visit_texpression env e + end + in + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } + (** Inline the useless variable (re-)assignments: A lot of intermediate variable assignments are introduced through the @@ -604,6 +700,7 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) true (* primitive function call *) | FunOrOp (Fun _) -> false (* non-primitive function call *) | _ -> false) + | StructUpdate _ -> true (* ADT constructor *) | _ -> false in let filter = @@ -737,6 +834,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) method! visit_texpression env e = match e.e with | Var _ | Const _ -> fun _ -> false + | StructUpdate _ -> + (* There shouldn't be monadic calls in structure updates - also + note that by returning [false] we are conservative: we might + *prevent* possible optimisations (i.e., filtering some function + calls), which is sound. *) + fun _ -> false | Let (_, _, re, e) -> ( match opt_destruct_function_call re with | None -> fun () -> self#visit_texpression env e () @@ -829,7 +932,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) | Var _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) - | Abs _ -> + | StructUpdate _ | Abs _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -1602,6 +1705,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Introduce the special structure create/update expressions *) + let def = intro_struct_updates ctx def in + log#ldebug + (lazy ("intro_struct_updates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Inline the useless variable reassignments *) let inline_named_vars = true in let inline_pure = true in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 1871f1bc..018ea6b5 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -194,6 +194,30 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body + | StructUpdate supd -> + (* Check the init value *) + (if Option.is_some supd.init then + match VarId.Map.find_opt (Option.get supd.init) ctx.env with + | None -> () + | Some ty -> assert (ty = e.ty)); + (* Check the fields *) + (* Retrieve and check the expected field type *) + let adt_id, adt_type_args = + match e.ty with + | Adt (type_id, tys) -> (type_id, tys) + | _ -> raise (Failure "Unreachable") + in + assert (adt_id = AdtId supd.struct_id); + let variant_id = None in + let expected_field_tys = + get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args + in + List.iter + (fun (fid, fe) -> + let expected_field_ty = FieldId.nth expected_field_tys fid in + assert (expected_field_ty = fe.ty); + check_texpression ctx fe) + supd.updates | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 40005671..1f5d1ed8 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -157,7 +157,7 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | Const _ | App _ | Abs _ | Qualif _ -> false + | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 2c103177..5252495d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -22,7 +22,8 @@ type type_context = { This map is empty when we translate the types, then contains all the translated types when we translate the functions. *) - types_infos : TA.type_infos; (* TODO: rename to type_infos *) + type_infos : TA.type_infos; + recursive_decls : T.TypeDeclId.Set.t; } [@@deriving show] @@ -451,8 +452,8 @@ let translate_type_decl (def : T.type_decl) : type_decl = (preserve all borrows, etc.) *) -let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = - let translate = translate_fwd_ty types_infos in +let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = + let translate = translate_fwd_ty type_infos in match ty with | T.Adt (type_id, regions, tys) -> ( (* Can't translate types with regions for now *) @@ -463,7 +464,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = match type_id with | AdtId _ | T.Assumed (T.Vec | T.Option) -> (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); + assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); let type_id = match type_id with | AdtId adt_id -> AdtId adt_id @@ -479,7 +480,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = | T.Assumed T.Box -> ( (* We eliminate boxes *) (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); + assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); match t_tys with | [ bty ] -> bty | _ -> @@ -494,17 +495,17 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = | Integer int_ty -> Integer int_ty | Str -> Str | Array ty -> - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); Array (translate ty) | Slice ty -> - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); Slice (translate ty) | Ref (_, rty, _) -> translate rty (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = - let types_infos = ctx.type_context.types_infos in - translate_fwd_ty types_infos ty + let type_infos = ctx.type_context.type_infos in + translate_fwd_ty type_infos ty (** Translate a type, when some regions may have ended. @@ -512,9 +513,9 @@ let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = [inside_mut]: are we inside a mutable borrow? *) -let rec translate_back_ty (types_infos : TA.type_infos) +let rec translate_back_ty (type_infos : TA.type_infos) (keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option = - let translate = translate_back_ty types_infos keep_region inside_mut in + let translate = translate_back_ty type_infos keep_region inside_mut in (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with @@ -522,7 +523,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) match type_id with | T.AdtId _ | Assumed (T.Vec | T.Option) -> (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); let type_id = match type_id with | T.AdtId id -> AdtId id @@ -536,7 +537,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); (* Eliminate the box *) match tys with | [ bty ] -> translate bty @@ -560,10 +561,10 @@ let rec translate_back_ty (types_infos : TA.type_infos) | Integer int_ty -> wrap (Integer int_ty) | Str -> wrap Str | Array ty -> ( - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); match translate ty with None -> None | Some ty -> Some (Array ty)) | Slice ty -> ( - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); match translate ty with None -> None | Some ty -> Some (Slice ty)) | Ref (r, rty, rkind) -> ( match rkind with @@ -574,14 +575,14 @@ let rec translate_back_ty (types_infos : TA.type_infos) (* Dive in, remembering the fact that we are inside a mutable borrow *) let inside_mut = true in if keep_region r then - translate_back_ty types_infos keep_region inside_mut rty + translate_back_ty type_infos keep_region inside_mut rty else None) (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option = - let types_infos = ctx.type_context.types_infos in - translate_back_ty types_infos keep_region inside_mut ty + let type_infos = ctx.type_context.type_infos in + translate_back_ty type_infos keep_region inside_mut ty (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -670,7 +671,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) of the forward function) which we use as hints to generate pretty names. *) let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) + (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) @@ -691,7 +692,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) * - the current backward function (if it is a backward function) *) let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in + let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in (* For the backward functions: for now we don't supported nested borrows, * so just check that there aren't parent regions *) assert (T.RegionGroupId.Set.is_empty parents); @@ -706,7 +707,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) | T.Var r -> T.RegionVarId.Set.mem r regions in let inside_mut = false in - translate_back_ty types_infos keep_region inside_mut + translate_back_ty type_infos keep_region inside_mut in (* Compute the additinal inputs for the current function, if it is a backward * function *) @@ -762,7 +763,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) match gid with | None -> (* This is a forward function: there is one (unnamed) output *) - ([ None ], [ translate_fwd_ty types_infos sg.output ]) + ([ None ], [ translate_fwd_ty type_infos sg.output ]) | Some gid -> (* This is a backward function: there might be several outputs. The outputs are the borrows inside the regions of the abstractions @@ -2057,11 +2058,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match branches with | [] -> raise (Failure "Unreachable") | [ (variant_id, svl, branch) ] - (* TODO: always introduce a match, and use micro-passes to turn the - the match into a let *) when not (TypesUtils.ty_is_custom_adt sv.V.sv_ty - && !Config.always_deconstruct_adts_with_matches) -> ( + && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. We can decompose the ADT value with a let-binding, unless @@ -2069,94 +2068,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) we *ignore* this branch (and go to the next one) if the ADT is a custom adt, and [always_deconstruct_adts_with_matches] is true. *) - let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in - let ctx, vars = fresh_vars_for_symbolic_values svl ctx in - let branch = translate_expression branch ctx in - match type_id with - | T.AdtId adt_id -> - (* Detect if this is an enumeration or not *) - let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in - let is_enum = type_decl_is_enum tdef in - (* We deconstruct the ADT with a let-binding in two situations: - - if the ADT is an enumeration (which must have exactly one branch) - - if we forbid using field projectors. - - We forbid using field projectors in some situations, for example - if the backend is Coq. See '!Config.dont_use_field_projectors}. - *) - let use_let = is_enum || !Config.dont_use_field_projectors in - if use_let then - (* Introduce a let binding which expands the ADT *) - let lvars = - List.map (fun v -> mk_typed_pattern_from_var v None) vars - in - let lv = mk_adt_pattern scrutinee.ty variant_id lvars in - let monadic = false in - - mk_let monadic lv - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - else - (* This is not an enumeration: introduce let-bindings for every - * field. - * We use the [dest] variable in order not to have to recompute - * the type of the result of the projection... *) - let adt_id, type_args = - match scrutinee.ty with - | Adt (adt_id, tys) -> (adt_id, tys) - | _ -> raise (Failure "Unreachable") - in - let gen_field_proj (field_id : FieldId.id) (dest : var) : - texpression = - let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args } in - let proj_e = Qualif qualif in - let proj_ty = mk_arrow scrutinee.ty dest.ty in - let proj = { e = proj_e; ty = proj_ty } in - mk_app proj scrutinee - in - let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in - let monadic = false in - List.fold_right - (fun (fid, var) e -> - let field_proj = gen_field_proj fid var in - mk_let monadic - (mk_typed_pattern_from_var var None) - field_proj e) - id_var_pairs branch - | T.Tuple -> - let vars = - List.map (fun x -> mk_typed_pattern_from_var x None) vars - in - let monadic = false in - mk_let monadic - (mk_simpl_tuple_pattern vars) - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - | T.Assumed T.Box -> - (* There should be exactly one variable *) - let var = - match vars with - | [ v ] -> v - | _ -> raise (Failure "Unreachable") - in - (* We simply introduce an assignment - the box type is the - * identity when extracted ([box a = a]) *) - let monadic = false in - mk_let monadic - (mk_typed_pattern_from_var var None) - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - | T.Assumed T.Vec -> - (* We can't expand vector values: we can access the fields only - * through the functions provided by the API (note that we don't - * know how to expand a vector, because it has a variable number - * of fields!) *) - raise (Failure "Can't expand a vector value") - | T.Assumed T.Option -> - (* We shouldn't get there in the "one-branch" case: options have - * two variants *) - raise (Failure "Unreachable")) + translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace + variant_id svl branch ctx | branches -> let translate_branch (variant_id : T.VariantId.id option) (svl : V.symbolic_value list) (branch : S.expression) : @@ -2225,6 +2138,120 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } +(* Translate and [ExpandAdt] when there is no branching (i.e., one branch). + + There are several possibilities: + - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding: + {[ + let Cons x0 ... xn = y in + ... + ]} + + - if the ADT is a structure, we attempt to introduce one let-binding per field: + {[ + let x0 = y.f0 in + ... + let xn = y.fn in + ... + ]} + + Of course, this is not always possible depending on the backend. + Also, recursive structures, and more specifically structures mutually recursive + with inductives, are usually not supported. We define such recursive structures + as inductives, in which case it is not always possible to use a notation + for the field projections. +*) +and translate_ExpandAdt_one_branch (sv : V.symbolic_value) + (scrutinee : texpression) (scrutinee_mplace : mplace option) + (variant_id : variant_id option) (svl : V.symbolic_value list) + (branch : S.expression) (ctx : bs_ctx) : texpression = + (* TODO: always introduce a match, and use micro-passes to turn the + the match into a let? *) + let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let ctx, vars = fresh_vars_for_symbolic_values svl ctx in + let branch = translate_expression branch ctx in + match type_id with + | T.AdtId adt_id -> + (* Detect if this is an enumeration or not *) + let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in + let is_enum = type_decl_is_enum tdef in + (* We deconstruct the ADT with a let-binding in two situations: + - if the ADT is an enumeration (which must have exactly one branch) + - if we forbid using field projectors. + *) + let is_rec_def = + T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls + in + let use_let = + is_enum + || !Config.dont_use_field_projectors + (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *) + || (!Config.backend = Lean && is_rec_def) + in + if use_let then + (* Introduce a let binding which expands the ADT *) + let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in + let lv = mk_adt_pattern scrutinee.ty variant_id lvars in + let monadic = false in + + mk_let monadic lv + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + else + (* This is not an enumeration: introduce let-bindings for every + * field. + * We use the [dest] variable in order not to have to recompute + * the type of the result of the projection... *) + let adt_id, type_args = + match scrutinee.ty with + | Adt (adt_id, tys) -> (adt_id, tys) + | _ -> raise (Failure "Unreachable") + in + let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = + let proj_kind = { adt_id; field_id } in + let qualif = { id = Proj proj_kind; type_args } in + let proj_e = Qualif qualif in + let proj_ty = mk_arrow scrutinee.ty dest.ty in + let proj = { e = proj_e; ty = proj_ty } in + mk_app proj scrutinee + in + let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in + let monadic = false in + List.fold_right + (fun (fid, var) e -> + let field_proj = gen_field_proj fid var in + mk_let monadic (mk_typed_pattern_from_var var None) field_proj e) + id_var_pairs branch + | T.Tuple -> + let vars = List.map (fun x -> mk_typed_pattern_from_var x None) vars in + let monadic = false in + mk_let monadic + (mk_simpl_tuple_pattern vars) + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + | T.Assumed T.Box -> + (* There should be exactly one variable *) + let var = + match vars with [ v ] -> v | _ -> raise (Failure "Unreachable") + in + (* We simply introduce an assignment - the box type is the + * identity when extracted ([box a = a]) *) + let monadic = false in + mk_let monadic + (mk_typed_pattern_from_var var None) + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + | T.Assumed T.Vec -> + (* We can't expand vector values: we can access the fields only + * through the functions provided by the API (note that we don't + * know how to expand a vector, because it has a variable number + * of fields!) *) + raise (Failure "Can't expand a vector value") + | T.Assumed T.Option -> + (* We shouldn't get there in the "one-branch" case: options have + * two variants *) + raise (Failure "Unreachable") + and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) (sv : V.symbolic_value) (v : V.typed_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2445,7 +2472,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> assert ( - not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty)); + not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty)); (None, ctx_translate_fwd_ty !ctx ty)) tys in @@ -2769,7 +2796,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = functions) *) let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (types_infos : TA.type_infos) + (type_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdNotLoopMap.t = (* For every function, translate the signatures of: @@ -2781,7 +2808,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) = (* The forward function *) let fwd_sg = - translate_fun_sig fun_infos fun_id types_infos sg input_names None + translate_fun_sig fun_infos fun_id type_infos sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) @@ -2789,7 +2816,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id types_infos sg input_names + translate_fun_sig fun_infos fun_id type_infos sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 6bff936b..347052a8 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -77,11 +77,19 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let fuel, var_counter = Pure.VarId.fresh var_counter in let calls = V.FunCallId.Map.empty in let abstractions = V.AbstractionId.Map.empty in + let recursive_type_decls = + T.TypeDeclId.Set.of_list + (List.filter_map + (fun (tid, g) -> + match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid) + (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups)) + in let type_context = { - SymbolicToPure.types_infos = type_context.type_infos; + SymbolicToPure.type_infos = type_context.type_infos; llbc_type_decls = type_context.type_decls; type_decls = pure_type_decls; + recursive_decls = recursive_type_decls; } in let fun_context = |