Skip to content

Commit

Permalink
improve broadcast resolution (#1792)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jul 13, 2022
1 parent bee6c69 commit 03180aa
Showing 1 changed file with 70 additions and 50 deletions.
120 changes: 70 additions & 50 deletions torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,79 @@ namespace {

//! Checks that the current loop nest is not realizing a serial
//! broadcast so that each index of producer buffer will only
//! be visited once.
//! TODO: should refactor this utility now to use loop maps in a
//! follow up.
//! be visited once, which is the only case where aggressive
//! inner sharing is valid.
//!
bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
auto producer_root =
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
auto consumer_root =
TensorDomain::noReductions(consumer->getMaybeRFactorDomain());

if (producer_root.size() != consumer_root.size()) {
// This case would be a single broadcast or a single reduce
// which wouldn't be a broadcast resolution
return false;
//! Note: see issue #1785:
//! serial broadcast resolution doesn't only happen to
//! immediate producers of broadcast ops. We can also have
//! example:
//! T1[I,B] = broadcast(T0[I]])
//! T3[I,I] = T1[I,B] + T2[I,I]
//! T4[I,I] = T3[I,I]
//! and generates the following loop:
//! alloc T0[4]
//! For i in 0..3
//! T0[...] =
//!
//! For j in 0...X:
//! alloc T3[4]
//! for k in 0..3:
//! alloc T1[1]
//! T1[0] = T0[k] // <- This is actually a broadcast resolution
//! T3[k] = T1[0] + T2[...]
//! T4[...] = T3[...]
//!
//! In this case we are actually visiting each pixel of T0 in each iteration
//! of the j loop while T1 was the broadcasted tensor causing this reuse.
//!
//! The current version of checking covers this scenario by checking the root
//! ids of the consumer concrete loop id's. Any time a local tensor like T0
//! appears in a re-use scenario like above, we should see a serial loop id
//! that was derived from some root id that doesn't concretely map to T0's
//! domain.

// Serial concrete loop id's that cover consumer's iter domain.
std::vector<Val*> consumer_serial_loop_concrete_ids;

for (auto consumer_leaf_id : consumer->domain()->domain()) {
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_leaf_id, IdMappingMode::LOOP);

// Check for any serial loop id with non-trivial extent
if (!concrete_loop_id->isThread() &&
!concrete_loop_id->extent()->isOneInt()) {
consumer_serial_loop_concrete_ids.push_back(concrete_loop_id);
}
}

std::vector<Val*> serial_ids;
std::copy_if(
producer->domain()->domain().begin(),
producer->domain()->domain().end(),
std::back_inserter(serial_ids),
[](IterDomain* id) { return !id->isThread(); });

auto serial_producer_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), serial_ids);
auto serial_root_id =
ir_utils::filterByType<IterDomain>(serial_producer_roots);
std::unordered_set<IterDomain*> serial_producer_root_set(
serial_root_id.begin(), serial_root_id.end());

for (const auto idx : c10::irange(producer_root.size())) {
if (producer_root[idx]->isBroadcast() &&
!consumer_root[idx]->isBroadcast()) {
// Check if this broadcast contributed to any serial
// scheduled iterdomains:
if (serial_producer_root_set.count(producer_root[idx])) {
return true;
}
// Collect the root id's that the serial loop iterdomain
// are transformed from.
auto serial_loop_roots = InputsOf::outputs(
FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids);

// Collect exact concrete id's in producer's root domain
std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
auto producer_root =
TensorDomain::noReductions(producer->getMaybeRFactorDomain());
std::transform(
producer_root.begin(),
producer_root.end(),
std::inserter(
producer_exact_concrete_root_ids,
producer_exact_concrete_root_ids.begin()),
ir_utils::caMapExactConcreteId);

// Check if serial loop roots indexes any exact root id's that
// is not within the set of producer's root exact id's. These
// id's will imply that the same producer pixel is accessed
// in multiple iterations of the materialized serial loop.
for (auto serial_loop_root :
ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
if (!producer_exact_concrete_root_ids.count(
ir_utils::caMapExactConcreteId(serial_loop_root))) {
return true;
}
}

Expand Down Expand Up @@ -998,7 +1034,6 @@ class AllocateReuseModifier {
struct InPlaceSharingInfo {
bool has_broadcast_between = false;
bool has_unsupported_op = false;
bool has_serial_broadcast_resolution_between = false;
};

//! Careful heavy check on inner sharing candidates,
Expand Down Expand Up @@ -1044,13 +1079,6 @@ class AllocateReuseModifier {
return false;
}

// TODO: blanket disable reuse across broadcast concretization
// to unblock issue for now.
// Should improve the precision of this analysis in a follow up.
if (topo_info.has_serial_broadcast_resolution_between) {
return false;
}

// Get information on the allocated domains of the
// two buffers
auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap();
Expand Down Expand Up @@ -1103,14 +1131,6 @@ class AllocateReuseModifier {
info.has_unsupported_op = true;
}
}

for (auto in_tv :
ir_utils::filterByType<TensorView>(tv_def->inputs())) {
if (all_used_val_set.count(in_tv) &&
isSerialBroadcastResolution(in_tv, tv)) {
info.has_serial_broadcast_resolution_between = true;
}
}
}
}
return info;
Expand Down

0 comments on commit 03180aa

Please sign in to comment.