Skip to content

Commit

Permalink
Cranelift: Break op cost ties with expression depth in egraphs (#7456)
Browse files Browse the repository at this point in the history
* Cranelift: Switch egraph `Cost` to a struct with named fields

Mechanical change.

* Cranelift: Break op cost ties with expression depth in egraphs

This means that, when the opcode cost is the same, we prefer shallow and wide
expressions to narrow and deep. For example, `(a + b) + (c + d)` is preferred to
`((a + b) + c) + d`. This is beneficial because it exposes more
instruction-level parallelism and shortens live ranges.

Co-Authored-By: Trevor Elliott <telliott@fastly.com>

* Cranelift: Bitpack the egraph `Cost` structure

Co-Authored-By: Chris Fallin <chris@cfallin.org>
Co-Authored-By: Trevor Elliott <telliott@fastly.com>

* Make it so you can't construct `Cost::inifinity()` by accident

* Use fold to code golf

---------

Co-authored-by: Trevor Elliott <telliott@fastly.com>
Co-authored-by: Chris Fallin <chris@cfallin.org>
  • Loading branch information
3 people authored Nov 7, 2023
1 parent 54aed0b commit b9f2a30
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 23 deletions.
107 changes: 92 additions & 15 deletions cranelift/codegen/src/egraph/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,52 @@ use crate::ir::Opcode;
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);

impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}

impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self.0.cmp(&other.0)
}
}

impl PartialOrd for Cost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = (Self::OP_COST_MASK >> Self::DEPTH_BITS) - 1;

pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Expand All @@ -43,11 +86,38 @@ impl Cost {
Cost(0)
}

/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite(self) -> Cost {
Cost(std::cmp::min(u32::MAX - 1, self.0))
/// Construct a new finite cost from the given parts.
///
/// The opcode cost is clamped to the maximum value representable.
fn new_finite(opcode_cost: u32, depth: u8) -> Cost {
let opcode_cost = std::cmp::min(opcode_cost, Self::MAX_OP_COST);
let cost = Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth));
debug_assert_ne!(cost, Cost::infinity());
cost
}

fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}

fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
}

/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new_finite(c.op_cost(), c.depth().saturating_add(1))
}
}

impl std::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}

Expand All @@ -59,22 +129,29 @@ impl std::default::Default for Cost {

impl std::ops::Add<Cost> for Cost {
type Output = Cost;

fn add(self, other: Cost) -> Cost {
Cost(self.0.saturating_add(other.0)).finite()
let op_cost = std::cmp::min(
self.op_cost().saturating_add(other.op_cost()),
Self::MAX_OP_COST,
);
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new_finite(op_cost, depth)
}
}

/// Return the cost of a *pure* opcode. Caller is responsible for
/// checking that the opcode came from an instruction that satisfies
/// `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost(1),
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new_finite(1, 0),

// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost(2)
Cost::new_finite(2, 0)
}

// "Simple" arithmetic.
Expand All @@ -86,9 +163,9 @@ pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost(3),
| Opcode::Sshr => Cost::new_finite(3, 0),

// Everything else (pure.)
_ => Cost(4),
_ => Cost::new_finite(4, 0),
}
}
13 changes: 5 additions & 8 deletions cranelift/codegen/src/egraph/elaborate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Elaboration phase: lowers EGraph back to sequences of operations
//! in CFG nodes.
use super::cost::{pure_op_cost, Cost};
use super::cost::Cost;
use super::domtree::DomTreeWithChildren;
use super::Stats;
use crate::dominator_tree::DominatorTree;
Expand Down Expand Up @@ -245,13 +245,10 @@ impl<'a> Elaborator<'a> {
// N.B.: at this point we know that the opcode is
// pure, so `pure_op_cost`'s precondition is
// satisfied.
let cost = self
.func
.dfg
.inst_values(inst)
.fold(pure_op_cost(inst_data.opcode()), |cost, value| {
cost + best[value].0
});
let cost = Cost::of_pure_op(
inst_data.opcode(),
self.func.dfg.inst_values(inst).map(|value| best[value].0),
);
best[value] = BestEntry(cost, value);
}
}
Expand Down

0 comments on commit b9f2a30

Please sign in to comment.