Skip to content

Commit 393e962

Browse files
Akashcodes732Akash Kaothalkarmgoin
authored andcommitted
[Hardware][Power] Enable compressed tensor W8A8 INT8 quantization for POWER (vllm-project#17153)
Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 7b32f2b commit 393e962

File tree

4 files changed

+669
-5
lines changed

4 files changed

+669
-5
lines changed

cmake/cpu_extension.cmake

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
167167

168168
FetchContent_MakeAvailable(oneDNN)
169169

170+
list(APPEND LIBS dnnl)
171+
elseif(POWER10_FOUND)
172+
FetchContent_Declare(
173+
oneDNN
174+
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
175+
GIT_TAG v3.7.2
176+
GIT_PROGRESS TRUE
177+
GIT_SHALLOW TRUE
178+
)
179+
180+
set(ONEDNN_LIBRARY_TYPE "STATIC")
181+
set(ONEDNN_BUILD_DOC "OFF")
182+
set(ONEDNN_BUILD_EXAMPLES "OFF")
183+
set(ONEDNN_BUILD_TESTS "OFF")
184+
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
185+
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
186+
set(ONEDNN_BUILD_GRAPH "OFF")
187+
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
188+
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
189+
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
190+
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
191+
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
192+
193+
set(DNNL_CPU_RUNTIME "OMP")
194+
195+
FetchContent_MakeAvailable(oneDNN)
196+
170197
list(APPEND LIBS dnnl)
171198
endif()
172199

@@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
197224
"csrc/cpu/quant.cpp"
198225
"csrc/cpu/shm.cpp"
199226
${VLLM_EXT_SRC})
227+
elseif(POWER10_FOUND)
228+
set(VLLM_EXT_SRC
229+
"csrc/cpu/quant.cpp"
230+
${VLLM_EXT_SRC})
200231
endif()
201232

202233
#
@@ -214,4 +245,4 @@ define_gpu_extension_target(
214245
WITH_SOABI
215246
)
216247

217-
message(STATUS "Enabling C extension.")
248+
message(STATUS "Enabling C extension.")

csrc/cpu/cpu_types_vsx.hpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <altivec.h>
66
#include <cmath>
7+
#include <algorithm>
78
#include <torch/all.h>
89

910
namespace vec_op {
@@ -62,6 +63,10 @@ typedef struct f32x4x4_t {
6263
__vector float val[4];
6364
} f32x4x4_t;
6465

66+
typedef struct i32x4x4_t {
67+
__vector int32_t val[4];
68+
} i32x4x4_t;
69+
6570
struct FP32Vec8;
6671
struct FP32Vec16;
6772

@@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
98103
vec_xst(reg.val[0], 0, (signed short*)ptr);
99104
vec_xst(reg.val[1], 16, (signed short*)ptr);
100105
}
106+
107+
void save(void* ptr, const int elem_num) const {
108+
const int clamped_elem = std::max(0, std::min(elem_num, 16));
109+
110+
// Calculate elements to store in each 128-bit part (8 elements each)
111+
const int elements_val0 = std::min(clamped_elem, 8);
112+
const int elements_val1 = std::max(clamped_elem - 8, 0);
113+
114+
// Convert elements to bytes (2 bytes per element)
115+
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
116+
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
117+
118+
signed short* dest = static_cast<signed short*>(ptr);
119+
// Store the first part using vec_xst_len
120+
if (bytes_val0 > 0) {
121+
vec_xst_len(reg.val[0], dest, bytes_val0);
122+
}
123+
// Store the second part if needed
124+
if (bytes_val1 > 0) {
125+
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
126+
}
127+
}
101128
};
102129

103130
const static __vector signed short zero = vec_splats((signed short)0);
@@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
257284
}
258285
};
259286

