Skip to content

Commit

Permalink
introduce AMD inThreadTranspose for K-major dot operand
Browse files Browse the repository at this point in the history
  • Loading branch information
jtang10 committed Nov 21, 2024
1 parent ad28e6c commit e275dcc
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 2 deletions.
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
mlir::registerTritonAMDGPUInThreadTranspose();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ namespace mlir::triton::gpu {
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
// shared layouts with hasLeadingOffset == true) but is otherwise unused.
//
// inThreadTranspose is a flag indicating if transpose should be performed while
// the data resides in thread-local registers. This is set to true on AMD
// platform when non-KContig matrix is about to be written into LDS (shared
// memory) but is otherwise unused. More details are provided in the
// transpose2D() function in LinearLayoutConversions.cpp.
// Returns std::nullopt if the given layout can't be converted to an LL.
// TODO(jlebar): Remove the std::optional once all layouts are supported.
//
std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth = std::nullopt);
std::optional<int32_t> elemBitWidth = std::nullopt,
bool inThreadTranspose = false);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
Expand Down
99 changes: 98 additions & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,60 @@ LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
return ret;
}

// For sizePerThread = [4, 8], the regular linear layout will express it as
// the following
// - register=1 -> (1, 0)
// register=2 -> (2, 0)
// register=4 -> (4, 0)
// register=8 -> (0, 1)
// register=16 -> (0, 2)
// where out dims are: [dim1 (size 8), dim0 (size 4)]
// If we take the binary form, it will be an identity matrix. If we traverse
// from the dim of 4, it will be like the following
// - register=1 -> (0, 1)
// register=2 -> (0, 2)
// register=4 -> (1, 0)
// register=8 -> (2, 0)
// register=16 -> (4, 0)
// where out dims are: [dim1 (size 8), dim0 (size 4)]
// Inside the function we only change the register layout generation, so
// register layout is created by newly introduced transpose2D and the rest still
// comes from identityStandardND.
// Note that simply reversing the for-loop identityStandardND will not work
// because it will change the most minor dimension from dim1 to dim0, and still
// keep it as an identity matrix.
LinearLayout transpose2D(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
assert((order.size() == 2) && "only support dim of 2 now");

MLIRContext *ctx = inDimName.getContext();
StringAttr kRegister = S("register");

std::vector<std::vector<int32_t>> bases;
// traverse 2nd dimension (K-dim in GEMM case)
int dim = order[1];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({0, basis});
}
// traverse 1st dimension (N-dim in GEMM non-KContig B-tensor)
// this is the consecutive dimension loaded from global memory
dim = order[0];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({basis, 0});
}

auto dimMinor = "dim" + std::to_string(order[0]);
auto dimMajor = "dim" + std::to_string(order[1]);
StringAttr kDimMinor = S(dimMinor);
StringAttr kDimMajor = S(dimMajor);
auto ret = LinearLayout(
{{kRegister, bases}},
{{kDimMinor, shape[order[0]]}, {kDimMajor, shape[order[1]]}}, false);

return ret;
}

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
Expand Down Expand Up @@ -274,6 +328,45 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
}

// This function convert blockedEncodingAttr to linear layout in a special way.
// It accompanies the AMDGPUInThreadTranspose pass to transpose non-KContig
// tensor into KContig prior to writing into LDS (shared memory). This
// conversion treats the sizePerThread as a 2D matrix and has different access
// pattern.
//
// For example, consider the following blocked layout generated by
// AMDGPUInThreadTranspose: #blocked1 = #triton_gpu.blocked<{sizePerThread =
// [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>.
// Here since sizePerThread is 2D, there could be two ways to traverse it: along
// the dim of 8 or the dim of 4. The regular toLinearLayout() would go through
// it from the leading order, i.e. dim of 8, but since we want to transpose it
// in-thread, we'd want to iterate of the 2nd order, i.e. dim of 4, so that we
// can pack the element of 4 into a single vector, and AMD backend LLVM compiler
// will pack elements into consecutive VGPR to write data contiguous in K
// dimension into LDS. In this way we guarantee vectorized ds_read, and ds_write
// can be vectorized to 64bit or 32bit depending on the block size and number of
// warps.
//
// The functions is named ThreadRake because we have thread raking through
// multiple row at the same time, as opposed each warp raking through a cluster
// of rows, or the Triton way, which iterates through every warp avaiable,
// and then tile it over the entire block.
LinearLayout blockedToLinearLayoutThreadRake(ArrayRef<int64_t> shape,
BlockedEncodingAttr blocked) {
MLIRContext *ctx = blocked.getContext();
int rank = shape.size();
auto outDimNames = standardOutDimNames(ctx, rank);
const auto &order = blocked.getOrder();
auto sizePerThread = blocked.getSizePerThread();

auto ctaLayout =
transpose2D(S("register"), sizePerThread, order) *
identityStandardND(S("lane"), blocked.getThreadsPerWarp(), order) *
identityStandardND(S("warp"), blocked.getWarpsPerCTA(), order);

return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape);
}

} // anonymous namespace

