Skip to content

Commit

Permalink
Implement more cases for getMaxBits (#2879)
Browse files Browse the repository at this point in the history
- Complete 64-bit cases in range `AddInt64` ... `ShrSInt64`
- `ExtendSInt32` and `ExtendUInt32` for unary cases
- For binary cases

  - `AddInt32` / `AddInt64`
  - `MulInt32` / `MulInt64`
  - `RemUInt32` / `RemUInt64`
  - `RemSInt32` / `RemSInt64`
  - `DivUInt32` / `DivUInt64`
  - `DivSInt32` / `DivSInt64`
  - and more

Also more fast paths for some getMaxBits calculations
  • Loading branch information
MaxGraey authored Sep 17, 2020
1 parent 6116553 commit 2d47c0b
Show file tree
Hide file tree
Showing 6 changed files with 719 additions and 40 deletions.
187 changes: 172 additions & 15 deletions src/ir/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,85 @@ struct DummyLocalInfoProvider {
template<typename LocalInfoProvider = DummyLocalInfoProvider>
Index getMaxBits(Expression* curr,
LocalInfoProvider* localInfoProvider = nullptr) {
if (auto* const_ = curr->dynCast<Const>()) {
if (auto* c = curr->dynCast<Const>()) {
switch (curr->type.getBasic()) {
case Type::i32:
return 32 - const_->value.countLeadingZeroes().geti32();
return 32 - c->value.countLeadingZeroes().geti32();
case Type::i64:
return 64 - const_->value.countLeadingZeroes().geti64();
return 64 - c->value.countLeadingZeroes().geti64();
default:
WASM_UNREACHABLE("invalid type");
}
} else if (auto* binary = curr->dynCast<Binary>()) {
switch (binary->op) {
// 32-bit
case AddInt32:
case SubInt32:
case MulInt32:
case DivSInt32:
case DivUInt32:
case RemSInt32:
case RemUInt32:
case RotLInt32:
case RotRInt32:
case SubInt32:
return 32;
case AddInt32: {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1);
}
case MulInt32: {
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
return std::min(Index(32), maxBitsLeft + maxBitsRight);
}
case DivSInt32: {
if (auto* c = binary->right->dynCast<Const>()) {
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
// If either side might be negative, then the result will be negative
if (maxBitsLeft == 32 || c->value.geti32() < 0) {
return 32;
}
int32_t bitsRight = getMaxBits(c);
return std::max(0, maxBitsLeft - bitsRight + 1);
}
return 32;
}
case DivUInt32: {
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
if (auto* c = binary->right->dynCast<Const>()) {
int32_t bitsRight = getMaxBits(c);
return std::max(0, maxBitsLeft - bitsRight + 1);
}
return maxBitsLeft;
}
case RemSInt32: {
if (auto* c = binary->right->dynCast<Const>()) {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
// if maxBitsLeft is negative
if (maxBitsLeft == 32) {
return 32;
}
auto bitsRight = Index(CeilLog2(c->value.geti32()));
return std::min(maxBitsLeft, bitsRight);
}
return 32;
}
case RemUInt32: {
if (auto* c = binary->right->dynCast<Const>()) {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
auto bitsRight = Index(CeilLog2(c->value.geti32()));
return std::min(maxBitsLeft, bitsRight);
}
return 32;
case AndInt32:
}
case AndInt32: {
return std::min(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
}
case OrInt32:
case XorInt32:
return std::max(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
case XorInt32: {
auto maxBits = getMaxBits(binary->right, localInfoProvider);
// if maxBits is negative
if (maxBits == 32) {
return 32;
}
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
}
case ShlInt32: {
if (auto* shifts = binary->right->dynCast<Const>()) {
return std::min(Index(32),
Expand All @@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr,
case ShrSInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
// if maxBits is negative
if (maxBits == 32) {
return 32;
}
Expand All @@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr,
}
return 32;
}
// 64-bit TODO
case RotLInt64:
case RotRInt64:
case SubInt64:
return 64;
case AddInt64: {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1);
}
case MulInt64: {
auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
return std::min(Index(64), maxBitsLeft + maxBitsRight);
}
case DivSInt64: {
if (auto* c = binary->right->dynCast<Const>()) {
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
// if maxBitsLeft or right const value is negative
if (maxBitsLeft == 64 || c->value.geti64() < 0) {
return 64;
}
int32_t bitsRight = getMaxBits(c);
return std::max(0, maxBitsLeft - bitsRight + 1);
}
return 64;
}
case DivUInt64: {
int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
if (auto* c = binary->right->dynCast<Const>()) {
int32_t bitsRight = getMaxBits(c);
return std::max(0, maxBitsLeft - bitsRight + 1);
}
return maxBitsLeft;
}
case RemSInt64: {
if (auto* c = binary->right->dynCast<Const>()) {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
// if maxBitsLeft is negative
if (maxBitsLeft == 64) {
return 64;
}
auto bitsRight = Index(CeilLog2(c->value.geti64()));
return std::min(maxBitsLeft, bitsRight);
}
return 64;
}
case RemUInt64: {
if (auto* c = binary->right->dynCast<Const>()) {
auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
auto bitsRight = Index(CeilLog2(c->value.geti64()));
return std::min(maxBitsLeft, bitsRight);
}
return 64;
}
case AndInt64: {
auto maxBits = getMaxBits(binary->right, localInfoProvider);
return std::min(getMaxBits(binary->left, localInfoProvider), maxBits);
}
case OrInt64:
case XorInt64: {
auto maxBits = getMaxBits(binary->right, localInfoProvider);
// if maxBits is negative
if (maxBits == 64) {
return 64;
}
return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
}
case ShlInt64: {
if (auto* shifts = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
return std::min(Index(64),
Bits::getEffectiveShifts(shifts) + maxBits);
}
return 64;
}
case ShrUInt64: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
auto shifts =
std::min(Index(Bits::getEffectiveShifts(shift)),
maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
}
return 64;
}
case ShrSInt64: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
// if maxBits is negative
if (maxBits == 64) {
return 64;
}
auto shifts =
std::min(Index(Bits::getEffectiveShifts(shift)),
maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
}
return 64;
}
// comparisons
case EqInt32:
case NeInt32:
Expand All @@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr,
case GtUInt32:
case GeSInt32:
case GeUInt32:

case EqInt64:
case NeInt64:
case LtSInt64:
Expand All @@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr,
case GtUInt64:
case GeSInt64:
case GeUInt64:

case EqFloat32:
case NeFloat32:
case LtFloat32:
case LeFloat32:
case GtFloat32:
case GeFloat32:

case EqFloat64:
case NeFloat64:
case LtFloat64:
Expand All @@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr,
case EqZInt64:
return 1;
case WrapInt64:
case ExtendUInt32:
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
case ExtendSInt32: {
auto maxBits = getMaxBits(unary->value, localInfoProvider);
return maxBits == 32 ? Index(64) : maxBits;
}
default: {
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/support/bits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ template<> int CountLeadingZeroes<uint64_t>(uint64_t v) {
#endif
}

template<> int CeilLog2<uint32_t>(uint32_t v) {
return 32 - CountLeadingZeroes(v - 1);
}

template<> int CeilLog2<uint64_t>(uint64_t v) {
return 64 - CountLeadingZeroes(v - 1);
}

uint32_t Log2(uint32_t v) {
switch (v) {
default:
Expand Down
6 changes: 6 additions & 0 deletions src/support/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ template<typename T> int PopCount(T);
template<typename T> uint32_t BitReverse(T);
template<typename T> int CountTrailingZeroes(T);
template<typename T> int CountLeadingZeroes(T);
template<typename T> int CeilLog2(T);

#ifndef wasm_support_bits_definitions
// The template specializations are provided elsewhere.
Expand All @@ -52,6 +53,8 @@ extern template int CountTrailingZeroes(uint32_t);
extern template int CountTrailingZeroes(uint64_t);
extern template int CountLeadingZeroes(uint32_t);
extern template int CountLeadingZeroes(uint64_t);
extern template int CeilLog2(uint32_t);
extern template int CeilLog2(uint64_t);
#endif

// Convenience signed -> unsigned. It usually doesn't make much sense to use bit
Expand All @@ -65,6 +68,9 @@ template<typename T> int CountTrailingZeroes(T v) {
template<typename T> int CountLeadingZeroes(T v) {
return CountLeadingZeroes(typename std::make_unsigned<T>::type(v));
}
template<typename T> int CeilLog2(T v) {
return CeilLog2(typename std::make_unsigned<T>::type(v));
}
template<typename T> bool IsPowerOf2(T v) {
return v != 0 && (v & (v - 1)) == 0;
}
Expand Down
Loading

0 comments on commit 2d47c0b

Please sign in to comment.