Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segment self mapping fusions #1954

Merged
merged 8 commits into from
Sep 14, 2022
Merged

Segment self mapping fusions #1954

merged 8 commits into from
Sep 14, 2022

Conversation

zasdfgbnm
Copy link
Collaborator

No description provided.

Comment on lines +820 to +822
#if 0
// silent wrong result
TEST_F(NVFuserTest, FusionTransposeViewSelfMapping_CUDA) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't get caught?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random thought: we might want to make sure the disjoint view sets, and something similar for axes involved in permutation match. The issue here seems to be that we're transposing the same partial dimensions. We probably want to make sure throughout the fusion the dimensions active in a transpose are dimensions that are disjoint in the view set.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thinking through the analysis, I don't know if we'd need to build something like a transpose disjoint map, but I think it might be enough if we check all the dimensions active in a transpose are disjoint. I'm wondering if we could somehow get in trouble with a series of transposes that are all individually disjoint in the view map, but could exhibit behavior above where we're effectively transposing "partial" dimensions.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on this @zasdfgbnm ?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be enough if we check all the dimensions active in a transpose are disjoint

I feel that this is in the wrong direction. For example torch.zeros(6).view(2, 3).transpose(0, 1) should be totally fine to schedule, but the transpose dimensions are not disjoint in the view disjoint map.

In my opinion, using solely disjoint-set is not sufficient for modeling this problem. That is, we can not just come up with a new fancy disjoint-set of IterDomains such that a fusion is schedulable iif its disjoint-set is really disjoint.

I think we should do the following analysis: For the DAG

       T0[I0, I1]
       /        \
     view   transpose
      |          |
T1[I2, I3]   T2[I4, I5]
       \       /
       T3[I6, I7]

we first start from the view operation. What the view operation tells us is:

  1. On the producer side, {I0, I1} must be ordered as (I0, I1) and I0 is contiguous.
  2. On the consumer side, {I2, I3} must be ordered as (I2, I3) and I2 is contiguous.

If we propagate this order and contiguity information, along the DAG, at some point, we will see a conflict. For example, if we propagate T0->T2 and T1->T3->T2, then:

  1. from the first path, we will conclude that {I4, I5} must be ordered as {I5, I4} and I5 is contiguous.
  2. from the second path, we will conclude that {I4, and I5} must be ordered as {I4, I5} and I4 is contiguous.

The above 1 and 2 are in conflict, so we reject this fusion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we merge this PR, and discuss and fix the skipped test in a follow-up? Currently, transpose is already enabled on TorchScript in TOT devel, and this would prevent the bug from happening on transpose.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's merge this, then please move your comment to an issue and we can discuss further there.

build(fusion);

if (!allow_self_mapping) {
TORCH_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand the error reporting with information about the domains causing the failure?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Could collect this information in selfMappingExists, and print it on failure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message added

This was referenced Sep 9, 2022
@zasdfgbnm zasdfgbnm merged commit eabe8d8 into devel Sep 14, 2022
@zasdfgbnm zasdfgbnm deleted the segment-self-mapping branch September 14, 2022 00:51
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants