Skip to content

Commit 6dcb50c

Browse files
jeffrasfc-gh-reyazda
authored andcommitted
FP [6,8,12] quantizer op (deepspeedai#5336)
Flexible-bit quantizer-dequantizer library with fp6/fp12/fp8 support Requires Ampere+ architecture, this is due to the initial focus of this op only on `bfloat16` input types. Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
1 parent f6e9adb commit 6dcb50c

File tree

12 files changed

+1014
-1
lines changed

12 files changed

+1014
-1
lines changed

.github/workflows/nv-pre-compile-ops.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
3737
- name: Compile DeepSpeed Ops
3838
run: |
39-
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
39+
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
4040
- name: DS Report
4141
run: |
4242
ds_report
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// DeepSpeed Team
5+
6+
#pragma once
7+
8+
#include <ATen/cuda/CUDAContext.h>
9+
#include <cuda_runtime_api.h>
10+
#include <cassert>
11+
#include <iostream>
12+
#include <vector>
13+
#include "cublas_v2.h"
14+
#include "cuda.h"
15+
#include "curand.h"
16+
17+
#include <cuda.h>
18+
#include <cuda_runtime_api.h>
19+
#include <stdlib.h>
20+
#include <sys/time.h>
21+
#include <map>
22+
#include <memory>
23+
#include <stack>
24+
#include <string>
25+
#define WARP_SIZE 32
26+
27+
class FPContext {
28+
public:
29+
FPContext() : _seed(42)
30+
{
31+
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
32+
curandSetPseudoRandomGeneratorSeed(_gen, 123);
33+
}
34+
35+
virtual ~FPContext() {}
36+
37+
static FPContext& Instance()
38+
{
39+
static FPContext _ctx;
40+
return _ctx;
41+
}
42+
43+
curandGenerator_t& GetRandGenerator() { return _gen; }
44+
45+
cudaStream_t GetCurrentStream()
46+
{
47+
// get current pytorch stream.
48+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49+
return stream;
50+
}
51+
52+
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
53+
{
54+
uint64_t offset = _curr_offset;
55+
_curr_offset += offset_inc;
56+
return std::pair<uint64_t, uint64_t>(_seed, offset);
57+
}
58+
59+
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
60+
61+
private:
62+
curandGenerator_t _gen;
63+
cublasHandle_t _cublasHandle;
64+
uint64_t _seed;
65+
uint64_t _curr_offset;
66+
};
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// DeepSpeed Team
5+
6+
#pragma once
7+
8+
#include <cuda.h>
9+
#include <stdint.h>
10+
11+
#include <cuda_fp16.h>
12+
13+
#include <cuda_bf16.h>
14+
#include <cuda_runtime_api.h>
15+
#include <stdio.h>
16+
17+
#define QUANT_SWITCH(Q_BITS, ...) \
18+
[&] { \
19+
if (12 == Q_BITS) { \
20+
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
21+
constexpr int CONST_Q_BITS = 8; \
22+
constexpr int CONST_Q_MANTISA_BITS = 3; \
23+
__VA_ARGS__(); \
24+
} else if (13 == Q_BITS) { \
25+
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
26+
constexpr int CONST_Q_BITS = 8; \
27+
constexpr int CONST_Q_MANTISA_BITS = 3; \
28+
__VA_ARGS__(); \
29+
} else if (10 == Q_BITS) { \
30+
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
31+
constexpr int CONST_Q_BITS = 8; \
32+
constexpr int CONST_Q_MANTISA_BITS = 2; \
33+
__VA_ARGS__(); \
34+
} else if (11 == Q_BITS) { \
35+
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
36+
constexpr int CONST_Q_BITS = 8; \
37+
constexpr int CONST_Q_MANTISA_BITS = 2; \
38+
__VA_ARGS__(); \
39+
} else if (28 == Q_BITS) { \
40+
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
41+
constexpr int CONST_Q_BITS = 12; \
42+
constexpr int CONST_Q_MANTISA_BITS = 7; \
43+
__VA_ARGS__(); \
44+
} else if (29 == Q_BITS) { \
45+
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
46+
constexpr int CONST_Q_BITS = 12; \
47+
constexpr int CONST_Q_MANTISA_BITS = 7; \
48+
__VA_ARGS__(); \
49+
} else if (6 == Q_BITS) { \
50+
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
51+
constexpr int CONST_Q_BITS = 6; \
52+
constexpr int CONST_Q_MANTISA_BITS = 2; \
53+
__VA_ARGS__(); \
54+
} else if (7 == Q_BITS) { \
55+
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
56+
constexpr int CONST_Q_BITS = 6; \
57+
constexpr int CONST_Q_MANTISA_BITS = 2; \
58+
__VA_ARGS__(); \
59+
} else if (2 == Q_BITS) { \
60+
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
61+
constexpr int CONST_Q_BITS = 4; \
62+
constexpr int CONST_Q_MANTISA_BITS = 1; \
63+
__VA_ARGS__(); \
64+
} else { \
65+
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
66+
constexpr int CONST_Q_BITS = 4; \
67+
constexpr int CONST_Q_MANTISA_BITS = 1; \
68+
__VA_ARGS__(); \
69+
} \
70+
}()
71+
72+
#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \
73+
[&] { \
74+
if (12 == Q_MANTISA_EXPONENT_BITS) { \
75+
constexpr int CONST_Q_MANTISA_BITS = 3; \
76+
constexpr int CONST_Q_EXPONENT_BITS = 4; \
77+
__VA_ARGS__(); \
78+
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \
79+
constexpr int CONST_Q_MANTISA_BITS = 2; \
80+
constexpr int CONST_Q_EXPONENT_BITS = 5; \
81+
__VA_ARGS__(); \
82+
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \
83+
constexpr int CONST_Q_MANTISA_BITS = 7; \
84+
constexpr int CONST_Q_EXPONENT_BITS = 4; \
85+
__VA_ARGS__(); \
86+
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \
87+
constexpr int CONST_Q_MANTISA_BITS = 2; \
88+
constexpr int CONST_Q_EXPONENT_BITS = 3; \
89+
__VA_ARGS__(); \
90+
} else { \
91+
constexpr int CONST_Q_MANTISA_BITS = 1; \
92+
constexpr int CONST_Q_EXPONENT_BITS = 2; \
93+
__VA_ARGS__(); \
94+
} \
95+
}()
96+
97+
template <typename T, int mantisa, int exponent>
98+
void launch_quantization(T* val,
99+
uint8_t* q_val,
100+
int num_groups,
101+
int group_size,
102+
cudaStream_t stream,
103+
float q_range,
104+
int q_bits,
105+
int q_mantisa_bits,
106+
int stochastic_rounding);
107+
108+
template <typename T, int mantisa>
109+
void launch_dequantization(uint8_t* val,
110+
T* q_val,
111+
int num_groups,
112+
int group_size,
113+
int q_mantisa_bits,
114+
int q_exponent_bits,
115+
cudaStream_t stream);

csrc/fp_quantizer/quantize.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// DeepSpeed Team
5+
6+
#include "quantize.h"
7+
8+
#include <c10/cuda/CUDAStream.h>
9+
#include <torch/extension.h>
10+
#include <vector>
11+
12+
#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
13+
if (val.options().dtype() == torch::T_TYPE) { \
14+
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
15+
(uint8_t*)out.data_ptr(), \
16+
num_groups, \
17+
group_size, \
18+
at::cuda::getCurrentCUDAStream(), \
19+
q_range, \
20+
q_bits, \
21+
q_mantisa_bits, \
22+
stochastic_rounding); \
23+
}
24+
25+
at::Tensor quantize(torch::Tensor& val,
26+
int group_size,
27+
int stochastic_rounding,
28+
int q_bits,
29+
int q_mantisa_bits)
30+
{
31+
int total_elems = at::numel(val);
32+
auto options = at::TensorOptions()
33+
.dtype(torch::kInt8)
34+
.layout(val.layout())
35+
.device(val.device())
36+
.requires_grad(false);
37+
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
38+
(q_bits == 12 ? 510.0 : // fp12 range
39+
(q_bits == 6 ? 28.0 : // fp6 range
40+
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
41+
// in case accuracy is not matching!
42+
int num_groups = total_elems / group_size;
43+
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);
44+
45+
DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
46+
#ifdef BF16_AVAILABLE
47+
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8);
48+
#endif
49+
50+
return out;
51+
}
52+
53+
#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
54+
if (val.options().dtype() == torch::T_TYPE) { \
55+
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
56+
(C_TYPE*)val.data_ptr(), \
57+
num_groups, \
58+
group_size, \
59+
q_mantisa_bits, \
60+
q_exponent_bits, \
61+
at::cuda::getCurrentCUDAStream()); \
62+
return; \
63+
}
64+
65+
void dequantize(torch::Tensor& val,
66+
torch::Tensor& val_q,
67+
int group_size,
68+
int q_mantisa_bits,
69+
int q_exponent_bits)
70+
{
71+
int total_elems = at::numel(val);
72+
73+
int num_groups = total_elems / group_size;
74+
75+
DISPATCH_DEQUANTIZE(kHalf, __half, 10);
76+
#ifdef BF16_AVAILABLE
77+
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7);
78+
#endif
79+
}
80+
81+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
82+
{
83+
m.def("quantize", &quantize, "quantize function");
84+
m.def("dequantize", &dequantize, "dequantize function");
85+
}

0 commit comments

Comments
 (0)