From e71e1ecefe67219846070590bbed54bbc7416b79 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 11 Aug 2022 13:19:16 -0700 Subject: [PATCH] Fix name clash of RNG with shared memory (#1904) --- torch/csrc/jit/codegen/cuda/codegen.cpp | 38 +++++++++---------- .../jit/codegen/cuda/test/test_gpu_rng.cu | 34 +++++++++++++++++ 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 0d597dbaba0bd..1fa8425c465de 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -260,7 +260,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { // Random number generator (optional) if (kernel_summary.max_rng_offsets >= 0) { - indent() << "auto offset = philox_args.captured_ ?\n"; + indent() << "auto philox_offset = philox_args.captured_ ?\n"; indent() << " static_cast(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n"; indent() << " philox_args.offset_.val;\n"; @@ -290,18 +290,18 @@ class CudaKernelGenerator : private OptOutConstDispatch { << ") extern __shared__ char array[];\n"; if (has_dynamic_smem) { - indent() << "unsigned offset = 0;\n"; + indent() << "unsigned smem_offset = 0;\n"; } if (has_reductions || has_parallel_welford) { indent() << "void* shared_mem = array;\n"; if (has_dynamic_smem) { if (has_parallel_welford) { - indent() << "offset += " + indent() << "smem_offset += " << "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof(" << kernel_summary.largest_smem_data_type << "));\n"; } else { - indent() << "offset += " + indent() << "smem_offset += " << "((blockDim.x * blockDim.y * blockDim.z) * sizeof(" << kernel_summary.largest_smem_data_type << "));\n"; } @@ -710,19 +710,19 @@ class CudaKernelGenerator : private OptOutConstDispatch { auto out_tv = uop->out()->as()->view(); auto index = genTensorIndex(uop->in()->as()); int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4; - indent() << "nvfuser_index_t subseq" << uop->name() << " = (" << index - << ") / " << multiple << ";\n"; - indent() << "nvfuser_index_t component" << uop->name() << " = (" + indent() << "nvfuser_index_t rng_subseq" << uop->name() << " = (" + << index << ") / " << multiple << ";\n"; + indent() << "nvfuser_index_t rng_component" << uop->name() << " = (" << index << ") % " << multiple << ";\n"; - indent() << "nvfuser_index_t offset" << uop->name() << " = " + indent() << "nvfuser_index_t rng_offset" << uop->name() << " = " << uop->getRNGOffset() << ";\n"; - indent() << "if (rng_subseq != subseq" << uop->name() - << " || rng_offset != offset" << uop->name() << ") {\n"; - indent() << " rng_result = philox(philox_args.seed_, subseq" - << uop->name() << ", offset / 4 + offset" << uop->name() - << ");\n"; - indent() << " rng_subseq = subseq" << uop->name() << ";\n"; - indent() << " rng_offset = offset" << uop->name() << ";\n"; + indent() << "if (rng_subseq != rng_subseq" << uop->name() + << " || rng_offset != rng_offset" << uop->name() << ") {\n"; + indent() << " rng_result = philox(philox_args.seed_, rng_subseq" + << uop->name() << ", philox_offset / 4 + rng_offset" + << uop->name() << ");\n"; + indent() << " rng_subseq = rng_subseq" << uop->name() << ";\n"; + indent() << " rng_offset = rng_offset" << uop->name() << ";\n"; indent() << "}\n"; } @@ -764,7 +764,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { code_ << "("; if (op_type == UnaryOpType::RandLike) { - code_ << "rng_result, component" << uop->name(); + code_ << "rng_result, rng_component" << uop->name(); } else { code_ << gen(uop->in()); } @@ -2329,15 +2329,15 @@ class CudaKernelGenerator : private OptOutConstDispatch { break; case MemoryType::Shared: // Align Offset Position - indent() << "offset = alignBufferSize(offset, " + indent() << "smem_offset = alignBufferSize(smem_offset, " // Always align to 128b / 16B << 16 << ");\n"; // Shared Memory Pointer indent() << buffer_dtype << "* " << varName(tv) << " = reinterpret_cast<" << buffer_dtype << "*>" - << "(array + offset);\n"; + << "(array + smem_offset);\n"; // Increment Offset Position - indent() << "offset += (" << genInline(size) << " * sizeof(" + indent() << "smem_offset += (" << genInline(size) << " * sizeof(" << buffer_dtype << "));\n"; break; case MemoryType::Local: { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index 765d951df1579..21360b1510914 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -230,5 +231,38 @@ TEST_F(NVFuserTest, FusionBroadcastingRNG2_CUDA) { } } +TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) { + for (auto dtype : {kFloat, kDouble}) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype)); + TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype)); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = randlike(tv0); + auto tv3 = add(tv1, tv2); + auto tv4 = add(tv0, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + at::Tensor t0 = at::zeros({5, 1}, options); + at::Tensor t1 = at::zeros({5, 5}, options); + + auto lparams = scheduleTranspose(fusion, {t0, t1}); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0, t1}, lparams); + auto cg_outputs = fe.runFusion({t0, t1}, lparams); + auto out = cg_outputs[0]; + + TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item()) + TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item()) + TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item()) + TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item()) + } +} + } // namespace jit } // namespace torch