Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schedulePointwise cleanup: - computeAt + InlinePropagator #1815

Merged
merged 23 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ void ComputeAt::runAt(
InlinePropagatorSelector selector(selected);

InlinePropagator inline_propagator(
selector.selected(), consumer, consumer_position, mode);
consumer, consumer_position, mode, selector.selected());
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
Expand Down Expand Up @@ -227,7 +227,7 @@ void ComputeAt::runWith(
InlinePropagatorSelector selector(selected);

InlinePropagator inline_propagator(
selector.selected(), producer, producer_position, mode);
producer, producer_position, mode, selector.selected());
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
Expand Down
10 changes: 7 additions & 3 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {

void InlinePropagator::setCAPos(TensorView* tv) {
size_t pos = mapped_reference_pos_.at(tv);
if (selected_.count(tv) && !tv->isFusionInput()) {
if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) {
auto max_pos = getMaxPosAll(tv);
if (mode_ == ComputeAtMode::Standard) {
TORCH_INTERNAL_ASSERT(
Expand All @@ -171,10 +171,10 @@ void InlinePropagator::setCAPos(TensorView* tv) {
}

InlinePropagator::InlinePropagator(
std::unordered_set<TensorView*> selected,
TensorView* reference,
int64_t reference_pos,
ComputeAtMode mode)
ComputeAtMode mode,
std::unordered_set<TensorView*> selected)
: max_pos_calc(mode),
selected_(std::move(selected)),
reference_(reference),
Expand Down Expand Up @@ -213,6 +213,8 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
Expand Down Expand Up @@ -240,6 +242,8 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
Expand Down
27 changes: 21 additions & 6 deletions torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace cuda {
// Simple selector that only propagates across tensor views in the provided
// unordered_set. Will also propagate to all consumers of those tensors, and the
// siblings of those tensors.
class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
class TORCH_CUDA_CU_API InlinePropagatorSelector
: public MaxInfoSpanningTree::Selector {
std::unordered_set<TensorView*> selected_;

public:
Expand All @@ -29,7 +30,7 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
}
};

class MaxPosCalculator {
class TORCH_CUDA_CU_API MaxPosCalculator {
ComputeAtMode mode_ = ComputeAtMode::Standard;

// Root domains in producer that's unmappable to any of its consumers
Expand Down Expand Up @@ -67,7 +68,10 @@ class MaxPosCalculator {
MaxPosCalculator(ComputeAtMode mode);
};

class InlinePropagator : public MaxInfoSpanningTree::Propagator {
// Propagate inline position to the `selected` tensors in the DAG. If `selected`
// is not specified or empty, then propagate to the entire DAG.
class TORCH_CUDA_CU_API InlinePropagator
: public MaxInfoSpanningTree::Propagator {
// Checks producers and consumers to see what the maximum position in tv is
// that can be shared across both directions.
size_t getMaxPosAll(TensorView* tv, bool check_siblings = true);
Expand All @@ -94,10 +98,20 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {

public:
InlinePropagator(
std::unordered_set<TensorView*> selected,
TensorView* reference,
int64_t reference_pos,
ComputeAtMode mode);
ComputeAtMode mode = ComputeAtMode::Standard,
std::unordered_set<TensorView*> selected = {});

InlinePropagator(
TensorView* reference,
int64_t reference_pos,
std::unordered_set<TensorView*> selected)
: InlinePropagator(
reference,
reference_pos,
ComputeAtMode::Standard,
selected) {}

~InlinePropagator() = default;

Expand All @@ -112,7 +126,8 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
// the tensors, and it is not needed to compute the max producer position in a
// specific order. But MaxInfoSpanningTree provides a very convenient API to
// visit the tensors, so I just use it for cleaner code.
class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator {
class TORCH_CUDA_CU_API MaxProducerPosUpdater
: public MaxInfoSpanningTree::Propagator {
std::unordered_set<TensorView*> updated_;
void handle(TensorView* tv);

Expand Down
108 changes: 35 additions & 73 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>

#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
Expand All @@ -13,6 +14,9 @@

#include <ATen/cuda/CUDAContext.h>

#include <algorithm>
#include <unordered_map>

namespace torch {
namespace jit {
namespace fuser {
Expand Down Expand Up @@ -671,6 +675,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
}
}

int64_t unswitch_pos;
if (params.break_point) {
// 2D parallelization scheme
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
Expand Down Expand Up @@ -723,8 +728,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
} else {
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
}
} else {
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx]
Expand All @@ -734,8 +741,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
} else {
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
}
}
} else {
Expand All @@ -747,8 +756,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
} else {
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
unswitch_pos = 3;
}
} else {
// [BIDx | BIDy | Unswitch, Unroll, TIDx]
Expand All @@ -757,8 +768,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
} else {
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 3;
}
}
}
Expand Down Expand Up @@ -803,10 +816,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
}
unswitch_pos = 2;
}

TransformPropagator propagator(reference_tv);
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator);
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
spanning_tree.traverse(&propagator);
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);

if (params.vectorize) {
Expand Down Expand Up @@ -841,84 +856,31 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
}
}

