From b09a4070722fbe656b76dfe8d49c8cb6c4c0740b Mon Sep 17 00:00:00 2001 From: Akash Kaothalkar Date: Tue, 22 Apr 2025 13:15:54 -0500 Subject: [PATCH 1/2] Feat: int8 w8a8 enablement for ppc64le Signed-off-by: Akash Kaothalkar --- cmake/cpu_extension.cmake | 33 ++- csrc/cpu/cpu_types_vsx.hpp | 506 ++++++++++++++++++++++++++---------- csrc/cpu/quant.cpp | 346 +++++++++++++++++++++++- csrc/cpu/torch_bindings.cpp | 34 +++ 4 files changed, 772 insertions(+), 147 deletions(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 00670bd398b5..fb763db9fc35 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) +elseif(POWER10_FOUND) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.2 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + set(DNNL_CPU_RUNTIME "OMP") + + FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) endif() @@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) +elseif(POWER10_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) endif() # @@ -214,4 +245,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index a8e1be37eb41..236a2b1b4161 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -4,45 +4,46 @@ #include #include +#include #include + + namespace vec_op { // FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD - #define CPU_KERNEL_GUARD_IN(NAME) - #define CPU_KERNEL_GUARD_OUT(NAME) +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) #else - #define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; - #define CPU_KERNEL_GUARD_OUT(NAME) \ - std::cout << #NAME << " exit." << std::endl; +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F&& f) { +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F&& f) { +constexpr void unroll_loop(F &&f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template -struct Vec { +template struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; @@ -62,6 +63,10 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; +typedef struct i32x4x4_t{ + __vector int32_t val[4]; +} i32x4x4_t; + struct FP32Vec8; struct FP32Vec16; @@ -70,14 +75,12 @@ struct BF16Vec8 : public Vec { __vector signed short reg; - explicit BF16Vec8(const void* ptr) - : reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {} + explicit BF16Vec8(const void *ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} - explicit BF16Vec8(const FP32Vec8&); + explicit BF16Vec8(const FP32Vec8 &); - void save(void* ptr) const { - *reinterpret_cast<__vector signed short*>(ptr) = reg; - } + void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } }; struct BF16Vec16 : public Vec { @@ -85,18 +88,40 @@ struct BF16Vec16 : public Vec { ss16x8x2_t reg; - explicit BF16Vec16(const void* ptr) { + explicit BF16Vec16(const void *ptr) { // Load 256 bits in two parts - reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); - reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); + reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); } - explicit BF16Vec16(const FP32Vec16&); + explicit BF16Vec16(const FP32Vec16 &); - void save(void* ptr) const { + void save(void *ptr) const { // Save 256 bits in two parts - vec_xst(reg.val[0], 0, (signed short*)ptr); - vec_xst(reg.val[1], 16, (signed short*)ptr); + vec_xst(reg.val[0], 0, (signed short *)ptr); + vec_xst(reg.val[1], 16, (signed short *)ptr); + } + + void save(void* ptr, const int elem_num) const { + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + + // Calculate elements to store in each 128-bit part (8 elements each) + const int elements_val0 = std::min(clamped_elem, 8); + const int elements_val1 = std::max(clamped_elem - 8, 0); + + // Convert elements to bytes (2 bytes per element) + const size_t bytes_val0 = elements_val0 * sizeof(signed short); + const size_t bytes_val1 = elements_val1 * sizeof(signed short); + + signed short* dest = static_cast(ptr); + // Store the first part using vec_xst_len + if (bytes_val0 > 0) { + vec_xst_len(reg.val[0], dest, bytes_val0); + } + // Store the second part if needed + if (bytes_val1 > 0) { + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + } } }; @@ -106,15 +131,19 @@ struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; ss16x8x4_t reg; - explicit BF16Vec32(const void* ptr) - : reg(*reinterpret_cast(ptr)) {} + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} - explicit BF16Vec32(const BF16Vec8& vec8_data) - : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} + explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ + vec8_data.reg, + vec8_data.reg, + vec8_data.reg, + vec8_data.reg + }) {} - void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } }; struct FP32Vec4 : public Vec { @@ -130,11 +159,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(vec_splats(0.0f)) {} - explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} + explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(__vector float data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -156,19 +185,19 @@ struct FP32Vec8 : public Vec { reg.val[1] = vec_splats(0.0f); } - explicit FP32Vec8(const float* ptr) { + explicit FP32Vec8(const float *ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); } explicit FP32Vec8(f32x4x2_t data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8& data) { + explicit FP32Vec8(const FP32Vec8 &data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; } - explicit FP32Vec8(const BF16Vec8& v) { + explicit FP32Vec8(const BF16Vec8 &v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg); } @@ -177,8 +206,7 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop( - [&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -231,32 +259,87 @@ struct FP32Vec8 : public Vec { return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); } - FP32Vec8 operator*(const FP32Vec8& b) const { - return FP32Vec8( - {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator+(const FP32Vec8& b) const { - return FP32Vec8( - {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator-(const FP32Vec8& b) const { - return FP32Vec8( - {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator/(const FP32Vec8& b) const { - return FP32Vec8( - {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); } - void save(float* ptr) const { + void save(float *ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); } }; + +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + i32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + + i32x4x4_t reg; + + explicit INT32Vec16(const void* data_ptr) { + reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); + reg.val[1] = vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = vec_xl(48, reinterpret_cast(data_ptr)); + } + + void save(int32_t* ptr) const { + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); + } + + void save(int32_t* ptr, const int elem_num) const { + const int elements_in_chunk1 = (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len( + reg.val[0], + reinterpret_cast(ptr), + bytes_chunk1 + ); + vec_xst_len( + reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2 + ); + vec_xst_len( + reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3 + ); + vec_xst_len( + reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4 + ); + } + +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -280,7 +363,7 @@ struct FP32Vec16 : public Vec { reg.val[3] = vec_splats(0.0f); } - explicit FP32Vec16(const float* ptr) { + explicit FP32Vec16(const float *ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); reg.val[2] = vec_xl(32, ptr); @@ -289,76 +372,191 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(f32x4x4_t data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16& data) { + explicit FP32Vec16(const FP32Vec16 &data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[2]; reg.val[3] = data.reg.val[3]; } - explicit FP32Vec16(const FP32Vec4& data) { + explicit FP32Vec16(const FP32Vec4 &data) { reg.val[0] = data.reg; reg.val[1] = data.reg; reg.val[2] = data.reg; reg.val[3] = data.reg; } - explicit FP32Vec16(const FP32Vec8& data) { + explicit FP32Vec16(const FP32Vec8 &data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[0]; reg.val[3] = data.reg.val[1]; } - explicit FP32Vec16(const BF16Vec16& v) { + explicit FP32Vec16(const BF16Vec16 &v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); } - explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + explicit FP32Vec16(const INT32Vec16& v){ + reg.val[0] = vec_ctf(v.reg.val[0], 0); + reg.val[1] = vec_ctf(v.reg.val[1], 0); + reg.val[2] = vec_ctf(v.reg.val[2], 0); + reg.val[3] = vec_ctf(v.reg.val[3], 0); + } + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(f32x4x4_t({ + vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3])) + })); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({ + vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3]) + })); + } - FP32Vec16 operator*(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), - vec_mul(reg.val[1], b.reg.val[1]), - vec_mul(reg.val[2], b.reg.val[2]), - vec_mul(reg.val[3], b.reg.val[3])})); + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({ + vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3]) + })); } - FP32Vec16 operator+(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), - vec_add(reg.val[1], b.reg.val[1]), - vec_add(reg.val[2], b.reg.val[2]), - vec_add(reg.val[3], b.reg.val[3])})); +FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = vec_splats(static_cast(elem_num)); + + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + result.reg.val[0] = vec_sel(this->reg.val[0], vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], vec_min(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 abs() const { + return FP32Vec16(f32x4x4_t({ + vec_abs(reg.val[0]), + vec_abs(reg.val[1]), + vec_abs(reg.val[2]), + vec_abs(reg.val[3]) + })); } - FP32Vec16 operator-(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), - vec_sub(reg.val[1], b.reg.val[1]), - vec_sub(reg.val[2], b.reg.val[2]), - vec_sub(reg.val[3], b.reg.val[3])})); + float reduce_max() { + __vector float max01 = vec_max(reg.val[0], reg.val[1]); + __vector float max23 = vec_max(reg.val[2], reg.val[3]); + __vector float max_all = vec_max(max01, max23); + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); + temp = vec_max(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); } - FP32Vec16 operator/(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), - vec_div(reg.val[1], b.reg.val[1]), - vec_div(reg.val[2], b.reg.val[2]), - vec_div(reg.val[3], b.reg.val[3])})); + float reduce_min() { + __vector float min01 = vec_min(reg.val[0], reg.val[1]); + __vector float min23 = vec_min(reg.val[2], reg.val[3]); + __vector float min_all = vec_min(min01, min23); + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); + temp = vec_min(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); } float reduce_sum() const { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop( - [&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); return result; } - template - float reduce_sub_sum(int idx) { + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); AliasReg ar; @@ -371,70 +569,104 @@ struct FP32Vec16 : public Vec { return result; } - void save(float* ptr) const { + void save(float *ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[3], 48, ptr); } -}; -template -struct VecType { - using vec_type = void; + void save(float* ptr, const int elem_num) const { + const int elements_in_chunk1 = (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = static_cast(elements_in_chunk4 * sizeof(float)); + + vec_xst_len(reg.val[0], ptr, bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), bytes_chunk4); +} + }; -template -using vec_t = typename VecType::vec_type; +struct INT8Vec16 : public Vec { + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 -template <> -struct VecType { - using vec_type = FP32Vec8; -}; + union AliasReg { + __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; + + __vector signed char reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); + + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); + + reg = vec_packs(packed1, packed2); + } -template <> -struct VecType { - using vec_type = BF16Vec8; + void save(void *ptr) const { + *reinterpret_cast<__vector signed char *>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } }; -template -void storeFP32(float v, T* ptr) { - *ptr = v; -} -inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { acc = acc + a * b; } -template <> -inline void storeFP32(float v, c10::BFloat16* ptr) { - c10::BFloat16 __attribute__((__may_alias__))* v_ptr = - reinterpret_cast(&v); +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } #ifndef __VEC_CLASS_FP_NAN - #define __VEC_CLASS_FP_NAN (1 << 6) +#define __VEC_CLASS_FP_NAN (1 << 6) #endif -const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13, - 16, 17, 20, 21, 24, 25, 28, 29}; +const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; #ifndef _ARCH_PWR10 -const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, - 0x00007fff}; -const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, - 0x7fc00000}; -const static __vector unsigned int sh16 = {16, 16, 16, 16}; -const static __vector unsigned int one = {1, 1, 1, 1}; +const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; +const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; +const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; +const static __vector unsigned int one = { 1, 1, 1, 1 }; #endif -inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { #ifdef _ARCH_PWR10 __vector signed short ret[2]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[1]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); reg = vec_perm(ret[0], ret[1], omask); #elif defined(_ARCH_PWR9) __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); @@ -447,10 +679,8 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { __vector unsigned int rnd1 = vec_add(lsb1, bias); inp0 = vec_add(inp0, rnd0); inp1 = vec_add(inp1, rnd1); - __vector __bool int sel0 = - vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = - vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp0 = vec_sr(inp0, sh16); @@ -459,17 +689,13 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { #endif } -inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { #ifdef _ARCH_PWR10 __vector signed short ret[4]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[1]); - ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[2]); - ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16( - (__vector unsigned char)v.reg.val[3]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask); #elif defined(_ARCH_PWR9) @@ -493,14 +719,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { inp1 = vec_add(inp1, rnd1); inp2 = vec_add(inp2, rnd2); inp3 = vec_add(inp3, rnd3); - __vector __bool int sel0 = - vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = - vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - __vector __bool int sel2 = - vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); - __vector __bool int sel3 = - vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp2 = vec_sel(inp2, nan, sel2); @@ -514,10 +736,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { #endif } -inline void prefetch(const void* addr) { +inline void prefetch(const void *addr) { __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 6751e7e55fc5..80047caf6048 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -318,13 +318,294 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, elems_out.save(output + i * hidden_size + j, hidden_size - j); } } +#elif defined(__powerpc64__) +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size){ + + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size){ + + + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size){ + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} #else template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -332,7 +613,7 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -340,7 +621,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") } template @@ -349,7 +630,7 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -611,3 +892,60 @@ void dynamic_scaled_int8_quant( } }); } + +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +){ + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_ppc64le only supports INT8 inputs."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + //We dont need this + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + }); +} + +#endif + diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4e..fae680f166c7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,13 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); +#endif + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -150,6 +157,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL From f89176e5c2e7f229509593a30f6dd6b05753ce39 Mon Sep 17 00:00:00 2001 From: Akash Kaothalkar Date: Tue, 22 Apr 2025 13:32:02 -0500 Subject: [PATCH 2/2] chore: ran pre-commit Signed-off-by: Akash Kaothalkar --- csrc/cpu/cpu_types_vsx.hpp | 567 +++++++++++++++++++----------------- csrc/cpu/quant.cpp | 79 +++-- csrc/cpu/torch_bindings.cpp | 7 +- 3 files changed, 346 insertions(+), 307 deletions(-) diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index 236a2b1b4161..089b9840ea2e 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -7,43 +7,43 @@ #include #include - - namespace vec_op { // FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; @@ -63,7 +63,7 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; -typedef struct i32x4x4_t{ +typedef struct i32x4x4_t { __vector int32_t val[4]; } i32x4x4_t; @@ -75,12 +75,14 @@ struct BF16Vec8 : public Vec { __vector signed short reg; - explicit BF16Vec8(const void *ptr) - : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + explicit BF16Vec8(const void* ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {} - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } }; struct BF16Vec16 : public Vec { @@ -88,39 +90,39 @@ struct BF16Vec16 : public Vec { ss16x8x2_t reg; - explicit BF16Vec16(const void *ptr) { + explicit BF16Vec16(const void* ptr) { // Load 256 bits in two parts - reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); - reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); } - explicit BF16Vec16(const FP32Vec16 &); + explicit BF16Vec16(const FP32Vec16&); - void save(void *ptr) const { + void save(void* ptr) const { // Save 256 bits in two parts - vec_xst(reg.val[0], 0, (signed short *)ptr); - vec_xst(reg.val[1], 16, (signed short *)ptr); + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); } void save(void* ptr, const int elem_num) const { - const int clamped_elem = std::max(0, std::min(elem_num, 16)); - + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + // Calculate elements to store in each 128-bit part (8 elements each) const int elements_val0 = std::min(clamped_elem, 8); const int elements_val1 = std::max(clamped_elem - 8, 0); - + // Convert elements to bytes (2 bytes per element) const size_t bytes_val0 = elements_val0 * sizeof(signed short); const size_t bytes_val1 = elements_val1 * sizeof(signed short); signed short* dest = static_cast(ptr); - // Store the first part using vec_xst_len + // Store the first part using vec_xst_len if (bytes_val0 > 0) { - vec_xst_len(reg.val[0], dest, bytes_val0); + vec_xst_len(reg.val[0], dest, bytes_val0); } - // Store the second part if needed + // Store the second part if needed if (bytes_val1 > 0) { - vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); } } }; @@ -131,19 +133,15 @@ struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; ss16x8x4_t reg; - explicit BF16Vec32(const void *ptr) - : reg(*reinterpret_cast(ptr)) {} + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} - explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ - vec8_data.reg, - vec8_data.reg, - vec8_data.reg, - vec8_data.reg - }) {} + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } }; struct FP32Vec4 : public Vec { @@ -159,11 +157,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(vec_splats(0.0f)) {} - explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(__vector float data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -185,19 +183,19 @@ struct FP32Vec8 : public Vec { reg.val[1] = vec_splats(0.0f); } - explicit FP32Vec8(const float *ptr) { + explicit FP32Vec8(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); } explicit FP32Vec8(f32x4x2_t data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8 &data) { + explicit FP32Vec8(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; } - explicit FP32Vec8(const BF16Vec8 &v) { + explicit FP32Vec8(const BF16Vec8& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg); } @@ -206,7 +204,8 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -259,29 +258,32 @@ struct FP32Vec8 : public Vec { return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); } - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); } }; - struct INT32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -293,9 +295,12 @@ struct INT32Vec16 : public Vec { explicit INT32Vec16(const void* data_ptr) { reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); - reg.val[1] = vec_xl(16, reinterpret_cast(data_ptr)); - reg.val[2] = vec_xl(32, reinterpret_cast(data_ptr)); - reg.val[3] = vec_xl(48, reinterpret_cast(data_ptr)); + reg.val[1] = + vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = + vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = + vec_xl(48, reinterpret_cast(data_ptr)); } void save(int32_t* ptr) const { @@ -306,38 +311,35 @@ struct INT32Vec16 : public Vec { } void save(int32_t* ptr, const int elem_num) const { - const int elements_in_chunk1 = (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; - const int elements_in_chunk2 = (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; - const int elements_in_chunk3 = (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; - const int elements_in_chunk4 = (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; - - const size_t bytes_chunk1 = static_cast(elements_in_chunk1 * sizeof(int32_t)); - const size_t bytes_chunk2 = static_cast(elements_in_chunk2 * sizeof(int32_t)); - const size_t bytes_chunk3 = static_cast(elements_in_chunk3 * sizeof(int32_t)); - const size_t bytes_chunk4 = static_cast(elements_in_chunk4 * sizeof(int32_t)); - - vec_xst_len( - reg.val[0], - reinterpret_cast(ptr), - bytes_chunk1 - ); - vec_xst_len( - reg.val[1], - reinterpret_cast(reinterpret_cast(ptr) + 16), - bytes_chunk2 - ); - vec_xst_len( - reg.val[2], - reinterpret_cast(reinterpret_cast(ptr) + 32), - bytes_chunk3 - ); - vec_xst_len( - reg.val[3], - reinterpret_cast(reinterpret_cast(ptr) + 48), - bytes_chunk4 - ); + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len(reg.val[0], reinterpret_cast(ptr), bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); } - }; struct FP32Vec16 : public Vec { @@ -363,7 +365,7 @@ struct FP32Vec16 : public Vec { reg.val[3] = vec_splats(0.0f); } - explicit FP32Vec16(const float *ptr) { + explicit FP32Vec16(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); reg.val[2] = vec_xl(32, ptr); @@ -372,161 +374,162 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(f32x4x4_t data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) { + explicit FP32Vec16(const FP32Vec16& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[2]; reg.val[3] = data.reg.val[3]; } - explicit FP32Vec16(const FP32Vec4 &data) { + explicit FP32Vec16(const FP32Vec4& data) { reg.val[0] = data.reg; reg.val[1] = data.reg; reg.val[2] = data.reg; reg.val[3] = data.reg; } - explicit FP32Vec16(const FP32Vec8 &data) { + explicit FP32Vec16(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[0]; reg.val[3] = data.reg.val[1]; } - explicit FP32Vec16(const BF16Vec16 &v) { + explicit FP32Vec16(const BF16Vec16& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); } - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - explicit FP32Vec16(const INT32Vec16& v){ + explicit FP32Vec16(const INT32Vec16& v) { reg.val[0] = vec_ctf(v.reg.val[0], 0); reg.val[1] = vec_ctf(v.reg.val[1], 0); reg.val[2] = vec_ctf(v.reg.val[2], 0); reg.val[3] = vec_ctf(v.reg.val[3], 0); } - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_mul(reg.val[0], b.reg.val[0]), - vec_mul(reg.val[1], b.reg.val[1]), - vec_mul(reg.val[2], b.reg.val[2]), - vec_mul(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_add(reg.val[0], b.reg.val[0]), - vec_add(reg.val[1], b.reg.val[1]), - vec_add(reg.val[2], b.reg.val[2]), - vec_add(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_sub(reg.val[0], b.reg.val[0]), - vec_sub(reg.val[1], b.reg.val[1]), - vec_sub(reg.val[2], b.reg.val[2]), - vec_sub(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_div(reg.val[0], b.reg.val[0]), - vec_div(reg.val[1], b.reg.val[1]), - vec_div(reg.val[2], b.reg.val[2]), - vec_div(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); } FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { - return FP32Vec16(f32x4x4_t({ - vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), - vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), - vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), - vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3])) - })); + return FP32Vec16(f32x4x4_t( + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); } FP32Vec16 max(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({ - vec_max(reg.val[0], b.reg.val[0]), - vec_max(reg.val[1], b.reg.val[1]), - vec_max(reg.val[2], b.reg.val[2]), - vec_max(reg.val[3], b.reg.val[3]) - })); + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); } FP32Vec16 max(const FP32Vec16& b, int elem_num) const { - FP32Vec16 result; - - // Create a vector of element indices for each chunk - __vector unsigned int indices = {0, 1, 2, 3}; - __vector unsigned int elem_num_vec = vec_splats(static_cast(elem_num)); - - // Compute masks for each chunk - __vector unsigned int chunk_offset0 = {0, 0, 0, 0}; // Chunk 0: Elements 0-3 - __vector unsigned int chunk_offset1 = {4, 4, 4, 4}; // Chunk 1: Elements 4-7 - __vector unsigned int chunk_offset2 = {8, 8, 8, 8}; // Chunk 2: Elements 8-11 - __vector unsigned int chunk_offset3 = {12, 12, 12, 12}; // Chunk 3: Elements 12-15 - - // Compute masks for each chunk - __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); - __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); - __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); - __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); - - // Apply masks to compute the result for each chunk - result.reg.val[0] = vec_sel(this->reg.val[0], vec_max(this->reg.val[0], b.reg.val[0]), mask0); - result.reg.val[1] = vec_sel(this->reg.val[1], vec_max(this->reg.val[1], b.reg.val[1]), mask1); - result.reg.val[2] = vec_sel(this->reg.val[2], vec_max(this->reg.val[2], b.reg.val[2]), mask2); - result.reg.val[3] = vec_sel(this->reg.val[3], vec_max(this->reg.val[3], b.reg.val[3]), mask3); - - return FP32Vec16(result.reg); - } + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, + 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, + 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, + 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, + 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } FP32Vec16 min(const FP32Vec16& b) const { - return FP32Vec16(f32x4x4_t({ - vec_min(reg.val[0], b.reg.val[0]), - vec_min(reg.val[1], b.reg.val[1]), - vec_min(reg.val[2], b.reg.val[2]), - vec_min(reg.val[3], b.reg.val[3]) - })); + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3])})); } -FP32Vec16 min(const FP32Vec16& b, int elem_num) const { - FP32Vec16 result; + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; - vector unsigned int indices = {0, 1, 2, 3}; - vector unsigned int elem_num_vec = vec_splats(static_cast(elem_num)); + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); - vector unsigned int chunk_offset0 = {0, 0, 0, 0}; - vector unsigned int chunk_offset1 = {4, 4, 4, 4}; - vector unsigned int chunk_offset2 = {8, 8, 8, 8}; - vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; - vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); - vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); - vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); - vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); - result.reg.val[0] = vec_sel(this->reg.val[0], vec_min(this->reg.val[0], b.reg.val[0]), mask0); - result.reg.val[1] = vec_sel(this->reg.val[1], vec_min(this->reg.val[1], b.reg.val[1]), mask1); - result.reg.val[2] = vec_sel(this->reg.val[2], vec_min(this->reg.val[2], b.reg.val[2]), mask2); - result.reg.val[3] = vec_sel(this->reg.val[3], vec_min(this->reg.val[3], b.reg.val[3]), mask3); + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_min(this->reg.val[3], b.reg.val[3]), mask3); - return FP32Vec16(result.reg); - } + return FP32Vec16(result.reg); + } FP32Vec16 abs() const { - return FP32Vec16(f32x4x4_t({ - vec_abs(reg.val[0]), - vec_abs(reg.val[1]), - vec_abs(reg.val[2]), - vec_abs(reg.val[3]) - })); + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); } float reduce_max() { @@ -551,12 +554,14 @@ FP32Vec16 min(const FP32Vec16& b, int elem_num) const { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); AliasReg ar; @@ -569,7 +574,7 @@ FP32Vec16 min(const FP32Vec16& b, int elem_num) const { return result; } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[2], 32, ptr); @@ -577,96 +582,124 @@ FP32Vec16 min(const FP32Vec16& b, int elem_num) const { } void save(float* ptr, const int elem_num) const { - const int elements_in_chunk1 = (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; - const int elements_in_chunk2 = (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; - const int elements_in_chunk3 = (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; - const int elements_in_chunk4 = (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; - - const size_t bytes_chunk1 = static_cast(elements_in_chunk1 * sizeof(float)); - const size_t bytes_chunk2 = static_cast(elements_in_chunk2 * sizeof(float)); - const size_t bytes_chunk3 = static_cast(elements_in_chunk3 * sizeof(float)); - const size_t bytes_chunk4 = static_cast(elements_in_chunk4 * sizeof(float)); + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(float)); vec_xst_len(reg.val[0], ptr, bytes_chunk1); vec_xst_len(reg.val[1], - reinterpret_cast(reinterpret_cast(ptr) + 16), bytes_chunk2); + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); vec_xst_len(reg.val[2], - reinterpret_cast(reinterpret_cast(ptr) + 32), bytes_chunk3); + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); vec_xst_len(reg.val[3], - reinterpret_cast(reinterpret_cast(ptr) + 48), bytes_chunk4); -} - + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } }; struct INT8Vec16 : public Vec { - constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 - - union AliasReg { - __vector signed char reg; - int8_t values[VEC_NUM_ELEM]; - }; + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 + union AliasReg { __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; - explicit INT8Vec16(const FP32Vec16& vec) { - __vector signed int ret[4]; - ret[0] = vec_cts(vec.reg.val[0], 0); - ret[1] = vec_cts(vec.reg.val[1], 0); - ret[2] = vec_cts(vec.reg.val[2], 0); - ret[3] = vec_cts(vec.reg.val[3], 0); + __vector signed char reg; - __vector signed short packed1 = vec_packs(ret[0], ret[1]); - __vector signed short packed2 = vec_packs(ret[2], ret[3]); + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); - reg = vec_packs(packed1, packed2); - } + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); - void save(void *ptr) const { - *reinterpret_cast<__vector signed char *>(ptr) = reg; - } - void save(signed char* ptr, const int elem_num) { - vec_xst_len(reg, ptr, static_cast(elem_num)); - } -}; + reg = vec_packs(packed1, packed2); + } + void save(void* ptr) const { + *reinterpret_cast<__vector signed char*>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } +}; -template struct VecType { using vec_type = void; }; +template +struct VecType { + using vec_type = void; +}; -template using vec_t = typename VecType::vec_type; +template +using vec_t = typename VecType::vec_type; -template <> struct VecType { using vec_type = FP32Vec8; }; +template <> +struct VecType { + using vec_type = FP32Vec8; +}; -template <> struct VecType { using vec_type = BF16Vec8; }; +template <> +struct VecType { + using vec_type = BF16Vec8; +}; -template void storeFP32(float v, T *ptr) { *ptr = v; } +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc = acc + a * b; } -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } #ifndef __VEC_CLASS_FP_NAN -#define __VEC_CLASS_FP_NAN (1 << 6) + #define __VEC_CLASS_FP_NAN (1 << 6) #endif -const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13, + 16, 17, 20, 21, 24, 25, 28, 29}; #ifndef _ARCH_PWR10 -const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; -const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; -const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; -const static __vector unsigned int one = { 1, 1, 1, 1 }; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; #endif -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { #ifdef _ARCH_PWR10 __vector signed short ret[2]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); reg = vec_perm(ret[0], ret[1], omask); #elif defined(_ARCH_PWR9) __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); @@ -679,8 +712,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { __vector unsigned int rnd1 = vec_add(lsb1, bias); inp0 = vec_add(inp0, rnd0); inp1 = vec_add(inp1, rnd1); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp0 = vec_sr(inp0, sh16); @@ -689,13 +724,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { #endif } -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { #ifdef _ARCH_PWR10 __vector signed short ret[4]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); - ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); - ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[3]); reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask); #elif defined(_ARCH_PWR9) @@ -719,10 +758,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { inp1 = vec_add(inp1, rnd1); inp2 = vec_add(inp2, rnd2); inp3 = vec_add(inp3, rnd3); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); - __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = + vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = + vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp2 = vec_sel(inp2, nan, sel2); @@ -736,10 +779,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { #endif } -inline void prefetch(const void *addr) { +inline void prefetch(const void* addr) { __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 80047caf6048..f61dbcc948e8 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -323,8 +323,7 @@ template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, - const int hidden_size){ - + const int hidden_size) { using load_vec_t = typename KernelVecType::load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; @@ -338,7 +337,6 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const cvt_vec_t i8_min_vec(i8_min); const cvt_vec_t i8_max_vec(i8_max); - cvt_vec_t zp_vec; if constexpr (AZP) { zp_vec = cvt_vec_t(static_cast(*azp)); @@ -471,9 +469,7 @@ template void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size){ - - + const int hidden_size) { CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) using load_vec_t = typename KernelVecType::load_vec_type; using azp_adj_load_vec_t = @@ -481,21 +477,18 @@ void static_quant_epilogue(const float* input, scalar_t* output, using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { cvt_vec_t a_scale_vec(a_scale); cvt_vec_t b_scale_vec(*b_scale); cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - int j = 0; for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { cvt_vec_t elems_fp32(input + i * hidden_size + j); azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); cvt_vec_t azp_adj_fp32(azp_adj_vec); - if constexpr (PerChannel) { b_scale_vec = cvt_vec_t(b_scale + j); scale_vec = b_scale_vec * a_scale_vec; @@ -525,7 +518,7 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const float* a_scale, const float* b_scale, const int32_t* azp, const int32_t* azp_adj, const scalar_t* bias, const int num_tokens, - const int hidden_size){ + const int hidden_size) { CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) using load_vec_t = typename KernelVecType::load_vec_type; using azp_adj_load_vec_t = @@ -605,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -613,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, + "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -630,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512/powerpc64 support.") + TORCH_CHECK(false, + "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -894,21 +891,21 @@ void dynamic_scaled_int8_quant( } #if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -){ +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +) { CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) // Checks for conformality TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); + "int8_scaled_mm_ppc64le only supports INT8 inputs."); TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - //We dont need this + b.size(1) == c.size(1)); + // We dont need this TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); @@ -916,36 +913,34 @@ void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-majo TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major TORCH_CHECK(b.stride(0) == 1); // Column-major TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment + b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); } VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); // Compute C_inter=s_b * (A@B) DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); } }); } #endif - diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index fae680f166c7..248b42ab4127 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -20,9 +20,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, #if defined(__powerpc64__) void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); + const torch::Tensor& b, + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); #endif void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,