From b34e3b93ee1a8030730c14af3995dd95665af07d Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 24 Aug 2022 05:49:49 -0700 Subject: [PATCH] Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924) --- torch/csrc/jit/codegen/cuda/codegen.cpp | 10 +++-- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 4 ++ .../jit/codegen/cuda/scheduler/transpose.cpp | 5 ++- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 43 +++++++++++++++++++ .../jit/codegen/cuda/test/test_gpu_rng.cu | 35 +++++++++++++++ 5 files changed, 92 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 5768863efa696..a74d8e56b3c83 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -754,10 +754,12 @@ class CudaKernelGenerator : private OptOutConstDispatch { auto out_tv = rop->output(0)->as()->view(); auto index = genTensorIndex(rop->getPhiloxIndex()->as()); int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4; - indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index - << ") / " << multiple << ";\n"; - indent() << "nvfuser_index_t rng_component" << rop->name() << " = (" - << index << ") % " << multiple << ";\n"; + indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index + << ";\n"; + indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index" + << rop->name() << " / " << multiple << ";\n"; + indent() << "nvfuser_index_t rng_component" << rop->name() + << " = linear_index" << rop->name() << " % " << multiple << ";\n"; indent() << "nvfuser_index_t rng_offset" << rop->name() << " = " << rop->getRNGOffset() << ";\n"; indent() << "if (rng_subseq != rng_subseq" << rop->name() diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 17cd0c34dd123..955ce1974e21a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -204,6 +204,10 @@ bool isScalarOp(const Expr* expr) { } bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { + if (expr->isA()) { + return true; + } + if (!isTvOp(expr)) { return false; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index afb5f09f2ec61..5ef502321b773 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -109,6 +109,9 @@ class DomainMap : public pointwise_utils::DomainMap { decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs}; for (auto tv_filtered_group : tv_filtered_groups) { for (auto tv : *tv_filtered_group) { + if (tv->isFusionInput() && tv->uses().empty()) { + continue; + } if (grouped.count(tv) > 0) { continue; } @@ -653,7 +656,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { if (inner_most_pos2_in_ref1 > inner_most_pos1_in_ref1) { inner_most_pos2_in_ref1--; } - if (!merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) { + if (merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) { (*merged2)--; } reference1->merge(*merged1, inner_most_pos1_in_ref1); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index f063ba82b6816..4f72bf93ba36e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -25512,6 +25512,49 @@ TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) { executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1926 + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion->addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + for (auto tv : {tv1, tv2}) { + tv->split(0, 4); + tv->reorder({{1, -1}}); + tv->split(1, 8); + tv->merge(0); + tv->split(0, 1); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::Unswitch); + } + tv1->merge(2); + tv2->reorder({{2, 3}}); + tv2->merge(2); + for (auto tv : {tv1, tv2}) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({5, 5}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto out = cg_outputs[0]; + + testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index bb7f910b2a665..f6b570be2bb8a 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -264,5 +264,40 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) { } } +TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1926 + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({5, 1}); + TensorView* tv1 = makeConcreteTensor({5, 5}); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = randlike(tv0); + auto tv3 = add(tv1, tv2); + auto tv4 = add(tv0, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::zeros({5, 1}, options); + at::Tensor t1 = at::zeros({5, 5}, options); + + TransposeParams heuristics; + heuristics.tile_size1 = 8; + heuristics.tile_size2 = 4; + scheduleTranspose(fusion, heuristics); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto out = cg_outputs[0]; + + TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item()); + TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item()); + TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item()); + TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item()); +} + } // namespace jit } // namespace torch