diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 261 |
1 files changed, 144 insertions, 117 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 2c103177..5252495d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -22,7 +22,8 @@ type type_context = { This map is empty when we translate the types, then contains all the translated types when we translate the functions. *) - types_infos : TA.type_infos; (* TODO: rename to type_infos *) + type_infos : TA.type_infos; + recursive_decls : T.TypeDeclId.Set.t; } [@@deriving show] @@ -451,8 +452,8 @@ let translate_type_decl (def : T.type_decl) : type_decl = (preserve all borrows, etc.) *) -let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = - let translate = translate_fwd_ty types_infos in +let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = + let translate = translate_fwd_ty type_infos in match ty with | T.Adt (type_id, regions, tys) -> ( (* Can't translate types with regions for now *) @@ -463,7 +464,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = match type_id with | AdtId _ | T.Assumed (T.Vec | T.Option) -> (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); + assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); let type_id = match type_id with | AdtId adt_id -> AdtId adt_id @@ -479,7 +480,7 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = | T.Assumed T.Box -> ( (* We eliminate boxes *) (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); + assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); match t_tys with | [ bty ] -> bty | _ -> @@ -494,17 +495,17 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = | Integer int_ty -> Integer int_ty | Str -> Str | Array ty -> - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); Array (translate ty) | Slice ty -> - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); Slice (translate ty) | Ref (_, rty, _) -> translate rty (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = - let types_infos = ctx.type_context.types_infos in - translate_fwd_ty types_infos ty + let type_infos = ctx.type_context.type_infos in + translate_fwd_ty type_infos ty (** Translate a type, when some regions may have ended. @@ -512,9 +513,9 @@ let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = [inside_mut]: are we inside a mutable borrow? *) -let rec translate_back_ty (types_infos : TA.type_infos) +let rec translate_back_ty (type_infos : TA.type_infos) (keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option = - let translate = translate_back_ty types_infos keep_region inside_mut in + let translate = translate_back_ty type_infos keep_region inside_mut in (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with @@ -522,7 +523,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) match type_id with | T.AdtId _ | Assumed (T.Vec | T.Option) -> (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); let type_id = match type_id with | T.AdtId id -> AdtId id @@ -536,7 +537,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); (* Eliminate the box *) match tys with | [ bty ] -> translate bty @@ -560,10 +561,10 @@ let rec translate_back_ty (types_infos : TA.type_infos) | Integer int_ty -> wrap (Integer int_ty) | Str -> wrap Str | Array ty -> ( - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); match translate ty with None -> None | Some ty -> Some (Array ty)) | Slice ty -> ( - assert (not (TypesUtils.ty_has_borrows types_infos ty)); + assert (not (TypesUtils.ty_has_borrows type_infos ty)); match translate ty with None -> None | Some ty -> Some (Slice ty)) | Ref (r, rty, rkind) -> ( match rkind with @@ -574,14 +575,14 @@ let rec translate_back_ty (types_infos : TA.type_infos) (* Dive in, remembering the fact that we are inside a mutable borrow *) let inside_mut = true in if keep_region r then - translate_back_ty types_infos keep_region inside_mut rty + translate_back_ty type_infos keep_region inside_mut rty else None) (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option = - let types_infos = ctx.type_context.types_infos in - translate_back_ty types_infos keep_region inside_mut ty + let type_infos = ctx.type_context.type_infos in + translate_back_ty type_infos keep_region inside_mut ty (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -670,7 +671,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) of the forward function) which we use as hints to generate pretty names. *) let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) + (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) @@ -691,7 +692,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) * - the current backward function (if it is a backward function) *) let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in + let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in (* For the backward functions: for now we don't supported nested borrows, * so just check that there aren't parent regions *) assert (T.RegionGroupId.Set.is_empty parents); @@ -706,7 +707,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) | T.Var r -> T.RegionVarId.Set.mem r regions in let inside_mut = false in - translate_back_ty types_infos keep_region inside_mut + translate_back_ty type_infos keep_region inside_mut in (* Compute the additinal inputs for the current function, if it is a backward * function *) @@ -762,7 +763,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) match gid with | None -> (* This is a forward function: there is one (unnamed) output *) - ([ None ], [ translate_fwd_ty types_infos sg.output ]) + ([ None ], [ translate_fwd_ty type_infos sg.output ]) | Some gid -> (* This is a backward function: there might be several outputs. The outputs are the borrows inside the regions of the abstractions @@ -2057,11 +2058,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match branches with | [] -> raise (Failure "Unreachable") | [ (variant_id, svl, branch) ] - (* TODO: always introduce a match, and use micro-passes to turn the - the match into a let *) when not (TypesUtils.ty_is_custom_adt sv.V.sv_ty - && !Config.always_deconstruct_adts_with_matches) -> ( + && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. We can decompose the ADT value with a let-binding, unless @@ -2069,94 +2068,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) we *ignore* this branch (and go to the next one) if the ADT is a custom adt, and [always_deconstruct_adts_with_matches] is true. *) - let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in - let ctx, vars = fresh_vars_for_symbolic_values svl ctx in - let branch = translate_expression branch ctx in - match type_id with - | T.AdtId adt_id -> - (* Detect if this is an enumeration or not *) - let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in - let is_enum = type_decl_is_enum tdef in - (* We deconstruct the ADT with a let-binding in two situations: - - if the ADT is an enumeration (which must have exactly one branch) - - if we forbid using field projectors. - - We forbid using field projectors in some situations, for example - if the backend is Coq. See '!Config.dont_use_field_projectors}. - *) - let use_let = is_enum || !Config.dont_use_field_projectors in - if use_let then - (* Introduce a let binding which expands the ADT *) - let lvars = - List.map (fun v -> mk_typed_pattern_from_var v None) vars - in - let lv = mk_adt_pattern scrutinee.ty variant_id lvars in - let monadic = false in - - mk_let monadic lv - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - else - (* This is not an enumeration: introduce let-bindings for every - * field. - * We use the [dest] variable in order not to have to recompute - * the type of the result of the projection... *) - let adt_id, type_args = - match scrutinee.ty with - | Adt (adt_id, tys) -> (adt_id, tys) - | _ -> raise (Failure "Unreachable") - in - let gen_field_proj (field_id : FieldId.id) (dest : var) : - texpression = - let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args } in - let proj_e = Qualif qualif in - let proj_ty = mk_arrow scrutinee.ty dest.ty in - let proj = { e = proj_e; ty = proj_ty } in - mk_app proj scrutinee - in - let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in - let monadic = false in - List.fold_right - (fun (fid, var) e -> - let field_proj = gen_field_proj fid var in - mk_let monadic - (mk_typed_pattern_from_var var None) - field_proj e) - id_var_pairs branch - | T.Tuple -> - let vars = - List.map (fun x -> mk_typed_pattern_from_var x None) vars - in - let monadic = false in - mk_let monadic - (mk_simpl_tuple_pattern vars) - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - | T.Assumed T.Box -> - (* There should be exactly one variable *) - let var = - match vars with - | [ v ] -> v - | _ -> raise (Failure "Unreachable") - in - (* We simply introduce an assignment - the box type is the - * identity when extracted ([box a = a]) *) - let monadic = false in - mk_let monadic - (mk_typed_pattern_from_var var None) - (mk_opt_mplace_texpression scrutinee_mplace scrutinee) - branch - | T.Assumed T.Vec -> - (* We can't expand vector values: we can access the fields only - * through the functions provided by the API (note that we don't - * know how to expand a vector, because it has a variable number - * of fields!) *) - raise (Failure "Can't expand a vector value") - | T.Assumed T.Option -> - (* We shouldn't get there in the "one-branch" case: options have - * two variants *) - raise (Failure "Unreachable")) + translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace + variant_id svl branch ctx | branches -> let translate_branch (variant_id : T.VariantId.id option) (svl : V.symbolic_value list) (branch : S.expression) : @@ -2225,6 +2138,120 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } +(* Translate and [ExpandAdt] when there is no branching (i.e., one branch). + + There are several possibilities: + - if the ADT is an enumeration, we attempt to deconstruct it with a let-binding: + {[ + let Cons x0 ... xn = y in + ... + ]} + + - if the ADT is a structure, we attempt to introduce one let-binding per field: + {[ + let x0 = y.f0 in + ... + let xn = y.fn in + ... + ]} + + Of course, this is not always possible depending on the backend. + Also, recursive structures, and more specifically structures mutually recursive + with inductives, are usually not supported. We define such recursive structures + as inductives, in which case it is not always possible to use a notation + for the field projections. +*) +and translate_ExpandAdt_one_branch (sv : V.symbolic_value) + (scrutinee : texpression) (scrutinee_mplace : mplace option) + (variant_id : variant_id option) (svl : V.symbolic_value list) + (branch : S.expression) (ctx : bs_ctx) : texpression = + (* TODO: always introduce a match, and use micro-passes to turn the + the match into a let? *) + let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let ctx, vars = fresh_vars_for_symbolic_values svl ctx in + let branch = translate_expression branch ctx in + match type_id with + | T.AdtId adt_id -> + (* Detect if this is an enumeration or not *) + let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in + let is_enum = type_decl_is_enum tdef in + (* We deconstruct the ADT with a let-binding in two situations: + - if the ADT is an enumeration (which must have exactly one branch) + - if we forbid using field projectors. + *) + let is_rec_def = + T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls + in + let use_let = + is_enum + || !Config.dont_use_field_projectors + (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *) + || (!Config.backend = Lean && is_rec_def) + in + if use_let then + (* Introduce a let binding which expands the ADT *) + let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in + let lv = mk_adt_pattern scrutinee.ty variant_id lvars in + let monadic = false in + + mk_let monadic lv + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + else + (* This is not an enumeration: introduce let-bindings for every + * field. + * We use the [dest] variable in order not to have to recompute + * the type of the result of the projection... *) + let adt_id, type_args = + match scrutinee.ty with + | Adt (adt_id, tys) -> (adt_id, tys) + | _ -> raise (Failure "Unreachable") + in + let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = + let proj_kind = { adt_id; field_id } in + let qualif = { id = Proj proj_kind; type_args } in + let proj_e = Qualif qualif in + let proj_ty = mk_arrow scrutinee.ty dest.ty in + let proj = { e = proj_e; ty = proj_ty } in + mk_app proj scrutinee + in + let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in + let monadic = false in + List.fold_right + (fun (fid, var) e -> + let field_proj = gen_field_proj fid var in + mk_let monadic (mk_typed_pattern_from_var var None) field_proj e) + id_var_pairs branch + | T.Tuple -> + let vars = List.map (fun x -> mk_typed_pattern_from_var x None) vars in + let monadic = false in + mk_let monadic + (mk_simpl_tuple_pattern vars) + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + | T.Assumed T.Box -> + (* There should be exactly one variable *) + let var = + match vars with [ v ] -> v | _ -> raise (Failure "Unreachable") + in + (* We simply introduce an assignment - the box type is the + * identity when extracted ([box a = a]) *) + let monadic = false in + mk_let monadic + (mk_typed_pattern_from_var var None) + (mk_opt_mplace_texpression scrutinee_mplace scrutinee) + branch + | T.Assumed T.Vec -> + (* We can't expand vector values: we can access the fields only + * through the functions provided by the API (note that we don't + * know how to expand a vector, because it has a variable number + * of fields!) *) + raise (Failure "Can't expand a vector value") + | T.Assumed T.Option -> + (* We shouldn't get there in the "one-branch" case: options have + * two variants *) + raise (Failure "Unreachable") + and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) (sv : V.symbolic_value) (v : V.typed_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2445,7 +2472,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> assert ( - not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty)); + not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty)); (None, ctx_translate_fwd_ty !ctx ty)) tys in @@ -2769,7 +2796,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = functions) *) let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (types_infos : TA.type_infos) + (type_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdNotLoopMap.t = (* For every function, translate the signatures of: @@ -2781,7 +2808,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) = (* The forward function *) let fwd_sg = - translate_fun_sig fun_infos fun_id types_infos sg input_names None + translate_fun_sig fun_infos fun_id type_infos sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) @@ -2789,7 +2816,7 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id types_infos sg input_names + translate_fun_sig fun_infos fun_id type_infos sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in |