diff --git a/CMakeLists.txt b/CMakeLists.txt index 5350512..cdd7043 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,13 +44,6 @@ if (WITH_OMP) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fopenmp") endif() -IF(NOT (CUDA_VERSION GREATER 10.2)) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") -ENDIF() - -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES") ENDIF() diff --git a/include/detail/gpu_rnnt_kernel.h b/include/detail/gpu_rnnt_kernel.h index f5e5c0d..5e2e7c0 100644 --- a/include/detail/gpu_rnnt_kernel.h +++ b/include/detail/gpu_rnnt_kernel.h @@ -177,6 +177,43 @@ __global__ void compute_grad_kernel(Tp* grads, const Tp* const acts, const Tp* c } } } +template<> +__global__ void compute_grad_kernel<128, half>(half* grads, const half* const acts, const half* const denom, const half* alphas, const half* betas, const half* const logll, const int* const xlen, const int* const ylen, + const int* const mlabels, const int minibatch, const int maxT, const int maxU, const int alphabet_size, const int blank_) { + int tid = threadIdx.x; // alphabet dim + int idx = tid; + int col = blockIdx.x; // mb, t, u + + int u = col % maxU; + int bt = (col - u) / maxU; + int t = bt % maxT; + int mb = (bt - t) / maxT; + + const int T = xlen[mb]; + const int U = ylen[mb] + 1; + const int* labels = mlabels + mb * (maxU - 1); + + if (t < T && u < U) { + while (idx < alphabet_size) { + half logpk = denom[col] + acts[col * alphabet_size + idx]; + // half logpk = logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, idx); + half grad = hexp(alphas[col] + betas[col] + logpk - logll[mb]); + // grad to last blank transition + if (idx == blank_ && t == T-1 && u == U-1) { + grad -= hexp(alphas[col] + logpk - logll[mb]); + } + if (idx == blank_ && t < T-1) { + grad -= hexp(alphas[col] + logpk - logll[mb] + betas[col + maxU]); + } + if (u < U-1 && idx == labels[u]) { + grad -= hexp(alphas[col] + logpk - logll[mb] + betas[col+1]); + } + grads[col * alphabet_size + idx] = grad; + + idx += 128; + } + } +} template __global__ void compute_fastemit_grad_kernel(Tp* grads, const Tp* const acts, const Tp* const denom, const Tp* alphas, const Tp* betas, const Tp* const logll, const int* const xlen, const int* const ylen, @@ -223,4 +260,50 @@ __global__ void compute_fastemit_grad_kernel(Tp* grads, const Tp* const acts, co idx += NT; } } +} +template<> +__global__ void compute_fastemit_grad_kernel<128, half>(half* grads, const half* const acts, const half* const denom, const half* alphas, const half* betas, const half* const logll, const int* const xlen, const int* const ylen, + const int* const mlabels, const int minibatch, const int maxT, const int maxU, const int alphabet_size, const int blank_, const half fastemit_lambda) { + int tid = threadIdx.x; // alphabet dim + int idx = tid; + int col = blockIdx.x; // mb, t, u + + int u = col % maxU; + int bt = (col - u) / maxU; + int t = bt % maxT; + int mb = (bt - t) / maxT; + + const int T = xlen[mb]; + const int U = ylen[mb] + 1; + const int* labels = mlabels + mb * (maxU - 1); + + if (t < T && u < U) { + while (idx < alphabet_size) { + half logpk = denom[col] + acts[col * alphabet_size + idx]; + // half logpk = logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, idx); + half grad = hexp(alphas[col] + betas[col] + logpk - logll[mb]); + + half logy_btu1 = rnnt_helper::neg_inf(); // log(y(t,u)) + log(beta(t, u+1)) + if (u < U-1) { + logy_btu1 = denom[col] + acts[col * alphabet_size + labels[u]] + betas[col+1]; + } + grad += fastemit_lambda * hexp(alphas[col] + logy_btu1 + logpk - logll[mb]); + + // grad to last blank transition + if (idx == blank_ && t == T-1 && u == U-1) { + grad -= hexp(alphas[col] + logpk - logll[mb]); + grad -= fastemit_lambda * hexp(alphas[col] + logy_btu1 + logpk - logll[mb]); + } + if (idx == blank_ && t < T-1) { + grad -= hexp(alphas[col] + logpk - logll[mb] + betas[col + maxU]); + } + if (u < U-1 && idx == labels[u]) { + grad -= hexp(alphas[col] + logpk - logll[mb] + betas[col+1]); + grad -= fastemit_lambda * hexp(alphas[col] + logy_btu1 - logll[mb]); + } + grads[col * alphabet_size + idx] = grad; + + idx += 128; + } + } } \ No newline at end of file diff --git a/include/detail/hostdevice.h b/include/detail/hostdevice.h index 42bc37c..0f63407 100644 --- a/include/detail/hostdevice.h +++ b/include/detail/hostdevice.h @@ -1,7 +1,8 @@ #pragma once #ifdef __CUDACC__ - #define HOSTDEVICE __host__ __device__ +// changed from '__host__ __device__' to '__device__' to call cuda math lib functions for half + #define HOSTDEVICE __device__ #else #define HOSTDEVICE #endif \ No newline at end of file diff --git a/include/detail/reduce.h b/include/detail/reduce.h index d2a6ae1..06a83eb 100644 --- a/include/detail/reduce.h +++ b/include/detail/reduce.h @@ -72,6 +72,11 @@ __global__ void reduce_rows(Iop f, Rop g, const T* const acts, T* output, int nu output[col] = curr; } +__device__ half log(half number) +{ +return hlog(number); +} + template __global__ void reduce_minus(Iop f, Rop g, const T* const acts, T* output, int num_rows) { diff --git a/include/detail/rnnt_helper.h b/include/detail/rnnt_helper.h index 2495fbd..4e80c0b 100644 --- a/include/detail/rnnt_helper.h +++ b/include/detail/rnnt_helper.h @@ -5,6 +5,7 @@ #include #include "hostdevice.h" +#include namespace rnnt_helper { @@ -26,6 +27,16 @@ inline HOSTDEVICE T log_sum_exp(T a, T b) { inline int div_up(int x, int y) { return (x + y - 1) / y; } +template<> +inline HOSTDEVICE half log_sum_exp(half a, half b) +{ +if (__hisinf(a) == -1) return b; +if (__hisinf(b) == -1) return a; +if (__hgt(a, b)) +return __hadd(hlog(__hadd((half)1, hexp(__hsub(b, a)))), a); +else +return __hadd(hlog(__hadd((half)1, hexp(__hsub(a, b)))), b); +} template struct maximum { HOSTDEVICE @@ -33,6 +44,12 @@ template struct maximum { return x < y ? y : x; } }; +template<> struct maximum { + HOSTDEVICE + half operator()(const half& x, const half& y) const { + return __hlt(x, y) ? y : x; + } +}; template struct add { HOSTDEVICE @@ -41,6 +58,13 @@ template struct add { } }; +template<> struct add { + HOSTDEVICE + half operator()(const half& x, const half& y) { + return __hadd(x, y); + } +}; + template struct identity { HOSTDEVICE Res operator()(const Arg& x) const {return Res(x);} }; @@ -53,6 +77,13 @@ template struct exponential { HOSTDEVICE Res operator()(const Arg& x) const {return std::exp(x);} }; +template<> struct exponential +{ + HOSTDEVICE half operator()(const half& x) const { + return hexp(x); + } +}; + template struct log_plus { typedef Res result_type; diff --git a/include/rnnt.h b/include/rnnt.h index 207065f..3c640b4 100644 --- a/include/rnnt.h +++ b/include/rnnt.h @@ -125,7 +125,20 @@ rnntStatus_t compute_rnnt_loss_fp64(const double* const activations, void *workspace, rnntOptions options); - +#ifdef WARPRNNT_ENABLE_GPU +#include + +rnntStatus_t compute_rnnt_loss_half(const half* const activations, + half* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + half *costs, + void *workspace, + rnntOptions options); +#endif /** For a given set of max sequence length and minibatch size return the required * workspace size. This will need to be allocated in the same memory space as your diff --git a/pytorch_binding/src/binding.cpp b/pytorch_binding/src/binding.cpp index 8fb9fe8..be6889b 100644 --- a/pytorch_binding/src/binding.cpp +++ b/pytorch_binding/src/binding.cpp @@ -44,17 +44,17 @@ int cpu_rnnt(torch::Tensor acts, #endif size_t cpu_size_bytes = 0; - switch (acts.type().scalarType()) { + switch (acts.scalar_type()) { case torch::ScalarType::Float: { get_workspace_size(maxT, maxU, minibatch_size, false, &cpu_size_bytes); float* cpu_workspace = (float*) new unsigned char[cpu_size_bytes]; - compute_rnnt_loss(acts.data(), grads.data(), - labels.data(), label_lengths.data(), - input_lengths.data(), alphabet_size, - minibatch_size, costs.data(), + compute_rnnt_loss(acts.data_ptr(), grads.data_ptr(), + labels.data_ptr(), label_lengths.data_ptr(), + input_lengths.data_ptr(), alphabet_size, + minibatch_size, costs.data_ptr(), cpu_workspace, options); delete cpu_workspace; @@ -67,10 +67,10 @@ int cpu_rnnt(torch::Tensor acts, sizeof(double)); double* cpu_workspace = (double*) new unsigned char[cpu_size_bytes]; - compute_rnnt_loss_fp64(acts.data(), grads.data(), - labels.data(), label_lengths.data(), - input_lengths.data(), alphabet_size, - minibatch_size, costs.data(), + compute_rnnt_loss_fp64(acts.data_ptr(), grads.data_ptr(), + labels.data_ptr(), label_lengths.data_ptr(), + input_lengths.data_ptr(), alphabet_size, + minibatch_size, costs.data_ptr(), cpu_workspace, options); delete cpu_workspace; @@ -111,7 +111,7 @@ int gpu_rnnt(torch::Tensor acts, options.num_threads = std::max(options.num_threads, (unsigned int) 1); #endif - switch (acts.type().scalarType()) { + switch (acts.scalar_type()) { case torch::ScalarType::Float: { size_t gpu_size_bytes; @@ -122,10 +122,10 @@ int gpu_rnnt(torch::Tensor acts, void* gpu_workspace = c10::cuda::CUDACachingAllocator::raw_alloc(gpu_size_bytes); - compute_rnnt_loss(acts.data(), grads.data(), - labels.data(), label_lengths.data(), - input_lengths.data(), alphabet_size, - minibatch_size, costs.data(), + compute_rnnt_loss(acts.data_ptr(), grads.data_ptr(), + labels.data_ptr(), label_lengths.data_ptr(), + input_lengths.data_ptr(), alphabet_size, + minibatch_size, costs.data_ptr(), gpu_workspace, options); c10::cuda::CUDACachingAllocator::raw_delete(gpu_workspace); @@ -141,10 +141,29 @@ int gpu_rnnt(torch::Tensor acts, void* gpu_workspace = c10::cuda::CUDACachingAllocator::raw_alloc(gpu_size_bytes); - compute_rnnt_loss_fp64(acts.data(), grads.data(), - labels.data(), label_lengths.data(), - input_lengths.data(), alphabet_size, - minibatch_size, costs.data(), + compute_rnnt_loss_fp64(acts.data_ptr(), grads.data_ptr(), + labels.data_ptr(), label_lengths.data_ptr(), + input_lengths.data_ptr(), alphabet_size, + minibatch_size, costs.data_ptr(), + gpu_workspace, options); + + c10::cuda::CUDACachingAllocator::raw_delete(gpu_workspace); + return 0; + } + case torch::ScalarType::Half: + { + size_t gpu_size_bytes; + get_workspace_size(maxT, maxU, minibatch_size, + true, &gpu_size_bytes); + + cudaSetDevice(acts.get_device()); + + void* gpu_workspace = c10::cuda::CUDACachingAllocator::raw_alloc(gpu_size_bytes); + + compute_rnnt_loss_half((half*)(acts.data_ptr()), (half*)(grads.data_ptr()), + labels.data_ptr(), label_lengths.data_ptr(), + input_lengths.data_ptr(), alphabet_size, + minibatch_size, (half*)(costs.data_ptr()), gpu_workspace, options); c10::cuda::CUDACachingAllocator::raw_delete(gpu_workspace); diff --git a/src/rnnt_entrypoint.cpp b/src/rnnt_entrypoint.cpp index 9854470..4147dd9 100644 --- a/src/rnnt_entrypoint.cpp +++ b/src/rnnt_entrypoint.cpp @@ -186,4 +186,53 @@ rnntStatus_t compute_rnnt_loss_fp64(const double* const activations, //BTUV } } +rnntStatus_t compute_rnnt_loss_half(const half* const activations, //BTUV + half* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + half *costs, + void *workspace, + rnntOptions options) { + + if (activations == nullptr || + flat_labels == nullptr || + label_lengths == nullptr || + input_lengths == nullptr || + costs == nullptr || + workspace == nullptr || + alphabet_size <= 0 || + minibatch <= 0 || + options.maxT <= 0 || + options.maxU <= 0 || + options.fastemit_lambda < 0) + return RNNT_STATUS_INVALID_VALUE; + + if (options.loc == RNNT_CPU) { + std::cerr << "CPU execution requested in half, but is not available for this type of data" << std::endl; + return RNNT_STATUS_EXECUTION_FAILED; + } else if (options.loc == RNNT_GPU) { +#ifdef __CUDACC__ + GpuRNNT rnnt(minibatch, options.maxT, options.maxU, alphabet_size, workspace, + options.blank_label, options.fastemit_lambda, options.num_threads, options.stream); + + if (gradients != NULL) + return rnnt.cost_and_grad(activations, gradients, + costs, + flat_labels, label_lengths, + input_lengths); + else + return rnnt.score_forward(activations, costs, flat_labels, + label_lengths, input_lengths); +#else + std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl; + return RNNT_STATUS_EXECUTION_FAILED; +#endif + } else { + return RNNT_STATUS_INVALID_VALUE; + } +} + }