diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index b090670d955c..7788610eb371 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -204,10 +204,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); - auto dstShape = dstTy.getShape(); - auto srcSharedLayout = cast(srcTy.getEncoding()); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && - "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index cb9ae9dd0f3c..af77ed79266d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -36,7 +36,10 @@ class TMALoadLowering : public OpRewritePattern { auto ctaLayout = getCTALayout(tensorType.getEncoding()); Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); - if (tensorType.getRank() > 1) { + if (tensorType.getRank() == 2) { + // The following SharedEncodingAttr constructor creates SMEM encoding with + // hasLeadingOffset = true, which is not currently supported for + // higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); @@ -87,7 +90,10 @@ class TMAStoreLowering auto ctaLayout = getCTALayout(tensorType.getEncoding()); Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); - if (tensorType.getRank() > 1) { + if (tensorType.getRank() == 2) { + // The following SharedEncodingAttr constructor creates SMEM encoding with + // hasLeadingOffset = true, which is not currently supported for + // higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 23065953d65b..bf1e08fef466 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -3,7 +3,9 @@ import triton import triton.language as tl -from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) +from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor, + TmaDescKernelParam) +from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg from typing import Optional @@ -538,3 +540,28 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): ) torch.testing.assert_close(ref_out, A) assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + + +@requires_tma +@pytest.mark.parametrize("inner_size", [16, 64]) +def test_experimetal_descriptor_load_4d(inner_size): + device = "cuda" + + @triton.jit + def kernel(Z, desc, inner_size: tl.constexpr): + off0 = tl.arange(0, 2) + off1 = tl.arange(0, 2) + off2 = tl.arange(0, 32) + off3 = tl.arange(0, inner_size) + x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, inner_size], tl.dtype("uint8")) + out_ptrs = (Z + 2 * 32 * inner_size * off0[:, None, None, None] + 32 * inner_size * off1[None, :, None, None] + + inner_size * off2[None, None, :, None] + off3[None, None, None, :]) + tl.store(out_ptrs, x) + + x = torch.randint(size=(4, 8, 32, inner_size), low=0, high=100, dtype=torch.uint8).to(device) + desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, inner_size], [2, 2, 32, inner_size], 1) + + z_tri = torch.zeros(size=(2, 2, 32, inner_size), dtype=torch.uint8, device=device) + kernel[(1, )](z_tri, desc, inner_size) + + assert torch.equal(x[2:4, 2:4, :, :], z_tri) diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index 6077cab6f5fc..76e1a796ec9e 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -9,15 +9,21 @@ class TmaDescKernelParam: def __init__(self, ptr, dims, block_dims, element_size): self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu") assert len(dims) == len(block_dims) - assert 1 <= len(dims) <= 2 + assert 1 <= len(dims) <= 5 assert self.desc.data_ptr() % 64 == 0 if len(dims) == 1: - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, - self.desc.data_ptr()) + fill_desc_func = triton.runtime.driver.active.utils.fill_1d_tma_descriptor + elif len(dims) == 2: + fill_desc_func = triton.runtime.driver.active.utils.fill_2d_tma_descriptor + elif len(dims) == 3: + fill_desc_func = triton.runtime.driver.active.utils.fill_3d_tma_descriptor + elif len(dims) == 4: + fill_desc_func = triton.runtime.driver.active.utils.fill_4d_tma_descriptor else: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], - block_dims[1], element_size, self.desc.data_ptr()) + fill_desc_func = triton.runtime.driver.active.utils.fill_5d_tma_descriptor + + fill_desc_func(ptr, *dims, *block_dims, element_size, self.desc.data_ptr()) # Return a CUtensorMap* pointer in host memory def tma_desc_cpu_ptr(self): diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 12deb0d1e7a3..49cd01e7b327 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -279,20 +279,12 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. -static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dim; - uint32_t tensorDim; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, - &elementSize, &desc_address)) { - return NULL; - } - uint64_t dims[1] = {dim}; - uint64_t globalStrides[1] = {dim * elementSize}; - uint32_t boxDim[1] = {tensorDim}; - uint32_t elementStrides[1] = {1}; +static PyObject *fillTMADescriptior(unsigned long long global_address, + uint64_t *dims, uint32_t *tensorDims, + int elementSize, + unsigned long long desc_address, + uint64_t *globalStrides, + uint32_t *elementStrides, int rank) { CUtensorMapDataType type; switch (elementSize) { case 1: @@ -306,24 +298,68 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { break; default: PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - return NULL; } - assert((elementSize * tensorDim) >= 32 && "block size too small."); - int rank = 1; + + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + + if (rank == 2) { + // For now, we do not swizzle for higher ranks. Enabling swizzling in TMA + // implies hasLeadingOffset = true in SMEM encoding, which is currently not + // supported for higher rank TMA copies. This convention needs to be in sync + // with the TMA lowering pass in codegen. + l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } + + // The bounding box inner dimension must be less than or equal to the + // swizzle size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy + // operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, getCuTensorMapEncodeTiledHandle); CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, l2Promotion, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; } -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t elementStrides[1] = {1}; + int rank = 1; + + return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize, + desc_address, globalStrides, elementStrides, rank); +} + static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { unsigned long long global_address; uint64_t dims[2]; @@ -335,54 +371,77 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { &desc_address)) { return NULL; } - uint64_t globalStrides[2] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize}; + uint64_t globalStrides[1] = {dims[0] * elementSize}; uint32_t elementStrides[2] = {1, 1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - } int rank = 2; - // Swizzling should be picked in codegen but since we need to set it on the - // descriptor we rely on a convention between this function and codegen. - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; - if (contigDimSizeInByte >= 128) { - swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - } else if (contigDimSizeInByte >= 64) { - swizzle = CU_TENSOR_MAP_SWIZZLE_64B; - } else if (contigDimSizeInByte >= 32) { - swizzle = CU_TENSOR_MAP_SWIZZLE_32B; - } else { - assert(false && "block size too small."); + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); +} + +static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[3]; + uint32_t tensorDims[3]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], + &dims[0], &tensorDims[2], &tensorDims[1], + &tensorDims[0], &elementSize, &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[3] = {1, 1, 1}; + int rank = 3; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); +} + +static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[4]; + uint32_t tensorDims[4]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], + &dims[2], &dims[1], &dims[0], &tensorDims[3], + &tensorDims[2], &tensorDims[1], &tensorDims[0], + &elementSize, &desc_address)) { + return NULL; } - // The bounding box inner dimension must be less than or equal to the swizzle - // size. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // We clamp the block size and the codegen will emit multiple copy operations. - if (contigDimSizeInByte > 128) { - tensorDims[0] = 128 / elementSize; + uint64_t globalStrides[3] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize, + dims[0] * dims[1] * dims[2] * elementSize}; + uint32_t elementStrides[4] = {1, 1, 1, 1}; + int rank = 4; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); +} + +static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[5]; + uint32_t tensorDims[5]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, &dims[4], + &dims[3], &dims[2], &dims[1], &dims[0], &tensorDims[4], + &tensorDims[3], &tensorDims[2], &tensorDims[1], + &tensorDims[0], &elementSize, &desc_address)) { + return NULL; } - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; + uint64_t globalStrides[4] = { + dims[0] * elementSize, dims[0] * dims[1] * elementSize, + dims[0] * dims[1] * dims[2] * elementSize, + dims[0] * dims[1] * dims[2] * dims[3] * elementSize}; + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + int rank = 5; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyMethodDef ModuleMethods[] = { @@ -400,6 +459,9 @@ static PyMethodDef ModuleMethods[] = { "that calls printf()."}, {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_3d_tma_descriptor", fill3DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_4d_tma_descriptor", fill4DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_5d_tma_descriptor", fill5DTMADescriptor, METH_VARARGS, "doc"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 827ce61cbaf2..539f6cb8bc80 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -87,6 +87,9 @@ def __init__(self): self.set_printf_fifo_size = mod.set_printf_fifo_size self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + self.fill_3d_tma_descriptor = mod.fill_3d_tma_descriptor + self.fill_4d_tma_descriptor = mod.fill_4d_tma_descriptor + self.fill_5d_tma_descriptor = mod.fill_5d_tma_descriptor # ------------------------