Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 3-5D TMA to allow loading non-matmul operands #5207

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
Copy link
Author

@masahi masahi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm removing this assert to unblock higher rank SMEM load, which seems to work fine.

I don't know what assumption this code has, so please let me know if there is a more reasonable relaxation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine, we might want to test this more heavily but doesn't have to be part of this.


auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
Expand Down
10 changes: 8 additions & 2 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
auto ctaLayout = getCTALayout(tensorType.getEncoding());
Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1,
1, order, ctaLayout);
if (tensorType.getRank() > 1) {
if (tensorType.getRank() == 2) {
// The following SharedEncodingAttr constructor creates SMEM encoding with
// hasLeadingOffset = true, which is not currently supported for
// higher-rank TMA.
encoding = SharedEncodingAttr::get(
tensorType.getContext(), tensorType.getShape(), order, ctaLayout,
tensorType.getElementType());
Expand Down Expand Up @@ -87,7 +90,10 @@ class TMAStoreLowering
auto ctaLayout = getCTALayout(tensorType.getEncoding());
Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1,
1, order, ctaLayout);
if (tensorType.getRank() > 1) {
if (tensorType.getRank() == 2) {
// The following SharedEncodingAttr constructor creates SMEM encoding with
// hasLeadingOffset = true, which is not currently supported for
// higher-rank TMA.
encoding = SharedEncodingAttr::get(
tensorType.getContext(), tensorType.getShape(), order, ctaLayout,
tensorType.getElementType());
Expand Down
29 changes: 28 additions & 1 deletion python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import triton
import triton.language as tl
from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor)
from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor,
TmaDescKernelParam)
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg

from typing import Optional
Expand Down Expand Up @@ -538,3 +540,28 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
)
torch.testing.assert_close(ref_out, A)
assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"]


@requires_tma
@pytest.mark.parametrize("inner_size", [16, 64])
def test_experimetal_descriptor_load_4d(inner_size):
device = "cuda"

@triton.jit
def kernel(Z, desc, inner_size: tl.constexpr):
off0 = tl.arange(0, 2)
off1 = tl.arange(0, 2)
off2 = tl.arange(0, 32)
off3 = tl.arange(0, inner_size)
x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, inner_size], tl.dtype("uint8"))
out_ptrs = (Z + 2 * 32 * inner_size * off0[:, None, None, None] + 32 * inner_size * off1[None, :, None, None] +
inner_size * off2[None, None, :, None] + off3[None, None, None, :])
tl.store(out_ptrs, x)

x = torch.randint(size=(4, 8, 32, inner_size), low=0, high=100, dtype=torch.uint8).to(device)
desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, inner_size], [2, 2, 32, inner_size], 1)

z_tri = torch.zeros(size=(2, 2, 32, inner_size), dtype=torch.uint8, device=device)
kernel[(1, )](z_tri, desc, inner_size)

assert torch.equal(x[2:4, 2:4, :, :], z_tri)
16 changes: 11 additions & 5 deletions python/triton/tools/experimental_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@ class TmaDescKernelParam:
def __init__(self, ptr, dims, block_dims, element_size):
self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
assert len(dims) == len(block_dims)
assert 1 <= len(dims) <= 2
assert 1 <= len(dims) <= 5
assert self.desc.data_ptr() % 64 == 0

if len(dims) == 1:
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
self.desc.data_ptr())
fill_desc_func = triton.runtime.driver.active.utils.fill_1d_tma_descriptor
elif len(dims) == 2:
fill_desc_func = triton.runtime.driver.active.utils.fill_2d_tma_descriptor
elif len(dims) == 3:
fill_desc_func = triton.runtime.driver.active.utils.fill_3d_tma_descriptor
elif len(dims) == 4:
fill_desc_func = triton.runtime.driver.active.utils.fill_4d_tma_descriptor
else:
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0],
block_dims[1], element_size, self.desc.data_ptr())
fill_desc_func = triton.runtime.driver.active.utils.fill_5d_tma_descriptor

fill_desc_func(ptr, *dims, *block_dims, element_size, self.desc.data_ptr())

