Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug introduced in #1831 #1840

Merged
merged 2 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);

// Vectorization are propagated separately
vectorize_id = reference_tv->axis(4);

// [outer, Unswitch | i-remainder, TIDx, Vectorization]
Expand All @@ -708,6 +708,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->reorder({{1, 2}});
// [outer, i-remainder, unswitch, unroll, TIDx ]
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);

//[outer | i-remainder, Unswitch, Unroll, TIDx]
Expand Down Expand Up @@ -794,8 +799,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::TIDx);
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(3);

//[BIDx, TIDx, Unswitch, Vectorization]
Expand All @@ -813,6 +817,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [BIDx, Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(2)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
}
unswitch_pos = 2;
Expand All @@ -830,18 +839,20 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
if (!tv->isFusionInput()) {
vectorized_tvs.emplace_back(tv);
continue;
}
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
// move inputs to consumers of inputs
Comment on lines 841 to 849
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bug fix

auto consumer_tvs = ir_utils::consumerTvsOf(tv);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
}
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
vectorize_id->parallelize(ParallelType::Vectorize);
scheduler_utils::parallelizeAllLike(
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
Expand All @@ -852,7 +863,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {

// Begin by inlining at the unswitch position for the entire DAG. The cached
// inputs, and outputs will keep this inline position, but other tensors will
// get a higher position in later inline propagation.
// get a higher position in later inline propagation. We need this separate
// step because we were not using ParallelType::Unroll, so we have to do
// unrolling manually.
InlinePropagator inline_unswitch(
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
spanning_tree.traverse(&inline_unswitch);
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18608,6 +18608,44 @@ TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) {
testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

const int size = 1024 * 64;

TensorView* x = makeContigTensor(1);
fusion.addInput(x);
auto y = sin(x);
fusion.addOutput(y);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

// PyTorch's CUDA caching allocator should always return aligned pointer for
// freshly allocated tensor
at::Tensor at_x = at::randn({size}, options);

schedulePointwise(&fusion, {at_x});

for (auto x_consumer : ir_utils::consumerTvsOf(x)) {
bool found_vec_in_input = false;
for (auto id : x_consumer->domain()->domain()) {
if (isParallelTypeVectorize(id->getParallelType())) {
found_vec_in_input = true;
break;
}
}
TORCH_CHECK(found_vec_in_input, "Expect input to be vectorized");
}

for (auto id : y->domain()->domain()) {
if (isParallelTypeVectorize(id->getParallelType())) {
return;
}
}
TORCH_CHECK(false, "Expect output to be vectorized");
}

TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down