Skip to content

Commit

Permalink
Fix vectorization bug introduced in #1831 (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 18, 2022
1 parent 63630f1 commit 0b83645
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
27 changes: 20 additions & 7 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);

// Vectorization are propagated separately
vectorize_id = reference_tv->axis(4);

// [outer, Unswitch | i-remainder, TIDx, Vectorization]
Expand All @@ -708,6 +708,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->reorder({{1, 2}});
// [outer, i-remainder, unswitch, unroll, TIDx ]
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);

//[outer | i-remainder, Unswitch, Unroll, TIDx]
Expand Down Expand Up @@ -794,8 +799,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::TIDx);
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(3);

//[BIDx, TIDx, Unswitch, Vectorization]
Expand All @@ -813,6 +817,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [BIDx, Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(2)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
}
unswitch_pos = 2;
Expand All @@ -830,18 +839,20 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
if (!tv->isFusionInput()) {
vectorized_tvs.emplace_back(tv);
continue;
}
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
// move inputs to consumers of inputs
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
}
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
vectorize_id->parallelize(ParallelType::Vectorize);
scheduler_utils::parallelizeAllLike(
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
Expand All @@ -852,7 +863,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {

// Begin by inlining at the unswitch position for the entire DAG. The cached
// inputs, and outputs will keep this inline position, but other tensors will
// get a higher position in later inline propagation.
// get a higher position in later inline propagation. We need this separate
// step because we were not using ParallelType::Unroll, so we have to do
// unrolling manually.
InlinePropagator inline_unswitch(
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
spanning_tree.traverse(&inline_unswitch);
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18608,6 +18608,44 @@ TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) {
testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

const int size = 1024 * 64;

TensorView* x = makeContigTensor(1);
fusion.addInput(x);
auto y = sin(x);
fusion.addOutput(y);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

// PyTorch's CUDA caching allocator should always return aligned pointer for
// freshly allocated tensor
at::Tensor at_x = at::randn({size}, options);

schedulePointwise(&fusion, {at_x});

for (auto x_consumer : ir_utils::consumerTvsOf(x)) {
bool found_vec_in_input = false;
for (auto id : x_consumer->domain()->domain()) {
if (isParallelTypeVectorize(id->getParallelType())) {
found_vec_in_input = true;
break;
}
}
TORCH_CHECK(found_vec_in_input, "Expect input to be vectorized");
}

for (auto id : y->domain()->domain()) {
if (isParallelTypeVectorize(id->getParallelType())) {
return;
}
}
TORCH_CHECK(false, "Expect output to be vectorized");
}

TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit 0b83645

Please sign in to comment.