Skip to content

Commit

Permalink
Cleanup trivial reduction workarounds (#2006)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Oct 1, 2022
1 parent e4b6585 commit bca20c1
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 154 deletions.
26 changes: 10 additions & 16 deletions torch/csrc/jit/codegen/cuda/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,25 @@ size_t MaxPosCalculator::getMaxPosAll(
return max_pos;
}

void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
void inlineMost() {
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()));
}

void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::vector<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
}

void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::unordered_set<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
Expand Down Expand Up @@ -276,10 +272,9 @@ std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
pair.first->inlineAt(pair.second, best_effort, &calc);
}
Expand All @@ -289,10 +284,9 @@ void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
if (selected.count(pair.first) > 0) {
pair.first->inlineAt(pair.second, best_effort, &calc);
Expand Down
17 changes: 5 additions & 12 deletions torch/csrc/jit/codegen/cuda/inlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,28 @@ class MaxPosCalculator {

// Inline to the right most allowed position for all tensors in the current
// fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost();
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::vector<TensorView*>& tvs);
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::unordered_set<TensorView*>& tvs);

// Inline to the position corresponding to the reference position in the
// reference tensor for all tensors in the current fusion.
TORCH_CUDA_CU_API void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

// Inline to the position corresponding to the reference position in the
// reference tensor for selected tensors in the current fusion.
TORCH_CUDA_CU_API void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

} // namespace cuda
} // namespace fuser
Expand Down
77 changes: 38 additions & 39 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,43 @@ std::vector<IterDomain*> IterDomain::clone(
return cloned_domains;
}

IterType inferIterType(IterDomain* i1, IterDomain* i2) {
// The itertype inference is a pattern matching of the rules below:
//
// X + X = X
// trivial reduction + X = X
// X + trivial reduction = X
// broadcasting + X = X
// X + broadcasting = X
// fail
//
// The rules are proceeded one by one in order. For each rule, we test if the
// given (outer, inner) matches the pattern. If it does, then we stop
// procceeding and get a result. If we have reached the end without finding
// any matched pattern, then it is a mistake and should be reported.
//
// Note that based on the above rule:
// broadcasting + (non-trivial) reduction = reduction
// broadcasting + trivial reduction = broadcasting
if (i1->getIterType() == i2->getIterType()) {
return i1->getIterType();
}
if (i1->isTrivialReduction()) {
return i2->getIterType();
}
if (i2->isTrivialReduction()) {
return i1->getIterType();
}
if (i1->isBroadcast()) {
return i2->getIterType();
}
if (i2->isBroadcast()) {
return i1->getIterType();
}
TORCH_CHECK(
false, "Merging IterDomains requires that their iteration types match.");
}

// Merging does not propagate the start and stop values of the input
// domains to the merged output domain. The actual range of the
// domains is enforced by predicates. Note that since only root
Expand All @@ -1606,48 +1643,10 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
"Merging IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->isTrivialReduction()) ||
(outer->isTrivialReduction() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
"Merging gather and non-gather domains is not supported.");

TORCH_CHECK(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}

// Merging trivial reduction with iter domain, that's fine, just make it an
// iter domain.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}

// Merging trivial reduction with broadcasting, that's fine, just make it a
// broadcasting.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->isBroadcast() || inner->isBroadcast())) {
itype = IterType::Broadcast;
}
IterType itype = inferIterType(outer, inner);

Val* expanded_extent = nullptr;
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {
Expand Down
7 changes: 1 addition & 6 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,8 @@ void multiReductionInliner(
}
}

// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
scheduler_utils::getTrivialReductionMap(fusion);

// Inline the schedule
inlineMost(mapped_to_trivial_reduction);
inlineMost();
}

