Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8277850: C2: optimize mask checks in counted loops #6697

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,15 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {
return TypeInt::INT; // No constants to be had
}

const Type* AndINode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask(phase, in(2), in(1), T_INT)) {
return TypeInt::ZERO;
}

return MulNode::Value(phase);
}

//------------------------------Identity---------------------------------------
// Masking off the high bits of an unsigned load is not required
Node* AndINode::Identity(PhaseGVN* phase) {
Expand Down Expand Up @@ -598,6 +607,12 @@ Node *AndINode::Ideal(PhaseGVN *phase, bool can_reshape) {
phase->type(load->in(1)) == TypeInt::ZERO )
return new AndINode( load->in(2), in(2) );

// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_INT);
if (progress != NULL) {
return progress;
}

return MulNode::Ideal(phase, can_reshape);
}

Expand Down Expand Up @@ -629,6 +644,15 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {
return TypeLong::LONG; // No constants to be had
}

const Type* AndLNode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask(phase, in(2), in(1), T_LONG)) {
return TypeLong::ZERO;
}

return MulNode::Value(phase);
}

//------------------------------Identity---------------------------------------
// Masking off the high bits of an unsigned load is not required
Node* AndLNode::Identity(PhaseGVN* phase) {
Expand Down Expand Up @@ -675,7 +699,7 @@ Node *AndLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
const jlong mask = t2->get_con();

Node* in1 = in(1);
uint op = in1->Opcode();
int op = in1->Opcode();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got curious and checked the code, we have similar patterns everywhere. Filed JDK-8278328 to clean up this mess.


// Are we masking a long that was converted from an int with a mask
// that fits in 32-bits? Commute them and use an AndINode. Don't
Expand Down Expand Up @@ -705,6 +729,12 @@ Node *AndLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
}
}

// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_LONG);
if (progress != NULL) {
return progress;
}

return MulNode::Ideal(phase, can_reshape);
}

Expand Down Expand Up @@ -1683,3 +1713,64 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
return TypeLong::LONG;
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment here too which pattern it is looking for.

