diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index c4421a72e278..839d42170517 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4783,10 +4783,15 @@ class ConvertTritonGPUToLLVM decomposed = true; }); - // async wait is supported in Ampere and later mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { - if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability) || - decomposed) { + if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) { + // async wait is supported in Ampere and later + asyncWaitOp.erase(); + } else if (decomposed) { + // Wait for all previous async ops + OpBuilder builder(asyncWaitOp); + auto newAsyncWaitOp = + builder.create(asyncWaitOp.getLoc(), 0); asyncWaitOp.erase(); } }); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 9bf9ecb059d9..090c1ea92b91 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -262,10 +262,10 @@ struct TritonCatPattern : public OpConversionPattern { // For now, this behaves like generic, but this will evolve when // we add support for `can_reorder=False` Type retType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, retType, + adaptor.getOperands()); return success(); } - }; struct TritonTransPattern : public OpConversionPattern { @@ -450,13 +450,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, - TritonCatPattern, - TritonReducePattern, - TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, - TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>( - typeConverter, context); + TritonGenericPattern, TritonCatPattern, + TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, + TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, + TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, + TritonAtomicRMWPattern>(typeConverter, context); } // diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 5a4d19a46e33..3417e36acc48 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -782,8 +782,8 @@ class BlockedToMMA : public mlir::RewritePattern { newRetType.getEncoding())); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); - auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32()); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, + b, newAcc, dotOp.allowTF32()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 8287350de8ac..c6e27eec6c62 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -225,6 +225,7 @@ scf::ForOp Prefetcher::createNewForOp() { BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py new file mode 100644 index 000000000000..c5e2540a47a4 --- /dev/null +++ b/python/tests/test_matmul.py @@ -0,0 +1,101 @@ +import itertools + +import pytest +import torch + +import triton +import triton._C.libtriton.triton as _triton + + +@pytest.mark.parametrize( + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", + itertools.chain( + *[ + [ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE), + # split-k + (64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE), + # variable input + (128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] + ], + # n-stage + *[ + [ + (16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE), + # split-k + (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] + ] + ), +) +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8 and DTYPE == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + #if DTYPE == "bfloat16" and SPLIT_K != 1: + # pytest.skip("bfloat16 matmuls don't allow split_k for now") + if DTYPE == "bfloat16": + pytest.skip("bfloat16 matmuls doesn't support for now") + torch.manual_seed(0) + # nuke kernel decorators -- will set meta-parameters manually + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] + kernel = triton.ops._matmul.kernel + kernel.configs = configs + # kernel.run = kernel.run.run.run + + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + # allocate/transpose inputs + DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE] + a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) + b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) + a = a.t() if AT else a + b = b.t() if BT else b + # run test + th_c = torch.matmul(a, b) + tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest) + triton.testing.assert_almost_equal(th_c, tt_c) diff --git a/python/tests/test_vecadd.py b/python/tests/test_vecadd.py index 265c25acd3cf..d032cf14a65a 100644 --- a/python/tests/test_vecadd.py +++ b/python/tests/test_vecadd.py @@ -68,7 +68,7 @@ def kernel(x_ptr, @num_elements: number of elements ''' pid = tl.program_id(axis=0) - for i in range(math.ceil(block_size / iter_size)): + for i in range(tl.cdiv(block_size, iter_size)): # TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error. offset = pid * block_size + tl.arange(0, iter_size) x_ptrs = x_ptr + offset diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 7d442b9f03e3..2e760dabb838 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -329,10 +329,6 @@ def visit_Tuple(self, node): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value fn = { ast.Add: '__add__', ast.Sub: '__sub__', @@ -591,8 +587,10 @@ def visit_For(self, node): ast.NodeVisitor.generic_visit(self, stmt) return # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False if isinstance(step, triton.language.constexpr) and step.value < 0: step = triton.language.constexpr(-step.value) + negative_step = True lb, ub = ub, lb # lb/ub/step might be constexpr, we need to cast them to tensor lb = triton.language.core._to_tensor(lb, self.builder).handle @@ -640,6 +638,9 @@ def visit_For(self, node): # update induction variable with actual value, and replace all uses self.builder.set_insertion_point_to_start(for_op.get_body(0)) iv = self.builder.create_index_to_si(for_op.get_induction_var()) + if negative_step: + ub_si = self.builder.create_index_to_si(ub) + iv = self.builder.create_sub(ub_si, iv) self.lscope[node.target.id].handle.replace_all_uses_with(iv) self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32)) @@ -890,9 +891,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.enable_debug() - # Convert blocked layout to mma layout for dot ops so that pipeline - # can get shared memory swizzled correctly. pm.add_coalesce_pass() + # The combine pass converts blocked layout to mma layout + # for dot ops so that pipeline can get shared memory swizzled correctly. pm.add_triton_gpu_combine_pass(compute_capability) pm.add_tritongpu_pipeline_pass(num_stages) # Prefetch must be done after pipeline pass because pipeline pass @@ -1358,12 +1359,12 @@ def make_hash(fn, **kwargs): return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() -# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func, +# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace -# - (@\w+) : match an @ symbol followed by one or more word characters +# - (@\w+) : match an @ symbol followed by one or more word characters # (letters, digits, or underscores), and capture it as group 1 (the function name) -# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" @@ -1395,20 +1396,20 @@ def compile(fn, **kwargs): extern_libs = kwargs.get("extern_libs", dict()) device = kwargs.get("device", torch.cuda.current_device()) capability = torch.cuda.get_device_capability() - capability = capability[0]*10 + capability[1] + capability = capability[0] * 10 + capability[1] # build compilation stages stages = { - "ast" : (lambda path: fn, None), - "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ast_to_ttir(src, signature, configs[0], constants)), - "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), - "llir": (lambda path: Path(path).read_bytes(), - lambda src: ttgir_to_llir(src, extern_libs, capability)), - "ptx": (lambda path: Path(path).read_text(), - lambda src: llir_to_ptx(src, capability)), - "cubin": (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, capability)) + "ast": (lambda path: fn, None), + "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ast_to_ttir(src, signature, configs[0], constants)), + "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), + "llir": (lambda path: Path(path).read_bytes(), + lambda src: ttgir_to_llir(src, extern_libs, capability)), + "ptx": (lambda path: Path(path).read_text(), + lambda src: llir_to_ptx(src, capability)), + "cubin": (lambda path: Path(path).read_bytes(), + lambda src: ptx_to_cubin(src, capability)) } # find out the signature of the function if isinstance(fn, triton.runtime.JITFunction): @@ -1467,8 +1468,8 @@ def compile(fn, **kwargs): if ir == ext: next_module = parse(fn) elif os.path.exists(path) and\ - ir in metadata["ctime"] and\ - os.path.getctime(path) == metadata["ctime"][ir]: + ir in metadata["ctime"] and\ + os.path.getctime(path) == metadata["ctime"][ir]: next_module = parse(path) else: next_module = compile(module) @@ -1504,8 +1505,7 @@ def __init__(self, so_path, metadata, asm): self.asm = asm device = torch.cuda.current_device() global cuda_utils - if cuda_utils is None: - cuda_utils = CudaUtils() + init_cuda_utils() mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) self.cu_module = mod self.cu_function = func @@ -1562,6 +1562,34 @@ def _generate_src(self): #define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; } + static PyObject* getDeviceProperties(PyObject* self, PyObject* args){ + int device_id; + if(!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int multiprocessor_count; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device)); + CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute(&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem, + "multiprocessor_count", multiprocessor_count, + "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, + "mem_bus_width", mem_bus_width); + } + static PyObject* loadBinary(PyObject* self, PyObject* args) { const char* name; const char* data; @@ -1601,6 +1629,7 @@ def _generate_src(self): static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"}, {NULL, NULL, 0, NULL} // sentinel }; @@ -1640,6 +1669,13 @@ def __init__(self): mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +def init_cuda_utils(): + global cuda_utils + if cuda_utils is None: + cuda_utils = CudaUtils() cuda_utils = None diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a099139fc760..8f89d6838d64 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -9,6 +9,7 @@ T = TypeVar('T') + def _to_tensor(x, builder): if isinstance(x, bool): return tensor(builder.get_int1(x), int1) @@ -348,6 +349,9 @@ def __rsub__(self, other): def __mul__(self, other): return constexpr(self.value * other.value) + def __mod__(self, other): + return constexpr(self.value % other.value) + def __rmul__(self, other): return constexpr(other.value * self.value) @@ -726,10 +730,12 @@ def broadcast_to(input, shape, _builder=None): """ return semantic.broadcast_impl_shape(input, shape, _builder) + @builtin def trans(input, _builder=None): return semantic.trans(input, _builder) + @builtin def cat(input, other, can_reorder=False, _builder=None): """ @@ -762,6 +768,7 @@ def view(input, shape, _builder=None): shape = [x.value for x in shape] return semantic.view(input, shape, _builder) + @builtin def reshape(input, shape, _builder=None): # TODO: should be more than just a view diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 9d8e65a50544..6bdc91ebdb67 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -481,7 +481,8 @@ def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: # TODO: disable when TritonToTritonGPU handles views properly - assert len(input.shape) == len(dst_shape) + + # assert len(input.shape) == len(dst_shape) numel = 1 for s in dst_shape: numel *= s diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index f1ac78849116..0ffcc167718b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -26,9 +26,6 @@ def get_configs_io_bound(): return configs -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, -}) @triton.autotune( configs=[ # basic configs for compute-bound matmuls @@ -59,6 +56,9 @@ def get_configs_io_bound(): 'top_k': 10 }, ) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) @triton.jit def _kernel(A, B, C, M, N, K, stride_am, stride_ak, diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 004f236b968c..53638657a52b 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -10,7 +10,9 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + triton.compiler.init_cuda_utils() + + num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) return tflops @@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) return tflops def get_tflops(backend, device, num_ctas, num_warps, dtype): - cc = _triton.runtime.cc(backend, device) - if cc < 80 and dtype == torch.float32: + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) @@ -59,7 +61,7 @@ def estimate_matmul_time( compute_ms = total_ops / tput # time to load data - num_sm = _triton.runtime.num_sm(backend, device) + num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] active_cta_ratio = min(1, num_ctas / num_sm) active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% @@ -99,7 +101,7 @@ def estimate_matmul_time( def early_config_prune(configs, named_args): backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() - cc = _triton.runtime.cc(backend, device) + capability = torch.cuda.get_device_capability() # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages dtsize = named_args['A'].element_size() dtype = named_args['A'].dtype @@ -110,7 +112,10 @@ def early_config_prune(configs, named_args): kw = config.kwargs BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + + # TODO: move to `cuda_utils` submodule + triton.compiler.init_cuda_utils() + max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"] required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize if required_shared_memory <= max_shared_memory: pruned_configs.append(config) @@ -136,7 +141,7 @@ def early_config_prune(configs, named_args): pruned_configs = [] for k, v in configs_map.items(): BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k - if cc >= 80: + if capability[0] >= 8: # compute cycles (only works for ampere GPUs) mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) mma_cycles = mmas / min(4, num_warps) * 8 diff --git a/python/triton/testing.py b/python/triton/testing.py index 95a05349ca14..2b8ad6d5aa72 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -16,6 +16,9 @@ _cutlass = None has_cutlass = False +# TODO: move to separate module +import triton + def catch_oor(kernel, pytest_handle=None): try: @@ -330,8 +333,8 @@ def get_dram_gbps(backend=None, device=None): backend = _triton.runtime.backend.CUDA if not device: device = torch.cuda.current_device() - mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) - bus_width = _triton.runtime.global_memory_bus_width(backend, device) + mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"] bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s return bw_gbps @@ -341,11 +344,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo backend = _triton.runtime.backend.CUDA if not device: device = torch.cuda.current_device() - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + + triton.compiler.init_cuda_utils() + num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 if not clock_rate: - clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz - cc = _triton.runtime.cc(backend, device) - if cc < 80: + clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: assert dtype == torch.float16 ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: