From eb05c2e3b63377c323c33c1296495baa9357596a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 5 Dec 2023 17:50:38 +0100 Subject: Remove the type sv_kind ("symbolic value kind") --- compiler/Interpreter.ml | 2 +- compiler/InterpreterBorrows.ml | 60 ++++++++++------------------------- compiler/InterpreterExpansion.ml | 41 +++++++++++------------- compiler/InterpreterExpressions.ml | 14 +++----- compiler/InterpreterLoopsMatchCtxs.ml | 24 +++++--------- compiler/InterpreterStatements.ml | 4 +-- compiler/InterpreterUtils.ml | 32 +++++++------------ compiler/Values.ml | 44 ------------------------- 8 files changed, 63 insertions(+), 158 deletions(-) (limited to 'compiler') diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 4ecafd31..76432faa 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -206,7 +206,7 @@ let initialize_symbolic_context_for_fun (ctx : decls_ctx) (fdef : fun_decl) : in (* Create fresh symbolic values for the inputs *) let input_svs = - List.map (fun ty -> mk_fresh_symbolic_value SynthInput ty) inst_sg.inputs + List.map (fun ty -> mk_fresh_symbolic_value ty) inst_sg.inputs in (* Initialize the abstractions as empty (i.e., with no avalues) abstractions *) let call_id = fresh_fun_call_id () in diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 6a7ac095..19b9fd3b 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -444,13 +444,6 @@ let give_back_symbolic_value (_config : config) (proj_regions : RegionId.Set.t) (ctx : eval_ctx) : eval_ctx = (* Sanity checks *) assert (sv.sv_id <> nsv.sv_id && ty_is_rty proj_ty); - (match nsv.sv_kind with - | SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack - -> - () - | FunCallRet | SynthInput | Global | KindConstGeneric | LoopOutput | LoopJoin - | Aggregate | ConstGeneric | TraitConst -> - raise (Failure "Unreachable")); (* Store the given-back value as a meta-value for synthesis purposes *) let mv = nsv in (* Substitution function, to replace the borrow projectors over symbolic values *) @@ -459,31 +452,20 @@ let give_back_symbolic_value (_config : config) (proj_regions : RegionId.Set.t) let _ = raise Utils.Unimplemented in (* Compute the projection over the given back value *) let child_proj = - match nsv.sv_kind with - | SynthRetGivenBack -> - (* The given back value comes from the return value of the function - we are currently synthesizing (as it is given back, it means - we ended one of the regions appearing in the signature: we are - currently synthesizing one of the backward functions). - - As we don't allow borrow overwrites on returned value, we can - (and MUST) forget the borrows *) - AIgnoredProjBorrows - | FunCallGivenBack -> - (* TODO: there is something wrong here. - Consider this: - {[ - abs0 {'a} { AProjLoans (s0 : &'a mut T) [] } - abs1 {'b} { AProjBorrows (s0 : &'a mut T <: &'b mut T) } - ]} - - Upon ending abs1, we give back some fresh symbolic value [s1], - that we reinsert where the loan for [s0] is. However, the mutable - borrow in the type [&'a mut T] was ended: we give back a value of - type [T]! We thus *mustn't* introduce a projector here. - *) - AProjBorrows (nsv, sv.sv_ty) - | _ -> raise (Failure "Unreachable") + (* TODO: there is something wrong here. + Consider this: + {[ + abs0 {'a} { AProjLoans (s0 : &'a mut T) [] } + abs1 {'b} { AProjBorrows (s0 : &'a mut T <: &'b mut T) } + ]} + + Upon ending abs1, we give back some fresh symbolic value [s1], + that we reinsert where the loan for [s0] is. However, the mutable + borrow in the type [&'a mut T] was ended: we give back a value of + type [T]! We thus *mustn't* introduce a projector here. + *) + (* AProjBorrows (nsv, sv.sv_ty) *) + raise (Failure "TODO") in AProjLoans (sv, (mv, child_proj) :: local_given_back) in @@ -739,17 +721,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id) *) let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue) : symbolic_value = - let sv_kind = - match abs_kind with - | FunCall _ -> FunCallGivenBack - | SynthRet _ -> SynthRetGivenBack - | SynthInput _ -> SynthInputGivenBack - | Loop _ -> LoopGivenBack - | Identity -> - (* Identity abstractions give back nothing *) - raise (Failure "Unreachable") - in - mk_fresh_symbolic_value sv_kind av.ty + mk_fresh_symbolic_value av.ty (** Auxiliary function: see {!end_borrow_aux}. @@ -1239,7 +1211,7 @@ and end_abstraction_borrows (config : config) (chain : borrow_or_abs_ids) ("end_abstraction_borrows: found aproj borrows: " ^ aproj_to_string ctx (AProjBorrows (sv, proj_ty)))); (* Generate a fresh symbolic value *) - let nsv = mk_fresh_symbolic_value FunCallGivenBack proj_ty in + let nsv = mk_fresh_symbolic_value proj_ty in (* Replace the proj_borrows - there should be exactly one *) let ended_borrow = AEndedProjBorrows nsv in let ctx = update_aproj_borrows abs.abs_id sv ended_borrow ctx in diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index d7f5fcd5..bbf4d9d5 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -209,8 +209,8 @@ let apply_symbolic_expansion_non_borrow (config : config) doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) - (kind : sv_kind) (def_id : TypeDeclId.id) (generics : generic_args) - (ctx : eval_ctx) : symbolic_expansion list = + (def_id : TypeDeclId.id) (generics : generic_args) (ctx : eval_ctx) : + symbolic_expansion list = (* Lookup the definition and check if it is an enumeration with several * variants *) let def = ctx_lookup_type_decl ctx def_id in @@ -227,7 +227,7 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) let initialize ((variant_id, field_types) : VariantId.id option * rty list) : symbolic_expansion = let field_values = - List.map (fun (ty : rty) -> mk_fresh_symbolic_value kind ty) field_types + List.map (fun (ty : rty) -> mk_fresh_symbolic_value ty) field_types in let see = SeAdt (variant_id, field_values) in see @@ -235,20 +235,19 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (* Initialize all the expanded values of all the variants *) List.map initialize variants_fields_types -let compute_expanded_symbolic_tuple_value (kind : sv_kind) - (field_types : rty list) : symbolic_expansion = +let compute_expanded_symbolic_tuple_value (field_types : rty list) : + symbolic_expansion = (* Generate the field values *) let field_values = - List.map (fun sv_ty -> mk_fresh_symbolic_value kind sv_ty) field_types + List.map (fun sv_ty -> mk_fresh_symbolic_value sv_ty) field_types in let variant_id = None in let see = SeAdt (variant_id, field_values) in see -let compute_expanded_symbolic_box_value (kind : sv_kind) (boxed_ty : rty) : - symbolic_expansion = +let compute_expanded_symbolic_box_value (boxed_ty : rty) : symbolic_expansion = (* Introduce a fresh symbolic value *) - let boxed_value = mk_fresh_symbolic_value kind boxed_ty in + let boxed_value = mk_fresh_symbolic_value boxed_ty in let see = SeAdt (None, [ boxed_value ]) in see @@ -262,16 +261,15 @@ let compute_expanded_symbolic_box_value (kind : sv_kind) (boxed_ty : rty) : doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_adt_value (expand_enumerations : bool) - (kind : sv_kind) (adt_id : type_id) (generics : generic_args) - (ctx : eval_ctx) : symbolic_expansion list = + (adt_id : type_id) (generics : generic_args) (ctx : eval_ctx) : + symbolic_expansion list = match (adt_id, generics.regions, generics.types) with | TAdtId def_id, _, _ -> - compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind - def_id generics ctx - | TTuple, [], _ -> - [ compute_expanded_symbolic_tuple_value kind generics.types ] + compute_expanded_symbolic_non_assumed_adt_value expand_enumerations def_id + generics ctx + | TTuple, [], _ -> [ compute_expanded_symbolic_tuple_value generics.types ] | TAssumed TBox, [], [ boxed_ty ] -> - [ compute_expanded_symbolic_box_value kind boxed_ty ] + [ compute_expanded_symbolic_box_value boxed_ty ] | _ -> raise (Failure "compute_expanded_symbolic_adt_value: unexpected combination") @@ -313,7 +311,7 @@ let expand_symbolic_value_shared_borrow (config : config) else None in (* The fresh symbolic value for the shared value *) - let shared_sv = mk_fresh_symbolic_value original_sv.sv_kind ref_ty in + let shared_sv = mk_fresh_symbolic_value ref_ty in (* Visitor to replace the projectors on borrows *) let obj = object (self) @@ -403,7 +401,7 @@ let expand_symbolic_value_borrow (config : config) | RMut -> (* Simple case: simply create a fresh symbolic value and a fresh * borrow id *) - let sv = mk_fresh_symbolic_value original_sv.sv_kind ref_ty in + let sv = mk_fresh_symbolic_value ref_ty in let bid = fresh_borrow_id () in let see = SeMutRef (bid, sv) in (* Expand the symbolic values - we simply perform a substitution (and @@ -524,8 +522,8 @@ let expand_symbolic_value_no_branching (config : config) (sv : symbolic_value) (* Compute the expanded value *) let allow_branching = false in let seel = - compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - generics ctx + compute_expanded_symbolic_adt_value allow_branching adt_id generics + ctx in (* There should be exacly one branch *) let see = Collections.List.to_cons_nil seel in @@ -581,8 +579,7 @@ let expand_symbolic_adt (config : config) (sv : symbolic_value) let allow_branching = true in (* Compute the expanded value *) let seel = - compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - generics ctx + compute_expanded_symbolic_adt_value allow_branching adt_id generics ctx in (* Apply *) let seel = List.map (fun see -> (Some see, cf_branches)) seel in diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index af545fb9..9f117933 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -283,7 +283,7 @@ let eval_operand_no_reorganize (config : config) (op : operand) List.find (fun (name, _) -> name = const_name) trait_decl.consts in (* Introduce a fresh symbolic value *) - let v = mk_fresh_symbolic_typed_value TraitConst ty in + let v = mk_fresh_symbolic_typed_value ty in (* Continue the evaluation *) let e = cf v ctx in (* We have to wrap the generated expression *) @@ -315,7 +315,7 @@ let eval_operand_no_reorganize (config : config) (op : operand) copy_value allow_adt_copy config ctx cv | SymbolicMode -> (* We use the looked up value only for its type *) - let v = mk_fresh_symbolic_typed_value KindConstGeneric cv.ty in + let v = mk_fresh_symbolic_typed_value cv.ty in (ctx, v) in (* Continue *) @@ -464,9 +464,7 @@ let eval_unary_op_symbolic (config : config) (unop : unop) (op : operand) | Cast (CastInteger (_, tgt_ty)), _ -> TLiteral (TInteger tgt_ty) | _ -> raise (Failure "Invalid input for unop") in - let res_sv = - { sv_kind = FunCallRet; sv_id = res_sv_id; sv_ty = res_sv_ty } - in + let res_sv = { sv_id = res_sv_id; sv_ty = res_sv_ty } in (* Call the continuation *) let expr = cf (Ok (mk_typed_value_from_symbolic_value res_sv)) ctx in (* Synthesize the symbolic AST *) @@ -603,9 +601,7 @@ let eval_binary_op_symbolic (config : config) (binop : binop) (op1 : operand) | Ne | Eq -> raise (Failure "Unreachable")) | _ -> raise (Failure "Invalid inputs for binop") in - let res_sv = - { sv_kind = FunCallRet; sv_id = res_sv_id; sv_ty = res_sv_ty } - in + let res_sv = { sv_id = res_sv_id; sv_ty = res_sv_ty } in (* Call the continuattion *) let v = mk_typed_value_from_symbolic_value res_sv in let expr = cf (Ok v) ctx in @@ -769,7 +765,7 @@ let eval_rvalue_aggregate (config : config) (aggregate_kind : aggregate_kind) array we introduce here might be duplicated in the generated code: by introducing a symbolic value we introduce a let-binding in the generated code. *) - let saggregated = mk_fresh_symbolic_typed_value Aggregate ty in + let saggregated = mk_fresh_symbolic_typed_value ty in (* Call the continuation *) match cf saggregated ctx with | None -> None diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index c21dab71..90559c29 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -389,7 +389,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct let match_distinct_literals (ty : ety) (_ : literal) (_ : literal) : typed_value = - mk_fresh_symbolic_typed_value_from_no_regions_ty LoopJoin ty + mk_fresh_symbolic_typed_value_from_no_regions_ty ty let match_distinct_adts (ty : ety) (adt0 : adt_value) (adt1 : adt_value) : typed_value = @@ -418,7 +418,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct check_loans false adt1.field_values; (* No borrows, no loans: we can introduce a symbolic value *) - mk_fresh_symbolic_typed_value_from_no_regions_ty LoopJoin ty + mk_fresh_symbolic_typed_value_from_no_regions_ty ty let match_shared_borrows _ (ty : ety) (bid0 : borrow_id) (bid1 : borrow_id) : borrow_id = @@ -435,9 +435,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct (* Generate a fresh symbolic value for the shared value *) let _, bv_ty, kind = ty_as_ref ty in - let sv = - mk_fresh_symbolic_typed_value_from_no_regions_ty LoopJoin bv_ty - in + let sv = mk_fresh_symbolic_typed_value_from_no_regions_ty bv_ty in let borrow_ty = mk_ref_ty (RFVar rid) bv_ty kind in @@ -582,9 +580,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct (* Generate a fresh symbolic value for the borrowed value *) let _, bv_ty, kind = ty_as_ref ty in - let sv = - mk_fresh_symbolic_typed_value_from_no_regions_ty LoopJoin bv_ty - in + let sv = mk_fresh_symbolic_typed_value_from_no_regions_ty bv_ty in let borrow_ty = mk_ref_ty (RFVar rid) bv_ty kind in @@ -664,7 +660,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct borrows *) assert (not (ty_has_borrows S.ctx.type_context.type_infos sv0.sv_ty)); (* We simply introduce a fresh symbolic value *) - mk_fresh_symbolic_value LoopJoin sv0.sv_ty) + mk_fresh_symbolic_value sv0.sv_ty) let match_symbolic_with_other (left : bool) (sv : symbolic_value) (v : typed_value) : typed_value = @@ -685,7 +681,7 @@ module MakeJoinMatcher (S : MatchJoinState) : PrimMatcher = struct if value_is_left then raise (ValueMatchFailure (LoanInLeft id)) else raise (ValueMatchFailure (LoanInRight id))); (* Return a fresh symbolic value *) - mk_fresh_symbolic_typed_value LoopJoin sv.sv_ty + mk_fresh_symbolic_typed_value sv.sv_ty let match_bottom_with_other (left : bool) (v : typed_value) : typed_value = (* If there are outer loans in the non-bottom value, raise an exception. @@ -834,7 +830,7 @@ struct let match_distinct_literals (ty : ety) (_ : literal) (_ : literal) : typed_value = - mk_fresh_symbolic_typed_value_from_no_regions_ty LoopJoin ty + mk_fresh_symbolic_typed_value_from_no_regions_ty ty let match_distinct_adts (_ty : ety) (_adt0 : adt_value) (_adt1 : adt_value) : typed_value = @@ -904,11 +900,7 @@ struct GetSetSid.match_e "match_symbolic_values: ids: " S.sid_map id0 id1 in let sv_ty = match_rtys sv0.sv_ty sv1.sv_ty in - let sv_kind = - if sv0.sv_kind = sv1.sv_kind then sv0.sv_kind - else raise (Distinct "match_symbolic_values: sv_kind") - in - let sv = { sv_id; sv_ty; sv_kind } in + let sv = { sv_id; sv_ty } in sv else ( (* Check: fixed values are fixed *) diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 437b358a..66b8492a 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -1008,7 +1008,7 @@ and eval_global (config : config) (dest : place) (gid : GlobalDeclId.id) : (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be * defined as equal to the value of the global (see {!S.synthesize_global_eval}). *) assert (ty_no_regions global.ty); - let sval = mk_fresh_symbolic_value Global global.ty in + let sval = mk_fresh_symbolic_value global.ty in let cc = assign_to_place config (mk_typed_value_from_symbolic_value sval) dest in @@ -1312,7 +1312,7 @@ and eval_function_call_symbolic_from_inst_sig (config : config) (* Generate a fresh symbolic value for the return value *) let ret_sv_ty = inst_sg.output in - let ret_spc = mk_fresh_symbolic_value FunCallRet ret_sv_ty in + let ret_spc = mk_fresh_symbolic_value ret_sv_ty in let ret_value = mk_typed_value_from_symbolic_value ret_spc in let ret_av regions = mk_aproj_loans_value_from_symbolic_value regions ret_spc diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index d3f8f4fa..e04a6b90 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -74,31 +74,29 @@ let mk_place_from_var_id (var_id : VarId.id) : place = { var_id; projection = [] } (** Create a fresh symbolic value *) -let mk_fresh_symbolic_value (sv_kind : sv_kind) (ty : ty) : symbolic_value = +let mk_fresh_symbolic_value (ty : ty) : symbolic_value = (* Sanity check *) assert (ty_is_rty ty); let sv_id = fresh_symbolic_value_id () in - let svalue = { sv_kind; sv_id; sv_ty = ty } in + let svalue = { sv_id; sv_ty = ty } in svalue -let mk_fresh_symbolic_value_from_no_regions_ty (sv_kind : sv_kind) (ty : ty) : - symbolic_value = +let mk_fresh_symbolic_value_from_no_regions_ty (ty : ty) : symbolic_value = assert (ty_no_regions ty); - mk_fresh_symbolic_value sv_kind ty + mk_fresh_symbolic_value ty (** Create a fresh symbolic value *) -let mk_fresh_symbolic_typed_value (sv_kind : sv_kind) (rty : ty) : typed_value = +let mk_fresh_symbolic_typed_value (rty : ty) : typed_value = assert (ty_is_rty rty); let ty = Substitute.erase_regions rty in (* Generate the fresh a symbolic value *) - let value = mk_fresh_symbolic_value sv_kind rty in + let value = mk_fresh_symbolic_value rty in let value = VSymbolic value in { value; ty } -let mk_fresh_symbolic_typed_value_from_no_regions_ty (sv_kind : sv_kind) - (ty : ty) : typed_value = +let mk_fresh_symbolic_typed_value_from_no_regions_ty (ty : ty) : typed_value = assert (ty_no_regions ty); - mk_fresh_symbolic_typed_value sv_kind ty + mk_fresh_symbolic_typed_value ty (** Create a typed value from a symbolic value. *) let mk_typed_value_from_symbolic_value (svalue : symbolic_value) : typed_value = @@ -267,15 +265,9 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : eval_ctx) inherit [_] iter_typed_value method! visit_symbolic_value _ s = - match s.sv_kind with - | FunCallRet | LoopOutput | LoopJoin -> - if ty_has_borrow_under_mut ctx.type_context.type_infos s.sv_ty then - raise Found - else () - | SynthInput | SynthInputGivenBack | FunCallGivenBack - | SynthRetGivenBack | Global | KindConstGeneric | LoopGivenBack - | Aggregate | ConstGeneric | TraitConst -> - () + if ty_has_borrow_under_mut ctx.type_context.type_infos s.sv_ty then + raise Found + else () end in (* We use exceptions *) @@ -430,7 +422,7 @@ let initialize_eval_context (ctx : decls_ctx) (List.map (fun (cg : const_generic_var) -> let ty = TLiteral cg.ty in - let cv = mk_fresh_symbolic_typed_value ConstGeneric ty in + let cv = mk_fresh_symbolic_typed_value ty in (cg.index, cv)) const_generic_vars) in diff --git a/compiler/Values.ml b/compiler/Values.ml index 60cbcc8b..5473ce3e 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -15,47 +15,6 @@ module LoopId = IdGen () type symbolic_value_id = SymbolicValueId.id [@@deriving show, ord] type symbolic_value_id_set = SymbolicValueId.Set.t [@@deriving show, ord] type loop_id = LoopId.id [@@deriving show, ord] - -(** The kind of a symbolic value, which precises how the value was generated. - - TODO: remove? This is not useful anymore. - *) -type sv_kind = - | FunCallRet (** The value is the return value of a function call *) - | FunCallGivenBack - (** The value is a borrowed value given back by an abstraction - (happens when giving a borrow to a function: when the abstraction - introduced to model the function call ends we reintroduce a symbolic - value in the context for the value modified by the abstraction through - the borrow). - *) - | SynthInput - (** The value is an input value of the function whose body we are - currently synthesizing. - *) - | SynthRetGivenBack - (** The value is a borrowed value that the function whose body we are - synthesizing returned, and which was given back because we ended - one of the lifetimes of this function (we do this to synthesize - the backward functions). - *) - | SynthInputGivenBack - (** The value was given back upon ending one of the input abstractions *) - | Global (** The value is a global *) - | KindConstGeneric (** The value is a const generic *) - | LoopOutput (** The output of a loop (seen as a forward computation) *) - | LoopGivenBack - (** A value given back by a loop (when ending abstractions while going backwards) *) - | LoopJoin - (** The result of a loop join (when computing loop fixed points) *) - | Aggregate - (** A symbolic value we introduce in place of an aggregate value *) - | ConstGeneric - (** A symbolic value we introduce when using a const generic as a value *) - | TraitConst - (** A symbolic value we introduce when evaluating a trait associated constant *) -[@@deriving show, ord] - type borrow_id = BorrowId.id [@@deriving show, ord] type borrow_id_set = BorrowId.Set.t [@@deriving show, ord] type loan_id = BorrowId.id [@@deriving show, ord] @@ -65,7 +24,6 @@ type loan_id_set = BorrowId.Set.t [@@deriving show, ord] class ['self] iter_typed_value_base = object (self : 'self) inherit [_] iter_ty - method visit_sv_kind : 'env -> sv_kind -> unit = fun _ _ -> () method visit_symbolic_value_id : 'env -> symbolic_value_id -> unit = fun _ _ -> () @@ -85,7 +43,6 @@ class ['self] iter_typed_value_base = class ['self] map_typed_value_base = object (self : 'self) inherit [_] map_ty - method visit_sv_kind : 'env -> sv_kind -> sv_kind = fun _ x -> x method visit_symbolic_value_id : 'env -> symbolic_value_id -> symbolic_value_id = @@ -104,7 +61,6 @@ class ['self] map_typed_value_base = (** A symbolic value *) type symbolic_value = { - sv_kind : sv_kind; sv_id : symbolic_value_id; sv_ty : ty; (** This should be a type with regions *) } -- cgit v1.2.3 From 0209fee47a11b371d258fe02b8cc59b325de21d6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 12:07:39 +0100 Subject: Use a better syntax when extracting tuple types (structures with unnamed fields) --- compiler/Config.ml | 18 +++- compiler/Extract.ml | 74 ++++++++----- compiler/ExtractBase.ml | 17 +-- compiler/ExtractTypes.ml | 230 ++++++++++++++++++++++++----------------- compiler/InterpreterBorrows.ml | 19 ++-- compiler/PureMicroPasses.ml | 82 ++++++++------- compiler/PureUtils.ml | 11 ++ compiler/SymbolicToPure.ml | 30 ++++-- compiler/TypesAnalysis.ml | 47 +++++++-- compiler/TypesUtils.ml | 18 ++++ 10 files changed, 347 insertions(+), 199 deletions(-) (limited to 'compiler') diff --git a/compiler/Config.ml b/compiler/Config.ml index 364ef748..b09544ba 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -338,7 +338,7 @@ let type_check_pure_code = ref false as far as possible while leaving "holes" in the generated code? *) let fail_hard = ref true -(** if true, add the type name as a prefix +(** If true, add the type name as a prefix to the variant names. Ex.: In Rust: @@ -364,3 +364,19 @@ let fail_hard = ref true ]} *) let variant_concatenate_type_name = ref true + +(** If true, extract the structures with unnamed fields as tuples. + + ex.: + {[ + // Rust + struct Foo(u32) + + // OCaml + type Foo = (u32) + ]} + *) +let use_tuple_structs = ref true + +let backend_has_tuple_projectors () = + match !backend with Lean -> true | Coq | FStar | HOL4 -> false diff --git a/compiler/Extract.ml b/compiler/Extract.ml index e48e6ae6..85bdd929 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -111,7 +111,7 @@ let extract_global_decl_register_names (ctx : extraction_ctx) context updated with new bindings. [is_single_pat]: are we extracting a single pattern (a pattern for a let-binding - or a lambda). + or a lambda)? TODO: we don't need something very generic anymore (some definitions used to be polymorphic). @@ -121,38 +121,53 @@ let extract_adt_g_value (fmt : F.formatter) (ctx : extraction_ctx) (is_single_pat : bool) (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = + let extract_as_tuple () = + (* This is very annoying: in Coq, we can't write [()] for the value of + type [unit], we have to write [tt]. *) + if !backend = Coq && field_values = [] then ( + F.pp_print_string fmt "tt"; + ctx) + else + (* If there is exactly one value, we don't print the parentheses *) + let lb, rb = + if List.length field_values = 1 then ("", "") else ("(", ")") + in + F.pp_print_string fmt lb; + let ctx = + Collections.List.fold_left_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun ctx v -> extract_value ctx false v) + ctx field_values + in + F.pp_print_string fmt rb; + ctx + in match ty with | TAdt (TTuple, generics) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) assert (List.length generics.types = List.length field_values); assert (generics.const_generics = [] && generics.trait_refs = []); - (* This is very annoying: in Coq, we can't write [()] for the value of - type [unit], we have to write [tt]. *) - if !backend = Coq && field_values = [] then ( - F.pp_print_string fmt "tt"; - ctx) - else ( - F.pp_print_string fmt "("; - let ctx = - Collections.List.fold_left_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (fun ctx v -> extract_value ctx false v) - ctx field_values - in - F.pp_print_string fmt ")"; - ctx) + extract_as_tuple () | TAdt (adt_id, _) -> (* "Regular" ADT *) - - (* If we are generating a pattern for a let-binding and we target Lean, - the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. - - Otherwise, it is: `let Cons x0 ... xn = ...` - *) - if is_single_pat && !Config.backend = Lean then ( + (* We may still extract the ADT as a tuple, if none of the fields are + named *) + if + PureUtils.type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos adt_id + then (* Extract as a tuple *) + extract_as_tuple () + else if + (* If we are generating a pattern for a let-binding and we target Lean, + the syntax is: `let ⟨ x0, ..., xn ⟩ := ...`. + + Otherwise, it is: `let Cons x0 ... xn = ...` + *) + is_single_pat && !Config.backend = Lean + then ( F.pp_print_string fmt "⟨"; F.pp_print_space fmt (); let ctx = @@ -517,7 +532,14 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) match args with | [ arg ] -> (* Exactly one argument: pretty-print *) - let field_name = ctx_get_field proj.adt_id proj.field_id ctx in + let field_name = + (* Check if we need to extract the type as a structure *) + if + PureUtils.type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos proj.adt_id + then FieldId.to_string proj.field_id + else ctx_get_field proj.adt_id proj.field_id ctx + in (* Open a box *) F.pp_open_hovbox fmt ctx.indent_incr; (* Extract the expression *) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 43658b6e..93204515 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -109,7 +109,7 @@ let decl_is_first_from_group (kind : decl_kind) : bool = let decl_is_not_last_from_group (kind : decl_kind) : bool = not (decl_is_last_from_group kind) -type type_decl_kind = Enum | Struct [@@deriving show] +type type_decl_kind = Enum | Struct | Tuple [@@deriving show] (** We use identifiers to look for name clashes *) type id = @@ -1194,12 +1194,13 @@ let type_decl_kind_to_qualif (kind : decl_kind) | Declared -> Some "val") | Coq -> ( match (kind, type_kind) with + | SingleNonRec, Some Tuple -> Some "Definition" | SingleNonRec, Some Enum -> Some "Inductive" | SingleNonRec, Some Struct -> Some "Record" | (SingleRec | MutRecFirst), Some _ -> Some "Inductive" | (MutRecInner | MutRecLast), Some _ -> (* Coq doesn't support groups of mutually recursive definitions which mix - * records and inducties: we convert everything to records if this happens + * records and inductives: we convert everything to records if this happens *) Some "with" | (Assumed | Declared), None -> Some "Axiom" @@ -1214,12 +1215,12 @@ let type_decl_kind_to_qualif (kind : decl_kind) ^ ")"))) | Lean -> ( match kind with - | SingleNonRec -> - if type_kind = Some Struct then Some "structure" else Some "inductive" - | SingleRec -> Some "inductive" - | MutRecFirst -> Some "inductive" - | MutRecInner -> Some "inductive" - | MutRecLast -> Some "inductive" + | SingleNonRec -> ( + match type_kind with + | Some Tuple -> Some "def" + | Some Struct -> Some "structure" + | _ -> Some "inductive") + | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> Some "inductive" | Assumed -> Some "axiom" | Declared -> Some "axiom") | HOL4 -> None diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 3657627b..22243a4a 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -1,7 +1,4 @@ (** The generic extraction *) -(* Turn the whole module into a functor: it is very annoying to carry the - the formatter everywhere... -*) open Pure open PureUtils @@ -696,92 +693,101 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : * - the field names, if this is a structure *) let ctx = - match def.kind with - | Struct fields -> - (* Compute the names *) - let field_names, cons_name = - match info with - | None | Some { body_info = None; _ } -> - let field_names = - FieldId.mapi - (fun fid (field : field) -> - ( fid, - ctx_compute_field_name ctx def.llbc_name fid - field.field_name )) - fields - in - let cons_name = - ctx_compute_struct_constructor ctx def.llbc_name - in - (field_names, cons_name) - | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> - let field_names = - FieldId.mapi - (fun fid (field : field) -> - let rust_name = Option.get field.field_name in + (* Ignore this if the type is to be extracted as a tuple *) + if + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos def.def_id + then ctx + else + match def.kind with + | Struct fields -> + (* Compute the names *) + let field_names, cons_name = + match info with + | None | Some { body_info = None; _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + ( fid, + ctx_compute_field_name ctx def.llbc_name fid + field.field_name )) + fields + in + let cons_name = + ctx_compute_struct_constructor ctx def.llbc_name + in + (field_names, cons_name) + | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + let rust_name = Option.get field.field_name in + let name = + snd + (List.find (fun (n, _) -> n = rust_name) field_names) + in + (fid, name)) + fields + in + (field_names, cons_name) + | Some info -> + raise + (Failure + ("Invalid builtin information: " + ^ show_builtin_type_info info)) + in + (* Add the fields *) + let ctx = + List.fold_left + (fun ctx (fid, name) -> + ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx) + ctx field_names + in + (* Add the constructor name *) + ctx_add (StructId (TAdtId def.def_id)) cons_name ctx + | Enum variants -> + let variant_names = + match info with + | None -> + VariantId.mapi + (fun variant_id (variant : variant) -> let name = - snd (List.find (fun (n, _) -> n = rust_name) field_names) + ctx_compute_variant_name ctx def.llbc_name + variant.variant_name in - (fid, name)) - fields - in - (field_names, cons_name) - | Some info -> - raise - (Failure - ("Invalid builtin information: " - ^ show_builtin_type_info info)) - in - (* Add the fields *) - let ctx = + (* Add the type name prefix for Lean *) + let name = + if !Config.backend = Lean then + let type_name = + ctx_compute_type_name ctx def.llbc_name + in + type_name ^ "." ^ name + else name + in + (variant_id, name)) + variants + | Some { body_info = Some (Enum variant_infos); _ } -> + (* We need to compute the map from variant to variant *) + let variant_map = + StringMap.of_list + (List.map + (fun (info : builtin_enum_variant_info) -> + (info.rust_variant_name, info.extract_variant_name)) + variant_infos) + in + VariantId.mapi + (fun variant_id (variant : variant) -> + (variant_id, StringMap.find variant.variant_name variant_map)) + variants + | _ -> raise (Failure "Invalid builtin information") + in List.fold_left - (fun ctx (fid, name) -> - ctx_add (FieldId (TAdtId def.def_id, fid)) name ctx) - ctx field_names - in - (* Add the constructor name *) - ctx_add (StructId (TAdtId def.def_id)) cons_name ctx - | Enum variants -> - let variant_names = - match info with - | None -> - VariantId.mapi - (fun variant_id (variant : variant) -> - let name = - ctx_compute_variant_name ctx def.llbc_name - variant.variant_name - in - (* Add the type name prefix for Lean *) - let name = - if !Config.backend = Lean then - let type_name = ctx_compute_type_name ctx def.llbc_name in - type_name ^ "." ^ name - else name - in - (variant_id, name)) - variants - | Some { body_info = Some (Enum variant_infos); _ } -> - (* We need to compute the map from variant to variant *) - let variant_map = - StringMap.of_list - (List.map - (fun (info : builtin_enum_variant_info) -> - (info.rust_variant_name, info.extract_variant_name)) - variant_infos) - in - VariantId.mapi - (fun variant_id (variant : variant) -> - (variant_id, StringMap.find variant.variant_name variant_map)) - variants - | _ -> raise (Failure "Invalid builtin information") - in - List.fold_left - (fun ctx (vid, vname) -> - ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx) - ctx variant_names - | Opaque -> - (* Nothing to do *) - ctx + (fun ctx (vid, vname) -> + ctx_add (VariantId (TAdtId def.def_id, vid)) vname ctx) + ctx variant_names + | Opaque -> + (* Nothing to do *) + ctx in (* Return *) ctx @@ -906,6 +912,19 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in List.iter (fun (vid, v) -> print_variant vid v) variants +(** Extract a struct as a tuple *) +let extract_type_decl_tuple_struct_body (ctx : extraction_ctx) + (fmt : F.formatter) (fields : field list) : unit = + let sep = match !backend with Coq | FStar | HOL4 -> "*" | Lean -> "×" in + Collections.List.iter_link + (fun () -> + F.pp_print_space fmt (); + F.pp_print_string fmt sep) + (fun (f : field) -> + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty true f.field_ty) + fields + let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) (type_params : string list) (cg_params : string list) (fields : field list) @@ -1264,12 +1283,18 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (extract_body : bool) : unit = (* Sanity check *) assert (extract_body || !backend <> HOL4); + let is_tuple_struct = + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos def.def_id + in let type_kind = if extract_body then - match def.kind with - | Struct _ -> Some Struct - | Enum _ -> Some Enum - | Opaque -> None + if is_tuple_struct then Some Tuple + else + match def.kind with + | Struct _ -> Some Struct + | Enum _ -> Some Enum + | Opaque -> None else None in (* If in Coq and the declaration is opaque, it must have the shape: @@ -1300,7 +1325,8 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) * for parsing: we thus use a hovbox. *) (match !backend with | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 - | Lean -> F.pp_open_vbox fmt 0); + | Lean -> + if is_tuple_struct then F.pp_open_hvbox fmt 0 else F.pp_open_vbox fmt 0); (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "type TYPE_NAME" *) @@ -1320,7 +1346,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let eq = match !backend with | FStar -> "=" - | Coq -> ":=" + | Coq -> + (* For Coq, the `*` is overloaded. If we want to extract a product + type (and not a product between, say, integers) we need to help Coq + a bit *) + if is_tuple_struct then ": Type :=" else ":=" | Lean -> if type_kind = Some Struct && kind = SingleNonRec then "where" else ":=" @@ -1341,8 +1371,11 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (if extract_body then match def.kind with | Struct fields -> - extract_type_decl_struct_body ctx_body fmt type_decl_group kind def - type_params cg_params fields + if is_tuple_struct then + extract_type_decl_tuple_struct_body ctx_body fmt fields + else + extract_type_decl_struct_body ctx_body fmt type_decl_group kind def + type_params cg_params fields | Enum variants -> extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name type_params cg_params variants @@ -1670,8 +1703,13 @@ let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) match !backend with | FStar | Lean | HOL4 -> () | Coq -> - extract_type_decl_coq_arguments ctx fmt kind decl; - extract_type_decl_record_field_projectors ctx fmt kind decl + if + not + (TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos decl.def_id) + then ( + extract_type_decl_coq_arguments ctx fmt kind decl; + extract_type_decl_record_field_projectors ctx fmt kind decl) (** Extract the state type declaration. *) let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 19b9fd3b..e56919fa 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -706,7 +706,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id) (** Convert an {!type:avalue} to a {!type:value}. This function is used when ending abstractions: whenever we end a borrow - in an abstraction, we converted the borrowed {!avalue} to a fresh symbolic + in an abstraction, we convert the borrowed {!avalue} to a fresh symbolic {!type:value}, then give back this {!type:value} to the context. Note that some regions may have ended in the symbolic value we generate. @@ -719,8 +719,7 @@ let reborrow_shared (original_bid : BorrowId.id) (new_bid : BorrowId.id) be expanded (because expanding this symbolic value would require expanding a reference whose region has already ended). *) -let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue) - : symbolic_value = +let convert_avalue_to_given_back_value (av : typed_avalue) : symbolic_value = mk_fresh_symbolic_value av.ty (** Auxiliary function: see {!end_borrow_aux}. @@ -739,8 +738,8 @@ let convert_avalue_to_given_back_value (abs_kind : abs_kind) (av : typed_avalue) borrows. This kind of internal reshuffling. should be similar to ending abstractions (it is tantamount to ending *sub*-abstractions). *) -let give_back (config : config) (abs_id_opt : AbstractionId.id option) - (l : BorrowId.id) (bc : g_borrow_content) (ctx : eval_ctx) : eval_ctx = +let give_back (config : config) (l : BorrowId.id) (bc : g_borrow_content) + (ctx : eval_ctx) : eval_ctx = (* Debug *) log#ldebug (lazy @@ -781,9 +780,7 @@ let give_back (config : config) (abs_id_opt : AbstractionId.id option) Rem.: we shouldn't do this here. We should do this in a function which takes care of ending *sub*-abstractions. *) - let abs_id = Option.get abs_id_opt in - let abs = ctx_lookup_abs ctx abs_id in - let sv = convert_avalue_to_given_back_value abs.kind av in + let sv = convert_avalue_to_given_back_value av in (* Update the context *) give_back_avalue_to_same_abstraction config l av (mk_typed_value_from_symbolic_value sv) @@ -929,14 +926,14 @@ let rec end_borrow_aux (config : config) (chain : borrow_or_abs_ids) cf_check cf ctx (* We found a borrow and replaced it with [Bottom]: give it back (i.e., update the corresponding loan) *) - | Ok (ctx, Some (abs_id_opt, bc)) -> + | Ok (ctx, Some (_, bc)) -> (* Sanity check: the borrowed value shouldn't contain loans *) (match bc with | Concrete (VMutBorrow (_, bv)) -> assert (Option.is_none (get_first_loan_in_value bv)) | _ -> ()); (* Give back the value *) - let ctx = give_back config abs_id_opt l bc ctx in + let ctx = give_back config l bc ctx in (* Do a sanity check and continue *) cf_check cf ctx @@ -1161,7 +1158,7 @@ and end_abstraction_borrows (config : config) (chain : borrow_or_abs_ids) match bc with | AMutBorrow (bid, av) -> (* First, convert the avalue to a (fresh symbolic) value *) - let sv = convert_avalue_to_given_back_value abs.kind av in + let sv = convert_avalue_to_given_back_value av in (* Replace the mut borrow to register the fact that we ended * it and store with it the freshly generated given back value *) let ended_borrow = ABorrow (AEndedMutBorrow (sv, av)) in diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index d0741b29..68f8943a 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -563,12 +563,13 @@ let remove_meta (def : fun_decl) : fun_decl = This micro-pass turns those into expressions which use structure syntax: {[ - { - f0 := x0; - ... - fn := xn; - } + type struct = { f0 : nat; f1 : nat; f2 : nat } + + Mkstruct x.f0 x.f1 y ~~> { x with f2 = y } ]} + + Note however that we do not apply this transformation if the + structure is to be extracted as a tuple. *) let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let obj = @@ -592,37 +593,44 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = } -> (* Lookup the def *) let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in - (* Check that there are as many arguments as there are fields - note - that the def should have a body (otherwise we couldn't use the - constructor) *) - let fields = TypesUtils.type_decl_get_fields decl None in - if List.length fields = List.length args then - (* Check if the definition is recursive *) - let is_rec = - match - TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups - with - | NonRecGroup _ -> false - | RecGroup _ -> true - in - (* Convert, if possible - note that for now for Lean and Coq - we don't support the structure syntax on recursive structures *) - if - (!Config.backend <> Lean && !Config.backend <> Coq) - || not is_rec - then - let struct_id = TAdtId adt_id in - let init = None in - let updates = - FieldId.mapi - (fun fid fe -> (fid, self#visit_texpression env fe)) - args + (* Check if the def will be extracted as a tuple *) + if + TypesUtils.type_decl_from_decl_id_is_tuple_struct + ctx.type_ctx.type_infos adt_id + then ignore () + else + (* Check that there are as many arguments as there are fields - note + that the def should have a body (otherwise we couldn't use the + constructor) *) + let fields = TypesUtils.type_decl_get_fields decl None in + if List.length fields = List.length args then + (* Check if the definition is recursive *) + let is_rec = + match + TypeDeclId.Map.find adt_id + ctx.type_ctx.type_decls_groups + with + | NonRecGroup _ -> false + | RecGroup _ -> true in - let ne = { struct_id; init; updates } in - let nty = e.ty in - { e = StructUpdate ne; ty = nty } + (* Convert, if possible - note that for now for Lean and Coq + we don't support the structure syntax on recursive structures *) + if + (!Config.backend <> Lean && !Config.backend <> Coq) + || not is_rec + then + let struct_id = TAdtId adt_id in + let init = None in + let updates = + FieldId.mapi + (fun fid fe -> (fid, self#visit_texpression env fe)) + args + in + let ne = { struct_id; init; updates } in + let nty = e.ty in + { e = StructUpdate ne; ty = nty } + else ignore () else ignore () - else ignore () | _ -> ignore ()) | _ -> super#visit_texpression env e end @@ -1069,12 +1077,10 @@ let simplify_let_then_return _ctx def = (** Simplify the aggregated ADTs. Ex.: {[ - type struct = { f0 : nat; f1 : nat } + type struct = { f0 : nat; f1 : nat; f2 : nat } - Mkstruct x.f0 x.f1 ~~> x + Mkstruct x.f0 x.f1 x.f2 ~~> x ]} - - TODO: introduce a notation for [{ x with field = ... }], and use it. *) let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let expr_visitor = diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index a5143f3c..39dcd52d 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -687,3 +687,14 @@ let trait_impl_is_empty (trait_impl : trait_impl) : bool = in parent_trait_refs = [] && consts = [] && types = [] && required_methods = [] && provided_methods = [] + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : type_id) : bool = + match id with + | TTuple -> true + | TAdtId id -> + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + | TAssumed _ -> false diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3b30549c..bf4d26f2 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2299,11 +2299,11 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match sexp with | V.SeLiteral _ -> (* We do not *register* symbolic expansions to literal - * values in the symbolic ADT *) + values in the symbolic ADT *) raise (Failure "Unreachable") | SeMutRef (_, nsv) | SeSharedRef (_, nsv) -> (* The (mut/shared) borrow type is extracted to identity: we thus simply - * introduce an reassignment *) + introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let next_e = translate_expression e ctx in let monadic = false in @@ -2324,10 +2324,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) && !Config.always_deconstruct_adts_with_matches) -> (* There is exactly one branch: no branching. - We can decompose the ADT value with a let-binding, unless - the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): - we *ignore* this branch (and go to the next one) if the ADT is a custom - adt, and [always_deconstruct_adts_with_matches] is true. + We can decompose the ADT value with a let-binding, unless + the backend doesn't support this (see {!Config.always_deconstruct_adts_with_matches}): + we *ignore* this branch (and go to the next one) if the ADT is a custom + adt, and [always_deconstruct_adts_with_matches] is true. *) translate_ExpandAdt_one_branch sv scrutinee scrutinee_mplace variant_id svl branch ctx @@ -2361,7 +2361,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) { e; ty }) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any - * new values/variables *) + new values/variables *) let true_e = translate_expression true_e ctx in let false_e = translate_expression false_e ctx in let e = @@ -2376,7 +2376,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : match_branch = (* We don't need to update the context: we don't introduce any - * new values/variables *) + new values/variables *) let branch = translate_expression branch_e ctx in let pat = mk_typed_pattern_from_literal (VScalar v) in { pat; branch } @@ -2436,20 +2436,28 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (* Detect if this is an enumeration or not *) let tdef = bs_ctx_lookup_llbc_type_decl adt_id ctx in let is_enum = TypesUtils.type_decl_is_enum tdef in - (* We deconstruct the ADT with a let-binding in two situations: + (* We deconstruct the ADT with a single let-binding in two situations: - if the ADT is an enumeration (which must have exactly one branch) - if we forbid using field projectors. *) let is_rec_def = T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls in - let use_let = + let use_let_with_cons = is_enum || !Config.dont_use_field_projectors (* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *) || (!Config.backend = Lean && is_rec_def) + (* Also, there is a special case when the ADT is to be extracted as + a tuple, because it is a structure with unnamed fields. Some backends + like Lean have projectors for tuples (like so: `x.3`), but others + like Coq don't, in which case we have to deconstruct the whole ADT + at once (`let (a, b, c) = x in`) *) + || TypesUtils.type_decl_from_type_id_is_tuple_struct + ctx.type_context.type_infos type_id + && not (Config.backend_has_tuple_projectors ()) in - if use_let then + if use_let_with_cons then (* Introduce a let binding which expands the ADT *) let lvars = List.map (fun v -> mk_typed_pattern_from_var v None) vars in let lv = mk_adt_pattern scrutinee.ty variant_id lvars in diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index 589c380c..12c20262 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -27,6 +27,10 @@ type 'p g_type_info = { borrows_info : type_borrows_info; (** Various informations about the borrows *) param_infos : 'p; (** Gives information about the type parameters *) + is_tuple_struct : bool; + (** If true, it means the type is a record that we should extract as a tuple. + This field is only valid for type declarations. + *) } [@@deriving show] @@ -55,22 +59,43 @@ let type_borrows_info_init : type_borrows_info = contains_borrow_under_mut = false; } -let initialize_g_type_info (param_infos : 'p) : 'p g_type_info = - { borrows_info = type_borrows_info_init; param_infos } +(** Return true if a type declaration is a structure with unnamed fields. -let initialize_type_decl_info (def : type_decl) : type_decl_info = + Note that there are two possibilities: + - either all the fields are named + - or none of the fields are named + *) +let type_decl_is_tuple_struct (x : type_decl) : bool = + match x.kind with + | Struct fields -> List.for_all (fun f -> f.field_name = None) fields + | _ -> false + +let initialize_g_type_info (is_tuple_struct : bool) (param_infos : 'p) : + 'p g_type_info = + { borrows_info = type_borrows_info_init; is_tuple_struct; param_infos } + +let initialize_type_decl_info (is_rec : bool) (def : type_decl) : type_decl_info + = let param_info = { under_borrow = false; under_mut_borrow = false } in let param_infos = List.map (fun _ -> param_info) def.generics.types in - initialize_g_type_info param_infos + let is_tuple_struct = + !Config.use_tuple_structs && (not is_rec) && type_decl_is_tuple_struct def + in + initialize_g_type_info is_tuple_struct param_infos let type_decl_info_to_partial_type_info (info : type_decl_info) : partial_type_info = - { borrows_info = info.borrows_info; param_infos = Some info.param_infos } + { + borrows_info = info.borrows_info; + is_tuple_struct = info.is_tuple_struct; + param_infos = Some info.param_infos; + } let partial_type_info_to_type_decl_info (info : partial_type_info) : type_decl_info = { borrows_info = info.borrows_info; + is_tuple_struct = info.is_tuple_struct; param_infos = Option.get info.param_infos; } @@ -283,14 +308,20 @@ let analyze_type_decl (updated : bool ref) (infos : type_infos) let analyze_type_declaration_group (type_decls : type_decl TypeDeclId.Map.t) (infos : type_infos) (decl : type_declaration_group) : type_infos = (* Collect the identifiers used in the declaration group *) - let ids = match decl with NonRecGroup id -> [ id ] | RecGroup ids -> ids in + let is_rec, ids = + match decl with + | NonRecGroup id -> (false, [ id ]) + | RecGroup ids -> (true, ids) + in (* Retrieve the type definitions *) let decl_defs = List.map (fun id -> TypeDeclId.Map.find id type_decls) ids in (* Initialize the type information for the current definitions *) let infos = List.fold_left (fun infos (def : type_decl) -> - TypeDeclId.Map.add def.def_id (initialize_type_decl_info def) infos) + TypeDeclId.Map.add def.def_id + (initialize_type_decl_info is_rec def) + infos) infos decl_defs in (* Analyze the types - this function simply computes a fixed-point *) @@ -327,7 +358,7 @@ let analyze_ty (infos : type_infos) (ty : ty) : ty_info = (* We don't use [updated] but need to give it as parameter *) let updated = ref false in (* We don't need to compute whether the type contains 'static or not *) - let ty_info = initialize_g_type_info None in + let ty_info = initialize_g_type_info false None in let ty_info = analyze_full_ty updated infos ty_info ty in (* Convert the ty_info *) partial_type_info_to_ty_info ty_info diff --git a/compiler/TypesUtils.ml b/compiler/TypesUtils.ml index c8418ba0..28db59ec 100644 --- a/compiler/TypesUtils.ml +++ b/compiler/TypesUtils.ml @@ -111,3 +111,21 @@ let trait_type_constraint_no_regions (x : trait_type_constraint) : bool = raise_if_region_ty_visitor#visit_ty () ty; true with Found -> false + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_decl_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : TypeDeclId.id) : bool = + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + +(** Return true if a type declaration should be extracted as a tuple, because + it is a non-recursive structure with unnamed fields. *) +let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) + (id : type_id) : bool = + match id with + | TTuple -> true + | TAdtId id -> + let info = TypeDeclId.Map.find id ctx in + info.is_tuple_struct + | TAssumed _ -> false -- cgit v1.2.3 From 6dbe9e153043e5091a4d17da9bc7c3ed7d4093b1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 12:23:57 +0100 Subject: Fix minor issues when extracting a structure with one field as a tuple --- compiler/Extract.ml | 55 ++++++++++++++++++++++++++++++------------------ compiler/ExtractTypes.ml | 11 ++++++++++ 2 files changed, 45 insertions(+), 21 deletions(-) (limited to 'compiler') diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 85bdd929..20cdb20b 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -531,28 +531,41 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) * projection ([x.field] instead of [MkAdt?.field x] *) match args with | [ arg ] -> - (* Exactly one argument: pretty-print *) - let field_name = - (* Check if we need to extract the type as a structure *) - if - PureUtils.type_decl_from_type_id_is_tuple_struct - ctx.trans_ctx.type_ctx.type_infos proj.adt_id - then FieldId.to_string proj.field_id - else ctx_get_field proj.adt_id proj.field_id ctx + let is_tuple_struct = + PureUtils.type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos proj.adt_id in - (* Open a box *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Extract the expression *) - extract_texpression ctx fmt true arg; - (* We allow to break where the "." appears (except Lean, it's a syntax error) *) - if !backend <> Lean then F.pp_print_break fmt 0 0; - F.pp_print_string fmt "."; - (* If in Coq, the field projection has to be parenthesized *) - (match !backend with - | FStar | Lean | HOL4 -> F.pp_print_string fmt field_name - | Coq -> F.pp_print_string fmt ("(" ^ field_name ^ ")")); - (* Close the box *) - F.pp_close_box fmt () + (* Check if we extract the type as a tuple, and it only has one field. + In this case, there is no projection. *) + let has_one_field = + match proj.adt_id with + | TAdtId id -> ( + let d = TypeDeclId.Map.find id ctx.trans_types in + match d.kind with Struct [ _ ] -> true | _ -> false) + | _ -> false + in + if is_tuple_struct && has_one_field then + extract_texpression ctx fmt inside arg + else + (* Exactly one argument: pretty-print *) + let field_name = + (* Check if we need to extract the type as a tuple *) + if is_tuple_struct then FieldId.to_string proj.field_id + else ctx_get_field proj.adt_id proj.field_id ctx + in + (* Open a box *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Extract the expression *) + extract_texpression ctx fmt true arg; + (* We allow to break where the "." appears (except Lean, it's a syntax error) *) + if !backend <> Lean then F.pp_print_break fmt 0 0; + F.pp_print_string fmt "."; + (* If in Coq, the field projection has to be parenthesized *) + (match !backend with + | FStar | Lean | HOL4 -> F.pp_print_string fmt field_name + | Coq -> F.pp_print_string fmt ("(" ^ field_name ^ ")")); + (* Close the box *) + F.pp_close_box fmt () | arg :: args -> (* Call extract_App again, but in such a way that the first argument is * isolated *) diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 22243a4a..08064a06 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -1287,6 +1287,9 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) TypesUtils.type_decl_from_decl_id_is_tuple_struct ctx.trans_ctx.type_ctx.type_infos def.def_id in + let is_tuple_struct_one_field = + is_tuple_struct && match def.kind with Struct [ _ ] -> true | _ -> false + in let type_kind = if extract_body then if is_tuple_struct then Some Tuple @@ -1329,6 +1332,14 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) if is_tuple_struct then F.pp_open_hvbox fmt 0 else F.pp_open_vbox fmt 0); (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) F.pp_open_hovbox fmt ctx.indent_incr; + (* > "@[reducible]" + We need this annotation, otherwise Lean sometimes doesn't manage to typecheck + the expressions when it needs to coerce the type. + *) + if is_tuple_struct_one_field && !backend = Lean then ( + F.pp_print_string fmt "@[reducible]"; + F.pp_print_space fmt ()) + else (); (* > "type TYPE_NAME" *) let qualif = type_decl_kind_to_qualif kind type_kind in (match qualif with -- cgit v1.2.3 From c17d8cbb7c32d2c2ce9d737fe5359cfbe7d4418c Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 12:44:54 +0100 Subject: Update the micro passes to inline deconstruction of tuples with one field --- compiler/PureMicroPasses.ml | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) (limited to 'compiler') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 68f8943a..959ec1c8 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -667,8 +667,8 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = leave the let-bindings where they are, and eliminated them in a subsequent pass (if they are useless). *) -let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) - (def : fun_decl) : fun_decl = +let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) + (inline_pure : bool) (def : fun_decl) : fun_decl = let obj = object (self) inherit [_] map_expression as super @@ -677,9 +677,12 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) the substitution map while doing so *) method! visit_Let (env : texpression VarId.Map.t) monadic lv re e = (* In order to filter, we need to check first that: - * - the let-binding is not monadic - * - the left-value is a variable - *) + - the let-binding is not monadic + - the left-value is a variable + + We also inline if the binding decomposes a structure that is to be + extracted as a tuple, and the right value is a variable. + *) match (monadic, lv.value) with | false, PatVar (lv_var, _) -> (* We can filter if: *) @@ -725,6 +728,31 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) let e = self#visit_texpression env e in (* Reconstruct the [let], only if the binding is not filtered *) if filter then e.e else Let (monadic, lv, re, e) + | ( false, + PatAdt + { + variant_id = None; + field_values = [ { value = PatVar (lv_var, _); ty = _ } ]; + } ) -> + (* Second case: we deconstruct a structure with one field that we will + extract as tuple. *) + let adt_id, _ = PureUtils.ty_as_adt re.ty in + (* Update the rhs (we may perform substitutions inside, and it is + * better to do them *before* we inline it *) + let re = self#visit_texpression env re in + if + PureUtils.is_var re + && type_decl_from_type_id_is_tuple_struct ctx.type_ctx.type_infos + adt_id + then + (* Update the substitution environment *) + let env = VarId.Map.add lv_var.id re env in + (* Update the next expression *) + let e = self#visit_texpression env e in + (* We filter the [let], and thus do not reconstruct it *) + e.e + else (* Nothing to do *) + super#visit_Let env monadic lv re e | _ -> super#visit_Let env monadic lv re e (** Substitute the variables *) @@ -1792,7 +1820,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = let inline_named_vars = true in let inline_pure = true in let def = - inline_useless_var_reassignments inline_named_vars inline_pure def + inline_useless_var_reassignments ctx inline_named_vars inline_pure def in log#ldebug (lazy -- cgit v1.2.3 From 2fc876ab40bed10e36f6ee6581f516cdda3b9bc4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 14:42:08 +0100 Subject: Fix the extraction of the empty type --- compiler/ExtractTypes.ml | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) (limited to 'compiler') diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 08064a06..785f7629 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -915,15 +915,20 @@ let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) (** Extract a struct as a tuple *) let extract_type_decl_tuple_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (fields : field list) : unit = - let sep = match !backend with Coq | FStar | HOL4 -> "*" | Lean -> "×" in - Collections.List.iter_link - (fun () -> - F.pp_print_space fmt (); - F.pp_print_string fmt sep) - (fun (f : field) -> - F.pp_print_space fmt (); - extract_ty ctx fmt TypeDeclId.Set.empty true f.field_ty) - fields + (* If the type is empty, we need to have a special treatment *) + if fields = [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt (unit_name ())) + else + let sep = match !backend with Coq | FStar | HOL4 -> "*" | Lean -> "×" in + Collections.List.iter_link + (fun () -> + F.pp_print_space fmt (); + F.pp_print_string fmt sep) + (fun (f : field) -> + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty true f.field_ty) + fields let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) @@ -1287,8 +1292,9 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) TypesUtils.type_decl_from_decl_id_is_tuple_struct ctx.trans_ctx.type_ctx.type_infos def.def_id in - let is_tuple_struct_one_field = - is_tuple_struct && match def.kind with Struct [ _ ] -> true | _ -> false + let is_tuple_struct_one_or_zero_field = + is_tuple_struct + && match def.kind with Struct [] | Struct [ _ ] -> true | _ -> false in let type_kind = if extract_body then @@ -1336,7 +1342,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) We need this annotation, otherwise Lean sometimes doesn't manage to typecheck the expressions when it needs to coerce the type. *) - if is_tuple_struct_one_field && !backend = Lean then ( + if is_tuple_struct_one_or_zero_field && !backend = Lean then ( F.pp_print_string fmt "@[reducible]"; F.pp_print_space fmt ()) else (); -- cgit v1.2.3