-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Implement API for padded layout transformations #12720
Changes from 23 commits
93559cf
60ea527
2971b5b
885fd78
185eead
874bfc2
ddea093
2055bbf
f3538cd
a6dbd30
619c5b7
aa9bbf7
7f5707c
5a1e63f
c463043
98a8446
19af1ee
8eb775a
e874020
6386db5
59a0acf
d532610
efb25ac
13b8cef
19a78e8
d801dab
6a4f4cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2443,6 +2443,7 @@ def transform_layout( | |
block: Union[BlockRV, str], | ||
buffer: Union[Tuple[str, int], str, Buffer], | ||
index_map: Union[IndexMap, Callable], | ||
pad_value: Optional[Union[int, float, IndexMap, Callable]] = None, | ||
) -> None: | ||
"""Apply a transformation represented by IndexMap to buffer | ||
|
||
|
@@ -2479,6 +2480,31 @@ def transform_layout( | |
primitive will be called in addition to the | ||
TransformLayout primitive. | ||
|
||
pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the assumption when pad_value is IndexMap. I remember in the RFC we assume it should contain no BufferLoad from buffers except the current buffer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, and the docstring has been updated. I've also added two unit tests, one that validates that an error is raised when the pad value loads from a different buffer, and one that specifies the intended behavior for pad value that loads from the transformed buffer. The latter is currently marked with |
||
|
||
The value to be used for any padding introduced by the | ||
transformation. If the schedule contains a producer block | ||
for the specified buffer, the pad value will be written as | ||
part of the producer block if possible, or after the producer | ||
block otherwise. Otherwise, if the buffer is an input, will | ||
insert an annotation block to state that the padding contains | ||
the known value. | ||
|
||
Note: If applied to an input buffer, the calling scope is | ||
responsible for ensuring that the pad_value is present. | ||
Algebraic symplifications, branch elimination, and other | ||
optimizations may assume that this precondition is met, and | ||
may result in incorrect results being returned. | ||
|
||
If None, the transformation may not introduce padding. | ||
|
||
If an int, float or PrimExpr, the transformation is the | ||
specific value to be present in the padding. | ||
|
||
If an IndexMap or Callable, the transformation is the | ||
value to be present in the padding in terms of the | ||
transformed index. | ||
Comment on lines
+2509
to
+2511
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cpp side only accepts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I had been thinking of it as the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updates made to pass |
||
|
||
Examples | ||
-------- | ||
Before transform_layout, in TensorIR, the IR is: | ||
|
@@ -2536,9 +2562,18 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> | |
else: | ||
axis_separators = [] | ||
|
||
if pad_value is None: | ||
pass | ||
elif callable(pad_value): | ||
pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices)) | ||
elif not isinstance(pad_value, IndexMap): | ||
pad_value = IndexMap.from_func( | ||
lambda *indices: pad_value, ndim=len(index_map.final_indices) | ||
) | ||
|
||
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 | ||
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member | ||
self, block, buffer_index, buffer_index_type_enum, index_map | ||
self, block, buffer_index, buffer_index_type_enum, index_map, pad_value | ||
) | ||
if axis_separators: | ||
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j): | |
|
||
|
||
def shared_32x16_to_ldmatrix_32x16_layout(i, j): | ||
thread_id = (i % 4) + 4 * (j % 8) | ||
thread_id = (i % 16) // 4 + 4 * (j % 8) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @masahi There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for tagging @masahi, I had forgotten to do so. I think I have it set up correctly, based on Nvidia documentation and similarity to the (16,32) shape, but couldn't verify definitively. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm I think the original mapping is correct, this is from p34 of the slide https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21745-developing-cuda-kernels-to-push-tensor-cores-to-the-absolute-limit-on-nvidia-a100.pdf Sorry I don't remember the details There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah sorry I was talking about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if the index map is incorrect, it doesn't affect the correctness of tensorized MMA since the index map is only used for pattern matching purpose... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for looking into it! I wasn't able to find any tests that explicitly validate the transform (e.g. use the transform to generate data in a specific layout, then pass through the mma), as all the tests either started with transformed data, only used the 16x16 shape, or replaced everything with the tensor intrinsic. I had put together this standalone test to convince myself on it. The main issue with the current index map is that it doesn't map to unique locations (512 input indices map to 128 output indices). It only arose as an issue in this PR, because it generates the inverse in order to determine whether/where padding is required. |
||
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the use case for this? According to the doc the mapping function should return a List, it might also need update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was to allow the mapping function to return a single
PrimExpr
, or something that the ffi can convert into aPrimExpr
. Since it wouldn't make sense for the pad value to provide multiple outputs, I found myself frequently writinglambda i,j : i+j
instead oflambda i,j: [i+j]
. I figured that since I was frequently making that mistake, later users would also likely make it as well, so it would be best to support that functionality.Good call on the documentation, and I'll update the documentation for
from_func
andfrom_func_with_separators
accordingly.