forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve view support on pointwise and transpose scheduler #1906
Merged
Merged
Changes from 8 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
18d5bc1
test
zasdfgbnm 47cd850
fix
zasdfgbnm 747c808
fix view grouping
zasdfgbnm 9d0451f
transpose test
zasdfgbnm 88ce6d1
fix
zasdfgbnm 0ea00e8
split inner_only and vectorize_pass
zasdfgbnm 788cb1d
TODO
zasdfgbnm 55e2117
update
zasdfgbnm cd379df
Support expand->flatten
csarofeen 0fdf41a
Trivial reduce before flatten.
csarofeen b8631ad
fix reduction+flatten
zasdfgbnm 784cb42
spanning tree
zasdfgbnm 1602327
fix
zasdfgbnm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,28 +86,53 @@ class DomainMap : public pointwise_utils::DomainMap { | |
// The order here must be deterministic, because in transpose heuristics, we | ||
// have `vectorize_factor1` and `vectorize_factor2` and we need to be sure | ||
// that `1` and `2` are assigned to the same group across runs. | ||
// | ||
// In the case where view is present in the graph, there are two cases: if the | ||
// view doesn't touch any inner dimension of any group, then the support of it | ||
// is trivial. In the case where view actually touches an inner-most dim, we | ||
// keep track of the inner-most dimension of view's split and merges. | ||
// | ||
// For example, if you have: | ||
// T0 [2, 3, 5] <-- input | ||
// T1 [2, 5, 3] <-- input | ||
// T2 [2, 5, 3] = transpose(T0) + T1 | ||
// T3 [2, 15] = view(T2) | ||
// output <-- T3 | ||
// | ||
// Then T3 should be in the same group with T1, and T0 should have | ||
// different group with T1 and T3. | ||
std::vector<std::vector<TensorView*>> groupInputsOutputsByInnerDim() const { | ||
std::vector<std::vector<TensorView*>> groups; | ||
auto output_tvs = ir_utils::filterByType<TensorView>(fusion_->outputs()); | ||
auto input_tvs = ir_utils::filterByType<TensorView>(fusion_->inputs()); | ||
std::unordered_map<size_t, IterDomain*> group_to_inner_dim_map; | ||
decltype(input_tvs)* tv_filtered_group[2] = {&output_tvs, &input_tvs}; | ||
for (auto view : tv_filtered_group) { | ||
for (auto tv : *view) { | ||
auto inner_most_id = scheduler_utils::innerMostRootDim(tv); | ||
bool found = false; | ||
for (auto gi : c10::irange(groups.size())) { | ||
auto& g = groups[gi]; | ||
auto group_inner_dim = group_to_inner_dim_map.at(gi); | ||
if (areExactMapped(inner_most_id, group_inner_dim)) { | ||
g.emplace_back(tv); | ||
found = true; | ||
break; | ||
} | ||
std::unordered_set<TensorView*> grouped; | ||
decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs}; | ||
for (auto tv_filtered_group : tv_filtered_groups) { | ||
for (auto tv : *tv_filtered_group) { | ||
if (grouped.count(tv) > 0) { | ||
continue; | ||
} | ||
if (!found) { | ||
group_to_inner_dim_map[groups.size()] = inner_most_id; | ||
groups.push_back({tv}); | ||
groups.emplace_back(std::vector<TensorView*>{tv}); | ||
grouped.emplace(tv); | ||
// We only want to grab the inner-most dimension, because we don't want | ||
// tensors with different inner-most dimension to be put in the same | ||
// group. For example, if we have: | ||
// T2[i1, i3*i2] = relu(view(transpose(T1[i1, i2, i3]))) | ||
// then we don't want T1 and T2 to be in the same group. | ||
// | ||
// But we don't want to check contiguity. For example, if we have: | ||
// T1[i1, i2, i3] (contiguous) + T2[i1, i2, i3] (discontiguous) | ||
// Then we still want to T1 and T2 to be grouped together. | ||
auto group = | ||
scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false); | ||
for (auto member_tv : group) { | ||
TORCH_INTERNAL_ASSERT( | ||
grouped.count(member_tv) == 0 || member_tv == tv, | ||
"The group of ", | ||
member_tv->toString(), | ||
" is ambiguous. This is likely a bug."); | ||
grouped.emplace(member_tv); | ||
groups.back().emplace_back(member_tv); | ||
} | ||
} | ||
} | ||
|
@@ -263,6 +288,10 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics( | |
runtime_info.getInnerDimVectorizableWidth(tv); | ||
vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor); | ||
} | ||
// TODO: Since group2 only has global->shared and shared->global set op, we | ||
// can have fine-grained control of unroll/vectorization at per tensor level. | ||
// We should not be using a single global vectorize factor for the entire | ||
// group 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we could consider doing this for all our schedulers, but it hasn't seemed like a pragmatic concern. We should revisit (more generically) if we see any cases where this would be beneficial. |
||
for (auto tv : grouped_inputs_outputs[1]) { | ||
const auto tv_vectorize_factor = | ||
runtime_info.getInnerDimVectorizableWidth(tv); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for future work: This also means, group2 does not need a reference tensor. Every tensor is its own reference tensor.