Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float operations SqrtLower, MulSubAdd, GetExponent etc. #2425

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ from left to right, of the arguments passed to `Create{2-4}`.
* `V`: `{f}` \
<code>V **Sqrt**(V a)</code>: returns `sqrt(a[i])`.

* `V`: `{f}` \
<code>V **SqrtLower**(V a)</code>: returns `sqrt(a[0])` in lowest lane and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in other PRs, I think we'd be better off with a new First1 op, and removing the other newly added *Lower ops.

`a[i]` elsewhere.

* `V`: `{f}` \
<code>V **ApproximateReciprocalSqrt**(V a)</code>: returns an approximation
of `1.0 / sqrt(a[i])`. `sqrt(a) ~= ApproximateReciprocalSqrt(a) * a`. x86
Expand All @@ -666,6 +670,10 @@ from left to right, of the arguments passed to `Create{2-4}`.
<code>V **ApproximateReciprocal**(V a)</code>: returns an approximation of
`1.0 / a[i]`.

* `V`: `{f}` \
<code>V **GetExponent**(V v)</code>: returns the exponent of `v[i]` as a floating point value.
Essentially calculates `floor(log2(x))`.

#### Min/Max

**Note**: Min/Max corner cases are target-specific and may change. If either
Expand Down Expand Up @@ -846,6 +854,10 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the
c))` or `MulAddSub(a, b, OddEven(c, Neg(c))`, but `MulSub(a, b, c)` is more
efficient on some targets (including AVX2/AVX3).

* <code>V **MulSubAdd**(V a, V b, V c)</code>: returns `a[i] * b[i] + c[i]` in
the even lanes and `a[i] * b[i] - c[i]` in the odd lanes. Essentially,
MulAddSub with `c[i]` negated.

* `V`: `bf16`, `D`: `RepartitionToWide<DFromV<V>>`, `VW`: `Vec<D>` \
<code>VW **MulEvenAdd**(D d, V a, V b, VW c)</code>: equivalent to and
potentially more efficient than `MulAdd(PromoteEvenTo(d, a),
Expand Down Expand Up @@ -887,6 +899,21 @@ not a concern, these are equivalent to, and potentially more efficient than,
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.

#### Zero masked arithmetic

All ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.

* `V`: `{f}` \
<code>V **MaskedSqrtOrZero**(M m, V a)</code>: returns `sqrt(a[i])` where
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also drop the OrZero suffix here.

m is true, and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalSqrtOrZero**(M m, V a)</code>: returns
the result of ApproximateReciprocalSqrt where m is true and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalOrZero**(M m, V a)</code>: returns the
result of ApproximateReciprocal where m is true and zero otherwise.

#### Shifts

**Note**: Counts not in `[0, sizeof(T)*8)` yield implementation-defined results.
Expand Down
60 changes: 60 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(b, m, a); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -1234,6 +1243,29 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
// ------------------------------ Sqrt
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)

// ------------------------------ MaskedSqrt
namespace detail {
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_M, MaskedSqrt, sqrt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we rely on First1 and remove the SqrtLower wrapper, then we can instead expose this op as MaskedSqrtOr(V no, V a).

}

// ------------------------------ SqrtLower
#ifdef HWY_NATIVE_SQRT_LOWER
#undef HWY_NATIVE_SQRT_LOWER
#else
#define HWY_NATIVE_SQRT_LOWER
#endif

#define HWY_SVE_SQRT_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) a) { \
return detail::MaskedSqrt(svptrue_pat_b##BITS(SV_VL1), a, a); \
}

HWY_SVE_FOREACH_F(HWY_SVE_SQRT_LOWER, SqrtLower, _)
#undef HWY_SVE_SQRT_LOWER

// ------------------------------ MaskedSqrtOrZero
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrtOrZero, sqrt)

// ------------------------------ ApproximateReciprocalSqrt
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
#undef HWY_NATIVE_F64_APPROX_RSQRT
Expand Down Expand Up @@ -2883,6 +2915,34 @@ HWY_API VFromD<D> Iota(const D d, T2 first) {
ConvertScalarTo<TFromD<D>>(first));
}

// ------------------------------ GetExponent

#if HWY_SVE_HAVE_2 || HWY_IDE
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

namespace detail {
#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
}
HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb)
#undef HWY_SVE_GET_EXP
} // namespace detail

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
const RebindToSigned<decltype(d)> di;
const VFromD<decltype(di)> exponent_int = detail::GetExponent(v);
// convert integer to original type
return ConvertTo(d, exponent_int);
}
#endif // HWY_SVE_HAVE_2

// ------------------------------ InterleaveLower

template <class D, class V>
Expand Down
77 changes: 77 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,34 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ GetExponent

#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
using T = TFromV<V>;
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<decltype(d)> di;

constexpr uint8_t mantissa_bits = MantissaBits<T>();
const auto exponent_offset = Set(di, MaxExponentField<T>() >> 1);

// extract exponent bits as integer
const auto encoded_exponent = ShiftRight<mantissa_bits>(BitCast(du, Abs(v)));
const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset);

// convert integer to original type
return ConvertTo(d, exponent_int);
}

#endif // HWY_NATIVE_GET_EXPONENT
// ------------------------------ LoadInterleaved2

#if HWY_IDE || \
Expand Down Expand Up @@ -4359,6 +4387,19 @@ HWY_API V MulAddSub(V mul, V x, V sub_or_add) {
OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
return MulAdd(mul, x, add);
}
// ------------------------------ MulSubAdd

template <class V>
HWY_API V MulSubAdd(V mul, V x, V sub_or_add) {
using D = DFromV<V>;
using T = TFromD<D>;
using TNegate = If<!IsSigned<T>(), MakeSigned<T>, T>;

const D d;
const Rebind<TNegate, D> d_negate;

return MulAddSub(mul, x, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
}

// ------------------------------ Integer division
#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -5184,6 +5225,30 @@ HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,

#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT

// ------------------------------ SqrtLower
#if (defined(HWY_NATIVE_SQRT_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_SQRT_LOWER
#undef HWY_NATIVE_SQRT_LOWER
#else
#define HWY_NATIVE_SQRT_LOWER
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V SqrtLower(V a) {
const DFromV<V> d;
const auto first_mask = FirstN(d, 1);
return IfThenElse(first_mask, Sqrt(a), a);
}

#undef HWY_SVE_SQRT_LOWER
#endif // HWY_NATIVE_SQRT_LOWER

// ------------------------------ MaskedSqrtOrZero
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOrZero(M m, V v) {
return IfThenElseZero(m, Sqrt(v));
}

// ------------------------------ SumOfMulQuadAccumulate

#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \
Expand Down Expand Up @@ -5368,6 +5433,12 @@ HWY_API V ApproximateReciprocal(V v) {

#endif // HWY_NATIVE_F64_APPROX_RECIP

// ------------------------------ MaskedApproximateReciprocalOrZero
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalOrZero(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocal(v));
}

// ------------------------------ F64 ApproximateReciprocalSqrt

#if (defined(HWY_NATIVE_F64_APPROX_RSQRT) == defined(HWY_TARGET_TOGGLE))
Expand All @@ -5393,6 +5464,12 @@ HWY_API V ApproximateReciprocalSqrt(V v) {

#endif // HWY_NATIVE_F64_APPROX_RSQRT

// ------------------------------ MaskedApproximateReciprocalSqrtOrZero
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalSqrtOrZero(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocalSqrt(v));
}

// ------------------------------ Compress*

#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE))
Expand Down
Loading
Loading