Skip to content

Commit 84dc8c4

Browse files
r-zeninefacebook-github-bot
authored andcommitted
Add's workaround for ScalarType::Byte for cuda (pytorch#35027)
Summary: This PR add's a workaround for `cuda` for `ScalarType::Byte` for the `AT_DISPATCH_*` macros. As discussed here: pytorch#34826 Pull Request resolved: pytorch#35027 Differential Revision: D20596555 Pulled By: colesbury fbshipit-source-id: 72e842603723a5aa146e4224e79befafc62f2624
1 parent 39a101d commit 84dc8c4

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

c10/core/ScalarType.h

+11
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ struct ScalarTypeToCPPType<c10::ScalarType::Bool> {
108108
static type t;
109109
};
110110

111+
template<>
112+
struct ScalarTypeToCPPType<c10::ScalarType::Byte> {
113+
using type = uint8_t;
114+
115+
// This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
116+
// due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
117+
// For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
118+
// TODO: remove once the bug is fixed.
119+
static type t;
120+
};
121+
111122
template<>
112123
struct ScalarTypeToCPPType<c10::ScalarType::Long> {
113124
using type = int64_t;

0 commit comments

Comments
 (0)