Skip to content

Commit

Permalink
[Hardware][Intel] Support compressed-tensor W8A8 for CPU backend (vll…
Browse files Browse the repository at this point in the history
…m-project#7257)

Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
bigPYJ1151 authored and Alvant committed Oct 26, 2024
1 parent c4425ab commit 508e3d7
Show file tree
Hide file tree
Showing 18 changed files with 686 additions and 43 deletions.
6 changes: 6 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ docker exec cpu-test bash -c "
--ignore=tests/models/test_jamba.py \
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

# Run compressed-tensor test
docker exec cpu-test bash -c "
pytest -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"

# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
Expand Down
18 changes: 17 additions & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

FROM ubuntu:22.04 AS cpu-test-1

ENV CCACHE_DIR=/root/.cache/ccache

ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache

RUN --mount=type=cache,target=/var/cache/apt \
apt-get update -y \
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
Expand All @@ -26,6 +30,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade pip && \
pip install -r requirements-build.txt

# install oneDNN
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git

RUN --mount=type=cache,target=/root/.cache/ccache \
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
-DONEDNN_BUILD_DOC=OFF \
-DONEDNN_BUILD_EXAMPLES=OFF \
-DONEDNN_BUILD_TESTS=OFF \
-DONEDNN_BUILD_GRAPH=OFF \
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
cmake --build ./oneDNN/build --target install --config Release

FROM cpu-test-1 AS build

WORKDIR /workspace/vllm
Expand All @@ -41,7 +58,6 @@ COPY ./ ./
ARG VLLM_CPU_DISABLE_AVX512
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}

ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/ccache \
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
Expand Down
18 changes: 12 additions & 6 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)

#
# Define environment variables for special configurations
Expand Down Expand Up @@ -83,12 +84,7 @@ endif()

message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")

list(APPEND LIBS "numa")


#
# Define extension targets
#
list(APPEND LIBS dnnl numa)

#
# _C extension
Expand All @@ -102,6 +98,16 @@ set(VLLM_EXT_SRC
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")

if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif()

#
# Define extension targets
#

define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down
62 changes: 60 additions & 2 deletions csrc/cpu/cpu_types_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace vec_op {
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
#define CPU_KERNEL_GUARD_OUT(NAME)
#endif

#define FORCE_INLINE __attribute__((always_inline)) inline
Expand Down Expand Up @@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
explicit BF16Vec16(const FP32Vec16 &);

void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }

void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm256_mask_storeu_epi16(ptr, mask, reg);
}
};

#ifdef __AVX512F__
Expand Down Expand Up @@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}

FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
}

FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(_mm512_max_ps(reg, b.reg));
}

FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
}

FP32Vec16 abs() const {
return FP32Vec16(_mm512_abs_ps(reg));
}

float reduce_sum() const { return _mm512_reduce_add_ps(reg); }

float reduce_max() const { return _mm512_reduce_max_ps(reg); }

template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
Expand All @@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}

void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }

void save(float* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm512_mask_storeu_ps(ptr, mask, reg);
}
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
Expand Down Expand Up @@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
};
#endif

#ifdef __AVX512F__
struct INT8Vec16: public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m128i reg;
int8_t values[VEC_ELEM_NUM];
};

__m128i reg;

explicit INT8Vec16(const FP32Vec16& vec) : reg(
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
) {}

void save(int8_t* ptr) const {
_mm_storeu_epi8(ptr, reg);
}

void save(int8_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm_mask_storeu_epi8(ptr, mask, reg);
}
};
#endif

template <typename T> struct VecType { using vec_type = void; };

template <typename T> using vec_t = typename VecType<T>::vec_type;
Expand Down
168 changes: 168 additions & 0 deletions csrc/cpu/dnnl_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#ifndef DNNL_HELPER_HPP
#define DNNL_HELPER_HPP

#include <c10/util/BFloat16.h>

#include "oneapi/dnnl/dnnl.hpp"

namespace {
template <typename T>
struct DNNLType {
static constexpr dnnl::memory::data_type type =
dnnl::memory::data_type::undef;
};

template <>
struct DNNLType<int8_t> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
};

template <>
struct DNNLType<int32_t> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
};

template <>
struct DNNLType<float> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
};

template <>
struct DNNLType<c10::BFloat16> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
};

template <typename T>
constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;
}
}; // namespace

template <bool InputNoScale>
class DNNLPrimitiveHelper {
public:
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
// A: [M, K], row-major
// B: [K, N], column-major
// C: [M, N], row-major
// bias: [N], row-major, optional
// a_scales: [MS]
// b_scales: [NS]
// Note: Due to the limitation of oneDNN
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
// not supported.
template <typename OutputT, typename BiasT>
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
dnnl_dim_t K, const float* a_scales,
const float* b_scales, dnnl_dim_t MS,
dnnl_dim_t NS) {
auto&& OutputType = get_dnnl_type<OutputT>();
auto&& BiasType = get_dnnl_type<BiasT>();

dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});

dnnl::primitive_attr attr;
if constexpr (!InputNoScale) {
if (MS == 1) {
// per-tensor
attr.set_scales_mask(DNNL_ARG_SRC, 0);
} else {
// per-token
TORCH_CHECK(false, "per-token quantization is unsupported.");
}
}

if (NS == 1) {
// per-tensor
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
} else {
// per-channel
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
}

dnnl::matmul::primitive_desc matmul_pd;
if (bias) {
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
bias_md, c_md, attr);
} else {
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
c_md, attr);
}
dnnl::matmul matmul(matmul_pd);

auto& engine = default_engine();

dnnl::memory a_m(a_md, engine, (void*)a);
dnnl::memory b_m(b_md, engine, (void*)b);
dnnl::memory c_m(c_md, engine, (void*)c);
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
(void*)a_scales);
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
(void*)b_scales);

auto& stream = default_stream();
if constexpr (InputNoScale) {
if (bias) {
dnnl::memory::desc bias_md({N}, BiasType, {1});
dnnl::memory bias_m(bias_md, engine, (void*)bias);
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_BIAS, bias_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
} else {
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
}
} else {
if (bias) {
dnnl::memory::desc bias_md({N}, BiasType, {1});
dnnl::memory bias_m(bias_md, engine, (void*)bias);
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_BIAS, bias_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
} else {
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
}
}
stream.wait();
}

private:
static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
return engine;
}

static dnnl::stream& default_stream() {
static dnnl::stream stream(default_engine());
return stream;
}
};

#endif
Loading

0 comments on commit 508e3d7

Please sign in to comment.