diff --git a/cmake/AddTritonUnitTest.cmake b/cmake/AddTritonUnitTest.cmake index 121687d6af12..a9efb9ad1ad8 100644 --- a/cmake/AddTritonUnitTest.cmake +++ b/cmake/AddTritonUnitTest.cmake @@ -35,7 +35,7 @@ function(add_triton_ut) # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac # laptop. I think the issue may be that the very first time you run a program # it's a bit slow. - gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60) + gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60) # Add the unit test to the top-level unit test target. add_dependencies(TritonUnitTests ${__NAME}) diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 3a488e65ed03..9d0d6684da88 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -18,6 +18,12 @@ namespace mlir { namespace triton { class AllocationAnalysis; +/// Callback to allow backends to specify target-specific scratch sizes for +/// some operations. +using AllocationAnalysisScratchSizeFn = std::function; + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op); + // To convert a tensor from one layout to another, we need to allocate a // temporary buffer (i.e., scratch buffer) in shared memory. The conversion may // require multiple iterations, with each iteration involving multiple @@ -102,7 +108,8 @@ class Allocation { explicit Allocation(Operation *operation) : operation(operation) {} /// Runs allocation analysis on the given top-level operation. - void run(FuncAllocMapT &funcAllocMap); + void run(FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter); /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } @@ -250,7 +257,9 @@ class ModuleAllocation : public CallGraph { public: using FuncOffsetMapT = DenseMap; - explicit ModuleAllocation(ModuleOp moduleOp) + ModuleAllocation(ModuleOp moduleOp, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = + triton::defaultAllocationAnalysisScratchSizeFn) : CallGraph(moduleOp) { walk( // Pre-order edge walk callback @@ -259,7 +268,7 @@ class ModuleAllocation : public CallGraph { [&](FunctionOpInterface funcOp) { auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); if (inserted) - iter->second.run(funcMap); + iter->second.run(funcMap, scratchSizeGetter); }); } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 76f806f0aee1..74c64e65c3c4 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, return scratchConfig; } +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + return helper.getScratchSizeInBytes(); + } + if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + return helper.getScratchSizeInBytes(); + } + if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + return std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + } + if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return 0; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); + return isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + } + if (isa(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + return 0; + } + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + auto elemTy = cast(value.getType()).getPointeeType(); + assert(!isa(elemTy) && "unexpected pointer type"); + return elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + } + if (auto createTensormap = dyn_cast(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } + return 0; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation::FuncAllocMapT *funcAllocMap, - Allocation *allocation) + Allocation *allocation, + AllocationAnalysisScratchSizeFn scratchSizeGetter) : operation(operation), funcAllocMap(funcAllocMap), - allocation(allocation) { + allocation(allocation), scratchSizeGetter(scratchSizeGetter) { run(); } @@ -177,77 +234,19 @@ class AllocationAnalysis { /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { - const size_t scratchAlignment = 128; - if (auto reduceOp = dyn_cast(op)) { - ReduceOpHelper helper(reduceOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto scanOp = dyn_cast(op)) { - ScanLoweringHelper helper(scanOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto histogram = dyn_cast(op)) { - auto dstTy = histogram.getType(); - int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( - op->getParentOfType()); - auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * - std::max(8, dstTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto cvtLayout = dyn_cast(op)) { - auto srcTy = cvtLayout.getSrc().getType(); - auto dstTy = cvtLayout.getType(); - auto srcEncoding = srcTy.getEncoding(); - auto dstEncoding = dstTy.getEncoding(); - if (mlir::isa(srcEncoding) || - mlir::isa(dstEncoding)) { - // Conversions from/to shared memory do not need scratch memory. - return; - } - // ConvertLayoutOp with both input/output non-shared_layout - // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's - // also possible to realize it with other approaches in restricted - // conditions, such as warp-shuffle - auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); - auto elems = getNumScratchElements(scratchConfig.paddedRepShape); - auto bytes = - isa(srcTy.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (isa(op)) { - auto value = op->getOperand(0); - // only scalar requires scratch memory - // make it explicit for readability - if (dyn_cast(value.getType())) { - // nothing to do - } else { - auto smemShape = getRepShapeForAtomic(op->getResult(0)); - auto elems = getNumScratchElements(smemShape); - auto elemTy = cast(value.getType()).getPointeeType(); - assert(!isa(elemTy) && "unexpected pointer type"); - auto bytes = - elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } - } else if (auto callOp = dyn_cast(op)) { + constexpr size_t scratchAlignment = 128; + if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); auto *funcAlloc = &(*funcAllocMap)[funcOp]; auto bytes = funcAlloc->getSharedMemorySize(); maybeAddScratchBuffer(op, bytes, scratchAlignment); - } else if (auto createTensormap = - dyn_cast(op)) { - constexpr int32_t kTMASize = 128; - constexpr int32_t kTMAAlign = 128; - maybeAddScratchBuffer(op, kTMASize, - kTMAAlign); + return; } + unsigned bytes = scratchSizeGetter(op); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { @@ -547,12 +546,16 @@ class AllocationAnalysis { Allocation::FuncAllocMapT *funcAllocMap; Allocation *allocation; BufferRangeMapT bufferRange; + AllocationAnalysisScratchSizeFn scratchSizeGetter; }; } // namespace triton -void Allocation::run(FuncAllocMapT &funcAllocMap) { - triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +void Allocation::run( + FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this, + scratchSizeGetter); } std::map> diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index d29c7c2fdd81..4cea14f0957f 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -138,12 +138,17 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(Attribute layout) { + static bool isSupportedDotOpLayout(RankedTensorType type) { + auto layout = type.getEncoding(); + auto bitwidth = type.getElementType().getIntOrFloatBitWidth(); if (auto dot = dyn_cast(layout)) { + auto kWidth = dot.getKWidth(); // Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy: // - kWidth == 8 + // - kWidth == 4, bitwidth = 32 if (auto mma = dyn_cast(dot.getParent())) { - bool legacyLoweringIsBuggy = dot.getKWidth() >= 8; + bool legacyLoweringIsBuggy = + kWidth >= 8 || (kWidth == 4 && bitwidth == 32); return legacyLoweringIsBuggy && mma.isAmpere(); } if (isa(dot.getParent())) @@ -162,7 +167,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { if (isa(srcLayout) && (isa( dstLayout) || - isSupportedDotOpLayout(dstLayout))) { + isSupportedDotOpLayout(dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -202,7 +207,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); auto dstLayout = dstTy.getEncoding(); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) && + assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 67c216b4bc08..7fd986f9ea5e 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -13,6 +13,11 @@ import triton.language as tl input_dtypes = ["float16", "float32", "float64"] +if triton.runtime.driver.active.get_current_target().backend == "cuda": + input_dtypes += ["int8", "float8_e5m2"] + cc = torch.cuda.get_device_capability(0) + if cc >= (8, 9): + input_dtypes += ["float8_e4m3fn"] out_dtypes = ["float16", "float32"] @@ -63,28 +68,40 @@ def matmul_kernel(A, B, C, M, N, K, # tl.store(C, acc, mask=mask) -@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", - [(M, K, N, w, x, o) # - for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # +@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype", + [(M, K, N, BLOCK_K, w, x, o) # + for BLOCK_K in [16, 32] # + for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] # for w in input_dtypes for x in input_dtypes # for o in out_dtypes]) -def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): +def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype): if x_dtype == w_dtype: pytest.skip("skip the same input dtype") device = torch.cuda.current_device() - x_dtype = getattr(torch, x_dtype) - w_dtype = getattr(torch, w_dtype) - a = torch.randn((M, K), device=device, dtype=x_dtype) - b = torch.randn((K, N), device=device, dtype=w_dtype) + x_dtype: torch.dtype = getattr(torch, x_dtype) + w_dtype: torch.dtype = getattr(torch, w_dtype) + + def init_tensor(dtype, shape): + if dtype == torch.int8: + return torch.randint(0, 2, shape, device=device, dtype=dtype) + elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + return torch.randn(shape, device=device, dtype=torch.float16).to(dtype) + else: + return torch.randn(shape, device=device, dtype=dtype) + + torch.manual_seed(42) + a = init_tensor(w_dtype, (M, K)) + b = init_tensor(x_dtype, (K, N)) + torch_dtype = getattr(torch, out_dtype) triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype)) out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) # launch kernel - BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 - grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1) + block_m, block_n, block_k = 16, 16, BLOCK_K + grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1) matmul_kernel[grid]( a, b, out_triton, M, N, K, # @@ -92,8 +109,8 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): b.stride(0), b.stride(1), # out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # GROUP_M=8, # - BLOCK_M=BLOCK_M, # - BLOCK_N=BLOCK_N, # - BLOCK_K=BLOCK_K) + BLOCK_M=block_m, # + BLOCK_N=block_n, # + BLOCK_K=block_k) torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01) diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py index a886ebb457f4..fc0a413c0663 100644 --- a/python/test/unit/tools/test_irsource.py +++ b/python/test/unit/tools/test_irsource.py @@ -1,4 +1,4 @@ -import tempfile +import pathlib import triton from triton.compiler import IRSource from triton._C.libtriton import ir @@ -6,7 +6,7 @@ target = triton.runtime.driver.active.get_current_target() -def test_mlir_attribute_parsing() -> None: +def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: ''' Tests that MLIR attributes are parsed correctly from input ttir/ttgir. @@ -37,21 +37,20 @@ def test_mlir_attribute_parsing() -> None: } } """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(sample_ttgir) - f.flush() - context = ir.context() - src = IRSource(f.name, context) + temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir" + temp_file.write_text(sample_ttgir) + context = ir.context() + src = IRSource(str(temp_file), context) - # check name and type signature - # should match ty_to_cpp(...) - assert src.signature == \ - {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ - 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} - assert src.name == "@matmul_kernel" + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" - # check num warps - assert src.parse_options()['num_warps'] == 8 + # check num warps + assert src.parse_options()['num_warps'] == 8 sample_ttgir_vector_add = r""" #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -83,11 +82,10 @@ def test_mlir_attribute_parsing() -> None: } } """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(sample_ttgir_vector_add) - f.flush() - context = ir.context() - src = IRSource(f.name, context) + temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir" + temp_file.write_text(sample_ttgir_vector_add) + context = ir.context() + src = IRSource(str(temp_file), context) - # now test compilation - triton.compile(f.name, target=target) + # now test compilation + triton.compile(str(temp_file), target=target) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index b967f136a966..573d9d41913d 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -131,6 +131,10 @@ def _post_hook(kwargs, exception): def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure + verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1" + if verbose: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() @@ -161,7 +165,9 @@ def kernel_call(): try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, PTXASError): + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if verbose: + print(f"Autotuning failed with {e}") return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 208f6b80bfe5..a12e3a026071 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,4 +1,11 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128 + +// Check there are no lines with a size different to 128 and we have at least a line with size 128. + +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} +// CHECK-128: scratch offset = {{.*}}, size = 128 +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 772e0258bf78..e7245e75cbbf 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -5,21 +5,42 @@ using namespace mlir; namespace { +unsigned getScratchSize128(Operation *) { return 128; } + +enum class GetScratchSizeFunction { + None, + ValidConstant, +}; + struct TestAllocationPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + TestAllocationPass() = default; + TestAllocationPass(const TestAllocationPass &other) + : PassWrapper>(other) {} + StringRef getArgument() const final { return "test-print-allocation"; } StringRef getDescription() const final { return "print the result of the allocation pass"; } + ModuleAllocation getModuleAllocation() { + switch (getScratchSizeFunction) { + case GetScratchSizeFunction::None: + return {getOperation()}; + case GetScratchSizeFunction::ValidConstant: + return {getOperation(), getScratchSize128}; + } + llvm_unreachable("Unhandled case"); + } + void runOnOperation() override { auto &os = llvm::errs(); ModuleOp moduleOp = getOperation(); // Convert to std::string can remove quotes from opName - ModuleAllocation moduleAllocation(moduleOp); + ModuleAllocation moduleAllocation = getModuleAllocation(); moduleOp.walk([&](triton::FuncOp funcOp) { auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); os << opName << "\n"; @@ -48,6 +69,15 @@ struct TestAllocationPass os << "size = " << allocation->getSharedMemorySize() << "\n"; }); } + + Option getScratchSizeFunction{ + *this, "get-scratch-size-function", + llvm::cl::desc("Custom scratch size function to use"), + llvm::cl::init(GetScratchSizeFunction::None), + llvm::cl::values( + clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"), + clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant", + "ValidConstant"))}; }; } // namespace diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 4c99a44dff15..6f2b99dfa13c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -226,11 +226,6 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, Value cSwizzleOffset) { Value warpB = multiDimWarpId[0]; Value warpOff = kOrder == 2 ? multiDimWarpId[1] : multiDimWarpId[2]; - int cTileShape = tileShape[order[0]]; - int sTileShape = tileShape[order[1]]; - if (!needTrans) { - std::swap(cTileShape, sTileShape); - } SmallVector offs(numPtrs); @@ -239,7 +234,6 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, int laneHeight = 8; int quadWidth = laneWidth * kWidth; int quadHeight = laneHeight; - int numQuadI = 2; // outer index base Value iBase = udiv(lane, i32_val(laneWidth)); @@ -544,12 +538,15 @@ Value composeValuesToDotOperandLayoutStruct( // unpacked into individual elements. // `kIters` specifies the number of contiguous int32 elements each thread // should load. - auto kIters = isHopper ? 1 : kWidth / (32 / bitwidth); + // `kSize` specifies the total number of int32 elements each thread should + // load. + int kIters = isHopper ? 1 : kWidth / (32 / bitwidth); + int kSize = repK >= kIters ? repK * 2 : kIters; std::vector elems; auto unpackVec = [&](int b, int m, int k) { - for (auto kIter = 0; kIter < kIters; ++kIter) { - auto val = vals.at({b, m, k + kIter}); + for (int kIter = 0; kIter < kIters; ++kIter) { + auto val = vals.at({b, m, (k + kIter) % kSize}); auto vec = bitcast(val, vecTy); for (auto i = 0; i < numElemsPerVec; ++i) { elems.push_back(extract_element(eltTy, vec, i32_val(i))); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 7b7ca7d1e238..622ff40873aa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -90,6 +90,7 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the // K dimension. llvm::SmallVector si; + auto kIters = kWidth / (32 / bitwidth); if (dot.getOpIdx() == 0) { // Original register layout: @@ -106,11 +107,63 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]] // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]] // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]] - for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) - for (size_t tile = 0; tile < 4; ++tile) - for (size_t e = 0; e < numElemsPerVec; ++e) { - si.push_back(kRep * numElemsPerVec + tile * kWidth + e); - } + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 4; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32, so numElemsPerVec=1. + // Each tile of the dot operand layout has a size of 16x32. + // However, if the triton tensor size is 16x16, elements along the k + // dimension are duplicated. Within each tile, each register + // contains 2x8 elements arranged as follows: + // + // tile0/0 tile0/1 + // |<--kWidth=4-->| |<--kWidth-->| + // |<-mmaWidth=2->| + // [0, 1, 2, 3] [0, 1, 2, 3] + // [4, 5, 6, 7] [4, 5, 6, 7] + // + // tile0/1 replicates the elements in tile0/0 along the k dimension. + // For a tensor size of 32x32, the next tile on the m dimension is as + // follows: + // + // tile1/0 tile1/1 + // |<--kWidth-->| |<--kWidth-->| + // [8, 9, 10, 11], [8, 9, 10, 11] + // [12, 13, 14, 15], [12, 13, 14, 15] + // + // Within a single tile, we can perform two MMAs, and the + // resulting register layout for each MMA is as follows: + // + // 1st MMA: [0, 4, 1, 5] + // 2nd MMA: [2, 6, 3, 7] + // 3rd MMA: [8, 12, 9, 13] + // 4th MMA: [10, 14, 11, 15] + // + // Additionally, we should reorder the elements by moving the duplicated + // elements to the end. In the example above, we convert the order from + // tile0/0, tile0/1, tile1/0, tile1/1 to tile0/0, tile1/0, tile0/1, + // tile1/1, so that only the first two tiles will be used in the + // computation. + size_t elemsPerTile = 2 * 2 * kWidth; + size_t elemsPerMma = 2 * 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t mTile = 0; mTile < 2; ++mTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + mTile * kWidth + + kTile * numElemsPerVec + e); + } + } } else { // Original register layout: // @@ -122,11 +175,36 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // 2nd MMA: [[2, 3], [10, 11]] // 3rd MMA: [[4, 5], [12, 13]] // 4th MMA: [[6, 7], [14, 15]] - for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) - for (size_t tile = 0; tile < 2; ++tile) - for (size_t e = 0; e < numElemsPerVec; ++e) { - si.push_back(kRep * numElemsPerVec + tile * kWidth + e); - } + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 2; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32. + // Original register layout: + // + // tile0/0 tile0/1 + // [0, 1, 2, 3]^T, [0, 1, 2, 3]^T + // + // Similar to the opIdx=0 situation, we should reorder the elements by + // moving the duplicated elements to the end. + size_t elemsPerTile = 2 * kWidth; + size_t elemsPerMma = 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + kTile * numElemsPerVec + + e); + } + } } auto step = si.size(); diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index bd27bf0842f3..aded5a96b139 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -26,10 +26,16 @@ void initProton(pybind11::module &&m) { SessionManager::instance().activateSession(sessionId); }); + m.def("activate_all", + []() { SessionManager::instance().activateAllSessions(); }); + m.def("deactivate", [](size_t sessionId) { SessionManager::instance().deactivateSession(sessionId); }); + m.def("deactivate_all", + []() { SessionManager::instance().deactivateAllSessions(); }); + m.def("finalize", [](size_t sessionId, const std::string &outputFormat) { auto outputFormatEnum = parseOutputFormat(outputFormat); SessionManager::instance().finalizeSession(sessionId, outputFormatEnum); diff --git a/third_party/proton/csrc/include/Profiler/Profiler.h b/third_party/proton/csrc/include/Profiler/Profiler.h index e87e8ccef88f..ed14fc1b685e 100644 --- a/third_party/proton/csrc/include/Profiler/Profiler.h +++ b/third_party/proton/csrc/include/Profiler/Profiler.h @@ -27,10 +27,9 @@ class Profiler { /// If the profiler is already started, this function does nothing. Profiler *start() { std::unique_lock lock(mutex); - if (this->isInitialized) - return this; - this->doStart(); - this->isInitialized = true; + if (this->initializedCount == 0) + this->doStart(); + this->initializedCount++; return this; } @@ -45,10 +44,11 @@ class Profiler { /// Stop the profiler. Profiler *stop() { std::unique_lock lock(mutex); - if (!this->isInitialized) + if (this->initializedCount == 0) return this; - this->doStop(); - this->isInitialized = false; + this->initializedCount--; + if (this->initializedCount == 0) + this->doStop(); return this; } @@ -80,7 +80,9 @@ class Profiler { mutable std::shared_mutex mutex; std::set dataSet; - bool isInitialized{false}; + +private: + int initializedCount{}; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index 7e3f9eb54684..28696de30c29 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -77,8 +77,12 @@ class SessionManager : public Singleton { void activateSession(size_t sessionId); + void activateAllSessions(); + void deactivateSession(size_t sessionId); + void deactivateAllSessions(); + void enterScope(const Scope &scope); void exitScope(const Scope &scope); diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index ee149acaf533..082f2b6af689 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -84,11 +84,25 @@ void SessionManager::activateSession(size_t sessionId) { activateSessionImpl(sessionId); } +void SessionManager::activateAllSessions() { + std::unique_lock lock(mutex); + for (auto iter : sessionActive) { + activateSessionImpl(iter.first); + } +} + void SessionManager::deactivateSession(size_t sessionId) { std::unique_lock lock(mutex); deActivateSessionImpl(sessionId); } +void SessionManager::deactivateAllSessions() { + std::unique_lock lock(mutex); + for (auto iter : sessionActive) { + deActivateSessionImpl(iter.first); + } +} + void SessionManager::activateSessionImpl(size_t sessionId) { throwIfSessionNotInitialized(sessions, sessionId); if (sessionActive[sessionId]) @@ -116,6 +130,7 @@ void SessionManager::removeSession(size_t sessionId) { } auto path = sessions[sessionId]->path; sessionPaths.erase(path); + sessionActive.erase(sessionId); sessions.erase(sessionId); } diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 808a1742a5b8..575c85b0cac8 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -85,36 +85,42 @@ def start( return libproton.start(name, context, data, backend) -def activate(session: Optional[int] = 0) -> None: +def activate(session: Optional[int] = None) -> None: """ Activate the specified session. The profiling session will be active and data will be recorded. Args: - session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + session (int): The session ID of the profiling session. Defaults to None (all sessions) Returns: None """ if is_command_line() and session != 0: raise ValueError("Only one session can be activated when running from the command line.") - libproton.activate(session) + if session is None: + libproton.activate_all() + else: + libproton.activate(session) -def deactivate(session: Optional[int] = 0) -> None: +def deactivate(session: Optional[int] = None) -> None: """ Stop the specified session. The profiling session's data will still be in the memory, but no more data will be recorded. Args: - session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + session (int): The session ID of the profiling session. Defaults to None (all sessions) Returns: None """ if is_command_line() and session != 0: raise ValueError("Only one session can be deactivated when running from the command line.") - libproton.deactivate(session) + if session is None: + libproton.deactivate_all() + else: + libproton.deactivate(session) def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None: diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index dd26ecbbfc7b..d4013ce2b675 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -3,7 +3,7 @@ import pathlib -def test_profile(tmp_path: pathlib.Path): +def test_profile_single_session(tmp_path: pathlib.Path): temp_file0 = tmp_path / "test_profile0.hatchet" session_id0 = proton.start(str(temp_file0.with_suffix(""))) proton.activate() @@ -29,6 +29,28 @@ def test_profile(tmp_path: pathlib.Path): pathlib.Path("test.hatchet").unlink() +def test_profile_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + proton.start(str(temp_file0.with_suffix(""))) + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert temp_file0.exists() + assert temp_file1.exists() + + temp_file2 = tmp_path / "test_profile2.hatchet" + session_id2 = proton.start(str(temp_file2.with_suffix(""))) + temp_file3 = tmp_path / "test_profile3.hatchet" + session_id3 = proton.start(str(temp_file3.with_suffix(""))) + proton.deactivate(session_id2) + proton.deactivate(session_id3) + proton.finalize() + assert temp_file2.exists() + assert temp_file3.exists() + + def test_profile_decorator(tmp_path: pathlib.Path): temp_file = tmp_path / "test_profile_decorator.hatchet" diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 01bcaf3be0df..0d2204fbfa7b 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -257,3 +257,23 @@ def test_deactivate(tmp_path: pathlib.Path): assert "device_id" not in data[0]["metrics"] assert len(data[0]["children"]) == 1 assert "device_id" in data[0]["children"][0]["metrics"] + + +def test_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_multiple_sessions0.hatchet" + temp_file1 = tmp_path / "test_multiple_sessions1.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + torch.randn((10, 10), device="cuda") + torch.randn((10, 10), device="cuda") + proton.deactivate(session_id0) + proton.finalize(session_id0) + torch.randn((10, 10), device="cuda") + proton.finalize(session_id1) + # kernel has been invokved twice in session 0 and three times in session 1 + with temp_file0.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 2 + with temp_file1.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 3