# Return a CUtensorMap* pointer in host memory
def tma_desc_cpu_ptr(self):
Expand Down
194 changes: 128 additions & 66 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,12 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {

// Simple helper to experiment creating TMA descriptors on the host.
// This is a useful to test TMA operations independently.
static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dim;
uint32_t tensorDim;
int elementSize;
unsigned long long desc_address;
if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim,
&elementSize, &desc_address)) {
return NULL;
}
uint64_t dims[1] = {dim};
uint64_t globalStrides[1] = {dim * elementSize};
uint32_t boxDim[1] = {tensorDim};
uint32_t elementStrides[1] = {1};
static PyObject *fillTMADescriptior(unsigned long long global_address,
uint64_t *dims, uint32_t *tensorDims,
int elementSize,
unsigned long long desc_address,
uint64_t *globalStrides,
uint32_t *elementStrides, int rank) {
CUtensorMapDataType type;
switch (elementSize) {
case 1:
Expand All @@ -306,24 +298,68 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
break;
default:
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
return NULL;
}
assert((elementSize * tensorDim) >= 32 && "block size too small.");
int rank = 1;

// Swizzling should be picked in codegen but since we need to set it on the
// descriptor we rely on a convention between this function and codegen.
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;

if (rank == 2) {
// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA
// implies hasLeadingOffset = true in SMEM encoding, which is currently not
// supported for higher rank TMA copies. This convention needs to be in sync
// with the TMA lowering pass in codegen.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux This is my takeaway from our discussion yesterday. Let me know if this is ok.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases

Copy link
Author

@masahi masahi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With or without this PR, the underlying limitation that prevents 3D TMA with swizzling for matmul inputs continues to exist. So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.

For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.

Overall, I think this PR won't make the situation any worse. Maybe the new TMA representation you mentioned would solve all of those issues. But while we wait for that, it would be good to enable more use cases for TMA within Triton - it is an "experimental" feature, after all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well the case that I think is interesting is to write a batch matmul case where the global tensor is 3D but each block loads a 2D tensor and compute matmul on it.

So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.

even then you wouldn't want swizzling for this case right?

For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.

That breaks the bound checking part right?

Copy link
Author

@masahi masahi Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even then you wouldn't want swizzling for this case right?

Oh sorry if there was misunderstanding. I believe 2D-5D TMA should be treated equally, and I do hope that we can remove this restriction. Whether or not swizzing would be beneficial for my use case is a separate question I need to investigate in the future. Right now my inner-most axis size is 16B so no swizzling would be applied anyway. But I can tweak the sizes of the inner-most two dims, to make the inner-most axis wider and apply swizzling if I want to.

That breaks the bound checking part right?

hmm I haven't thought about that but indeed I don't see how OOB check can work if some dims are flattened (if possible at all).

Maybe the device-side tensor-map creation can be used? After we get the batch id (or group id for grouped gemm), we can use 2D TMA. I'm not sure if that's supported now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get clarified on your concern. I thought this PR would not have any negative implications, since higher-rank TMA with swizzling doesn't work anyway. But are you saying that, one important special case of 3D TMA, where the actual load is 2D (since one of copy dim sizes is always one) is supposed to work with the current impl, but my change would disallow the swizzling for that case as well?

If that's the case, the only workaround, without a proper fix, would be to make swizzling a parameter for higher-rank TMA that the user provide. By default, we don't swizzle. We also need to pass the same swizzling param to tl.experimental_descriptor_load(...) to make the codegen and runtime code in sync.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea would be to base the decision to enable swizzling not on the rank of the global tensor but the "effective rank" of the box, where by "effective rank" I mean a rank after removing size-1 dims.

l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
uint32_t contigDimSizeInByte = elementSize * tensorDims[0];

if (contigDimSizeInByte >= 128) {
swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
} else if (contigDimSizeInByte >= 64) {
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
} else if (contigDimSizeInByte >= 32) {
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
}

// The bounding box inner dimension must be less than or equal to the
// swizzle size.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
// We clamp the block size and the codegen will emit multiple copy
// operations.
if (contigDimSizeInByte > 128) {
tensorDims[0] = 128 / elementSize;
}
}

static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
getCuTensorMapEncodeTiledHandle);
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle, l2Promotion, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
Py_INCREF(Py_None);
return Py_None;
}

// Simple helper to experiment creating TMA descriptors on the host.
// This is a useful to test TMA operations independently.
static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dim;
uint32_t tensorDim;
int elementSize;
unsigned long long desc_address;
if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim,
&elementSize, &desc_address)) {
return NULL;
}
uint64_t globalStrides[1] = {dim * elementSize};
uint32_t elementStrides[1] = {1};
int rank = 1;

return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize,
desc_address, globalStrides, elementStrides, rank);
}

static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dims[2];
Expand All @@ -335,54 +371,77 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
&desc_address)) {
return NULL;
}
uint64_t globalStrides[2] = {dims[0] * elementSize,
dims[0] * dims[1] * elementSize};
uint64_t globalStrides[1] = {dims[0] * elementSize};
uint32_t elementStrides[2] = {1, 1};
CUtensorMapDataType type;
switch (elementSize) {
case 1:
type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case 2:
type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 4:
type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
default:
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
}
int rank = 2;
// Swizzling should be picked in codegen but since we need to set it on the
// descriptor we rely on a convention between this function and codegen.
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
uint32_t contigDimSizeInByte = elementSize * tensorDims[0];
if (contigDimSizeInByte >= 128) {
swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
} else if (contigDimSizeInByte >= 64) {
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
} else if (contigDimSizeInByte >= 32) {
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
} else {
assert(false && "block size too small.");

return fillTMADescriptior(global_address, dims, tensorDims, elementSize,
desc_address, globalStrides, elementStrides, rank);
}

static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dims[3];
uint32_t tensorDims[3];
int elementSize;
unsigned long long desc_address;
if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1],
&dims[0], &tensorDims[2], &tensorDims[1],
&tensorDims[0], &elementSize, &desc_address)) {
return NULL;
}
uint64_t globalStrides[2] = {dims[0] * elementSize,
dims[0] * dims[1] * elementSize};
uint32_t elementStrides[3] = {1, 1, 1};
int rank = 3;

return fillTMADescriptior(global_address, dims, tensorDims, elementSize,
desc_address, globalStrides, elementStrides, rank);
}

static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dims[4];
uint32_t tensorDims[4];
int elementSize;
unsigned long long desc_address;
if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3],
&dims[2], &dims[1], &dims[0], &tensorDims[3],
&tensorDims[2], &tensorDims[1], &tensorDims[0],
&elementSize, &desc_address)) {
return NULL;
}
// The bounding box inner dimension must be less than or equal to the swizzle
// size.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
// We clamp the block size and the codegen will emit multiple copy operations.
if (contigDimSizeInByte > 128) {
tensorDims[0] = 128 / elementSize;
uint64_t globalStrides[3] = {dims[0] * elementSize,
dims[0] * dims[1] * elementSize,
dims[0] * dims[1] * dims[2] * elementSize};
uint32_t elementStrides[4] = {1, 1, 1, 1};
int rank = 4;

return fillTMADescriptior(global_address, dims, tensorDims, elementSize,
desc_address, globalStrides, elementStrides, rank);
}

static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) {
unsigned long long global_address;
uint64_t dims[5];
uint32_t tensorDims[5];
int elementSize;
unsigned long long desc_address;
if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, &dims[4],
&dims[3], &dims[2], &dims[1], &dims[0], &tensorDims[4],
&tensorDims[3], &tensorDims[2], &tensorDims[1],
&tensorDims[0], &elementSize, &desc_address)) {
return NULL;
}
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
getCuTensorMapEncodeTiledHandle);
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
Py_INCREF(Py_None);
return Py_None;
uint64_t globalStrides[4] = {
dims[0] * elementSize, dims[0] * dims[1] * elementSize,
dims[0] * dims[1] * dims[2] * elementSize,
dims[0] * dims[1] * dims[2] * dims[3] * elementSize};
uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
int rank = 5;

return fillTMADescriptior(global_address, dims, tensorDims, elementSize,
desc_address, globalStrides, elementStrides, rank);
}

static PyMethodDef ModuleMethods[] = {
Expand All @@ -400,6 +459,9 @@ static PyMethodDef ModuleMethods[] = {
"that calls printf()."},
{"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"},
{"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"},
{"fill_3d_tma_descriptor", fill3DTMADescriptor, METH_VARARGS, "doc"},
{"fill_4d_tma_descriptor", fill4DTMADescriptor, METH_VARARGS, "doc"},
{"fill_5d_tma_descriptor", fill5DTMADescriptor, METH_VARARGS, "doc"},

{NULL, NULL, 0, NULL} // sentinel
};
Expand Down
3 changes: 3 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(self):
self.set_printf_fifo_size = mod.set_printf_fifo_size
self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
self.fill_3d_tma_descriptor = mod.fill_3d_tma_descriptor
self.fill_4d_tma_descriptor = mod.fill_4d_tma_descriptor
self.fill_5d_tma_descriptor = mod.fill_5d_tma_descriptor


# ------------------------
Expand Down