Skip to content

Commit

Permalink
Fix ir_utils::hasBlockSync + misc fixes in transpose scheduler (#1924)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Aug 24, 2022
1 parent 14a53e6 commit b34e3b9
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 5 deletions.
10 changes: 6 additions & 4 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,12 @@ class CudaKernelGenerator : private OptOutConstDispatch {
auto out_tv = rop->output(0)->as<kir::TensorIndex>()->view();
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index
<< ") / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << rop->name() << " = ("
<< index << ") % " << multiple << ";\n";
indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
<< ";\n";
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
<< rop->name() << " / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << rop->name()
<< " = linear_index" << rop->name() << " % " << multiple << ";\n";
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
<< rop->getRNGOffset() << ";\n";
indent() << "if (rng_subseq != rng_subseq" << rop->name()
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ bool isScalarOp(const Expr* expr) {
}

bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) {
if (expr->isA<kir::BlockSync>()) {
return true;
}

if (!isTvOp(expr)) {
return false;
}
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class DomainMap : public pointwise_utils::DomainMap {
decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs};
for (auto tv_filtered_group : tv_filtered_groups) {
for (auto tv : *tv_filtered_group) {
if (tv->isFusionInput() && tv->uses().empty()) {
continue;
}
if (grouped.count(tv) > 0) {
continue;
}
Expand Down Expand Up @@ -653,7 +656,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
if (inner_most_pos2_in_ref1 > inner_most_pos1_in_ref1) {
inner_most_pos2_in_ref1--;
}
if (!merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) {
if (merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) {
(*merged2)--;
}
reference1->merge(*merged1, inner_most_pos1_in_ref1);
Expand Down
43 changes: 43 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25512,6 +25512,49 @@ TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) {
executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) {
// https://github.com/csarofeen/pytorch/issues/1926
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = makeSymbolicTensor(2);
fusion->addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion->addOutput(tv2);

tv1->setMemoryType(MemoryType::Shared);
for (auto tv : {tv1, tv2}) {
tv->split(0, 4);
tv->reorder({{1, -1}});
tv->split(1, 8);
tv->merge(0);
tv->split(0, 1);
tv->axis(0)->parallelize(ParallelType::BIDx);
tv->axis(1)->parallelize(ParallelType::Unswitch);
}
tv1->merge(2);
tv2->reorder({{2, 3}});
tv2->merge(2);
for (auto tv : {tv1, tv2}) {
tv->axis(-1)->parallelize(ParallelType::TIDx);
}

InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);

auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({5, 5}, options);

FusionExecutor fe;
fe.compileFusion(fusion, {t0});
auto cg_outputs = fe.runFusion({t0});
auto out = cg_outputs[0];

testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
35 changes: 35 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,5 +264,40 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) {
}
}
TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
// https://github.com/csarofeen/pytorch/issues/1926
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
TensorView* tv0 = makeConcreteTensor({5, 1});
TensorView* tv1 = makeConcreteTensor({5, 5});
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = randlike(tv0);
auto tv3 = add(tv1, tv2);
auto tv4 = add(tv0, tv3);
fusion->addOutput(tv4);
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::zeros({5, 1}, options);
at::Tensor t1 = at::zeros({5, 5}, options);
TransposeParams heuristics;
heuristics.tile_size1 = 8;
heuristics.tile_size2 = 4;
scheduleTranspose(fusion, heuristics);
FusionExecutor fe;
fe.compileFusion(fusion, {t0, t1});
auto cg_outputs = fe.runFusion({t0, t1});
auto out = cg_outputs[0];
TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>());
TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>());
TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>());
TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
}
} // namespace jit
} // namespace torch

0 comments on commit b34e3b9

Please sign in to comment.