// Helper method to transform:
// patterns similar to (v << 2) & 3 to 0
// and
// patterns similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
bool MulNode::AndIL_shift_and_mask(PhaseGVN* phase, Node* mask, Node* shift, BasicType bt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The operand order is distracting and perhaps confusing, since mask is in(2) and shift is in(1).
  2. The mask is not necessarily in(2), because it doesn't have to be a constant; it can be a bounded value.
  3. I wrote an expanded comment.
  4. There's a redundant NULL check.
  5. You should use a different name for the adjusted bt value.

Here are the changes I suggest to this function:

// Given an expression (AndX shift mask) or (AndX mask shift),
// determine if the AndX must always produce zero, because the
// the shift (x<<N) is bitwise disjoint from the mask #M.
// The X in AndX must be I or L, depending on bt.
// Specifically, the following cases fold to zero,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
//   (AndI (LShiftI _ #N) #M) => #0
//   (AndL (LShiftL _ #N) #M) => #0
//   (AndL (ConvI2L (LShiftI _ #N)) #M) => #0
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, we check the AndX for both operand orders.
bool MulNode::AndIL_shift_and_mask(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse) const {
  if (mask == NULL || shift == NULL) {
    return false;
  }
  const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
  const TypeInteger* shift_t = phase->type(shift)->isa_integer(bt);
  if (mask_t == NULL || shift_t == NULL) {
    return false;
  }
+ BasicType shift_bt = bt;
+ if (bt == T_LONG && shift->Opcode() == Op_ConvI2L) {
    Node* val = shift->in(1);
    if (val == NULL) {
      return false;
    }
+   if (val->Opcode() == Op_LShiftI) {
+     shift_bt = T_INT;
+     shift = val;
+   }
  }
#s/bt/shift_bt/
  if (shift->Opcode() != Op_LShift(shift_bt)) {
+   if (check_reverse &&
+       (mask->Opcode() == Op_LShift(bt) ||
+        (bt == T_LONG && mask->Opcode() == Op_ConvI2L))) {
+     // try it the other way around
+     return AndIL_shift_and_mask(phase, shift, mask, bt, false);
+   }
    return false;
  }
  Node* shift2 = shift->in(2);
  if (shift2 == NULL) {
    return false;
  }
  const Type* shift2_t = phase->type(shift2);
  if (!shift2_t->isa_int() || !shift2_t->is_int()->is_con()) {
    return false;
  }

#s/bt/shift_bt/
  jint shift_con = shift2_t->is_int()->get_con() & ((shift_bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
  if ((1L << shift_con) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0) {
    return true;
  }

  return false;
}

Suggest signature change:

if (mask == NULL || shift == NULL) {
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to check shift for TOP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code below:

const TypeInteger* shift_t = phase->type(shift)->isa_integer(bt);
  if (mask_t == NULL || shift_t == NULL) {

catches the case where shift is top, I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. It is type check.

}
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
const TypeInteger* shift_t = phase->type(shift)->isa_integer(bt);
if (mask_t == NULL || shift_t == NULL) {
return false;
}
if (bt == T_LONG && shift != NULL && shift->Opcode() == Op_ConvI2L) {
bt = T_INT;
shift = shift->in(1);
if (shift == NULL) {
return false;
}
}
if (shift->Opcode() != Op_LShift(bt)) {
return false;
}
Node* shift2 = shift->in(2);
if (shift2 == NULL) {
return false;
}
const Type* shift2_t = phase->type(shift2);
if (!shift2_t->isa_int() || !shift2_t->is_int()->is_con()) {
return false;
}

jint shift_con = shift2_t->is_int()->get_con() & ((bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
if ((((jlong)1) << shift_con) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0) {
return true;
}

return false;
}

// Helper method to transform:
// patterns similar to (v1 + (v2 << 2)) & 3 to v1 & 3
Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment describing the pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comments for this function also.

// Given an expression (AndX (AddX v1 (LShiftX v2 #N)) #M)
// determine if the AndX must always produce (AndX v1 #M),
// because the shift (v2<<N) is bitwise disjoint from the mask #M.
// The X in AndX will be I or L, depending on bt.
// Specifically, the following cases fold,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
//   (AndI (AddI v1 (LShiftI _ #N)) #M) => v1
//   (AndL (AddI v1 (LShiftL _ #N)) #M) => v1
//   (AndL (AddL v1 (ConvI2L (LShiftI _ #N))) #M) => v1
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, and because the AddX operands can come in either
// order, we check for every operand order.
Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
  Node* add  = in(1);
  Node* mask = in(2);
+ if (add == NULL || mask == NULL) {
+   return NULL;
+ }
+ int addidx = 0;
  if (add->Opcode() == Op_Add(bt)) {
+   addidx = 1;
+ } else if (mask->Opcode() == Op_Add(bt)) {
+   mask = add;
+   addidx = 2;
+   add = in(addidx);
+ }
+ if (addidx > 0) {
    Node* add1 = add->in(1);
    Node* add2 = add->in(2);
    if (add1 != NULL && add2 != NULL) {
      if (AndIL_shift_and_mask(phase, add1, mask, bt, false)) {
        set_req_X(addidx, add2, phase);
        return this;
      } else if (AndIL_shift_and_mask(phase, add2, mask, bt, false)) {
        set_req_X(addidx, add1, phase);
        return this;
      }
    }
  }
  return NULL;
}

Node* in1 = in(1);
Node* in2 = in(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caller already determine that in2 is const int mask = t2->get_con(). You can pass it here.

if (in1 != NULL && in2 != NULL && in1->Opcode() == Op_Add(bt)) {
Node* add1 = in1->in(1);
Node* add2 = in1->in(2);
if (add1 != NULL && add2 != NULL) {
if (AndIL_shift_and_mask(phase, in2, add1, bt)) {
set_req_X(1, add2, phase);
return this;
} else if (AndIL_shift_and_mask(phase, in2, add2, bt)) {
set_req_X(1, add1, phase);
return this;
}
}
}
return NULL;
}
5 changes: 5 additions & 0 deletions src/hotspot/share/opto/mulnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class MulNode : public Node {
virtual int min_opcode() const = 0;

static MulNode* make(Node* in1, Node* in2, BasicType bt);

static bool AndIL_shift_and_mask(PhaseGVN* phase, Node* mask, Node* shift, BasicType bt);
Node* AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt);
};

//------------------------------MulINode---------------------------------------
Expand Down Expand Up @@ -189,6 +192,7 @@ class AndINode : public MulINode {
virtual int Opcode() const;
virtual Node *Ideal(PhaseGVN *phase, bool can_reshape);
virtual Node* Identity(PhaseGVN* phase);
virtual const Type* Value(PhaseGVN* phase) const;
virtual const Type *mul_ring( const Type *, const Type * ) const;
const Type *mul_id() const { return TypeInt::MINUS_1; }
const Type *add_id() const { return TypeInt::ZERO; }
Expand All @@ -208,6 +212,7 @@ class AndLNode : public MulLNode {
virtual int Opcode() const;
virtual Node *Ideal(PhaseGVN *phase, bool can_reshape);
virtual Node* Identity(PhaseGVN* phase);
virtual const Type* Value(PhaseGVN* phase) const;
virtual const Type *mul_ring( const Type *, const Type * ) const;
const Type *mul_id() const { return TypeLong::MINUS_1; }
const Type *add_id() const { return TypeLong::ZERO; }
Expand Down
108 changes: 108 additions & 0 deletions test/hotspot/jtreg/compiler/c2/irTests/TestShiftAndMask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2021, Red Hat, Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

package compiler.c2.irTests;

import compiler.lib.ir_framework.*;

/*
* @test
* @bug 8277850
* @summary C2: optimize mask checks in counted loops
* @library /test/lib /
* @run driver compiler.c2.irTests.TestShiftAndMask
*/

public class TestShiftAndMask {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests!

public static void main(String[] args) {
TestFramework.run();
}

@Test
@Arguments(Argument.RANDOM_EACH)
@IR(failOn = { IRNode.AND_I, IRNode.LSHIFT_I })
public static int shiftMaskInt(int i) {
return (i << 2) & 3; // transformed to: return 0;
}

@Test
@Arguments(Argument.RANDOM_EACH)
@IR(failOn = { IRNode.AND_L, IRNode.LSHIFT_L })
public static long shiftMaskLong(long i) {
return (i << 2) & 3; // transformed to: return 0;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftMaskInt(int i, int j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftMaskLong(long i, long j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(failOn = { IRNode.AND_I, IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftMaskInt2(int i, int j) {
return ((j << 2) + (i << 2)) & 3; // transformed to: return 0;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(failOn = { IRNode.AND_L, IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftMaskLong2(long i, long j) {
return ((j << 2) + (i << 2)) & 3; // transformed to: return 0;
}

@Test
@Arguments(Argument.RANDOM_EACH)
@IR(failOn = { IRNode.AND_L, IRNode.LSHIFT_I })
public static long shiftConvMask(int i) {
return ((long)(i << 2)) & 3; // transformed to: return 0;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_I, IRNode.CONV_I2L })
public static long addShiftConvMask(int i, long j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
}

@Test
@Arguments({Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(failOn = { IRNode.AND_L, IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftConvMask2(int i, int j) {
return (((long)(j << 2)) + ((long)(i << 2))) & 3; // transformed to: return 0;
}

}

8 changes: 8 additions & 0 deletions test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ public class IRNode {
public static final String SCOPE_OBJECT = "(.*# ScObj.*" + END;
public static final String MEMBAR = START + "MemBar" + MID + END;

public static final String AND_I = START + "AndI" + MID + END;
public static final String AND_L = START + "AndL" + MID + END;
public static final String LSHIFT_I = START + "LShiftI" + MID + END;
public static final String LSHIFT_L = START + "LShiftL" + MID + END;
public static final String ADD_I = START + "AddI" + MID + END;
public static final String ADD_L = START + "AddL" + MID + END;
public static final String CONV_I2L = START + "ConvI2L" + MID + END;

/**
* Called by {@link IRMatcher} to merge special composite nodes together with additional user-defined input.
*/
Expand Down