Skip to content

Commit

Permalink
Exclude unsupported data types (#1951)
Browse files Browse the repository at this point in the history
* Exclude unsupported data types
  • Loading branch information
naoyam committed Sep 2, 2022
1 parent 992e17c commit ddc01e4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
22 changes: 22 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ TEST_F(NVFuserTest, FusionStandaloneFull_CUDA) {
fusion->addInput(fill_val2);
fusion->addInput(fill_val3);
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
auto out_tv = full({size}, fill_val1, aten_to_data_type(dtype));
fusion->addOutput(out_tv);
out_tv = full({size, size}, fill_val2, aten_to_data_type(dtype));
Expand All @@ -57,6 +60,9 @@ TEST_F(NVFuserTest, FusionStandaloneFull_CUDA) {
std::vector<at::Tensor> expect;
expect.reserve(dtypes.size());
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
const auto options =
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
expect.emplace_back(at::full({size}, 11, options));
Expand Down Expand Up @@ -94,6 +100,9 @@ TEST_F(NVFuserTest, FusionStandaloneZeros_CUDA) {
Val* size = IrBuilder::create<Int>();
fusion->addInput(size);
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
auto out_tv = zeros({size}, aten_to_data_type(dtype));
fusion->addOutput(out_tv);
out_tv = zeros({size, size}, aten_to_data_type(dtype));
Expand All @@ -108,6 +117,9 @@ TEST_F(NVFuserTest, FusionStandaloneZeros_CUDA) {
std::vector<at::Tensor> expect;
expect.reserve(dtypes.size());
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
const auto options =
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
expect.emplace_back(at::zeros({size}, options));
Expand Down Expand Up @@ -145,6 +157,9 @@ TEST_F(NVFuserTest, FusionStandaloneOnes_CUDA) {
Val* size = IrBuilder::create<Int>();
fusion->addInput(size);
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
auto out_tv = ones({size}, aten_to_data_type(dtype));
fusion->addOutput(out_tv);
out_tv = ones({size, size}, aten_to_data_type(dtype));
Expand All @@ -159,6 +174,9 @@ TEST_F(NVFuserTest, FusionStandaloneOnes_CUDA) {
std::vector<at::Tensor> expect;
expect.reserve(dtypes.size());
for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}
const auto options =
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
expect.emplace_back(at::ones({size}, options));
Expand All @@ -183,6 +201,10 @@ TEST_F(NVFuserTest, FusionStandaloneARange_CUDA) {
auto dtypes = {kFloat, kLong, kDouble};

for (auto dtype : dtypes) {
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
continue;
}

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/type.h>

#include <ATen/cuda/CUDAContext.h>

#include <stdexcept>
#include <unordered_map>

Expand Down Expand Up @@ -160,6 +162,17 @@ DataType getTypeFromComplexType(DataType dtype) {
}
}

bool isSupportedTypeByDevice(DataType dtype) {
auto prop = at::cuda::getCurrentDeviceProperties();
auto major_ver = prop->major;
switch (dtype) {
case DataType::BFloat16:
return major_ver >= 8;
default:
return true;
}
}

bool isIntegerOp(const BinaryOpType bopt) {
return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift;
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ int getVectorSizeFromType(DataType dtype);
DataType getTypeFromVectorType(DataType dtype);
// Return the corresponding scalar of a complex type
DataType getTypeFromComplexType(DataType dtype);
// Return if the datatype is supported on the current device
TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype);

enum class ExprType {
Invalid,
Expand Down

0 comments on commit ddc01e4

Please sign in to comment.