Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BasicDecimal256 Multiplication Support (PR for decimal256 branch, not master) #8344

Merged
merged 23 commits into from
Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 161 additions & 49 deletions cpp/src/arrow/util/basic_decimal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = {

#ifdef ARROW_USE_NATIVE_INT128
static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF;
#else
static constexpr uint64_t kIntMask = 0xFFFFFFFF;
#endif

// same as ScaleMultipliers[38] - 1
Expand Down Expand Up @@ -254,69 +252,148 @@ BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) {

namespace {

// TODO: Remove this guard once it's used by BasicDecimal256
#ifndef ARROW_USE_NATIVE_INT128
// This method losslessly multiplies x and y into a 128 bit unsigned integer
// whose high bits will be stored in hi and low bits in lo.
void ExtendAndMultiplyUint64(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* lo) {
#ifdef ARROW_USE_NATIVE_INT128
const __uint128_t r = static_cast<__uint128_t>(x) * y;
*lo = r & kInt64Mask;
*hi = r >> 64;
#else
// If we can't use a native fallback, perform multiplication
// by splitting up x and y into 32 bit high/low bit components,
// Multiply two N bit word components into a 2*N bit result, with high bits
// stored in hi and low bits in lo.
template <typename Word>
void ExtendAndMultiplyUint(Word x, Word y, Word* hi, Word* lo) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's simpler if this method handles only uint64_t, and there is another method that takes std::array<uint64_t, n> and uses for loops like https://github.com/google/zetasql/blob/master/zetasql/common/multiprecision_int.h#L723. This way, ExtendAndMultiplyUint128 doesn't need to repeat the similar pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This saves a lot of code, though does take 60 ns for multiplication as opposed to 20 ns prior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try making ExtendAndMultiplyUint inline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that I wasn't using the native path prior, which is why the benchmark was so slow. Updated, new results are 32 ns when __uint128_t is used and 65 ns when uint64_t is used, which I think is more reasonable.

// Perform multiplication on two N bit words x and y into a 2*N bit result
// by splitting up x and y into N/2 bit high/low bit components,
// allowing us to represent the multiplication as
// x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32
// + x_hi * y_hi * 2^64.
// x * y = x_lo * y_lo + x_hi * y_lo * 2^N/2 + y_hi * x_lo * 2^N/2
// + x_hi * y_hi * 2^N
//
// Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi.
// Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi
// Therefore,
// lo_lo is (x_lo * y_lo)_lo,
// lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo,
// hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi,
// hi_hi is (x_hi * y_hi)_hi
const uint64_t x_lo = x & kIntMask;
const uint64_t y_lo = y & kIntMask;
const uint64_t x_hi = x >> 32;
const uint64_t y_hi = y >> 32;
constexpr Word kHighBitShift = sizeof(Word) * 4;
constexpr Word kLowBitMask = (static_cast<Word>(1) << kHighBitShift) - 1;

const uint64_t t = x_lo * y_lo;
const uint64_t t_lo = t & kIntMask;
const uint64_t t_hi = t >> 32;
const Word x_lo = x & kLowBitMask;
const Word y_lo = y & kLowBitMask;
const Word x_hi = x >> kHighBitShift;
const Word y_hi = y >> kHighBitShift;

const uint64_t u = x_hi * y_lo + t_hi;
const uint64_t u_lo = u & kIntMask;
const uint64_t u_hi = u >> 32;
const Word t = x_lo * y_lo;
const Word t_lo = t & kLowBitMask;
const Word t_hi = t >> kHighBitShift;

const uint64_t v = x_lo * y_hi + u_lo;
const uint64_t v_hi = v >> 32;
const Word u = x_hi * y_lo + t_hi;
const Word u_lo = u & kLowBitMask;
const Word u_hi = u >> kHighBitShift;

const Word v = x_lo * y_hi + u_lo;
const Word v_hi = v >> kHighBitShift;

*hi = x_hi * y_hi + u_hi + v_hi;
*lo = (v << 32) | t_lo;
#endif
*lo = (v << kHighBitShift) + t_lo;
}

// Convenience wrapper type over 128 bit unsigned integers
#ifdef ARROW_USE_NATIVE_INT128
struct uint128_t {
Luminarys marked this conversation as resolved.
Show resolved Hide resolved
uint128_t() {}
uint128_t(uint64_t hi, uint64_t lo) : val_((static_cast<__uint128_t>(hi) << 64) | lo) {}
uint128_t(const BasicDecimal128& decimal) {
val_ = (static_cast<__uint128_t>(decimal.high_bits()) << 64) | decimal.low_bits();
}

uint64_t hi() { return val_ >> 64; }
uint64_t lo() { return val_ & kInt64Mask; }

uint128_t& operator+=(const uint128_t& other) {
val_ += other.val_;
return *this;
}

__uint128_t val_;
};

uint128_t operator*(const uint128_t& left, const uint128_t& right) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try defining operator*= instead of operator*. Maybe this can help the compiler generate more efficient code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (or perhaps some other change I made) seems to have improved performance significantly, it takes 13 ns~ with native int128 and 40 ns~ with uint64 fallback.

uint128_t r;
r.val_ = left.val_ * right.val_;
return r;
}
#else
struct uint128_t {
uint128_t() {}
uint128_t(uint64_t hi, uint64_t lo) : hi_(hi), lo_(lo) {}
uint128_t(const BasicDecimal128& decimal) {
hi_ = decimal.high_bits();
lo_ = decimal.low_bits();
}

uint64_t hi() const { return hi_; }
uint64_t lo() const { return lo_; }

uint128_t& operator+=(const uint128_t& other) {
// To deduce the carry bit, we perform "65 bit" addition on the low bits and
// seeing if the resulting high bit is 1. This is accomplished by shifting the
// low bits to the right by 1 (chopping off the lowest bit), then adding 1 if the
// result of adding the two chopped bits would have produced a carry.
uint64_t carry = (((lo_ & other.lo_) & 1) + (lo_ >> 1) + (other.lo_ >> 1)) >> 63;
hi_ += other.hi_ + carry;
lo_ += other.lo_;
return *this;
}

uint64_t hi_;
uint64_t lo_;
};

uint128_t operator*(const uint128_t& left, const uint128_t& right) {
uint128_t r;
ExtendAndMultiplyUint(left.lo_, right.lo_, &r.hi_, &r.lo_);
r.hi_ += (left.hi_ * right.lo_) + (left.lo_ * right.hi_);
return r;
}
#endif

void MultiplyUint128(uint64_t x_hi, uint64_t x_lo, uint64_t y_hi, uint64_t y_lo,
uint64_t* hi, uint64_t* lo) {
void ExtendAndMultiplyUint128(uint128_t x, uint128_t y, uint128_t* hi, uint128_t* lo) {
Luminarys marked this conversation as resolved.
Show resolved Hide resolved
#ifdef ARROW_USE_NATIVE_INT128
const __uint128_t x = (static_cast<__uint128_t>(x_hi) << 64) | x_lo;
const __uint128_t y = (static_cast<__uint128_t>(y_hi) << 64) | y_lo;
const __uint128_t r = x * y;
*lo = r & kInt64Mask;
*hi = r >> 64;
return ExtendAndMultiplyUint(x.val_, y.val_, &hi->val_, &lo->val_);
#else
// To perform 128 bit multiplication without a native fallback
// we first perform lossless 64 bit multiplication of the low
// bits, and then add x_hi * y_lo and x_lo * y_hi to the high
// bits. Note that we can skip adding x_hi * y_hi because it
// always will be over 128 bits.
ExtendAndMultiplyUint64(x_lo, y_lo, hi, lo);
*hi += (x_hi * y_lo) + (x_lo * y_hi);
// This follows the same algorithm as in ExtendAndMultiplyUint, but must
// perform manual overflow checks.
ExtendAndMultiplyUint(x.hi_, y.hi_, &hi->hi_, &hi->lo_);
ExtendAndMultiplyUint(x.lo_, y.lo_, &lo->hi_, &lo->lo_);

uint128_t t;
ExtendAndMultiplyUint(x.hi_, y.lo_, &t.hi_, &t.lo_);
lo->hi_ += t.lo_;
// Check for overflow in lo.hi
if (lo->hi_ < t.lo_) {
hi->lo_++;
}
hi->lo_ += t.hi_;
// Check for overflow in hi.lo
if (hi->lo_ < t.hi_) {
hi->hi_++;
}

ExtendAndMultiplyUint(x.lo_, y.hi_, &t.hi_, &t.lo_);
lo->hi_ += t.lo_;
// Check for overflow in lo.hi
if (lo->hi_ < t.lo_) {
hi->lo_++;
}
hi->lo_ += t.hi_;
// Check for overflow in hi.lo
if (hi->lo_ < t.hi_) {
hi->hi_++;
}
#endif
}

void MultiplyUint256(uint128_t x_hi, uint128_t x_lo, uint128_t y_hi, uint128_t y_lo,
uint128_t* hi, uint128_t* lo) {
ExtendAndMultiplyUint128(x_lo, y_lo, hi, lo);
*hi += x_hi * y_lo;
*hi += x_lo * y_hi;
}

} // namespace

BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
Expand All @@ -325,10 +402,9 @@ BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
const bool negate = Sign() != right.Sign();
BasicDecimal128 x = BasicDecimal128::Abs(*this);
BasicDecimal128 y = BasicDecimal128::Abs(right);
uint64_t hi;
MultiplyUint128(x.high_bits(), x.low_bits(), y.high_bits(), y.low_bits(), &hi,
&low_bits_);
high_bits_ = hi;
uint128_t r = uint128_t(x) * uint128_t(y);
high_bits_ = r.hi();
low_bits_ = r.lo();
if (negate) {
Negate();
}
Expand Down Expand Up @@ -800,6 +876,13 @@ BasicDecimal256& BasicDecimal256::Negate() {
return *this;
}

BasicDecimal256& BasicDecimal256::Abs() { return *this < 0 ? Negate() : *this; }

BasicDecimal256 BasicDecimal256::Abs(const BasicDecimal256& in) {
BasicDecimal256 result(in);
return result.Abs();
}

std::array<uint8_t, 32> BasicDecimal256::ToBytes() const {
std::array<uint8_t, 32> out{{0}};
ToBytes(out.data());
Expand All @@ -821,12 +904,41 @@ void BasicDecimal256::ToBytes(uint8_t* out) const {
#endif
}

BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) {
// Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the
// min the negation taking the absolute values here should always be safe.
const bool negate = Sign() != right.Sign();
BasicDecimal256 x = BasicDecimal256::Abs(*this);
BasicDecimal256 y = BasicDecimal256::Abs(right);

uint128_t r_hi;
uint128_t r_lo;
MultiplyUint256({x.little_endian_array_[3], x.little_endian_array_[2]},
{x.little_endian_array_[1], x.little_endian_array_[0]},
{y.little_endian_array_[3], y.little_endian_array_[2]},
{y.little_endian_array_[1], y.little_endian_array_[0]}, &r_hi, &r_lo);
little_endian_array_[0] = r_lo.lo();
little_endian_array_[1] = r_lo.hi();
little_endian_array_[2] = r_hi.lo();
little_endian_array_[3] = r_hi.hi();
if (negate) {
Negate();
}
return *this;
}

DecimalStatus BasicDecimal256::Rescale(int32_t original_scale, int32_t new_scale,
BasicDecimal256* out) const {
// TODO: implement.
return DecimalStatus::kSuccess;
}

BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) {
BasicDecimal256 result = left;
result *= right;
return result;
}

