From 037a75a42048f1d8a9c30efb466f1ffbfd2894ad Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 7 Jul 2022 13:08:13 -0700 Subject: [PATCH] Dropout prob extremal patch (#1804) Fixes #1799 1. Updates rand_like by changing output==1 to 0 via `where`; 2. Patches codegen float output. --- test/test_jit_cuda_fuser.py | 43 ++++++++++++++++--------- torch/csrc/jit/codegen/cuda/arith.cpp | 25 +++++++++++++- torch/csrc/jit/codegen/cuda/codegen.cpp | 19 +++++++---- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index dee87fae0935fd..45c762b2b426ad 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -82,6 +82,11 @@ def is_pre_volta(): TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported() +TEST_LARGE_TENSOR = RUN_NVFUSER +if RUN_NVFUSER: + torch.ones(1).cuda() # initialize cuda context + TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 + class CudaFuserTestOptions(): def __init__(self): self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() @@ -183,23 +188,27 @@ def tearDown(self): self.cuda_fuser_options.restore() super(TestCudaFuser, self).tearDown() - def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1): - torch.cuda.manual_seed_all(123) - jit_o = jit_op(*args) - torch.cuda.manual_seed_all(123) + def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1): + seed = 123 + torch.cuda.manual_seed_all(seed) jit_o = jit_op(*args) - torch.cuda.manual_seed_all(123) - o = op(*args) - if type(jit_o) is torch.Tensor: - jit_o = [jit_o, ] - o = [o, ] + for i in range(check_runs): + torch.cuda.manual_seed_all(seed + i) + jit_o = jit_op(*args) + torch.cuda.manual_seed_all(seed + i) + o = op(*args) + + if type(jit_o) is torch.Tensor: + jit_o = [jit_o, ] + o = [o, ] + + for oo, jit_oo in zip(o, jit_o): + self.assertEqual(oo.dtype, jit_oo.dtype) + self.assertEqual(oo, jit_oo) + if check_stride: + self.assertEqual(oo.stride(), jit_oo.stride()) - for oo, jit_oo in zip(o, jit_o): - self.assertEqual(oo.dtype, jit_oo.dtype) - self.assertEqual(oo, jit_oo) - if check_stride: - self.assertEqual(oo.stride(), jit_oo.stride()) self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True) def _run_training_helper(self, jit_op, op, grads, *args): @@ -2562,13 +2571,14 @@ def t(x: torch.Tensor, p: float, train: bool): self._run_helper(t_jit, t, x, 0.15, False) + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_train_nograd_fusion(self): dtype = torch.float device = "cuda" - x = torch.randn([10, 4, 8], dtype=dtype, device=device) + x = torch.randn([64, 128, 1024], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) @@ -2577,7 +2587,8 @@ def t(x: torch.Tensor, p: float, train: bool): t_jit = torch.jit.script(t) - self._run_helper(t_jit, t, x, 0.0, True) + self._run_helper(t_jit, t, x, 0.0, True, check_runs=20) + self._run_helper(t_jit, t, x, 1.0, True, check_runs=20) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 32edaec3c1faf1..5adaf5b28b59d6 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -458,7 +458,6 @@ TensorView* unaryOp( } NVFUSER_DEFINE_UNARY_OP(set, Set) -NVFUSER_DEFINE_UNARY_OP(randlike, RandLike) NVFUSER_DEFINE_UNARY_OP(ceil, Ceil) NVFUSER_DEFINE_UNARY_OP(floor, Floor) NVFUSER_DEFINE_UNARY_OP(frac, Frac) @@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu) NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) #undef NVFUSER_DEFINE_UNARY_OP +Val* randlike(Val* v) { + TORCH_CHECK( + isFloatingPointType(v->dtype()), + "input must have floating point type, but got ", + v->dtype()); + auto rand_vals = unaryOp(UnaryOpType::RandLike, v); + return where( + eq(rand_vals, IrBuilder::create(1.0)), + IrBuilder::create(0.0), + rand_vals); +} + +TensorView* randlike(TensorView* v) { + TORCH_CHECK( + isFloatingPointType(v->dtype()), + "input must have floating point type, but got ", + v->dtype()); + auto rand_vals = unaryOp(UnaryOpType::RandLike, v); + return where( + eq(rand_vals, IrBuilder::create(1.0)), + IrBuilder::create(0.0), + rand_vals); +} + Val* bitwise_not(Val* v) { TORCH_CHECK( isIntegralType(v->dtype()) || v->dtype() == DataType::Bool, diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 2d8d3e3f2d60fd..20b104ca2c6709 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -170,7 +170,15 @@ class CudaKernelGenerator : private OptOutConstDispatch { } private: - explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {} + explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) { + initStringStreamFormat(code_); + } + + void initStringStreamFormat(std::stringstream& ss) { + const int digits = std::numeric_limits::max_digits10; + ss.imbue(std::locale("C")); + ss << std::scientific << std::setprecision(digits); + } // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { @@ -358,6 +366,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { std::string gen(const Statement* stmt) { std::stringstream tmp_code; + initStringStreamFormat(tmp_code); std::swap(tmp_code, code_); OptOutConstDispatch::handle(stmt); std::swap(tmp_code, code_); @@ -419,9 +428,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } else if (std::isnan(val)) { code_ << "NAN"; } else { - const int digits = - std::numeric_limits::max_digits10; - code_ << std::setprecision(digits) << val; + code_ << val; } } else { code_ << varName(d); @@ -454,9 +461,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (def != nullptr && !has_alloc) { code_ << "(" << gen(def) << ")"; } else if (c->isConst()) { - const int digits = std::numeric_limits::max_digits10; - code_ << "std::complex" << std::setprecision(digits) - << *c->value(); + code_ << "std::complex" << *c->value(); } else { code_ << varName(c); }