From b8416282072e977fdc0f9c77019852faf586a418 Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Tue, 6 Dec 2022 10:42:32 -0800 Subject: [PATCH] Drop Maxwell Support (#2574) * Officially drop Maxwell support * Formatting * Comparison mismatch fix --- .../inference/csrc/apply_rotary_pos_emb.cu | 4 -- csrc/transformer/inference/csrc/dequantize.cu | 3 - csrc/transformer/inference/csrc/gelu.cu | 67 +++++++------------ csrc/transformer/inference/csrc/relu.cu | 53 ++++----------- csrc/transformer/inference/csrc/softmax.cu | 3 - csrc/transformer/inference/csrc/transform.cu | 19 ------ op_builder/builder.py | 16 ++++- op_builder/transformer_inference.py | 17 +++++ 8 files changed, 72 insertions(+), 110 deletions(-) diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 4a91975a73ca..b898bf92cf71 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -67,7 +67,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, unsigned total_count, int max_out_tokens) { -#if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -102,7 +101,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, lane += WARP_SIZE; } } -#endif } __global__ void apply_rotary_pos_emb1(float* mixed_query, float* key_layer, @@ -159,7 +157,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, unsigned total_count, int max_out_tokens) { -#if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -205,7 +202,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, lane += WARP_SIZE; } } -#endif } template diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 959016bf10e3..33605e1f54e0 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -50,8 +50,6 @@ __global__ void dequantize_kernel(__half* output, unsigned groups, unsigned merge_count) { -#ifdef HALF_PRECISION_AVAILABLE - unsigned merge_hidden = hidden_dim >> merge_count; unsigned quantization_stride = (merge_hidden * output_size) / groups; @@ -75,7 +73,6 @@ __global__ void dequantize_kernel(__half* output, output[q_index] = __float2half(scale_data * (float)q); tid += blockDim.x; } -#endif } template diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 557369abcc60..71a37bb368c7 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -17,6 +17,9 @@ inline __device__ float gelu(const float x) return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); } +/* +In-place gelu(biasAdd(x)) for channels last +*/ template __global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size) { @@ -64,63 +67,51 @@ void launch_bias_gelu(T* input, template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); -// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc. -__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) -{ - constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(float); - const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; - - if (offset < total_count) { - float data[vals_per_access]; - float bias_data[vals_per_access]; - mem_access::load_global(data, input + offset); - mem_access::load_global(bias_data, bias + (offset % hidden_size)); - -#pragma unroll - for (int i = 0; i < vals_per_access; i++) { data[i] += bias_data[i]; } - - mem_access::store_global(input + offset, data); - } -} - -__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) +/* +In-place channels-last bias add +*/ +template +__global__ void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size) { -#ifdef HALF_PRECISION_AVAILABLE + // Input restriction: intermediate_size % vals_per_access == 0 constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(__half); - const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; if (offset < total_count) { - __half2 data[vals_per_access / 2]; - __half2 bias_data[vals_per_access / 2]; + T data[values_per_access]; + T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global(bias_data, bias + (offset % hidden_size)); + mem_access::load_global(data_bias, bias + (offset % intermediate_size)); #pragma unroll - for (int i = 0; i < vals_per_access / 2; i++) { - float2 data_f = __half22float2(data[i]); - float2 bias_f = __half22float2(bias_data[i]); - data[i] = __floats2half2_rn(data_f.x + bias_f.x, data_f.y + bias_f.y); + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(data_f + bias_f); } mem_access::store_global(input + offset, data); } -#endif } template -void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) +void launch_bias_add(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) { constexpr int threads = 1024; constexpr int granularity = 16; - const int total_count = batch_size * hidden_size; + const int total_count = batch_size * intermediate_size; const int elems_per_block = threads * (granularity / sizeof(T)); dim3 block_dims(threads); dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); - fused_bias_add<<>>(input, bias, total_count, hidden_size); + fused_bias_add<<>>( + input, bias, total_count, intermediate_size); } template void launch_bias_add(float*, const float*, int, int, cudaStream_t); @@ -181,8 +172,6 @@ __global__ void fused_bias_residual(__half* residual, const float mp_scale, const bool preln) { -#ifdef HALF_PRECISION_AVAILABLE - float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -241,7 +230,6 @@ __global__ void fused_bias_residual(__half* residual, res_fl2_ptr[offset] = res_fl2; } -#endif } template @@ -325,8 +313,6 @@ __global__ void gptj_residual_add(__half* residual, const int intermediate_size, const float mp_scale) { -#ifdef HALF_PRECISION_AVAILABLE - float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -379,7 +365,6 @@ __global__ void gptj_residual_add(__half* residual, res_fl2_ptr[offset] = res_fl2; } -#endif } template diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu index 26445b74e87c..87e169a9194f 100644 --- a/csrc/transformer/inference/csrc/relu.cu +++ b/csrc/transformer/inference/csrc/relu.cu @@ -2,6 +2,7 @@ Copyright 2022 The Microsoft DeepSpeed Team */ +#include "conversion_utils.h" #include "inference_cuda_layers.h" #include "memory_access_utils.h" @@ -11,58 +12,32 @@ namespace cg = cooperative_groups; inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } -__global__ void fused_bias_relu(float* input, - const float* bias, - int total_count, - int intermediate_size) +/* +In-place relu(biasAdd(x)) for channels last +*/ +template +__global__ void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size) { // Input restriction: intermediate_size % vals_per_access == 0 constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(float); - const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; if (offset < total_count) { - float data[vals_per_access]; - float data_bias[vals_per_access]; + T data[values_per_access]; + T data_bias[values_per_access]; mem_access::load_global(data, input + offset); mem_access::load_global(data_bias, bias + (offset % intermediate_size)); #pragma unroll - for (int i = 0; i < vals_per_access; i++) { data[i] = relu(data[i] + data_bias[i]); } - - mem_access::store_global(input + offset, data); - } -} - -__global__ void fused_bias_relu(__half* input, - const __half* bias, - int total_count, - int intermediate_size) -{ - // Input restriction: intermediate_size % vals_per_access == 0 - // This kernel doubles the per-thread ALU workload as compared to the float implementation -#ifdef HALF_PRECISION_AVAILABLE - constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(__half); - int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; - - if (offset < total_count) { - // Divide by 2 since we store two values per __half2 - __half2 data[vals_per_access / 2]; - __half2 bias_data[vals_per_access / 2]; - mem_access::load_global(data, input + offset); - mem_access::load_global(bias_data, bias + (offset % intermediate_size)); - -#pragma unroll - for (int i = 0; i < vals_per_access / 2; i++) { - float2 data_f = __half22float2(data[i]); - float2 bias_f = __half22float2(bias_data[i]); - data[i] = __floats2half2_rn(relu(data_f.x + bias_f.x), relu(data_f.y + bias_f.y)); + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(relu(data_f + bias_f)); } mem_access::store_global(input + offset, data); } -#endif } template diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index e3790d5c5d40..c5f04176203b 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -48,8 +48,6 @@ __global__ void attn_softmax_v2(__half* vals, int iterations, int reduceWidth) { -#ifdef HALF_PRECISION_AVAILABLE - cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -232,7 +230,6 @@ __global__ void attn_softmax_v2(__half* vals, } } } -#endif } __global__ void attn_softmax_v2(float* vals, diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index a5a43c364ed6..023e02fe1c52 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -90,8 +90,6 @@ __global__ void bias_add_transform_0213(__half* output, // q int head_ext, int max_out_tokens) { -#if __CUDA_ARCH__ >= 700 - unsigned half_dim = (rotary_dim << 3) >> 1; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -146,8 +144,6 @@ __global__ void bias_add_transform_0213(__half* output, // q output_vec[d3] = q; } else output_vec[d3] = vals_vec[d3]; - -#endif } // [B S C*H] - > C * [B A S N] @@ -269,7 +265,6 @@ __global__ void pad_add_transform_0213(__half* output, int heads, int padded_head_size) { -#if __CUDA_ARCH__ >= 700 float4 ZERO; const __half2 zero_h = __float2half2_rn(0.f); __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); @@ -303,8 +298,6 @@ __global__ void pad_add_transform_0213(__half* output, output_vec[d3] = vals_vec[d3]; else output_vec[d3] = ZERO; - -#endif } template @@ -409,8 +402,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output, int heads, int head_ext) { -#ifdef HALF_PRECISION_AVAILABLE - int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; @@ -455,8 +446,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output, output_half[2] = vals_half[2] + bias_half[2]; output_half[3] = vals_half[3] + bias_half[3]; output_vec[d3] = output_arr; - -#endif } __global__ void bias_add_transform_0213_v2(__half* output, @@ -466,7 +455,6 @@ __global__ void bias_add_transform_0213_v2(__half* output, int seq_length, int heads) { -#ifdef HALF_PRECISION_AVAILABLE __shared__ float4 in_data[3072]; int d0_stride = hidden_dim * seq_length; @@ -528,7 +516,6 @@ __global__ void bias_add_transform_0213_v2(__half* output, output_vec[out_index + iter_offset] = in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; } -#endif } template @@ -580,8 +567,6 @@ __global__ void transform4d_0213<__half>(__half* out, int hidden_dim, int head_ext) { -#if __CUDA_ARCH__ >= 700 - int d0_stride = hidden_dim * (seq_length / head_ext); int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; @@ -606,8 +591,6 @@ __global__ void transform4d_0213<__half>(__half* out, out_vec += (d2 * d1_stride * gridDim.y); out_vec[d3] = in_vec[d3]; - -#endif } __global__ void transform4d_0213_v2(__half* out, @@ -616,7 +599,6 @@ __global__ void transform4d_0213_v2(__half* out, int seq_length, int hidden_dim) { -#if __CUDA_ARCH__ >= 700 __shared__ float4 in_data[3072]; int d0_stride = hidden_dim * seq_length; @@ -657,7 +639,6 @@ __global__ void transform4d_0213_v2(__half* out, int iter_id = iter * iteration_stride + iter_index; out_vec[output_offset + iter_id] = in_data[iter_id]; } -#endif } // 3 * [B A S N] - > [B S C*H] diff --git a/op_builder/builder.py b/op_builder/builder.py index b197e7047d4a..6dc87ca7ce80 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -15,6 +15,7 @@ import distutils.sysconfig from distutils.errors import CompileError, LinkError from abc import ABC, abstractmethod +from typing import List YELLOW = '\033[93m' END = '\033[0m' @@ -524,7 +525,7 @@ def compute_capability_args(self, cross_compile_archs=None): - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... - TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... - `cross_compile_archs` uses ; separator. @@ -554,6 +555,12 @@ def compute_capability_args(self, cross_compile_archs=None): cross_compile_archs = get_default_compute_capabilities() ccs = cross_compile_archs.split(';') + ccs = self.filter_ccs(ccs) + if len(ccs) == 0: + raise RuntimeError( + f"Unable to load {self.name} op due to no compute capabilities remaining after filtering" + ) + args = [] for cc in ccs: num = cc[0] + cc[2] @@ -563,6 +570,13 @@ def compute_capability_args(self, cross_compile_archs=None): return args + def filter_ccs(self, ccs: List[str]): + """ + Prune any compute capabilities that are not compatible with the builder. Should log + which CCs have been pruned. + """ + return ccs + def version_dependent_macros(self): # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 version_ge_1_1 = [] diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 14ea975eb1cb..20797b7938a4 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -25,6 +25,11 @@ def is_compatible(self, verbose=True): sys_cuda_major, _ = installed_cuda_version() torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major + if cuda_capability < 6: + self.warning( + "NVIDIA Inference is only supported on Pascal and newer architectures" + ) + cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: self.warning( @@ -32,6 +37,18 @@ def is_compatible(self, verbose=True): cuda_okay = False return super().is_compatible(verbose) and cuda_okay + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + def sources(self): return [ 'csrc/transformer/inference/csrc/pt_binding.cpp',