Skip to content

Commit

Permalink
Fix mutator and sameAs for expanded IterDomain (#1902)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Aug 11, 2022
1 parent b7435af commit 3381793
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,10 @@ bool IterDomain::sameAs(const Statement* other) const {
is_same = is_same && ScalarCheck::sameAs(start(), other_id->start());
is_same =
is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset());
is_same = is_same && (hasExpandedExtent() == other_id->hasExpandedExtent());
if (is_same && hasExpandedExtent()) {
is_same = ScalarCheck::sameAs(expandedExtent(), other_id->expandedExtent());
}

return is_same;
}
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/mutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,14 @@ void OptOutMutator::mutate(NamedScalar* ns) {}
void OptOutMutator::mutate(IterDomain* id) {
Val* start = maybeMutated(id->start());
Val* extent = maybeMutated(id->extent());
Val* expanded_extent = nullptr;
if (id->hasExpandedExtent()) {
expanded_extent = maybeMutated(id->expandedExtent());
}
Val* stop_offset = maybeMutated(id->stopOffset());
if (start->sameAs(id->start()) && extent->sameAs(id->extent()) &&
(!id->hasExpandedExtent() ||
expanded_extent->sameAs(id->expandedExtent())) &&
stop_offset->sameAs(id->stopOffset())) {
return;
}
Expand All @@ -69,6 +75,7 @@ void OptOutMutator::mutate(IterDomain* id) {
.start(start)
.extent(extent)
.stop_offset(stop_offset)
.expanded_extent(expanded_extent)
.build());
}

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23815,7 +23815,6 @@ TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) {
// Make tv2 swizzled and half-inlined (unsupported).
tv0->computeAt(tv3, -2);

fusion.print();
FusionExecutor fe;
ASSERT_ANY_THROW(fe.compileFusion(&fusion));
}
Expand Down

0 comments on commit 3381793

Please sign in to comment.