-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework HaloInfo interface #1987
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ void ShiftPredicateInserter::insert( | |
TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); | ||
|
||
const bool needs_shift_predicate = | ||
gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition()); | ||
gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); | ||
if (!needs_shift_predicate) { | ||
return; | ||
} | ||
|
@@ -145,13 +145,6 @@ const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { | |
return it->second; | ||
} | ||
|
||
AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||
return const_cast<AxisHaloInfo&>( | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||
const_cast<const HaloInfo*>(this)->getRootAxisInfo(id)); | ||
} | ||
|
||
void HaloInfo::setRootAxisInfo( | ||
IterDomain* id, | ||
const AxisHaloInfo& root_axis_info) { | ||
|
@@ -161,7 +154,9 @@ void HaloInfo::setRootAxisInfo( | |
return; | ||
} | ||
|
||
void HaloInfo::build(Fusion* fusion) { | ||
HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map) | ||
// Make a copy of the permissive map for extent comparators | ||
: permissive_map_(ca_map->idGraph().permissiveNodes()) { | ||
const auto vals = fusion->usedMathVals(); | ||
auto tvs = ir_utils::filterByType<TensorView>(vals); | ||
|
||
|
@@ -202,7 +197,7 @@ void HaloInfo::build(Fusion* fusion) { | |
|
||
// Note that validation requires consumer halo info | ||
for (auto tv : tvs) { | ||
validate(tv); | ||
validate(tv, ca_map); | ||
} | ||
} | ||
|
||
|
@@ -474,12 +469,13 @@ void HaloInfo::build(TensorDomain* td) { | |
//! Other types of parallelization should be supported except for | ||
//! vectorization. Vectorization should be eventually supported but | ||
//! needs further work. | ||
void HaloInfo::validate(TensorView* tv) const { | ||
void HaloInfo::validate( | ||
TensorView* tv, | ||
std::shared_ptr<const ComputeAtMap> ca_map) const { | ||
const auto mem_type = tv->getMemoryType(); | ||
|
||
for (auto axis : tv->domain()->domain()) { | ||
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( | ||
axis, IdMappingMode::LOOP); | ||
auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP); | ||
|
||
// The extent is assumed to be the same | ||
TORCH_INTERNAL_ASSERT( | ||
|
@@ -526,7 +522,7 @@ void HaloInfo::validate(TensorView* tv) const { | |
consumer->domain()->domain().begin(), | ||
consumer->domain()->domain().end(), | ||
[&](IterDomain* consumer_axis) { | ||
return GpuLower::current()->caMap()->areMapped( | ||
return ca_map->areMapped( | ||
axis, consumer_axis, IdMappingMode::PERMISSIVE); | ||
}); | ||
if (it == consumer->domain()->domain().end()) { | ||
|
@@ -626,11 +622,10 @@ bool extentCompare( | |
const HaloInfo& halo_map, | ||
IterDomain* id1, | ||
IterDomain* id2, | ||
Cmp cmp) { | ||
auto gpu_lower = GpuLower::current(); | ||
Cmp cmp, | ||
const DisjointSets<IterDomain*>& permissive_map) { | ||
TORCH_INTERNAL_ASSERT( | ||
gpu_lower->caMap()->areMapped(id1, id2, IdMappingMode::PERMISSIVE), | ||
"Invalid axes to compare"); | ||
permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare"); | ||
|
||
// It's invalid to compare two axes and when only either of them has | ||
// halo. | ||
|
@@ -652,10 +647,10 @@ bool extentCompare( | |
auto merge2 = dynamic_cast<Merge*>(id2->definition()); | ||
TORCH_INTERNAL_ASSERT( | ||
merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2); | ||
auto inner_le = | ||
extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp); | ||
auto outer_le = | ||
extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp); | ||
auto inner_le = extentCompare( | ||
halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map); | ||
auto outer_le = extentCompare( | ||
halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map); | ||
return inner_le && outer_le; | ||
} else { | ||
// This is not considered. Should never reach here. | ||
|
@@ -667,11 +662,11 @@ bool extentCompare( | |
} // namespace | ||
|
||
bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { | ||
return extentCompare(*this, id1, id2, std::less_equal<>()); | ||
return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_); | ||
} | ||
|
||
bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { | ||
return extentCompare(*this, id1, id2, std::equal_to<>()); | ||
return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_); | ||
} | ||
|
||
std::string HaloInfo::toString() const { | ||
|
@@ -722,19 +717,19 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { | |
} | ||
|
||
std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap( | ||
const LoopIndexing& loop_indexing) { | ||
const LoopIndexing& loop_indexing) const { | ||
// Use a local workspace to avoid re-defining halo info. | ||
HaloInfo local_halo_info; | ||
HaloInfo local_halo_info = *GpuLower::current()->haloInfo(); | ||
|
||
auto& global_halo_info = GpuLower::current()->haloInfo(); | ||
auto global_halo_info = GpuLower::current()->haloInfo(); | ||
|
||
// Setup root: | ||
for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) { | ||
auto consumer_index_concrete_id = | ||
ir_utils::caMapExactConcreteId(consumer_root_id); | ||
local_halo_info.setRootAxisInfo( | ||
consumer_index_concrete_id, | ||
global_halo_info.getRootAxisInfo(consumer_root_id)); | ||
global_halo_info->getRootAxisInfo(consumer_root_id)); | ||
} | ||
|
||
// Track IDs that are generated by merging halo-extended IDs | ||
|
@@ -801,7 +796,8 @@ std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap( | |
merged_shifted_ids.insert(ir_utils::caMapExactConcreteId(merge->out())); | ||
// Note that halo_width_map_ is not updated | ||
} else { | ||
setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0); | ||
local_halo_info.setHaloWidth( | ||
ir_utils::caMapExactConcreteId(merge->out()), 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the decoupling, here we still use the CA map held by GpuLower. And, also just realized it's a little counter-intuitive that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is making this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will look again at the inter-class dependencies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this class is really hard set to the lowering process as it's taking indexing in. That's why I didn't bother looking more closely. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating That's my understanding from my reading of his refactoring PR. Tagging @shmsong as he may want to chime in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I was assuming it was just a typo. |
||
} | ||
} else if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(expr)) { | ||
// Swizzle with halo not yet supported, just set the width | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you trying to decouple the states held by
GpuLower
from each other? I like it as that seems like a better and more robust design.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm trying to pull the dependencies apart a bit as I was thinking of building a HaloInfo structure in contiguity which is also used outside GpuLower.