Skip to content

Commit

Permalink
[BACKEND] Propagate mma layout to following elementwise operations. (t…
Browse files Browse the repository at this point in the history
…riton-lang#3973)

For matmul with following arithmetic operations such as `acc +=
tl.dot(a, b)`, currently the mma layout of the `dot` result isn't
propagated into the subsequent `add`. As a result when the dot is inside
a loop, there will be repeated layout conversion from mma to blocked.
I'm fixing this by allowing mma layout propagated so that it can be
reused.
  • Loading branch information
htyu authored Oct 22, 2024
1 parent ed39cb0 commit 1064b59
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 107 deletions.
95 changes: 3 additions & 92 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() {
op->erase();
}

// Look ahead to at the transitive uses and see if there is a convert to mma
// operations.
bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
SmallVector<Value> queue = {op->getResult(0)};
SetVector<Operation *> forwardSlice;
llvm::SmallDenseSet<Value> seen;
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
// HACK: Stop propagation if the ReduceOp is using mma layout but is
// producing tensor smaller than the layout we would like to propagate.
// This is to avoid stepping into the known bug.
if (isa<mlir::triton::ReduceOp>(op)) {
auto tensorType =
dyn_cast<RankedTensorType>(op->getOperand(0).getType());
if (tensorType &&
isa<NvidiaMmaEncodingAttr>(tensorType.getEncoding())) {
auto mmaInstrShape =
cast<NvidiaMmaEncodingAttr>(encoding).getInstrShape();
if (tensorType.getShape()[tensorType.getRank() - 2] <
mmaInstrShape[0] ||
tensorType.getShape()[tensorType.getRank() - 1] <
mmaInstrShape[1]) {
return false;
}
}
}

if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
Attribute dstEncoding = convertOp.getType().getEncoding();
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding))
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (isa<triton::gpu::AMDMfmaEncodingAttr,
triton::gpu::AMDWmmaEncodingAttr>(dstEncoding))
return true;
if (isa<triton::gpu::DotOperandEncodingAttr>(dstEncoding)) {
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
return mmaLayout.getVersionMajor() > 1;
} else {
assert((mlir::isa<triton::gpu::AMDMfmaEncodingAttr,
triton::gpu::AMDWmmaEncodingAttr>(encoding)));
return true;
}
}
}
bool isMMAV3 =
isa<NvidiaMmaEncodingAttr>(encoding) &&
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
return true;
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
continue;
if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp())) {
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def &&
(forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(ifOp.getResult(operand.getOperandNumber()));
}
}
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
if (!forOp)
continue;
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber()));
}
}
}
return false;
}

// Return true if the op is an op with a layout we don't want to change. We will
// propagate the layout starting from anchor ops.
bool isLayoutAnchor(Operation *op) {
Expand All @@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) {
}

void LayoutPropagation::initAnchorLayout() {
auto maybeAddAnchor = [&](Value v) {
auto addAnchor = [&](Value v) {
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
// Workaround, don't popagate MMA layout unless there is a convert
// back to mma further down to avoid generating reduction with MMA
// layout that may have lower performance.
// This can be improved with more aggressive backward propagation.
if (isa<MmaEncodingTrait>(tensorType.getEncoding()) &&
v.getDefiningOp() &&
!hasConvertToMMATransisitiveUse(v.getDefiningOp(),
tensorType.getEncoding())) {
return;
}
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
}
};
Expand All @@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() {
// you can pass a tensor with an encoding as an arg, instead of explicitly
// calling tt.load.
for (auto arg : funcOp.getArguments()) {
maybeAddAnchor(arg);
addAnchor(arg);
}

funcOp.walk([&](Operation *op) {
if (isLayoutAnchor(op)) {
for (auto result : op->getResults()) {
maybeAddAnchor(result);
addAnchor(result);
}
}
});
Expand Down
15 changes: 0 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,21 +3222,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri,
w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs)

if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"):
if not is_cuda():
pass
else:
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync.bfly")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0

# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
# TODO: we should eliminate these unused functions in ptx code.
if not (capability[0] >= 9):
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32)
Expand Down
42 changes: 42 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return %outLHS : tensor<128x64xf32, #blocked1>
}
}

// -----

#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>

module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} {
// CHECK-LABEL: matmul_add
tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %C : !tt.ptr<f32>) {
%a_ptr_init = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #BL>
%c_ptr_init = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #CL>
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL>
%cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>

%100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>) {
%a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>, #AL>
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr<f16>, #BL>
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL>
// CHECK: %[[T0:.*]] = tt.dot
// CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma>
%t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
// CHECK: scf.yield
scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>
}

// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked
tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr<f32>, #CL>
tt.return
}
}

0 comments on commit 1064b59

Please sign in to comment.