diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md
index 8220e9b718..8062d0b09d 100644
--- a/g3doc/quick_reference.md
+++ b/g3doc/quick_reference.md
@@ -1004,6 +1004,24 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to
`Not(Or(IsNaN(v), IsInf(v)))`.
+#### Masked floating-point classification
+
+All ops in this section return `false` for `mask=false` lanes. These are
+equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.
+
+* `V`: `{f}` \
+ M **MaskedIsNaN**(V v)
: returns mask indicating whether `v[i]`
+ is "not a number" (unordered) or `false` if `m[i]` is false.
+
+* `V`: `{f}` \
+ M **MaskedIsInf**(V v)
: returns mask indicating whether `v[i]`
+ is positive or negative infinity or `false` if `m[i]` is false.
+
+* `V`: `{f}` \
+ M **MaskedIsFinite**(V v)
: returns mask indicating whether
+ `v[i]` is neither NaN nor infinity, i.e. normal, subnormal or zero or
+ `false` if `m[i]` is false. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`.
+
### Logical
* `V`: `{u,i}` \
@@ -1477,6 +1495,29 @@ These return a mask (see above) indicating whether the condition is true.
for comparing 64-bit keys alongside 64-bit values. Only available if
`HWY_TARGET != HWY_SCALAR`.
+#### Masked comparison
+
+All ops in this section return `false` for `mask=false` lanes. These are
+equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.
+
+* M **MaskedCompEq**(M m, V a, V b)
: returns `a[i] == b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedCompNe**(M m, V a, V b)
: returns `a[i] != b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedCompLt**(M m, V a, V b)
: returns `a[i] < b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedCompGt**(M m, V a, V b)
: returns `a[i] > b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedCompLe**(M m, V a, V b)
: returns `a[i] <= b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedCompGe**(M m, V a, V b)
: returns `a[i] >= b[i]` or
+ `false` if `m[i]` is false.
+
### Memory
Memory operands are little-endian, otherwise their order would depend on the
diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h
index 2dde1479de..2200d9f2c5 100644
--- a/hwy/ops/arm_sve-inl.h
+++ b/hwy/ops/arm_sve-inl.h
@@ -1783,6 +1783,77 @@ HWY_API svbool_t IsFinite(const V v) {
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField()));
}
+// ------------------------------ MaskedCompEq etc.
+#ifdef HWY_NATIVE_MASKED_COMP
+#undef HWY_NATIVE_MASKED_COMP
+#else
+#define HWY_NATIVE_MASKED_COMP
+#endif
+
+// mask = f(mask, vector, vector)
+#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
+ HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
+ HWY_SVE_V(BASE, BITS) b) { \
+ return sv##OP##_##CHAR##BITS(m, a, b); \
+ }
+
+namespace detail {
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)
+
+} // namespace detail
+
+#undef HWY_SVE_COMPARE_Z
+
+template >
+HWY_API MFromD MaskedCompEq(M m, V a, V b) {
+ return detail::MaskedEq(m, a, b);
+}
+
+template >
+HWY_API MFromD MaskedCompNe(M m, V a, V b) {
+ return detail::MaskedNe(m, a, b);
+}
+
+template >
+HWY_API MFromD MaskedCompLt(M m, V a, V b) {
+ return detail::MaskedLt(m, a, b);
+}
+
+template >
+HWY_API MFromD MaskedCompGt(M m, V a, V b) {
+ // Swap args to reverse comparison
+ return detail::MaskedLt(m, b, a);
+}
+
+template >
+HWY_API MFromD MaskedCompLe(M m, V a, V b) {
+ return detail::MaskedLe(m, a, b);
+}
+
+template >
+HWY_API MFromD MaskedCompGe(M m, V a, V b) {
+ // Swap args to reverse comparison
+ return detail::MaskedLe(m, b, a);
+}
+
+template >
+HWY_API MFromD MaskedIsInf(const M m, const V v) {
+ return And(m, IsInf(v));
+}
+
+template >
+HWY_API MFromD MaskedIsFinite(const M m, const V v) {
+ return And(m, IsFinite(v));
+}
+
+template >
+HWY_API MFromD MaskedIsNaN(const M m, const V v) {
+ return detail::MaskedNe(m, v, v);
+}
+
// ================================================== MEMORY
// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h
index 99b518d99c..ad094a1f37 100644
--- a/hwy/ops/generic_ops-inl.h
+++ b/hwy/ops/generic_ops-inl.h
@@ -574,6 +574,60 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH
+// ------------------------------ MaskedCompEq etc.
+#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_MASKED_COMP
+#undef HWY_NATIVE_MASKED_COMP
+#else
+#define HWY_NATIVE_MASKED_COMP
+#endif
+
+template
+HWY_API auto MaskedCompEq(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Eq(a, b));
+}
+
+template
+HWY_API auto MaskedCompNe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Ne(a, b));
+}
+
+template
+HWY_API auto MaskedCompLt(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Lt(a, b));
+}
+
+template
+HWY_API auto MaskedCompGt(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Gt(a, b));
+}
+
+template
+HWY_API auto MaskedCompLe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Le(a, b));
+}
+
+template
+HWY_API auto MaskedCompGe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Ge(a, b));
+}
+
+template >
+HWY_API MFromD MaskedIsInf(const M m, const V v) {
+ return And(m, IsInf(v));
+}
+
+template >
+HWY_API MFromD MaskedIsFinite(const M m, const V v) {
+ return And(m, IsFinite(v));
+}
+
+template >
+HWY_API MFromD MaskedIsNaN(const M m, const V v) {
+ return And(m, IsNaN(v));
+}
+#endif // HWY_NATIVE_MASKED_COMP
+
// ------------------------------ IfNegativeThenNegOrUndefIfZero
#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
diff --git a/hwy/tests/compare_test.cc b/hwy/tests/compare_test.cc
index 728b58c3dc..38fed3aa1c 100644
--- a/hwy/tests/compare_test.cc
+++ b/hwy/tests/compare_test.cc
@@ -673,7 +673,176 @@ HWY_NOINLINE void TestAllEq128Upper() {
ForGEVectors<128, TestEq128Upper>()(uint64_t());
}
-} // namespace
+struct TestMaskedComparision {
+ template
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ const Vec v0 = Zero(d);
+ const Vec v2 = Iota(d, 2);
+ const Vec v2b = Iota(d, 2);
+ const Vec v3 = Iota(d, 3);
+ const size_t N = Lanes(d);
+
+ const Mask mask_false = MaskFalse(d);
+ const Mask mask_true = MaskTrue(d);
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2b));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2b));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLt(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));
+
+ auto bool_lanes = AllocateAligned(N);
+ HWY_ASSERT(bool_lanes);
+
+ for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
+ for (size_t i = 0; i < N; ++i) {
+ bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
+ }
+
+ const Vec mask_i = Load(d, bool_lanes.get());
+ const Mask mask = RebindMask(d, Gt(mask_i, Zero(d)));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2b));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2b));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLt(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllMaskedComparision() {
+ ForAllTypes(ForPartialVectors());
+}
+
+struct TestMaskedFloatClassification {
+ template
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ const Vec v0 = Zero(d);
+ const Vec v1 = Iota(d, 2);
+ const Vec v2 = Inf(d);
+ const Vec v3 = NaN(d);
+ const size_t N = Lanes(d);
+
+ const Mask mask_false = MaskFalse(d);
+ const Mask mask_true = MaskTrue(d);
+
+ // Test against all zeros
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0));
+
+ // Test against finite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v1));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v1));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1));
+
+ // Test against infinite values
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsInf(mask_true, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2));
+
+ // Test against NaN values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3));
+
+ auto bool_lanes = AllocateAligned(N);
+ HWY_ASSERT(bool_lanes);
+
+ for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
+ for (size_t i = 0; i < N; ++i) {
+ bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
+ }
+
+ const Vec mask_i = Load(d, bool_lanes.get());
+ const Mask mask = RebindMask(d, Gt(mask_i, Zero(d)));
+
+ // Test against all zeros
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0));
+
+ // Test against finite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v1));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v1));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1));
+
+ // Test against infinite values
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedIsInf(mask, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2));
+
+ // Test against NaN values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v3));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3));
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllMaskedFloatClassification() {
+ ForFloatTypes(ForPartialVectors());
+}
+} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
@@ -695,6 +864,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper);
+
+HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedComparision);
+HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy