diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h new file mode 100644 index 000000000000..548f9f5c7d3b --- /dev/null +++ b/csrc/includes/memory_access_utils.h @@ -0,0 +1,665 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#pragma once + +#include + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#endif + +// Discuss basic GPUs to support +#if __CUDA_ARCH__ >= 530 +#define PTX_AVAILABLE +#endif + +/////////////////////////////// Memory Access Utils /////////////////////////////// +namespace mem_access { + +enum class LoadPolicy { + CacheAll, // Cache at all levels + CacheGlobal, // Cache at L2 only + CacheStreaming // Cache with evict first policy +}; + +enum class StorePolicy { + Writeback, // Cache in L1, write-back on eviction + CacheGlobal, // Bypass L1, write-back on eviction + CacheStreaming // Allocate cache line with evict first policy +}; + +template +__device__ __forceinline__ void load_global(void* dst, const void* src); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void load_shared(void* dst, const void* src); + +template +__device__ __forceinline__ void store_global(void* dst, const void* src); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void store_shared(void* dst, const void* src); + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl); + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate); + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate); + +__device__ __forceinline__ void memcpy_async_fence(); + +template +__device__ __forceinline__ void memcpy_async_wait(); + +template +__device__ __forceinline__ void tail_complete_wait(int remaining_stages); +#endif + +// Util for tracking pipeline buffers +// TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE +template +class BufferTracker { +public: + int current_state; + + __device__ __forceinline__ BufferTracker() : current_state(0) {} + + __device__ __forceinline__ int get() + { + int return_val = current_state++; + current_state = (current_state == max ? 0 : current_state); + return return_val; + } +}; + +__device__ __forceinline__ uint32_t lane_id() +{ +#ifdef PTX_AVAILABLE + unsigned int lane_id; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +#else + return threadIdx.x & (warpSize - 1); // Portable +#endif +} + +/////////// Load Global /////////// +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +/////////// Load Shared /////////// +namespace internal { + +#ifdef PTX_AVAILABLE +__device__ __forceinline__ unsigned convert_to_shared(const void* ptr) +{ +#if __CUDACC_VER_MAJOR__ >= 11 + // In CUDA 11 we have a builtin intrinsic + return __cvta_generic_to_shared(ptr); +#else + unsigned ret_val; + asm volatile( + "{\n" + "\t.reg .u64 p1;\n" + "\tcvta.to.shared.u64 p1, %1\n" + "\tcvt.u32.u64 %0, p1;\n" + "}\n" + : "=r"(ret_val) + : "l"(ptr)); + return ret_val; +#endif +} +#endif + +} // namespace internal + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +/////////// Store Global /////////// + +template <> +__device__ __forceinline__ void store_global<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Store Shared /////////// + +template <> +__device__ __forceinline__ void store_shared<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Asynchronous Memory Copy /////////// + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? AccessSize : 0); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +template +__device__ __forceinline__ void memcpy_async_zero_nop(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? AccessSize : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +// Cache global variants. Separate interface to require deliberate use of them. +__device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? 16 : 0); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" + : + : "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? 16 : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); } + +template +__device__ __forceinline__ void memcpy_async_wait() +{ + static_assert(stages <= 8); + + asm volatile("cp.async.wait_group %0;\n" : : "n"(stages)); +} + +// TODO: The tail complete should be a known compile time artifact, should try and induce this +// without all of the branches from the call-site. This is a hacky solution. +template <> +__device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages) +{ + if (remaining_stages == 0) memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages) +{ + if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages) +{ + if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages) +{ + if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages) +{ + if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages) +{ + if (remaining_stages == 5) + memcpy_async_wait<5>(); + else if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} +#endif + +} // namespace mem_access diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 8a34bb2017f1..e7279d77e985 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -1,4 +1,8 @@ -#include "custom_cuda_layers.h" +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#include "inference_cuda_layers.h" #ifndef __HIP_PLATFORM_HCC__ #include diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 3409f7ba7de8..3018845bc3b8 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -1,4 +1,8 @@ -#include "custom_cuda_layers.h" +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#include "inference_cuda_layers.h" #define MAX_QUANTIZE_GROUPING 1024 diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 94f5ad3b51c5..f6ac9386caaf 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -1,4 +1,9 @@ -#include "custom_cuda_layers.h" +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" namespace cg = cooperative_groups; #define MAX_CAP 4 @@ -16,25 +21,21 @@ __global__ void fused_bias_gelu(float* input, int total_count, int intermediate_size) { - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; + // Input restriction: intermediate_size % vals_per_access == 0 + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(float); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; + float data[vals_per_access]; + float data_bias[vals_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global(data_bias, bias + (offset % intermediate_size)); - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { data[i] = gelu(data[i] + data_bias[i]); } - input_cast[offset] = data; + mem_access::store_global(input + offset, data); } } @@ -43,40 +44,28 @@ __global__ void fused_bias_gelu(__half* input, int total_count, int intermediate_size) { + // Input restriction: intermediate_size % vals_per_access == 0 + // This kernel doubles the per-thread ALU workload as compared to the float implementation #ifdef HALF_PRECISION_AVAILABLE - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(__half); + int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % intermediate_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); + // Divide by 2 since we store two values per __half2 + __half2 data[vals_per_access / 2]; + __half2 bias_data[vals_per_access / 2]; + mem_access::load_global(data, input + offset); + mem_access::load_global(bias_data, bias + (offset % intermediate_size)); + +#pragma unroll + for (int i = 0; i < vals_per_access / 2; i++) { + float2 data_f = __half22float2(data[i]); + float2 bias_f = __half22float2(bias_data[i]); + data[i] = __floats2half2_rn(gelu(data_f.x + bias_f.x), gelu(data_f.y + bias_f.y)); + } - input_cast[offset] = vals_vec; + mem_access::store_global(input + offset, data); } #endif } @@ -88,13 +77,16 @@ void launch_bias_gelu(T* input, int batch_size, cudaStream_t stream) { - int total_count = batch_size * (intermediate_size / 4); - int threads = 1024; // intermediate_size / iterations / 4; + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); + dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); fused_bias_gelu<<>>( - input, bias, total_count, intermediate_size / 4); + input, bias, total_count, intermediate_size); } template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu index 22c23011ede8..aff5411dc8e1 100644 --- a/csrc/transformer/inference/csrc/normalize.cu +++ b/csrc/transformer/inference/csrc/normalize.cu @@ -1,5 +1,9 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #include -#include "custom_cuda_layers.h" +#include "inference_cuda_layers.h" #ifndef __HIP_PLATFORM_HCC__ #include diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index a4202043bbd8..affe3b0dd3f7 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1,10 +1,13 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ #include #include #include -#include "context.h" -#include "cublas_wrappers.h" -#include "custom_cuda_layers.h" +#include "inference_context.h" +#include "inference_cublas_wrappers.h" +#include "inference_cuda_layers.h" std::array gemm_algos = std::array({99, 99, 99}); diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu index 87011f65ea92..0472d0db3490 100644 --- a/csrc/transformer/inference/csrc/relu.cu +++ b/csrc/transformer/inference/csrc/relu.cu @@ -1,4 +1,8 @@ -#include "custom_cuda_layers.h" +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#include "inference_cuda_layers.h" #define MAX_CAP 4 #define MAX_SEQ 2048 diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index dcc90c3b5cbb..cc7c784913d7 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -1,5 +1,9 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #include -#include "custom_cuda_layers.h" +#include "inference_cuda_layers.h" #ifndef __HIP_PLATFORM_HCC__ #include diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index dd7adb7a0508..9230516238c0 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -1,7 +1,11 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #ifndef __HIP_PLATFORM_HCC__ #include #endif -#include "custom_cuda_layers.h" +#include "inference_cuda_layers.h" namespace cg = cooperative_groups; // Bias add diff --git a/csrc/transformer/inference/includes/context.h b/csrc/transformer/inference/includes/inference_context.h similarity index 99% rename from csrc/transformer/inference/includes/context.h rename to csrc/transformer/inference/includes/inference_context.h index ed739000b080..a6f6613fc6a5 100644 --- a/csrc/transformer/inference/includes/context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -1,3 +1,7 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #pragma once #include diff --git a/csrc/transformer/inference/includes/cublas_wrappers.h b/csrc/transformer/inference/includes/inference_cublas_wrappers.h similarity index 99% rename from csrc/transformer/inference/includes/cublas_wrappers.h rename to csrc/transformer/inference/includes/inference_cublas_wrappers.h index 75d18a40fc8e..9e55cc1c7423 100644 --- a/csrc/transformer/inference/includes/cublas_wrappers.h +++ b/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -1,3 +1,7 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #pragma once #include diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h similarity index 99% rename from csrc/transformer/inference/includes/custom_cuda_layers.h rename to csrc/transformer/inference/includes/inference_cuda_layers.h index 32708bfe0b46..d1e10c516c69 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -1,3 +1,7 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #pragma once #ifdef __HIP_PLATFORM_HCC__ diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index b0b86225e97c..5e78ee3ff5d8 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -51,4 +51,4 @@ def extra_ldflags(self): return [] def include_paths(self): - return ['csrc/transformer/inference/includes'] + return ['csrc/transformer/inference/includes', 'csrc/includes'] diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index 773ea6556462..bf0b184fb5fe 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -1,3 +1,7 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" + import pytest import torch import deepspeed