Skip to content

Commit

Permalink
Auto merge of #107009 - cjgillot:jump-threading, r=pnkfelix
Browse files Browse the repository at this point in the history
Implement jump threading MIR opt

This pass is an attempt to generalize `ConstGoto` and `SeparateConstSwitch` passes into a more complete jump threading pass.

This pass is rather heavy, as it performs a truncated backwards DFS on MIR starting from each `SwitchInt` terminator. This backwards DFS remains very limited, as it only walks through `Goto` terminators.

It is build to support constants and discriminants, and a propagating through a very limited set of operations.

The pass successfully manages to disentangle the `Some(x?)` use case and the DFA use case. It still needs a few tests before being ready.
  • Loading branch information
bors committed Oct 23, 2023
2 parents e2068cd + dd08dd4 commit 1322f92
Show file tree
Hide file tree
Showing 31 changed files with 2,797 additions and 137 deletions.
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4279,6 +4279,7 @@ dependencies = [
"coverage_test_macros",
"either",
"itertools",
"rustc_arena",
"rustc_ast",
"rustc_attr",
"rustc_const_eval",
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ impl SwitchTargets {
Self { values: smallvec![value], targets: smallvec![then, else_] }
}

/// Inverse of `SwitchTargets::static_if`.
pub fn as_static_if(&self) -> Option<(u128, BasicBlock, BasicBlock)> {
if let &[value] = &self.values[..] && let &[then, else_] = &self.targets[..] {
Some((value, then, else_))
} else {
None
}
}

/// Returns the fallback target that is jumped to when none of the values match the operand.
pub fn otherwise(&self) -> BasicBlock {
*self.targets.last().unwrap()
Expand Down
149 changes: 124 additions & 25 deletions compiler/rustc_mir_dataflow/src/value_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,19 @@ impl<V: Clone> Clone for State<V> {
}
}

impl<V: Clone + HasTop + HasBottom> State<V> {
impl<V: Clone> State<V> {
pub fn new(init: V, map: &Map) -> State<V> {
let values = IndexVec::from_elem_n(init, map.value_count);
State(StateData::Reachable(values))
}

pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
match self.0 {
StateData::Unreachable => true,
StateData::Reachable(ref values) => values.iter().all(f),
}
}

pub fn is_reachable(&self) -> bool {
matches!(&self.0, StateData::Reachable(_))
}
Expand All @@ -472,7 +484,10 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
self.0 = StateData::Unreachable;
}

pub fn flood_all(&mut self) {
pub fn flood_all(&mut self)
where
V: HasTop,
{
self.flood_all_with(V::TOP)
}

Expand All @@ -481,28 +496,52 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
values.raw.fill(value);
}

/// Assign `value` to all places that are contained in `place` or may alias one.
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, None, &mut |vi| {
values[vi] = value.clone();
});
self.flood_with_tail_elem(place, None, map, value)
}

pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
/// Assign `TOP` to all places that are contained in `place` or may alias one.
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map)
where
V: HasTop,
{
self.flood_with(place, map, V::TOP)
}

/// Assign `value` to the discriminant of `place` and all places that may alias it.
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |vi| {
values[vi] = value.clone();
});
self.flood_with_tail_elem(place, Some(TrackElem::Discriminant), map, value)
}

pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
/// Assign `TOP` to the discriminant of `place` and all places that may alias it.
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map)
where
V: HasTop,
{
self.flood_discr_with(place, map, V::TOP)
}

/// This method is the most general version of the `flood_*` method.
///
/// Assign `value` on the given place and all places that may alias it. In particular, when
/// the given place has a variant downcast, we invoke the function on all the other variants.
///
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
/// as such.
pub fn flood_with_tail_elem(
&mut self,
place: PlaceRef<'_>,
tail_elem: Option<TrackElem>,
map: &Map,
value: V,
) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
values[vi] = value.clone();
});
}

/// Low-level method that assigns to a place.
/// This does nothing if the place is not tracked.
///
Expand Down Expand Up @@ -553,44 +592,104 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
}

/// Helper method to interpret `target = result`.
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
where
V: HasTop,
{
self.flood(target, map);
if let Some(target) = map.find(target) {
self.insert_idx(target, result, map);
}
}

/// Helper method for assignments to a discriminant.
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
where
V: HasTop,
{
self.flood_discr(target, map);
if let Some(target) = map.find_discr(target) {
self.insert_idx(target, result, map);
}
}

/// Retrieve the value stored for a place, or `None` if it is not tracked.
pub fn try_get(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the discriminant stored for a place, or `None` if it is not tracked.
pub fn try_get_discr(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find_discr(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the slice length stored for a place, or `None` if it is not tracked.
pub fn try_get_len(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find_len(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the value stored for a place index, or `None` if it is not tracked.
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
match &self.0 {
StateData::Reachable(values) => {
map.places[place].value_index.map(|v| values[v].clone())
}
StateData::Unreachable => None,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V {
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::TOP)
///
/// This method returns ⊥ if the place is tracked and the state is unreachable.
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_discr(place) {
Some(place) => self.get_idx(place, map),
None => V::TOP,
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_len(place) {
Some(place) => self.get_idx(place, map),
None => V::TOP,
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place index, or ⊤ if it is not tracked.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(values) => {
map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_transform/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
tracing = "0.1"
either = "1"
rustc_ast = { path = "../rustc_ast" }
rustc_arena = { path = "../rustc_arena" }
rustc_attr = { path = "../rustc_attr" }
rustc_data_structures = { path = "../rustc_data_structures" }
rustc_errors = { path = "../rustc_errors" }
Expand Down
98 changes: 98 additions & 0 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};

const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;

/// Verify that the callee body is compatible with the caller.
#[derive(Clone)]
pub(crate) struct CostChecker<'b, 'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
cost: usize,
callee_body: &'b Body<'tcx>,
instance: Option<ty::Instance<'tcx>>,
}

impl<'b, 'tcx> CostChecker<'b, 'tcx> {
pub fn new(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
instance: Option<ty::Instance<'tcx>>,
callee_body: &'b Body<'tcx>,
) -> CostChecker<'b, 'tcx> {
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
}

pub fn cost(&self) -> usize {
self.cost
}

fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
if let Some(instance) = self.instance {
instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
} else {
v
}
}
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
match statement.kind {
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => self.cost += INSTR_COST,
}
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
let tcx = self.tcx;
match terminator.kind {
TerminatorKind::Drop { ref place, unwind, .. } => {
// If the place doesn't actually need dropping, treat it like a regular goto.
let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
} else {
self.cost += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
let fn_ty = self.instantiate_ty(f.const_.ty());
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
// Don't give intrinsics the extra penalty for calls
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Assert { unwind, .. } => {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { unwind, .. } => {
self.cost += INSTR_COST;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
_ => self.cost += INSTR_COST,
}
}
}
Loading

0 comments on commit 1322f92

Please sign in to comment.