summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml261
1 files changed, 144 insertions, 117 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 2c103177..5252495d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -22,7 +22,8 @@ type type_context = {
This map is empty when we translate the types, then contains all
the translated types when we translate the functions.
*)
- types_infos : TA.type_infos; (* TODO: rename to type_infos *)
+ type_infos : TA.type_infos;
+ recursive_decls : T.TypeDeclId.Set.t;
}
[@@deriving show]
@@ -451,8 +452,8 @@ let translate_type_decl (def : T.type_decl) : type_decl =
(preserve all borrows, etc.)
*)
-let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
- let translate = translate_fwd_ty types_infos in
+let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
+ let translate = translate_fwd_ty type_infos in
match ty with
| T.Adt (type_id, regions, tys) -> (
(* Can't translate types with regions for now *)
@@ -463,7 +464,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
match type_id with
| AdtId _ | T.Assumed (T.Vec | T.Option) ->
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
+ assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
let type_id =
match type_id with
| AdtId adt_id -> AdtId adt_id
@@ -479,7 +480,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| T.Assumed T.Box -> (
(* We eliminate boxes *)
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
+ assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
match t_tys with
| [ bty ] -> bty
| _ ->
@@ -494,17 +495,17 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| Integer int_ty -> Integer int_ty
| Str -> Str
| Array ty ->
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
Array (translate ty)
| Slice ty ->
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
Slice (translate ty)
| Ref (_, rty, _) -> translate rty
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
- let types_infos = ctx.type_context.types_infos in
- translate_fwd_ty types_infos ty
+ let type_infos = ctx.type_context.type_infos in
+ translate_fwd_ty type_infos ty
(** Translate a type, when some regions may have ended.
@@ -512,9 +513,9 @@ let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
[inside_mut]: are we inside a mutable borrow?
*)
-let rec translate_back_ty (types_infos : TA.type_infos)
+let rec translate_back_ty (type_infos : TA.type_infos)
(keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option =
- let translate = translate_back_ty types_infos keep_region inside_mut in
+ let translate = translate_back_ty type_infos keep_region inside_mut in
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
@@ -522,7 +523,7 @@ let rec translate_back_ty (types_infos : TA.type_infos)
match type_id with
| T.AdtId _ | Assumed (T.Vec | T.Option) ->
(* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
let type_id =
match type_id with
| T.AdtId id -> AdtId id
@@ -536,7 +537,7 @@ let rec translate_back_ty (types_infos : TA.type_infos)
else None
| Assumed T.Box -> (
(* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
(* Eliminate the box *)
match tys with
| [ bty ] -> translate bty
@@ -560,10 +561,10 @@ let rec translate_back_ty (types_infos : TA.type_infos)
| Integer int_ty -> wrap (Integer int_ty)
| Str -> wrap Str
| Array ty -> (
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
match translate ty with None -> None | Some ty -> Some (Array ty))
| Slice ty -> (
- assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows type_infos ty));
match translate ty with None -> None | Some ty -> Some (Slice ty))
| Ref (r, rty, rkind) -> (
match rkind with
@@ -574,14 +575,14 @@ let rec translate_back_ty (types_infos : TA.type_infos)
(* Dive in, remembering the fact that we are inside a mutable borrow *)
let inside_mut = true in
if keep_region r then
- translate_back_ty types_infos keep_region inside_mut rty
+ translate_back_ty type_infos keep_region inside_mut rty
else None)
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
(inside_mut : bool) (ty : 'r T.ty) : ty option =
- let types_infos = ctx.type_context.types_infos in
- translate_back_ty types_infos keep_region inside_mut ty
+ let type_infos = ctx.type_context.type_infos in
+ translate_back_ty type_infos keep_region inside_mut ty
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -670,7 +671,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
of the forward function) which we use as hints to generate pretty names.
*)
let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig)
+ (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig)
(input_names : string option list) (bid : T.RegionGroupId.id option) :
fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
@@ -691,7 +692,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
* - the current backward function (if it is a backward function)
*)
let fuel = mk_fuel_input_ty_as_list effect_info in
- let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in
+ let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in
(* For the backward functions: for now we don't supported nested borrows,
* so just check that there aren't parent regions *)
assert (T.RegionGroupId.Set.is_empty parents);
@@ -706,7 +707,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
| T.Var r -> T.RegionVarId.Set.mem r regions
in
let inside_mut = false in
- translate_back_ty types_infos keep_region inside_mut
+ translate_back_ty type_infos keep_region inside_mut
in
(* Compute the additinal inputs for the current function, if it is a backward
* function *)
@@ -762,7 +763,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
match gid with
| None ->
(* This is a forward function: there is one (unnamed) output *)
- ([ None ], [ translate_fwd_ty types_infos sg.output ])
+ ([ None ], [ translate_fwd_ty type_infos sg.output ])
| Some gid ->
(* This is a backward function: there might be several outputs.
The outputs are the borrows inside the regions of the abstractions
@@ -2057,11 +2058,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
match branches with
| [] -> raise (Failure "Unreachable")
| [ (variant_id, svl, branch) ]
- (* TODO: always introduce a match, and use micro-passes to turn the
- the match into a let *)
when not
(TypesUtils.ty_is_custom_adt sv.V.sv_ty
- && !Config.always_deconstruct_adts_with_matches) -> (
+ && !Config.always_deconstruct_adts_with_matches) ->
(* There is exactly one branch: no branching.
We can decompose the ADT value with a let-binding, unless
@@ -2069,94 +2068,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
we *ignore* this branch (and go to the next one) if the ADT is a custom
adt, and [always_deconstruct_adts_with_matches] is true.
*)
- let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
- let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
- let branch = translate_expression branch ctx in
- match type_id with
- | T.AdtId adt_id ->
- (* Detect if this is an enumeration or not *)
- let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in
- let is_enum = type_decl_is_enum tdef in
- (* We deconstruct the ADT with a let-binding in two situations:
- - if the ADT is an enumeration (which must have exactly one branch)
- - if we forbid using field projectors.
-
- We forbid using field projectors in some situations, for example
- if the backend is Coq. See '!Config.dont_use_field_projectors}.
- *)
- let use_let = is_enum || !Config.dont_use_field_projectors in
- if use_let then
- (* Introduce a let binding which expands the ADT *)
- let lvars =
- List.map (fun v -> mk_typed_pattern_from_var v None) vars
- in
- let lv = mk_adt_pattern scrutinee.ty variant_id lvars in
- let monadic = false in
-
- mk_let monadic lv
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- else
- (* This is not an enumeration: introduce let-bindings for every
- * field.
- * We use the [dest] variable in order not to have to recompute
- * the type of the result of the projection... *)
- let adt_id, type_args =
- match scrutinee.ty with
- | Adt (adt_id, tys) -> (adt_id, tys)
- | _ -> raise (Failure "Unreachable")
- in
- let gen_field_proj (field_id : FieldId.id) (dest : var) :
- texpression =
- let proj_kind = { adt_id; field_id } in
- let qualif = { id = Proj proj_kind; type_args } in
- let proj_e = Qualif qualif in
- let proj_ty = mk_arrow scrutinee.ty dest.ty in
- let proj = { e = proj_e; ty = proj_ty } in
- mk_app proj scrutinee
- in
- let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
- let monadic = false in
- List.fold_right
- (fun (fid, var) e ->
- let field_proj = gen_field_proj fid var in
- mk_let monadic
- (mk_typed_pattern_from_var var None)
- field_proj e)
- id_var_pairs branch
- | T.Tuple ->
- let vars =
- List.map (fun x -> mk_typed_pattern_from_var x None) vars
- in
- let monadic = false in
- mk_let monadic
- (mk_simpl_tuple_pattern vars)
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- | T.Assumed T.Box ->
- (* There should be exactly one variable *)
- let var =
- match vars with
- | [ v ] -> v
- | _ -> raise (Failure "Unreachable")
- in
- (* We simply introduce an assignment - the box type is the
- * identity when extracted ([box a = a]) *)
- let monadic = false in
- mk_let monadic
- (mk_typed_pattern_from_var var None)
- (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
- branch
- | T.Assumed T.Vec ->
- (* We can't expand vector values: we can access the fields only
- * through the functions provided by the API (note that we don't
- * know how to expand a vector, because it has a variable number
- * of fields!) *)
- raise (Failure "Can't expand a vector value")
- | T.Assumed T.Option ->
- (* We shouldn't get there in the "one-branch" case: options have
- * two variants *)
- raise (Failure "Unreachable"))
+ translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace
+ variant_id svl branch ctx
| branches ->
let translate_branch (variant_id : T.VariantId.id option)
(svl : V.symbolic_value list) (branch : S.expression) :
@@ -2225,6 +2138,120 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
+(* Translate and [ExpandAdt] when there is no branching (i.e., one branch).
+
+ There are several possibilities:
+ - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding:
+ {[
+ let Cons x0 ... xn = y in
+ ...
+ ]}
+
+ - if the ADT is a structure, we attempt to introduce one let-binding per field:
+ {[
+ let x0 = y.f0 in
+ ...
+ let xn = y.fn in
+ ...
+ ]}
+
+ Of course, this is not always possible depending on the backend.
+ Also, recursive structures, and more specifically structures mutually recursive
+ with inductives, are usually not supported. We define such recursive structures
+ as inductives, in which case it is not always possible to use a notation
+ for the field projections.
+*)
+and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
+ (scrutinee : texpression) (scrutinee_mplace : mplace option)
+ (variant_id : variant_id option) (svl : V.symbolic_value list)
+ (branch : S.expression) (ctx : bs_ctx) : texpression =
+ (* TODO: always introduce a match, and use micro-passes to turn the
+ the match into a let? *)
+ let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
+ let branch = translate_expression branch ctx in
+ match type_id with
+ | T.AdtId adt_id ->
+ (* Detect if this is an enumeration or not *)
+ let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in
+ let is_enum = type_decl_is_enum tdef in
+ (* We deconstruct the ADT with a let-binding in two situations:
+ - if the ADT is an enumeration (which must have exactly one branch)
+ - if we forbid using field projectors.
+ *)
+ let is_rec_def =
+ T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls
+ in
+ let use_let =
+ is_enum
+ || !Config.dont_use_field_projectors
+ (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *)
+ || (!Config.backend = Lean && is_rec_def)
+ in
+ if use_let then
+ (* Introduce a let binding which expands the ADT *)
+ let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in
+ let lv = mk_adt_pattern scrutinee.ty variant_id lvars in
+ let monadic = false in
+
+ mk_let monadic lv
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ else
+ (* This is not an enumeration: introduce let-bindings for every
+ * field.
+ * We use the [dest] variable in order not to have to recompute
+ * the type of the result of the projection... *)
+ let adt_id, type_args =
+ match scrutinee.ty with
+ | Adt (adt_id, tys) -> (adt_id, tys)
+ | _ -> raise (Failure "Unreachable")
+ in
+ let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression =
+ let proj_kind = { adt_id; field_id } in
+ let qualif = { id = Proj proj_kind; type_args } in
+ let proj_e = Qualif qualif in
+ let proj_ty = mk_arrow scrutinee.ty dest.ty in
+ let proj = { e = proj_e; ty = proj_ty } in
+ mk_app proj scrutinee
+ in
+ let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
+ let monadic = false in
+ List.fold_right
+ (fun (fid, var) e ->
+ let field_proj = gen_field_proj fid var in
+ mk_let monadic (mk_typed_pattern_from_var var None) field_proj e)
+ id_var_pairs branch
+ | T.Tuple ->
+ let vars = List.map (fun x -> mk_typed_pattern_from_var x None) vars in
+ let monadic = false in
+ mk_let monadic
+ (mk_simpl_tuple_pattern vars)
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ | T.Assumed T.Box ->
+ (* There should be exactly one variable *)
+ let var =
+ match vars with [ v ] -> v | _ -> raise (Failure "Unreachable")
+ in
+ (* We simply introduce an assignment - the box type is the
+ * identity when extracted ([box a = a]) *)
+ let monadic = false in
+ mk_let monadic
+ (mk_typed_pattern_from_var var None)
+ (mk_opt_mplace_texpression scrutinee_mplace scrutinee)
+ branch
+ | T.Assumed T.Vec ->
+ (* We can't expand vector values: we can access the fields only
+ * through the functions provided by the API (note that we don't
+ * know how to expand a vector, because it has a variable number
+ * of fields!) *)
+ raise (Failure "Can't expand a vector value")
+ | T.Assumed T.Option ->
+ (* We shouldn't get there in the "one-branch" case: options have
+ * two variants *)
+ raise (Failure "Unreachable")
+
and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(sv : V.symbolic_value) (v : V.typed_value) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -2445,7 +2472,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
List.map
(fun ty ->
assert (
- not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty));
+ not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty));
(None, ctx_translate_fwd_ty !ctx ty))
tys
in
@@ -2769,7 +2796,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
functions)
*)
let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (types_infos : TA.type_infos)
+ (type_infos : TA.type_infos)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdNotLoopMap.t =
(* For every function, translate the signatures of:
@@ -2781,7 +2808,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
=
(* The forward function *)
let fwd_sg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names None
+ translate_fun_sig fun_infos fun_id type_infos sg input_names None
in
let fwd_id = (fun_id, None) in
(* The backward functions *)
@@ -2789,7 +2816,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names
+ translate_fun_sig fun_infos fun_id type_infos sg input_names
(Some rg.id)
in
let id = (fun_id, Some rg.id) in