-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Float operations SqrtLower, MulSubAdd, GetExponent etc. #2425
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -657,6 +657,10 @@ from left to right, of the arguments passed to `Create{2-4}`. | |
* `V`: `{f}` \ | ||
<code>V **Sqrt**(V a)</code>: returns `sqrt(a[i])`. | ||
|
||
* `V`: `{f}` \ | ||
<code>V **SqrtLower**(V a)</code>: returns `sqrt(a[0])` in lowest lane and | ||
`a[i]` elsewhere. | ||
|
||
* `V`: `{f}` \ | ||
<code>V **ApproximateReciprocalSqrt**(V a)</code>: returns an approximation | ||
of `1.0 / sqrt(a[i])`. `sqrt(a) ~= ApproximateReciprocalSqrt(a) * a`. x86 | ||
|
@@ -666,6 +670,10 @@ from left to right, of the arguments passed to `Create{2-4}`. | |
<code>V **ApproximateReciprocal**(V a)</code>: returns an approximation of | ||
`1.0 / a[i]`. | ||
|
||
* `V`: `{f}` \ | ||
<code>V **GetExponent**(V v)</code>: returns the exponent of `v[i]` as a floating point value. | ||
Essentially calculates `floor(log2(x))`. | ||
|
||
#### Min/Max | ||
|
||
**Note**: Min/Max corner cases are target-specific and may change. If either | ||
|
@@ -846,6 +854,10 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the | |
c))` or `MulAddSub(a, b, OddEven(c, Neg(c))`, but `MulSub(a, b, c)` is more | ||
efficient on some targets (including AVX2/AVX3). | ||
|
||
* <code>V **MulSubAdd**(V a, V b, V c)</code>: returns `a[i] * b[i] + c[i]` in | ||
the even lanes and `a[i] * b[i] - c[i]` in the odd lanes. Essentially, | ||
MulAddSub with `c[i]` negated. | ||
|
||
* `V`: `bf16`, `D`: `RepartitionToWide<DFromV<V>>`, `VW`: `Vec<D>` \ | ||
<code>VW **MulEvenAdd**(D d, V a, V b, VW c)</code>: equivalent to and | ||
potentially more efficient than `MulAdd(PromoteEvenTo(d, a), | ||
|
@@ -887,6 +899,21 @@ not a concern, these are equivalent to, and potentially more efficient than, | |
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if | ||
`m[i]` is false. | ||
|
||
#### Zero masked arithmetic | ||
|
||
All ops in this section return `0` for `mask=false` lanes. These are equivalent | ||
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc. | ||
|
||
* `V`: `{f}` \ | ||
<code>V **MaskedSqrtOrZero**(M m, V a)</code>: returns `sqrt(a[i])` where | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also drop the OrZero suffix here. |
||
m is true, and zero otherwise. | ||
* `V`: `{f}` \ | ||
<code>V **MaskedApproximateReciprocalSqrtOrZero**(M m, V a)</code>: returns | ||
the result of ApproximateReciprocalSqrt where m is true and zero otherwise. | ||
* `V`: `{f}` \ | ||
<code>V **MaskedApproximateReciprocalOrZero**(M m, V a)</code>: returns the | ||
result of ApproximateReciprocal where m is true and zero otherwise. | ||
|
||
#### Shifts | ||
|
||
**Note**: Counts not in `[0, sizeof(T)*8)` yield implementation-defined results. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) | |
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ | ||
return sv##OP##_##CHAR##BITS(v); \ | ||
} | ||
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) \ | ||
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ | ||
return sv##OP##_##CHAR##BITS##_m(b, m, a); \ | ||
} | ||
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ | ||
return sv##OP##_##CHAR##BITS##_z(m, a); \ | ||
} | ||
|
||
// vector = f(vector, scalar), e.g. detail::AddN | ||
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
|
@@ -1234,6 +1243,29 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) | |
// ------------------------------ Sqrt | ||
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) | ||
|
||
// ------------------------------ MaskedSqrt | ||
namespace detail { | ||
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_M, MaskedSqrt, sqrt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we rely on First1 and remove the SqrtLower wrapper, then we can instead expose this op as MaskedSqrtOr(V no, V a). |
||
} | ||
|
||
// ------------------------------ SqrtLower | ||
#ifdef HWY_NATIVE_SQRT_LOWER | ||
#undef HWY_NATIVE_SQRT_LOWER | ||
#else | ||
#define HWY_NATIVE_SQRT_LOWER | ||
#endif | ||
|
||
#define HWY_SVE_SQRT_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) a) { \ | ||
return detail::MaskedSqrt(svptrue_pat_b##BITS(SV_VL1), a, a); \ | ||
} | ||
|
||
HWY_SVE_FOREACH_F(HWY_SVE_SQRT_LOWER, SqrtLower, _) | ||
#undef HWY_SVE_SQRT_LOWER | ||
|
||
// ------------------------------ MaskedSqrtOrZero | ||
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrtOrZero, sqrt) | ||
|
||
// ------------------------------ ApproximateReciprocalSqrt | ||
#ifdef HWY_NATIVE_F64_APPROX_RSQRT | ||
#undef HWY_NATIVE_F64_APPROX_RSQRT | ||
|
@@ -2883,6 +2915,34 @@ HWY_API VFromD<D> Iota(const D d, T2 first) { | |
ConvertScalarTo<TFromD<D>>(first)); | ||
} | ||
|
||
// ------------------------------ GetExponent | ||
|
||
#if HWY_SVE_HAVE_2 || HWY_IDE | ||
#ifdef HWY_NATIVE_GET_EXPONENT | ||
#undef HWY_NATIVE_GET_EXPONENT | ||
#else | ||
#define HWY_NATIVE_GET_EXPONENT | ||
#endif | ||
|
||
namespace detail { | ||
#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \ | ||
HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ | ||
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ | ||
} | ||
HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb) | ||
#undef HWY_SVE_GET_EXP | ||
} // namespace detail | ||
|
||
template <class V, HWY_IF_FLOAT_V(V)> | ||
HWY_API V GetExponent(V v) { | ||
const DFromV<V> d; | ||
const RebindToSigned<decltype(d)> di; | ||
const VFromD<decltype(di)> exponent_int = detail::GetExponent(v); | ||
// convert integer to original type | ||
return ConvertTo(d, exponent_int); | ||
} | ||
#endif // HWY_SVE_HAVE_2 | ||
|
||
// ------------------------------ InterleaveLower | ||
|
||
template <class D, class V> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in other PRs, I think we'd be better off with a new First1 op, and removing the other newly added *Lower ops.