-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Various masked operations #2428
base: master
Are you sure you want to change the base?
Changes from all commits
d68360c
49433a8
7896095
d9a19a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1050,6 +1050,9 @@ types, and on SVE/RVV. | |
|
||
* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`. | ||
|
||
* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we mean |
||
or `zero` if `m[i]` is false. | ||
|
||
The following three-argument functions may be more efficient than assembling | ||
them from 2-argument functions: | ||
|
||
|
@@ -2237,6 +2240,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`: | |
must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index | ||
type `TI` must be an integer of the same size as `TFromD<D>`. | ||
|
||
* <code>V **TableLookupLanesOr**(M m, V a, V b, unspecified)</code> returns the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we don't yet have an optimized version of these op, and it's just a convenience wrapper over IfThenElse. Would it be an option to move this into a utility function within your codebase? It's not clear whether this provides enough value to be a documented op that all readers must know. |
||
result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and returns | ||
`b[i]` where `m[i]` is false. | ||
|
||
* <code>V **TableLookupLanesOrZero**(M m, V a, unspecified)</code> returns | ||
the result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and | ||
returns zero where `m[i]` is false. | ||
|
||
* <code>V **TwoTablesLookupLanesOr**(D d, M m, V a, V b, unspecified)</code> | ||
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where | ||
`m[i]` is true, and `a[i]` where `m[i]` is false. | ||
|
||
* <code>V **TwoTablesLookupLanesOrZero**(D d, M m, V a, V b, unspecified)</code> | ||
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where | ||
`m[i]` is true, and zero where `m[i]` is false. | ||
|
||
* <code>V **Per4LaneBlockShuffle**<size_t kIdx3, size_t kIdx2, size_t | ||
kIdx1, size_t kIdx0>(V v)</code> does a per 4-lane block shuffle of `v` | ||
if `Lanes(DFromV<V>())` is greater than or equal to 4 or a shuffle of the | ||
|
@@ -2377,6 +2396,24 @@ more efficient on some targets. | |
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes. | ||
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes. | ||
|
||
### Masked reductions | ||
|
||
**Note**: Horizontal operations (across lanes of the same vector) such as | ||
reductions are slower than normal SIMD operations and are typically used outside | ||
critical loops. | ||
|
||
All ops in this section ignore lanes where `mask=false`. These are equivalent | ||
to, and potentially more efficient than, `GetLane(SumOfLanes(d, | ||
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask | ||
elements are false. | ||
|
||
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! This looks useful. |
||
where `m[i]` is `true`. | ||
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all | ||
lanes where `m[i]` is `true`. | ||
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all | ||
lanes where `m[i]` is `true`. | ||
|
||
### Crypto | ||
|
||
Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) | |
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ | ||
return sv##OP##_##CHAR##BITS(v); \ | ||
} | ||
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: we have the naming convention P for predicate, for example in |
||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ | ||
return sv##OP##_##CHAR##BITS##_m(b, m, a); \ | ||
} | ||
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ | ||
return sv##OP##_##CHAR##BITS##_z(m, a); \ | ||
} | ||
|
||
// vector = f(vector, scalar), e.g. detail::AddN | ||
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
|
@@ -252,6 +261,17 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) | |
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ | ||
return sv##OP##_##CHAR##BITS##_x(m, a, b); \ | ||
} | ||
#define HWY_SVE_RETV_ARGMVV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ | ||
return sv##OP##_##CHAR##BITS##_m(m, a, b); \ | ||
} | ||
// User-specified mask. Mask=false value is zero. | ||
#define HWY_SVE_RETV_ARGMVVZ(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ | ||
return sv##OP##_##CHAR##BITS##_z(m, a, b); \ | ||
} | ||
|
||
#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
|
@@ -260,6 +280,13 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) | |
return sv##OP##_##CHAR##BITS(a, b, c); \ | ||
} | ||
|
||
#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ | ||
HWY_SVE_V(BASE, BITS) c) { \ | ||
return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \ | ||
} | ||
|
||
// ------------------------------ Lanes | ||
|
||
namespace detail { | ||
|
@@ -727,6 +754,9 @@ HWY_API V Or(const V a, const V b) { | |
return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); | ||
} | ||
|
||
// ------------------------------ MaskedOrOrZero | ||
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVVZ, MaskedOrOrZero, orr) | ||
|
||
// ------------------------------ Xor | ||
|
||
namespace detail { | ||
|
@@ -3288,6 +3318,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) { | |
return detail::MaxOfLanesM(detail::MakeMask(d), v); | ||
} | ||
|
||
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#else | ||
#define HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#endif | ||
|
||
template <class D, class M> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a TODO here that we can remove the SumOfLanesM in favor of using MaskedReduceSum directly. This entails adding the D arg to |
||
HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) { | ||
return detail::SumOfLanesM(m, v); | ||
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) { | ||
return detail::MinOfLanesM(m, v); | ||
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) { | ||
return detail::MaxOfLanesM(m, v); | ||
} | ||
|
||
// ------------------------------ SumOfLanes | ||
|
||
template <class D, HWY_IF_LANES_GT_D(D, 1)> | ||
|
@@ -4755,6 +4804,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) { | |
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); | ||
return IfThenElse(IsNegative(v), yes, no); | ||
} | ||
// ------------------------------ IfNegativeThenNegOrUndefIfZero | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This op is undocumented, do we intend to add it? If so, let's add documentation and test. |
||
|
||
#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG | ||
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG | ||
#else | ||
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG | ||
#endif | ||
|
||
#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ | ||
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ | ||
} | ||
|
||
HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) | ||
|
||
#undef HWY_SVE_NEG_IF | ||
|
||
// ------------------------------ AverageRound (ShiftRight) | ||
|
||
|
@@ -6291,13 +6357,19 @@ HWY_API V HighestSetBitIndex(V v) { | |
#undef HWY_SVE_IF_NOT_EMULATED_D | ||
#undef HWY_SVE_PTRUE | ||
#undef HWY_SVE_RETV_ARGMVV | ||
#undef HWY_SVE_RETV_ARGMVVZ | ||
#undef HWY_SVE_RETV_ARGPV | ||
#undef HWY_SVE_RETV_ARGPVN | ||
#undef HWY_SVE_RETV_ARGPVV | ||
#undef HWY_SVE_RETV_ARGV | ||
#undef HWY_SVE_RETV_ARGVN | ||
#undef HWY_SVE_RETV_ARGMV | ||
#undef HWY_SVE_RETV_ARGMV_M | ||
#undef HWY_SVE_RETV_ARGMV_Z | ||
#undef HWY_SVE_RETV_ARGVV | ||
#undef HWY_SVE_RETV_ARGMVV_M | ||
#undef HWY_SVE_RETV_ARGVVV | ||
#undef HWY_SVE_RETV_ARGMVVV | ||
#undef HWY_SVE_T | ||
#undef HWY_SVE_UNDEFINED | ||
#undef HWY_SVE_V | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -882,6 +882,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) { | |
} | ||
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8 | ||
|
||
#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) | ||
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#else | ||
#define HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#endif | ||
|
||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) { | ||
return ReduceSum(d, IfThenElseZero(m, v)); | ||
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) { | ||
return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unnecessarily expensive, how about we replace MaxOfLanes with Set(d, hwy::HighestValue)? |
||
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) { | ||
return ReduceMax(d, IfThenElseZero(m, v)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can get into trouble for signed values. If all values are negative, the presence of mask=false elements changes the result. Can similarly use hwy::LowestValue here? |
||
} | ||
|
||
#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
|
||
// ------------------------------ IsEitherNaN | ||
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE)) | ||
#ifdef HWY_NATIVE_IS_EITHER_NAN | ||
|
@@ -6444,6 +6466,30 @@ HWY_API V ReverseBits(V v) { | |
} | ||
#endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64 | ||
|
||
// ------------------------------ TableLookupLanesOr | ||
template <class V, class M> | ||
HWY_API V TableLookupLanesOr(M m, V a, V b, IndicesFromD<DFromV<V>> idx) { | ||
return IfThenElse(m, TableLookupLanes(a, idx), b); | ||
} | ||
|
||
// ------------------------------ TableLookupLanesOrZero | ||
template <class V, class M> | ||
HWY_API V TableLookupLanesOrZero(M m, V a, IndicesFromD<DFromV<V>> idx) { | ||
return IfThenElseZero(m, TableLookupLanes(a, idx)); | ||
} | ||
|
||
// ------------------------------ TwoTablesLookupLanesOr | ||
template <class D, class V, class M> | ||
HWY_API V TwoTablesLookupLanesOr(D d, M m, V a, V b, IndicesFromD<D> idx) { | ||
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), a); | ||
} | ||
|
||
// ------------------------------ TwoTablesLookupLanesOrZero | ||
template <class D, class V, class M> | ||
HWY_API V TwoTablesLookupLanesOrZero(D d, M m, V a, V b, IndicesFromD<D> idx) { | ||
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), Zero(d)); | ||
} | ||
|
||
// ------------------------------ Per4LaneBlockShuffle | ||
|
||
#if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE)) | ||
|
@@ -7299,6 +7345,10 @@ HWY_API V BitShuffle(V v, VI idx) { | |
|
||
#endif // HWY_NATIVE_BITSHUFFLE | ||
|
||
template <class V, class M> | ||
HWY_API V MaskedOrOrZero(M m, V a, V b) { | ||
return IfThenElseZero(m, Or(a, b)); | ||
} | ||
// ================================================== Operator wrapper | ||
|
||
// SVE* and RVV currently cannot define operators and have already defined | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a different naming convention here which might be a bit more natural?
There is also a MaskedLoad which returns 0 as the default, as opposed to MaskedLoadOr, which has the explicit default value. If we apply that here, we can just call it MaskedOr(m, a b), what do you think?