Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masked compare and floating point classifications #2427

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,24 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to
`Not(Or(IsNaN(v), IsInf(v)))`.

#### Masked floating-point classification

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And(m, IsNaN)?


* `V`: `{f}` \
<code>M **MaskedIsNaN**(V v)</code>: returns mask indicating whether `v[i]`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the mask argument to the documentation :)

is "not a number" (unordered) or `false` if `m[i]` is false.

* `V`: `{f}` \
<code>M **MaskedIsInf**(V v)</code>: returns mask indicating whether `v[i]`
is positive or negative infinity or `false` if `m[i]` is false.

* `V`: `{f}` \
<code>M **MaskedIsFinite**(V v)</code>: returns mask indicating whether
`v[i]` is neither NaN nor infinity, i.e. normal, subnormal or zero or
`false` if `m[i]` is false. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`.

### Logical

* `V`: `{u,i}` \
Expand Down Expand Up @@ -1477,6 +1495,29 @@ These return a mask (see above) indicating whether the condition is true.
for comparing 64-bit keys alongside 64-bit values. Only available if
`HWY_TARGET != HWY_SCALAR`.

#### Masked comparison

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.

* <code>M **MaskedCompEq**(M m, V a, V b)</code>: returns `a[i] == b[i]` or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just call it MaskedEq for consistency with the usual naming convention, which is only to prepend Masked to the existing Eq?

`false` if `m[i]` is false.

* <code>M **MaskedCompNe**(M m, V a, V b)</code>: returns `a[i] != b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompLt**(M m, V a, V b)</code>: returns `a[i] < b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompGt**(M m, V a, V b)</code>: returns `a[i] > b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompLe**(M m, V a, V b)</code>: returns `a[i] <= b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompGe**(M m, V a, V b)</code>: returns `a[i] >= b[i]` or
`false` if `m[i]` is false.

### Memory

Memory operands are little-endian, otherwise their order would depend on the
Expand Down
71 changes: 71 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,77 @@ HWY_API svbool_t IsFinite(const V v) {
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>()));
}

// ------------------------------ MaskedCompEq etc.
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

// mask = f(mask, vector, vector)
#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS(m, a, b); \
}

namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can expose these directly. Rather than putting them in detail:: and adding a wrapper function, you can just remove the namespace detail, and specify the desired name of the op as the second to last argument (MaskedEq is already good IMO).

HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)

} // namespace detail

#undef HWY_SVE_COMPARE_Z

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompEq(M m, V a, V b) {
return detail::MaskedEq(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompNe(M m, V a, V b) {
return detail::MaskedNe(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompLt(M m, V a, V b) {
return detail::MaskedLt(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompGt(M m, V a, V b) {
// Swap args to reverse comparison
return detail::MaskedLt(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompLe(M m, V a, V b) {
return detail::MaskedLe(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompGe(M m, V a, V b) {
// Swap args to reverse comparison
return detail::MaskedLe(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsInf(const M m, const V v) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever plan to provide a faster implementation of MaskIfInf/IsFinite, or can those be removed?

return And(m, IsInf(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsFinite(const M m, const V v) {
return And(m, IsFinite(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return detail::MaskedNe(m, v, v);
}

// ================================================== MEMORY

// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
Expand Down
54 changes: 54 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,60 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedCompEq etc.
#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

template <class V, class M>
HWY_API auto MaskedCompEq(M m, V a, V b) -> decltype(a == b) {
return And(m, Eq(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompNe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ne(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompLt(M m, V a, V b) -> decltype(a == b) {
return And(m, Lt(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompGt(M m, V a, V b) -> decltype(a == b) {
return And(m, Gt(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompLe(M m, V a, V b) -> decltype(a == b) {
return And(m, Le(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompGe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ge(a, b));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsInf(const M m, const V v) {
return And(m, IsInf(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsFinite(const M m, const V v) {
return And(m, IsFinite(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return And(m, IsNaN(v));
}
#endif // HWY_NATIVE_MASKED_COMP

// ------------------------------ IfNegativeThenNegOrUndefIfZero

#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
Expand Down
174 changes: 173 additions & 1 deletion hwy/tests/compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,176 @@ HWY_NOINLINE void TestAllEq128Upper() {
ForGEVectors<128, TestEq128Upper>()(uint64_t());
}

} // namespace
struct TestMaskedComparision {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v2 = Iota(d, 2);
const Vec<D> v2b = Iota(d, 2);
const Vec<D> v3 = Iota(d, 3);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedComparision() {
ForAllTypes(ForPartialVectors<TestMaskedComparision>());
}

struct TestMaskedFloatClassification {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v1 = Iota(d, 2);
const Vec<D> v2 = Inf(d);
const Vec<D> v3 = NaN(d);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v1));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v1));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsInf(mask_true, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v3));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v1));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v1));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsInf(mask, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v3));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3));
}
}
};

HWY_NOINLINE void TestAllMaskedFloatClassification() {
ForFloatTypes(ForPartialVectors<TestMaskedFloatClassification>());
}
} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
Expand All @@ -695,6 +864,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper);

HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedComparision);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading