From 1640404d23bd411196a720e3ace3f3e4827f76b5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Nov 2024 05:14:32 +0000 Subject: [PATCH 01/10] Add support for higher dim TMA --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 4 +- .../Transforms/TMALowering.cpp | 2 +- .../test/unit/hopper/test_experimental_tma.py | 93 ++++++++- .../triton/tools/experimental_descriptor.py | 16 +- third_party/nvidia/backend/driver.c | 187 +++++++++++------- third_party/nvidia/backend/driver.py | 3 + 6 files changed, 228 insertions(+), 77 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4cea14f0957f..5d07f4a8924f 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -207,8 +207,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); auto dstLayout = dstTy.getEncoding(); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && - "Unexpected rank of ConvertLayout(shared->distributed)"); + // assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && + // "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index cb9ae9dd0f3c..34b0eced4d94 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -36,7 +36,7 @@ class TMALoadLowering : public OpRewritePattern { auto ctaLayout = getCTALayout(tensorType.getEncoding()); Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); - if (tensorType.getRank() > 1) { + if (tensorType.getRank() == 2) { encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 7062093aef6d..4e7b15128f5b 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -1,9 +1,11 @@ +import tempfile import pytest import torch import triton import triton.language as tl -from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) +from triton.backends.compiler import GPUTarget +from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor, TmaDescKernelParam) from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma from typing import Optional @@ -106,8 +108,8 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm num_warps=8, num_stages=num_stages, dtype=tl.float16) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) - if BLOCK_M >= 64 and BLOCK_N >= 64: - assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + # if BLOCK_M >= 64 and BLOCK_N >= 64: + # assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] if byval_tma: assert ".param .align 64 .b8" in kernel.asm["ptx"] @@ -534,3 +536,88 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): ) torch.testing.assert_close(ref_out, A) assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + + +@requires_tma +def test_experimetal_descriptor_load_3d_no_jit(): + device = "cuda" + + ir = """ +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 2, 1], order = [2, 0, 1]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { + %true = arith.constant true + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<32> : tensor<1x2x1xi32, #blocked> + %cst_0 = arith.constant dense<64> : tensor<2x1x1xi32, #blocked> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<2x2x32xi8, #shared, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %1, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %1, 128, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c2_i32, %c2_i32, %c0_i32] %0, %1, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <2x2x32xi8, #shared,#triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %1, %c0_i32 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %2 = triton_gpu.local_load %0 : !tt.memdesc<2x2x32xi8, #shared, #triton_gpu.shared_memory, mutable> -> tensor<2x2x32xi8, #blocked> + %3 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> -> tensor<2x1xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<2x1xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<2x1x1xi32, #blocked> + %6 = arith.muli %5, %cst_0 : tensor<2x1x1xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<2x1x1x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<2x1x1x!tt.ptr, #blocked>, tensor<2x1x1xi32, #blocked> + %9 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> + %11 = tt.expand_dims %10 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<1x2x1xi32, #blocked> + %12 = arith.muli %11, %cst : tensor<1x2x1xi32, #blocked> + %13 = tt.broadcast %8 : tensor<2x1x1x!tt.ptr, #blocked> -> tensor<2x2x1x!tt.ptr, #blocked> + %14 = tt.broadcast %12 : tensor<1x2x1xi32, #blocked> -> tensor<2x2x1xi32, #blocked> + %15 = tt.addptr %13, %14 : tensor<2x2x1x!tt.ptr, #blocked>, tensor<2x2x1xi32, #blocked> + %16 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 1, parent = #blocked}>}>> + %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 1, parent = #blocked}>}>> -> tensor<1x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = tt.expand_dims %17 {axis = 1 : i32} : tensor<1x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1x32xi32, #blocked> + %19 = tt.broadcast %15 : tensor<2x2x1x!tt.ptr, #blocked> -> tensor<2x2x32x!tt.ptr, #blocked> + %20 = tt.broadcast %18 : tensor<1x1x32xi32, #blocked> -> tensor<2x2x32xi32, #blocked> + %21 = tt.addptr %19, %20 : tensor<2x2x32x!tt.ptr, #blocked>, tensor<2x2x32xi32, #blocked> + tt.store %21, %2 : tensor<2x2x32x!tt.ptr, #blocked> + tt.return + } +} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name, target=GPUTarget("cuda", 90, 32)) + + x = torch.randint(size=(4, 8, 32), low=0, high=100, dtype=torch.uint8).to(device) + desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32], [2, 2, 32], 1) + + z_tri = torch.zeros(size=(2, 2, 32), dtype=torch.uint8, device=device) + kernel[(1,1,1)](z_tri, desc) + + assert torch.equal(x[2:4, 2:4, :], z_tri) + + +@requires_tma +def test_experimetal_descriptor_load_4d(): + device = "cuda" + + @triton.jit + def kernel(Z, desc): + off0 = tl.arange(0, 2) + off1 = tl.arange(0, 2) + off2 = tl.arange(0, 32) + off3 = tl.arange(0, 16) + x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, 16], tl.dtype("uint8")) + out_ptrs = Z + 2 * 32 * 16 * off0[:, None, None, None] + 32 * 16 * off1[None, :, None, None] + 16 * off2[None, None, :, None] + off3[None, None, None, :] + tl.store(out_ptrs, x) + + x = torch.randint(size=(4, 8, 32, 16), low=0, high=100, dtype=torch.uint8).to(device) + desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, 16], [2, 2, 32, 16], 1) + + z_tri = torch.zeros(size=(2, 2, 32, 16), dtype=torch.uint8, device=device) + kernel[(1, )](z_tri, desc, num_warps=4) + + assert torch.equal(x[2:4, 2:4, :, :], z_tri) diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index 6077cab6f5fc..76e1a796ec9e 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -9,15 +9,21 @@ class TmaDescKernelParam: def __init__(self, ptr, dims, block_dims, element_size): self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu") assert len(dims) == len(block_dims) - assert 1 <= len(dims) <= 2 + assert 1 <= len(dims) <= 5 assert self.desc.data_ptr() % 64 == 0 if len(dims) == 1: - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, - self.desc.data_ptr()) + fill_desc_func = triton.runtime.driver.active.utils.fill_1d_tma_descriptor + elif len(dims) == 2: + fill_desc_func = triton.runtime.driver.active.utils.fill_2d_tma_descriptor + elif len(dims) == 3: + fill_desc_func = triton.runtime.driver.active.utils.fill_3d_tma_descriptor + elif len(dims) == 4: + fill_desc_func = triton.runtime.driver.active.utils.fill_4d_tma_descriptor else: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], - block_dims[1], element_size, self.desc.data_ptr()) + fill_desc_func = triton.runtime.driver.active.utils.fill_5d_tma_descriptor + + fill_desc_func(ptr, *dims, *block_dims, element_size, self.desc.data_ptr()) # Return a CUtensorMap* pointer in host memory def tma_desc_cpu_ptr(self): diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 12deb0d1e7a3..6785fa5a7768 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -279,20 +279,10 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. -static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dim; - uint32_t tensorDim; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, - &elementSize, &desc_address)) { - return NULL; - } - uint64_t dims[1] = {dim}; - uint64_t globalStrides[1] = {dim * elementSize}; - uint32_t boxDim[1] = {tensorDim}; - uint32_t elementStrides[1] = {1}; +static PyObject *fillTMADescriptior(unsigned long long global_address, uint64_t* dims, + uint32_t* tensorDims, int elementSize, + unsigned long long desc_address, uint64_t* globalStrides, + uint32_t* elementStrides, int rank) { CUtensorMapDataType type; switch (elementSize) { case 1: @@ -306,24 +296,64 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { break; default: PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - return NULL; } - assert((elementSize * tensorDim) >= 32 && "block size too small."); - int rank = 1; + + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + + if (rank > 1) { + l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } + + // The bounding box inner dimension must be less than or equal to the + // swizzle size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy + // operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, getCuTensorMapEncodeTiledHandle); CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, l2Promotion, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; } -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t elementStrides[1] = {1}; + int rank = 1; + + return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize, desc_address, + globalStrides, elementStrides, rank); +} + static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { unsigned long long global_address; uint64_t dims[2]; @@ -335,54 +365,76 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { &desc_address)) { return NULL; } - uint64_t globalStrides[2] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize}; + uint64_t globalStrides[1] = {dims[0] * elementSize}; uint32_t elementStrides[2] = {1, 1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - } int rank = 2; - // Swizzling should be picked in codegen but since we need to set it on the - // descriptor we rely on a convention between this function and codegen. - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; - if (contigDimSizeInByte >= 128) { - swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - } else if (contigDimSizeInByte >= 64) { - swizzle = CU_TENSOR_MAP_SWIZZLE_64B; - } else if (contigDimSizeInByte >= 32) { - swizzle = CU_TENSOR_MAP_SWIZZLE_32B; - } else { - assert(false && "block size too small."); + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, + globalStrides, elementStrides, rank); +} + +static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[3]; + uint32_t tensorDims[3]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], &dims[0], + &tensorDims[2], &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[3] = {1, 1, 1}; + int rank = 3; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, + globalStrides, elementStrides, rank); +} + +static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[4]; + uint32_t tensorDims[4]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], &dims[2], &dims[1], &dims[0], + &tensorDims[3], &tensorDims[2], &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; } - // The bounding box inner dimension must be less than or equal to the swizzle - // size. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // We clamp the block size and the codegen will emit multiple copy operations. - if (contigDimSizeInByte > 128) { - tensorDims[0] = 128 / elementSize; + uint64_t globalStrides[3] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize, + dims[0] * dims[1] * dims[2] * elementSize}; + uint32_t elementStrides[4] = {1, 1, 1, 1}; + int rank = 4; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, + globalStrides, elementStrides, rank); +} + +static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[5]; + uint32_t tensorDims[5]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, + &dims[4], &dims[3], &dims[2], &dims[1], &dims[0], + &tensorDims[4], &tensorDims[3], &tensorDims[2], &tensorDims[1], &tensorDims[0], + &elementSize, &desc_address)) { + return NULL; } - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; + uint64_t globalStrides[4] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize, + dims[0] * dims[1] * dims[2] * elementSize, + dims[0] * dims[1] * dims[2] * dims[3] * elementSize}; + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + int rank = 5; + + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, + globalStrides, elementStrides, rank); } static PyMethodDef ModuleMethods[] = { @@ -400,6 +452,9 @@ static PyMethodDef ModuleMethods[] = { "that calls printf()."}, {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_3d_tma_descriptor", fill3DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_4d_tma_descriptor", fill4DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_5d_tma_descriptor", fill5DTMADescriptor, METH_VARARGS, "doc"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 827ce61cbaf2..539f6cb8bc80 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -87,6 +87,9 @@ def __init__(self): self.set_printf_fifo_size = mod.set_printf_fifo_size self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + self.fill_3d_tma_descriptor = mod.fill_3d_tma_descriptor + self.fill_4d_tma_descriptor = mod.fill_4d_tma_descriptor + self.fill_5d_tma_descriptor = mod.fill_5d_tma_descriptor # ------------------------ From 83328452130bc5e29b6ffd3320be2dee8b4adfe7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Nov 2024 05:20:56 +0000 Subject: [PATCH 02/10] fix --- python/test/unit/hopper/test_experimental_tma.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 4e7b15128f5b..6dc4a9bea290 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -108,8 +108,8 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm num_warps=8, num_stages=num_stages, dtype=tl.float16) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) - # if BLOCK_M >= 64 and BLOCK_N >= 64: - # assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + if BLOCK_M >= 64 and BLOCK_N >= 64: + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] if byval_tma: assert ".param .align 64 .b8" in kernel.asm["ptx"] From 517ea83b22e49463185718650ee20d514cf6ab1b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Nov 2024 09:32:26 +0000 Subject: [PATCH 03/10] disable swizzling and hasLeadingOffset for > 2D TMA --- lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp | 3 ++- third_party/nvidia/backend/driver.c | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 34b0eced4d94..b3e6ae988d98 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -41,6 +41,7 @@ class TMALoadLowering : public OpRewritePattern { tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); } + MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); @@ -87,7 +88,7 @@ class TMAStoreLowering auto ctaLayout = getCTALayout(tensorType.getEncoding()); Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); - if (tensorType.getRank() > 1) { + if (tensorType.getRank() == 2) { encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 6785fa5a7768..3705aeb5bb7f 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -303,7 +303,7 @@ static PyObject *fillTMADescriptior(unsigned long long global_address, uint64_t* CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; - if (rank > 1) { + if (rank == 2) { l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; From 028000cc770d94aed3c6e44a014a3b0403853d27 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Nov 2024 20:02:55 +0000 Subject: [PATCH 04/10] parametrize test --- .../test/unit/hopper/test_experimental_tma.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 6dc4a9bea290..3e62a31806e0 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -601,23 +601,25 @@ def test_experimetal_descriptor_load_3d_no_jit(): @requires_tma -def test_experimetal_descriptor_load_4d(): +@pytest.mark.parametrize("inner_size", [16, 64]) +def test_experimetal_descriptor_load_4d(inner_size): device = "cuda" @triton.jit - def kernel(Z, desc): + def kernel(Z, desc, inner_size: tl.constexpr): off0 = tl.arange(0, 2) off1 = tl.arange(0, 2) off2 = tl.arange(0, 32) - off3 = tl.arange(0, 16) - x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, 16], tl.dtype("uint8")) - out_ptrs = Z + 2 * 32 * 16 * off0[:, None, None, None] + 32 * 16 * off1[None, :, None, None] + 16 * off2[None, None, :, None] + off3[None, None, None, :] + off3 = tl.arange(0, inner_size) + x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], + [2, 2, 32, inner_size], tl.dtype("uint8")) + out_ptrs = Z + 2 * 32 * inner_size * off0[:, None, None, None] + 32 * inner_size * off1[None, :, None, None] + inner_size * off2[None, None, :, None] + off3[None, None, None, :] tl.store(out_ptrs, x) - x = torch.randint(size=(4, 8, 32, 16), low=0, high=100, dtype=torch.uint8).to(device) - desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, 16], [2, 2, 32, 16], 1) + x = torch.randint(size=(4, 8, 32, inner_size), low=0, high=100, dtype=torch.uint8).to(device) + desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, inner_size], [2, 2, 32, inner_size], 1) - z_tri = torch.zeros(size=(2, 2, 32, 16), dtype=torch.uint8, device=device) - kernel[(1, )](z_tri, desc, num_warps=4) + z_tri = torch.zeros(size=(2, 2, 32, inner_size), dtype=torch.uint8, device=device) + kernel[(1, )](z_tri, desc, inner_size) assert torch.equal(x[2:4, 2:4, :, :], z_tri) From 04f48c52c866bac60e7e951d10e44c82e696b81c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Nov 2024 20:30:15 +0000 Subject: [PATCH 05/10] add comment on disabling swizzling --- lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp | 5 ++++- third_party/nvidia/backend/driver.c | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index b3e6ae988d98..594921cbaaea 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -37,11 +37,12 @@ class TMALoadLowering : public OpRewritePattern { Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); if (tensorType.getRank() == 2) { + // The following SharedEncodingAttr constructor creates SMEM encoding with + // hasLeadingOffset = true, which is not currently supported for higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); } - MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); @@ -89,6 +90,8 @@ class TMAStoreLowering Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, order, ctaLayout); if (tensorType.getRank() == 2) { + // The following SharedEncodingAttr constructor creates SMEM encoding with + // hasLeadingOffset = true, which is not currently supported for higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 3705aeb5bb7f..8be08915c4a2 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -304,6 +304,9 @@ static PyObject *fillTMADescriptior(unsigned long long global_address, uint64_t* CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; if (rank == 2) { + // For now, we do not swizzle for higher ranks. Enabling swizzling in TMA implies hasLeadingOffset = true + // in SMEM encoding, which is currently not supported for higher rank TMA copies. This convention needs to be + // in sync with the TMA lowering pass in codegen. l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; From d228adfa56e0b6fda0c81e219d898b0838ddc4ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Nov 2024 20:38:16 +0000 Subject: [PATCH 06/10] run ruff --- .../test/unit/hopper/test_experimental_tma.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 3e62a31806e0..a220f25547de 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -586,7 +586,7 @@ def test_experimetal_descriptor_load_3d_no_jit(): } """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".ttgir") as f: f.write(ir) f.flush() kernel = triton.compile(f.name, target=GPUTarget("cuda", 90, 32)) @@ -595,7 +595,7 @@ def test_experimetal_descriptor_load_3d_no_jit(): desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32], [2, 2, 32], 1) z_tri = torch.zeros(size=(2, 2, 32), dtype=torch.uint8, device=device) - kernel[(1,1,1)](z_tri, desc) + kernel[(1, 1, 1)](z_tri, desc) assert torch.equal(x[2:4, 2:4, :], z_tri) @@ -611,15 +611,20 @@ def kernel(Z, desc, inner_size: tl.constexpr): off1 = tl.arange(0, 2) off2 = tl.arange(0, 32) off3 = tl.arange(0, inner_size) - x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], - [2, 2, 32, inner_size], tl.dtype("uint8")) - out_ptrs = Z + 2 * 32 * inner_size * off0[:, None, None, None] + 32 * inner_size * off1[None, :, None, None] + inner_size * off2[None, None, :, None] + off3[None, None, None, :] + x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, inner_size], tl.dtype("uint8")) + out_ptrs = ( + Z + + 2 * 32 * inner_size * off0[:, None, None, None] + + 32 * inner_size * off1[None, :, None, None] + + inner_size * off2[None, None, :, None] + + off3[None, None, None, :] + ) tl.store(out_ptrs, x) x = torch.randint(size=(4, 8, 32, inner_size), low=0, high=100, dtype=torch.uint8).to(device) desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, inner_size], [2, 2, 32, inner_size], 1) z_tri = torch.zeros(size=(2, 2, 32, inner_size), dtype=torch.uint8, device=device) - kernel[(1, )](z_tri, desc, inner_size) + kernel[(1,)](z_tri, desc, inner_size) assert torch.equal(x[2:4, 2:4, :, :], z_tri) From d5b5b1bfc856fd169ae9ebe9da0f4edefb1cab48 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Nov 2024 23:51:04 +0000 Subject: [PATCH 07/10] remove rank <= 2 assert --- lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 5d07f4a8924f..2682b5591a71 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -204,11 +204,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); - auto dstShape = dstTy.getShape(); - auto srcSharedLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); - // assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && - // "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), From 38d6cd470ee4c5b8bbc2847e0842d1f49ea40d48 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 21 Nov 2024 00:13:59 +0000 Subject: [PATCH 08/10] add comment on test --- python/test/unit/hopper/test_experimental_tma.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index a220f25547de..88949751e023 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -540,6 +540,9 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): @requires_tma def test_experimetal_descriptor_load_3d_no_jit(): + """In addition to testing 3D TMA, we also test parsing of TTGIR when TMA descriptors are in arguments. + See https://github.com/triton-lang/triton/pull/4875 + """ device = "cuda" ir = """ From a9f4e30636b457318e4db44d7b00d640bd4ea2e7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 21 Nov 2024 00:20:06 +0000 Subject: [PATCH 09/10] precommit --- .../Transforms/TMALowering.cpp | 6 +- .../test/unit/hopper/test_experimental_tma.py | 14 ++-- third_party/nvidia/backend/driver.c | 66 ++++++++++--------- 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 594921cbaaea..af77ed79266d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -38,7 +38,8 @@ class TMALoadLowering : public OpRewritePattern { 1, order, ctaLayout); if (tensorType.getRank() == 2) { // The following SharedEncodingAttr constructor creates SMEM encoding with - // hasLeadingOffset = true, which is not currently supported for higher-rank TMA. + // hasLeadingOffset = true, which is not currently supported for + // higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); @@ -91,7 +92,8 @@ class TMAStoreLowering 1, order, ctaLayout); if (tensorType.getRank() == 2) { // The following SharedEncodingAttr constructor creates SMEM encoding with - // hasLeadingOffset = true, which is not currently supported for higher-rank TMA. + // hasLeadingOffset = true, which is not currently supported for + // higher-rank TMA. encoding = SharedEncodingAttr::get( tensorType.getContext(), tensorType.getShape(), order, ctaLayout, tensorType.getElementType()); diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 88949751e023..0f00461d7e35 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -5,7 +5,8 @@ import triton import triton.language as tl from triton.backends.compiler import GPUTarget -from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor, TmaDescKernelParam) +from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor, + TmaDescKernelParam) from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma from typing import Optional @@ -615,19 +616,14 @@ def kernel(Z, desc, inner_size: tl.constexpr): off2 = tl.arange(0, 32) off3 = tl.arange(0, inner_size) x = tl._experimental_descriptor_load(desc, [2, 2, 0, 0], [2, 2, 32, inner_size], tl.dtype("uint8")) - out_ptrs = ( - Z - + 2 * 32 * inner_size * off0[:, None, None, None] - + 32 * inner_size * off1[None, :, None, None] - + inner_size * off2[None, None, :, None] - + off3[None, None, None, :] - ) + out_ptrs = (Z + 2 * 32 * inner_size * off0[:, None, None, None] + 32 * inner_size * off1[None, :, None, None] + + inner_size * off2[None, None, :, None] + off3[None, None, None, :]) tl.store(out_ptrs, x) x = torch.randint(size=(4, 8, 32, inner_size), low=0, high=100, dtype=torch.uint8).to(device) desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32, inner_size], [2, 2, 32, inner_size], 1) z_tri = torch.zeros(size=(2, 2, 32, inner_size), dtype=torch.uint8, device=device) - kernel[(1,)](z_tri, desc, inner_size) + kernel[(1, )](z_tri, desc, inner_size) assert torch.equal(x[2:4, 2:4, :, :], z_tri) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 8be08915c4a2..49cd01e7b327 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -279,10 +279,12 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. -static PyObject *fillTMADescriptior(unsigned long long global_address, uint64_t* dims, - uint32_t* tensorDims, int elementSize, - unsigned long long desc_address, uint64_t* globalStrides, - uint32_t* elementStrides, int rank) { +static PyObject *fillTMADescriptior(unsigned long long global_address, + uint64_t *dims, uint32_t *tensorDims, + int elementSize, + unsigned long long desc_address, + uint64_t *globalStrides, + uint32_t *elementStrides, int rank) { CUtensorMapDataType type; switch (elementSize) { case 1: @@ -304,9 +306,10 @@ static PyObject *fillTMADescriptior(unsigned long long global_address, uint64_t* CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; if (rank == 2) { - // For now, we do not swizzle for higher ranks. Enabling swizzling in TMA implies hasLeadingOffset = true - // in SMEM encoding, which is currently not supported for higher rank TMA copies. This convention needs to be - // in sync with the TMA lowering pass in codegen. + // For now, we do not swizzle for higher ranks. Enabling swizzling in TMA + // implies hasLeadingOffset = true in SMEM encoding, which is currently not + // supported for higher rank TMA copies. This convention needs to be in sync + // with the TMA lowering pass in codegen. l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; @@ -353,8 +356,8 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { uint32_t elementStrides[1] = {1}; int rank = 1; - return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize, desc_address, - globalStrides, elementStrides, rank); + return fillTMADescriptior(global_address, &dim, &tensorDim, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { @@ -372,8 +375,8 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { uint32_t elementStrides[2] = {1, 1}; int rank = 2; - return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, - globalStrides, elementStrides, rank); + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { @@ -382,9 +385,9 @@ static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { uint32_t tensorDims[3]; int elementSize; unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], &dims[0], - &tensorDims[2], &tensorDims[1], &tensorDims[0], &elementSize, - &desc_address)) { + if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], + &dims[0], &tensorDims[2], &tensorDims[1], + &tensorDims[0], &elementSize, &desc_address)) { return NULL; } uint64_t globalStrides[2] = {dims[0] * elementSize, @@ -392,8 +395,8 @@ static PyObject *fill3DTMADescriptor(PyObject *self, PyObject *args) { uint32_t elementStrides[3] = {1, 1, 1}; int rank = 3; - return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, - globalStrides, elementStrides, rank); + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { @@ -402,9 +405,10 @@ static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { uint32_t tensorDims[4]; int elementSize; unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], &dims[2], &dims[1], &dims[0], - &tensorDims[3], &tensorDims[2], &tensorDims[1], &tensorDims[0], &elementSize, - &desc_address)) { + if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], + &dims[2], &dims[1], &dims[0], &tensorDims[3], + &tensorDims[2], &tensorDims[1], &tensorDims[0], + &elementSize, &desc_address)) { return NULL; } uint64_t globalStrides[3] = {dims[0] * elementSize, @@ -413,8 +417,8 @@ static PyObject *fill4DTMADescriptor(PyObject *self, PyObject *args) { uint32_t elementStrides[4] = {1, 1, 1, 1}; int rank = 4; - return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, - globalStrides, elementStrides, rank); + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) { @@ -423,21 +427,21 @@ static PyObject *fill5DTMADescriptor(PyObject *self, PyObject *args) { uint32_t tensorDims[5]; int elementSize; unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, - &dims[4], &dims[3], &dims[2], &dims[1], &dims[0], - &tensorDims[4], &tensorDims[3], &tensorDims[2], &tensorDims[1], &tensorDims[0], - &elementSize, &desc_address)) { + if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, &dims[4], + &dims[3], &dims[2], &dims[1], &dims[0], &tensorDims[4], + &tensorDims[3], &tensorDims[2], &tensorDims[1], + &tensorDims[0], &elementSize, &desc_address)) { return NULL; } - uint64_t globalStrides[4] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize, - dims[0] * dims[1] * dims[2] * elementSize, - dims[0] * dims[1] * dims[2] * dims[3] * elementSize}; + uint64_t globalStrides[4] = { + dims[0] * elementSize, dims[0] * dims[1] * elementSize, + dims[0] * dims[1] * dims[2] * elementSize, + dims[0] * dims[1] * dims[2] * dims[3] * elementSize}; uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; int rank = 5; - return fillTMADescriptior(global_address, dims, tensorDims, elementSize, desc_address, - globalStrides, elementStrides, rank); + return fillTMADescriptior(global_address, dims, tensorDims, elementSize, + desc_address, globalStrides, elementStrides, rank); } static PyMethodDef ModuleMethods[] = { From e19f43936c89fb5b2fce69dd9d506a62479d38e6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 22 Nov 2024 04:18:58 +0900 Subject: [PATCH 10/10] remove ttgir test --- .../test/unit/hopper/test_experimental_tma.py | 65 ------------------- 1 file changed, 65 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 0f00461d7e35..9876e93f0b64 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -539,71 +539,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] -@requires_tma -def test_experimetal_descriptor_load_3d_no_jit(): - """In addition to testing 3D TMA, we also test parsing of TTGIR when TMA descriptors are in arguments. - See https://github.com/triton-lang/triton/pull/4875 - """ - device = "cuda" - - ir = """ -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 2, 1], order = [2, 0, 1]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { - %true = arith.constant true - %c2_i32 = arith.constant 2 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<32> : tensor<1x2x1xi32, #blocked> - %cst_0 = arith.constant dense<64> : tensor<2x1x1xi32, #blocked> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<2x2x32xi8, #shared, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %1, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %1, 128, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c2_i32, %c2_i32, %c0_i32] %0, %1, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <2x2x32xi8, #shared,#triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %1, %c0_i32 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.inval_barrier %1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_load %0 : !tt.memdesc<2x2x32xi8, #shared, #triton_gpu.shared_memory, mutable> -> tensor<2x2x32xi8, #blocked> - %3 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> -> tensor<2x1xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<2x1xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<2x1x1xi32, #blocked> - %6 = arith.muli %5, %cst_0 : tensor<2x1x1xi32, #blocked> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<2x1x1x!tt.ptr, #blocked> - %8 = tt.addptr %7, %6 : tensor<2x1x1x!tt.ptr, #blocked>, tensor<2x1x1xi32, #blocked> - %9 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> - %11 = tt.expand_dims %10 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<1x2x1xi32, #blocked> - %12 = arith.muli %11, %cst : tensor<1x2x1xi32, #blocked> - %13 = tt.broadcast %8 : tensor<2x1x1x!tt.ptr, #blocked> -> tensor<2x2x1x!tt.ptr, #blocked> - %14 = tt.broadcast %12 : tensor<1x2x1xi32, #blocked> -> tensor<2x2x1xi32, #blocked> - %15 = tt.addptr %13, %14 : tensor<2x2x1x!tt.ptr, #blocked>, tensor<2x2x1xi32, #blocked> - %16 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 1, parent = #blocked}>}>> - %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 1, parent = #blocked}>}>> -> tensor<1x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %18 = tt.expand_dims %17 {axis = 1 : i32} : tensor<1x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1x32xi32, #blocked> - %19 = tt.broadcast %15 : tensor<2x2x1x!tt.ptr, #blocked> -> tensor<2x2x32x!tt.ptr, #blocked> - %20 = tt.broadcast %18 : tensor<1x1x32xi32, #blocked> -> tensor<2x2x32xi32, #blocked> - %21 = tt.addptr %19, %20 : tensor<2x2x32x!tt.ptr, #blocked>, tensor<2x2x32xi32, #blocked> - tt.store %21, %2 : tensor<2x2x32x!tt.ptr, #blocked> - tt.return - } -} - """ - - with tempfile.NamedTemporaryFile(mode="w", suffix=".ttgir") as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name, target=GPUTarget("cuda", 90, 32)) - - x = torch.randint(size=(4, 8, 32), low=0, high=100, dtype=torch.uint8).to(device) - desc = TmaDescKernelParam(x.data_ptr(), [4, 8, 32], [2, 2, 32], 1) - - z_tri = torch.zeros(size=(2, 2, 32), dtype=torch.uint8, device=device) - kernel[(1, 1, 1)](z_tri, desc) - - assert torch.equal(x[2:4, 2:4, :], z_tri) - - @requires_tma @pytest.mark.parametrize("inner_size", [16, 64]) def test_experimetal_descriptor_load_4d(inner_size):