std::optional<LinearLayout>
Expand Down Expand Up @@ -790,9 +883,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {
std::optional<int32_t> elemBitWidth /*= std::nullopt*/,
bool inThreadTranspose /*= false*/) {
// Layouts are distributed or shared
if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
auto blocked = dyn_cast<BlockedEncodingAttr>(distributed);
if (blocked && inThreadTranspose)
return blockedToLinearLayoutThreadRake(shape, blocked);
return distributed.toLinearLayout(shape);
} else if (auto shared = dyn_cast<SharedEncodingAttr>(layout)) {
if (shared.getHasLeadingOffset()) {
Expand Down
54 changes: 54 additions & 0 deletions test/TritonGPU/amd/in-thread-transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-in-thread-transpose | FileCheck %s

#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {

// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x256x!tt.ptr<f16>, [[threadrake_layout]]>
// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<64x256x!tt.ptr<f16>, [[threadrake_layout]]>
tt.func public @threadRake_transpose_b(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr<f16>, #blocked1>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%1 = tt.load %arg1 : tensor<64x256x!tt.ptr<f16>, #blocked1>
%2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma>
tt.return
}
}

// -----

#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {

// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<32x128x!tt.ptr<f16>, [[threadrake_layout]]>
// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<32x128x!tt.ptr<f16>, [[threadrake_layout]]>
tt.func public @threadRake_transpose_b_no_change(%arg0: tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x128x!tt.ptr<f16>, #blocked1>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
%1 = tt.load %arg1 : tensor<32x128x!tt.ptr<f16>, #blocked1>
%2 = triton_gpu.convert_layout %1 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%6 = tt.dot %arg0, %2, %cst_0 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
tt.return
}
}


// -----
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {

// CHECK-NOT: {{.*}} = triton_gpu.convert_layout {{.*blocked.*}} -> {{.*blocked.*}}
tt.func public @threadRake_no_transpose(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr<f16>, #blocked1>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%1 = tt.load %arg1 : tensor<64x256x!tt.ptr<f16>, #blocked1>
%2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma>
tt.return
}
}
2 changes: 2 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ std::unique_ptr<Pass> createTritonAMDGPUCanonicalizePointersPass();

std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass();

std::unique_ptr<Pass> createTritonAMDGPUInThreadTransposePass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUTransforms/Passes.h.inc"
Expand Down
15 changes: 15 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,19 @@ def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "ml
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];
}

def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::ModuleOp"> {
let summary = "Transpose K-outer dot operand while data is loaded into register right before writing to LDS";

let description = [{
Transpose non-KContig dot operand (not consecutive on K dimension) right before writing data into LDS. This feature
happens right after data has been loaded from global memory to thread-local registers and will promote
(does not guarantee) vectorized LDS write while let SharedEncodingAttr guarantee vectorized LDS read, by
adding few VALU instructions to perform in-thread transpose.
}];

let constructor = "mlir::createTritonAMDGPUInThreadTransposePass()";

let dependentDialects = [];
}

#endif
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_triton_library(TritonAMDGPUTransforms
ReorderInstructions.cpp
StreamPipelineV2.cpp
MfmaGroup.cpp
inThreadTranspose.cpp

DEPENDS
TritonAMDGPUIR
Expand Down
Loading

0 comments on commit e275dcc

Please sign in to comment.