diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 4ab5a00c7cee2..e9080a72a2b87 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -692,7 +692,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [outer, Unswitch | i-remainder, TIDx, Vectorization] reference_tv->axis(1)->parallelize(ParallelType::Unswitch); reference_tv->axis(3)->parallelize(ParallelType::TIDx); - + // Vectorization are propagated separately vectorize_id = reference_tv->axis(4); // [outer, Unswitch | i-remainder, TIDx, Vectorization] @@ -708,6 +708,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv->reorder({{1, 2}}); // [outer, i-remainder, unswitch, unroll, TIDx ] reference_tv->axis(2)->parallelize(ParallelType::Unswitch); + // Here we do not set axis(3)->parallelize(Unroll) because we do not want + // it to be propagated. We manually unroll by splitting the inline + // propagation process into two steps: + // step 1: inline at the unswitch position for cached inputs and outputs + // step 2: inline at the inner most dim for the rest of the graph reference_tv->axis(4)->parallelize(ParallelType::TIDx); //[outer | i-remainder, Unswitch, Unroll, TIDx] @@ -794,8 +799,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv->axis(0)->parallelize(ParallelType::BIDx); reference_tv->axis(1)->parallelize(ParallelType::TIDx); reference_tv->axis(2)->parallelize(ParallelType::Unswitch); - // Aggressively mark with vectorized and cleanup later. That way we - // don't have to manually specify parallelization outside the reference. + // Vectorization are propagated separately vectorize_id = reference_tv->axis(3); //[BIDx, TIDx, Unswitch, Vectorization] @@ -813,6 +817,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx, Unswitch, Unroll, TIDx] reference_tv->axis(0)->parallelize(ParallelType::BIDx); reference_tv->axis(1)->parallelize(ParallelType::Unswitch); + // Here we do not set axis(2)->parallelize(Unroll) because we do not want + // it to be propagated. We manually unroll by splitting the inline + // propagation process into two steps: + // step 1: inline at the unswitch position for cached inputs and outputs + // step 2: inline at the inner most dim for the rest of the graph reference_tv->axis(3)->parallelize(ParallelType::TIDx); } unswitch_pos = 2; @@ -830,18 +839,20 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { std::vector vectorized_tvs; bool should_vectorize_reference_tv = false; for (auto tv : inputs_outputs) { + if (tv == reference_tv) { + should_vectorize_reference_tv = true; + } if (!tv->isFusionInput()) { vectorized_tvs.emplace_back(tv); continue; } - if (tv == reference_tv) { - should_vectorize_reference_tv = true; - } // move inputs to consumers of inputs auto consumer_tvs = ir_utils::consumerTvsOf(tv); vectorized_tvs.insert( vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end()); } + // Aggressively mark with vectorized and cleanup later. That way we + // don't have to manually specify parallelization outside the reference. vectorize_id->parallelize(ParallelType::Vectorize); scheduler_utils::parallelizeAllLike( reference_tv, vectorized_tvs, {ParallelType::Vectorize}); @@ -852,7 +863,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // Begin by inlining at the unswitch position for the entire DAG. The cached // inputs, and outputs will keep this inline position, but other tensors will - // get a higher position in later inline propagation. + // get a higher position in later inline propagation. We need this separate + // step because we were not using ParallelType::Unroll, so we have to do + // unrolling manually. InlinePropagator inline_unswitch( reference_tv, unswitch_pos, ComputeAtMode::BestEffort); spanning_tree.traverse(&inline_unswitch); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index fdabf1301dd05..e0f5cb44baac6 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -18608,6 +18608,44 @@ TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int size = 1024 * 64; + + TensorView* x = makeContigTensor(1); + fusion.addInput(x); + auto y = sin(x); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // PyTorch's CUDA caching allocator should always return aligned pointer for + // freshly allocated tensor + at::Tensor at_x = at::randn({size}, options); + + schedulePointwise(&fusion, {at_x}); + + for (auto x_consumer : ir_utils::consumerTvsOf(x)) { + bool found_vec_in_input = false; + for (auto id : x_consumer->domain()->domain()) { + if (isParallelTypeVectorize(id->getParallelType())) { + found_vec_in_input = true; + break; + } + } + TORCH_CHECK(found_vec_in_input, "Expect input to be vectorized"); + } + + for (auto id : y->domain()->domain()) { + if (isParallelTypeVectorize(id->getParallelType())) { + return; + } + } + TORCH_CHECK(false, "Expect output to be vectorized"); +} + TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) { Fusion fusion; FusionGuard fg(&fusion);