Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve buffer reuse pass #1792

Merged
merged 2 commits into from
Jul 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 70 additions & 50 deletions torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,79 @@ namespace {

//! Checks that the current loop nest is not realizing a serial
//! broadcast so that each index of producer buffer will only
//! be visited once.
//! TODO: should refactor this utility now to use loop maps in a
//! follow up.
//! be visited once, which is the only case where aggressive
//! inner sharing is valid.
//!
bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
auto producer_root =
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
auto consumer_root =
TensorDomain::noReductions(consumer->getMaybeRFactorDomain());

if (producer_root.size() != consumer_root.size()) {
// This case would be a single broadcast or a single reduce
// which wouldn't be a broadcast resolution
return false;
//! Note: see issue #1785:
//! serial broadcast resolution doesn't only happen to
//! immediate producers of broadcast ops. We can also have
//! example:
//! T1[I,B] = broadcast(T0[I]])
//! T3[I,I] = T1[I,B] + T2[I,I]
//! T4[I,I] = T3[I,I]
//! and generates the following loop:
//! alloc T0[4]
//! For i in 0..3
//! T0[...] =
//!
//! For j in 0...X:
//! alloc T3[4]
//! for k in 0..3:
//! alloc T1[1]
//! T1[0] = T0[k] // <- This is actually a broadcast resolution
//! T3[k] = T1[0] + T2[...]
//! T4[...] = T3[...]
//!
//! In this case we are actually visiting each pixel of T0 in each iteration
//! of the j loop while T1 was the broadcasted tensor causing this reuse.
//!
//! The current version of checking covers this scenario by checking the root
//! ids of the consumer concrete loop id's. Any time a local tensor like T0
//! appears in a re-use scenario like above, we should see a serial loop id
//! that was derived from some root id that doesn't concretely map to T0's
//! domain.

// Serial concrete loop id's that cover consumer's iter domain.
std::vector<Val*> consumer_serial_loop_concrete_ids;

for (auto consumer_leaf_id : consumer->domain()->domain()) {
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_leaf_id, IdMappingMode::LOOP);

// Check for any serial loop id with non-trivial extent
if (!concrete_loop_id->isThread() &&
!concrete_loop_id->extent()->isOneInt()) {
consumer_serial_loop_concrete_ids.push_back(concrete_loop_id);
}
}

std::vector<Val*> serial_ids;
std::copy_if(
producer->domain()->domain().begin(),
producer->domain()->domain().end(),
std::back_inserter(serial_ids),
[](IterDomain* id) { return !id->isThread(); });

auto serial_producer_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), serial_ids);
auto serial_root_id =
ir_utils::filterByType<IterDomain>(serial_producer_roots);
std::unordered_set<IterDomain*> serial_producer_root_set(
serial_root_id.begin(), serial_root_id.end());

for (const auto idx : c10::irange(producer_root.size())) {
if (producer_root[idx]->isBroadcast() &&
!consumer_root[idx]->isBroadcast()) {
// Check if this broadcast contributed to any serial
// scheduled iterdomains:
if (serial_producer_root_set.count(producer_root[idx])) {
return true;
}
// Collect the root id's that the serial loop iterdomain
// are transformed from.
auto serial_loop_roots = InputsOf::outputs(
FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids);

// Collect exact concrete id's in producer's root domain
std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
auto producer_root =
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
std::transform(
producer_root.begin(),
producer_root.end(),
std::inserter(
producer_exact_concrete_root_ids,
producer_exact_concrete_root_ids.begin()),
ir_utils::caMapExactConcreteId);

// Check if serial loop roots indexes any exact root id's that
// is not within the set of producer's root exact id's. These
// id's will imply that the same producer pixel is accessed
// in multiple iterations of the materialized serial loop.
for (auto serial_loop_root :
ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
if (!producer_exact_concrete_root_ids.count(
ir_utils::caMapExactConcreteId(serial_loop_root))) {
return true;
}
}

Expand Down Expand Up @@ -998,7 +1034,6 @@ class AllocateReuseModifier {
struct InPlaceSharingInfo {
bool has_broadcast_between = false;
bool has_unsupported_op = false;
bool has_serial_broadcast_resolution_between = false;
};

//! Careful heavy check on inner sharing candidates,
Expand Down Expand Up @@ -1037,13 +1072,6 @@ class AllocateReuseModifier {
return false;
}

// TODO: blanket disable reuse across broadcast concretization
// to unblock issue for now.
// Should improve the precision of this analysis in a follow up.
if (topo_info.has_serial_broadcast_resolution_between) {
return false;
}

// Get information on the allocated domains of the
// two buffers
auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap();
Expand Down Expand Up @@ -1096,14 +1124,6 @@ class AllocateReuseModifier {
info.has_unsupported_op = true;
}
}

for (auto in_tv :
ir_utils::filterByType<TensorView>(tv_def->inputs())) {
if (all_used_val_set.count(in_tv) &&
isSerialBroadcastResolution(in_tv, tv)) {
info.has_serial_broadcast_resolution_between = true;
}
}
}
}
return info;
Expand Down