Skip to content

Commit

Permalink
C2: optimize constant addends in masked sums
Browse files Browse the repository at this point in the history
Extends the optimization of masked sums introduced in openjdk#6697 to cover constant values, which currently break the optimization.

Such constant values arise in an expression of the following form, for example from MemorySegmentImpl#isAlignedForElement:

(base + (index + 1) << 8) & 255
=> MulNode
(base + (index << 8 + 256)) & 255
=> AddNode
((base + index << 8) + 256) & 255

Currently, "256" is not being recognized as a shifted value. This PR enables:

((base + index << 8) + 256) & 255
=> MulNode
(base + index << 8) & 255
=> MulNode (PR openjdk#6697)
base & 255
  • Loading branch information
mernst-github committed Dec 21, 2024
1 parent 43b7e9f commit bd77da8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 79 deletions.
144 changes: 69 additions & 75 deletions src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndINode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_INT, true)) {
if (AndIL_is_always_zero(phase, in(1), in(2), T_INT, true)) {
return TypeInt::ZERO;
}

Expand Down Expand Up @@ -719,7 +719,7 @@ Node* AndINode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndINode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_INT);
Node* progress = AndIL_sum_and_mask(phase, T_INT);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -803,7 +803,7 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndLNode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
if (AndIL_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
return TypeLong::ZERO;
}

Expand Down Expand Up @@ -851,7 +851,7 @@ Node* AndLNode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_LONG);
Node* progress = AndIL_sum_and_mask(phase, T_LONG);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -2052,94 +2052,88 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
}
}

