Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup trivial reduction workarounds #2006

Merged
merged 6 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions torch/csrc/jit/codegen/cuda/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,25 @@ size_t MaxPosCalculator::getMaxPosAll(
return max_pos;
}

void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
void inlineMost() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this seems fine to remove, though I'm open minded we want ID based avoidance of inlining certain dimensions, but we probably want a better interface for that.

inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()));
}

void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::vector<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
}

void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::unordered_set<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
Expand Down Expand Up @@ -276,10 +272,9 @@ std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
pair.first->inlineAt(pair.second, best_effort, &calc);
}
Expand All @@ -289,10 +284,9 @@ void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
if (selected.count(pair.first) > 0) {
pair.first->inlineAt(pair.second, best_effort, &calc);
Expand Down
17 changes: 5 additions & 12 deletions torch/csrc/jit/codegen/cuda/inlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,28 @@ class MaxPosCalculator {

// Inline to the right most allowed position for all tensors in the current
// fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost();
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::vector<TensorView*>& tvs);
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::unordered_set<TensorView*>& tvs);

// Inline to the position corresponding to the reference position in the
// reference tensor for all tensors in the current fusion.
TORCH_CUDA_CU_API void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

// Inline to the position corresponding to the reference position in the
// reference tensor for selected tensors in the current fusion.
TORCH_CUDA_CU_API void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

} // namespace cuda
} // namespace fuser
Expand Down
48 changes: 18 additions & 30 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1606,49 +1606,37 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
"Merging IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->isTrivialReduction()) ||
(outer->isTrivialReduction() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
"Merging gather and non-gather domains is not supported.");

TORCH_CHECK(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
if (inner->getIterType() == itype) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
goto itype_infer_finished;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
if (inner->isTrivialReduction()) {
goto itype_infer_finished;
}

// Merging trivial reduction with iter domain, that's fine, just make it an
// iter domain.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
if (outer->isTrivialReduction()) {
itype = inner->getIterType();
goto itype_infer_finished;
}

// Merging trivial reduction with broadcasting, that's fine, just make it a
// broadcasting.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->isBroadcast() || inner->isBroadcast())) {
itype = IterType::Broadcast;
if (inner->isBroadcast()) {
goto itype_infer_finished;
}

if (outer->isBroadcast()) {
itype = inner->getIterType();
goto itype_infer_finished;
}

TORCH_CHECK(
false, "Merging IterDomains requires that their iteration types match.");

itype_infer_finished:
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
Val* expanded_extent = nullptr;
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {
if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
Expand Down
7 changes: 1 addition & 6 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,8 @@ void multiReductionInliner(
}
}

// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
scheduler_utils::getTrivialReductionMap(fusion);

// Inline the schedule
inlineMost(mapped_to_trivial_reduction);
inlineMost();
}

namespace {
Expand Down
86 changes: 13 additions & 73 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@ namespace scheduler_utils {

// Returns number of "valid" dimensions. e.g. if tv has
// [I1, R2, I3, I4, R3{1}]
// where R3{1} is in dont_merge, resulting domain should be:
// [I1, I3*I4, R2, R3{1}] with return value 3
// resulting domain should be:
// [I1, I3*I4, R2*R3{1}] with return value 3
//
// if tv has
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
// where R5{1} and R6{1} are in dont_merge, resulting domain should be:
// [I2*I4, R1*R3, R4, R5{1}, R6{1}]
// resulting domain should be:
// [I2*I4, R1*R3, R4*R5{1}*R6{1}]
// with return value 3
size_t merge_3d(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t merge_3d(TensorView* tv) {
bool active_is_reduction = false;
bool first_dim = true;
int prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -67,10 +61,6 @@ size_t merge_3d(

for (int i = static_cast<int>(tv->nDims()) - 2; i >= 0; i--) {
auto id = tv->axis(i);
if (dont_merge.count(id)) {
continue;
}

if (first_dim) {
active_is_reduction = id->isReduction();
prev_i = i;
Expand All @@ -96,10 +86,6 @@ size_t merge_3d(
prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 3; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -114,7 +100,7 @@ size_t merge_3d(
if (prev_i == -1) {
// Two dimensional, put merged dimensions first
tv->reorder({{-1, 0}, {-2, 1}});
// [outer, inner, dont_merge...]
// [outer, inner]
if (tv->axis(0)->isReduction()) {
// put reductions as second axis
tv->reorder({{0, 1}, {1, 0}});
Expand Down Expand Up @@ -195,13 +181,11 @@ c10::optional<size_t> mergeDims(
return left;
}

size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (!tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand All @@ -219,16 +203,14 @@ size_t mergeReduction(
return prev_i == -1 ? 0 : num_merged + 1;
}

size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeNonReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
if (tv->nDims() == 0) {
return 0;
}
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand Down Expand Up @@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize(
return persistent_buffer_size;
}

std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
auto all_tvs = ir_utils::allTvs(fusion);
std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
// root domain vs domain shouldn't matter as at this point we shouldn't have
// any transformations.
for (auto id : tv->getRootDomain()) {
if (id->isTrivialReduction()) {
mapped_to_trivial_reduction.emplace(id);
}
}
}

if (!mapped_to_trivial_reduction.empty()) {
// Use the loop map as that is the most permissive
auto ca_map = ComputeAtMap(fusion);
// Make a copy we need to check mappings of all
auto trivial_ids = mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
for (auto id : tv->getRootDomain()) {
if (!id->extent()->isOneInt()) {
continue;
}
if (std::any_of(
trivial_ids.begin(),
trivial_ids.end(),
[&ca_map, &id](IterDomain* trivial_id) {
return ca_map.areMapped(
id, trivial_id, IdMappingMode::PERMISSIVE);
})) {
mapped_to_trivial_reduction.emplace(id);
}
}
}
}
return mapped_to_trivial_reduction;
}

std::pair<bool, bool> canonicalDimReduction(
Fusion* fusion,
TensorView* tv,
bool schedule_3D) {
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
getTrivialReductionMap(fusion);

TORCH_INTERNAL_ASSERT(tv != nullptr);

if (!schedule_3D) {
// We coalesce all reduction axes to the right;
bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_red_axis = mergeReduction(tv) > 0;

bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_iter_axis = mergeNonReduction(tv) > 0;
return {has_iter_axis, has_red_axis};
} else {
TORCH_INTERNAL_ASSERT(
merge_3d(tv, mapped_to_trivial_reduction) == 3,
"Tried 3D merge, but result is not 3D.");
merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D.");
return {true, true};
}
}
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
}

// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// reduction axes.
size_t mergeReduction(TensorView* tv);

// merge all non-reduction axes to the left side and returns total number of
// iteration axes. Don't merge is typically used for trivial reductions.
size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// iteration axes.
size_t mergeNonReduction(TensorView* tv);

// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
Expand Down
Loading