diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 18db40380577..6e308242ecf0 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -36,7 +36,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/csrc/fp_quantizer/includes/context.h b/csrc/fp_quantizer/includes/context.h new file mode 100644 index 000000000000..5bd9badbcb4f --- /dev/null +++ b/csrc/fp_quantizer/includes/context.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#define WARP_SIZE 32 + +class FPContext { +public: + FPContext() : _seed(42) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + } + + virtual ~FPContext() {} + + static FPContext& Instance() + { + static FPContext _ctx; + return _ctx; + } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + uint64_t _seed; + uint64_t _curr_offset; +}; diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h new file mode 100644 index 000000000000..2204c1ba74fc --- /dev/null +++ b/csrc/fp_quantizer/includes/quantize.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#define QUANT_SWITCH(Q_BITS, ...) \ + [&] { \ + if (12 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (13 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (10 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (11 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (28 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (29 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (6 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (7 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (2 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } \ + }() + +#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \ + [&] { \ + if (12 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (10 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 5; \ + __VA_ARGS__(); \ + } else if (28 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (6 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 3; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + constexpr int CONST_Q_EXPONENT_BITS = 2; \ + __VA_ARGS__(); \ + } \ + }() + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding); + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp new file mode 100644 index 000000000000..4a88ff767636 --- /dev/null +++ b/csrc/fp_quantizer/quantize.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "quantize.h" + +#include +#include +#include + +#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_quantization((C_TYPE*)val.data_ptr(), \ + (uint8_t*)out.data_ptr(), \ + num_groups, \ + group_size, \ + at::cuda::getCurrentCUDAStream(), \ + q_range, \ + q_bits, \ + q_mantisa_bits, \ + stochastic_rounding); \ + } + +at::Tensor quantize(torch::Tensor& val, + int group_size, + int stochastic_rounding, + int q_bits, + int q_mantisa_bits) +{ + int total_elems = at::numel(val); + auto options = at::TensorOptions() + .dtype(torch::kInt8) + .layout(val.layout()) + .device(val.device()) + .requires_grad(false); + float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges + (q_bits == 12 ? 510.0 : // fp12 range + (q_bits == 6 ? 28.0 : // fp6 range + 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 + // in case accuracy is not matching! + int num_groups = total_elems / group_size; + auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options); + + DISPATCH_QUANTIZE(kHalf, __half, 23, 8); +#ifdef BF16_AVAILABLE + DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8); +#endif + + return out; +} + +#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + num_groups, \ + group_size, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void dequantize(torch::Tensor& val, + torch::Tensor& val_q, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7); +#endif +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("quantize", &quantize, "quantize function"); + m.def("dequantize", &dequantize, "dequantize function"); +} diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu new file mode 100644 index 000000000000..37be6cc0657c --- /dev/null +++ b/csrc/fp_quantizer/quantize.cu @@ -0,0 +1,427 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "context.h" +#include "memory_access_utils.h" +#include "quantize.h" +#include "reduction_utils.h" + +#include +#include + +#include +#include + +#include +#include + +using ROp = reduce::ROpType; + +namespace quantization { + +constexpr int access_granularity = 16; +constexpr int quanitzed_access_granularity = 4; +constexpr int quanitzed_access_granularity_6bits = 2; +constexpr int threads = 256; +constexpr int warps = threads / 32; + +} // namespace quantization + +template +__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state) +{ + constexpr uint32_t mantisa_mask = (1 << (_mantisa_bits - q_mantisa_bits)) - 1; + uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask) + : 1 << (_mantisa_bits - q_mantisa_bits - 1); + mantisa += offset; + dst_exponent += (((mantisa & ~mantisa_mask) == (1 << _mantisa_bits)) ? 1 : 0); +} + +template +__device__ void clip(uint32_t& exponent, uint32_t& mantisa) +{ + constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1)); + constexpr uint32_t min_exponent = + (1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1); + if (exponent > max_exponent) { + exponent = max_exponent; + mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10 + } + if (exponent < min_exponent) { + exponent = min_exponent; + mantisa = 0; + } +} + +template +__global__ void apply_quantization(T* val, + uint8_t* q_val, + int group_size, + std::pair seed, + float q_range) +{ + int tidx = threadIdx.x; + int wid = tidx >> 5; + int lane = tidx & 0x1f; + int gid = blockIdx.x * quantization::warps + wid; + + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + constexpr uint32_t load_stride = vector_size * hw_warp_size; + constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size; + const uint32_t thread_offset = lane * vector_size; + const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8); + const uint32_t base_load_offset = gid * group_size + thread_offset; + const uint32_t base_store_offset = + gid * ((group_size * total_q_bits / 8) + 4) + + store_thread_offset; // 4-byte for saving the scale per group + const T* load_base_ptr = val + base_load_offset; + T tmp_buf[unroll * vector_size]; + T cur_max; + reduce::init(&cur_max); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + mem_access::load_global( + &tmp_buf[vector_size * i], load_base_ptr + i * load_stride); + for (int j = 0; j < vector_size; j++) + cur_max = reduce::element(cur_max, __habs(tmp_buf[i * vector_size + j])); + } + } + reduce::_block(tb, warp, &cur_max); + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + uint8_t* store_base_ptr = q_val + base_store_offset; + float scale = (float)q_range / conversion::to(cur_max); +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + uint64_t q_buf = 0; + uint64_t q_buf1 = 0; +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float val_f = conversion::to(tmp_buf[i * vector_size + j]) * scale; + uint32_t* data = reinterpret_cast(&val_f); + uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits; + uint32_t dst_mantisa = (data[0] & _mantisa_mask); + + uint32_t dst_exponent = cur_exponent; + + round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>( + dst_mantisa, dst_exponent, &state); + if (cur_exponent != 0) + clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>( + dst_exponent, dst_mantisa); + + dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6) + q_buf = q_buf | + ((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else if (total_q_bits == 12) { + if (j < 5) + q_buf = + q_buf | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else + q_buf1 = + q_buf1 | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << (j - 5) * total_q_bits); + } + } + if (total_q_bits == 12) { + uint64_t last_nibble_mask = 0xf; + last_nibble_mask = q_buf1 & last_nibble_mask; + q_buf = (last_nibble_mask << 60) | q_buf; + q_buf1 >>= 4; + } + uint8_t* int8_data = reinterpret_cast(&q_buf); + uint8_t* int8_data1 = reinterpret_cast(&q_buf1); + if (total_q_bits == 6) { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits, + int8_data + quantization::quanitzed_access_granularity_6bits); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits * 2, + int8_data + 2 * quantization::quanitzed_access_granularity_6bits); + } else { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + + if (total_q_bits > 4) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity, + int8_data + quantization::quanitzed_access_granularity); + if (total_q_bits == 12) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity * 2, + int8_data1); + } + } + } + } + } + if (lane == 0) { + float q_scale = conversion::to(cur_max) / (float)q_range; + uint8_t* scale_as_int8 = reinterpret_cast(&q_scale); + uint32_t scale_offset = + gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8); + if (total_q_bits != 6) + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + else { + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + mem_access::store_global( + q_val + scale_offset + quantization::quanitzed_access_granularity_6bits, + scale_as_int8 + quantization::quanitzed_access_granularity_6bits); + } + } +} + +template +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size) +{ + int tidx = threadIdx.x; + int wid = tidx >> 5; + int lane = tidx & 0x1f; + int gid = blockIdx.x * quantization::warps + wid; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + constexpr uint32_t load_stride = vector_size * hw_warp_size; + const uint32_t thread_offset = lane * vector_size; + const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8; + const uint32_t base_load_offset = + gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset + const uint32_t base_store_offset = gid * group_size + thread_offset; + const uint8_t* load_base_ptr = val + base_load_offset; + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + T* store_base_ptr = q_val + base_store_offset; + float scale; //= q_scale[gid]; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + uint32_t loading_offset = i * load_stride * quantized_bits / 8; + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr + loading_offset); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global( + int8_data, load_base_ptr + loading_offset); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | + (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global( + store_base_ptr + i * load_stride, store_buf); + } + } +} + +#define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \ + case COUNT: \ + apply_quantization \ + <<>>(val, q_val, group_size, seed, q_range); \ + break; + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding) +{ + const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + const dim3 block(quantization::threads); + + std::pair seed = FPContext::Instance().IncrementOffset(16); + + constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); + + const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] { + switch (copy_unroll) { + LAUNCH_FOR_QUANTIZATION_UNROLL(1) + LAUNCH_FOR_QUANTIZATION_UNROLL(2) + LAUNCH_FOR_QUANTIZATION_UNROLL(3) + LAUNCH_FOR_QUANTIZATION_UNROLL(4) + LAUNCH_FOR_QUANTIZATION_UNROLL(5) + LAUNCH_FOR_QUANTIZATION_UNROLL(6) + } + }); +} +#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \ + template void launch_quantization( \ + T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int); +// fp8(E4M3), nearest-rounding +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); +#endif +INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); + +#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \ + case COUNT: \ + apply_dequantization \ + <<>>(val, q_val, group_size); \ + break; + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + const dim3 block(quantization::threads); + + constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); + const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + switch (copy_unroll) { + LAUNCH_FOR_DEQUANTIZATION_UNROLL(1) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(2) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(3) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(4) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(5) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(6) + } + }); +} +#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ + template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h index 6789714d27c7..cb990b75bbe8 100644 --- a/csrc/includes/memory_access_utils.h +++ b/csrc/includes/memory_access_utils.h @@ -868,6 +868,35 @@ __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(voi #endif } +template <> +__device__ __forceinline__ void store_global<2>(void* dst, const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + /////////// Store Shared /////////// template <> diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index eb8efab77ac1..eb9afb66a894 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -159,6 +159,12 @@ DS_D_INLINE float element(const float lhs, const float rhs) return lhs + rhs; } +template <> +DS_D_INLINE double element(const double lhs, const double rhs) +{ + return lhs + rhs; +} + template <> DS_D_INLINE float element(const float lhs, const float rhs) { @@ -189,6 +195,19 @@ DS_D_INLINE __half element(const __half lhs, const __half rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 element(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} +#endif + template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { @@ -220,6 +239,21 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 element(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __nv_bfloat162 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} +#endif + template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { @@ -295,6 +329,11 @@ DS_D_INLINE float init() { return 0.0f; } +template <> +DS_D_INLINE double init() +{ + return (double)0.0f; +} template <> DS_D_INLINE float init() @@ -331,6 +370,15 @@ DS_D_INLINE __half init() return __half(neg_inf); } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 init() +{ + constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; + return __nv_bfloat16(neg_inf); +} +#endif + template <> DS_D_INLINE __half2 init() { diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py new file mode 100644 index 000000000000..5575f3567185 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .quantize import FP_Quantize diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py new file mode 100644 index 000000000000..5dc3c190ae5d --- /dev/null +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +fp_quant_module = None + + +class FP_Quantize: + + def __init__(self, group_size=512) -> None: + global fp_quant_module + if fp_quant_module is None: + fp_quant_module = FPQuantizerBuilder().load() + + self.group_size = group_size + self.orig_dtype = None + + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + assert input.dtype == torch.bfloat16, "only support bf16 for now" + if return_meta_tensor: + assert q_bits == 8, "meta tensor is only supported with q_bit=8" + + self.orig_dtype = input.dtype + self.orig_shape = input.shape + + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" + + out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) + + if return_meta_tensor: + data, scale = out.split(self.group_size, dim=-1) + return data.contiguous().reshape(input.shape), scale.contiguous() + + return out + + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype, + device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py new file mode 100644 index 000000000000..bafd3e0c33f6 --- /dev/null +++ b/op_builder/fp_quantizer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class FPQuantizerBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FP_QUANTIZER" + NAME = "fp_quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 8: + self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 8: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def sources(self): + return [ + "csrc/fp_quantizer/quantize.cu", + "csrc/fp_quantizer/quantize.cpp", + ] + + def extra_ldflags(self): + return ['-lcurand'] + + def include_paths(self): + return ['csrc/fp_quantizer/includes', 'csrc/includes'] diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index f28c1ecb165c..dd13ac163517 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -10,6 +10,7 @@ pytest<=8.0.0 pytest-forked pytest-randomly pytest-xdist +qtorch==0.3.0 recommonmark sphinx sphinx-rtd-theme diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py new file mode 100644 index 000000000000..101f4cd69811 --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.ops.fp_quantizer import FP_Quantize +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +# warning: this import silently JIT builds a set of kernels and may take a minute +from qtorch.quant import float_quantize + + +def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_size=1024): + ori_dt = input.dtype + ori_shape = input.shape + last_dim = group_size + input = input.view(-1, last_dim) + + q_bits = exp_bits + man_bits + 1 + input_to_float = input.float() + if q_bits == 8: + q_range = 480. + elif q_bits == 6: + q_range = 28. + elif q_bits == 12: + q_range = 510. + else: + assert (0), \ + "Please specify the right quantization range for the selected precision!" + input_max = input_to_float.abs().amax(dim=-1, keepdim=True) + return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \ + input_max / q_range).to(ori_dt)).reshape(ori_shape) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_meta(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + fpq = FP_Quantize(group_size=group_size) + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + ds_x = x.clone() + x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor) + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) +def test_fp_quant(dtype, q_bits): + group_size = 128 + fpq = FP_Quantize(group_size=group_size) + + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits) + + if q_bits == 8: + exp_bits = 4 + man_bits = 3 + elif q_bits == 6: + exp_bits = 3 + man_bits = 2 + elif q_bits == 12: + exp_bits = 4 + man_bits = 7 + else: + raise ValueError(f"unknown {q_bits=}") + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"