Skip to content

Commit

Permalink
Fix sync map (#2047)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Oct 8, 2022
1 parent f5bca33 commit 6ac74d1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 7 deletions.
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
// Depends on thread_pred_map_, validates parallelization collects which
// tensor views need WAR or RAW syncs
sync_map_.build(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) {
std::cout << sync_map_.toString() << std::endl;
}

partialSplitMap().build(fusion_);

Expand Down
12 changes: 8 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ void SyncMap::build(Fusion* fusion) {
} // end for consumers

if (raw_dims.any()) {
needs_raw_sync_[producer] = raw_dims;
needs_raw_sync_[producer] |= raw_dims;
}

} // end producer
Expand All @@ -492,10 +492,14 @@ void SyncMap::build(Fusion* fusion) {

std::string SyncMap::toString() const {
std::stringstream ss;
ss << "TVs requiring RAW:" << std::endl;
ss << "SyncMap:";
bool is_first = true;
for (auto entry : needs_raw_sync_) {
ss << " " << entry.first->toString() << " :: " << entry.second.toString()
<< std::endl;
if (!is_first) {
ss << ",";
}
ss << " " << entry.first->toString() << " -> " << entry.second.toString();
is_first = false;
}
return ss.str();
}
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5040,6 +5040,40 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) {
ASSERT_ANY_THROW(fusion.printKernel());
}

// Repro of #2046
TEST_F(NVFuserTest, FusionValidateParallelize7_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);

auto tv1 = set(tv0);
auto tv2 = set(tv1);
auto tv3 = set(tv1);
fusion.addOutput(tv2);
fusion.addOutput(tv3);

tv1->setMemoryType(MemoryType::Global);

tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);

tv2->axis(1)->parallelize(ParallelType::TIDy);
tv3->axis(0)->parallelize(ParallelType::BIDx);

// tv2 uses tv1 but is not parallelized with BIDx, so a grid sync is
// required. It should be placed as a top-level expression.

GpuLower gpulw(&fusion);
TORCH_CHECK(
std::any_of(
gpulw.kernel()->topLevelExprs().begin(),
gpulw.kernel()->topLevelExprs().end(),
[](Expr* expr) { return expr->isA<kir::GridSync>(); }),
"Grid sync not found");
}

TEST_F(NVFuserTest, FusionDAGMerging_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ auto parseDebugDumpOptions() {
{DebugDumpOption::TransformPropagator, false},
{DebugDumpOption::Cubin, false},
{DebugDumpOption::Ptx, false},
{DebugDumpOption::BankConflictInfo, false}};
{DebugDumpOption::BankConflictInfo, false},
{DebugDumpOption::SyncMap, false}};

if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
c10::string_view options_view(dump_options);
Expand Down Expand Up @@ -106,6 +107,8 @@ auto parseDebugDumpOptions() {
options_map[DebugDumpOption::Ptx] = true;
} else if (token == "bank_conflict") {
options_map[DebugDumpOption::BankConflictInfo] = true;
} else if (token == "sync_map") {
options_map[DebugDumpOption::SyncMap] = true;
} else {
TORCH_CHECK(
false,
Expand All @@ -118,7 +121,7 @@ auto parseDebugDumpOptions() {
"\tdraw_segmented_fusion, scheduler_params, parallel_dimensions,\n",
"\tbuffer_reuse_verbose, ptxas_verbose, halo, segmenter_logging,\n",
"\tperf_debug_verbose, python_definition, python_frontend_debug,\n",
"\ttransform_propagator, cubin, ptx, bank_conflict\n");
"\ttransform_propagator, cubin, ptx, bank_conflict, sync_map\n");
}
options_view = (end_pos != c10::string_view::npos)
? options_view.substr(end_pos + 1)
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ enum class DebugDumpOption {
//! path and replay result
Cubin, //! Dump compiled CUBIN
Ptx, //! Dump compiled PTX
BankConflictInfo //! Dump bank confliction info
BankConflictInfo, //! Dump bank confliction info
SyncMap //! RAW dependency info
};

TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
Expand Down

0 comments on commit 6ac74d1

Please sign in to comment.