Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Access Utility #2276

Merged
merged 12 commits into from
Sep 1, 2022
665 changes: 665 additions & 0 deletions csrc/includes/memory_access_utils.h

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#define MAX_QUANTIZE_GROUPING 1024

Expand Down
92 changes: 42 additions & 50 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;
#define MAX_CAP 4
Expand All @@ -16,25 +21,21 @@ __global__ void fused_bias_gelu(float* input,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];

data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
float data[vals_per_access];
float data_bias[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));

data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] = gelu(data[i] + data_bias[i]); }

input_cast[offset] = data;
mem_access::store_global<granularity>(input + offset, data);
}
}

Expand All @@ -43,40 +44,28 @@ __global__ void fused_bias_gelu(__half* input,
int total_count,
int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
// This kernel doubles the per-thread ALU workload as compared to the float implementation
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;

low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
// Divide by 2 since we store two values per __half2
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % intermediate_size));

#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(gelu(data_f.x + bias_f.x), gelu(data_f.y + bias_f.y));
}

input_cast[offset] = vals_vec;
mem_access::store_global<granularity>(input + offset, data);
}
#endif
}
Expand All @@ -88,13 +77,16 @@ void launch_bias_gelu(T* input,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
constexpr int threads = 1024;
constexpr int granularity = 16;

const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);

fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
input, bias, total_count, intermediate_size);
}

template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/normalize.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <limits>
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
9 changes: 6 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "inference_context.h"
#include "inference_cublas_wrappers.h"
#include "inference_cuda_layers.h"

std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});

Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/relu.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#define MAX_CAP 4
#define MAX_SEQ 2048
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/softmax.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <limits>
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/transform.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"
namespace cg = cooperative_groups;

// Bias add
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#include <c10/cuda/CUDAStream.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#include <assert.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#ifdef __HIP_PLATFORM_HCC__
Expand Down
2 changes: 1 addition & 1 deletion op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def extra_ldflags(self):
return []

def include_paths(self):
return ['csrc/transformer/inference/includes']
return ['csrc/transformer/inference/includes', 'csrc/includes']
4 changes: 4 additions & 0 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""

import pytest
import torch
import deepspeed
Expand Down