287+
struct INT32Vec16 : public Vec<INT32Vec16> {
288+
constexpr static int VEC_ELEM_NUM = 16;
289+
union AliasReg {
290+
i32x4x4_t reg;
291+
int32_t values[VEC_ELEM_NUM];
292+
};
293+
294+
i32x4x4_t reg;
295+
296+
explicit INT32Vec16(const void* data_ptr) {
297+
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
298+
reg.val[1] =
299+
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
300+
reg.val[2] =
301+
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
302+
reg.val[3] =
303+
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
304+
}
305+
306+
void save(int32_t* ptr) const {
307+
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
308+
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
309+
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
310+
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
311+
}
312+
313+
void save(int32_t* ptr, const int elem_num) const {
314+
const int elements_in_chunk1 =
315+
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
316+
const int elements_in_chunk2 =
317+
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
318+
const int elements_in_chunk3 =
319+
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
320+
const int elements_in_chunk4 =
321+
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
322+
323+
const size_t bytes_chunk1 =
324+
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
325+
const size_t bytes_chunk2 =
326+
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
327+
const size_t bytes_chunk3 =
328+
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
329+
const size_t bytes_chunk4 =
330+
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
331+
332+
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
333+
vec_xst_len(reg.val[1],
334+
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
335+
bytes_chunk2);
336+
vec_xst_len(reg.val[2],
337+
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
338+
bytes_chunk3);
339+
vec_xst_len(reg.val[3],
340+
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
341+
bytes_chunk4);
342+
}
343+
};
344+
260345
struct FP32Vec16 : public Vec<FP32Vec16> {
261346
constexpr static int VEC_ELEM_NUM = 16;
262347
union AliasReg {
@@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
319404

320405
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
321406

407+
explicit FP32Vec16(const INT32Vec16& v) {
408+
reg.val[0] = vec_ctf(v.reg.val[0], 0);
409+
reg.val[1] = vec_ctf(v.reg.val[1], 0);
410+
reg.val[2] = vec_ctf(v.reg.val[2], 0);
411+
reg.val[3] = vec_ctf(v.reg.val[3], 0);
412+
}
413+
322414
FP32Vec16 operator*(const FP32Vec16& b) const {
323415
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
324416
vec_mul(reg.val[1], b.reg.val[1]),
@@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
347439
vec_div(reg.val[3], b.reg.val[3])}));
348440
}
349441

