Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some further cleanup for the new computeAt interface #1793

Merged
merged 114 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
1338146
Compute at refactor
zasdfgbnm Jun 2, 2022
28348d0
save
zasdfgbnm Jun 2, 2022
1a1e609
fix
zasdfgbnm Jun 2, 2022
9f421d2
cleanup
zasdfgbnm Jun 2, 2022
b84d0c4
cleanup
zasdfgbnm Jun 2, 2022
e8545a7
fix
zasdfgbnm Jun 2, 2022
a441a90
fix
zasdfgbnm Jun 2, 2022
5c7e846
cleanup
zasdfgbnm Jun 2, 2022
938de6f
save
zasdfgbnm Jun 2, 2022
fd16198
fix
zasdfgbnm Jun 2, 2022
af1a560
fix
zasdfgbnm Jun 2, 2022
62852ff
note
zasdfgbnm Jun 3, 2022
fd9259d
split PasC and CasP
zasdfgbnm Jun 3, 2022
2af1b1a
fix
zasdfgbnm Jun 3, 2022
8b4c8a0
fix
zasdfgbnm Jun 3, 2022
89a2a75
don't set ca pos for fusion input
zasdfgbnm Jun 3, 2022
b9a9239
fix
zasdfgbnm Jun 3, 2022
49d19a4
siblings
zasdfgbnm Jun 3, 2022
789532e
cleanup
zasdfgbnm Jun 3, 2022
1d94372
unmappable_dims
zasdfgbnm Jun 3, 2022
436f7a0
tmp test
zasdfgbnm Jun 13, 2022
3022d5a
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 13, 2022
aa60279
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 23, 2022
a49313f
pull new TransformPropagator
zasdfgbnm Jun 23, 2022
c5ad239
save
zasdfgbnm Jun 23, 2022
c9e12ab
Update ir_interface_nodes.h
zasdfgbnm Jun 23, 2022
5ad1f8f
Update ir_interface_nodes.h
zasdfgbnm Jun 23, 2022
ffef502
save
zasdfgbnm Jun 23, 2022
587c2c4
save
zasdfgbnm Jun 23, 2022
8c4cb21
save
zasdfgbnm Jun 23, 2022
ebd943d
save
zasdfgbnm Jun 23, 2022
1ae4e70
fix
zasdfgbnm Jun 23, 2022
58df2cb
save
zasdfgbnm Jun 23, 2022
680a520
save
zasdfgbnm Jun 23, 2022
889bb27
short cut
zasdfgbnm Jun 24, 2022
db54f57
cleanup
zasdfgbnm Jun 24, 2022
270d83c
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 26, 2022
ccd4a00
resolve
zasdfgbnm Jun 26, 2022
83c3d0a
Adding sibling path for MaxInfoSpanningTree
zasdfgbnm Jun 27, 2022
8210fce
save
zasdfgbnm Jun 27, 2022
04525db
remove check in fullSelfReplay
zasdfgbnm Jun 27, 2022
b18fa5b
save
zasdfgbnm Jun 27, 2022
26c1d4e
save
zasdfgbnm Jun 27, 2022
32312ce
save?
zasdfgbnm Jun 27, 2022
c270fd4
Merge branch 'spanning-tree-siblings' of github.com:csarofeen/pytorch…
zasdfgbnm Jun 27, 2022
96ae406
save
zasdfgbnm Jun 27, 2022
a42daf3
save
zasdfgbnm Jun 27, 2022
3e277e2
no hoistInnermostBroadcast
zasdfgbnm Jun 27, 2022
a0757d1
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 27, 2022
4c3342c
save
zasdfgbnm Jun 28, 2022
81d9200
Merge branch 'devel' of github.com:csarofeen/pytorch into spanning-tr…
zasdfgbnm Jun 28, 2022
702f2b0
save
zasdfgbnm Jun 28, 2022
ba0e8af
resolve review
zasdfgbnm Jun 28, 2022
bfe66c7
save
zasdfgbnm Jun 28, 2022
de2f3be
save
zasdfgbnm Jun 28, 2022
195bc34
move skipReplay to TransformReplay
zasdfgbnm Jun 28, 2022
91b2f9b
TransformPropagator skip replay if possible
zasdfgbnm Jun 28, 2022
2a499a8
test
zasdfgbnm Jun 28, 2022
1fb655e
Merge branch 'transform-propagator-skip-replay' of github.com:csarofe…
zasdfgbnm Jun 28, 2022
f4925b3
save
zasdfgbnm Jun 28, 2022
ca8ef16
cleanup debugging print
zasdfgbnm Jun 28, 2022
c2dadf6
minor cleanup
zasdfgbnm Jun 28, 2022
a86f517
more cleanup
zasdfgbnm Jun 28, 2022
d8ed318
more cleanup
zasdfgbnm Jun 28, 2022
f8d6b8a
save
zasdfgbnm Jun 28, 2022
220129b
save
zasdfgbnm Jun 28, 2022
1421a19
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 28, 2022
7756485
save
zasdfgbnm Jun 28, 2022
ad55204
cleanup getConsumerPosAlignedToProducerCA
zasdfgbnm Jun 28, 2022
6e497f6
more cleanup of getConsumerPosAlignedToProducerCA
zasdfgbnm Jun 28, 2022
194c7df
more cleanup on getConsumerPosAlignedToProducerCA
zasdfgbnm Jun 28, 2022
b4fe1fe
cleanup ComputeAtSubgraphSelector
zasdfgbnm Jun 28, 2022
0f8c52e
save
zasdfgbnm Jun 28, 2022
a873c17
save
zasdfgbnm Jun 28, 2022
3f13eb2
save
zasdfgbnm Jun 28, 2022
59ba327
cleanup
zasdfgbnm Jun 28, 2022
679867c
cleanup max pos logic
zasdfgbnm Jun 29, 2022
1839158
getMaxPos* cleanup, step 1
zasdfgbnm Jun 29, 2022
c1738d8
getMaxPos* cleanup, step 2
zasdfgbnm Jun 29, 2022
9f765a4
getMaxPos* cleanup, step 3
zasdfgbnm Jun 29, 2022
b3ef17e
renaming
zasdfgbnm Jun 29, 2022
a28f7aa
file renaming
zasdfgbnm Jun 29, 2022
6bf74e6
validate domain
zasdfgbnm Jun 29, 2022
2a84c76
split out MaxPosCalculator
zasdfgbnm Jun 29, 2022
4953b6c
cleanup computeAt
zasdfgbnm Jun 29, 2022
9100c3d
siblingTvsOf
zasdfgbnm Jun 29, 2022
0830c0a
Move functions around to be consistent with header order, add comment…
csarofeen Jun 29, 2022
d0b0fc1
fix
zasdfgbnm Jun 29, 2022
ea0d9cb
minor cleanup on variable names and comments
zasdfgbnm Jun 29, 2022
aa5f602
Add SpanningTreePrinter
zasdfgbnm Jun 30, 2022
0f53f0a
Merge branch 'printer' of github.com:csarofeen/pytorch into compute-a…
zasdfgbnm Jun 30, 2022
e4d0aac
no check producer
zasdfgbnm Jun 30, 2022
e234c9b
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 30, 2022
5d40935
revert getMaxPosCasP to restore previous behavior
zasdfgbnm Jun 30, 2022
29e1474
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jun 30, 2022
1281d34
revert #1786
zasdfgbnm Jun 30, 2022
befa9dc
fix TransformReplay::getMatchedLeafPosWithoutReplay
zasdfgbnm Jun 30, 2022
11d7ff9
save
zasdfgbnm Jun 30, 2022
2d3c39f
Merge branch 'skip-fix' of github.com:csarofeen/pytorch into compute-…
zasdfgbnm Jun 30, 2022
ee4cd26
save
zasdfgbnm Jun 30, 2022
b18a85a
fix
zasdfgbnm Jul 1, 2022
c07fd52
symmetric
zasdfgbnm Jul 1, 2022
1bb4e02
save
zasdfgbnm Jul 1, 2022
f2d17f3
cleanup
zasdfgbnm Jul 1, 2022
4521302
skip both from and to ids
zasdfgbnm Jul 1, 2022
6793f27
save
zasdfgbnm Jul 1, 2022
8de84bb
fix
zasdfgbnm Jul 1, 2022
6ba20e2
Merge branch 'skip-fix' of github.com:csarofeen/pytorch into compute-…
zasdfgbnm Jul 1, 2022
18c8923
save
zasdfgbnm Jul 1, 2022
2987b48
fix most inlined compute at
zasdfgbnm Jul 1, 2022
30a621b
fix
zasdfgbnm Jul 1, 2022
0f2249d
cleanup
zasdfgbnm Jul 1, 2022
4c3d9a7
save
zasdfgbnm Jul 1, 2022
a00848e
Merge branch 'devel' of github.com:csarofeen/pytorch into compute-at-…
zasdfgbnm Jul 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void ComputeAt::runAt(
TensorView* consumer,
unsigned int consumer_position,
ComputeAtMode mode) {
FUSER_PERF_SCOPE("ComputeAt::run");
FUSER_PERF_SCOPE("ComputeAt::runAt");

// Make sure the correct fusion is setup between this and consumer.
TORCH_CHECK(
Expand All @@ -175,6 +175,10 @@ void ComputeAt::runAt(
consumer,
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
consumer_position = consumer->nDims();
}

FusionGuard fg(producer->fusion());

auto selected = getPropagationSubgraph(producer, consumer);
Expand Down Expand Up @@ -206,6 +210,10 @@ void ComputeAt::runWith(
consumer,
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
producer_position = producer->nDims();
}

FusionGuard fg(producer->fusion());

auto selected = getPropagationSubgraph(producer, consumer);
Expand Down
94 changes: 77 additions & 17 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
int pos = getReplayPosPasC(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
// TODO: Can we make TransformPropagator do the transformation, and
// InlinePropagator only set the CA positions?
// TORCH_CHECK(to_pos >= 0);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" to producer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -335,9 +341,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
int pos = getReplayPosCasP(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
// TODO: Can we make TransformPropagator do the transformation, and
// InlinePropagator only set the CA positions?
// TORCH_CHECK(to_pos >= 0);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" to consumer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -373,23 +385,71 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
recordReplayedPos(to, from_pos);
}

namespace {

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
// compute at position of producer domain. Used in computeAt pass only. No
// checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
// Locate consumer's position that aligns with
// the producer's new compute at axis. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

auto c2p_map =
BestEffortReplay::replayPasC(
producer,
consumer,
-1,
// Compute at root domain may not be valid here, as all
// producers don't have to be able to map into consumer at
// max producer position. Since computeAt should be valid
// and this mechanism is only intended to lower produce
// position of consumer, we can simply use the pairwise map.
PairwiseRootDomainMap(producer, consumer))
.getReplay();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
unsigned int consumer_pos = consumer->nDims();
while (consumer_pos > 0) {
for (auto producer : ir_utils::producerTvsOf(consumer)) {
auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
producer, consumer, consumer_pos);
if (producer_pos >= 0 &&
producer_pos <= producer->getComputeAtPosition()) {
goto finished;
}
Comment on lines -381 to -387
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted this part back to the copy-pasted code from the previously computeAt. Looks like I was doing something wrong in my new code, and after the getMatchedLeafPosWithoutReplayPasC change, this starts to fail.

auto consumer_id = consumer->axis((int)consumer_pos - 1);
auto p_dom = producer->domain()->domain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &c2p_map](IterDomain* p_id) {
auto c_id_it = c2p_map.find(consumer_id);
if (c_id_it != c2p_map.end()) {
return c_id_it->second == p_id;
}
return false;
})) {
break;
}
consumer_pos--;
}
finished:
consumer->setMaxProducer(consumer_pos, true);

return consumer_pos;
}

} // namespace

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
}
consumer->setMaxProducer(consumer_pos);
}

void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
Expand Down