Skip to content

Commit

Permalink
Cranelift: Bitpack the egraph Cost structure
Browse files Browse the repository at this point in the history
Co-Authored-By: Chris Fallin <chris@cfallin.org>
Co-Authored-By: Trevor Elliott <telliott@fastly.com>
  • Loading branch information
3 people committed Nov 2, 2023
1 parent 89d3601 commit 3d031f6
Showing 1 changed file with 60 additions and 41 deletions.
101 changes: 60 additions & 41 deletions cranelift/codegen/src/egraph/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,36 @@ use crate::ir::Opcode;
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct Cost {
opcode_cost: u32,
depth: u32,
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);

impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}

impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Break `opcode_cost` ties with `depth`. This means that, when the
// opcode cost is the same, we prefer shallow and wide expressions to
// narrow and deep. For example, `(a + b) + (c + d)` is preferred to
// `((a + b) + c) + d`. This is beneficial because it exposes more
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self.opcode_cost
.cmp(&other.opcode_cost)
.then_with(|| self.depth.cmp(&other.depth))
self.0.cmp(&other.0)
}
}

Expand All @@ -58,48 +71,53 @@ impl PartialOrd for Cost {
}

impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;

pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost {
opcode_cost: u32::MAX,
depth: u32::MAX,
}
Cost(u32::MAX)
}

pub(crate) fn zero() -> Cost {
Cost {
opcode_cost: 0,
depth: 0,
}
Cost(0)
}

pub(crate) fn new(opcode_cost: u32) -> Cost {
let cost = Cost {
opcode_cost,
depth: 0,
};
cost.finite()
fn new(opcode_cost: u32, depth: u8) -> Cost {
debug_assert!(
opcode_cost <= Self::MAX_OP_COST,
"Cost::new: given opcode cost of {opcode_cost} is larger than max of {}",
Self::MAX_OP_COST,
);
Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth))
}

fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}

fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
}

/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite(self) -> Cost {
Cost {
opcode_cost: std::cmp::min(u32::MAX - 1, self.opcode_cost),
depth: std::cmp::min(u32::MAX - 1, self.depth),
}
Cost(std::cmp::min(u32::MAX - 1, self.0))
}

/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let mut c: Self = pure_op_cost(op) + operand_costs.into_iter().sum();
c.depth = c.depth.saturating_add(1);
c
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new(c.op_cost(), c.depth().saturating_add(1)).finite()
}
}

Expand All @@ -123,11 +141,12 @@ impl std::ops::Add<Cost> for Cost {
type Output = Cost;

fn add(self, other: Cost) -> Cost {
let cost = Cost {
opcode_cost: self.opcode_cost.saturating_add(other.opcode_cost),
depth: std::cmp::max(self.depth, other.depth),
};
cost.finite()
let op_cost = std::cmp::min(
self.op_cost().saturating_add(other.op_cost()),
Self::MAX_OP_COST,
);
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new(op_cost, depth).finite()
}
}

Expand All @@ -138,11 +157,11 @@ impl std::ops::Add<Cost> for Cost {
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1),
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),

// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost::new(2)
Cost::new(2, 0)
}

// "Simple" arithmetic.
Expand All @@ -154,9 +173,9 @@ fn pure_op_cost(op: Opcode) -> Cost {
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new(3),
| Opcode::Sshr => Cost::new(3, 0),

// Everything else (pure.)
_ => Cost::new(4),
_ => Cost::new(4, 0),
}
}

0 comments on commit 3d031f6

Please sign in to comment.