From dcb1a77150d26875ab67b5e12cb299a3d9369d4a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 8 Nov 2022 16:35:55 +0100 Subject: Update `switch` to have a specific treatment of ADTs --- compiler/InterpreterExpressions.ml | 96 +------------------ compiler/InterpreterExpressions.mli | 40 +++++++- compiler/InterpreterPaths.ml | 32 +++---- compiler/InterpreterPaths.mli | 8 +- compiler/InterpreterStatements.ml | 185 +++++++++++++++++++++++------------- compiler/PrePasses.ml | 5 +- compiler/Pure.ml | 4 +- compiler/PureUtils.ml | 2 +- compiler/Substitute.ml | 31 +++--- 9 files changed, 207 insertions(+), 196 deletions(-) (limited to 'compiler') diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index ea0e1aa9..a0fb97d1 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -66,22 +66,6 @@ let read_place (config : C.config) (access : access_kind) (p : E.place) (* Call the continuation *) cf v ctx -(** Small utility. - - Prepare the access to a place in a right-value (typically an operand) by - reorganizing the environment. - - We reorganize the environment so that: - - we can access the place (we prepare *along* the path) - - the value at the place itself doesn't contain loans (the [access_kind] - controls whether we only end mutable loans, or also shared loans). - - We also check, after the reorganization, that the value at the place - *doesn't contain any bottom nor reserved borrows*. - - [expand_prim_copy]: if [true], expand the symbolic values which are - primitively copyable and contain borrows. - *) let access_rplace_reorganize_and_read (config : C.config) (expand_prim_copy : bool) (access : access_kind) (p : E.place) (cf : V.typed_value -> m_fun) : m_fun = @@ -560,80 +544,6 @@ let eval_binary_op (config : C.config) (binop : E.binop) (op1 : E.operand) | C.ConcreteMode -> eval_binary_op_concrete config binop op1 op2 cf | C.SymbolicMode -> eval_binary_op_symbolic config binop op1 op2 cf -(** Evaluate the discriminant of a concrete (i.e., non symbolic) ADT value *) -let eval_rvalue_discriminant_concrete (config : C.config) (p : E.place) - (cf : V.typed_value -> m_fun) : m_fun = - (* Note that discriminant values have type [isize] *) - (* Access the value *) - let access = Read in - let expand_prim_copy = false in - let prepare = - access_rplace_reorganize_and_read config expand_prim_copy access p - in - (* Read the value *) - let read (cf : V.typed_value -> m_fun) (v : V.typed_value) : m_fun = - (* The value may be shared: we need to ignore the shared loans *) - let v = value_strip_shared_loans v in - match v.V.value with - | Adt av -> ( - match av.variant_id with - | None -> - raise - (Failure - "Invalid input for `discriminant`: structure instead of enum") - | Some variant_id -> ( - let id = Z.of_int (T.VariantId.to_int variant_id) in - match mk_scalar Isize id with - | Error _ -> raise (Failure "Disciminant id out of range") - (* Should really never happen *) - | Ok sv -> - cf { V.value = V.Primitive (PV.Scalar sv); ty = Integer Isize }) - ) - | _ -> - raise - (Failure ("Invalid input for `discriminant`: " ^ V.show_typed_value v)) - in - (* Compose and apply *) - comp prepare read cf - -(** Evaluate the discriminant of an ADT value. - - Might lead to branching, if the value is symbolic. - *) -let eval_rvalue_discriminant (config : C.config) (p : E.place) - (cf : V.typed_value -> m_fun) : m_fun = - fun ctx -> - log#ldebug (lazy "eval_rvalue_discriminant"); - (* Note that discriminant values have type [isize] *) - (* Access the value *) - let access = Read in - let expand_prim_copy = false in - let prepare = - access_rplace_reorganize_and_read config expand_prim_copy access p - in - (* Read the value *) - let read (cf : V.typed_value -> m_fun) (v : V.typed_value) : m_fun = - fun ctx -> - (* The value may be shared: we need to ignore the shared loans *) - let v = value_strip_shared_loans v in - match v.V.value with - | Adt _ -> eval_rvalue_discriminant_concrete config p cf ctx - | Symbolic sv -> - (* Expand the symbolic value - may lead to branching *) - let allow_branching = true in - let cc = - expand_symbolic_value config allow_branching sv - (Some (S.mk_mplace p ctx)) - in - (* This time the value is concrete: reevaluate *) - comp cc (eval_rvalue_discriminant_concrete config p) cf ctx - | _ -> - raise - (Failure ("Invalid input for `discriminant`: " ^ V.show_typed_value v)) - in - (* Compose and apply *) - comp prepare read cf ctx - let eval_rvalue_ref (config : C.config) (p : E.place) (bkind : E.borrow_kind) (cf : V.typed_value -> m_fun) : m_fun = fun ctx -> @@ -780,7 +690,11 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue) | E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx | E.Aggregate (aggregate_kind, ops) -> comp_wrap (eval_rvalue_aggregate config aggregate_kind ops) ctx - | E.Discriminant p -> comp_wrap (eval_rvalue_discriminant config p) ctx + | E.Discriminant _ -> + raise + (Failure + "Unreachable: discriminant reads should have been eliminated from \ + the AST") | E.Global _ -> raise (Failure "Unreachable") let eval_fake_read (config : C.config) (p : E.place) : cm_fun = diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli index c610e939..2ea3c6df 100644 --- a/compiler/InterpreterExpressions.mli +++ b/compiler/InterpreterExpressions.mli @@ -11,6 +11,41 @@ module S = SynthesizeSymbolic open Cps open InterpreterPaths +(** Read a place (CPS-style function). + + We also check that the value *doesn't contain bottoms or reserved + borrows*. + + This function doesn't reorganize the context to make sure we can read + the place. If needs be, you should call {!update_ctx_along_read_place} first. + *) +val read_place : + C.config -> access_kind -> E.place -> (V.typed_value -> m_fun) -> m_fun + +(** Auxiliary function. + + Prepare the access to a place in a right-value (typically an operand) by + reorganizing the environment. + + We reorganize the environment so that: + - we can access the place (we prepare *along* the path) + - the value at the place itself doesn't contain loans (the [access_kind] + controls whether we only end mutable loans, or also shared loans). + + We also check, after the reorganization, that the value at the place + *doesn't contain any bottom nor reserved borrows*. + + [expand_prim_copy]: if [true], expand the symbolic values which are + primitively copyable and contain borrows. + *) +val access_rplace_reorganize_and_read : + C.config -> + bool -> + access_kind -> + E.place -> + (V.typed_value -> m_fun) -> + m_fun + (** Evaluate an operand. Reorganize the context, then evaluate the operand. @@ -26,9 +61,12 @@ val eval_operand : C.config -> E.operand -> (V.typed_value -> m_fun) -> m_fun val eval_operands : C.config -> E.operand list -> (V.typed_value list -> m_fun) -> m_fun -(** Evaluate an rvalue which is not a global. +(** Evaluate an rvalue which is not a global (globals are handled elsewhere). Transmits the computed rvalue to the received continuation. + + Note that this function fails on {!E.Discriminant}: discriminant reads should + have been eliminated from the AST. *) val eval_rvalue_not_global : C.config -> E.rvalue -> ((V.typed_value, eval_error) result -> m_fun) -> m_fun diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 3d0c69e8..63e03e31 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -569,23 +569,21 @@ let rec end_loans_at_place (config : C.config) (access : access_kind) in (* First, retrieve the value *) - match try_read_place config access p ctx with - | Error _ -> raise (Failure "Unreachable") - | Ok v -> ( - (* Inspect the value and update the context while doing so. - If the context gets updated: perform a recursive call (many things - may have been updated in the context: we need to re-read the value - at place [p] - and this value may actually not be accessible - anymore...) - *) - try - obj#visit_typed_value () v; - (* No context update required: apply the continuation *) - cf ctx - with UpdateCtx cc -> - (* We need to update the context: compose the caugth continuation with - * a recursive call to reinspect the value *) - comp cc (end_loans_at_place config access p) cf ctx) + let v = read_place config access p ctx in + (* Inspect the value and update the context while doing so. + If the context gets updated: perform a recursive call (many things + may have been updated in the context: we need to re-read the value + at place [p] - and this value may actually not be accessible + anymore...) + *) + try + obj#visit_typed_value () v; + (* No context update required: apply the continuation *) + cf ctx + with UpdateCtx cc -> + (* We need to update the context: compose the caugth continuation with + * a recursive call to reinspect the value *) + comp cc (end_loans_at_place config access p) cf ctx let drop_outer_loans_at_lplace (config : C.config) (p : E.place) : cm_fun = fun cf ctx -> diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli index ed00b7c5..14baf128 100644 --- a/compiler/InterpreterPaths.mli +++ b/compiler/InterpreterPaths.mli @@ -27,6 +27,9 @@ val update_ctx_along_write_place : C.config -> access_kind -> E.place -> cm_fun (** Read the value at a given place. + This function doesn't update the environment to make sure the value is + accessible: if needs be, you should call {!update_ctx_along_read_place} first. + Note that we only access the value at the place, and do not check that the value is "well-formed" (for instance that it doesn't contain bottoms). *) @@ -35,7 +38,10 @@ val read_place : (** Update the value at a given place. - This function is an auxiliary function and is not safe: it will not check if + This function doesn't update the environment to make sure the value is + accessible: if needs be, you should call {!update_ctx_along_write_place} first. + + This function is a helper function and is **not safe**: it will not check if the overwritten value contains borrows, loans, etc. and will simply overwrite it. *) diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 08a03885..21027ff8 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -889,7 +889,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = in (* Apply *) eval_statement config loop_body reeval_loop_body ctx - | A.Switch (op, tgts) -> eval_switch config op tgts cf ctx + | A.Switch switch -> eval_switch config switch cf ctx in (* Compose and apply *) comp cc cf_eval_st cf ctx @@ -916,8 +916,7 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) S.synthesize_global_eval gid sval e (** Evaluate a switch *) -and eval_switch (config : C.config) (op : E.operand) (tgts : A.switch_targets) : - st_cm_fun = +and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = fun cf ctx -> (* We evaluate the operand in two steps: * first we prepare it, then we check if its value is concrete or @@ -927,74 +926,124 @@ and eval_switch (config : C.config) (op : E.operand) (tgts : A.switch_targets) : * value if it is symbolic, because the value may have been move * (and would thus floating in thin air...)! * *) - (* Prepare the operand *) - let cf_eval_op cf : m_fun = eval_operand config op cf in (* Match on the targets *) - let cf_match (cf : st_m_fun) (op_v : V.typed_value) : m_fun = - fun ctx -> - match tgts with - | A.If (st1, st2) -> ( - match op_v.value with - | V.Primitive (PV.Bool b) -> - (* Evaluate the if and the branch body *) - let cf_branch cf : m_fun = - (* Branch *) - if b then eval_statement config st1 cf - else eval_statement config st2 cf - in - (* Compose the continuations *) - cf_branch cf ctx - | V.Symbolic sv -> - (* Expand the symbolic boolean, and continue by evaluating - * the branches *) - let cf_true : m_fun = eval_statement config st1 cf in - let cf_false : m_fun = eval_statement config st2 cf in - expand_symbolic_bool config sv - (S.mk_opt_place_from_op op ctx) - cf_true cf_false ctx - | _ -> raise (Failure "Inconsistent state")) - | A.SwitchInt (int_ty, stgts, otherwise) -> ( - match op_v.value with - | V.Primitive (PV.Scalar sv) -> - (* Evaluate the branch *) - let cf_eval_branch cf = - (* Sanity check *) - assert (sv.PV.int_ty = int_ty); - (* Find the branch *) - match List.find_opt (fun (svl, _) -> List.mem sv svl) stgts with - | None -> eval_statement config otherwise cf - | Some (_, tgt) -> eval_statement config tgt cf - in - (* Compose *) - cf_eval_branch cf ctx - | V.Symbolic sv -> - (* Expand the symbolic value and continue by evaluating the - * proper branches *) - let stgts = - List.map - (fun (cv, tgt_st) -> (cv, eval_statement config tgt_st cf)) - stgts - in - (* Several branches may be grouped together: every branch is described - * by a pair (list of values, branch expression). - * In order to do a symbolic evaluation, we make this "flat" by - * de-grouping the branches. *) - let stgts = - List.concat - (List.map - (fun (vl, st) -> List.map (fun v -> (v, st)) vl) - stgts) - in - (* Translate the otherwise branch *) - let otherwise = eval_statement config otherwise cf in - (* Expand and continue *) - expand_symbolic_int config sv - (S.mk_opt_place_from_op op ctx) - int_ty stgts otherwise ctx - | _ -> raise (Failure "Inconsistent state")) + let cf_match : st_cm_fun = + fun cf ctx -> + match switch with + | A.If (op, st1, st2) -> + (* Evaluate the operand *) + let cf_eval_op = eval_operand config op in + (* Switch on the value *) + let cf_if (cf : st_m_fun) (op_v : V.typed_value) : m_fun = + fun ctx -> + match op_v.value with + | V.Primitive (PV.Bool b) -> + (* Evaluate the if and the branch body *) + let cf_branch cf : m_fun = + (* Branch *) + if b then eval_statement config st1 cf + else eval_statement config st2 cf + in + (* Compose the continuations *) + cf_branch cf ctx + | V.Symbolic sv -> + (* Expand the symbolic boolean, and continue by evaluating + * the branches *) + let cf_true : m_fun = eval_statement config st1 cf in + let cf_false : m_fun = eval_statement config st2 cf in + expand_symbolic_bool config sv + (S.mk_opt_place_from_op op ctx) + cf_true cf_false ctx + | _ -> raise (Failure "Inconsistent state") + in + (* Compose *) + comp cf_eval_op cf_if cf ctx + | A.SwitchInt (op, int_ty, stgts, otherwise) -> + (* Evaluate the operand *) + let cf_eval_op = eval_operand config op in + (* Switch on the value *) + let cf_switch (cf : st_m_fun) (op_v : V.typed_value) : m_fun = + fun ctx -> + match op_v.value with + | V.Primitive (PV.Scalar sv) -> + (* Evaluate the branch *) + let cf_eval_branch cf = + (* Sanity check *) + assert (sv.PV.int_ty = int_ty); + (* Find the branch *) + match List.find_opt (fun (svl, _) -> List.mem sv svl) stgts with + | None -> eval_statement config otherwise cf + | Some (_, tgt) -> eval_statement config tgt cf + in + (* Compose *) + cf_eval_branch cf ctx + | V.Symbolic sv -> + (* Expand the symbolic value and continue by evaluating the + * proper branches *) + let stgts = + List.map + (fun (cv, tgt_st) -> (cv, eval_statement config tgt_st cf)) + stgts + in + (* Several branches may be grouped together: every branch is described + * by a pair (list of values, branch expression). + * In order to do a symbolic evaluation, we make this "flat" by + * de-grouping the branches. *) + let stgts = + List.concat + (List.map + (fun (vl, st) -> List.map (fun v -> (v, st)) vl) + stgts) + in + (* Translate the otherwise branch *) + let otherwise = eval_statement config otherwise cf in + (* Expand and continue *) + expand_symbolic_int config sv + (S.mk_opt_place_from_op op ctx) + int_ty stgts otherwise ctx + | _ -> raise (Failure "Inconsistent state") + in + (* Compose *) + comp cf_eval_op cf_switch cf ctx + | A.Match (p, stgts, otherwise) -> + (* Access the place *) + let access = Read in + let expand_prim_copy = false in + let cf_read_p cf : m_fun = + access_rplace_reorganize_and_read config expand_prim_copy access p cf + in + (* Match on the value *) + let cf_match (cf : st_m_fun) (p_v : V.typed_value) : m_fun = + fun ctx -> + (* The value may be shared: we need to ignore the shared loans + to read the value itself *) + let p_v = value_strip_shared_loans p_v in + (* Match *) + match p_v.value with + | V.Adt adt -> ( + (* Evaluate the discriminant *) + let dv = Option.get adt.variant_id in + (* Find the branch, evaluate and continue *) + match List.find_opt (fun (svl, _) -> List.mem dv svl) stgts with + | None -> eval_statement config otherwise cf ctx + | Some (_, tgt) -> eval_statement config tgt cf ctx) + | V.Symbolic sv -> + (* Expand the symbolic value - may lead to branching *) + let allow_branching = true in + let cf_expand = + expand_symbolic_value config allow_branching sv + (Some (S.mk_mplace p ctx)) + in + (* Re-evaluate the switch - the value is not symbolic anymore, + which means we will go to the other branch *) + comp cf_expand (eval_switch config switch) cf ctx + | _ -> raise (Failure "Inconsistent state") + in + (* Compose *) + comp cf_read_p cf_match cf ctx in (* Compose the continuations *) - comp cf_eval_op cf_match cf ctx + cf_match cf ctx (** Evaluate a function call (auxiliary helper for [eval_statement]) *) and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml index 1bdaf174..082a81ba 100644 --- a/compiler/PrePasses.ml +++ b/compiler/PrePasses.ml @@ -121,10 +121,9 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl = method! visit_Sequence env st1 st2 = match st1.content with - | Switch (op, tgts) -> + | Switch switch -> if can_be_moved st2 then - super#visit_Switch env op - (chain_statements_in_switch_targets tgts st2) + super#visit_Switch env (chain_statements_in_switch switch st2) else super#visit_Sequence env st1 st2 | _ -> super#visit_Sequence env st1 st2 end diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 7f71cf7f..62657fc7 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -48,7 +48,7 @@ let option_none_id = T.option_none_id type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty [@@deriving show, ord] -(** Ancestor for iter visitor for {!ty} *) +(** Ancestor for iter visitor for {!Pure.ty} *) class ['self] iter_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter @@ -57,7 +57,7 @@ class ['self] iter_ty_base = method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () end -(** Ancestor for map visitor for {!ty} *) +(** Ancestor for map visitor for {!Pure.ty} *) class ['self] map_ty_base = object (_self : 'self) inherit [_] VisitorsRuntime.map diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 2ef97e59..9d364dc7 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -73,7 +73,7 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp -(** Retrieve the list of fields for the given variant of a {!Pure.type_decl}. +(** Retrieve the list of fields for the given variant of a {!type_decl}. Raises [Invalid_argument] if the arguments are incorrect. *) diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 30469c2f..eb61f076 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -274,7 +274,7 @@ let call_substitute (tsubst : T.TypeVarId.id -> T.ety) (call : A.call) : A.call dest; } -(** Apply a type substitution to a statement *) +(** Apply a type substitution to a statement - TODO: use a map iterator *) let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) (st : A.statement) : A.statement = { st with A.content = raw_statement_substitute tsubst st.content } @@ -305,23 +305,30 @@ and raw_statement_substitute (tsubst : T.TypeVarId.id -> T.ety) | A.Sequence (st1, st2) -> A.Sequence (statement_substitute tsubst st1, statement_substitute tsubst st2) - | A.Switch (op, tgts) -> - A.Switch - (operand_substitute tsubst op, switch_targets_substitute tsubst tgts) + | A.Switch switch -> A.Switch (switch_substitute tsubst switch) | A.Loop le -> A.Loop (statement_substitute tsubst le) -(** Apply a type substitution to switch targets *) -and switch_targets_substitute (tsubst : T.TypeVarId.id -> T.ety) - (tgts : A.switch_targets) : A.switch_targets = - match tgts with - | A.If (st1, st2) -> - A.If (statement_substitute tsubst st1, statement_substitute tsubst st2) - | A.SwitchInt (int_ty, tgts, otherwise) -> +(** Apply a type substitution to a switch *) +and switch_substitute (tsubst : T.TypeVarId.id -> T.ety) (switch : A.switch) : + A.switch = + match switch with + | A.If (op, st1, st2) -> + A.If + ( operand_substitute tsubst op, + statement_substitute tsubst st1, + statement_substitute tsubst st2 ) + | A.SwitchInt (op, int_ty, tgts, otherwise) -> let tgts = List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts in let otherwise = statement_substitute tsubst otherwise in - A.SwitchInt (int_ty, tgts, otherwise) + A.SwitchInt (operand_substitute tsubst op, int_ty, tgts, otherwise) + | A.Match (p, tgts, otherwise) -> + let tgts = + List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts + in + let otherwise = statement_substitute tsubst otherwise in + A.Match (place_substitute tsubst p, tgts, otherwise) (** Apply a type substitution to a function body. Return the local variables and the body. *) -- cgit v1.2.3