Skip to content

Commit

Permalink
Fix name clash of RNG with shared memory (#1904)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Aug 11, 2022
1 parent 3381793 commit e71e1ec
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 19 deletions.
38 changes: 19 additions & 19 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

// Random number generator (optional)
if (kernel_summary.max_rng_offsets >= 0) {
indent() << "auto offset = philox_args.captured_ ?\n";
indent() << "auto philox_offset = philox_args.captured_ ?\n";
indent()
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
indent() << " philox_args.offset_.val;\n";
Expand Down Expand Up @@ -290,18 +290,18 @@ class CudaKernelGenerator : private OptOutConstDispatch {
<< ") extern __shared__ char array[];\n";

if (has_dynamic_smem) {
indent() << "unsigned offset = 0;\n";
indent() << "unsigned smem_offset = 0;\n";
}

if (has_reductions || has_parallel_welford) {
indent() << "void* shared_mem = array;\n";
if (has_dynamic_smem) {
if (has_parallel_welford) {
indent() << "offset += "
indent() << "smem_offset += "
<< "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
<< kernel_summary.largest_smem_data_type << "));\n";
} else {
indent() << "offset += "
indent() << "smem_offset += "
<< "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
<< kernel_summary.largest_smem_data_type << "));\n";
}
Expand Down Expand Up @@ -710,19 +710,19 @@ class CudaKernelGenerator : private OptOutConstDispatch {
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
auto index = genTensorIndex(uop->in()->as<kir::TensorIndex>());
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
indent() << "nvfuser_index_t subseq" << uop->name() << " = (" << index
<< ") / " << multiple << ";\n";
indent() << "nvfuser_index_t component" << uop->name() << " = ("
indent() << "nvfuser_index_t rng_subseq" << uop->name() << " = ("
<< index << ") / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << uop->name() << " = ("
<< index << ") % " << multiple << ";\n";
indent() << "nvfuser_index_t offset" << uop->name() << " = "
indent() << "nvfuser_index_t rng_offset" << uop->name() << " = "
<< uop->getRNGOffset() << ";\n";
indent() << "if (rng_subseq != subseq" << uop->name()
<< " || rng_offset != offset" << uop->name() << ") {\n";
indent() << " rng_result = philox(philox_args.seed_, subseq"
<< uop->name() << ", offset / 4 + offset" << uop->name()
<< ");\n";
indent() << " rng_subseq = subseq" << uop->name() << ";\n";
indent() << " rng_offset = offset" << uop->name() << ";\n";
indent() << "if (rng_subseq != rng_subseq" << uop->name()
<< " || rng_offset != rng_offset" << uop->name() << ") {\n";
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
<< uop->name() << ", philox_offset / 4 + rng_offset"
<< uop->name() << ");\n";
indent() << " rng_subseq = rng_subseq" << uop->name() << ";\n";
indent() << " rng_offset = rng_offset" << uop->name() << ";\n";
indent() << "}\n";
}

Expand Down Expand Up @@ -764,7 +764,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

code_ << "(";
if (op_type == UnaryOpType::RandLike) {
code_ << "rng_result, component" << uop->name();
code_ << "rng_result, rng_component" << uop->name();
} else {
code_ << gen(uop->in());
}
Expand Down Expand Up @@ -2329,15 +2329,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
break;
case MemoryType::Shared:
// Align Offset Position
indent() << "offset = alignBufferSize(offset, "
indent() << "smem_offset = alignBufferSize(smem_offset, "
// Always align to 128b / 16B
<< 16 << ");\n";
// Shared Memory Pointer
indent() << buffer_dtype << "* " << varName(tv)
<< " = reinterpret_cast<" << buffer_dtype << "*>"
<< "(array + offset);\n";
<< "(array + smem_offset);\n";
// Increment Offset Position
indent() << "offset += (" << genInline(size) << " * sizeof("
indent() << "smem_offset += (" << genInline(size) << " * sizeof("
<< buffer_dtype << "));\n";
break;
case MemoryType::Local: {
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
Expand Down Expand Up @@ -230,5 +231,38 @@ TEST_F(NVFuserTest, FusionBroadcastingRNG2_CUDA) {
}
}
TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) {
for (auto dtype : {kFloat, kDouble}) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype));
TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype));
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = randlike(tv0);
auto tv3 = add(tv1, tv2);
auto tv4 = add(tv0, tv3);
fusion->addOutput(tv4);
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
at::Tensor t0 = at::zeros({5, 1}, options);
at::Tensor t1 = at::zeros({5, 5}, options);
auto lparams = scheduleTranspose(fusion, {t0, t1});
FusionExecutor fe;
fe.compileFusion(fusion, {t0, t1}, lparams);
auto cg_outputs = fe.runFusion({t0, t1}, lparams);
auto out = cg_outputs[0];
TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>())
TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>())
TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>())
TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>())
}
}
} // namespace jit
} // namespace torch

0 comments on commit e71e1ec

Please sign in to comment.