-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extended permutation support requested from upstream #1601
Comments
To support 1. We can do a simple update on our current permutation support logic to handle implicit broadcast. That shouldn't take more than 1 day's work. (famous last words 😃) To support 2 (and 3). Currently with TorchScript PE integration, we can look at profiled outputs and figure out what the stride order should be. The profiling information should tell us what permutation we needs on input & output. So we don't need to do any propagation, but might want to do some permutation folding in order to cancel them out. (NNC is currently doing this, but without the folding part.) |
A quick note, we should confirm the requirements from upstream as well. I'll ping upstream stakeholders on this when we finished the first round of discussion on timeline/workload within the team. |
Very briefly describing our rough brainstorm ideas:
|
Extended permutation support in integration (See more details on #1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ```
Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: #76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Summary: Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: #76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d23619b030444e2a77daab6aaa60988b765ba471 Reviewed By: malfet Differential Revision: D36101858 Pulled By: malfet fbshipit-source-id: 17662c68d7f1b448d72b270d6cfa6b8aea463df6
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: pytorch/pytorch#76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: pytorch/pytorch#76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Feature
Upstream has been requesting us to support/respect permutation propagation from input to output.
Permuted tensor indicates intended memory format for storage. In PyTorch eager, permuted input tensor usually dictates the memory format for output tensor ([Note 0]).
The combined ask coming from upstream is to respect upstream permutation propagation rule. e.g. nvfuser should produce output that's in the same stride order as with PyTorch eager ([Note 1]).
Motivation
view
op is used and enabling nvfuser breaks the script due to the permutation change (since view cannot merge non-contiguous dimensions).Pitch
Our current support for permutation is implemented in integration and adopted by individual ops. We also limit the scope to tensors with identical rank to keep things simple.
To support upstream's ask, it'll be is heavy workload on codegen.
Breaking their concern into smaller pieces:
Note 0: a. Current implementation on format propagation is per individual op, so there could be holes in the support; b. In cases where this propagation fails, it should be considered as a bug (quoting upstream); c. TensorIterator is the de facto framework support on permutation, even though there are other special case supports (BN/CONV e.t.c.)
Note 1: Upstream is not requiring us to match exact stride as with eager, but just stride order. We haven't cleared up details as what to do with broadcast.
Note 2: reasonable in the sense that we shouldn't randomly drop permutation, but don't have to stick to eager for 1.12. But in the long term we better have a consistent behavior.
Additional context
Past issues that might be relevant here: #1254 #1144 #194 #262
The text was updated successfully, but these errors were encountered: