Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
66 changes: 66 additions & 0 deletions csrc/fp_quantizer/includes/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <stdlib.h>
#include <sys/time.h>
#include <map>
#include <memory>
#include <stack>
#include <string>
#define WARP_SIZE 32

class FPContext {
public:
FPContext() : _seed(42)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
}

virtual ~FPContext() {}

static FPContext& Instance()
{
static FPContext _ctx;
return _ctx;
}

curandGenerator_t& GetRandGenerator() { return _gen; }

cudaStream_t GetCurrentStream()
{
// get current pytorch stream.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}

void SetSeed(uint64_t new_seed) { _seed = new_seed; }

private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
uint64_t _seed;
uint64_t _curr_offset;
};
115 changes: 115 additions & 0 deletions csrc/fp_quantizer/includes/quantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#pragma once

#include <cuda.h>
#include <stdint.h>

#include <cuda_fp16.h>

#include <cuda_bf16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>

#define QUANT_SWITCH(Q_BITS, ...) \
[&] { \
if (12 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (13 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (10 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (11 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (28 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (29 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (6 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (7 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (2 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} \
}()

#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \
[&] { \
if (12 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 3; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 5; \
__VA_ARGS__(); \
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 7; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 3; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_Q_MANTISA_BITS = 1; \
constexpr int CONST_Q_EXPONENT_BITS = 2; \
__VA_ARGS__(); \
} \
}()

template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
int group_size,
cudaStream_t stream,
float q_range,
int q_bits,
int q_mantisa_bits,
int stochastic_rounding);

template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);
85 changes: 85 additions & 0 deletions csrc/fp_quantizer/quantize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "quantize.h"

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>

#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
(uint8_t*)out.data_ptr(), \
num_groups, \
group_size, \
at::cuda::getCurrentCUDAStream(), \
q_range, \
q_bits, \
q_mantisa_bits, \
stochastic_rounding); \
}

at::Tensor quantize(torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8);
#endif

return out;
}

#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
(C_TYPE*)val.data_ptr(), \
num_groups, \
group_size, \
q_mantisa_bits, \
q_exponent_bits, \
at::cuda::getCurrentCUDAStream()); \
return; \
}

void dequantize(torch::Tensor& val,
torch::Tensor& val_q,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);

int num_groups = total_elems / group_size;

DISPATCH_DEQUANTIZE(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7);
#endif
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
}
Loading