bool operator==(const BasicDecimal256& left, const BasicDecimal256& right) {
return left.little_endian_array() == right.little_endian_array();
}
Expand Down
24 changes: 22 additions & 2 deletions cpp/src/arrow/util/basic_decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ class ARROW_EXPORT BasicDecimal128 {
BasicDecimal128& operator>>=(uint32_t bits);

/// \brief Get the high bits of the two's complement representation of the number.
inline int64_t high_bits() const { return high_bits_; }
inline constexpr int64_t high_bits() const { return high_bits_; }

/// \brief Get the low bits of the two's complement representation of the number.
inline uint64_t low_bits() const { return low_bits_; }
inline constexpr uint64_t low_bits() const { return low_bits_; }

/// \brief Return the raw bytes of the value in native-endian byte order.
std::array<uint8_t, 16> ToBytes() const;
Expand Down Expand Up @@ -195,13 +195,23 @@ class ARROW_EXPORT BasicDecimal256 {
: little_endian_array_({static_cast<uint64_t>(value), extend(value), extend(value),
extend(value)}) {}

constexpr BasicDecimal256(BasicDecimal128 value) noexcept
: little_endian_array_({value.low_bits(), static_cast<uint64_t>(value.high_bits()),
extend(value.high_bits()), extend(value.high_bits())}) {}

/// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in
/// native-endian byte order.
explicit BasicDecimal256(const uint8_t* bytes);

/// \brief Negate the current value (in-place)
BasicDecimal256& Negate();

/// \brief Absolute value (in-place)
BasicDecimal256& Abs();

/// \brief Absolute value
static BasicDecimal256 Abs(const BasicDecimal256& left);

/// \brief Get the bits of the two's complement representation of the number. The 4
/// elements are in little endian order. The bits within each uint64_t element are in
/// native endian order. For example,
Expand All @@ -220,6 +230,13 @@ class ARROW_EXPORT BasicDecimal256 {
DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
BasicDecimal256* out) const;

inline int64_t Sign() const {
return 1 | (static_cast<int64_t>(little_endian_array_[3]) >> 63);
}

/// \brief Multiply this number by another number. The result is truncated to 256 bits.
BasicDecimal256& operator*=(const BasicDecimal256& right);

private:
template <typename T>
inline static constexpr uint64_t extend(T low_bits) noexcept {
Expand All @@ -234,4 +251,7 @@ ARROW_EXPORT bool operator<(const BasicDecimal256& left, const BasicDecimal256&
ARROW_EXPORT bool operator<=(const BasicDecimal256& left, const BasicDecimal256& right);
ARROW_EXPORT bool operator>(const BasicDecimal256& left, const BasicDecimal256& right);
ARROW_EXPORT bool operator>=(const BasicDecimal256& left, const BasicDecimal256& right);

ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left,
const BasicDecimal256& right);
} // namespace arrow
5 changes: 5 additions & 0 deletions cpp/src/arrow/util/decimal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -721,4 +721,9 @@ Result<Decimal256> Decimal256::FromString(const char* s) {
Status Decimal256::ToArrowStatus(DecimalStatus dstatus) const {
return arrow::ToArrowStatus(dstatus, 256);
}

std::ostream& operator<<(std::ostream& os, const Decimal256& decimal) {
os << decimal.ToIntegerString();
return os;
}
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/util/decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ class ARROW_EXPORT Decimal256 : public BasicDecimal256 {
return std::move(out);
}

friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os,
const Decimal256& decimal);

private:
/// Converts internal error code to Status
Status ToArrowStatus(DecimalStatus dstatus) const;
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/util/decimal_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,21 @@ static void BinaryMathOp(benchmark::State& state) { // NOLINT non-const referen
state.SetItemsProcessed(state.iterations() * kValueSize);
}

static void BinaryMathOp256(benchmark::State& state) { // NOLINT non-const reference
std::vector<BasicDecimal256> v1, v2;
for (uint64_t x = 0; x < kValueSize; x++) {
v1.push_back(BasicDecimal256({100 + x, 100 + x, 100 + x, 100 + x}));
v2.push_back(BasicDecimal256({200 + x, 200 + x, 200 + x, 200 + x}));
}

for (auto _ : state) {
for (int x = 0; x < kValueSize; x += 5) {
benchmark::DoNotOptimize(v1[x + 2] * v2[x + 2]);
}
}
state.SetItemsProcessed(state.iterations() * kValueSize);
}

static void UnaryOp(benchmark::State& state) { // NOLINT non-const reference
std::vector<BasicDecimal128> v;
for (int x = 0; x < kValueSize; x++) {
Expand Down Expand Up @@ -191,6 +206,7 @@ static void BinaryBitOp(benchmark::State& state) { // NOLINT non-const referenc
BENCHMARK(FromString);
BENCHMARK(ToString);
BENCHMARK(BinaryMathOp);
BENCHMARK(BinaryMathOp256);
BENCHMARK(BinaryMathOpAggregate);
BENCHMARK(BinaryCompareOp);
BENCHMARK(BinaryCompareOpConstant);
Expand Down
Loading