// Compute at into cached inputs
std::vector<TensorView*> consumers_of_cached_inputs;
// Cache of input, and one of its consumers
std::vector<std::pair<TensorView*, TensorView*>> input_cache_and_consumer;
{
// Avoid duplicate additions, so track what we add
std::unordered_set<TensorView*> added;
for (auto cached_input : cached_inputs) {
auto consumer_tvs = ir_utils::consumerTvsOf(cached_input);
TORCH_INTERNAL_ASSERT(
consumer_tvs.size(),
"Input was not succesfully filtered out for scheduling but wasn't used.");

// Grab a consumer which will be used for computeAt structure of cached
// input into a consumer
input_cache_and_consumer.emplace_back(
std::make_pair(cached_input, consumer_tvs[0]));

// Grab all consumers which will be used for inlining computeAt for the
// body of the computation (excluding caching inputs/outputs)
for (auto consumer_tv : consumer_tvs) {
// Don't duplicate
if (added.insert(consumer_tv).second) {
consumers_of_cached_inputs.emplace_back(consumer_tv);
}
}
}
}

for (auto entry : input_cache_and_consumer) {
// Compute at inside unswitch position:
auto input_cache = entry.first;
auto input_cache_consumer = entry.second;

auto unswitch_it = std::find_if(
input_cache_consumer->domain()->domain().begin(),
input_cache_consumer->domain()->domain().end(),
[](IterDomain* id) {
return id->getParallelType() == ParallelType::Unswitch;
});
auto unswitch_pos =
unswitch_it == input_cache_consumer->domain()->domain().end()
? -1
: std::distance(
input_cache_consumer->domain()->domain().begin(), unswitch_it) +
1;
// Begin by inlining at the unswitch position for the entire DAG. The cached
// inputs, and outputs will keep this inline position, but other tensors will
// get a higher position in later inline propagation.
InlinePropagator inline_unswitch(
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
spanning_tree.traverse(&inline_unswitch);

input_cache->computeAt(
input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort);
// Inline at the inner most position. The CA position of all tensors except
// inputs, cached inputs and outputs will be updated.
std::unordered_set<TensorView*> inner_most_tensors(
all_tvs.begin(), all_tvs.end());
for (auto cached_input : cached_inputs) {
inner_most_tensors.erase(cached_input);
}

// Producers for inlined computeAt
std::vector<TensorView*> compute_from = consumers_of_cached_inputs;

// Consumers for inlined computeAt
std::vector<TensorView*> compute_to;
// Compute at cached outputs
//[BIDx, Unswitch, Vectorization, TIDx]
for (auto entry : cached_outputs) {
auto cached_output = entry.first;
auto output = entry.second;

auto unswitch_it = std::find_if(
output->domain()->domain().begin(),
output->domain()->domain().end(),
[](IterDomain* id) {
return id->getParallelType() == ParallelType::Unswitch;
});
auto unswitch_pos = unswitch_it == output->domain()->domain().end()
? -1
: std::distance(output->domain()->domain().begin(), unswitch_it) + 1;

cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort);
compute_to.push_back(cached_output);
inner_most_tensors.erase(output);
}
InlinePropagator inline_inner_most(
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
spanning_tree.traverse(&inline_inner_most);

scheduler_utils::computeAtBetween(
compute_from, compute_to, -1, ComputeAtMode::BestEffort);
// Fix max producer position
MaxProducerPosUpdater updater;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call MaxProducerPosUpdater at the end of InlinePropagator::traverse?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InlinePropagator does not have a traverse, MaxRootDomainInfoSpanningTree does, so probably no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should move the functionality of MaxProducerPosUpdater into InlinePropagator::setCAPos? I can do that in a separate PR, but doing so will slow down compilation speed as we will be updating the max producer pos of a tensor multiple times.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or... Is the max producer position only used in expr sort? If so, we can just leave it as unset when scheduling, and add a pass in the lowering for setting it.

Copy link
Owner

@csarofeen csarofeen Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max producer position is important for multiple transformation passes. To make sure that we aren't modifying tensors that are being relied on by other tensors. It's more of a rule to enforce as we're modifying the schedule it remains legal.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding it to the tail end of inlining pass is worth the extra cost, as it will help catch user errors in scheduling.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it can be done in a separate PR.

spanning_tree.traverse(&updater);
}

} // namespace cuda
Expand Down
26 changes: 13 additions & 13 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1362,26 +1362,26 @@ TEST_F(NVFuserTest, FusionParser_CUDA) {
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
int64_t i51;
i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
if ((i51 < T0.size[0])) {
int64_t i50;
i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
if ((i50 < T0.size[0])) {
float T5[1];
T5[0] = 0;
T5[0]
= T1[i51];
= T1[i50];
float T4[1];
T4[0] = 0;
T4[0]
= T0[i51];
float T6[1];
= T0[i50];
float T2[1];
T2[0]
= T4[0]
* T5[0];
float T6[1];
T6[0]
= T2[0]
* T4[0];
T3[i51]
T3[i50]
= T6[0];
}
}
Expand Down Expand Up @@ -19086,18 +19086,17 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) {
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) {
int64_t i172;
i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
int64_t i171;
i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
__half T9[1];
T9[0] = 0;
T9[0]
= T2[((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * ((T0.size[2] * T0.size[1]) * T0.size[3])) + ((((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * (T0.size[2] * T0.size[1])) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * T0.size[2]) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3])];
__half T8[1];
T8[0] = 0;
T8[0]
= T0[i172];
__half T10[1];
= T0[i171];
float T3[1];
T3[0]
= __half2float(T9[0]);
Expand All @@ -19114,9 +19113,10 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2,
float T6[1];
T6[0]
= relu(T5[0]);
__half T10[1];
T10[0]
= __float2half(T6[0]);
T7[i172]
T7[i171]
= T10[0];
}
}
Expand Down