// Given an expression (AndX shift mask) or (AndX mask shift),
// Returns a lower bound of the number of trailing zeros in expr.
jint MulNode::AndIL_min_trailing_zeros(PhaseGVN* phase, Node* expr, BasicType bt) {
expr = expr->uncast();
if (expr == nullptr) {
return 0;
}
const TypeInteger* type = phase->type(expr)->isa_integer(bt);
if (type == nullptr) {
return 0;
}

if (type->is_con()) {
long con = type->get_con_as_long(type->basic_type());
return con == 0L ? 0 : count_trailing_zeros(con);
}

if (expr->Opcode() == Op_ConvI2L) {
expr = expr->in(1);
if (expr == nullptr) {
return 0;
}
expr = expr->uncast();
if (expr == nullptr) {
return 0;
}
type = phase->type(expr)->isa_int();
}

if (expr->Opcode() == Op_LShift(type->basic_type())) {
Node* rhs = expr->in(2);
if (rhs == nullptr) {
return 0;
}
const TypeInt* rhs_t = phase->type(rhs)->isa_int();
if (!rhs_t || !rhs_t->is_con()) {
return 0;
}
return rhs_t->get_con() & ((type->isa_int() ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
}

return 0;
}

// Given an expression (AndX expr mask) or (AndX mask expr),
// determine if the AndX must always produce zero, because the
// the shift (x<<N) is bitwise disjoint from the mask #M.
// expr is bitwise disjoint from the mask.
// The X in AndX must be I or L, depending on bt.
// Specifically, the following cases fold to zero,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (LShiftI _ #N) #M) => #0
// (AndL (LShiftL _ #N) #M) => #0
// (AndL (ConvI2L (LShiftI _ #N)) #M) => #0
// as well as for constant operands:
// (AndI (ConI [+-] _ << #N) #M) => #0
// (AndL (ConL [+-] _ << #N) #M) => #0
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, we check the AndX for both operand orders.
bool MulNode::AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse) {
if (mask == nullptr || shift == nullptr) {
// mask M, we check for both operand orders.
bool MulNode::AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse) {
if (mask == nullptr || expr == nullptr) {
return false;
}
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr || phase->type(shift)->isa_integer(bt) == nullptr) {
if (mask_t == nullptr) {
return false;
}
shift = shift->uncast();
if (shift == nullptr) {
return false;
}
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
BasicType shift_bt = bt;
if (bt == T_LONG && shift->Opcode() == Op_ConvI2L) {
bt = T_INT;
Node* val = shift->in(1);
if (val == nullptr) {
return false;
}
val = val->uncast();
if (val == nullptr) {
return false;
}
if (val->Opcode() == Op_LShiftI) {
shift_bt = T_INT;
shift = val;
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
}
}
if (shift->Opcode() != Op_LShift(shift_bt)) {
if (check_reverse &&
(mask->Opcode() == Op_LShift(bt) ||
(bt == T_LONG && mask->Opcode() == Op_ConvI2L))) {
// try it the other way around
return AndIL_shift_and_mask_is_always_zero(phase, mask, shift, bt, false);
}
return false;
}
Node* shift2 = shift->in(2);
if (shift2 == nullptr) {
return false;
}
const Type* shift2_t = phase->type(shift2);
if (!shift2_t->isa_int() || !shift2_t->is_int()->is_con()) {
return false;
}

jint shift_con = shift2_t->is_int()->get_con() & ((shift_bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
if ((((jlong)1) << shift_con) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0) {
return true;
jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
if (zeros == 0) {
// try it the other way around
return check_reverse && AndIL_is_always_zero(phase, mask, expr, bt, false);
}

return false;
return ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
}

// Given an expression (AndX (AddX v1 (LShiftX v2 #N)) #M)
// determine if the AndX must always produce (AndX v1 #M),
// because the shift (v2<<N) is bitwise disjoint from the mask #M.
// The X in AndX will be I or L, depending on bt.
// Specifically, the following cases fold,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (AddI v1 (LShiftI _ #N)) #M) => (AndI v1 #M)
// (AndL (AddI v1 (LShiftL _ #N)) #M) => (AndL v1 #M)
// (AndL (AddL v1 (ConvI2L (LShiftI _ #N))) #M) => (AndL v1 #M)
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, and because the AddX operands can come in either
// order, we check for every operand order.
Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
// Given an expression (AndX (AddX v1 v2) mask)
// determine if the AndX must always produce (AndX v1 mask),
// because v2 is bitwise disjoint from the mask.
// Because the AddX operands can come in either
// order, we check for both orders.
Node* MulNode::AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add = in(1);
Node* mask = in(2);
if (add == nullptr || mask == nullptr) {
Expand All @@ -2157,10 +2151,10 @@ Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add1 = add->in(1);
Node* add2 = add->in(2);
if (add1 != nullptr && add2 != nullptr) {
if (AndIL_shift_and_mask_is_always_zero(phase, add1, mask, bt, false)) {
if (AndIL_is_always_zero(phase, add1, mask, bt, false)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_shift_and_mask_is_always_zero(phase, add2, mask, bt, false)) {
} else if (AndIL_is_always_zero(phase, add2, mask, bt, false)) {
set_req_X(addidx, add1, phase);
return this;
}
Expand Down
5 changes: 3 additions & 2 deletions src/hotspot/share/opto/mulnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class MulNode : public Node {

static MulNode* make(Node* in1, Node* in2, BasicType bt);

static bool AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse);
Node* AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt);
static jint AndIL_min_trailing_zeros(PhaseGVN* phase, Node* addend, BasicType bt);
static bool AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse);
Node* AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt);
};

//------------------------------MulINode---------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions test/hotspot/jtreg/compiler/c2/irTests/TestShiftAndMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static void checkShiftNonConstMaskLong(long res) {
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftMaskInt(int i, int j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
return (j + ((i + 1) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftMaskInt")
Expand Down Expand Up @@ -165,7 +165,7 @@ public static void addSshiftNonConstMaskInt_runner() {
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftMaskLong(long i, long j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
return (j + ((i - 3) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftMaskLong")
Expand Down

0 comments on commit bd77da8

Please sign in to comment.