Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework HaloInfo interface #1987

Merged
merged 1 commit into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ void ComputeAtMap::allocateIndexVariables() {
// Halo extended parallel loops currently are handled
// differently and an index variable would still
// be allocated in this case.
(GpuLower::current()->haloInfo().getExtent(id) == nullptr)) {
(GpuLower::current()->haloInfo()->getExtent(id) == nullptr)) {
ptype = id->getParallelType();
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ContigIDs::ContigIDs(
(ignore_halo_constraint ||
!GpuLower::current()
->haloInfo()
.getRootAxisInfo(root_domain_i)
->getRootAxisInfo(root_domain_i)
.hasHalo())) {
contig_ids_.emplace(root_domain_i);
is_contig_root_[root_domain_i] = true;
Expand Down
23 changes: 12 additions & 11 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ int getProducerHaloOffset(
IterDomain* consumer_id = it->second;

const auto& halo_map = GpuLower::current()->haloInfo();
const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0);
const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0);
const auto p_pad = halo_map->getRootAxisInfo(producer_id).width(0);
const auto c_pad = halo_map->getRootAxisInfo(consumer_id).width(0);

auto offset = p_pad - c_pad;

Expand Down Expand Up @@ -985,7 +985,7 @@ Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) {
normal_extent = id->extent();
}

const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id);
const auto& halo = GpuLower::current()->haloInfo()->getRootAxisInfo(id);
if (halo.hasHalo()) {
auto halo_extent = SimplifyingIrBuilder::addExpr(
normal_extent, SimplifyingIrBuilder::create<Int>(halo.width()));
Expand Down Expand Up @@ -2351,7 +2351,7 @@ std::vector<PredicateDomainInfo> getPredicateContigIds(
std::unordered_set<IterDomain*> excluded_ids;

for (auto consumer_root_id : consumer_root_domain) {
if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) {
if (gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id).hasHalo()) {
excluded_ids.insert(consumer_root_id);
continue;
}
Expand Down Expand Up @@ -2487,7 +2487,7 @@ int getUnswitchStopOffset(
const auto gpu_lower = GpuLower::current();

AxisHaloInfo halo_info =
gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id);
gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id);

// If the consumer root domain to predicate does not have halo, no
// adjustment is required.
Expand All @@ -2511,7 +2511,7 @@ int getUnswitchStopOffset(
unswitch_it,
consumer_tv->domain()->domain().end(),
[&gpu_lower, &consumer_root_id](auto leaf_id) {
return gpu_lower->haloInfo().isHaloInherited(
return gpu_lower->haloInfo()->isHaloInherited(
consumer_root_id, leaf_id);
})) {
return halo_info.width();
Expand Down Expand Up @@ -2669,7 +2669,8 @@ std::pair<Val*, Val*> getStartAndStopLimitOffsets(
Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset());

if (!non_divisible_pred) {
AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id);
AxisHaloInfo halo_info =
gpu_lower->haloInfo()->getRootAxisInfo(consumer_id);

// Below, "left" and "right" halo mean halo at offset zero and
// axis extent, respectively.
Expand All @@ -2693,8 +2694,8 @@ std::pair<Val*, Val*> getStartAndStopLimitOffsets(
// that it is less than the extent of the predicated ID +
// halo. Note that getRootAxisInfo doesn't work since consumer_id
// isn't a root domain.
if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) {
auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id);
if (gpu_lower->haloInfo()->hasHaloWidth(consumer_id)) {
auto halo = gpu_lower->haloInfo()->getHaloWidth(consumer_id);
stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo);
}
}
Expand Down Expand Up @@ -2841,8 +2842,8 @@ bool canOmitStopPredicate(
// to be predicated, not its merged contig id even if it exists. So,
// if contig_id does not have root axis info, contig_id is
// guaranteed to have no halo.
auto halo_ext = gpu_lower->haloInfo().hasRootAxisInfo(contig_id)
? gpu_lower->haloInfo().getRootAxisInfo(contig_id).width()
auto halo_ext = gpu_lower->haloInfo()->hasRootAxisInfo(contig_id)
? gpu_lower->haloInfo()->getRootAxisInfo(contig_id).width()
: 0;

if (halo_ext + stop_offset_val.value() > 0) {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
// mappings of all iteration domains across the fusion. There are three types
// of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more
// information.
compute_at_map_ = std::make_unique<ComputeAtMap>(fusion_);
compute_at_map_ = std::make_shared<ComputeAtMap>(fusion_);

if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) {
std::cout << compute_at_map_->toString() << std::endl;
Expand Down Expand Up @@ -281,7 +281,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {

// Scan the whole fusion and build mappings about halo extensions of
// all IterDomains
haloInfo().build(fusion_);
halo_info_ = std::make_shared<HaloInfo>(fusion_, compute_at_map_);

// Want to run this after parallel map and halo info map are
// created. vectorized_accesses_ and vectorized_set_info_ are filled.
Expand Down
16 changes: 6 additions & 10 deletions torch/csrc/jit/codegen/cuda/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,16 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
return thread_pred_map_;
}

const std::unique_ptr<ComputeAtMap>& caMap() const {
return compute_at_map_;
std::shared_ptr<const ComputeAtMap> caMap() const {
return std::const_pointer_cast<const ComputeAtMap>(compute_at_map_);
}

const TrivialReductionInfo& trivialReductionInfo() const {
return trivial_reduction_info_;
}

const HaloInfo& haloInfo() const {
return halo_info_;
}

HaloInfo& haloInfo() {
return halo_info_;
std::shared_ptr<const HaloInfo> haloInfo() const {
return std::const_pointer_cast<const HaloInfo>(halo_info_);
}

const ParallelDimensionMap& parallelDimensionMap() const {
Expand Down Expand Up @@ -201,9 +197,9 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
ConcretizedBroadcastDomains concretized_broadcast_domains_;
ThreadPredicateMap thread_pred_map_;
PredicateElimination pred_elimination_;
std::unique_ptr<ComputeAtMap> compute_at_map_;
std::shared_ptr<ComputeAtMap> compute_at_map_;
TrivialReductionInfo trivial_reduction_info_;
HaloInfo halo_info_;
std::shared_ptr<HaloInfo> halo_info_;
LocalAllocationInfoMap local_allocation_info_map_;
WarpPaddedParallelInfo warp_pad_info_;
ParallelDimensionMap parallel_dimension_map_;
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class AllocationInserter : public kir::ExprMutator {
++init_loop_it) {
auto id = *init_loop_it;
kir::ForLoop* new_loop = nullptr;
auto extent_with_halo = gpu_lower->haloInfo().getExtent(id);
auto extent_with_halo = gpu_lower->haloInfo()->getExtent(id);
if (extent_with_halo) {
new_loop = IrBuilder::create<kir::ForLoop>(
id,
Expand Down Expand Up @@ -166,7 +166,7 @@ class AllocationInserter : public kir::ExprMutator {
}
auto extent = id->extent();
// Use halo-extended extent if found
auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id);
auto halo_extent = gpu_lower->haloInfo()->getRootAxisInfo(id);
if (halo_extent.hasHalo()) {
extent = IrBuilder::addExpr(
extent, IrBuilder::create<Int>(halo_extent.width()));
Expand Down Expand Up @@ -213,7 +213,7 @@ class AllocationInserter : public kir::ExprMutator {

// Get the halo extent if found
auto getExtent = [this](IterDomain* id) {
auto extent = gpu_lower->haloInfo().getExtent(id);
auto extent = gpu_lower->haloInfo()->getExtent(id);
if (extent == nullptr) {
extent = id->extent();
}
Expand Down Expand Up @@ -368,7 +368,7 @@ class AllocationInserter : public kir::ExprMutator {

auto extent = concrete_id->extent();

if (gpu_lower->haloInfo().getExtent(info.buffer->axis(axis_i)) !=
if (gpu_lower->haloInfo()->getExtent(info.buffer->axis(axis_i)) !=
nullptr) {
has_halo = true;
}
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ IndexingParameters getLinearIndexParameters(

// Derive the halo extents from the loop indexing result.
index_parameters.concrete_id_to_halo_extent =
GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing);
GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
loop_indexing);

protectNonPredicateIndexWithMagicZero(
loops,
Expand Down Expand Up @@ -233,7 +234,8 @@ IndexingParameters getNonGlobalInitialIndexParameters(

// Derive the halo extents from the loop indexing result.
index_parameters.concrete_id_to_halo_extent =
GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing);
GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
loop_indexing);

return index_parameters;
}
Expand Down Expand Up @@ -408,7 +410,8 @@ IndexingParameters getPredicateInitialIndexParameters(

// Derive the halo extents from the loop indexing result.
index_parameters.concrete_id_to_halo_extent =
GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing);
GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
loop_indexing);

return index_parameters;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ LoopNestGenerator::LoopNestGenerator(const std::vector<Expr*>& exprs) {
namespace {

kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) {
auto extent_with_halo = GpuLower::current()->haloInfo().getExtent(id);
auto extent_with_halo = GpuLower::current()->haloInfo()->getExtent(id);
kir::ForLoop* new_scope = nullptr;
if (extent_with_halo) {
// When an axis is extended with halo, unrolling and vectorization
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,12 @@ class PredicateChcker : public IterVisitor {

// Shift is not supported yet.
bool predicateShift(Expr* expr) const {
auto& halo_info = GpuLower::current()->haloInfo();
auto halo_info = GpuLower::current()->haloInfo();
auto input_tvs = ir_utils::filterByType<TensorView>(expr->inputs());
return halo_info.needsShiftPredicate(expr) ||
return halo_info->needsShiftPredicate(expr) ||
std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) {
return input_tv->definition() != nullptr &&
halo_info.needsShiftPredicate(input_tv->definition());
halo_info->needsShiftPredicate(input_tv->definition());
});
}

Expand Down
54 changes: 25 additions & 29 deletions torch/csrc/jit/codegen/cuda/lower_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void ShiftPredicateInserter::insert(
TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output");

const bool needs_shift_predicate =
gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition());
gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition());
if (!needs_shift_predicate) {
return;
}
Expand Down Expand Up @@ -145,13 +145,6 @@ const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const {
return it->second;
}

AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<AxisHaloInfo&>(
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<const HaloInfo*>(this)->getRootAxisInfo(id));
}

void HaloInfo::setRootAxisInfo(
IterDomain* id,
const AxisHaloInfo& root_axis_info) {
Expand All @@ -161,7 +154,9 @@ void HaloInfo::setRootAxisInfo(
return;
}

void HaloInfo::build(Fusion* fusion) {
HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map)
// Make a copy of the permissive map for extent comparators
: permissive_map_(ca_map->idGraph().permissiveNodes()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you trying to decouple the states held by GpuLower from each other? I like it as that seems like a better and more robust design.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm trying to pull the dependencies apart a bit as I was thinking of building a HaloInfo structure in contiguity which is also used outside GpuLower.

const auto vals = fusion->usedMathVals();
auto tvs = ir_utils::filterByType<TensorView>(vals);

Expand Down Expand Up @@ -202,7 +197,7 @@ void HaloInfo::build(Fusion* fusion) {

// Note that validation requires consumer halo info
for (auto tv : tvs) {
validate(tv);
validate(tv, ca_map);
}
}

Expand Down Expand Up @@ -474,12 +469,13 @@ void HaloInfo::build(TensorDomain* td) {
//! Other types of parallelization should be supported except for
//! vectorization. Vectorization should be eventually supported but
//! needs further work.
void HaloInfo::validate(TensorView* tv) const {
void HaloInfo::validate(
TensorView* tv,
std::shared_ptr<const ComputeAtMap> ca_map) const {
const auto mem_type = tv->getMemoryType();

for (auto axis : tv->domain()->domain()) {
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
axis, IdMappingMode::LOOP);
auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP);

// The extent is assumed to be the same
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -526,7 +522,7 @@ void HaloInfo::validate(TensorView* tv) const {
consumer->domain()->domain().begin(),
consumer->domain()->domain().end(),
[&](IterDomain* consumer_axis) {
return GpuLower::current()->caMap()->areMapped(
return ca_map->areMapped(
axis, consumer_axis, IdMappingMode::PERMISSIVE);
});
if (it == consumer->domain()->domain().end()) {
Expand Down Expand Up @@ -626,11 +622,10 @@ bool extentCompare(
const HaloInfo& halo_map,
IterDomain* id1,
IterDomain* id2,
Cmp cmp) {
auto gpu_lower = GpuLower::current();
Cmp cmp,
const DisjointSets<IterDomain*>& permissive_map) {
TORCH_INTERNAL_ASSERT(
gpu_lower->caMap()->areMapped(id1, id2, IdMappingMode::PERMISSIVE),
"Invalid axes to compare");
permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare");

// It's invalid to compare two axes and when only either of them has
// halo.
Expand All @@ -652,10 +647,10 @@ bool extentCompare(
auto merge2 = dynamic_cast<Merge*>(id2->definition());
TORCH_INTERNAL_ASSERT(
merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2);
auto inner_le =
extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp);
auto outer_le =
extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp);
auto inner_le = extentCompare(
halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map);
auto outer_le = extentCompare(
halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map);
return inner_le && outer_le;
} else {
// This is not considered. Should never reach here.
Expand All @@ -667,11 +662,11 @@ bool extentCompare(
} // namespace

bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const {
return extentCompare(*this, id1, id2, std::less_equal<>());
return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_);
}

bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const {
return extentCompare(*this, id1, id2, std::equal_to<>());
return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_);
}

std::string HaloInfo::toString() const {
Expand Down Expand Up @@ -722,19 +717,19 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const {
}

std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap(
const LoopIndexing& loop_indexing) {
const LoopIndexing& loop_indexing) const {
// Use a local workspace to avoid re-defining halo info.
HaloInfo local_halo_info;
HaloInfo local_halo_info = *GpuLower::current()->haloInfo();

auto& global_halo_info = GpuLower::current()->haloInfo();
auto global_halo_info = GpuLower::current()->haloInfo();

// Setup root:
for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) {
auto consumer_index_concrete_id =
ir_utils::caMapExactConcreteId(consumer_root_id);
local_halo_info.setRootAxisInfo(
consumer_index_concrete_id,
global_halo_info.getRootAxisInfo(consumer_root_id));
global_halo_info->getRootAxisInfo(consumer_root_id));
}

// Track IDs that are generated by merging halo-extended IDs
Expand Down Expand Up @@ -801,7 +796,8 @@ std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap(
merged_shifted_ids.insert(ir_utils::caMapExactConcreteId(merge->out()));
// Note that halo_width_map_ is not updated
} else {
setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0);
local_halo_info.setHaloWidth(
ir_utils::caMapExactConcreteId(merge->out()), 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the decoupling, here we still use the CA map held by GpuLower. And, also just realized it's a little counter-intuitive that ir_utils implicitly uses GpuLower. It doesn't seem trivial to remove this interface as it's used everywhere.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is making this local_halo_info right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look again at the inter-class dependencies.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this class is really hard set to the lowering process as it's taking indexing in. That's why I didn't bother looking more closely.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir_utils::caMapExactConcreteId seems like a weak alias, will get rid of that either before this is merged or in another PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating local_halo_info is right. It was part of Shiming's indexing refactoring. Previously, reference domains are added to the singleton instance of HaloInfo, but as we don't create reference IterDomains anymore, the properties generated for reference IterDomains are instead generated for concrete IterDomains. A complexity about halo is that even exact mapped IterDomains may have different halo sizes, so it's not possible to associate a unique halo attribute with a group of mapped IterDomains. As the halo size information depends on the loop nest where an expression appears, what Shiming did was to clone the pre-built HaloInfo and modify the halo information for the used concrete IterDomains at that loop nest.

That's my understanding from my reading of his refactoring PR. Tagging @shmsong as he may want to chime in.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was assuming it was just a typo.

}
} else if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(expr)) {
// Swizzle with halo not yet supported, just set the width
Expand Down
Loading