442+
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
443+
return FP32Vec16(f32x4x4_t(
444+
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
445+
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
446+
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
447+
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
448+
}
449+
450+
FP32Vec16 max(const FP32Vec16& b) const {
451+
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
452+
vec_max(reg.val[1], b.reg.val[1]),
453+
vec_max(reg.val[2], b.reg.val[2]),
454+
vec_max(reg.val[3], b.reg.val[3])}));
455+
}
456+
457+
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
458+
FP32Vec16 result;
459+
460+
// Create a vector of element indices for each chunk
461+
__vector unsigned int indices = {0, 1, 2, 3};
462+
__vector unsigned int elem_num_vec =
463+
vec_splats(static_cast<unsigned int>(elem_num));
464+
465+
// Compute masks for each chunk
466+
__vector unsigned int chunk_offset0 = {0, 0, 0,
467+
0}; // Chunk 0: Elements 0-3
468+
__vector unsigned int chunk_offset1 = {4, 4, 4,
469+
4}; // Chunk 1: Elements 4-7
470+
__vector unsigned int chunk_offset2 = {8, 8, 8,
471+
8}; // Chunk 2: Elements 8-11
472+
__vector unsigned int chunk_offset3 = {12, 12, 12,
473+
12}; // Chunk 3: Elements 12-15
474+
475+
// Compute masks for each chunk
476+
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
477+
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
478+
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
479+
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
480+
481+
// Apply masks to compute the result for each chunk
482+
result.reg.val[0] = vec_sel(this->reg.val[0],
483+
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
484+
result.reg.val[1] = vec_sel(this->reg.val[1],
485+
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
486+
result.reg.val[2] = vec_sel(this->reg.val[2],
487+
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
488+
result.reg.val[3] = vec_sel(this->reg.val[3],
489+
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
490+
491+
return FP32Vec16(result.reg);
492+
}
493+
494+
FP32Vec16 min(const FP32Vec16& b) const {
495+
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
496+
vec_min(reg.val[1], b.reg.val[1]),
497+
vec_min(reg.val[2], b.reg.val[2]),
498+
vec_min(reg.val[3], b.reg.val[3])}));
499+
}
500+
501+
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
502+
FP32Vec16 result;
503+
504+
vector unsigned int indices = {0, 1, 2, 3};
505+
vector unsigned int elem_num_vec =
506+
vec_splats(static_cast<unsigned int>(elem_num));
507+
508+
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
509+
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
510+
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
511+
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
512+
513+
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
514+
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
515+
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
516+
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
517+
518+
result.reg.val[0] = vec_sel(this->reg.val[0],
519+
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
520+
result.reg.val[1] = vec_sel(this->reg.val[1],
521+
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
522+
result.reg.val[2] = vec_sel(this->reg.val[2],
523+
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
524+
result.reg.val[3] = vec_sel(this->reg.val[3],
525+
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
526+
527+
return FP32Vec16(result.reg);
528+
}
529+
530+
FP32Vec16 abs() const {
531+
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
532+
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
533+
}
534+
535+
float reduce_max() {
536+
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
537+
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
538+
__vector float max_all = vec_max(max01, max23);
539+
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
540+
temp = vec_max(temp, vec_sld(temp, temp, 4));
541+
return vec_extract(temp, 0);
542+
}
543+
544+
float reduce_min() {
545+
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
546+
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
547+
__vector float min_all = vec_min(min01, min23);
548+
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
549+
temp = vec_min(temp, vec_sld(temp, temp, 4));
550+
return vec_extract(temp, 0);
551+
}
552+
350553
float reduce_sum() const {
351554
AliasReg ar;
352555
ar.reg = reg;
@@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
377580
vec_xst(reg.val[2], 32, ptr);
378581
vec_xst(reg.val[3], 48, ptr);
379582
}
583+
584+
void save(float* ptr, const int elem_num) const {
585+
const int elements_in_chunk1 =
586+
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
587+
const int elements_in_chunk2 =
588+
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
589+
const int elements_in_chunk3 =
590+
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
591+
const int elements_in_chunk4 =
592+
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
593+
594+
const size_t bytes_chunk1 =
595+
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
596+
const size_t bytes_chunk2 =
597+
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
598+
const size_t bytes_chunk3 =
599+
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
600+
const size_t bytes_chunk4 =
601+
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
602+
603+
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
604+
vec_xst_len(reg.val[1],
605+
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
606+
bytes_chunk2);
607+
vec_xst_len(reg.val[2],
608+
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
609+
bytes_chunk3);
610+
vec_xst_len(reg.val[3],
611+
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
612+
bytes_chunk4);
613+
}
614+
};
615+
616+
struct INT8Vec16 : public Vec<INT8Vec16> {
617+
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
618+
619+
union AliasReg {
620+
__vector signed char reg;
621+
int8_t values[VEC_NUM_ELEM];
622+
};
623+
624+
__vector signed char reg;
625+
626+
explicit INT8Vec16(const FP32Vec16& vec) {
627+
__vector signed int ret[4];
628+
ret[0] = vec_cts(vec.reg.val[0], 0);
629+
ret[1] = vec_cts(vec.reg.val[1], 0);
630+
ret[2] = vec_cts(vec.reg.val[2], 0);
631+
ret[3] = vec_cts(vec.reg.val[3], 0);
632+
633+
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
634+
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
635+
636+
reg = vec_packs(packed1, packed2);
637+
}
638+
639+
void save(void* ptr) const {
640+
*reinterpret_cast<__vector signed char*>(ptr) = reg;
641+
}
642+
void save(signed char* ptr, const int elem_num) {
643+
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
644+
}
380645
};
381646

382647
template <typename T>

0 commit comments

Comments
 (0)