diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8220e9b718..8062d0b09d 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -1004,6 +1004,24 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2): neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`. +#### Masked floating-point classification + +All ops in this section return `false` for `mask=false` lanes. These are +equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc. + +* `V`: `{f}` \ + M **MaskedIsNaN**(V v): returns mask indicating whether `v[i]` + is "not a number" (unordered) or `false` if `m[i]` is false. + +* `V`: `{f}` \ + M **MaskedIsInf**(V v): returns mask indicating whether `v[i]` + is positive or negative infinity or `false` if `m[i]` is false. + +* `V`: `{f}` \ + M **MaskedIsFinite**(V v): returns mask indicating whether + `v[i]` is neither NaN nor infinity, i.e. normal, subnormal or zero or + `false` if `m[i]` is false. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`. + ### Logical * `V`: `{u,i}` \ @@ -1477,6 +1495,29 @@ These return a mask (see above) indicating whether the condition is true. for comparing 64-bit keys alongside 64-bit values. Only available if `HWY_TARGET != HWY_SCALAR`. +#### Masked comparison + +All ops in this section return `false` for `mask=false` lanes. These are +equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc. + +* M **MaskedCompEq**(M m, V a, V b): returns `a[i] == b[i]` or + `false` if `m[i]` is false. + +* M **MaskedCompNe**(M m, V a, V b): returns `a[i] != b[i]` or + `false` if `m[i]` is false. + +* M **MaskedCompLt**(M m, V a, V b): returns `a[i] < b[i]` or + `false` if `m[i]` is false. + +* M **MaskedCompGt**(M m, V a, V b): returns `a[i] > b[i]` or + `false` if `m[i]` is false. + +* M **MaskedCompLe**(M m, V a, V b): returns `a[i] <= b[i]` or + `false` if `m[i]` is false. + +* M **MaskedCompGe**(M m, V a, V b): returns `a[i] >= b[i]` or + `false` if `m[i]` is false. + ### Memory Memory operands are little-endian, otherwise their order would depend on the diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 2dde1479de..2200d9f2c5 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -1783,6 +1783,77 @@ HWY_API svbool_t IsFinite(const V v) { return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField())); } +// ------------------------------ MaskedCompEq etc. +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// mask = f(mask, vector, vector) +#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(m, a, b); \ + } + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) + +} // namespace detail + +#undef HWY_SVE_COMPARE_Z + +template > +HWY_API MFromD MaskedCompEq(M m, V a, V b) { + return detail::MaskedEq(m, a, b); +} + +template > +HWY_API MFromD MaskedCompNe(M m, V a, V b) { + return detail::MaskedNe(m, a, b); +} + +template > +HWY_API MFromD MaskedCompLt(M m, V a, V b) { + return detail::MaskedLt(m, a, b); +} + +template > +HWY_API MFromD MaskedCompGt(M m, V a, V b) { + // Swap args to reverse comparison + return detail::MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedCompLe(M m, V a, V b) { + return detail::MaskedLe(m, a, b); +} + +template > +HWY_API MFromD MaskedCompGe(M m, V a, V b) { + // Swap args to reverse comparison + return detail::MaskedLe(m, b, a); +} + +template > +HWY_API MFromD MaskedIsInf(const M m, const V v) { + return And(m, IsInf(v)); +} + +template > +HWY_API MFromD MaskedIsFinite(const M m, const V v) { + return And(m, IsFinite(v)); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return detail::MaskedNe(m, v, v); +} + // ================================================== MEMORY // ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 99b518d99c..ad094a1f37 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -574,6 +574,60 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { } #endif // HWY_NATIVE_MASKED_ARITH +// ------------------------------ MaskedCompEq etc. +#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +template +HWY_API auto MaskedCompEq(M m, V a, V b) -> decltype(a == b) { + return And(m, Eq(a, b)); +} + +template +HWY_API auto MaskedCompNe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ne(a, b)); +} + +template +HWY_API auto MaskedCompLt(M m, V a, V b) -> decltype(a == b) { + return And(m, Lt(a, b)); +} + +template +HWY_API auto MaskedCompGt(M m, V a, V b) -> decltype(a == b) { + return And(m, Gt(a, b)); +} + +template +HWY_API auto MaskedCompLe(M m, V a, V b) -> decltype(a == b) { + return And(m, Le(a, b)); +} + +template +HWY_API auto MaskedCompGe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ge(a, b)); +} + +template > +HWY_API MFromD MaskedIsInf(const M m, const V v) { + return And(m, IsInf(v)); +} + +template > +HWY_API MFromD MaskedIsFinite(const M m, const V v) { + return And(m, IsFinite(v)); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return And(m, IsNaN(v)); +} +#endif // HWY_NATIVE_MASKED_COMP + // ------------------------------ IfNegativeThenNegOrUndefIfZero #if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \ diff --git a/hwy/tests/compare_test.cc b/hwy/tests/compare_test.cc index 728b58c3dc..38fed3aa1c 100644 --- a/hwy/tests/compare_test.cc +++ b/hwy/tests/compare_test.cc @@ -673,7 +673,176 @@ HWY_NOINLINE void TestAllEq128Upper() { ForGEVectors<128, TestEq128Upper>()(uint64_t()); } -} // namespace +struct TestMaskedComparision { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v0 = Zero(d); + const Vec v2 = Iota(d, 2); + const Vec v2b = Iota(d, 2); + const Vec v3 = Iota(d, 3); + const size_t N = Lanes(d); + + const Mask mask_false = MaskFalse(d); + const Mask mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2b)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLt(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2)); + + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2b)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLt(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedComparision() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskedFloatClassification { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v0 = Zero(d); + const Vec v1 = Iota(d, 2); + const Vec v2 = Inf(d); + const Vec v3 = NaN(d); + const size_t N = Lanes(d); + + const Mask mask_false = MaskFalse(d); + const Mask mask_true = MaskTrue(d); + + // Test against all zeros + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0)); + + // Test against finite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v1)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v1)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1)); + + // Test against infinite values + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsInf(mask_true, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2)); + + // Test against NaN values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3)); + + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + // Test against all zeros + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0)); + + // Test against finite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v1)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v1)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1)); + + // Test against infinite values + HWY_ASSERT_MASK_EQ(d, mask, MaskedIsInf(mask, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2)); + + // Test against NaN values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v3)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedFloatClassification() { + ForFloatTypes(ForPartialVectors()); +} +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy @@ -695,6 +864,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper); + +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedComparision); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification); HWY_AFTER_TEST(); } // namespace } // namespace hwy