diff --git a/numpy/_core/meson.build b/numpy/_core/meson.build index 3bffed752474..01c333209d6b 100644 --- a/numpy/_core/meson.build +++ b/numpy/_core/meson.build @@ -965,13 +965,14 @@ foreach gen_mtargets : [ ], [ 'loops_minmax.dispatch.h', - src_file.process('src/umath/loops_minmax.dispatch.c.src'), + 'src/umath/loops_minmax.dispatch.cpp', [ ASIMD, NEON, AVX512_SKX, AVX2, SSE2, VSX2, VXE, VX, LSX, + RVV, ] ], [ diff --git a/numpy/_core/src/common/simd/README.md b/numpy/_core/src/common/simd/README.md new file mode 100644 index 000000000000..9a68d1aa1bfc --- /dev/null +++ b/numpy/_core/src/common/simd/README.md @@ -0,0 +1,263 @@ +# NumPy SIMD Wrapper for Highway + +This directory contains a lightweight C++ wrapper over Google's [Highway](https://github.com/google/highway) SIMD library, designed specifically for NumPy's needs. + +> **Note**: This directory also contains the C interface of universal intrinsics (under `simd.h`) which is no longer supported. The Highway wrapper described in this document should be used instead for all new SIMD code. + +## Overview + +The wrapper simplifies Highway's SIMD interface by eliminating class tags and using lane types directly, which can be deduced from arguments in most cases. This design makes the SIMD code more intuitive and easier to maintain while still leveraging Highway generic intrinsics. + +## Architecture + +The wrapper consists of two main headers: + +1. `simd.hpp`: The main header that defines namespaces and includes configuration macros +2. `simd.inc.hpp`: Implementation details included by `simd.hpp` multiple times for different namespaces + +Additionally, this directory contains legacy C interface files for universal intrinsics (`simd.h` and related files) which are deprecated and should not be used for new code. All new SIMD code should use the Highway wrapper. + + +## Usage + +### Basic Usage + +```cpp +#include "simd/simd.hpp" + +// Use np::simd for maximum width SIMD operations +using namespace np::simd; +float *data = /* ... */; +Vec v = LoadU(data); +v = Add(v, v); +StoreU(v, data); + +// Use np::simd128 for fixed 128-bit SIMD operations +using namespace np::simd128; +Vec v128 = LoadU(data); +v128 = Add(v128, v128); +StoreU(v128, data); +``` + +### Checking for SIMD Support + +```cpp +#include "simd/simd.hpp" + +// Check if SIMD is enabled +#if NPY_SIMDX + // SIMD code +#else + // Scalar fallback code +#endif + +// Check for float64 support +#if NPY_SIMDX_F64 + // Use float64 SIMD operations +#endif + +// Check for FMA support +#if NPY_SIMDX_FMA + // Use FMA operations +#endif +``` + +## Type Support and Constraints + +The wrapper provides type constraints to help with SFINAE (Substitution Failure Is Not An Error) and compile-time type checking: + +- `kSupportLane`: Determines whether the specified lane type is supported by the SIMD extension. + ```cpp + // Base template - always defined, even when SIMD is not enabled (for SFINAE) + template + constexpr bool kSupportLane = NPY_SIMDX != 0; + template <> + constexpr bool kSupportLane = NPY_SIMDX_F64 != 0; + ``` + +- `kMaxLanes`: Maximum number of lanes supported by the SIMD extension for the specified lane type. + ```cpp + template + constexpr size_t kMaxLanes = HWY_MAX_LANES_D(_Tag); + ``` + +```cpp +#include "simd/simd.hpp" + +// Check if float64 operations are supported +if constexpr (np::simd::kSupportLane) { + // Use float64 operations +} +``` + +These constraints allow for compile-time checking of which lane types are supported, which can be used in SFINAE contexts to enable or disable functions based on type support. + +## Available Operations + +The wrapper provides the following common operations that are used in NumPy: + +- Vector creation operations: + - `Zero`: Returns a vector with all lanes set to zero + - `Set`: Returns a vector with all lanes set to the given value + - `Undefined`: Returns an uninitialized vector + +- Memory operations: + - `LoadU`: Unaligned load of a vector from memory + - `StoreU`: Unaligned store of a vector to memory + +- Vector information: + - `Lanes`: Returns the number of vector lanes based on the lane type + +- Type conversion: + - `BitCast`: Reinterprets a vector to a different type without modifying the underlying data + - `VecFromMask`: Converts a mask to a vector + +- Comparison operations: + - `Eq`: Element-wise equality comparison + - `Le`: Element-wise less than or equal comparison + - `Lt`: Element-wise less than comparison + - `Gt`: Element-wise greater than comparison + - `Ge`: Element-wise greater than or equal comparison + +- Arithmetic operations: + - `Add`: Element-wise addition + - `Sub`: Element-wise subtraction + - `Mul`: Element-wise multiplication + - `Div`: Element-wise division + - `Min`: Element-wise minimum + - `Max`: Element-wise maximum + - `Abs`: Element-wise absolute value + - `Sqrt`: Element-wise square root + +- Logical operations: + - `And`: Bitwise AND + - `Or`: Bitwise OR + - `Xor`: Bitwise XOR + - `AndNot`: Bitwise AND NOT (a & ~b) + +Additional Highway operations can be accessed via the `hn` namespace alias inside the `simd` or `simd128` namespaces. + +## Extending + +To add more operations from Highway: + +1. Import them in the `simd.inc.hpp` file using the `using` directive if they don't require a tag: + ```cpp + // For operations that don't require a tag + using hn::FunctionName; + ``` + +2. Define wrapper functions for intrinsics that require a class tag: + ```cpp + // For operations that require a tag + template + HWY_API ReturnType FunctionName(Args... args) { + return hn::FunctionName(_Tag(), args...); + } + ``` + +3. Add appropriate documentation and SFINAE constraints if needed + + +## Build Configuration + +The SIMD wrapper automatically disables SIMD operations when optimizations are disabled: + +- When `NPY_DISABLE_OPTIMIZATION` is defined, SIMD operations are disabled +- SIMD is enabled only when the Highway target is not scalar (`HWY_TARGET != HWY_SCALAR`) + +## Design Notes + +1. **Why avoid Highway scalar operations?** + - NumPy already provides kernels for scalar operations + - Compilers can better optimize standard library implementations + - Not all Highway intrinsics are fully supported in scalar mode + +2. **Legacy Universal Intrinsics** + - The older universal intrinsics C interface (in `simd.h` and accessible via `NPY_SIMD` macros) is deprecated + - All new SIMD code should use this Highway-based wrapper (accessible via `NPY_SIMDX` macros) + - The legacy code is maintained for compatibility but will eventually be removed + +3. **Feature Detection Constants vs. Highway Constants** + - NumPy-specific constants (`NPY_SIMDX_F16`, `NPY_SIMDX_F64`, `NPY_SIMDX_FMA`) provide additional safety beyond raw Highway constants + - Highway constants (e.g., `HWY_HAVE_FLOAT16`) only check platform capabilities but don't consider NumPy's build configuration + - Our constants combine both checks: + ```cpp + #define NPY_SIMDX_F16 (NPY_SIMDX && HWY_HAVE_FLOAT16) + ``` + - This ensures SIMD features won't be used when: + - Platform supports it but NumPy optimization is disabled via meson option: + ``` + option('disable-optimization', type: 'boolean', value: false, + description: 'Disable CPU optimized code (dispatch,simd,unroll...)') + ``` + - Highway target is scalar (`HWY_TARGET == HWY_SCALAR`) + - Using these constants ensures consistent behavior across different compilation settings + - Without this additional layer, code might incorrectly try to use SIMD paths in scalar mode + +4. **Namespace Design** + - `np::simd`: Maximum width SIMD operations (scalable) + - `np::simd128`: Fixed 128-bit SIMD operations + - `hn`: Highway namespace alias (available within the SIMD namespaces) + +5. **Why Namespaces and Why Not Just Use Highway Directly?** + - Highway's design uses class tag types as template parameters (e.g., `Vec>`) when defining vector types + - Many Highway functions require explicitly passing a tag instance as the first parameter + - This class tag-based approach increases verbosity and complexity in user code + - Our wrapper eliminates this by internally managing tags through namespaces, letting users directly use types e.g. `Vec` + - Simple example with raw Highway: + ```cpp + // Highway's approach + float *data = /* ... */; + + namespace hn = hwy::HWY_NAMESPACE; + using namespace hn; + + // Full-width operations + ScalableTag df; // Create a tag instance + Vec v = LoadU(df, data); // LoadU requires a tag instance + StoreU(v, df, data); // StoreU requires a tag instance + + // 128-bit operations + Full128 df128; // Create a 128-bit tag instance + Vec v128 = LoadU(df128, data); // LoadU requires a tag instance + StoreU(v128, df128, data); // StoreU requires a tag instance + ``` + + - Simple example with our wrapper: + ```cpp + // Our wrapper approach + float *data = /* ... */; + + // Full-width operations + using namespace np::simd; + Vec v = LoadU(data); // Full-width vector load + StoreU(v, data); + + // 128-bit operations + using namespace np::simd128; + Vec v128 = LoadU(data); // 128-bit vector load + StoreU(v128, data); + ``` + + - The namespaced approach simplifies code, reduces errors, and provides a more intuitive interface + - It preserves all Highway operations benefits while reducing cognitive overhead + +5. **Why Namespaces Are Essential for This Design?** + - Namespaces allow us to define different internal tag types (`hn::ScalableTag` in `np::simd` vs `hn::Full128` in `np::simd128`) + - This provides a consistent type-based interface (`Vec`) without requiring users to manually create tags + - Enables using the same function names (like `LoadU`) with different implementations based on SIMD width + - Without namespaces, we'd have to either reintroduce tags (defeating the purpose of the wrapper) or create different function names for each variant (e.g., `LoadU` vs `LoadU128`) + +6. **Template Type Parameters** + - `TLane`: The scalar type for each vector lane (e.g., uint8_t, float, double) + + +## Requirements + +- C++17 or later +- Google Highway library + +## License + +Same as NumPy's license diff --git a/numpy/_core/src/common/simd/simd.hpp b/numpy/_core/src/common/simd/simd.hpp new file mode 100644 index 000000000000..698da4adf865 --- /dev/null +++ b/numpy/_core/src/common/simd/simd.hpp @@ -0,0 +1,80 @@ +#ifndef NUMPY__CORE_SRC_COMMON_SIMD_SIMD_HPP_ +#define NUMPY__CORE_SRC_COMMON_SIMD_SIMD_HPP_ + +/** + * This header provides a thin wrapper over Google's Highway SIMD library. + * + * The wrapper aims to simplify the SIMD interface of Google's Highway by + * get ride of its class tags and use lane types directly which can be deduced + * from the args in most cases. + */ +/** + * Since `NPY_SIMD` is only limited to NumPy C universal intrinsics, + * `NPY_SIMDX` is defined to indicate the SIMD availability for Google's Highway + * C++ code. + * + * Highway SIMD is only available when optimization is enabled. + * When NPY_DISABLE_OPTIMIZATION is defined, SIMD operations are disabled + * and the code falls back to scalar implementations. + */ +#ifndef NPY_DISABLE_OPTIMIZATION +#include + +/** + * We avoid using Highway scalar operations for the following reasons: + * 1. We already provide kernels for scalar operations, so falling back to + * the NumPy implementation is more appropriate. Compilers can often + * optimize these better since they rely on standard libraries. + * 2. Not all Highway intrinsics are fully supported in scalar mode. + * + * Therefore, we only enable SIMD when the Highway target is not scalar. + */ +#define NPY_SIMDX (HWY_TARGET != HWY_SCALAR) + +// Indicates if the SIMD operations are available for float16. +#define NPY_SIMDX_F16 (NPY_SIMDX && HWY_HAVE_FLOAT16) +// Note: Highway requires SIMD extentions with native float32 support, so we don't need +// to check for it. + +// Indicates if the SIMD operations are available for float64. +#define NPY_SIMDX_F64 (NPY_SIMDX && HWY_HAVE_FLOAT64) + +// Indicates if the SIMD floating operations are natively supports fma. +#define NPY_SIMDX_FMA (NPY_SIMDX && HWY_NATIVE_FMA) + +#else +#define NPY_SIMDX 0 +#define NPY_SIMDX_F16 0 +#define NPY_SIMDX_F64 0 +#define NPY_SIMDX_FMA 0 +#endif + +namespace np { + +/// Represents the max SIMD width supported by the platform. +namespace simd { +#if NPY_SIMDX +/// The highway namespace alias. +/// We can not import all the symbols from the HWY_NAMESPACE because it will +/// conflict with the existing symbols in the numpy namespace. +namespace hn = hwy::HWY_NAMESPACE; +// internaly used by the template header +template +using _Tag = hn::ScalableTag; +#endif +#include "simd.inc.hpp" +} // namespace simd + +/// Represents the 128-bit SIMD width. +namespace simd128 { +#if NPY_SIMDX +namespace hn = hwy::HWY_NAMESPACE; +template +using _Tag = hn::Full128; +#endif +#include "simd.inc.hpp" +} // namespace simd128 + +} // namespace np + +#endif // NUMPY__CORE_SRC_COMMON_SIMD_SIMD_HPP_ diff --git a/numpy/_core/src/common/simd/simd.inc.hpp b/numpy/_core/src/common/simd/simd.inc.hpp new file mode 100644 index 000000000000..64d28bc47118 --- /dev/null +++ b/numpy/_core/src/common/simd/simd.inc.hpp @@ -0,0 +1,132 @@ +#ifndef NPY_SIMDX +#error "This is not a standalone header. Include simd.hpp instead." +#define NPY_SIMDX 1 // Prevent editors from graying out the happy branch +#endif + +// Using anonymous namespace instead of inline to ensure each translation unit +// gets its own copy of constants based on local compilation flags +namespace { + +// NOTE: This file is included by simd.hpp multiple times with different namespaces +// so avoid including any headers here + +/** + * Determines whether the specified lane type is supported by the SIMD extension. + * Always defined as false when SIMD is not enabled, so it can be used in SFINAE. + * + * @tparam TLane The lane type to check for support. + */ +template +constexpr bool kSupportLane = NPY_SIMDX != 0; + +#if NPY_SIMDX +// Define lane type support based on Highway capabilities +template <> +constexpr bool kSupportLane = HWY_HAVE_FLOAT16 != 0; +template <> +constexpr bool kSupportLane = HWY_HAVE_FLOAT64 != 0; +template <> +constexpr bool kSupportLane = + HWY_HAVE_FLOAT64 != 0 && sizeof(long double) == sizeof(double); + +/// Maximum number of lanes supported by the SIMD extension for the specified lane type. +template +constexpr size_t kMaxLanes = HWY_MAX_LANES_D(_Tag); + +/// Represents an N-lane vector based on the specified lane type. +/// @tparam TLane The scalar type for each vector lane +template +using Vec = hn::Vec<_Tag>; + +/// Represents a mask vector with boolean values or as a bitmask. +/// @tparam TLane The scalar type the mask corresponds to +template +using Mask = hn::Mask<_Tag>; + +/// Unaligned load of a vector from memory. +template +HWY_API Vec +LoadU(const TLane *ptr) +{ + return hn::LoadU(_Tag(), ptr); +} + +/// Unaligned store of a vector to memory. +template +HWY_API void +StoreU(const Vec &a, TLane *ptr) +{ + hn::StoreU(a, _Tag(), ptr); +} + +/// Returns the number of vector lanes based on the lane type. +template +HWY_API HWY_LANES_CONSTEXPR size_t +Lanes(TLane tag = 0) +{ + return hn::Lanes(_Tag()); +} + +/// Returns an uninitialized N-lane vector. +template +HWY_API Vec +Undefined(TLane tag = 0) +{ + return hn::Undefined(_Tag()); +} + +/// Returns N-lane vector with all lanes equal to zero. +template +HWY_API Vec +Zero(TLane tag = 0) +{ + return hn::Zero(_Tag()); +} + +/// Returns N-lane vector with all lanes equal to the given value of type `TLane`. +template +HWY_API Vec +Set(TLane val) +{ + return hn::Set(_Tag(), val); +} + +/// Converts a mask to a vector based on the specified lane type. +template +HWY_API Vec +VecFromMask(const TMask &m) +{ + return hn::VecFromMask(_Tag(), m); +} + +/// Convert (Reinterpret) an N-lane vector to a different type without modifying the +/// underlying data. +template +HWY_API Vec +BitCast(const TVec &v) +{ + return hn::BitCast(_Tag(), v); +} + +// Import common Highway intrinsics +using hn::Abs; +using hn::Add; +using hn::And; +using hn::AndNot; +using hn::Div; +using hn::Eq; +using hn::Ge; +using hn::Gt; +using hn::Le; +using hn::Lt; +using hn::Max; +using hn::Min; +using hn::Mul; +using hn::Or; +using hn::Sqrt; +using hn::Sub; +using hn::Xor; + +#endif // NPY_SIMDX + +} // namespace anonymous diff --git a/numpy/_core/src/highway b/numpy/_core/src/highway index 0b696633f9ad..12b325bc1793 160000 --- a/numpy/_core/src/highway +++ b/numpy/_core/src/highway @@ -1 +1 @@ -Subproject commit 0b696633f9ad89497dd5532b55eaa01625ad71ca +Subproject commit 12b325bc1793dee68ab2157995a690db859fe9e0 diff --git a/numpy/_core/src/umath/loops_hyperbolic.dispatch.cpp.src b/numpy/_core/src/umath/loops_hyperbolic.dispatch.cpp.src index 8c66229942ee..93d288fbdb2e 100755 --- a/numpy/_core/src/umath/loops_hyperbolic.dispatch.cpp.src +++ b/numpy/_core/src/umath/loops_hyperbolic.dispatch.cpp.src @@ -385,7 +385,7 @@ simd_tanh_f64(const double *src, npy_intp ssrc, double *dst, npy_intp sdst, npy_ vec_f64 b, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16; if constexpr(hn::MaxLanes(f64) == 2){ vec_f64 e0e1_0, e0e1_1; - uint64_t index[hn::Lanes(f64)]; + uint64_t index[hn::MaxLanes(f64)]; hn::StoreU(idx, u64, index); /**begin repeat diff --git a/numpy/_core/src/umath/loops_minmax.dispatch.c.src b/numpy/_core/src/umath/loops_minmax.dispatch.c.src deleted file mode 100644 index c11f391f9159..000000000000 --- a/numpy/_core/src/umath/loops_minmax.dispatch.c.src +++ /dev/null @@ -1,476 +0,0 @@ -#define _UMATHMODULE -#define _MULTIARRAYMODULE -#define NPY_NO_DEPRECATED_API NPY_API_VERSION - -#include "simd/simd.h" -#include "loops_utils.h" -#include "loops.h" -#include "lowlevel_strided_loops.h" -// Provides the various *_LOOP macros -#include "fast_loop_macros.h" - -/******************************************************************************* - ** Scalar intrinsics - ******************************************************************************/ -// signed/unsigned int -#define scalar_max_i(A, B) ((A > B) ? A : B) -#define scalar_min_i(A, B) ((A < B) ? A : B) -// fp, propagates NaNs -#define scalar_max(A, B) ((A >= B || npy_isnan(A)) ? A : B) -#define scalar_max_f scalar_max -#define scalar_max_d scalar_max -#define scalar_max_l scalar_max -#define scalar_min(A, B) ((A <= B || npy_isnan(A)) ? A : B) -#define scalar_min_f scalar_min -#define scalar_min_d scalar_min -#define scalar_min_l scalar_min -// fp, ignores NaNs -#define scalar_maxp_f fmaxf -#define scalar_maxp_d fmax -#define scalar_maxp_l fmaxl -#define scalar_minp_f fminf -#define scalar_minp_d fmin -#define scalar_minp_l fminl - -// special optimization for fp scalars propagates NaNs -// since there're no C99 support for it -#ifndef NPY_DISABLE_OPTIMIZATION -/**begin repeat - * #type = npy_float, npy_double# - * #sfx = f32, f64# - * #c_sfx = f, d# - * #isa_sfx = s, d# - * #sse_type = __m128, __m128d# - */ -/**begin repeat1 - * #op = max, min# - * #neon_instr = fmax, fmin# - */ -#ifdef NPY_HAVE_SSE2 -#undef scalar_@op@_@c_sfx@ -NPY_FINLINE @type@ scalar_@op@_@c_sfx@(@type@ a, @type@ b) { - @sse_type@ va = _mm_set_s@isa_sfx@(a); - @sse_type@ vb = _mm_set_s@isa_sfx@(b); - @sse_type@ rv = _mm_@op@_s@isa_sfx@(va, vb); - // X86 handle second operand - @sse_type@ nn = _mm_cmpord_s@isa_sfx@(va, va); - #ifdef NPY_HAVE_SSE41 - rv = _mm_blendv_p@isa_sfx@(va, rv, nn); - #else - rv = _mm_xor_p@isa_sfx@(va, _mm_and_p@isa_sfx@(_mm_xor_p@isa_sfx@(va, rv), nn)); - #endif - return _mm_cvts@isa_sfx@_@sfx@(rv); -} -#endif // SSE2 -#ifdef __aarch64__ -#undef scalar_@op@_@c_sfx@ -NPY_FINLINE @type@ scalar_@op@_@c_sfx@(@type@ a, @type@ b) { - @type@ result = 0; - __asm( - "@neon_instr@ %@isa_sfx@[result], %@isa_sfx@[a], %@isa_sfx@[b]" - : [result] "=w" (result) - : [a] "w" (a), [b] "w" (b) - ); - return result; -} -#endif // __aarch64__ -/**end repeat1**/ -/**end repeat**/ -#endif // NPY_DISABLE_OPTIMIZATION -// mapping to double if its possible -#if NPY_BITSOF_DOUBLE == NPY_BITSOF_LONGDOUBLE -/**begin repeat - * #op = max, min, maxp, minp# - */ - #undef scalar_@op@_l - #define scalar_@op@_l scalar_@op@_d -/**end repeat**/ -#endif - -/******************************************************************************* - ** Defining the SIMD kernels - ******************************************************************************/ -/**begin repeat - * #sfx = s8, u8, s16, u16, s32, u32, s64, u64, f32, f64# - * #simd_chk = NPY_SIMD*8, NPY_SIMD_F32, NPY_SIMD_F64# - * #is_fp = 0*8, 1, 1# - * #scalar_sfx = i*8, f, d# - */ -/**begin repeat1 - * # intrin = max, min, maxp, minp# - * # fp_only = 0, 0, 1, 1# - */ -#define SCALAR_OP scalar_@intrin@_@scalar_sfx@ -#if @simd_chk@ && (!@fp_only@ || (@is_fp@ && @fp_only@)) - -#if @is_fp@ && !@fp_only@ - #define V_INTRIN npyv_@intrin@n_@sfx@ // propagates NaNs - #define V_REDUCE_INTRIN npyv_reduce_@intrin@n_@sfx@ -#else - #define V_INTRIN npyv_@intrin@_@sfx@ - #define V_REDUCE_INTRIN npyv_reduce_@intrin@_@sfx@ -#endif - -// contiguous input. -static inline void -simd_reduce_c_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip, npyv_lanetype_@sfx@ *op1, npy_intp len) -{ - if (len < 1) { - return; - } - const int vstep = npyv_nlanes_@sfx@; - const int wstep = vstep*8; - npyv_@sfx@ acc = npyv_setall_@sfx@(op1[0]); - for (; len >= wstep; len -= wstep, ip += wstep) { - #ifdef NPY_HAVE_SSE2 - NPY_PREFETCH(ip + wstep, 0, 3); - #endif - npyv_@sfx@ v0 = npyv_load_@sfx@(ip + vstep * 0); - npyv_@sfx@ v1 = npyv_load_@sfx@(ip + vstep * 1); - npyv_@sfx@ v2 = npyv_load_@sfx@(ip + vstep * 2); - npyv_@sfx@ v3 = npyv_load_@sfx@(ip + vstep * 3); - - npyv_@sfx@ v4 = npyv_load_@sfx@(ip + vstep * 4); - npyv_@sfx@ v5 = npyv_load_@sfx@(ip + vstep * 5); - npyv_@sfx@ v6 = npyv_load_@sfx@(ip + vstep * 6); - npyv_@sfx@ v7 = npyv_load_@sfx@(ip + vstep * 7); - - npyv_@sfx@ r01 = V_INTRIN(v0, v1); - npyv_@sfx@ r23 = V_INTRIN(v2, v3); - npyv_@sfx@ r45 = V_INTRIN(v4, v5); - npyv_@sfx@ r67 = V_INTRIN(v6, v7); - acc = V_INTRIN(acc, V_INTRIN(V_INTRIN(r01, r23), V_INTRIN(r45, r67))); - } - for (; len >= vstep; len -= vstep, ip += vstep) { - acc = V_INTRIN(acc, npyv_load_@sfx@(ip)); - } - npyv_lanetype_@sfx@ r = V_REDUCE_INTRIN(acc); - // Scalar - finish up any remaining iterations - for (; len > 0; --len, ++ip) { - const npyv_lanetype_@sfx@ in2 = *ip; - r = SCALAR_OP(r, in2); - } - op1[0] = r; -} - -// contiguous inputs and output. -static inline void -simd_binary_ccc_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip1, const npyv_lanetype_@sfx@ *ip2, - npyv_lanetype_@sfx@ *op1, npy_intp len) -{ -#if NPY_SIMD_WIDTH == 128 - // Note, 6x unroll was chosen for best results on Apple M1 - const int vectorsPerLoop = 6; -#else - // To avoid memory bandwidth bottleneck - const int vectorsPerLoop = 2; -#endif - const int elemPerVector = npyv_nlanes_@sfx@; - int elemPerLoop = vectorsPerLoop * elemPerVector; - - npy_intp i = 0; - - for (; (i+elemPerLoop) <= len; i += elemPerLoop) { - npyv_@sfx@ v0 = npyv_load_@sfx@(&ip1[i + 0 * elemPerVector]); - npyv_@sfx@ v1 = npyv_load_@sfx@(&ip1[i + 1 * elemPerVector]); - #if NPY_SIMD_WIDTH == 128 - npyv_@sfx@ v2 = npyv_load_@sfx@(&ip1[i + 2 * elemPerVector]); - npyv_@sfx@ v3 = npyv_load_@sfx@(&ip1[i + 3 * elemPerVector]); - npyv_@sfx@ v4 = npyv_load_@sfx@(&ip1[i + 4 * elemPerVector]); - npyv_@sfx@ v5 = npyv_load_@sfx@(&ip1[i + 5 * elemPerVector]); - #endif - npyv_@sfx@ u0 = npyv_load_@sfx@(&ip2[i + 0 * elemPerVector]); - npyv_@sfx@ u1 = npyv_load_@sfx@(&ip2[i + 1 * elemPerVector]); - #if NPY_SIMD_WIDTH == 128 - npyv_@sfx@ u2 = npyv_load_@sfx@(&ip2[i + 2 * elemPerVector]); - npyv_@sfx@ u3 = npyv_load_@sfx@(&ip2[i + 3 * elemPerVector]); - npyv_@sfx@ u4 = npyv_load_@sfx@(&ip2[i + 4 * elemPerVector]); - npyv_@sfx@ u5 = npyv_load_@sfx@(&ip2[i + 5 * elemPerVector]); - #endif - npyv_@sfx@ m0 = V_INTRIN(v0, u0); - npyv_@sfx@ m1 = V_INTRIN(v1, u1); - #if NPY_SIMD_WIDTH == 128 - npyv_@sfx@ m2 = V_INTRIN(v2, u2); - npyv_@sfx@ m3 = V_INTRIN(v3, u3); - npyv_@sfx@ m4 = V_INTRIN(v4, u4); - npyv_@sfx@ m5 = V_INTRIN(v5, u5); - #endif - npyv_store_@sfx@(&op1[i + 0 * elemPerVector], m0); - npyv_store_@sfx@(&op1[i + 1 * elemPerVector], m1); - #if NPY_SIMD_WIDTH == 128 - npyv_store_@sfx@(&op1[i + 2 * elemPerVector], m2); - npyv_store_@sfx@(&op1[i + 3 * elemPerVector], m3); - npyv_store_@sfx@(&op1[i + 4 * elemPerVector], m4); - npyv_store_@sfx@(&op1[i + 5 * elemPerVector], m5); - #endif - } - for (; (i+elemPerVector) <= len; i += elemPerVector) { - npyv_@sfx@ v0 = npyv_load_@sfx@(ip1 + i); - npyv_@sfx@ u0 = npyv_load_@sfx@(ip2 + i); - npyv_@sfx@ m0 = V_INTRIN(v0, u0); - npyv_store_@sfx@(op1 + i, m0); - } - // Scalar - finish up any remaining iterations - for (; i < len; ++i) { - const npyv_lanetype_@sfx@ in1 = ip1[i]; - const npyv_lanetype_@sfx@ in2 = ip2[i]; - op1[i] = SCALAR_OP(in1, in2); - } -} -// non-contiguous for float 32/64-bit memory access -#if @is_fp@ && !defined(NPY_HAVE_NEON) -// unroll scalars faster than non-contiguous vector load/store on Arm -static inline void -simd_binary_@intrin@_@sfx@(const npyv_lanetype_@sfx@ *ip1, npy_intp sip1, - const npyv_lanetype_@sfx@ *ip2, npy_intp sip2, - npyv_lanetype_@sfx@ *op1, npy_intp sop1, - npy_intp len) -{ - const int vstep = npyv_nlanes_@sfx@; - for (; len >= vstep; len -= vstep, ip1 += sip1*vstep, - ip2 += sip2*vstep, op1 += sop1*vstep - ) { - npyv_@sfx@ a, b; - if (sip1 == 1) { - a = npyv_load_@sfx@(ip1); - } else { - a = npyv_loadn_@sfx@(ip1, sip1); - } - if (sip2 == 1) { - b = npyv_load_@sfx@(ip2); - } else { - b = npyv_loadn_@sfx@(ip2, sip2); - } - npyv_@sfx@ r = V_INTRIN(a, b); - if (sop1 == 1) { - npyv_store_@sfx@(op1, r); - } else { - npyv_storen_@sfx@(op1, sop1, r); - } - } - for (; len > 0; --len, ip1 += sip1, ip2 += sip2, op1 += sop1) { - const npyv_lanetype_@sfx@ a = *ip1; - const npyv_lanetype_@sfx@ b = *ip2; - *op1 = SCALAR_OP(a, b); - } -} -#endif - -#undef V_INTRIN -#undef V_REDUCE_INTRIN - -#endif // simd_chk && (!fp_only || (is_fp && fp_only)) - -#undef SCALAR_OP -/**end repeat1**/ -/**end repeat**/ - -/******************************************************************************* - ** Defining ufunc inner functions - ******************************************************************************/ -/**begin repeat - * #TYPE = UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * BYTE, SHORT, INT, LONG, LONGLONG, - * FLOAT, DOUBLE, LONGDOUBLE# - * - * #BTYPE = BYTE, SHORT, INT, LONG, LONGLONG, - * BYTE, SHORT, INT, LONG, LONGLONG, - * FLOAT, DOUBLE, LONGDOUBLE# - * #type = npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, - * npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_float, npy_double, npy_longdouble# - * - * #is_fp = 0*10, 1*3# - * #is_unsigned = 1*5, 0*5, 0*3# - * #scalar_sfx = i*10, f, d, l# - */ -#undef TO_SIMD_SFX -#if 0 -/**begin repeat1 - * #len = 8, 16, 32, 64# - */ -#elif NPY_SIMD && NPY_BITSOF_@BTYPE@ == @len@ - #if @is_fp@ - #define TO_SIMD_SFX(X) X##_f@len@ - #if NPY_BITSOF_@BTYPE@ == 32 && !NPY_SIMD_F32 - #undef TO_SIMD_SFX - #endif - #if NPY_BITSOF_@BTYPE@ == 64 && !NPY_SIMD_F64 - #undef TO_SIMD_SFX - #endif - #elif @is_unsigned@ - #define TO_SIMD_SFX(X) X##_u@len@ - #else - #define TO_SIMD_SFX(X) X##_s@len@ - #endif -/**end repeat1**/ -#endif - -/**begin repeat1 - * # kind = maximum, minimum, fmax, fmin# - * # intrin = max, min, maxp, minp# - * # fp_only = 0, 0, 1, 1# - */ -#if !@fp_only@ || (@is_fp@ && @fp_only@) -#define SCALAR_OP scalar_@intrin@_@scalar_sfx@ - -NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@kind@) -(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) -{ - char *ip1 = args[0], *ip2 = args[1], *op1 = args[2]; - npy_intp is1 = steps[0], is2 = steps[1], os1 = steps[2], - len = dimensions[0]; - npy_intp i = 0; -#ifdef TO_SIMD_SFX - #undef STYPE - #define STYPE TO_SIMD_SFX(npyv_lanetype) - if (IS_BINARY_REDUCE) { - // reduce and contiguous - if (is2 == sizeof(@type@)) { - TO_SIMD_SFX(simd_reduce_c_@intrin@)( - (STYPE*)ip2, (STYPE*)op1, len - ); - goto clear_fp; - } - } - else if (!is_mem_overlap(ip1, is1, op1, os1, len) && - !is_mem_overlap(ip2, is2, op1, os1, len) - ) { - // no overlap and operands are binary contiguous - if (IS_BINARY_CONT(@type@, @type@)) { - TO_SIMD_SFX(simd_binary_ccc_@intrin@)( - (STYPE*)ip1, (STYPE*)ip2, (STYPE*)op1, len - ); - goto clear_fp; - } - // unroll scalars faster than non-contiguous vector load/store on Arm - #if !defined(NPY_HAVE_NEON) && @is_fp@ - if (TO_SIMD_SFX(npyv_loadable_stride)(is1) && - TO_SIMD_SFX(npyv_loadable_stride)(is2) && - TO_SIMD_SFX(npyv_storable_stride)(os1) - ) { - TO_SIMD_SFX(simd_binary_@intrin@)( - (STYPE*)ip1, is1/sizeof(STYPE), - (STYPE*)ip2, is2/sizeof(STYPE), - (STYPE*)op1, os1/sizeof(STYPE), len - ); - goto clear_fp; - } - #endif - } -#endif // TO_SIMD_SFX -#ifndef NPY_DISABLE_OPTIMIZATION - // scalar unrolls - if (IS_BINARY_REDUCE) { - // Note, 8x unroll was chosen for best results on Apple M1 - npy_intp elemPerLoop = 8; - if((i+elemPerLoop) <= len){ - @type@ m0 = *((@type@ *)(ip2 + (i + 0) * is2)); - @type@ m1 = *((@type@ *)(ip2 + (i + 1) * is2)); - @type@ m2 = *((@type@ *)(ip2 + (i + 2) * is2)); - @type@ m3 = *((@type@ *)(ip2 + (i + 3) * is2)); - @type@ m4 = *((@type@ *)(ip2 + (i + 4) * is2)); - @type@ m5 = *((@type@ *)(ip2 + (i + 5) * is2)); - @type@ m6 = *((@type@ *)(ip2 + (i + 6) * is2)); - @type@ m7 = *((@type@ *)(ip2 + (i + 7) * is2)); - - i += elemPerLoop; - for(; (i+elemPerLoop)<=len; i+=elemPerLoop){ - @type@ v0 = *((@type@ *)(ip2 + (i + 0) * is2)); - @type@ v1 = *((@type@ *)(ip2 + (i + 1) * is2)); - @type@ v2 = *((@type@ *)(ip2 + (i + 2) * is2)); - @type@ v3 = *((@type@ *)(ip2 + (i + 3) * is2)); - @type@ v4 = *((@type@ *)(ip2 + (i + 4) * is2)); - @type@ v5 = *((@type@ *)(ip2 + (i + 5) * is2)); - @type@ v6 = *((@type@ *)(ip2 + (i + 6) * is2)); - @type@ v7 = *((@type@ *)(ip2 + (i + 7) * is2)); - - m0 = SCALAR_OP(m0, v0); - m1 = SCALAR_OP(m1, v1); - m2 = SCALAR_OP(m2, v2); - m3 = SCALAR_OP(m3, v3); - m4 = SCALAR_OP(m4, v4); - m5 = SCALAR_OP(m5, v5); - m6 = SCALAR_OP(m6, v6); - m7 = SCALAR_OP(m7, v7); - } - - m0 = SCALAR_OP(m0, m1); - m2 = SCALAR_OP(m2, m3); - m4 = SCALAR_OP(m4, m5); - m6 = SCALAR_OP(m6, m7); - - m0 = SCALAR_OP(m0, m2); - m4 = SCALAR_OP(m4, m6); - - m0 = SCALAR_OP(m0, m4); - - *((@type@ *)op1) = SCALAR_OP(*((@type@ *)op1), m0); - } - } else{ - // Note, 4x unroll was chosen for best results on Apple M1 - npy_intp elemPerLoop = 4; - for(; (i+elemPerLoop)<=len; i+=elemPerLoop){ - /* Note, we can't just load all, do all ops, then store all here. - * Sometimes ufuncs are called with `accumulate`, which makes the - * assumption that previous iterations have finished before next - * iteration. For example, the output of iteration 2 depends on the - * result of iteration 1. - */ - - /**begin repeat2 - * #unroll = 0, 1, 2, 3# - */ - @type@ v@unroll@ = *((@type@ *)(ip1 + (i + @unroll@) * is1)); - @type@ u@unroll@ = *((@type@ *)(ip2 + (i + @unroll@) * is2)); - *((@type@ *)(op1 + (i + @unroll@) * os1)) = SCALAR_OP(v@unroll@, u@unroll@); - /**end repeat2**/ - } - } -#endif // NPY_DISABLE_OPTIMIZATION - ip1 += is1 * i; - ip2 += is2 * i; - op1 += os1 * i; - for (; i < len; ++i, ip1 += is1, ip2 += is2, op1 += os1) { - const @type@ in1 = *(@type@ *)ip1; - const @type@ in2 = *(@type@ *)ip2; - *((@type@ *)op1) = SCALAR_OP(in1, in2); - } -#ifdef TO_SIMD_SFX -clear_fp: - npyv_cleanup(); -#endif -#if @is_fp@ - npy_clear_floatstatus_barrier((char*)dimensions); -#endif -} - - -NPY_NO_EXPORT int NPY_CPU_DISPATCH_CURFX(@TYPE@_@kind@_indexed) -(PyArrayMethod_Context *NPY_UNUSED(context), char *const *args, npy_intp const *dimensions, npy_intp const *steps, NpyAuxData *NPY_UNUSED(func)) -{ - char *ip1 = args[0]; - char *indxp = args[1]; - char *value = args[2]; - npy_intp is1 = steps[0], isindex = steps[1], isb = steps[2]; - npy_intp n = dimensions[0]; - npy_intp shape = steps[3]; - npy_intp i; - @type@ *indexed; - for(i = 0; i < n; i++, indxp += isindex, value += isb) { - npy_intp indx = *(npy_intp *)indxp; - if (indx < 0) { - indx += shape; - } - indexed = (@type@ *)(ip1 + is1 * indx); - *indexed = SCALAR_OP(*indexed, *(@type@ *)value); - } - return 0; -} - -#undef SCALAR_OP - -#endif // !fp_only || (is_fp && fp_only) -/**end repeat1**/ -/**end repeat**/ - diff --git a/numpy/_core/src/umath/loops_minmax.dispatch.cpp b/numpy/_core/src/umath/loops_minmax.dispatch.cpp new file mode 100644 index 000000000000..09afaf7ff7c8 --- /dev/null +++ b/numpy/_core/src/umath/loops_minmax.dispatch.cpp @@ -0,0 +1,525 @@ +#include "loops_utils.h" +#include "loops.h" + +#include +#include +#include "simd/simd.hpp" +#include "numpy/npy_common.h" +#include "common.hpp" +#include "fast_loop_macros.h" + +namespace { +using namespace np::simd; + +template struct OpMax { + using Degraded = std::conditional_t, OpMax, OpMax>; +#if NPY_HWY + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& a, const V& b) const { + if constexpr (std::is_floating_point_v) { + return hn::IfThenElse(hn::IsEitherNaN(a, b), Set(NAN), hn::Max(a, b)); + } else { + return hn::Max(a, b); + } + } + + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& v) const { + if constexpr (std::is_floating_point_v) { + return hn::AllFalse(_Tag(), hn::IsNaN(v)) ? hn::ReduceMax(_Tag(), v) : NAN; + } else { + return hn::ReduceMax(_Tag(), v); + } + } +#endif + + NPY_INLINE T operator()(T a, T b) const { + if constexpr (std::is_floating_point_v) { + return (a >= b || npy_isnan(a)) ? a : b; + } else { + return a > b ? a : b; + } + } +}; + +template struct OpMin { + using Degraded = std::conditional_t, OpMin, OpMin>; +#if NPY_HWY + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& a, const V& b) const { + if constexpr (std::is_floating_point_v) { + return hn::IfThenElse(hn::IsEitherNaN(a, b), Set(NAN), hn::Min(a, b)); + } else { + return hn::Min(a, b); + } + } + + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& v) const { + if constexpr (std::is_floating_point_v) { + return hn::AllFalse(_Tag(), hn::IsNaN(v)) ? hn::ReduceMin(_Tag(), v) : NAN; + } else { + return hn::ReduceMin(_Tag(), v); + } + } +#endif + + NPY_INLINE T operator()(T a, T b) const { + if constexpr (std::is_floating_point_v) { + return (a <= b || npy_isnan(a)) ? a : b; + } else { + return a < b ? a : b; + } + } +}; + +template >> struct OpMaxp { + using Degraded = std::conditional_t, OpMaxp, OpMaxp>; +#if NPY_HWY + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& a, const V& b) const { + return hn::Max(hn::IfThenElse(hn::IsNaN(a), b, a), hn::IfThenElse(hn::IsNaN(b), a, b)); + } + + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& v) const { + auto m = hn::IsNaN(v); + return hn::AllTrue(_Tag(), m) ? NAN : hn::MaskedReduceMax(_Tag(), hn::Not(m), v); + } +#endif + + NPY_INLINE T operator()(T a, T b) const { + if constexpr (std::is_same_v) { + return fmaxf(a, b); + } else if constexpr (std::is_same_v) { + return fmax(a, b); + } else { + return fmaxl(a, b); + } + } +}; + +template >> struct OpMinp { + using Degraded = std::conditional_t, OpMinp, OpMinp>; +#if NPY_HWY + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& a, const V& b) const { + return hn::Min(hn::IfThenElse(hn::IsNaN(a), b, a), hn::IfThenElse(hn::IsNaN(b), a, b)); + } + + template >, typename V = Vec> + HWY_INLINE HWY_ATTR auto operator()(const V& v) const { + auto m = hn::IsNaN(v); + return hn::AllTrue(_Tag(), m) ? NAN : hn::MaskedReduceMin(_Tag(), hn::Not(m), v); + } +#endif + + NPY_INLINE T operator()(T a, T b) const { + if constexpr (std::is_same_v) { + return fminf(a, b); + } else if constexpr (std::is_same_v) { + return fmin(a, b); + } else { + return fminl(a, b); + } + } +}; + +#if NPY_HWY +template > +HWY_INLINE HWY_ATTR auto LoadWithStride(const T* src, npy_intp ssrc) { + auto index = hn::Mul(hn::Iota(_Tag(), 0), Set(ssrc)); + return hn::GatherIndex(_Tag(), src, index); +} + +template > +HWY_INLINE HWY_ATTR void StoreWithStride(Vec vec, T* dst, npy_intp sdst) { + auto index = hn::Mul(hn::Iota(_Tag(), 0), Set(sdst)); + hn::ScatterIndex(vec, _Tag(), dst, index); +} + +/******************************************************************************** + ** Defining the SIMD kernels + ********************************************************************************/ + +// contiguous input. +template +HWY_INLINE HWY_ATTR void +simd_reduce_c(const T* ip, T* op1, npy_intp len) +{ + const OP op_func; + if (len < 1) { + return; + } + const int vstep = Lanes(); + const int wstep = vstep*8; + auto acc = Set(op1[0]); + for (; len >= wstep; len -= wstep, ip += wstep) { + /* + * error: '_mm_prefetch' needs target feature mmx on clang-cl + */ +#if !(defined(_MSC_VER) && defined(__clang__)) + hwy::Prefetch(ip + wstep); +#endif + auto v0 = LoadU(ip + vstep * 0); + auto v1 = LoadU(ip + vstep * 1); + auto v2 = LoadU(ip + vstep * 2); + auto v3 = LoadU(ip + vstep * 3); + + auto v4 = LoadU(ip + vstep * 4); + auto v5 = LoadU(ip + vstep * 5); + auto v6 = LoadU(ip + vstep * 6); + auto v7 = LoadU(ip + vstep * 7); + + auto r01 = op_func(v0, v1); + auto r23 = op_func(v2, v3); + auto r45 = op_func(v4, v5); + auto r67 = op_func(v6, v7); + acc = op_func(acc, op_func(op_func(r01, r23), op_func(r45, r67))); + } + for (; len >= vstep; len -= vstep, ip += vstep) { + acc = op_func(acc, LoadU(ip)); + } + T r = op_func(acc); + // Scalar - finish up any remaining iterations + for (; len > 0; --len, ++ip) { + const T in2 = *ip; + r = op_func(r, in2); + } + op1[0] = r; +} + +// contiguous inputs and output. +template +HWY_INLINE HWY_ATTR void +simd_binary_ccc(const T*ip1, const T*ip2, + T*op1, npy_intp len) +{ + const OP op_func; +#if HWY_MAX_BYTES == 16 + // Note, 6x unroll was chosen for best results on Apple M1 + const int vectorsPerLoop = 6; +#else + // To avoid memory bandwidth bottleneck + const int vectorsPerLoop = 2; +#endif + const int elemPerVector = Lanes(); + int elemPerLoop = vectorsPerLoop * elemPerVector; + + npy_intp i = 0; + + for (; (i+elemPerLoop) <= len; i += elemPerLoop) { + auto v0 = LoadU(&ip1[i + 0 * elemPerVector]); + auto v1 = LoadU(&ip1[i + 1 * elemPerVector]); + #if HWY_MAX_BYTES == 16 + auto v2 = LoadU(&ip1[i + 2 * elemPerVector]); + auto v3 = LoadU(&ip1[i + 3 * elemPerVector]); + auto v4 = LoadU(&ip1[i + 4 * elemPerVector]); + auto v5 = LoadU(&ip1[i + 5 * elemPerVector]); + #endif + auto u0 = LoadU(&ip2[i + 0 * elemPerVector]); + auto u1 = LoadU(&ip2[i + 1 * elemPerVector]); + #if HWY_MAX_BYTES == 16 + auto u2 = LoadU(&ip2[i + 2 * elemPerVector]); + auto u3 = LoadU(&ip2[i + 3 * elemPerVector]); + auto u4 = LoadU(&ip2[i + 4 * elemPerVector]); + auto u5 = LoadU(&ip2[i + 5 * elemPerVector]); + #endif + auto m0 = op_func(v0, u0); + auto m1 = op_func(v1, u1); + #if HWY_MAX_BYTES == 16 + auto m2 = op_func(v2, u2); + auto m3 = op_func(v3, u3); + auto m4 = op_func(v4, u4); + auto m5 = op_func(v5, u5); + #endif + StoreU(m0, &op1[i + 0 * elemPerVector]); + StoreU(m1, &op1[i + 1 * elemPerVector]); + #if HWY_MAX_BYTES == 16 + StoreU(m2, &op1[i + 2 * elemPerVector]); + StoreU(m3, &op1[i + 3 * elemPerVector]); + StoreU(m4, &op1[i + 4 * elemPerVector]); + StoreU(m5, &op1[i + 5 * elemPerVector]); + #endif + } + for (; (i+elemPerVector) <= len; i += elemPerVector) { + auto v0 = LoadU(ip1 + i); + auto u0 = LoadU(ip2 + i); + auto m0 = op_func(v0, u0); + StoreU(m0, op1 + i); + } + // Scalar - finish up any remaining iterations + for (; i < len; ++i) { + const T in1 = ip1[i]; + const T in2 = ip2[i]; + op1[i] = op_func(in1, in2); + } +} + +// non-contiguous for float 32/64-bit memory access +template +HWY_INLINE HWY_ATTR void +simd_binary(const T* ip1, npy_intp sip1, + const T* ip2, npy_intp sip2, + T* op1, npy_intp sop1, + npy_intp len) +{ + const OP op_func; + const int vstep = Lanes(); + for (; len >= vstep; len -= vstep, ip1 += sip1*vstep, + ip2 += sip2*vstep, op1 += sop1*vstep + ) { + Vec a, b; + if (sip1 == 1) { + a = LoadU(ip1); + } else { + a = LoadWithStride(ip1, sip1); + } + if (sip2 == 1) { + b = LoadU(ip2); + } else { + b = LoadWithStride(ip2, sip2); + } + auto r = op_func(a, b); + if (sop1 == 1) { + StoreU(r, op1); + } else { + StoreWithStride(r, op1, sop1); + } + } + for (; len > 0; --len, ip1 += sip1, ip2 += sip2, op1 += sop1) { + const T a = *ip1; + const T b = *ip2; + *op1 = op_func(a, b); + } +} +#endif // NPY_HWY + +template , double, T>> +HWY_INLINE HWY_ATTR void +minmax(char **args, npy_intp const*dimensions, npy_intp const*steps) +{ + const OP op_func; + char *ip1 = args[0], *ip2 = args[1], *op1 = args[2]; + npy_intp is1 = steps[0], is2 = steps[1], os1 = steps[2], + len = dimensions[0]; + npy_intp i = 0; +#if NPY_HWY + if constexpr (kSupportLane) { + if (IS_BINARY_REDUCE) { + // reduce and contiguous + if (is2 == sizeof(T)) { + simd_reduce_c( + (D*)ip2, (D*)op1, len + ); + goto clear_fp; + } + } + else if (!is_mem_overlap(ip1, is1, op1, os1, len) && + !is_mem_overlap(ip2, is2, op1, os1, len) + ) { + // no overlap and operands are binary contiguous + if (IS_BINARY_CONT(T, T)) { + simd_binary_ccc( + (D*)ip1, (D*)ip2, (D*)op1, len + ); + goto clear_fp; + } + // unroll scalars faster than non-contiguous vector load/store on Arm + #if !HWY_TARGET_IS_NEON + if constexpr (std::is_floating_point_v) { + if (alignof(T) == sizeof(T) && is1 % sizeof(T) == 0 && is2 % sizeof(T) == 0 && os1 % sizeof(T) == 0) { + simd_binary( + (D*)ip1, is1/sizeof(D), + (D*)ip2, is2/sizeof(D), + (D*)op1, os1/sizeof(D), len + ); + goto clear_fp; + } + } + #endif + } + } +#endif +#ifndef NPY_DISABLE_OPTIMIZATION + // scalar unrolls + if (IS_BINARY_REDUCE) { + // Note, 8x unroll was chosen for best results on Apple M1 + npy_intp elemPerLoop = 8; + if((i+elemPerLoop) <= len){ + T m0 = *((T*)(ip2 + (i + 0) * is2)); + T m1 = *((T*)(ip2 + (i + 1) * is2)); + T m2 = *((T*)(ip2 + (i + 2) * is2)); + T m3 = *((T*)(ip2 + (i + 3) * is2)); + T m4 = *((T*)(ip2 + (i + 4) * is2)); + T m5 = *((T*)(ip2 + (i + 5) * is2)); + T m6 = *((T*)(ip2 + (i + 6) * is2)); + T m7 = *((T*)(ip2 + (i + 7) * is2)); + + i += elemPerLoop; + for(; (i+elemPerLoop)<=len; i+=elemPerLoop){ + T v0 = *((T*)(ip2 + (i + 0) * is2)); + T v1 = *((T*)(ip2 + (i + 1) * is2)); + T v2 = *((T*)(ip2 + (i + 2) * is2)); + T v3 = *((T*)(ip2 + (i + 3) * is2)); + T v4 = *((T*)(ip2 + (i + 4) * is2)); + T v5 = *((T*)(ip2 + (i + 5) * is2)); + T v6 = *((T*)(ip2 + (i + 6) * is2)); + T v7 = *((T*)(ip2 + (i + 7) * is2)); + + m0 = op_func(m0, v0); + m1 = op_func(m1, v1); + m2 = op_func(m2, v2); + m3 = op_func(m3, v3); + m4 = op_func(m4, v4); + m5 = op_func(m5, v5); + m6 = op_func(m6, v6); + m7 = op_func(m7, v7); + } + + m0 = op_func(m0, m1); + m2 = op_func(m2, m3); + m4 = op_func(m4, m5); + m6 = op_func(m6, m7); + + m0 = op_func(m0, m2); + m4 = op_func(m4, m6); + + m0 = op_func(m0, m4); + + *((T*)op1) = op_func(*((T*)op1), m0); + } + } else{ + // Note, 4x unroll was chosen for best results on Apple M1 + npy_intp elemPerLoop = 4; + for(; (i+elemPerLoop)<=len; i+=elemPerLoop){ + /* Note, we can't just load all, do all ops, then store all here. + * Sometimes ufuncs are called with `accumulate`, which makes the + * assumption that previous iterations have finished before next + * iteration. For example, the output of iteration 2 depends on the + * result of iteration 1. + */ + + T v0 = *((T*)(ip1 + (i + 0) * is1)); + T u0 = *((T*)(ip2 + (i + 0) * is2)); + *((T*)(op1 + (i + 0) * os1)) = op_func(v0, u0); + T v1 = *((T*)(ip1 + (i + 1) * is1)); + T u1 = *((T*)(ip2 + (i + 1) * is2)); + *((T*)(op1 + (i + 1) * os1)) = op_func(v1, u1); + T v2 = *((T*)(ip1 + (i + 2) * is1)); + T u2 = *((T*)(ip2 + (i + 2) * is2)); + *((T*)(op1 + (i + 2) * os1)) = op_func(v2, u2); + T v3 = *((T*)(ip1 + (i + 3) * is1)); + T u3 = *((T*)(ip2 + (i + 3) * is2)); + *((T*)(op1 + (i + 3) * os1)) = op_func(v3, u3); + } + } +#endif // NPY_DISABLE_OPTIMIZATION + ip1 += is1 * i; + ip2 += is2 * i; + op1 += os1 * i; + for (; i < len; ++i, ip1 += is1, ip2 += is2, op1 += os1) { + const T in1 = *(T*)ip1; + const T in2 = *(T*)ip2; + *((T*)op1) = op_func(in1, in2); + } + + goto clear_fp; // suppress warnings +clear_fp: + if constexpr (std::is_floating_point_v) { + npy_clear_floatstatus_barrier((char*)dimensions); + } +} + +template +HWY_INLINE HWY_ATTR void +minmax_indexed(char *const* args, npy_intp const*dimensions, npy_intp const*steps) +{ + const OP op_func; + char *ip1 = args[0]; + char *indxp = args[1]; + char *value = args[2]; + npy_intp is1 = steps[0], isindex = steps[1], isb = steps[2]; + npy_intp n = dimensions[0]; + npy_intp shape = steps[3]; + npy_intp i; + T *indexed; + for(i = 0; i < n; i++, indxp += isindex, value += isb) { + npy_intp indx = *(npy_intp *)indxp; + if (indx < 0) { + indx += shape; + } + indexed = (T* )(ip1 + is1 * indx); + *indexed = op_func(*indexed, *(T*)value); + } +} + +} // anonymous namespace + +/******************************************************************************* + ** Defining ufunc inner functions + *******************************************************************************/ +#define DEFINE_UNARY_MINMAX_FUNCTION(TYPE, KIND, INTR, T) \ +NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(TYPE##_##KIND) \ +(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) \ +{ \ + using FixedType = typename np::meta::FixedWidth::Type; \ + minmax>(args, dimensions, steps); \ +} \ +NPY_NO_EXPORT int NPY_CPU_DISPATCH_CURFX(TYPE##_##KIND##_indexed) \ +(PyArrayMethod_Context *NPY_UNUSED(context), char *const *args, \ + npy_intp const *dimensions, npy_intp const *steps, NpyAuxData *NPY_UNUSED(func)) \ +{ \ + using FixedType = typename np::meta::FixedWidth::Type; \ + minmax_indexed>(args, dimensions, steps); \ + return 0; \ +} + +#define DEFINE_UNARY_MINMAX_FUNCTION_LD(TYPE, KIND, INTR) \ +NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(TYPE##_##KIND) \ +(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) \ +{ \ + minmax>(args, dimensions, steps); \ +} \ +NPY_NO_EXPORT int NPY_CPU_DISPATCH_CURFX(TYPE##_##KIND##_indexed) \ +(PyArrayMethod_Context *NPY_UNUSED(context), char *const *args, \ + npy_intp const *dimensions, npy_intp const *steps, NpyAuxData *NPY_UNUSED(func)) \ +{ \ + minmax_indexed>(args, dimensions, steps); \ + return 0; \ +} + +DEFINE_UNARY_MINMAX_FUNCTION(UBYTE, maximum, Max, npy_ubyte) +DEFINE_UNARY_MINMAX_FUNCTION(USHORT, maximum, Max, npy_ushort) +DEFINE_UNARY_MINMAX_FUNCTION(UINT, maximum, Max, npy_uint) +DEFINE_UNARY_MINMAX_FUNCTION(ULONG, maximum, Max, npy_ulong) +DEFINE_UNARY_MINMAX_FUNCTION(ULONGLONG, maximum, Max, npy_ulonglong) +DEFINE_UNARY_MINMAX_FUNCTION(UBYTE, minimum, Min, npy_ubyte) +DEFINE_UNARY_MINMAX_FUNCTION(USHORT, minimum, Min, npy_ushort) +DEFINE_UNARY_MINMAX_FUNCTION(UINT, minimum, Min, npy_uint) +DEFINE_UNARY_MINMAX_FUNCTION(ULONG, minimum, Min, npy_ulong) +DEFINE_UNARY_MINMAX_FUNCTION(ULONGLONG, minimum, Min, npy_ulonglong) +DEFINE_UNARY_MINMAX_FUNCTION(BYTE, maximum, Max, npy_byte) +DEFINE_UNARY_MINMAX_FUNCTION(SHORT, maximum, Max, npy_short) +DEFINE_UNARY_MINMAX_FUNCTION(INT, maximum, Max, npy_int) +DEFINE_UNARY_MINMAX_FUNCTION(LONG, maximum, Max, npy_long) +DEFINE_UNARY_MINMAX_FUNCTION(LONGLONG, maximum, Max, npy_longlong) +DEFINE_UNARY_MINMAX_FUNCTION(BYTE, minimum, Min, npy_byte) +DEFINE_UNARY_MINMAX_FUNCTION(SHORT, minimum, Min, npy_short) +DEFINE_UNARY_MINMAX_FUNCTION(INT, minimum, Min, npy_int) +DEFINE_UNARY_MINMAX_FUNCTION(LONG, minimum, Min, npy_long) +DEFINE_UNARY_MINMAX_FUNCTION(LONGLONG, minimum, Min, npy_longlong) +DEFINE_UNARY_MINMAX_FUNCTION(FLOAT, maximum, Max, npy_float) +DEFINE_UNARY_MINMAX_FUNCTION(DOUBLE, maximum, Max, npy_double) +DEFINE_UNARY_MINMAX_FUNCTION_LD(LONGDOUBLE, maximum, Max) +DEFINE_UNARY_MINMAX_FUNCTION(FLOAT, fmax, Maxp, npy_float) +DEFINE_UNARY_MINMAX_FUNCTION(DOUBLE, fmax, Maxp, npy_double) +DEFINE_UNARY_MINMAX_FUNCTION_LD(LONGDOUBLE, fmax, Maxp) +DEFINE_UNARY_MINMAX_FUNCTION(FLOAT, minimum, Min, npy_float) +DEFINE_UNARY_MINMAX_FUNCTION(DOUBLE, minimum, Min, npy_double) +DEFINE_UNARY_MINMAX_FUNCTION_LD(LONGDOUBLE, minimum, Min) +DEFINE_UNARY_MINMAX_FUNCTION(FLOAT, fmin, Minp, npy_float) +DEFINE_UNARY_MINMAX_FUNCTION(DOUBLE, fmin, Minp, npy_double) +DEFINE_UNARY_MINMAX_FUNCTION_LD(LONGDOUBLE, fmin, Minp) +#undef DEFINE_UNARY_MINMAX_FUNCTION +#undef DEFINE_UNARY_MINMAX_FUNCTION_LD diff --git a/numpy/_core/src/umath/loops_trigonometric.dispatch.cpp b/numpy/_core/src/umath/loops_trigonometric.dispatch.cpp index ae696db4cd4a..d298a8596cc4 100644 --- a/numpy/_core/src/umath/loops_trigonometric.dispatch.cpp +++ b/numpy/_core/src/umath/loops_trigonometric.dispatch.cpp @@ -3,7 +3,9 @@ #include "loops_utils.h" #include "simd/simd.h" +#include "simd/simd.hpp" #include + namespace hn = hwy::HWY_NAMESPACE; /* @@ -184,7 +186,7 @@ simd_sincos_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst, "larger than 256 bits."); simd_maski = ((uint8_t *)&simd_maski)[0]; #endif - float NPY_DECL_ALIGNED(NPY_SIMD_WIDTH) ip_fback[hn::Lanes(f32)]; + float NPY_DECL_ALIGNED(NPY_SIMD_WIDTH) ip_fback[hn::MaxLanes(f32)]; hn::Store(x_in, f32, ip_fback); // process elements using libc for large elements