-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BasicDecimal256 Multiplication Support (PR for decimal256 branch, not master) #8344
Changes from 17 commits
825f24e
43cec0e
7f22b19
0725cf6
6c1b7b1
4db42dd
2dac4d6
8b272a9
cd50114
ac700d9
3d5f74b
a21131b
cdacfaa
a5140d8
23abc2a
dbc7266
40b4503
8774ef2
0c6ab8e
0f0c907
1d3d624
f2854a6
41431e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,8 +122,6 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = { | |
|
||
#ifdef ARROW_USE_NATIVE_INT128 | ||
static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF; | ||
#else | ||
static constexpr uint64_t kIntMask = 0xFFFFFFFF; | ||
#endif | ||
|
||
// same as ScaleMultipliers[38] - 1 | ||
|
@@ -254,69 +252,148 @@ BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) { | |
|
||
namespace { | ||
|
||
// TODO: Remove this guard once it's used by BasicDecimal256 | ||
#ifndef ARROW_USE_NATIVE_INT128 | ||
// This method losslessly multiplies x and y into a 128 bit unsigned integer | ||
// whose high bits will be stored in hi and low bits in lo. | ||
void ExtendAndMultiplyUint64(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* lo) { | ||
#ifdef ARROW_USE_NATIVE_INT128 | ||
const __uint128_t r = static_cast<__uint128_t>(x) * y; | ||
*lo = r & kInt64Mask; | ||
*hi = r >> 64; | ||
#else | ||
// If we can't use a native fallback, perform multiplication | ||
// by splitting up x and y into 32 bit high/low bit components, | ||
// Multiply two N bit word components into a 2*N bit result, with high bits | ||
// stored in hi and low bits in lo. | ||
template <typename Word> | ||
void ExtendAndMultiplyUint(Word x, Word y, Word* hi, Word* lo) { | ||
// Perform multiplication on two N bit words x and y into a 2*N bit result | ||
// by splitting up x and y into N/2 bit high/low bit components, | ||
// allowing us to represent the multiplication as | ||
// x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32 | ||
// + x_hi * y_hi * 2^64. | ||
// x * y = x_lo * y_lo + x_hi * y_lo * 2^N/2 + y_hi * x_lo * 2^N/2 | ||
// + x_hi * y_hi * 2^N | ||
// | ||
// Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi. | ||
// Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi | ||
// Therefore, | ||
// lo_lo is (x_lo * y_lo)_lo, | ||
// lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo, | ||
// hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi, | ||
// hi_hi is (x_hi * y_hi)_hi | ||
const uint64_t x_lo = x & kIntMask; | ||
const uint64_t y_lo = y & kIntMask; | ||
const uint64_t x_hi = x >> 32; | ||
const uint64_t y_hi = y >> 32; | ||
constexpr Word kHighBitShift = sizeof(Word) * 4; | ||
constexpr Word kLowBitMask = (static_cast<Word>(1) << kHighBitShift) - 1; | ||
|
||
const uint64_t t = x_lo * y_lo; | ||
const uint64_t t_lo = t & kIntMask; | ||
const uint64_t t_hi = t >> 32; | ||
const Word x_lo = x & kLowBitMask; | ||
const Word y_lo = y & kLowBitMask; | ||
const Word x_hi = x >> kHighBitShift; | ||
const Word y_hi = y >> kHighBitShift; | ||
|
||
const uint64_t u = x_hi * y_lo + t_hi; | ||
const uint64_t u_lo = u & kIntMask; | ||
const uint64_t u_hi = u >> 32; | ||
const Word t = x_lo * y_lo; | ||
const Word t_lo = t & kLowBitMask; | ||
const Word t_hi = t >> kHighBitShift; | ||
|
||
const uint64_t v = x_lo * y_hi + u_lo; | ||
const uint64_t v_hi = v >> 32; | ||
const Word u = x_hi * y_lo + t_hi; | ||
const Word u_lo = u & kLowBitMask; | ||
const Word u_hi = u >> kHighBitShift; | ||
|
||
const Word v = x_lo * y_hi + u_lo; | ||
const Word v_hi = v >> kHighBitShift; | ||
|
||
*hi = x_hi * y_hi + u_hi + v_hi; | ||
*lo = (v << 32) | t_lo; | ||
#endif | ||
*lo = (v << kHighBitShift) + t_lo; | ||
} | ||
|
||
// Convenience wrapper type over 128 bit unsigned integers | ||
#ifdef ARROW_USE_NATIVE_INT128 | ||
struct uint128_t { | ||
Luminarys marked this conversation as resolved.
Show resolved
Hide resolved
|
||
uint128_t() {} | ||
uint128_t(uint64_t hi, uint64_t lo) : val_((static_cast<__uint128_t>(hi) << 64) | lo) {} | ||
uint128_t(const BasicDecimal128& decimal) { | ||
val_ = (static_cast<__uint128_t>(decimal.high_bits()) << 64) | decimal.low_bits(); | ||
} | ||
|
||
uint64_t hi() { return val_ >> 64; } | ||
uint64_t lo() { return val_ & kInt64Mask; } | ||
|
||
uint128_t& operator+=(const uint128_t& other) { | ||
val_ += other.val_; | ||
return *this; | ||
} | ||
|
||
__uint128_t val_; | ||
}; | ||
|
||
uint128_t operator*(const uint128_t& left, const uint128_t& right) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please try defining operator*= instead of operator*. Maybe this can help the compiler generate more efficient code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (or perhaps some other change I made) seems to have improved performance significantly, it takes 13 ns~ with native int128 and 40 ns~ with uint64 fallback. |
||
uint128_t r; | ||
r.val_ = left.val_ * right.val_; | ||
return r; | ||
} | ||
#else | ||
struct uint128_t { | ||
uint128_t() {} | ||
uint128_t(uint64_t hi, uint64_t lo) : hi_(hi), lo_(lo) {} | ||
uint128_t(const BasicDecimal128& decimal) { | ||
hi_ = decimal.high_bits(); | ||
lo_ = decimal.low_bits(); | ||
} | ||
|
||
uint64_t hi() const { return hi_; } | ||
uint64_t lo() const { return lo_; } | ||
|
||
uint128_t& operator+=(const uint128_t& other) { | ||
// To deduce the carry bit, we perform "65 bit" addition on the low bits and | ||
// seeing if the resulting high bit is 1. This is accomplished by shifting the | ||
// low bits to the right by 1 (chopping off the lowest bit), then adding 1 if the | ||
// result of adding the two chopped bits would have produced a carry. | ||
uint64_t carry = (((lo_ & other.lo_) & 1) + (lo_ >> 1) + (other.lo_ >> 1)) >> 63; | ||
hi_ += other.hi_ + carry; | ||
lo_ += other.lo_; | ||
return *this; | ||
} | ||
|
||
uint64_t hi_; | ||
uint64_t lo_; | ||
}; | ||
|
||
uint128_t operator*(const uint128_t& left, const uint128_t& right) { | ||
uint128_t r; | ||
ExtendAndMultiplyUint(left.lo_, right.lo_, &r.hi_, &r.lo_); | ||
r.hi_ += (left.hi_ * right.lo_) + (left.lo_ * right.hi_); | ||
return r; | ||
} | ||
#endif | ||
|
||
void MultiplyUint128(uint64_t x_hi, uint64_t x_lo, uint64_t y_hi, uint64_t y_lo, | ||
uint64_t* hi, uint64_t* lo) { | ||
void ExtendAndMultiplyUint128(uint128_t x, uint128_t y, uint128_t* hi, uint128_t* lo) { | ||
Luminarys marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#ifdef ARROW_USE_NATIVE_INT128 | ||
const __uint128_t x = (static_cast<__uint128_t>(x_hi) << 64) | x_lo; | ||
const __uint128_t y = (static_cast<__uint128_t>(y_hi) << 64) | y_lo; | ||
const __uint128_t r = x * y; | ||
*lo = r & kInt64Mask; | ||
*hi = r >> 64; | ||
return ExtendAndMultiplyUint(x.val_, y.val_, &hi->val_, &lo->val_); | ||
#else | ||
// To perform 128 bit multiplication without a native fallback | ||
// we first perform lossless 64 bit multiplication of the low | ||
// bits, and then add x_hi * y_lo and x_lo * y_hi to the high | ||
// bits. Note that we can skip adding x_hi * y_hi because it | ||
// always will be over 128 bits. | ||
ExtendAndMultiplyUint64(x_lo, y_lo, hi, lo); | ||
*hi += (x_hi * y_lo) + (x_lo * y_hi); | ||
// This follows the same algorithm as in ExtendAndMultiplyUint, but must | ||
// perform manual overflow checks. | ||
ExtendAndMultiplyUint(x.hi_, y.hi_, &hi->hi_, &hi->lo_); | ||
ExtendAndMultiplyUint(x.lo_, y.lo_, &lo->hi_, &lo->lo_); | ||
|
||
uint128_t t; | ||
ExtendAndMultiplyUint(x.hi_, y.lo_, &t.hi_, &t.lo_); | ||
lo->hi_ += t.lo_; | ||
// Check for overflow in lo.hi | ||
if (lo->hi_ < t.lo_) { | ||
hi->lo_++; | ||
} | ||
hi->lo_ += t.hi_; | ||
// Check for overflow in hi.lo | ||
if (hi->lo_ < t.hi_) { | ||
hi->hi_++; | ||
} | ||
|
||
ExtendAndMultiplyUint(x.lo_, y.hi_, &t.hi_, &t.lo_); | ||
lo->hi_ += t.lo_; | ||
// Check for overflow in lo.hi | ||
if (lo->hi_ < t.lo_) { | ||
hi->lo_++; | ||
} | ||
hi->lo_ += t.hi_; | ||
// Check for overflow in hi.lo | ||
if (hi->lo_ < t.hi_) { | ||
hi->hi_++; | ||
} | ||
#endif | ||
} | ||
|
||
void MultiplyUint256(uint128_t x_hi, uint128_t x_lo, uint128_t y_hi, uint128_t y_lo, | ||
uint128_t* hi, uint128_t* lo) { | ||
ExtendAndMultiplyUint128(x_lo, y_lo, hi, lo); | ||
*hi += x_hi * y_lo; | ||
*hi += x_lo * y_hi; | ||
} | ||
|
||
} // namespace | ||
|
||
BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) { | ||
|
@@ -325,10 +402,9 @@ BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) { | |
const bool negate = Sign() != right.Sign(); | ||
BasicDecimal128 x = BasicDecimal128::Abs(*this); | ||
BasicDecimal128 y = BasicDecimal128::Abs(right); | ||
uint64_t hi; | ||
MultiplyUint128(x.high_bits(), x.low_bits(), y.high_bits(), y.low_bits(), &hi, | ||
&low_bits_); | ||
high_bits_ = hi; | ||
uint128_t r = uint128_t(x) * uint128_t(y); | ||
high_bits_ = r.hi(); | ||
low_bits_ = r.lo(); | ||
if (negate) { | ||
Negate(); | ||
} | ||
|
@@ -800,6 +876,13 @@ BasicDecimal256& BasicDecimal256::Negate() { | |
return *this; | ||
} | ||
|
||
BasicDecimal256& BasicDecimal256::Abs() { return *this < 0 ? Negate() : *this; } | ||
|
||
BasicDecimal256 BasicDecimal256::Abs(const BasicDecimal256& in) { | ||
BasicDecimal256 result(in); | ||
return result.Abs(); | ||
} | ||
|
||
std::array<uint8_t, 32> BasicDecimal256::ToBytes() const { | ||
std::array<uint8_t, 32> out{{0}}; | ||
ToBytes(out.data()); | ||
|
@@ -821,12 +904,41 @@ void BasicDecimal256::ToBytes(uint8_t* out) const { | |
#endif | ||
} | ||
|
||
BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) { | ||
// Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the | ||
// min the negation taking the absolute values here should always be safe. | ||
const bool negate = Sign() != right.Sign(); | ||
BasicDecimal256 x = BasicDecimal256::Abs(*this); | ||
BasicDecimal256 y = BasicDecimal256::Abs(right); | ||
|
||
uint128_t r_hi; | ||
uint128_t r_lo; | ||
MultiplyUint256({x.little_endian_array_[3], x.little_endian_array_[2]}, | ||
{x.little_endian_array_[1], x.little_endian_array_[0]}, | ||
{y.little_endian_array_[3], y.little_endian_array_[2]}, | ||
{y.little_endian_array_[1], y.little_endian_array_[0]}, &r_hi, &r_lo); | ||
little_endian_array_[0] = r_lo.lo(); | ||
little_endian_array_[1] = r_lo.hi(); | ||
little_endian_array_[2] = r_hi.lo(); | ||
little_endian_array_[3] = r_hi.hi(); | ||
if (negate) { | ||
Negate(); | ||
} | ||
return *this; | ||
} | ||
|
||
DecimalStatus BasicDecimal256::Rescale(int32_t original_scale, int32_t new_scale, | ||
BasicDecimal256* out) const { | ||
// TODO: implement. | ||
return DecimalStatus::kSuccess; | ||
} | ||
|
||
BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) { | ||
BasicDecimal256 result = left; | ||
result *= right; | ||
return result; | ||
} | ||
|
||
bool operator==(const BasicDecimal256& left, const BasicDecimal256& right) { | ||
return left.little_endian_array() == right.little_endian_array(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's simpler if this method handles only uint64_t, and there is another method that takes std::array<uint64_t, n> and uses for loops like https://github.com/google/zetasql/blob/master/zetasql/common/multiprecision_int.h#L723. This way, ExtendAndMultiplyUint128 doesn't need to repeat the similar pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. This saves a lot of code, though does take 60 ns for multiplication as opposed to 20 ns prior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try making ExtendAndMultiplyUint inline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized that I wasn't using the native path prior, which is why the benchmark was so slow. Updated, new results are 32 ns when __uint128_t is used and 65 ns when uint64_t is used, which I think is more reasonable.