Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AT_DISPATCH_FLOATING_TYPES_AND2 fails with ScalarType::Byte #34826

Closed
r-zenine opened this issue Mar 16, 2020 · 4 comments
Closed

AT_DISPATCH_FLOATING_TYPES_AND2 fails with ScalarType::Byte #34826

r-zenine opened this issue Mar 16, 2020 · 4 comments
Assignees
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: dispatch DispatchStub, Type, void pointer table, c10 dispatch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@r-zenine
Copy link
Contributor

r-zenine commented Mar 16, 2020

🐛 Bug

Using AT_DISPATCH_FLOATING_TYPES_AND2, in cuda code is causing a compilation error:

I am trying to enable support for uint8 for nearest Neighbor upsampling.
Therefore, I tried to replace AT_DISPATCH_FLOATING_TYPES_AND_HALF with AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Byte, ...
Which caused a compilation. On the other hand, using AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half ... works fine.

To Reproduce

Steps to reproduce the behavior:

  1. Checkout to bdd7dbfd4b
  2. Change /home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest3d.cu(176) to use AT_DISPATCH_FLOATING_TYPES_AND2
  3. compile
[1/3] Building NVCC (Device) object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/torch_cuda_generated_UpSampleNearest3d.cu.o
FAILED: caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/torch_cuda_generated_UpSampleNearest3d.cu.o 
cd /home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda && /home/ryad/miniconda3/envs/pytorch-dev/bin/cmake -E make_directory /home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/. && /home/ryad/miniconda3/envs/pytorch-dev/bin/cmake -D verbose:BOOL=OFF -D build_configuration:STRING=Release -D generated_file:STRING=/home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/./torch_cuda_generated_UpSampleNearest3d.cu.o -D generated_cubin_file:STRING=/home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/./torch_cuda_generated_UpSampleNearest3d.cu.o.cubin.txt -P /home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/torch_cuda_generated_UpSampleNearest3d.cu.o.Release.cmake
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest3d.cu(176): error: incomplete type is not allowed

/home/ryad/workspace/pytorch/aten/src/ATen/AccumulateType.h(48): error: class "at::AccumulateType<<error-type>, true>" has no member "type"
          detected during instantiation of type "at::acc_type<<error-type>, true>" 
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest3d.cu(176): here

/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest3d.cu(269): error: incomplete type is not allowed

3 errors detected in the compilation of "/tmp/tmpxft_00007f4c_00000000-6_UpSampleNearest3d.cpp1.ii".
CMake Error at torch_cuda_generated_UpSampleNearest3d.cu.o.Release.cmake:281 (message):
  Error generating file
  /home/ryad/workspace/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/./torch_cuda_generated_UpSampleNearest3d.cu.o


[2/3] Building NVCC (Device) object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/torch_cuda_generated_UpSampleNearest2d.cu.o
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu: In function ‘at::Tensor at::native::upsample_nearest2d_cuda(const at::Tensor&, c10::IntArrayRef, c10::optional<double>, c10::optional<double>)’:
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu:305:44: warning: ‘c10::MemoryFormat c10::get_contiguous_memory_format()’ is deprecated [-Wdeprecated-declarations]
   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
                                            ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/ryad/workspace/pytorch/c10/core/MemoryFormat.h:34:36: note: declared here
 C10_DEPRECATED inline MemoryFormat get_contiguous_memory_format() {
                                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu:305:73: warning: ‘c10::MemoryFormat c10::get_contiguous_memory_format()’ is deprecated [-Wdeprecated-declarations]
   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
                                                                         ^
/home/ryad/workspace/pytorch/c10/core/MemoryFormat.h:34:36: note: declared here
 C10_DEPRECATED inline MemoryFormat get_contiguous_memory_format() {
                                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu: In function ‘at::Tensor at::native::upsample_nearest2d_backward_cuda(const at::Tensor&, c10::IntArrayRef, c10::IntArrayRef, c10::optional<double>, c10::optional<double>)’:
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu:328:54: warning: ‘c10::MemoryFormat c10::get_contiguous_memory_format()’ is deprecated [-Wdeprecated-declarations]
   Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
                                                      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/ryad/workspace/pytorch/c10/core/MemoryFormat.h:34:36: note: declared here
 C10_DEPRECATED inline MemoryFormat get_contiguous_memory_format() {
                                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/ryad/workspace/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu:328:83: warning: ‘c10::MemoryFormat c10::get_contiguous_memory_format()’ is deprecated [-Wdeprecated-declarations]
   Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
                                                                                   ^
/home/ryad/workspace/pytorch/c10/core/MemoryFormat.h:34:36: note: declared here
 C10_DEPRECATED inline MemoryFormat get_contiguous_memory_format() {
                                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
ninja: build stopped: subcommand failed.

Expected behavior

My understanding is that the code should compile since using AT_ALL_TYPES_AND is compiling.

cc @yf225

@colesbury
Copy link
Member

You probably also need to add a ScalarTypeToCPPType mapping from Byte to uint8_t here:

template<>
struct ScalarTypeToCPPType<c10::ScalarType::Bool> {
using type = bool;
// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
// TODO: remove once the bug is fixed.
static type t;
};

If that's not sufficient, there might be other similar mappings missing.

@r-zenine
Copy link
Contributor Author

Hi @colesbury,
Thanks for the tip. I'll give it a try.

@colesbury colesbury changed the title AT_DISPATCH_FLOATING_TYPES_AND2 causing compilation error. AT_DISPATCH_FLOATING_TYPES_AND2 fails with ScalarType::Byte Mar 18, 2020
@colesbury colesbury added module: cpp-extensions Related to torch.utils.cpp_extension module: dispatch DispatchStub, Type, void pointer table, c10 dispatch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 18, 2020
@colesbury
Copy link
Member

PyTorch devs: we might want to ensure that the AT_DISPATCH macros work with all the built-in scalar types

@r-zenine
Copy link
Contributor Author

Hi @colesbury ,

I'll submit the fix for Byte.

r-zenine added a commit to r-zenine/pytorch that referenced this issue Mar 19, 2020
Notes:
Due to a bug in AT_DISPATCH_FLOATING_TYPES_AND2 (see pytorch#34826), I used
AT_DISPATCH_ALL_TYPES_AND.
facebook-github-bot pushed a commit that referenced this issue Mar 24, 2020
Summary:
This PR add's a workaround for `cuda` for `ScalarType::Byte` for the `AT_DISPATCH_*` macros.
As discussed here:
#34826
Pull Request resolved: #35027

Differential Revision: D20596555

Pulled By: colesbury

fbshipit-source-id: 72e842603723a5aa146e4224e79befafc62f2624
@gchanan gchanan self-assigned this May 7, 2020
gchanan added a commit that referenced this issue May 7, 2020
gchanan added a commit that referenced this issue May 7, 2020
gchanan added a commit that referenced this issue May 8, 2020
gchanan added a commit that referenced this issue May 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: dispatch DispatchStub, Type, void pointer table, c10 dispatch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants