-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transform propagator skip replay when possible #1782
Changes from 15 commits
83c3d0a
8210fce
04525db
b18fa5b
26c1d4e
81d9200
702f2b0
ba0e8af
bfe66c7
195bc34
91b2f9b
2a499a8
0222ada
64b8e57
a929a60
f28c63c
6c15db7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -644,24 +644,173 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP( | |
return replayCasP(consumer, producer, compute_at_axis, root_map); | ||
} | ||
|
||
namespace { | ||
|
||
int getMatchedLeafPosWithoutReplay( | ||
const TensorView* producer, | ||
const TensorView* consumer, | ||
int consumer_or_producer_pos, | ||
bool consumer_pos = true) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We often use names like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the code has been replace by #1791, no more such issue 😉 |
||
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay"); | ||
|
||
const auto c2p_root_map = | ||
PairwiseRootDomainMap(producer, consumer) | ||
.mapConsumerToProducer(consumer->domain(), producer->domain()); | ||
|
||
// IterDomains in consumer root also in producer root | ||
std::unordered_set<Val*> mapped_consumer_roots; | ||
for (auto entry : c2p_root_map) { | ||
mapped_consumer_roots.emplace(entry.first); | ||
} | ||
|
||
const auto consumer_domain = consumer->domain()->domain(); | ||
|
||
auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( | ||
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); | ||
|
||
std::unordered_set<Val*> mapped_consumer_domain_ids( | ||
mapped_consumer_domain_ids_vec.begin(), | ||
mapped_consumer_domain_ids_vec.end()); | ||
|
||
const auto producer_domain = producer->domain()->domain(); | ||
|
||
auto it_consumer = consumer_domain.begin(); | ||
auto it_producer = producer_domain.begin(); | ||
|
||
auto best_effort_PasC = BestEffortReplay::replayPasC( | ||
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); | ||
|
||
auto c2p_map = best_effort_PasC.getReplay(); | ||
|
||
int mismatched_consumer_pos = 0; | ||
int mismatched_producer_pos = 0; | ||
while (it_consumer != consumer_domain.end()) { | ||
auto consumer_id = *it_consumer; | ||
if (!mapped_consumer_domain_ids.count(consumer_id)) { | ||
++it_consumer; | ||
mismatched_consumer_pos++; | ||
continue; | ||
} | ||
|
||
auto c2p_it = c2p_map.find(consumer_id); | ||
if (c2p_it == c2p_map.end()) { | ||
break; | ||
} | ||
|
||
if (it_producer == producer_domain.end()) { | ||
break; | ||
} | ||
|
||
auto producer_id = *it_producer; | ||
|
||
if (c2p_it->second == producer_id) { | ||
++mismatched_consumer_pos; | ||
++mismatched_producer_pos; | ||
++it_consumer; | ||
++it_producer; | ||
if (consumer_pos) { | ||
if (consumer_or_producer_pos == mismatched_consumer_pos) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this was what was confusing me. We're not returning the actual mismatched location, we're just returning the provided position if the mismatched position is further to the right. So this function can only return the provided position, or -1 meaning there's a mismatch before we hit the provided position. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should either return this function to just being a bool, or we should actually return the position of the mismatch and compare that position outside of this function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think we are not just returning the provided position, but instead, if you provide a consumer position, we are returning the corresponding producer position of the given consumer position. In this case, we can not return the actual mismatch position, because it does not provide the information about the corresponding producer position of the given consumer position? I copy-pasted this function from computeAt, and actually, there is one thing that I don't understand about this function. Why are we skipping unmappable dims in the consumer but not in the producer? What breaks the symmetry? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah it's my fault, this function is really just returning yes or no, it shouldn't be returning an int. It could return an int, but should be rewritten to do so.
If I have: T1 cannot be fully inlined to T4. So T1's second dimension (I believe) is marked as unmappable to T4. It's not okay to fully inline T2 into T4, but it's fine to fully inline T1 with T2. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @naoyam can tell me if I'm wrong on the details of the example, but I believe the principles stand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I still don't understand. Are you saying that the algorithm we used to compute the corresponding position is wrong, and we should rewrite to compute it differently? Or are you saying that the algorithm is OK, but we should split up the testing of "needs play" and "find the corresponding position" into two things? Or something else?
Hmm, I am not sure if this example is a related. I think TEST_F(NVFuserTest, TMP) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto tv0 = makeSymbolicTensor(2);
fusion->addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = sum(tv1, {1});
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv1, tv3);
fusion->addOutput(tv4);
auto producer = tv1;
auto consumer = tv4;
const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());
// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}
fusion->print();
std::cout << ir_utils::toString(mapped_consumer_roots) << std::endl;
} This gives I think "unmapped IDs in consumer but not in producer" here refers to new broadcasting dims, and we are saying that, if we see a leaf ID that completely comes from new broadcasting dims, then we can ignore it. But why don't we want to symmetrically do the same for reductions in the producer? For non-trivial reduction, I think this is because it could not be inlined, but for trivial reduction, it should be safe to just ignore it as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the algorithm is correct, and it should not be symmetric. I was considering trivial reduction, but trivial reduction could likely be ignored as well. I was trying to say the asymmetry is due to the unmappable dims, which are related to particular graph patterns with reduction (the pattern mentioned above). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is my mistake, I misread the return statements, I thought we were checking the return value was equal to the provided value to return that value. We're swapping produce/consume so they're different. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The algorithm is actually not correct😉: #1791 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The constraint about reductions not being able to get inlined is reflected in |
||
return mismatched_producer_pos; | ||
} | ||
} else { | ||
if (consumer_or_producer_pos == mismatched_producer_pos) { | ||
return mismatched_consumer_pos; | ||
} | ||
} | ||
} else { | ||
break; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
} // namespace | ||
|
||
int TransformReplay::getMatchedLeafPosWithoutReplayPasC( | ||
const TensorView* producer, | ||
const TensorView* consumer, | ||
int consumer_pos) { | ||
return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true); | ||
} | ||
|
||
int TransformReplay::getMatchedLeafPosWithoutReplayCasP( | ||
const TensorView* consumer, | ||
const TensorView* producer, | ||
int producer_pos) { | ||
return getMatchedLeafPosWithoutReplay( | ||
producer, consumer, producer_pos, false); | ||
} | ||
|
||
bool TransformReplay::fullyMatching( | ||
csarofeen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const TensorView* replay, | ||
const TensorView* target) { | ||
auto replay_root = replay->getRootDomain(); | ||
auto replay_dom = replay->domain()->domain(); | ||
auto target_root = target->getRootDomain(); | ||
auto target_dom = target->domain()->domain(); | ||
std::unordered_map<IterDomain*, IterDomain*> target2replay_map; | ||
if (replay_root.size() != target_root.size()) { | ||
return false; | ||
} | ||
target2replay_map.reserve(replay_root.size()); | ||
std::transform( | ||
target_root.begin(), | ||
target_root.end(), | ||
replay_root.begin(), | ||
std::inserter(target2replay_map, target2replay_map.begin()), | ||
[](auto a, auto b) { return std::make_pair(a, b); }); | ||
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); | ||
auto r = replay_.getReplay(); | ||
for (int64_t i = 0; i < replay_dom.size(); i++) { | ||
auto target_id = target_dom[i]; | ||
auto replay_it = r.find(target_id); | ||
if (replay_it == r.end() || replay_it->second != replay_dom[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { | ||
int pos = replayed_pos_.at(from); | ||
auto replay = TransformReplay::replayPasC(to, from, pos); | ||
to->setDomain(replay.first); | ||
replayed_pos_[to] = replay.second; | ||
// Note: [Using multiple TransformPropagators] | ||
// There are cases that we use multiple TransformPropagators along different | ||
// spanning trees with different references in the same fusion. Some of these | ||
// spanning trees could overlap. In cases when there are overlapping nodes, | ||
// TransformPropagator needs to respect the replay of others, because the | ||
// current TransformPropagator might not contain the most amount of | ||
// information on how to do the correct transformation. The logic below tells | ||
// TransformPropagator to skip the replay when not necessary. | ||
int new_pos = | ||
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); | ||
if (new_pos < 0) { | ||
auto replay = TransformReplay::replayPasC(to, from, pos); | ||
to->setDomain(replay.first); | ||
new_pos = replay.second; | ||
} | ||
replayed_pos_[to] = new_pos; | ||
} | ||
|
||
void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { | ||
int pos = replayed_pos_.at(from); | ||
auto replay = TransformReplay::replayCasP(to, from, pos); | ||
to->setDomain(replay.first); | ||
replayed_pos_[to] = replay.second; | ||
// See note [Using multiple TransformPropagators] | ||
int new_pos = | ||
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); | ||
if (new_pos < 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like I'm missing something here, why would new_pos be less than 0 here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is by protocol.
|
||
auto replay = TransformReplay::replayCasP(to, from, pos); | ||
to->setDomain(replay.first); | ||
new_pos = replay.second; | ||
} | ||
replayed_pos_[to] = new_pos; | ||
} | ||
|
||
void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { | ||
int pos = replayed_pos_.at(from); | ||
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); | ||
to->setDomain(replay); | ||
// See note [Using multiple TransformPropagators] | ||
if (!TransformReplay::fullyMatching(to, from)) { | ||
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); | ||
to->setDomain(replay); | ||
} | ||
replayed_pos_[to] = pos; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments on the parameters and the return value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that these functions need a better doc. I will work on it in a new PR.