Skip to content

Commit

Permalink
Fix quantized-inference & Add generic support of checkpoint loading (#…
Browse files Browse the repository at this point in the history
…2547)

* fix checkpoint loading when it is a dictionary

* fix some issues with saving ckpt & int8 inference

* fix quantized-inference & add generic support of checkpoint loading

* remove int8 hard-coded flag

* fix mlp return tensors

* fix several issue to load checkpoints of GPT-J, GPT-NEOX, and OPT with different TP-size

* add more comments & description for checkpoint-loading module

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
RezaYazdaniAminabadi and mrwyattii authored Dec 6, 2022
1 parent b841628 commit 35b350b
Show file tree
Hide file tree
Showing 8 changed files with 473 additions and 134 deletions.
141 changes: 95 additions & 46 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,18 @@ void quantized_gemm(void* output,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz)
int bsz,
int hidden_size)
{
T* weight16 = (T*)Context::Instance().GetWorkSpace() +
12 * Context::Instance().GetMaxTokenLenght() * weight.size(1);

T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;

// auto options = at::TensorOptions()
// .dtype(at::kHalf)
// .layout(at::kStrided)
// .device(at::kCUDA)
// .requires_grad(false);
// auto tmp = torch::empty(weight.sizes(), options);
// T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
Expand Down Expand Up @@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);

if (q_int8) {
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);

T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
(T*)input.data_ptr(),
weight,
q_scale,
q_scale.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}

if (q_int8) {
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
bsz,
Context::Instance().GetCurrentStream());
}

if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
intermediate,
weight1,
q_scale1,
q_scale1.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1449,64 +1467,95 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& weight_scale,
at::Tensor& bias,
at::Tensor& weight_out,
at::Tensor& weight_out_scale,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);

// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options);

int bsz = input.size(0) * input.size(1);

float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
if (q_int8) {
quantized_gemm<T>(intermediate.data_ptr(),
(T*)input.data_ptr(),
weight,
weight_scale,
weight_scale.size(0),
bsz,
input.size(2));
} else {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
intm_dim,
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)intermediate.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());

cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
(T*)intermediate.data_ptr(),
weight_out,
weight_out_scale,
weight_out_scale.size(0),
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
out_size,
bsz,
intm_dim,
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
merge_count,
mlp_extra_grouping)

device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
Expand Down Expand Up @@ -131,7 +131,6 @@ def forward(
if (self.config.fp16 or self.config.q_int8) \
and input.dtype == torch.float:
input = input.half()

with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
Expand Down
33 changes: 30 additions & 3 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, input):


class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__()
if weight is not None:
self.weight = weight
Expand All @@ -33,10 +33,12 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
torch.empty(weight_shape,
dtype=dtype,
device=torch.cuda.current_device()))

self.bias = Parameter(
torch.empty(weight_shape[0],
dtype=dtype,
device=torch.cuda.current_device()))
device=torch.cuda.current_device())) \
if bias is not None else None

def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
Expand All @@ -57,7 +59,7 @@ def forward(self, input):


class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.float):
def __init__(self, weight_shape, dtype=torch.half):
super(EmbeddingLayer, self).__init__()
self.weight = Parameter(
torch.empty(weight_shape[0],
Expand All @@ -67,3 +69,28 @@ def __init__(self, weight_shape, dtype=torch.float):

def forward(self, input):
return F.embedding(input, self.weight)


class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

return super().forward(positions + self.offset)
Loading

0 comments on commit 35b350b

Please sign in to comment.