-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: pytorch/pytorch#76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
- Loading branch information
1 parent
108b113
commit a01b572
Showing
1 changed file
with
218 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters