Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sync map #2047

Merged
merged 3 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
// Depends on thread_pred_map_, validates parallelization collects which
// tensor views need WAR or RAW syncs
sync_map_.build(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) {
std::cout << sync_map_.toString() << std::endl;
}

partialSplitMap().build(fusion_);

Expand Down
12 changes: 8 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ void SyncMap::build(Fusion* fusion) {
} // end for consumers

if (raw_dims.any()) {
needs_raw_sync_[producer] = raw_dims;
needs_raw_sync_[producer] |= raw_dims;
}

} // end producer
Expand All @@ -492,10 +492,14 @@ void SyncMap::build(Fusion* fusion) {

std::string SyncMap::toString() const {
std::stringstream ss;
ss << "TVs requiring RAW:" << std::endl;
ss << "SyncMap:";
bool is_first = true;
for (auto entry : needs_raw_sync_) {
ss << " " << entry.first->toString() << " :: " << entry.second.toString()
<< std::endl;
if (!is_first) {
ss << ",";
}
ss << " " << entry.first->toString() << " -> " << entry.second.toString();
is_first = false;
}
return ss.str();
}
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5040,6 +5040,40 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) {
ASSERT_ANY_THROW(fusion.printKernel());
}

// Repro of #2046
TEST_F(NVFuserTest, FusionValidateParallelize7_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);

auto tv1 = set(tv0);
auto tv2 = set(tv1);
auto tv3 = set(tv1);
fusion.addOutput(tv2);
fusion.addOutput(tv3);

tv1->setMemoryType(MemoryType::Global);

tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);

tv2->axis(1)->parallelize(ParallelType::TIDy);
tv3->axis(0)->parallelize(ParallelType::BIDx);

// tv2 uses tv1 but is not parallelized with BIDx, so a grid sync is
// required. It should be placed as a top-level expression.

GpuLower gpulw(&fusion);
TORCH_CHECK(
std::any_of(
gpulw.kernel()->topLevelExprs().begin(),
gpulw.kernel()->topLevelExprs().end(),
[](Expr* expr) { return expr->isA<kir::GridSync>(); }),
"Grid sync not found");
}

TEST_F(NVFuserTest, FusionDAGMerging_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ auto parseDebugDumpOptions() {
{DebugDumpOption::TransformPropagator, false},
{DebugDumpOption::Cubin, false},
{DebugDumpOption::Ptx, false},
{DebugDumpOption::BankConflictInfo, false}};
{DebugDumpOption::BankConflictInfo, false},
{DebugDumpOption::SyncMap, false}};

if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
c10::string_view options_view(dump_options);
Expand Down Expand Up @@ -106,6 +107,8 @@ auto parseDebugDumpOptions() {
options_map[DebugDumpOption::Ptx] = true;
} else if (token == "bank_conflict") {
options_map[DebugDumpOption::BankConflictInfo] = true;
} else if (token == "sync_map") {
options_map[DebugDumpOption::SyncMap] = true;
} else {
TORCH_CHECK(
false,
Expand All @@ -118,7 +121,7 @@ auto parseDebugDumpOptions() {
"\tdraw_segmented_fusion, scheduler_params, parallel_dimensions,\n",
"\tbuffer_reuse_verbose, ptxas_verbose, halo, segmenter_logging,\n",
"\tperf_debug_verbose, python_definition, python_frontend_debug,\n",
"\ttransform_propagator, cubin, ptx, bank_conflict\n");
"\ttransform_propagator, cubin, ptx, bank_conflict, sync_map\n");
}
options_view = (end_pos != c10::string_view::npos)
? options_view.substr(end_pos + 1)
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ enum class DebugDumpOption {
//! path and replay result
Cubin, //! Dump compiled CUBIN
Ptx, //! Dump compiled PTX
BankConflictInfo //! Dump bank confliction info
BankConflictInfo, //! Dump bank confliction info
SyncMap //! RAW dependency info
};

TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
Expand Down