namespace {
Expand Down
86 changes: 13 additions & 73 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@ namespace scheduler_utils {

// Returns number of "valid" dimensions. e.g. if tv has
// [I1, R2, I3, I4, R3{1}]
// where R3{1} is in dont_merge, resulting domain should be:
// [I1, I3*I4, R2, R3{1}] with return value 3
// resulting domain should be:
// [I1, I3*I4, R2*R3{1}] with return value 3
//
// if tv has
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
// where R5{1} and R6{1} are in dont_merge, resulting domain should be:
// [I2*I4, R1*R3, R4, R5{1}, R6{1}]
// resulting domain should be:
// [I2*I4, R1*R3, R4*R5{1}*R6{1}]
// with return value 3
size_t merge_3d(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t merge_3d(TensorView* tv) {
bool active_is_reduction = false;
bool first_dim = true;
int prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -67,10 +61,6 @@ size_t merge_3d(

for (int i = static_cast<int>(tv->nDims()) - 2; i >= 0; i--) {
auto id = tv->axis(i);
if (dont_merge.count(id)) {
continue;
}

if (first_dim) {
active_is_reduction = id->isReduction();
prev_i = i;
Expand All @@ -96,10 +86,6 @@ size_t merge_3d(
prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 3; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -114,7 +100,7 @@ size_t merge_3d(
if (prev_i == -1) {
// Two dimensional, put merged dimensions first
tv->reorder({{-1, 0}, {-2, 1}});
// [outer, inner, dont_merge...]
// [outer, inner]
if (tv->axis(0)->isReduction()) {
// put reductions as second axis
tv->reorder({{0, 1}, {1, 0}});
Expand Down Expand Up @@ -195,13 +181,11 @@ c10::optional<size_t> mergeDims(
return left;
}

size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (!tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand All @@ -219,16 +203,14 @@ size_t mergeReduction(
return prev_i == -1 ? 0 : num_merged + 1;
}

size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeNonReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
if (tv->nDims() == 0) {
return 0;
}
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand Down Expand Up @@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize(
return persistent_buffer_size;
}

std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
auto all_tvs = ir_utils::allTvs(fusion);
std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
// root domain vs domain shouldn't matter as at this point we shouldn't have
// any transformations.
for (auto id : tv->getRootDomain()) {
if (id->isTrivialReduction()) {
mapped_to_trivial_reduction.emplace(id);
}
}
}

if (!mapped_to_trivial_reduction.empty()) {
// Use the loop map as that is the most permissive
auto ca_map = ComputeAtMap(fusion);
// Make a copy we need to check mappings of all
auto trivial_ids = mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
for (auto id : tv->getRootDomain()) {
if (!id->extent()->isOneInt()) {
continue;
}
if (std::any_of(
trivial_ids.begin(),
trivial_ids.end(),
[&ca_map, &id](IterDomain* trivial_id) {
return ca_map.areMapped(
id, trivial_id, IdMappingMode::PERMISSIVE);
})) {
mapped_to_trivial_reduction.emplace(id);
}
}
}
}
return mapped_to_trivial_reduction;
}

std::pair<bool, bool> canonicalDimReduction(
Fusion* fusion,
TensorView* tv,
bool schedule_3D) {
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
getTrivialReductionMap(fusion);

TORCH_INTERNAL_ASSERT(tv != nullptr);

if (!schedule_3D) {
// We coalesce all reduction axes to the right;
bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_red_axis = mergeReduction(tv) > 0;

bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_iter_axis = mergeNonReduction(tv) > 0;
return {has_iter_axis, has_red_axis};
} else {
TORCH_INTERNAL_ASSERT(
merge_3d(tv, mapped_to_trivial_reduction) == 3,
"Tried 3D merge, but result is not 3D.");
merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D.");
return {true, true};
}
}
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
}

// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// reduction axes.
size_t mergeReduction(TensorView* tv);

// merge all non-reduction axes to the left side and returns total number of
// iteration axes. Don't merge is typically used for trivial reductions.
size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// iteration axes.
size_t mergeNonReduction(TensorView* tv);

// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
Expand Down
Loading

0 comments on commit bca20c1

Please sign in to comment.