[WIP] Implement tiling intefrace for unpack op #10823
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Idea
The main issue is about incomplete tile. Since all the dimensions are orthogonal, discussing 1-d unpack case is enough. The core idea is to make the input slice have complete tiles. In this case, a larger unpacked tile will be created. We'll need an extract_slice op to shift and truncate the output.
Example
Let's take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The coordinates of second tile (i.e.,
result[15..31]) are[(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last row are incomplete in terms of inputs. It's impossible to represent an unpack op using the coordinates. Because the input has higher rank and the math computation of coordinate is using mod and ceilDiv. That's very tricky.To represent the unpack op, we have to complete the rows. I.e., the input coordinates would start with
(1, 0); end with(3, 7). In this context, the tiled unpack produces a (3 * n) elements because there are 3 rows in total. Follow by a tensor.extract_slice op, we can get the actual result.What assumptions are broken in the approach?
Tow operations are returned. It breaks an assumption that tiling algorithm expects it to return only one operation. IMO, it can be relaxed. A sequence of operations are generated during tiling sounds no issue to me. The PR also prototypes it in tile + distribution pass, and it works e2e. It can't be addressed by using
getResultTilePositionbecause of the mechanism of tiling interface. That would introduce a cast op, which turns into shape mismatch for some cases.How do we handle extra memory usage?
The approach takes larger input and produces larger output. We'll need a temp space to store the larger output, and write the result back to output buffer. It means that a stack buffer allocation is needed. However, the size is bounded by tiling sizes and inner tile sizes. The size of additional memory is less than
2 * inner_tile_sizeper dimension. Furthermore, if we vectorize the op, all the results will be stored in register. In this context, we won't need any alloca ops.