diff --git a/nestedtensor/csrc/cuda/padding.cu b/nestedtensor/csrc/cuda/padding.cu index fb9550d3..61ae6166 100644 --- a/nestedtensor/csrc/cuda/padding.cu +++ b/nestedtensor/csrc/cuda/padding.cu @@ -19,8 +19,8 @@ void add_padding_1( const int* output_sizes, const int batch_size) { - const int batch_id = blockIdx.x; - const int grid_id = blockIdx.y; + const int batch_id = blockIdx.y; + const int grid_id = blockIdx.x; const int tid = threadIdx.x + grid_id * 256; const int grainsize = 16 * 256; const int batch_input_offset = offsets[batch_id]; @@ -59,8 +59,8 @@ void add_padding_2( const int* output_sizes, const int batch_size) { - const int batch_id = blockIdx.x; - const int grid_id = blockIdx.y; + const int batch_id = blockIdx.y; + const int grid_id = blockIdx.x; const int tid = threadIdx.x + grid_id * 256; const int grainsize = 16 * 256; const int offset = offsets[batch_id]; @@ -92,7 +92,7 @@ void add_padding_2( } } -template +template __global__ void add_padding_3( const T* input, @@ -101,40 +101,35 @@ void add_padding_3( const int* offsets, const int* input_sizes, int input_dim, - const int* output_sizes, - const int batch_size) + int output_sizes_3, + int output_sizes_2_3, + int output_numel) { const int batch_id = blockIdx.x; - const int grid_id = blockIdx.y; - const int tid = threadIdx.x + grid_id * 256; - const int grainsize = 16 * 256; - const int offset = offsets[batch_id]; + const int i0 = blockIdx.y; + const int tid = threadIdx.x; const int* sizes_i = input_sizes + batch_id * input_dim; - const int numel_i = sizes_i[0] * sizes_i[1] * sizes_i[2]; - const int output_offset = batch_id * output_sizes[1] * output_sizes[2] * output_sizes[3]; - const int output_numel = output_sizes[1] * output_sizes[2] * output_sizes[3]; - for (int ii = 0; ii < (output_numel / grainsize); ii++) { - const int i = ii * grainsize + tid; - const int i0 = i / (output_sizes[2] * output_sizes[3]); - const int i1 = (i % (output_sizes[2] * output_sizes[3])) / output_sizes[3]; - const int i2 = i % output_sizes[3]; - if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) { - const int input_offset = offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; - output[output_offset + i] = input[input_offset]; - } else { - output[output_offset + i] = padding_value; + const int sizes_0 = sizes_i[0]; + int i = tid; + output = output + batch_id * output_numel + i0 * output_sizes_2_3; + if (i0 < sizes_0) { + const int sizes_1 = sizes_i[1]; + const int sizes_2 = sizes_i[2]; + const int sizes_1_2 = sizes_1 * sizes_2; + input = input + offsets[batch_id] + i0 * sizes_1_2; + bool valid_0 = i0 < sizes_0; + for (;i < output_sizes_2_3;) { + const int i1 = i / output_sizes_3; + const int i2 = i % output_sizes_3; + const bool valid = i1 < sizes_1 && i2 < sizes_2; + const int input_offset = valid ? i1 * sizes_2 + i2 : 0; + output[i] = valid ? input[input_offset] : padding_value; + i += grainsize; } - } - const int i = (output_numel / grainsize) * grainsize + tid; - if (i < output_numel) { - const int i0 = i / (output_sizes[2] * output_sizes[3]); - const int i1 = (i % (output_sizes[2] * output_sizes[3])) / output_sizes[3]; - const int i2 = i % output_sizes[3]; - if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) { - const int input_offset = offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; - output[output_offset + i] = input[input_offset]; - } else { - output[output_offset + i] = padding_value; + } else { + for (;i < output_sizes_2_3;) { + output[i] = padding_value; + i += grainsize; } } } @@ -153,7 +148,7 @@ void add_padding_kernelLauncher( { dim3 grid; grid.x = batch_size; - grid.y = 16; + grid.y = output_sizes[1]; if (input_dim == 1) { add_padding_1<<>>( input, @@ -177,15 +172,16 @@ void add_padding_kernelLauncher( batch_size); } if (input_dim == 3) { - add_padding_3<<>>( + add_padding_3<<>>( input, output, padding_value, offsets, input_sizes, input_dim, - output_sizes, - batch_size); + output_sizes[3], + output_sizes[2] * output_sizes[3], + output_sizes[1] * output_sizes[2] * output_sizes[3]); } } diff --git a/nestedtensor/csrc/masking.cpp b/nestedtensor/csrc/masking.cpp index 2ca43b68..3a98cdcd 100644 --- a/nestedtensor/csrc/masking.cpp +++ b/nestedtensor/csrc/masking.cpp @@ -507,23 +507,21 @@ Tensor to_padded_tensor(Tensor nt, double padding) { std::vector new_size = padded_size_from_efficient_size(esize); at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options()); - Tensor new_size_tensor = torch::tensor(new_size); + Tensor new_size_tensor = torch::tensor(new_size, torch::kInt32); int64_t input_dim = nt_sizes.size(1); int64_t batch_size = nt_sizes.size(0); - at::Tensor metadata = at::cat({new_size_tensor, offsets, nt_sizes.reshape(-1)}); + at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)}); metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true); std::vector split_sizes; - split_sizes.push_back(new_size_tensor.numel()); split_sizes.push_back(offsets.numel()); split_sizes.push_back(nt_sizes.numel()); std::vector split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0); - new_size_tensor = split[0]; - offsets = split[1]; - nt_sizes = split[2]; + offsets = split[0]; + nt_sizes = split[1]; if (nt_buffer.dtype() == torch::kFloat16) { nested_tensor::cuda::add_padding_kernelLauncher( diff --git a/nestedtensor/version.py b/nestedtensor/version.py index 46a23e4b..8ec67334 100644 --- a/nestedtensor/version.py +++ b/nestedtensor/version.py @@ -1,5 +1,5 @@ -__version__ = '0.1.4+33fb247' -git_version = '33fb2477c856f8185f1e9c1e9a6ca28065e43cf9' +__version__ = '0.1.4+2719e68' +git_version = '2719e6833bcdec69084953381aa05e53e9df9baa' from nestedtensor import _C if hasattr(_C, 'CUDA_VERSION'): cuda = _C.CUDA_VERSION