Skip to content

Commit

Permalink
Merge commit 'cc89dac07b7acf3af9962d83250a8bc015fc5a91'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 29, 2024
2 parents 2350d5a + cc89dac commit 4f6f088
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 26 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's WMMA
// 1. Support for AMD's WMMA dot operand
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
if (isa<MmaEncodingTrait>(layout)) {
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ void lowerDistributedToShared(
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
assert(srcTy.getShape().size() <= 2 ||
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto elemTy = typeConverter->convertType(srcTy.getElementType());

auto smemBase = smemObj.getBase();
Expand Down Expand Up @@ -163,7 +160,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32) ||
dstTy.getRank() == 3) &&
mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
}
Expand Down Expand Up @@ -220,7 +219,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
assert((!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) ||
isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
8 changes: 6 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
slices.end())
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
if (llvm::find_if(slices, [](Operation *op) {
return op->hasTrait<OpTrait::DotLike>();
}) != slices.end())
return {(unsigned)numWarps, 1};

// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def build_extension(self, ext):
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPython3_INCLUDE_DIR=" + python_include_dir,
"-DPython3_INCLUDE_DIR=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
Expand Down
91 changes: 91 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5433,6 +5433,97 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
assert torch.equal(z, x)


layouts_3d = [
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0,
k_width=1),
]

shared_layout_3d = [
SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
]


@pytest.mark.parametrize("M, N, K", [[8, 16, 32]])
@pytest.mark.parametrize("shared_layout", shared_layout_3d)
@pytest.mark.parametrize("dist_layout", layouts_3d)
def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path):
layouts = f"""
#dist = {dist_layout}
#shared = {shared_layout}
"""
ir = layouts + f"""
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
%cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
%cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
%cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
%0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
%1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
%3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
%4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
%7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
%8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist>
%9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
%10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
%11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
%12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
%13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
%14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
%15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist>
%16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
%17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
%18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
%19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory>
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist>
%22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
%23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
%24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
%25 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
%26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
%27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
%28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
%29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
%30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist>
%31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
%32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
%33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
%34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
%35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
%36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
%37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist>
%38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
%39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
%40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
tt.return
}}
}}
"""

if is_xpu() and isinstance(dist_layout, DotOperandLayout) and isinstance(dist_layout.parent, MmaLayout):
pytest.xfail("DotOperandLayout with MmaLayout is not supported in XPU")

x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K)
z = torch.empty_like(x, device=device)

temp_file = tmp_path / "test_local_load_store.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))

kernel[(1, 1, 1)](x, z)
assert torch.equal(z, x)


mma_pairs = [
[
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
Expand Down
12 changes: 0 additions & 12 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ast
import inspect
import re
import sys
import warnings
import os
import textwrap
Expand Down Expand Up @@ -1176,17 +1175,6 @@ def visit_BoolOp(self, node: ast.BoolOp):

_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}

if sys.version_info < (3, 8):

def visit_NameConstant(self, node):
return constexpr(node.value)

def visit_Num(self, node):
return constexpr(node.n)

def visit_Str(self, node):
return constexpr(ast.literal_eval(node))

def visit_Attribute(self, node):
lhs = self.visit(node.value)
if _is_triton_tensor(lhs) and node.attr == "T":
Expand Down
65 changes: 65 additions & 0 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
Expand Down Expand Up @@ -97,6 +98,70 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
tt.return
}

// CHECK-LABEL: blocked_to_wmma1
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma1
tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>
tt.return
}

// CHECK-LABEL: wmma1_to_blocked
tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma1_to_blocked
tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
tt.return
}

// CHECK-LABEL: blocked_to_wmma2
tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma2
tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>
tt.return
}

// CHECK-LABEL: wmma2_to_blocked
tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma2_to_blocked
tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
tt.return
}
}

// -----
Expand Down
27 changes: 27 additions & 0 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}>
// CHECK: #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: chained_dot
tt.func public @chained_dot_wgmma(
%arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1>
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma>
%d = tt.dot %arg0, %arg1, %cst_0 :
tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
%t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked>
%c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1>
%r = tt.dot %c, %arg2, %cst_1 :
tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1>
tt.return %r : tensor<64x128xf32, #blocked1>
}
}

// -----

// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
Expand Down
Loading

0 comments on commit 4f6f088

Please sign in to comment.