-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 3-5D TMA to allow loading non-matmul operands #5207
base: main
Are you sure you want to change the base?
Changes from all commits
1640404
8332845
517ea83
028000c
04f48c5
d228adf
d5b5b1b
38d6cd4
a9f4e30
e19f439
591ad5d
e3f0282
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -279,20 +279,12 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { | |
|
||
// Simple helper to experiment creating TMA descriptors on the host. | ||
// This is a useful to test TMA operations independently. | ||
static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dim; | ||
uint32_t tensorDim; | ||
int elementSize; | ||
unsigned long long desc_address; | ||
if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, | ||
&elementSize, &desc_address)) { | ||
return NULL; | ||
} | ||
uint64_t dims[1] = {dim}; | ||
uint64_t globalStrides[1] = {dim * elementSize}; | ||
uint32_t boxDim[1] = {tensorDim}; | ||
uint32_t elementStrides[1] = {1}; | ||
static PyObject *fillTMADescriptior(unsigned long long global_address, | ||
uint64_t *dims, uint32_t *tensorDims, | ||
int elementSize, | ||
unsigned long long desc_address, | ||
uint64_t *globalStrides, | ||
uint32_t *elementStrides, int rank) { | ||
CUtensorMapDataType type; | ||
switch (elementSize) { | ||
case 1: | ||
|
@@ -306,24 +298,68 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { | |
break; | ||
default: | ||
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); | ||
return NULL; | ||
} | ||
assert((elementSize * tensorDim) >= 32 && "block size too small."); | ||
int rank = 1; | ||
|
||
// Swizzling should be picked in codegen but since we need to set it on the | ||
// descriptor we rely on a convention between this function and codegen. | ||
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; | ||
CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; | ||
|
||
if (rank == 2) { | ||
// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA | ||
// implies hasLeadingOffset = true in SMEM encoding, which is currently not | ||
// supported for higher rank TMA copies. This convention needs to be in sync | ||
// with the TMA lowering pass in codegen. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ThomasRaoux This is my takeaway from our discussion yesterday. Let me know if this is ok. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With or without this PR, the underlying limitation that prevents 3D TMA with swizzling for matmul inputs continues to exist. So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention. For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch. Overall, I think this PR won't make the situation any worse. Maybe the new TMA representation you mentioned would solve all of those issues. But while we wait for that, it would be good to enable more use cases for TMA within Triton - it is an "experimental" feature, after all. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well the case that I think is interesting is to write a batch matmul case where the global tensor is 3D but each block loads a 2D tensor and compute matmul on it.
even then you wouldn't want swizzling for this case right?
That breaks the bound checking part right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh sorry if there was misunderstanding. I believe 2D-5D TMA should be treated equally, and I do hope that we can remove this restriction. Whether or not swizzing would be beneficial for my use case is a separate question I need to investigate in the future. Right now my inner-most axis size is 16B so no swizzling would be applied anyway. But I can tweak the sizes of the inner-most two dims, to make the inner-most axis wider and apply swizzling if I want to.
hmm I haven't thought about that but indeed I don't see how OOB check can work if some dims are flattened (if possible at all). Maybe the device-side tensor-map creation can be used? After we get the batch id (or group id for grouped gemm), we can use 2D TMA. I'm not sure if that's supported now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to get clarified on your concern. I thought this PR would not have any negative implications, since higher-rank TMA with swizzling doesn't work anyway. But are you saying that, one important special case of 3D TMA, where the actual load is 2D (since one of copy dim sizes is always one) is supposed to work with the current impl, but my change would disallow the swizzling for that case as well? If that's the case, the only workaround, without a proper fix, would be to make swizzling a parameter for higher-rank TMA that the user provide. By default, we don't swizzle. We also need to pass the same swizzling param to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another idea would be to base the decision to enable swizzling not on the rank of the global tensor but the "effective rank" of the box, where by "effective rank" I mean a rank after removing size-1 dims. |
||
l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; | ||
uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; | ||
|
||
if (contigDimSizeInByte >= 128) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_128B; | ||
} else if (contigDimSizeInByte >= 64) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_64B; | ||
} else if (contigDimSizeInByte >= 32) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_32B; | ||
} | ||
|
||
// The bounding box inner dimension must be less than or equal to the | ||
// swizzle size. | ||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 | ||
// We clamp the block size and the codegen will emit multiple copy | ||
// operations. | ||
if (contigDimSizeInByte > 128) { | ||
tensorDims[0] = 128 / elementSize; | ||
} | ||
} | ||
|
||
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; | ||
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, | ||
getCuTensorMapEncodeTiledHandle); | ||
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( | ||
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, | ||
globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, | ||
CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, | ||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); | ||
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, | ||
swizzle, l2Promotion, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); | ||
Py_INCREF(Py_None); | ||
return Py_None; | ||
} | ||
|
||
// Simple helper to experiment creating TMA descriptors on the host. | ||
// This is a useful to test TMA operations independently. | ||
static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dim; | ||
uint32_t tensorDim; | ||
int elementSize; | ||
unsigned long long desc_address; | ||
if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, | ||
&elementSize, &desc_address)) { | ||
return NULL; | ||
} | ||
uint64_t globalStrides[1] = {dim * elementSize}; | ||
uint32_t elementStrides[1] = {1}; | ||
int rank = 1; | ||
|
||
return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize, | ||
desc_address, globalStrides, elementStrides, rank); | ||
} | ||
|
||
static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dims[2]; | ||
|
@@ -335,54 +371,77 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { | |
&desc_address)) { | ||
return NULL; | ||
} | ||
uint64_t globalStrides[2] = {dims[0] * elementSize, | ||
dims[0] * dims[1] * elementSize}; | ||
uint64_t globalStrides[1] = {dims[0] * elementSize}; | ||
uint32_t elementStrides[2] = {1, 1}; | ||
CUtensorMapDataType type; | ||
switch (elementSize) { | ||
case 1: | ||
type = CU_TENSOR_MAP_DATA_TYPE_UINT8; | ||
break; | ||
case 2: | ||
type = CU_TENSOR_MAP_DATA_TYPE_UINT16; | ||
break; | ||
case 4: | ||
type = CU_TENSOR_MAP_DATA_TYPE_UINT32; | ||
break; | ||
default: | ||
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); | ||
} | ||
int rank = 2; | ||
// Swizzling should be picked in codegen but since we need to set it on the | ||
// descriptor we rely on a convention between this function and codegen. | ||
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; | ||
uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; | ||
if (contigDimSizeInByte >= 128) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_128B; | ||
} else if (contigDimSizeInByte >= 64) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_64B; | ||
} else if (contigDimSizeInByte >= 32) { | ||
swizzle = CU_TENSOR_MAP_SWIZZLE_32B; | ||
} else { | ||
assert(false && "block size too small."); | ||
|
||
return fillTMADescriptior(global_address, dims, tensorDims, elementSize, | ||
desc_address, globalStrides, elementStrides, rank); | ||
} | ||
|
||
static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dims[3]; | ||
uint32_t tensorDims[3]; | ||
int elementSize; | ||
unsigned long long desc_address; | ||
if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], | ||
&dims[0], &tensorDims[2], &tensorDims[1], | ||
&tensorDims[0], &elementSize, &desc_address)) { | ||
return NULL; | ||
} | ||
uint64_t globalStrides[2] = {dims[0] * elementSize, | ||
dims[0] * dims[1] * elementSize}; | ||
uint32_t elementStrides[3] = {1, 1, 1}; | ||
int rank = 3; | ||
|
||
return fillTMADescriptior(global_address, dims, tensorDims, elementSize, | ||
desc_address, globalStrides, elementStrides, rank); | ||
} | ||
|
||
static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dims[4]; | ||
uint32_t tensorDims[4]; | ||
int elementSize; | ||
unsigned long long desc_address; | ||
if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], | ||
&dims[2], &dims[1], &dims[0], &tensorDims[3], | ||
&tensorDims[2], &tensorDims[1], &tensorDims[0], | ||
&elementSize, &desc_address)) { | ||
return NULL; | ||
} | ||
// The bounding box inner dimension must be less than or equal to the swizzle | ||
// size. | ||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 | ||
// We clamp the block size and the codegen will emit multiple copy operations. | ||
if (contigDimSizeInByte > 128) { | ||
tensorDims[0] = 128 / elementSize; | ||
uint64_t globalStrides[3] = {dims[0] * elementSize, | ||
dims[0] * dims[1] * elementSize, | ||
dims[0] * dims[1] * dims[2] * elementSize}; | ||
uint32_t elementStrides[4] = {1, 1, 1, 1}; | ||
int rank = 4; | ||
|
||
return fillTMADescriptior(global_address, dims, tensorDims, elementSize, | ||
desc_address, globalStrides, elementStrides, rank); | ||
} | ||
|
||
static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) { | ||
unsigned long long global_address; | ||
uint64_t dims[5]; | ||
uint32_t tensorDims[5]; | ||
int elementSize; | ||
unsigned long long desc_address; | ||
if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, &dims[4], | ||
&dims[3], &dims[2], &dims[1], &dims[0], &tensorDims[4], | ||
&tensorDims[3], &tensorDims[2], &tensorDims[1], | ||
&tensorDims[0], &elementSize, &desc_address)) { | ||
return NULL; | ||
} | ||
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; | ||
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, | ||
getCuTensorMapEncodeTiledHandle); | ||
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( | ||
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, | ||
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, | ||
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, | ||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); | ||
Py_INCREF(Py_None); | ||
return Py_None; | ||
uint64_t globalStrides[4] = { | ||
dims[0] * elementSize, dims[0] * dims[1] * elementSize, | ||
dims[0] * dims[1] * dims[2] * elementSize, | ||
dims[0] * dims[1] * dims[2] * dims[3] * elementSize}; | ||
uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; | ||
int rank = 5; | ||
|
||
return fillTMADescriptior(global_address, dims, tensorDims, elementSize, | ||
desc_address, globalStrides, elementStrides, rank); | ||
} | ||
|
||
static PyMethodDef ModuleMethods[] = { | ||
|
@@ -400,6 +459,9 @@ static PyMethodDef ModuleMethods[] = { | |
"that calls printf()."}, | ||
{"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, | ||
{"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, | ||
{"fill_3d_tma_descriptor", fill3DTMADescriptor, METH_VARARGS, "doc"}, | ||
{"fill_4d_tma_descriptor", fill4DTMADescriptor, METH_VARARGS, "doc"}, | ||
{"fill_5d_tma_descriptor", fill5DTMADescriptor, METH_VARARGS, "doc"}, | ||
|
||
{NULL, NULL, 0, NULL} // sentinel | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm removing this assert to unblock higher rank SMEM load, which seems to work fine.
I don't know what assumption this code has, so please let me know if there is a more reasonable relaxation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks fine, we might want to test this more heavily but doesn't have to be part of this.