Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended permutation support requested from upstream #1601

Open
jjsjann123 opened this issue Apr 20, 2022 · 3 comments
Open

Extended permutation support requested from upstream #1601

jjsjann123 opened this issue Apr 20, 2022 · 3 comments

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Apr 20, 2022

Feature

Upstream has been requesting us to support/respect permutation propagation from input to output.

Permuted tensor indicates intended memory format for storage. In PyTorch eager, permuted input tensor usually dictates the memory format for output tensor ([Note 0]).

The combined ask coming from upstream is to respect upstream permutation propagation rule. e.g. nvfuser should produce output that's in the same stride order as with PyTorch eager ([Note 1]).

Motivation

  • Upstream's concern mostly comes from perf regression. If we break permutation propagation, we might have impact on down stream operations going through eager (kernel selection / slow permutation).
  • In one of the TIMM script, we also observe a failure where view op is used and enabling nvfuser breaks the script due to the permutation change (since view cannot merge non-contiguous dimensions).

Pitch

Our current support for permutation is implemented in integration and adopted by individual ops. We also limit the scope to tensors with identical rank to keep things simple.

To support upstream's ask, it'll be is heavy workload on codegen.

Breaking their concern into smaller pieces:

  1. PyTorch Core Team wants reasonable permutation propagation in the short term (hard requirement for PyTorch 1.12). Where the main concern is our limited support where we can't propagate tensor with mismatch rank (implicitly broadcasted for binary ops). ([Note 2])
  2. PyTorch JIT Team has more strict requirement on respecting stride order in profiled runs, but have relaxed the support from arbitrary permutation to only propagate contiguous/channels_last. My read from this is that it is NOT a hard requirement for 1.12.
  3. Combining both asks, we should have arbitrary permutation support complies with PyTorch eager (TensorIterator) behavior.

Note 0: a. Current implementation on format propagation is per individual op, so there could be holes in the support; b. In cases where this propagation fails, it should be considered as a bug (quoting upstream); c. TensorIterator is the de facto framework support on permutation, even though there are other special case supports (BN/CONV e.t.c.)

Note 1: Upstream is not requiring us to match exact stride as with eager, but just stride order. We haven't cleared up details as what to do with broadcast.

Note 2: reasonable in the sense that we shouldn't randomly drop permutation, but don't have to stick to eager for 1.12. But in the long term we better have a consistent behavior.

Additional context

Past issues that might be relevant here: #1254 #1144 #194 #262

@jjsjann123
Copy link
Collaborator Author

To support 1. We can do a simple update on our current permutation support logic to handle implicit broadcast. That shouldn't take more than 1 day's work. (famous last words 😃)
There'll be not changed needed on current API. The drawback is that it'll be tricky to match the exact propagation rule in aten, we can try to match TensorIterator in some degree.

To support 2 (and 3). Currently with TorchScript PE integration, we can look at profiled outputs and figure out what the stride order should be. The profiling information should tell us what permutation we needs on input & output. So we don't need to do any propagation, but might want to do some permutation folding in order to cancel them out. (NNC is currently doing this, but without the folding part.)
This will be tricky as I don't see how this could be cleanly done in the integration. My Naive understanding tells me that we need to: 1. add new APIs to allow us mark contiguity in the permuted order; 2. insert transpose at inputs/outputs; 3. optimization pass to fold permutation that would cancel each other out.

@jjsjann123
Copy link
Collaborator Author

A quick note, we should confirm the requirements from upstream as well. I'll ping upstream stakeholders on this when we finished the first round of discussion on timeline/workload within the team.

@jjsjann123
Copy link
Collaborator Author

Very briefly describing our rough brainstorm ideas:

  1. We would want to mark tensor stride order in physical (memory) space, while keeping semantics in virtual space. Similar to what PyTorch is offering, with the exception that we only respect stride order and user can't provide explicit strides.
    API on how that would impact computation definition is still under discussion. We'll likely stick with producing memory-dense tensor.
    It is now users' responsibility on specifying I/O memory format.

  2. current work around to expand permutation with mismatched ranks can't cover our interested cases with unsqueezed bias. This limitation comes from the stride_indices computed by profiling and I don't have a good way to solve it properly. Thinking about hacking something up to specialize the case for broadcasted 1d tensor to achieve the goal in our limited time frame.

jjsjann123 added a commit that referenced this issue Apr 28, 2022
Extended permutation support in integration (See more details on #1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue May 2, 2022
Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```
Pull Request resolved: #76563
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue May 4, 2022
Summary:
Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```

Pull Request resolved: #76563
Approved by: https://github.com/kevinstephano, https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d23619b030444e2a77daab6aaa60988b765ba471

Reviewed By: malfet

Differential Revision: D36101858

Pulled By: malfet

fbshipit-source-id: 17662c68d7f1b448d72b270d6cfa6b8aea463df6
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this issue Oct 29, 2022
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```
Pull Request resolved: pytorch/pytorch#76563
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this issue Nov 10, 2022
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```
Pull Request resolved: pytorch/pytorch#76563
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant