summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-11-08 16:35:55 +0100
committerSon HO2022-11-10 11:35:30 +0100
commitdcb1a77150d26875ab67b5e12cb299a3d9369d4a (patch)
treec5eba364eab9f975bf6fc454320c6e38dfb330c1
parent357782ba4c039ac6d83b4fd8344121e89f87eb7b (diff)
Update `switch` to have a specific treatment of ADTs
-rw-r--r--compiler/InterpreterExpressions.ml96
-rw-r--r--compiler/InterpreterExpressions.mli40
-rw-r--r--compiler/InterpreterPaths.ml32
-rw-r--r--compiler/InterpreterPaths.mli8
-rw-r--r--compiler/InterpreterStatements.ml185
-rw-r--r--compiler/PrePasses.ml5
-rw-r--r--compiler/Pure.ml4
-rw-r--r--compiler/PureUtils.ml2
-rw-r--r--compiler/Substitute.ml31
9 files changed, 207 insertions, 196 deletions
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index ea0e1aa9..a0fb97d1 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -66,22 +66,6 @@ let read_place (config : C.config) (access : access_kind) (p : E.place)
(* Call the continuation *)
cf v ctx
-(** Small utility.
-
- Prepare the access to a place in a right-value (typically an operand) by
- reorganizing the environment.
-
- We reorganize the environment so that:
- - we can access the place (we prepare *along* the path)
- - the value at the place itself doesn't contain loans (the [access_kind]
- controls whether we only end mutable loans, or also shared loans).
-
- We also check, after the reorganization, that the value at the place
- *doesn't contain any bottom nor reserved borrows*.
-
- [expand_prim_copy]: if [true], expand the symbolic values which are
- primitively copyable and contain borrows.
- *)
let access_rplace_reorganize_and_read (config : C.config)
(expand_prim_copy : bool) (access : access_kind) (p : E.place)
(cf : V.typed_value -> m_fun) : m_fun =
@@ -560,80 +544,6 @@ let eval_binary_op (config : C.config) (binop : E.binop) (op1 : E.operand)
| C.ConcreteMode -> eval_binary_op_concrete config binop op1 op2 cf
| C.SymbolicMode -> eval_binary_op_symbolic config binop op1 op2 cf
-(** Evaluate the discriminant of a concrete (i.e., non symbolic) ADT value *)
-let eval_rvalue_discriminant_concrete (config : C.config) (p : E.place)
- (cf : V.typed_value -> m_fun) : m_fun =
- (* Note that discriminant values have type [isize] *)
- (* Access the value *)
- let access = Read in
- let expand_prim_copy = false in
- let prepare =
- access_rplace_reorganize_and_read config expand_prim_copy access p
- in
- (* Read the value *)
- let read (cf : V.typed_value -> m_fun) (v : V.typed_value) : m_fun =
- (* The value may be shared: we need to ignore the shared loans *)
- let v = value_strip_shared_loans v in
- match v.V.value with
- | Adt av -> (
- match av.variant_id with
- | None ->
- raise
- (Failure
- "Invalid input for `discriminant`: structure instead of enum")
- | Some variant_id -> (
- let id = Z.of_int (T.VariantId.to_int variant_id) in
- match mk_scalar Isize id with
- | Error _ -> raise (Failure "Disciminant id out of range")
- (* Should really never happen *)
- | Ok sv ->
- cf { V.value = V.Primitive (PV.Scalar sv); ty = Integer Isize })
- )
- | _ ->
- raise
- (Failure ("Invalid input for `discriminant`: " ^ V.show_typed_value v))
- in
- (* Compose and apply *)
- comp prepare read cf
-
-(** Evaluate the discriminant of an ADT value.
-
- Might lead to branching, if the value is symbolic.
- *)
-let eval_rvalue_discriminant (config : C.config) (p : E.place)
- (cf : V.typed_value -> m_fun) : m_fun =
- fun ctx ->
- log#ldebug (lazy "eval_rvalue_discriminant");
- (* Note that discriminant values have type [isize] *)
- (* Access the value *)
- let access = Read in
- let expand_prim_copy = false in
- let prepare =
- access_rplace_reorganize_and_read config expand_prim_copy access p
- in
- (* Read the value *)
- let read (cf : V.typed_value -> m_fun) (v : V.typed_value) : m_fun =
- fun ctx ->
- (* The value may be shared: we need to ignore the shared loans *)
- let v = value_strip_shared_loans v in
- match v.V.value with
- | Adt _ -> eval_rvalue_discriminant_concrete config p cf ctx
- | Symbolic sv ->
- (* Expand the symbolic value - may lead to branching *)
- let allow_branching = true in
- let cc =
- expand_symbolic_value config allow_branching sv
- (Some (S.mk_mplace p ctx))
- in
- (* This time the value is concrete: reevaluate *)
- comp cc (eval_rvalue_discriminant_concrete config p) cf ctx
- | _ ->
- raise
- (Failure ("Invalid input for `discriminant`: " ^ V.show_typed_value v))
- in
- (* Compose and apply *)
- comp prepare read cf ctx
-
let eval_rvalue_ref (config : C.config) (p : E.place) (bkind : E.borrow_kind)
(cf : V.typed_value -> m_fun) : m_fun =
fun ctx ->
@@ -780,7 +690,11 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue)
| E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx
| E.Aggregate (aggregate_kind, ops) ->
comp_wrap (eval_rvalue_aggregate config aggregate_kind ops) ctx
- | E.Discriminant p -> comp_wrap (eval_rvalue_discriminant config p) ctx
+ | E.Discriminant _ ->
+ raise
+ (Failure
+ "Unreachable: discriminant reads should have been eliminated from \
+ the AST")
| E.Global _ -> raise (Failure "Unreachable")
let eval_fake_read (config : C.config) (p : E.place) : cm_fun =
diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli
index c610e939..2ea3c6df 100644
--- a/compiler/InterpreterExpressions.mli
+++ b/compiler/InterpreterExpressions.mli
@@ -11,6 +11,41 @@ module S = SynthesizeSymbolic
open Cps
open InterpreterPaths
+(** Read a place (CPS-style function).
+
+ We also check that the value *doesn't contain bottoms or reserved
+ borrows*.
+
+ This function doesn't reorganize the context to make sure we can read
+ the place. If needs be, you should call {!update_ctx_along_read_place} first.
+ *)
+val read_place :
+ C.config -> access_kind -> E.place -> (V.typed_value -> m_fun) -> m_fun
+
+(** Auxiliary function.
+
+ Prepare the access to a place in a right-value (typically an operand) by
+ reorganizing the environment.
+
+ We reorganize the environment so that:
+ - we can access the place (we prepare *along* the path)
+ - the value at the place itself doesn't contain loans (the [access_kind]
+ controls whether we only end mutable loans, or also shared loans).
+
+ We also check, after the reorganization, that the value at the place
+ *doesn't contain any bottom nor reserved borrows*.
+
+ [expand_prim_copy]: if [true], expand the symbolic values which are
+ primitively copyable and contain borrows.
+ *)
+val access_rplace_reorganize_and_read :
+ C.config ->
+ bool ->
+ access_kind ->
+ E.place ->
+ (V.typed_value -> m_fun) ->
+ m_fun
+
(** Evaluate an operand.
Reorganize the context, then evaluate the operand.
@@ -26,9 +61,12 @@ val eval_operand : C.config -> E.operand -> (V.typed_value -> m_fun) -> m_fun
val eval_operands :
C.config -> E.operand list -> (V.typed_value list -> m_fun) -> m_fun
-(** Evaluate an rvalue which is not a global.
+(** Evaluate an rvalue which is not a global (globals are handled elsewhere).
Transmits the computed rvalue to the received continuation.
+
+ Note that this function fails on {!E.Discriminant}: discriminant reads should
+ have been eliminated from the AST.
*)
val eval_rvalue_not_global :
C.config -> E.rvalue -> ((V.typed_value, eval_error) result -> m_fun) -> m_fun
diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml
index 3d0c69e8..63e03e31 100644
--- a/compiler/InterpreterPaths.ml
+++ b/compiler/InterpreterPaths.ml
@@ -569,23 +569,21 @@ let rec end_loans_at_place (config : C.config) (access : access_kind)
in
(* First, retrieve the value *)
- match try_read_place config access p ctx with
- | Error _ -> raise (Failure "Unreachable")
- | Ok v -> (
- (* Inspect the value and update the context while doing so.
- If the context gets updated: perform a recursive call (many things
- may have been updated in the context: we need to re-read the value
- at place [p] - and this value may actually not be accessible
- anymore...)
- *)
- try
- obj#visit_typed_value () v;
- (* No context update required: apply the continuation *)
- cf ctx
- with UpdateCtx cc ->
- (* We need to update the context: compose the caugth continuation with
- * a recursive call to reinspect the value *)
- comp cc (end_loans_at_place config access p) cf ctx)
+ let v = read_place config access p ctx in
+ (* Inspect the value and update the context while doing so.
+ If the context gets updated: perform a recursive call (many things
+ may have been updated in the context: we need to re-read the value
+ at place [p] - and this value may actually not be accessible
+ anymore...)
+ *)
+ try
+ obj#visit_typed_value () v;
+ (* No context update required: apply the continuation *)
+ cf ctx
+ with UpdateCtx cc ->
+ (* We need to update the context: compose the caugth continuation with
+ * a recursive call to reinspect the value *)
+ comp cc (end_loans_at_place config access p) cf ctx
let drop_outer_loans_at_lplace (config : C.config) (p : E.place) : cm_fun =
fun cf ctx ->
diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli
index ed00b7c5..14baf128 100644
--- a/compiler/InterpreterPaths.mli
+++ b/compiler/InterpreterPaths.mli
@@ -27,6 +27,9 @@ val update_ctx_along_write_place : C.config -> access_kind -> E.place -> cm_fun
(** Read the value at a given place.
+ This function doesn't update the environment to make sure the value is
+ accessible: if needs be, you should call {!update_ctx_along_read_place} first.
+
Note that we only access the value at the place, and do not check that
the value is "well-formed" (for instance that it doesn't contain bottoms).
*)
@@ -35,7 +38,10 @@ val read_place :
(** Update the value at a given place.
- This function is an auxiliary function and is not safe: it will not check if
+ This function doesn't update the environment to make sure the value is
+ accessible: if needs be, you should call {!update_ctx_along_write_place} first.
+
+ This function is a helper function and is **not safe**: it will not check if
the overwritten value contains borrows, loans, etc. and will simply
overwrite it.
*)
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 08a03885..21027ff8 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -889,7 +889,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun =
in
(* Apply *)
eval_statement config loop_body reeval_loop_body ctx
- | A.Switch (op, tgts) -> eval_switch config op tgts cf ctx
+ | A.Switch switch -> eval_switch config switch cf ctx
in
(* Compose and apply *)
comp cc cf_eval_st cf ctx
@@ -916,8 +916,7 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id)
S.synthesize_global_eval gid sval e
(** Evaluate a switch *)
-and eval_switch (config : C.config) (op : E.operand) (tgts : A.switch_targets) :
- st_cm_fun =
+and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun =
fun cf ctx ->
(* We evaluate the operand in two steps:
* first we prepare it, then we check if its value is concrete or
@@ -927,74 +926,124 @@ and eval_switch (config : C.config) (op : E.operand) (tgts : A.switch_targets) :
* value if it is symbolic, because the value may have been move
* (and would thus floating in thin air...)!
* *)
- (* Prepare the operand *)
- let cf_eval_op cf : m_fun = eval_operand config op cf in
(* Match on the targets *)
- let cf_match (cf : st_m_fun) (op_v : V.typed_value) : m_fun =
- fun ctx ->
- match tgts with
- | A.If (st1, st2) -> (
- match op_v.value with
- | V.Primitive (PV.Bool b) ->
- (* Evaluate the if and the branch body *)
- let cf_branch cf : m_fun =
- (* Branch *)
- if b then eval_statement config st1 cf
- else eval_statement config st2 cf
- in
- (* Compose the continuations *)
- cf_branch cf ctx
- | V.Symbolic sv ->
- (* Expand the symbolic boolean, and continue by evaluating
- * the branches *)
- let cf_true : m_fun = eval_statement config st1 cf in
- let cf_false : m_fun = eval_statement config st2 cf in
- expand_symbolic_bool config sv
- (S.mk_opt_place_from_op op ctx)
- cf_true cf_false ctx
- | _ -> raise (Failure "Inconsistent state"))
- | A.SwitchInt (int_ty, stgts, otherwise) -> (
- match op_v.value with
- | V.Primitive (PV.Scalar sv) ->
- (* Evaluate the branch *)
- let cf_eval_branch cf =
- (* Sanity check *)
- assert (sv.PV.int_ty = int_ty);
- (* Find the branch *)
- match List.find_opt (fun (svl, _) -> List.mem sv svl) stgts with
- | None -> eval_statement config otherwise cf
- | Some (_, tgt) -> eval_statement config tgt cf
- in
- (* Compose *)
- cf_eval_branch cf ctx
- | V.Symbolic sv ->
- (* Expand the symbolic value and continue by evaluating the
- * proper branches *)
- let stgts =
- List.map
- (fun (cv, tgt_st) -> (cv, eval_statement config tgt_st cf))
- stgts
- in
- (* Several branches may be grouped together: every branch is described
- * by a pair (list of values, branch expression).
- * In order to do a symbolic evaluation, we make this "flat" by
- * de-grouping the branches. *)
- let stgts =
- List.concat
- (List.map
- (fun (vl, st) -> List.map (fun v -> (v, st)) vl)
- stgts)
- in
- (* Translate the otherwise branch *)
- let otherwise = eval_statement config otherwise cf in
- (* Expand and continue *)
- expand_symbolic_int config sv
- (S.mk_opt_place_from_op op ctx)
- int_ty stgts otherwise ctx
- | _ -> raise (Failure "Inconsistent state"))
+ let cf_match : st_cm_fun =
+ fun cf ctx ->
+ match switch with
+ | A.If (op, st1, st2) ->
+ (* Evaluate the operand *)
+ let cf_eval_op = eval_operand config op in
+ (* Switch on the value *)
+ let cf_if (cf : st_m_fun) (op_v : V.typed_value) : m_fun =
+ fun ctx ->
+ match op_v.value with
+ | V.Primitive (PV.Bool b) ->
+ (* Evaluate the if and the branch body *)
+ let cf_branch cf : m_fun =
+ (* Branch *)
+ if b then eval_statement config st1 cf
+ else eval_statement config st2 cf
+ in
+ (* Compose the continuations *)
+ cf_branch cf ctx
+ | V.Symbolic sv ->
+ (* Expand the symbolic boolean, and continue by evaluating
+ * the branches *)
+ let cf_true : m_fun = eval_statement config st1 cf in
+ let cf_false : m_fun = eval_statement config st2 cf in
+ expand_symbolic_bool config sv
+ (S.mk_opt_place_from_op op ctx)
+ cf_true cf_false ctx
+ | _ -> raise (Failure "Inconsistent state")
+ in
+ (* Compose *)
+ comp cf_eval_op cf_if cf ctx
+ | A.SwitchInt (op, int_ty, stgts, otherwise) ->
+ (* Evaluate the operand *)
+ let cf_eval_op = eval_operand config op in
+ (* Switch on the value *)
+ let cf_switch (cf : st_m_fun) (op_v : V.typed_value) : m_fun =
+ fun ctx ->
+ match op_v.value with
+ | V.Primitive (PV.Scalar sv) ->
+ (* Evaluate the branch *)
+ let cf_eval_branch cf =
+ (* Sanity check *)
+ assert (sv.PV.int_ty = int_ty);
+ (* Find the branch *)
+ match List.find_opt (fun (svl, _) -> List.mem sv svl) stgts with
+ | None -> eval_statement config otherwise cf
+ | Some (_, tgt) -> eval_statement config tgt cf
+ in
+ (* Compose *)
+ cf_eval_branch cf ctx
+ | V.Symbolic sv ->
+ (* Expand the symbolic value and continue by evaluating the
+ * proper branches *)
+ let stgts =
+ List.map
+ (fun (cv, tgt_st) -> (cv, eval_statement config tgt_st cf))
+ stgts
+ in
+ (* Several branches may be grouped together: every branch is described
+ * by a pair (list of values, branch expression).
+ * In order to do a symbolic evaluation, we make this "flat" by
+ * de-grouping the branches. *)
+ let stgts =
+ List.concat
+ (List.map
+ (fun (vl, st) -> List.map (fun v -> (v, st)) vl)
+ stgts)
+ in
+ (* Translate the otherwise branch *)
+ let otherwise = eval_statement config otherwise cf in
+ (* Expand and continue *)
+ expand_symbolic_int config sv
+ (S.mk_opt_place_from_op op ctx)
+ int_ty stgts otherwise ctx
+ | _ -> raise (Failure "Inconsistent state")
+ in
+ (* Compose *)
+ comp cf_eval_op cf_switch cf ctx
+ | A.Match (p, stgts, otherwise) ->
+ (* Access the place *)
+ let access = Read in
+ let expand_prim_copy = false in
+ let cf_read_p cf : m_fun =
+ access_rplace_reorganize_and_read config expand_prim_copy access p cf
+ in
+ (* Match on the value *)
+ let cf_match (cf : st_m_fun) (p_v : V.typed_value) : m_fun =
+ fun ctx ->
+ (* The value may be shared: we need to ignore the shared loans
+ to read the value itself *)
+ let p_v = value_strip_shared_loans p_v in
+ (* Match *)
+ match p_v.value with
+ | V.Adt adt -> (
+ (* Evaluate the discriminant *)
+ let dv = Option.get adt.variant_id in
+ (* Find the branch, evaluate and continue *)
+ match List.find_opt (fun (svl, _) -> List.mem dv svl) stgts with
+ | None -> eval_statement config otherwise cf ctx
+ | Some (_, tgt) -> eval_statement config tgt cf ctx)
+ | V.Symbolic sv ->
+ (* Expand the symbolic value - may lead to branching *)
+ let allow_branching = true in
+ let cf_expand =
+ expand_symbolic_value config allow_branching sv
+ (Some (S.mk_mplace p ctx))
+ in
+ (* Re-evaluate the switch - the value is not symbolic anymore,
+ which means we will go to the other branch *)
+ comp cf_expand (eval_switch config switch) cf ctx
+ | _ -> raise (Failure "Inconsistent state")
+ in
+ (* Compose *)
+ comp cf_read_p cf_match cf ctx
in
(* Compose the continuations *)
- comp cf_eval_op cf_match cf ctx
+ cf_match cf ctx
(** Evaluate a function call (auxiliary helper for [eval_statement]) *)
and eval_function_call (config : C.config) (call : A.call) : st_cm_fun =
diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml
index 1bdaf174..082a81ba 100644
--- a/compiler/PrePasses.ml
+++ b/compiler/PrePasses.ml
@@ -121,10 +121,9 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
method! visit_Sequence env st1 st2 =
match st1.content with
- | Switch (op, tgts) ->
+ | Switch switch ->
if can_be_moved st2 then
- super#visit_Switch env op
- (chain_statements_in_switch_targets tgts st2)
+ super#visit_Switch env (chain_statements_in_switch switch st2)
else super#visit_Sequence env st1 st2
| _ -> super#visit_Sequence env st1 st2
end
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 7f71cf7f..62657fc7 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -48,7 +48,7 @@ let option_none_id = T.option_none_id
type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
[@@deriving show, ord]
-(** Ancestor for iter visitor for {!ty} *)
+(** Ancestor for iter visitor for {!Pure.ty} *)
class ['self] iter_ty_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.iter
@@ -57,7 +57,7 @@ class ['self] iter_ty_base =
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
end
-(** Ancestor for map visitor for {!ty} *)
+(** Ancestor for map visitor for {!Pure.ty} *)
class ['self] map_ty_base =
object (_self : 'self)
inherit [_] VisitorsRuntime.map
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 2ef97e59..9d364dc7 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -73,7 +73,7 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
in
fun id -> TypeVarId.Map.find id mp
-(** Retrieve the list of fields for the given variant of a {!Pure.type_decl}.
+(** Retrieve the list of fields for the given variant of a {!type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
*)
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 30469c2f..eb61f076 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -274,7 +274,7 @@ let call_substitute (tsubst : T.TypeVarId.id -> T.ety) (call : A.call) : A.call
dest;
}
-(** Apply a type substitution to a statement *)
+(** Apply a type substitution to a statement - TODO: use a map iterator *)
let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety)
(st : A.statement) : A.statement =
{ st with A.content = raw_statement_substitute tsubst st.content }
@@ -305,23 +305,30 @@ and raw_statement_substitute (tsubst : T.TypeVarId.id -> T.ety)
| A.Sequence (st1, st2) ->
A.Sequence
(statement_substitute tsubst st1, statement_substitute tsubst st2)
- | A.Switch (op, tgts) ->
- A.Switch
- (operand_substitute tsubst op, switch_targets_substitute tsubst tgts)
+ | A.Switch switch -> A.Switch (switch_substitute tsubst switch)
| A.Loop le -> A.Loop (statement_substitute tsubst le)
-(** Apply a type substitution to switch targets *)
-and switch_targets_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (tgts : A.switch_targets) : A.switch_targets =
- match tgts with
- | A.If (st1, st2) ->
- A.If (statement_substitute tsubst st1, statement_substitute tsubst st2)
- | A.SwitchInt (int_ty, tgts, otherwise) ->
+(** Apply a type substitution to a switch *)
+and switch_substitute (tsubst : T.TypeVarId.id -> T.ety) (switch : A.switch) :
+ A.switch =
+ match switch with
+ | A.If (op, st1, st2) ->
+ A.If
+ ( operand_substitute tsubst op,
+ statement_substitute tsubst st1,
+ statement_substitute tsubst st2 )
+ | A.SwitchInt (op, int_ty, tgts, otherwise) ->
let tgts =
List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts
in
let otherwise = statement_substitute tsubst otherwise in
- A.SwitchInt (int_ty, tgts, otherwise)
+ A.SwitchInt (operand_substitute tsubst op, int_ty, tgts, otherwise)
+ | A.Match (p, tgts, otherwise) ->
+ let tgts =
+ List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts
+ in
+ let otherwise = statement_substitute tsubst otherwise in
+ A.Match (place_substitute tsubst p, tgts, otherwise)
(** Apply a type substitution to a function body. Return the local variables
and the body. *)