diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h index b565a4beffa13..fc6831f24c423 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h @@ -1,4 +1,5 @@ -#include +#pragma once + #include #include #include @@ -6,44 +7,16 @@ #include #include -#include -#include #include +// Tests go in torch::jit namespace torch { namespace jit { -namespace fuser { -namespace cuda { - -inline bool deviceMajorMinorCheck(int major, int minor = 0) { - auto dev_prop = at::cuda::getCurrentDeviceProperties(); - if (dev_prop->major < major || - (dev_prop->major == major && dev_prop->minor < minor)) { - return false; - } - return true; -} -inline int deviceSMCount() { - int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - return sm_count; -} +using namespace torch::jit::fuser::cuda; -class NVFuserTest : public ::testing::Test { - protected: - void SetUp() override { - // requires PASCAL or newer - if (!deviceMajorMinorCheck(6)) { - GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; - } - setFillAllocationWithNan(true); - } - - void TearDown() override { - c10::cuda::CUDACachingAllocator::emptyCache(); - } -}; +namespace { struct ValidationConstants { // Tolerances generated from randn + add + sum fusion @@ -74,8 +47,6 @@ struct ValidationConstants { double base_float_rel_tol = -1; }; -namespace { - // Returns abs and relative values to use for validation std::pair getTolerance( DataType dtype, @@ -338,15 +309,13 @@ ExpressionEvaluator bindInputsAndLaunchParams( return expr_eval; } -} // namespace - // Validation will look through the fusion and figure out how many elements were // reduced to create each output. It will then compute a tolernace to use for // allclose based on experimental results. The experimental results were based // on adding two tensors then summing them. This of course has an assumption // that we're always summing values between -2 and 2. If we start summing values // larger than that this approach might not hold. -inline void testValidate( +void testValidate( Fusion* fusion, const std::vector& fusion_outputs, const at::ArrayRef& aten_inputs, @@ -466,18 +435,6 @@ inline void testValidate( } } -inline void clearL2Cache() { - torch::NoGradGuard no_grad; - auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; - auto options = - torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); - - auto l2_elems = l2_cache_size / 4; - torch::Tensor t0 = torch::empty(l2_elems, options); - torch::Tensor t1 = torch::clone(t0); -}; - -} // namespace cuda -} // namespace fuser +} // namespace } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/test/test_utils.h b/torch/csrc/jit/codegen/cuda/test/test_utils.h index c8bf546daf4a0..fb83459952a2e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_utils.h +++ b/torch/csrc/jit/codegen/cuda/test/test_utils.h @@ -1,9 +1,17 @@ #pragma once -#include - +#include +#include #include +#include +#include +#include + +#include + +#include + // Tests go in torch::jit namespace torch { namespace jit { @@ -84,6 +92,45 @@ int64_t prime_numbers[] = { 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223}; +bool deviceMajorMinorCheck(int major, int minor = 0) { + auto dev_prop = at::cuda::getCurrentDeviceProperties(); + if (dev_prop->major < major || + (dev_prop->major == major && dev_prop->minor < minor)) { + return false; + } + return true; +} + +int deviceSMCount() { + int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + return sm_count; +} + +void clearL2Cache() { + torch::NoGradGuard no_grad; + auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; + auto options = + torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); + + auto l2_elems = l2_cache_size / 4; + torch::Tensor t0 = torch::empty(l2_elems, options); + torch::Tensor t1 = torch::clone(t0); +}; + } // namespace + +// Fixture class must be uniquely identified, i.e., can't be in an +// anonymous namespace +class NVFuserTest : public ::testing::Test { + protected: + void SetUp() override { + // requires PASCAL or newer + if (!deviceMajorMinorCheck(6)) { + GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; + } + setFillAllocationWithNan(true); + } +}; + } // namespace jit } // namespace torch