Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8346664: C2: Optimize mask check with constant offset #22856

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 57 additions & 74 deletions src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,13 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {
return and_value<TypeInt>(r0, r1);
}

// Is expr a neutral element wrt addition under mask?
static bool AndIL_is_zero_element(const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt);

const Type* AndINode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_is_always_zero(phase, in(1), in(2), T_INT, true)) {
if (AndIL_is_zero_element(phase, in(1), in(2), T_INT) ||
AndIL_is_zero_element(phase, in(2), in(1), T_INT)) {
Comment on lines +678 to +679
Copy link
Author

@mernst-github mernst-github Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mernst-github, thanks for making a comment in an OpenJDK project!

All comments and discussions in the OpenJDK Community must be made available under the OpenJDK Terms of Use. If you already are an OpenJDK Author, Committer or Reviewer, please click here to open a new issue so that we can record that fact. Please Use "Add GitHub user mernst-github" for the summary.

If you are not an OpenJDK Author, Committer or Reviewer, simply check the box below to accept the OpenJDK Terms of Use for your comments.

Your comment will be automatically restored once you have accepted the OpenJDK Terms of Use.

return TypeInt::ZERO;
}

Expand Down Expand Up @@ -803,7 +807,8 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndLNode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
if (AndIL_is_zero_element(phase, in(1), in(2), T_LONG) ||
AndIL_is_zero_element(phase, in(2), in(1), T_LONG)) {
return TypeLong::ZERO;
}

Expand Down Expand Up @@ -2052,54 +2057,76 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
}
}

// Returns a lower bound of the number of trailing zeros in expr.
jint MulNode::AndIL_min_trailing_zeros(PhaseGVN* phase, Node* expr, BasicType bt) {
expr = expr->uncast();
if (expr == nullptr) {
return 0;
// Given an expression (AndX (AddX v1 v2) mask)
// determine if the AndX must always produce (AndX v1 mask),
// because v2 is zero wrt addition under mask.
// Because the AddX operands can come in either
// order, we check for both orders.
Node* MulNode::AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add = in(1);
Node* mask = in(2);
if (add == nullptr || mask == nullptr) {
return nullptr;
}
int addidx = 0;
if (add->Opcode() == Op_Add(bt)) {
addidx = 1;
} else if (mask->Opcode() == Op_Add(bt)) {
mask = add;
addidx = 2;
add = in(addidx);
}
if (addidx > 0) {
Node* add1 = add->in(1);
Node* add2 = add->in(2);
if (add1 != nullptr && add2 != nullptr) {
if (AndIL_is_zero_element(phase, add1, mask, bt)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_is_zero_element(phase, add2, mask, bt)) {
set_req_X(addidx, add1, phase);
return this;
}
}
}
return nullptr;
}

// Returns a lower bound on the number of trailing zeros in expr.
static jint AndIL_min_trailing_zeros(const PhaseGVN* phase, const Node* expr, BasicType bt) {
expr = expr->uncast();
const TypeInteger* type = phase->type(expr)->isa_integer(bt);
if (type == nullptr) {
return 0;
}

if (type->is_con()) {
long con = type->get_con_as_long(type->basic_type());
return con == 0L ? 0 : count_trailing_zeros(con);
return con == 0L ? (type2aelembytes(bt) * BitsPerByte) : count_trailing_zeros(con);
}

if (expr->Opcode() == Op_ConvI2L) {
expr = expr->in(1);
if (expr == nullptr) {
return 0;
}
expr = expr->uncast();
if (expr == nullptr) {
return 0;
}
expr = expr->in(1)->uncast();
bt = T_INT;
type = phase->type(expr)->isa_int();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are trying to look through a ConvI2L, I think for the sake of consistency, you can reassign bt to T_INT at this point.

Copy link
Author

@mernst-github mernst-github Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mernst-github, thanks for making a comment in an OpenJDK project!

All comments and discussions in the OpenJDK Community must be made available under the OpenJDK Terms of Use. If you already are an OpenJDK Author, Committer or Reviewer, please click here to open a new issue so that we can record that fact. Please Use "Add GitHub user mernst-github" for the summary.

If you are not an OpenJDK Author, Committer or Reviewer, simply check the box below to accept the OpenJDK Terms of Use for your comments.

Your comment will be automatically restored once you have accepted the OpenJDK Terms of Use.

}

if (expr->Opcode() == Op_LShift(type->basic_type())) {
Node* rhs = expr->in(2);
if (rhs == nullptr) {
const TypeInt* rhs_t = phase->type(expr->in(2))->isa_int();
if (rhs_t == nullptr || !rhs_t->is_con()) {
return 0;
}
const TypeInt* rhs_t = phase->type(rhs)->isa_int();
if (!rhs_t || !rhs_t->is_con()) {
return 0;
}
return rhs_t->get_con() & ((type->isa_int() ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
return rhs_t->get_con() % (type2aelembytes(bt) * BitsPerByte);
}

return 0;
}

// Given an expression (AndX expr mask) or (AndX mask expr),
// determine if the AndX must always produce zero, because the
// expr is bitwise disjoint from the mask.
// Given an expression (AndX X+expr mask), determine
// whether expr is neutral wrt addition under mask
// and hence the result is always equivalent to (AndX X mask),
// The X in AndX must be I or L, depending on bt.
// Specifically, the following cases fold to zero,
// Specifically, this holds for the following cases,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (LShiftI _ #N) #M) => #0
Expand All @@ -2109,56 +2136,12 @@ jint MulNode::AndIL_min_trailing_zeros(PhaseGVN* phase, Node* expr, BasicType bt
// (AndI (ConI [+-] _ << #N) #M) => #0
// (AndL (ConL [+-] _ << #N) #M) => #0
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, we check for both operand orders.
bool MulNode::AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse) {
if (mask == nullptr || expr == nullptr) {
return false;
}
static bool AndIL_is_zero_element(const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt) {
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr) {
return false;
}
jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
if (zeros == 0) {
// try it the other way around
return check_reverse && AndIL_is_always_zero(phase, mask, expr, bt, false);
}

return ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
}

// Given an expression (AndX (AddX v1 v2) mask)
// determine if the AndX must always produce (AndX v1 mask),
// because v2 is bitwise disjoint from the mask.
// Because the AddX operands can come in either
// order, we check for both orders.
Node* MulNode::AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add = in(1);
Node* mask = in(2);
if (add == nullptr || mask == nullptr) {
return nullptr;
}
int addidx = 0;
if (add->Opcode() == Op_Add(bt)) {
addidx = 1;
} else if (mask->Opcode() == Op_Add(bt)) {
mask = add;
addidx = 2;
add = in(addidx);
}
if (addidx > 0) {
Node* add1 = add->in(1);
Node* add2 = add->in(2);
if (add1 != nullptr && add2 != nullptr) {
if (AndIL_is_always_zero(phase, add1, mask, bt, false)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_is_always_zero(phase, add2, mask, bt, false)) {
set_req_X(addidx, add1, phase);
return this;
}
}
}
return nullptr;
jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
return zeros > 0 && ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
}
3 changes: 1 addition & 2 deletions src/hotspot/share/opto/mulnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class MulNode : public Node {

static MulNode* make(Node* in1, Node* in2, BasicType bt);

static jint AndIL_min_trailing_zeros(PhaseGVN* phase, Node* addend, BasicType bt);
static bool AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse);
protected:
Node* AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt);
};

Expand Down
40 changes: 37 additions & 3 deletions test/hotspot/jtreg/compiler/c2/irTests/TestShiftAndMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

/*
* @test
* @bug 8277850 8278949 8285793
* @bug 8277850 8278949 8285793 8346664
* @summary C2: optimize mask checks in counted loops
* @library /test/lib /
* @run driver compiler.c2.irTests.TestShiftAndMask
Expand Down Expand Up @@ -120,7 +120,7 @@ public static void checkShiftNonConstMaskLong(long res) {
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftMaskInt(int i, int j) {
return (j + ((i + 1) << 2)) & 3; // transformed to: return j & 3;
return (j + (i << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftMaskInt")
Expand All @@ -133,6 +133,23 @@ public static void addShiftMaskInt_runner() {
}
}

@Test
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftPlusConstMaskInt(int i, int j) {
return (j + ((i + 5) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftPlusConstMaskInt")
public static void addShiftPlusConstMaskInt_runner() {
int i = RANDOM.nextInt();
int j = RANDOM.nextInt();
int res = addShiftPlusConstMaskInt(i, j);
if (res != (j & 3)) {
throw new RuntimeException("incorrect result: " + res);
}
}

@Test
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
Expand Down Expand Up @@ -165,7 +182,7 @@ public static void addSshiftNonConstMaskInt_runner() {
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftMaskLong(long i, long j) {
return (j + ((i - 3) << 2)) & 3; // transformed to: return j & 3;
return (j + (i << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftMaskLong")
Expand All @@ -178,6 +195,23 @@ public static void addShiftMaskLong_runner() {
}
}

@Test
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftPlusConstMaskLong(long i, long j) {
return (j + ((i - 5) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftPlusConstMaskLong")
public static void addShiftPlusConstMaskLong_runner() {
long i = RANDOM.nextLong();
long j = RANDOM.nextLong();
long res = addShiftPlusConstMaskLong(i, j);
if (res != (j & 3)) {
throw new RuntimeException("incorrect result: " + res);
}
}

@Test
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
Expand Down