Skip to content

Commit

Permalink
Move shift operators into the class
Browse files Browse the repository at this point in the history
  • Loading branch information
chfast committed Sep 3, 2024
1 parent 0ec0e70 commit 2c58a45
Showing 1 changed file with 128 additions and 148 deletions.
276 changes: 128 additions & 148 deletions include/intx/intx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ struct uint<128>
(shift < 128) ? uint{0, x[0] << (shift - 64)} : 0;
}

friend constexpr uint operator<<(uint x, std::convertible_to<uint64_t> auto shift) noexcept
{
return x << static_cast<uint64_t>(shift);
}

friend constexpr uint operator<<(uint x, uint shift) noexcept
{
if (shift[1] != 0) [[unlikely]]
Expand All @@ -350,6 +355,11 @@ struct uint<128>
(shift < 128) ? uint{x[1] >> (shift - 64)} : 0;
}

friend constexpr uint operator>>(uint x, std::convertible_to<uint64_t> auto shift) noexcept
{
return x >> static_cast<uint64_t>(shift);
}

friend constexpr uint operator>>(uint x, uint shift) noexcept
{
if (shift[1] != 0) [[unlikely]]
Expand All @@ -373,8 +383,8 @@ struct uint<128>
constexpr uint& operator|=(uint y) noexcept { return *this = *this | y; }
constexpr uint& operator&=(uint y) noexcept { return *this = *this & y; }
constexpr uint& operator^=(uint y) noexcept { return *this = *this ^ y; }
constexpr uint& operator<<=(uint64_t shift) noexcept { return *this = *this << shift; }
constexpr uint& operator>>=(uint64_t shift) noexcept { return *this = *this >> shift; }
constexpr uint& operator<<=(uint shift) noexcept { return *this = *this << shift; }
constexpr uint& operator>>=(uint shift) noexcept { return *this = *this >> shift; }
constexpr uint& operator/=(uint y) noexcept { return *this = *this / y; }
constexpr uint& operator%=(uint y) noexcept { return *this = *this % y; }
};
Expand Down Expand Up @@ -1000,171 +1010,158 @@ struct uint
friend constexpr bool operator>(const uint& x, const uint& y) noexcept { return y < x; }
friend constexpr bool operator>=(const uint& x, const uint& y) noexcept { return !(x < y); }
friend constexpr bool operator<=(const uint& x, const uint& y) noexcept { return !(y < x); }
};

using uint256 = uint<256>;


/// Signed less than comparison.
///
/// Interprets the arguments as two's complement signed integers
/// and checks the "less than" relation.
template <unsigned N>
inline constexpr bool slt(const uint<N>& x, const uint<N>& y) noexcept
{
constexpr auto top_word_idx = uint<N>::num_words - 1;
const auto x_neg = static_cast<int64_t>(x[top_word_idx]) < 0;
const auto y_neg = static_cast<int64_t>(y[top_word_idx]) < 0;
return ((x_neg ^ y_neg) != 0) ? x_neg : x < y;
}
friend inline constexpr uint operator<<(const uint& x, uint64_t shift) noexcept
{
if (shift >= num_bits) [[unlikely]]
return 0;

template <unsigned N>
inline constexpr uint<N> operator<<(const uint<N>& x, uint64_t shift) noexcept
{
if (shift >= uint<N>::num_bits) [[unlikely]]
return 0;
if constexpr (N == 256)
{
constexpr auto half_bits = num_bits / 2;

if constexpr (N == 256)
{
constexpr auto half_bits = uint<N>::num_bits / 2;
const auto xlo = uint128{x[0], x[1]};

const auto xlo = uint128{x[0], x[1]};
if (shift < half_bits)
{
const auto lo = xlo << shift;

const auto xhi = uint128{x[2], x[3]};

// Find the part moved from lo to hi.
// The shift right here can be invalid:
// for shift == 0 => rshift == half_bits.
// Split it into 2 valid shifts by (rshift - 1) and 1.
const auto rshift = half_bits - shift;
const auto lo_overflow = (xlo >> (rshift - 1)) >> 1;
const auto hi = (xhi << shift) | lo_overflow;
return {lo[0], lo[1], hi[0], hi[1]};
}

if (shift < half_bits)
const auto hi = xlo << (shift - half_bits);
return {0, 0, hi[0], hi[1]};
}
else
{
const auto lo = xlo << shift;
constexpr auto word_bits = sizeof(uint64_t) * 8;

const auto xhi = uint128{x[2], x[3]};
const auto s = shift % word_bits;
const auto skip = static_cast<size_t>(shift / word_bits);

// Find the part moved from lo to hi.
// The shift right here can be invalid:
// for shift == 0 => rshift == half_bits.
// Split it into 2 valid shifts by (rshift - 1) and 1.
const auto rshift = half_bits - shift;
const auto lo_overflow = (xlo >> (rshift - 1)) >> 1;
const auto hi = (xhi << shift) | lo_overflow;
return {lo[0], lo[1], hi[0], hi[1]};
uint r;
uint64_t carry = 0;
for (size_t i = 0; i < (num_words - skip); ++i)
{
r[i + skip] = (x[i] << s) | carry;
carry = (x[i] >> (word_bits - s - 1)) >> 1;
}
return r;
}
}

const auto hi = xlo << (shift - half_bits);
return {0, 0, hi[0], hi[1]};
friend inline constexpr uint operator<<(
const uint& x, std::convertible_to<uint64_t> auto shift) noexcept
{
return x << static_cast<uint64_t>(shift);
}
else

friend inline constexpr uint operator<<(const uint& x, const uint& shift) noexcept
{
constexpr auto word_bits = sizeof(uint64_t) * 8;
// TODO: This optimisation should be handled by operator<.
uint64_t high_words_fold = 0;
for (size_t i = 1; i < num_words; ++i)
high_words_fold |= shift[i];

const auto s = shift % word_bits;
const auto skip = static_cast<size_t>(shift / word_bits);
if (high_words_fold != 0) [[unlikely]]
return 0;

uint<N> r;
uint64_t carry = 0;
for (size_t i = 0; i < (uint<N>::num_words - skip); ++i)
{
r[i + skip] = (x[i] << s) | carry;
carry = (x[i] >> (word_bits - s - 1)) >> 1;
}
return r;
return x << shift[0];
}
}

template <unsigned N>
inline constexpr uint<N> operator>>(const uint<N>& x, uint64_t shift) noexcept
{
if (shift >= uint<N>::num_bits) [[unlikely]]
return 0;

if constexpr (N == 256)
friend inline constexpr uint operator>>(const uint& x, uint64_t shift) noexcept
{
constexpr auto half_bits = uint<N>::num_bits / 2;

const auto xhi = uint128{x[2], x[3]};
if (shift >= num_bits) [[unlikely]]
return 0;

if (shift < half_bits)
if constexpr (N == 256)
{
const auto hi = xhi >> shift;
constexpr auto half_bits = num_bits / 2;

const auto xlo = uint128{x[0], x[1]};
const auto xhi = uint128{x[2], x[3]};

// Find the part moved from hi to lo.
// The shift left here can be invalid:
// for shift == 0 => lshift == half_bits.
// Split it into 2 valid shifts by (lshift - 1) and 1.
const auto lshift = half_bits - shift;
const auto hi_overflow = (xhi << (lshift - 1)) << 1;
const auto lo = (xlo >> shift) | hi_overflow;
return {lo[0], lo[1], hi[0], hi[1]};
}
if (shift < half_bits)
{
const auto hi = xhi >> shift;

const auto xlo = uint128{x[0], x[1]};

// Find the part moved from hi to lo.
// The shift left here can be invalid:
// for shift == 0 => lshift == half_bits.
// Split it into 2 valid shifts by (lshift - 1) and 1.
const auto lshift = half_bits - shift;
const auto hi_overflow = (xhi << (lshift - 1)) << 1;
const auto lo = (xlo >> shift) | hi_overflow;
return {lo[0], lo[1], hi[0], hi[1]};
}

const auto lo = xhi >> (shift - half_bits);
return {lo[0], lo[1], 0, 0};
}
else
{
constexpr auto num_words = uint<N>::num_words;
constexpr auto word_bits = sizeof(uint64_t) * 8;
const auto lo = xhi >> (shift - half_bits);
return {lo[0], lo[1], 0, 0};
}
else
{
constexpr auto word_bits = sizeof(uint64_t) * 8;

const auto s = shift % word_bits;
const auto skip = static_cast<size_t>(shift / word_bits);
const auto s = shift % word_bits;
const auto skip = static_cast<size_t>(shift / word_bits);

uint<N> r;
uint64_t carry = 0;
for (size_t i = 0; i < (num_words - skip); ++i)
{
r[num_words - 1 - i - skip] = (x[num_words - 1 - i] >> s) | carry;
carry = (x[num_words - 1 - i] << (word_bits - s - 1)) << 1;
uint r;
uint64_t carry = 0;
for (size_t i = 0; i < (num_words - skip); ++i)
{
r[num_words - 1 - i - skip] = (x[num_words - 1 - i] >> s) | carry;
carry = (x[num_words - 1 - i] << (word_bits - s - 1)) << 1;
}
return r;
}
return r;
}
}

template <unsigned N>
inline constexpr uint<N> operator<<(const uint<N>& x, const uint<N>& shift) noexcept
{
uint64_t high_words_fold = 0;
for (size_t i = 1; i < uint<N>::num_words; ++i)
high_words_fold |= shift[i];

if (INTX_UNLIKELY(high_words_fold != 0))
return 0;
friend inline constexpr uint operator>>(
const uint& x, std::convertible_to<uint64_t> auto shift) noexcept
{
return x >> static_cast<uint64_t>(shift);
}

return x << shift[0];
}
friend inline constexpr uint operator>>(const uint& x, const uint& shift) noexcept
{
uint64_t high_words_fold = 0;
for (size_t i = 1; i < num_words; ++i)
high_words_fold |= shift[i];

template <unsigned N>
inline constexpr uint<N> operator>>(const uint<N>& x, const uint<N>& shift) noexcept
{
uint64_t high_words_fold = 0;
for (size_t i = 1; i < uint<N>::num_words; ++i)
high_words_fold |= shift[i];
if (high_words_fold != 0) [[unlikely]]
return 0;

if (INTX_UNLIKELY(high_words_fold != 0))
return 0;
return x >> shift[0];
}

return x >> shift[0];
}
constexpr uint& operator<<=(uint shift) noexcept { return *this = *this << shift; }
constexpr uint& operator>>=(uint shift) noexcept { return *this = *this >> shift; }
};

template <unsigned N, typename T>
inline constexpr uint<N> operator<<(const uint<N>& x, const T& shift) noexcept
requires std::is_convertible_v<T, uint<N>>
{
if (shift < T{sizeof(x) * 8})
return x << static_cast<uint64_t>(shift);
return 0;
}
using uint256 = uint<256>;

template <unsigned N, typename T>
inline constexpr uint<N> operator>>(const uint<N>& x, const T& shift) noexcept
requires std::is_convertible_v<T, uint<N>>
{
if (shift < T{sizeof(x) * 8})
return x >> static_cast<uint64_t>(shift);
return 0;
}

/// Signed less than comparison.
///
/// Interprets the arguments as two's complement signed integers
/// and checks the "less than" relation.
template <unsigned N>
inline constexpr uint<N>& operator>>=(uint<N>& x, uint64_t shift) noexcept
inline constexpr bool slt(const uint<N>& x, const uint<N>& y) noexcept
{
return x = x >> shift;
constexpr auto top_word_idx = uint<N>::num_words - 1;
const auto x_neg = static_cast<int64_t>(x[top_word_idx]) < 0;
const auto y_neg = static_cast<int64_t>(y[top_word_idx]) < 0;
return ((x_neg ^ y_neg) != 0) ? x_neg : x < y;
}


Expand Down Expand Up @@ -1541,23 +1538,6 @@ inline constexpr uint<N> bswap(const uint<N>& x) noexcept
}


// Support for type conversions for binary operators.

template <unsigned N, typename T>
inline constexpr uint<N>& operator<<=(uint<N>& x, const T& y) noexcept
requires std::is_convertible_v<T, uint<N>>
{
return x = x << y;
}

template <unsigned N, typename T>
inline constexpr uint<N>& operator>>=(uint<N>& x, const T& y) noexcept
requires std::is_convertible_v<T, uint<N>>
{
return x = x >> y;
}


inline constexpr uint256 addmod(const uint256& x, const uint256& y, const uint256& mod) noexcept
{
// Fast path for mod >= 2^192, with x and y at most slightly bigger than mod.
Expand Down

0 comments on commit 2c58a45

Please sign in to comment.