diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 00670bd398b5..fb763db9fc35 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) +elseif(POWER10_FOUND) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.2 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + set(DNNL_CPU_RUNTIME "OMP") + + FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) endif() @@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) +elseif(POWER10_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) endif() # @@ -214,4 +245,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index a8e1be37eb41..089b9840ea2e 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace vec_op { @@ -62,6 +63,10 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; +typedef struct i32x4x4_t { + __vector int32_t val[4]; +} i32x4x4_t; + struct FP32Vec8; struct FP32Vec16; @@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec { vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr); } + + void save(void* ptr, const int elem_num) const { + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + + // Calculate elements to store in each 128-bit part (8 elements each) + const int elements_val0 = std::min(clamped_elem, 8); + const int elements_val1 = std::max(clamped_elem - 8, 0); + + // Convert elements to bytes (2 bytes per element) + const size_t bytes_val0 = elements_val0 * sizeof(signed short); + const size_t bytes_val1 = elements_val1 * sizeof(signed short); + + signed short* dest = static_cast(ptr); + // Store the first part using vec_xst_len + if (bytes_val0 > 0) { + vec_xst_len(reg.val[0], dest, bytes_val0); + } + // Store the second part if needed + if (bytes_val1 > 0) { + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + } + } }; const static __vector signed short zero = vec_splats((signed short)0); @@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + i32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + + i32x4x4_t reg; + + explicit INT32Vec16(const void* data_ptr) { + reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); + reg.val[1] = + vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = + vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = + vec_xl(48, reinterpret_cast(data_ptr)); + } + + void save(int32_t* ptr) const { + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); + } + + void save(int32_t* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len(reg.val[0], reinterpret_cast(ptr), bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vec_ctf(v.reg.val[0], 0); + reg.val[1] = vec_ctf(v.reg.val[1], 0); + reg.val[2] = vec_ctf(v.reg.val[2], 0); + reg.val[3] = vec_ctf(v.reg.val[3], 0); + } + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]), @@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec { vec_div(reg.val[3], b.reg.val[3])})); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(f32x4x4_t( + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, + 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, + 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, + 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, + 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_min(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 abs() const { + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); + } + + float reduce_max() { + __vector float max01 = vec_max(reg.val[0], reg.val[1]); + __vector float max23 = vec_max(reg.val[2], reg.val[3]); + __vector float max_all = vec_max(max01, max23); + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); + temp = vec_max(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + + float reduce_min() { + __vector float min01 = vec_min(reg.val[0], reg.val[1]); + __vector float min23 = vec_min(reg.val[2], reg.val[3]); + __vector float min_all = vec_min(min01, min23); + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); + temp = vec_min(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec { vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[3], 48, ptr); } + + void save(float* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(float)); + + vec_xst_len(reg.val[0], ptr, bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 + + union AliasReg { + __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; + + __vector signed char reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); + + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); + + reg = vec_packs(packed1, packed2); + } + + void save(void* ptr) const { + *reinterpret_cast<__vector signed char*>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } }; template diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 6751e7e55fc5..f61dbcc948e8 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output, } } +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#elif defined(__powerpc64__) +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} template void dynamic_quant_epilogue(const float* input, scalar_t* output, const float* a_scale, const float* b_scale, @@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, + "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") } template @@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, + "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant( } }); } + +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_ppc64le only supports INT8 inputs."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + // We dont need this + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + }); +} + +#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4e..248b42ab4127 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); +#endif + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL