Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit test for bias add kernel #2298

Merged
merged 6 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void launch_bias_gelu(T* input,
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);

// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc.
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
Expand Down
26 changes: 22 additions & 4 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,22 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
return input_cont;
}

template <typename T>
at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
cmikeh2 marked this conversation as resolved.
Show resolved Hide resolved
{
auto input_cont = input.contiguous();

int bsz = input_cont.size(0) * input_cont.size(1);
int hidden_size = input_cont.size(2);

launch_bias_add((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}

template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
Expand Down Expand Up @@ -1323,25 +1339,27 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
"DeepSpeed attention with fp16 (CUDA)");
m.def("softmax_context_int8",
&ds_softmax_context1<__half>,
"DeepSpeed attention with fp32 (CUDA)");
"DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/ops/transformer/inference/test_bias_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system",
allow_module_level=True)

inference_module = None
torch_minor_version = None


def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)


def run_bias_add_reference(activations, bias):
# Expected behavior is that of casting to float32 internally and using the tanh approximation
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
return activations + bias


def run_bias_add_ds(activations, bias):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
if activations.dtype == torch.float16:
return inference_module.bias_add_fp16(activations, bias)
else:
return inference_module.bias_add_fp32(activations, bias)


@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
def test_bias_add(batch, sequence, channels, dtype):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
bias_ds = torch.randn((channels), dtype=dtype, device='cuda')

activations_ref = activations_ds.clone().detach()
bias_ref = bias_ds.clone().detach()

ds_out = run_bias_add_ds(activations_ds, bias_ds)
ref_out = run_bias_add_reference(activations_ref, bias_ref)
assert allclose(ds_out, ref_out)