summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-03-07 23:31:57 +0100
committerSon HO2023-06-04 21:44:33 +0200
commitfa76f1b94e1f68d520b02c0dc1072cb73fa9d8be (patch)
tree6d301b14dc1909beff34691796c4abae88490408 /compiler
parenta946df8b716695f4d387d852b7e74cf288ddb03e (diff)
Add a special expression for structure creation/update
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Config.ml1
-rw-r--r--compiler/Extract.ml105
-rw-r--r--compiler/PrintPure.ml19
-rw-r--r--compiler/Pure.ml101
-rw-r--r--compiler/PureMicroPasses.ml110
-rw-r--r--compiler/PureTypeCheck.ml24
-rw-r--r--compiler/PureUtils.ml2
-rw-r--r--compiler/SymbolicToPure.ml261
-rw-r--r--compiler/Translate.ml10
9 files changed, 451 insertions, 182 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 1baed7fa..15818938 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -118,7 +118,6 @@ let dont_use_field_projectors = ref false
(** Deconstructing ADTs which have only one variant with let-bindings is not always
supported: this parameter controls whether we use let-bindings in such situations or not.
*)
-
let always_deconstruct_adts_with_matches = ref false
(** Controls whether we need to use a state to model the external world
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 0b8d8bdf..8dd5910f 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1617,6 +1617,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Let (_, _, _, _) -> extract_lets ctx fmt inside e
| Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body
| Meta (_, e) -> extract_texpression ctx fmt inside e
+ | StructUpdate supd -> extract_StructUpdate ctx fmt inside supd
| Loop _ ->
(* The loop nodes should have been eliminated in {!PureMicroPasses} *)
raise (Failure "Unreachable")
@@ -1748,57 +1749,15 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
else ctx_get_variant adt_cons.adt_id vid ctx
| None -> ctx_get_struct with_opaque_pre adt_cons.adt_id ctx
in
- let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in
- if is_lean_struct then (
- (* TODO: when only one or two fields differ, considering using the with
- syntax (peephole optimization) *)
- let decl_id =
- match adt_cons.adt_id with AdtId id -> id | _ -> assert false
- in
- let def_kind =
- (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls)
- .kind
- in
- let fields =
- match def_kind with T.Struct fields -> fields | _ -> assert false
- in
- let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
- F.pp_open_hvbox fmt 0;
- F.pp_open_hvbox fmt ctx.indent_incr;
- F.pp_print_string fmt "{";
- F.pp_print_space fmt ();
- F.pp_open_hvbox fmt ctx.indent_incr;
- F.pp_open_hvbox fmt 0;
- Collections.List.iter_link
- (fun () ->
- F.pp_print_string fmt ",";
- F.pp_print_space fmt ())
- (fun ((fid, _), e) ->
- F.pp_open_hvbox fmt ctx.indent_incr;
- let f = ctx_get_field adt_cons.adt_id fid ctx in
- F.pp_print_string fmt f;
- F.pp_print_string fmt " := ";
- F.pp_open_hvbox fmt ctx.indent_incr;
- extract_texpression ctx fmt true e;
- F.pp_close_box fmt ();
- F.pp_close_box fmt ())
- (List.combine fields args);
- F.pp_close_box fmt ();
- F.pp_close_box fmt ();
- F.pp_close_box fmt ();
- F.pp_print_space fmt ();
- F.pp_print_string fmt "}";
- F.pp_close_box fmt ())
- else
- let use_parentheses = inside && args <> [] in
- if use_parentheses then F.pp_print_string fmt "(";
- F.pp_print_string fmt cons;
- Collections.List.iter
- (fun v ->
- F.pp_print_space fmt ();
- extract_texpression ctx fmt true v)
- args;
- if use_parentheses then F.pp_print_string fmt ")"
+ let use_parentheses = inside && args <> [] in
+ if use_parentheses then F.pp_print_string fmt "(";
+ F.pp_print_string fmt cons;
+ Collections.List.iter
+ (fun v ->
+ F.pp_print_space fmt ();
+ extract_texpression ctx fmt true v)
+ args;
+ if use_parentheses then F.pp_print_string fmt ")"
(** Subcase of the app case: ADT field projector. *)
and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
@@ -2078,6 +2037,50 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Close the box for the whole expression *)
F.pp_close_box fmt ()
+and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
+ (_inside : bool) (supd : struct_update) : unit =
+ (* We can't update a subset of the fields in Coq (i.e., we can do
+ [{| x:= 3; y := 4 |}], but there is no syntax for [{| s with x := 3 |}]) *)
+ assert (!backend <> Coq || supd.init = None);
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ let lb, rb =
+ match !backend with Lean | FStar -> ("{", "}") | Coq -> ("{|", "|}")
+ in
+ F.pp_print_string fmt lb;
+ F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ if supd.init <> None then (
+ let var_name = ctx_get_var (Option.get supd.init) ctx in
+ F.pp_print_string fmt var_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "where";
+ F.pp_print_space fmt ());
+ F.pp_open_hvbox fmt 0;
+ let delimiter = match !backend with Lean -> "," | Coq | FStar -> ";" in
+ let assign = match !backend with Coq | Lean -> ":=" | FStar -> "=" in
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt delimiter;
+ F.pp_print_space fmt ())
+ (fun (fid, fe) ->
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ let f = ctx_get_field (AdtId supd.struct_id) fid ctx in
+ F.pp_print_string fmt f;
+ F.pp_print_string fmt (" " ^ assign);
+ F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ extract_texpression ctx fmt true fe;
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ())
+ supd.updates;
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt rb;
+ F.pp_close_box fmt ()
+
(** Insert a space, if necessary *)
let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
if !space then space := false else F.pp_print_space fmt ()
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 2c8d5081..3f35a023 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -517,6 +517,25 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
| Loop loop ->
let e = loop_to_string fmt indent indent_incr loop in
if inside then "(" ^ e ^ ")" else e
+ | StructUpdate supd ->
+ let s =
+ match supd.init with
+ | None -> ""
+ | Some vid -> " " ^ fmt.var_id_to_string vid ^ " with"
+ in
+ let field_names = Option.get (fmt.adt_field_names supd.struct_id None) in
+ let indent1 = indent ^ indent_incr in
+ let indent2 = indent1 ^ indent_incr in
+ let fields =
+ List.map
+ (fun (fid, fe) ->
+ let field = FieldId.nth field_names fid in
+ let fe = texpression_to_string fmt false indent2 indent_incr fe in
+ "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";")
+ supd.updates
+ in
+ let bl = if fields = [] then "" else "\n" ^ indent in
+ "{" ^ s ^ String.concat "" fields ^ bl ^ "}"
| Meta (meta, e) -> (
let meta_s = meta_to_string fmt meta in
let e = texpression_to_string fmt inside indent indent_incr e in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 5b2fca7d..4a00dfb2 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -66,27 +66,67 @@ let error_out_of_fuel_id = VariantId.of_int 1
let fuel_zero_id = VariantId.of_int 0
let fuel_succ_id = VariantId.of_int 1
-type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
-[@@deriving show, ord]
+type type_decl_id = TypeDeclId.id [@@deriving show, ord]
+type type_var_id = TypeVarId.id [@@deriving show, ord]
(** Ancestor for iter visitor for [ty] *)
-class ['self] iter_ty_base =
+class ['self] iter_type_id_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.iter
- method visit_id : 'env -> TypeVarId.id -> unit = fun _ _ -> ()
- method visit_type_id : 'env -> type_id -> unit = fun _ _ -> ()
+ method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> ()
+ method visit_assumed_ty : 'env -> assumed_ty -> unit = fun _ _ -> ()
+ end
+
+(** Ancestor for map visitor for [ty] *)
+class ['self] map_type_id_base =
+ object (_self : 'self)
+ inherit [_] VisitorsRuntime.map
+
+ method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id =
+ fun _ x -> x
+
+ method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty = fun _ x -> x
+ end
+
+type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty
+[@@deriving
+ show,
+ ord,
+ visitors
+ {
+ name = "iter_type_id";
+ variety = "iter";
+ ancestors = [ "iter_type_id_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "map_type_id";
+ variety = "map";
+ ancestors = [ "map_type_id_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ concrete = true;
+ polymorphic = false;
+ }]
+
+(** Ancestor for iter visitor for [ty] *)
+class ['self] iter_ty_base =
+ object (_self : 'self)
+ inherit [_] iter_type_id
+ method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> ()
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
end
(** Ancestor for map visitor for [ty] *)
class ['self] map_ty_base =
object (_self : 'self)
- inherit [_] VisitorsRuntime.map
- method visit_id : 'env -> TypeVarId.id -> TypeVarId.id = fun _ id -> id
- method visit_type_id : 'env -> type_id -> type_id = fun _ id -> id
+ inherit [_] map_type_id
+ method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
method visit_integer_type : 'env -> integer_type -> integer_type =
- fun _ ity -> ity
+ fun _ x -> x
end
type ty =
@@ -98,7 +138,7 @@ type ty =
be able to only give back part of the ADT. We need a way to encode
such "partial" ADTs.
*)
- | TypeVar of TypeVarId.id
+ | TypeVar of type_var_id
| Bool
| Char
| Integer of integer_type
@@ -362,6 +402,7 @@ type qualif_id =
*)
type qualif = { id : qualif_id; type_args : ty list } [@@deriving show]
+type field_id = FieldId.id [@@deriving show, ord]
type var_id = VarId.id [@@deriving show, ord]
(** Ancestor for {!iter_expression} visitor *)
@@ -372,6 +413,8 @@ class ['self] iter_expression_base =
method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
method visit_qualif : 'env -> qualif -> unit = fun _ _ -> ()
method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> ()
+ method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> ()
+ method visit_field_id : 'env -> field_id -> unit = fun _ _ -> ()
end
(** Ancestor for {!map_expression} visitor *)
@@ -385,6 +428,11 @@ class ['self] map_expression_base =
method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x
method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x
+
+ method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id =
+ fun _ x -> x
+
+ method visit_field_id : 'env -> field_id -> field_id = fun _ x -> x
end
(** Ancestor for {!reduce_expression} visitor *)
@@ -398,6 +446,11 @@ class virtual ['self] reduce_expression_base =
method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero
method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero
method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero
+
+ method visit_type_decl_id : 'env -> type_decl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_field_id : 'env -> field_id -> 'a = fun _ _ -> self#zero
end
(** Ancestor for {!mapreduce_expression} visitor *)
@@ -416,6 +469,12 @@ class virtual ['self] mapreduce_expression_base =
method visit_loop_id : 'env -> loop_id -> loop_id * 'a =
fun _ x -> (x, self#zero)
+
+ method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_field_id : 'env -> field_id -> field_id * 'a =
+ fun _ x -> (x, self#zero)
end
(** **Rk.:** here, {!expression} is not at all equivalent to the expressions
@@ -477,6 +536,7 @@ type expression =
*)
| Switch of texpression * switch_body
| Loop of loop (** See the comments for {!loop} *)
+ | StructUpdate of struct_update (** See the comments for {!struct_update} *)
| Meta of (meta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
@@ -508,6 +568,27 @@ and loop = {
loop_body : texpression;
}
+(** Structure creation/update.
+
+ This expression is not strictly necessary, but allows for nice syntax, which
+ is important to work easily with the generated code.
+
+ If {!init} is [None], it defines a structure creation:
+ {[
+ { x := 3; y := true; }
+ ]}
+
+ If {!init} is [Some], it defines a structure update:
+ {[
+ { s with x := 3 }
+ ]}
+ *)
+and struct_update = {
+ struct_id : type_decl_id;
+ init : var_id option;
+ updates : (field_id * texpression) list;
+}
+
and texpression = { e : expression; ty : ty }
(** Meta-value (converted to an expression). It is important that the content
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 3614487e..7e6ca822 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -388,6 +388,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
| Loop loop -> update_loop loop ctx
+ | StructUpdate supd -> update_struct_update supd ctx
| Meta (meta, e) -> update_meta meta e ctx
in
(ctx, { e; ty })
@@ -474,6 +475,18 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
}
in
(ctx, Loop loop)
+ and update_struct_update (supd : struct_update) (ctx : pn_ctx) :
+ pn_ctx * expression =
+ let { struct_id; init; updates } = supd in
+ let ctx, updates =
+ List.fold_left_map
+ (fun ctx (fid, fe) ->
+ let ctx, fe = update_texpression fe ctx in
+ (ctx, (fid, fe)))
+ ctx updates
+ in
+ let supd = { struct_id; init; updates } in
+ (ctx, StructUpdate supd)
(* *)
and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
@@ -536,6 +549,89 @@ let remove_meta (def : fun_decl) : fun_decl =
let body = { body with body = PureUtils.remove_meta body.body } in
{ def with body = Some body }
+(** Introduce the special structure create/update expressions.
+
+ Upon generating the pure code, we introduce structure values by using
+ the structure constructors:
+ {[
+ Cons x0 ... xn
+ ]}
+
+ This micro-pass turns those into expressions which use structure syntax:
+ {[
+ {
+ f0 := x0;
+ ...
+ fn := xn;
+ }
+ ]}
+ *)
+let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
+ let obj =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_texpression env (e : texpression) =
+ match e.e with
+ | App _ -> (
+ let app, args = destruct_apps e in
+ let ignore () =
+ mk_apps
+ (self#visit_texpression env app)
+ (List.map (self#visit_texpression env) args)
+ in
+ match app.e with
+ | Qualif
+ {
+ id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
+ type_args = _;
+ } ->
+ (* Lookup the def *)
+ let decl =
+ TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ in
+ (* Check that there are as many arguments as there are fields - note
+ that the def should have a body (otherwise we couldn't use the
+ constructor) *)
+ let fields = TypesUtils.type_decl_get_fields decl None in
+ if List.length fields = List.length args then
+ (* Check if the definition is recursive *)
+ let is_rec =
+ match
+ TypeDeclId.Map.find adt_id
+ ctx.type_context.type_decls_groups
+ with
+ | NonRec _ -> false
+ | Rec _ -> true
+ in
+ (* Convert, if possible - note that for now for Lean and Coq
+ we don't support the structure syntax on recursive structures *)
+ if
+ (!Config.backend <> Lean && !Config.backend <> Coq)
+ || not is_rec
+ then
+ let struct_id = adt_id in
+ let init = None in
+ let updates =
+ FieldId.mapi
+ (fun fid fe -> (fid, self#visit_texpression env fe))
+ args
+ in
+ let ne = { struct_id; init; updates } in
+ let nty = e.ty in
+ { e = StructUpdate ne; ty = nty }
+ else ignore ()
+ else ignore ()
+ | _ -> ignore ())
+ | _ -> super#visit_texpression env e
+ end
+ in
+ match def.body with
+ | None -> def
+ | Some body ->
+ let body = { body with body = obj#visit_texpression () body.body } in
+ { def with body = Some body }
+
(** Inline the useless variable (re-)assignments:
A lot of intermediate variable assignments are introduced through the
@@ -604,6 +700,7 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
true (* primitive function call *)
| FunOrOp (Fun _) -> false (* non-primitive function call *)
| _ -> false)
+ | StructUpdate _ -> true (* ADT constructor *)
| _ -> false
in
let filter =
@@ -737,6 +834,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
method! visit_texpression env e =
match e.e with
| Var _ | Const _ -> fun _ -> false
+ | StructUpdate _ ->
+ (* There shouldn't be monadic calls in structure updates - also
+ note that by returning [false] we are conservative: we might
+ *prevent* possible optimisations (i.e., filtering some function
+ calls), which is sound. *)
+ fun _ -> false
| Let (_, _, re, e) -> (
match opt_destruct_function_call re with
| None -> fun () -> self#visit_texpression env e ()
@@ -829,7 +932,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
| Var _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
- | Abs _ ->
+ | StructUpdate _ | Abs _ ->
super#visit_expression env e
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
@@ -1602,6 +1705,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
log#ldebug
(lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Introduce the special structure create/update expressions *)
+ let def = intro_struct_updates ctx def in
+ log#ldebug
+ (lazy ("intro_struct_updates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
(* Inline the useless variable reassignments *)
let inline_named_vars = true in
let inline_pure = true in
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 1871f1bc..018ea6b5 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -194,6 +194,30 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty);
check_texpression ctx loop.fun_end;
check_texpression ctx loop.loop_body
+ | StructUpdate supd ->
+ (* Check the init value *)
+ (if Option.is_some supd.init then
+ match VarId.Map.find_opt (Option.get supd.init) ctx.env with
+ | None -> ()
+ | Some ty -> assert (ty = e.ty));
+ (* Check the fields *)
+ (* Retrieve and check the expected field type *)
+ let adt_id, adt_type_args =
+ match e.ty with
+ | Adt (type_id, tys) -> (type_id, tys)
+ | _ -> raise (Failure "Unreachable")
+ in
+ assert (adt_id = AdtId supd.struct_id);
+ let variant_id = None in
+ let expected_field_tys =
+ get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
+ in
+ List.iter
+ (fun (fid, fe) ->
+ let expected_field_ty = FieldId.nth expected_field_tys fid in
+ assert (expected_field_ty = fe.ty);
+ check_texpression ctx fe)
+ supd.updates
| Meta (_, e_next) ->
assert (e_next.ty = e.ty);
check_texpression ctx e_next
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 40005671..1f5d1ed8 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -157,7 +157,7 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Var _ | Const _ | App _ | Abs _ | Qualif _ -> false
+ | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 2c103177..5252495d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -22,7 +22,8 @@ type type_context = {
This map is empty when we translate the types, then contains all
the translated types when we translate the functions.
*)
- types_infos : TA.type_infos; (* TODO: rename to type_infos *)
+ type_infos : TA.type_infos;
+ recursive_decls : T.TypeDeclId.Set.t;
}
[@@deriving show]
@@ -451,8 +452,8 @@ let translate_type_decl (def : T.type_decl) : type_decl =
(preserve all borrows, etc.)
*)
-let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
- let translate = translate_fwd_ty types_infos in
+let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
+ let translate = translate_fwd_ty type_infos in
match ty with
| T.Adt (type_id, regions, tys) -> (
(* Can't translate types with regions for now *)
@@ -463,7 +464,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
match type_id with
| AdtId _ | T.Assumed (T.Vec | T.Option) ->
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
+ assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
let type_id =
match type_id with
| AdtId adt_id -> AdtId adt_id
@@ -479,7 +480,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| T.Assumed T.Box -> (
(* We eliminate boxes *)
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
+ assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
match t_tys with
| [ bty ] -> bty
| _ ->
@@ -494,17 +495,17 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| Integer int_ty -> Integer int_ty
| Str -> Str
| Array ty ->
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
Array (translate ty)
| Slice ty ->
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
Slice (translate ty)
| Ref (_, rty, _) -> translate rty
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
- let types_infos = ctx.type_context.types_infos in
- translate_fwd_ty types_infos ty
+ let type_infos = ctx.type_context.type_infos in
+ translate_fwd_ty type_infos ty
(** Translate a type, when some regions may have ended.
@@ -512,9 +513,9 @@ let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
[inside_mut]: are we inside a mutable borrow?
*)
-let rec translate_back_ty (types_infos : TA.type_infos)
+let rec translate_back_ty (type_infos : TA.type_infos)
(keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option =
- let translate = translate_back_ty types_infos keep_region inside_mut in
+ let translate = translate_back_ty type_infos keep_region inside_mut in
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
@@ -522,7 +523,7 @@ let rec translate_back_ty (types_infos : TA.type_infos)
match type_id with
| T.AdtId _ | Assumed (T.Vec | T.Option) ->
(* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
let type_id =
match type_id with
| T.AdtId id -> AdtId id
@@ -536,7 +537,7 @@ let rec translate_back_ty (types_infos : TA.type_infos)
else None
| Assumed T.Box -> (
(* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
(* Eliminate the box *)
match tys with
| [ bty ] -> translate bty
@@ -560,10 +561,10 @@ let rec translate_back_ty (types_infos : TA.type_infos)
| Integer int_ty -> wrap (Integer int_ty)
| Str -> wrap Str
| Array ty -> (
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
match translate ty with None -> None | Some ty -> Some (Array ty))
| Slice ty -> (
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
match translate ty with None -> None | Some ty -> Some (Slice ty))
| Ref (r, rty, rkind) -> (
match rkind with
@@ -574,14 +575,14 @@ let rec translate_back_ty (types_infos : TA.type_infos)
(* Dive in, remembering the fact that we are inside a mutable borrow *)
let inside_mut = true in
if keep_region r then
- translate_back_ty types_infos keep_region inside_mut rty
+ translate_back_ty type_infos keep_region inside_mut rty
else None)
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
(inside_mut : bool) (ty : 'r T.ty) : ty option =
- let types_infos = ctx.type_context.types_infos in
- translate_back_ty types_infos keep_region inside_mut ty
+ let type_infos = ctx.type_context.type_infos in
+ translate_back_ty type_infos keep_region inside_mut ty
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -670,7 +671,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
of the forward function) which we use as hints to generate pretty names.
*)
let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig)
+ (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig)
(input_names : string option list) (bid : T.RegionGroupId.id option) :
fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
@@ -691,7 +692,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
* - the current backward function (if it is a backward function)
*)
let fuel = mk_fuel_input_ty_as_list effect_info in
- let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in
+ let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in
(* For the backward functions: for now we don't supported nested borrows,
* so just check that there aren't parent regions *)
assert (T.RegionGroupId.Set.is_empty parents);
@@ -706,7 +707,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
| T.Var r -> T.RegionVarId.Set.mem r regions
in
let inside_mut = false in
- translate_back_ty types_infos keep_region inside_mut
+ translate_back_ty type_infos keep_region inside_mut
in
(* Compute the additinal inputs for the current function, if it is a backward
* function *)
@@ -762,7 +763,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
match gid with
| None ->
(* This is a forward function: there is one (unnamed) output *)
- ([ None ], [ translate_fwd_ty types_infos sg.output ])
+ ([ None ], [ translate_fwd_ty type_infos sg.output ])
| Some gid ->
(* This is a backward function: there might be several outputs.
The outputs are the borrows inside the regions of the abstractions
@@ -2057,11 +2058,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
match branches with
| [] -> raise (Failure "Unreachable")
| [ (variant_id, svl, branch) ]
- (* TODO: always introduce a match, and use micro-passes to turn the
- the match into a let *)
when not
(TypesUtils.ty_is_custom_adt sv.V.sv_ty
- && !Config.always_deconstruct_adts_with_matches) -> (
+ && !Config.always_deconstruct_adts_with_matches) ->
(* There is exactly one branch: no branching.
We can decompose the ADT value with a let-binding, unless
@@ -2069,94 +2068,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
we *ignore* this branch (and go to the next one) if the ADT is a custom
adt, and [always_deconstruct_adts_with_matches] is true.
*)
- let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
- let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
- let branch = translate_expression branch ctx in
- match type_id with
- | T.AdtId adt_id ->
- (* Detect if this is an enumeration or not *)
- let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in
- let is_enum = type_decl_is_enum tdef in
- (* We deconstruct the ADT with a let-binding in two situations:
- - if the ADT is an enumeration (which must have exactly one branch)
- - if we forbid using field projectors.
-
- We forbid using field projectors in some situations, for example
- if the backend is Coq. See '!Config.dont_use_field_projectors}.
- *)
- let use_let = is_enum || !Config.dont_use_field_projectors in
- if use_let then
- (* Introduce a let binding which expands the ADT *)
- let lvars =
- List.map (fun v -> mk_typed_pattern_from_var v None) vars
- in
- let lv = mk_adt_pattern scrutinee.ty variant_id lvars in
- let monadic = false in
-
- mk_let monadic lv
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- else
- (* This is not an enumeration: introduce let-bindings for every
- * field.
- * We use the [dest] variable in order not to have to recompute
- * the type of the result of the projection... *)
- let adt_id, type_args =
- match scrutinee.ty with
- | Adt (adt_id, tys) -> (adt_id, tys)
- | _ -> raise (Failure "Unreachable")
- in
- let gen_field_proj (field_id : FieldId.id) (dest : var) :
- texpression =
- let proj_kind = { adt_id; field_id } in
- let qualif = { id = Proj proj_kind; type_args } in
- let proj_e = Qualif qualif in
- let proj_ty = mk_arrow scrutinee.ty dest.ty in
- let proj = { e = proj_e; ty = proj_ty } in
- mk_app proj scrutinee
- in
- let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
- let monadic = false in
- List.fold_right
- (fun (fid, var) e ->
- let field_proj = gen_field_proj fid var in
- mk_let monadic
- (mk_typed_pattern_from_var var None)
- field_proj e)
- id_var_pairs branch
- | T.Tuple ->
- let vars =
- List.map (fun x -> mk_typed_pattern_from_var x None) vars
- in
- let monadic = false in
- mk_let monadic
- (mk_simpl_tuple_pattern vars)
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- | T.Assumed T.Box ->
- (* There should be exactly one variable *)
- let var =
- match vars with
- | [ v ] -> v
- | _ -> raise (Failure "Unreachable")
- in
- (* We simply introduce an assignment - the box type is the
- * identity when extracted ([box a = a]) *)
- let monadic = false in
- mk_let monadic
- (mk_typed_pattern_from_var var None)
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- | T.Assumed T.Vec ->
- (* We can't expand vector values: we can access the fields only
- * through the functions provided by the API (note that we don't
- * know how to expand a vector, because it has a variable number
- * of fields!) *)
- raise (Failure "Can't expand a vector value")
- | T.Assumed T.Option ->
- (* We shouldn't get there in the "one-branch" case: options have
- * two variants *)
- raise (Failure "Unreachable"))
+ translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace
+ variant_id svl branch ctx
| branches ->
let translate_branch (variant_id : T.VariantId.id option)
(svl : V.symbolic_value list) (branch : S.expression) :
@@ -2225,6 +2138,120 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
+(* Translate and [ExpandAdt] when there is no branching (i.e., one branch).
+
+ There are several possibilities:
+ - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding:
+ {[
+ let Cons x0 ... xn = y in
+ ...
+ ]}
+
+ - if the ADT is a structure, we attempt to introduce one let-binding per field:
+ {[
+ let x0 = y.f0 in
+ ...
+ let xn = y.fn in
+ ...
+ ]}
+
+ Of course, this is not always possible depending on the backend.
+ Also, recursive structures, and more specifically structures mutually recursive
+ with inductives, are usually not supported. We define such recursive structures
+ as inductives, in which case it is not always possible to use a notation
+ for the field projections.
+*)
+and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
+ (scrutinee : texpression) (scrutinee_mplace : mplace option)
+ (variant_id : variant_id option) (svl : V.symbolic_value list)
+ (branch : S.expression) (ctx : bs_ctx) : texpression =
+ (* TODO: always introduce a match, and use micro-passes to turn the
+ the match into a let? *)
+ let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
+ let branch = translate_expression branch ctx in
+ match type_id with
+ | T.AdtId adt_id ->
+ (* Detect if this is an enumeration or not *)
+ let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in
+ let is_enum = type_decl_is_enum tdef in
+ (* We deconstruct the ADT with a let-binding in two situations:
+ - if the ADT is an enumeration (which must have exactly one branch)
+ - if we forbid using field projectors.
+ *)
+ let is_rec_def =
+ T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls
+ in
+ let use_let =
+ is_enum
+ || !Config.dont_use_field_projectors
+ (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *)
+ || (!Config.backend = Lean && is_rec_def)
+ in
+ if use_let then
+ (* Introduce a let binding which expands the ADT *)
+ let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in
+ let lv = mk_adt_pattern scrutinee.ty variant_id lvars in
+ let monadic = false in
+
+ mk_let monadic lv
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ else
+ (* This is not an enumeration: introduce let-bindings for every
+ * field.
+ * We use the [dest] variable in order not to have to recompute
+ * the type of the result of the projection... *)
+ let adt_id, type_args =
+ match scrutinee.ty with
+ | Adt (adt_id, tys) -> (adt_id, tys)
+ | _ -> raise (Failure "Unreachable")
+ in
+ let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression =
+ let proj_kind = { adt_id; field_id } in
+ let qualif = { id = Proj proj_kind; type_args } in
+ let proj_e = Qualif qualif in
+ let proj_ty = mk_arrow scrutinee.ty dest.ty in
+ let proj = { e = proj_e; ty = proj_ty } in
+ mk_app proj scrutinee
+ in
+ let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
+ let monadic = false in
+ List.fold_right
+ (fun (fid, var) e ->
+ let field_proj = gen_field_proj fid var in
+ mk_let monadic (mk_typed_pattern_from_var var None) field_proj e)
+ id_var_pairs branch
+ | T.Tuple ->
+ let vars = List.map (fun x -> mk_typed_pattern_from_var x None) vars in
+ let monadic = false in
+ mk_let monadic
+ (mk_simpl_tuple_pattern vars)
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ | T.Assumed T.Box ->
+ (* There should be exactly one variable *)
+ let var =
+ match vars with [ v ] -> v | _ -> raise (Failure "Unreachable")
+ in
+ (* We simply introduce an assignment - the box type is the
+ * identity when extracted ([box a = a]) *)
+ let monadic = false in
+ mk_let monadic
+ (mk_typed_pattern_from_var var None)
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ | T.Assumed T.Vec ->
+ (* We can't expand vector values: we can access the fields only
+ * through the functions provided by the API (note that we don't
+ * know how to expand a vector, because it has a variable number
+ * of fields!) *)
+ raise (Failure "Can't expand a vector value")
+ | T.Assumed T.Option ->
+ (* We shouldn't get there in the "one-branch" case: options have
+ * two variants *)
+ raise (Failure "Unreachable")
+
and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(sv : V.symbolic_value) (v : V.typed_value) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -2445,7 +2472,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
List.map
(fun ty ->
assert (
- not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty));
+ not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty));
(None, ctx_translate_fwd_ty !ctx ty))
tys
in
@@ -2769,7 +2796,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
functions)
*)
let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (types_infos : TA.type_infos)
+ (type_infos : TA.type_infos)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdNotLoopMap.t =
(* For every function, translate the signatures of:
@@ -2781,7 +2808,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
=
(* The forward function *)
let fwd_sg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names None
+ translate_fun_sig fun_infos fun_id type_infos sg input_names None
in
let fwd_id = (fun_id, None) in
(* The backward functions *)
@@ -2789,7 +2816,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names
+ translate_fun_sig fun_infos fun_id type_infos sg input_names
(Some rg.id)
in
let id = (fun_id, Some rg.id) in
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 6bff936b..347052a8 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -77,11 +77,19 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let fuel, var_counter = Pure.VarId.fresh var_counter in
let calls = V.FunCallId.Map.empty in
let abstractions = V.AbstractionId.Map.empty in
+ let recursive_type_decls =
+ T.TypeDeclId.Set.of_list
+ (List.filter_map
+ (fun (tid, g) ->
+ match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid)
+ (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups))
+ in
let type_context =
{
- SymbolicToPure.types_infos = type_context.type_infos;
+ SymbolicToPure.type_infos = type_context.type_infos;
llbc_type_decls = type_context.type_decls;
type_decls = pure_type_decls;
+ recursive_decls = recursive_type_decls;
}
in
let fun_context =