Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve view support on pointwise and transpose scheduler #1906

Merged
merged 13 commits into from
Aug 16, 2022
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
data_cache, [&first_red_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
first_red_tv, true));
first_red_tv, true, true));
});

auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();
Expand All @@ -888,7 +888,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
data_cache, [&first_red_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
first_red_tv, false));
first_red_tv, false, false));
});

auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get();
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
data_cache, [&largest_out]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
largest_out, true));
largest_out, true, true));
});

constexpr int64_t kSixteen = 16; // clang tidy
Expand Down Expand Up @@ -691,7 +691,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
if (params.vectorize) {
// Grab all tensor views that should be vectorized
auto inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
data_cache, [&reduction_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
reduction_tv, true));
reduction_tv, true, true));
});

auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();
Expand All @@ -934,7 +934,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
data_cache, [&reduction_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
reduction_tv, false));
reduction_tv, false, false));
});

auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get();
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void multiReductionInliner(

// Grab all tensor views that should be vectorized
auto vectorizable_inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);

auto vectorizable_expr = [](Expr* e) {
return e->isA<UnaryOp>() &&
Expand Down
63 changes: 46 additions & 17 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,53 @@ class DomainMap : public pointwise_utils::DomainMap {
// The order here must be deterministic, because in transpose heuristics, we
// have `vectorize_factor1` and `vectorize_factor2` and we need to be sure
// that `1` and `2` are assigned to the same group across runs.
//
// In the case where view is present in the graph, there are two cases: if the
// view doesn't touch any inner dimension of any group, then the support of it
// is trivial. In the case where view actually touches an inner-most dim, we
// keep track of the inner-most dimension of view's split and merges.
//
// For example, if you have:
// T0 [2, 3, 5] <-- input
// T1 [2, 5, 3] <-- input
// T2 [2, 5, 3] = transpose(T0) + T1
// T3 [2, 15] = view(T2)
// output <-- T3
//
// Then T3 should be in the same group with T1, and T0 should have
// different group with T1 and T3.
std::vector<std::vector<TensorView*>> groupInputsOutputsByInnerDim() const {
std::vector<std::vector<TensorView*>> groups;
auto output_tvs = ir_utils::filterByType<TensorView>(fusion_->outputs());
auto input_tvs = ir_utils::filterByType<TensorView>(fusion_->inputs());
std::unordered_map<size_t, IterDomain*> group_to_inner_dim_map;
decltype(input_tvs)* tv_filtered_group[2] = {&output_tvs, &input_tvs};
for (auto view : tv_filtered_group) {
for (auto tv : *view) {
auto inner_most_id = scheduler_utils::innerMostRootDim(tv);
bool found = false;
for (auto gi : c10::irange(groups.size())) {
auto& g = groups[gi];
auto group_inner_dim = group_to_inner_dim_map.at(gi);
if (areExactMapped(inner_most_id, group_inner_dim)) {
g.emplace_back(tv);
found = true;
break;
}
std::unordered_set<TensorView*> grouped;
decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs};
for (auto tv_filtered_group : tv_filtered_groups) {
for (auto tv : *tv_filtered_group) {
if (grouped.count(tv) > 0) {
continue;
}
if (!found) {
group_to_inner_dim_map[groups.size()] = inner_most_id;
groups.push_back({tv});
groups.emplace_back(std::vector<TensorView*>{tv});
grouped.emplace(tv);
// We only want to grab the inner-most dimension, because we don't want
// tensors with different inner-most dimension to be put in the same
// group. For example, if we have:
// T2[i1, i3*i2] = relu(view(transpose(T1[i1, i2, i3])))
// then we don't want T1 and T2 to be in the same group.
//
// But we don't want to check contiguity. For example, if we have:
// T1[i1, i2, i3] (contiguous) + T2[i1, i2, i3] (discontiguous)
// Then we still want to T1 and T2 to be grouped together.
auto group =
scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false);
for (auto member_tv : group) {
TORCH_INTERNAL_ASSERT(
grouped.count(member_tv) == 0 || member_tv == tv,
"The group of ",
member_tv->toString(),
" is ambiguous. This is likely a bug.");
grouped.emplace(member_tv);
groups.back().emplace_back(member_tv);
}
}
}
Expand Down Expand Up @@ -263,6 +288,10 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
runtime_info.getInnerDimVectorizableWidth(tv);
vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor);
}
// TODO: Since group2 only has global->shared and shared->global set op, we
// can have fine-grained control of unroll/vectorization at per tensor level.
// We should not be using a single global vectorize factor for the entire
// group 2
Comment on lines +291 to +294
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future work: This also means, group2 does not need a reference tensor. Every tensor is its own reference tensor.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we could consider doing this for all our schedulers, but it hasn't seemed like a pragmatic concern. We should revisit (more generically) if we see any cases where this would be beneficial.

for (auto tv : grouped_inputs_outputs[1]) {
const auto tv_vectorize_factor =
runtime_info.getInnerDimVectorizableWidth(tv);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TORCH_CUDA_CU_API LaunchParams scheduleTranspose(

//! Utility for canSchedule interface to check if this fusion has at least two
//! groups, each with a fully broadcasted reference tensor.
bool hasAtLeastTwoValidGroups(Fusion* fusion);
TORCH_CUDA_CU_API bool hasAtLeastTwoValidGroups(Fusion* fusion);

} // namespace cuda
} // namespace fuser
Expand Down
Loading