Skip to content

Commit

Permalink
[BACKEND] Fix transpose optimization missed during refactor (#5226)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Nov 22, 2024
1 parent 82b8f0f commit 4ae95e7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
return failure();

// Match outerCvt(trans(innerCvt(x))).
auto trans = allocOp.getSrc().getDefiningOp<MemDescTransOp>();
auto trans = allocOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>({1, 0}))
return failure();

MemDescType allocType = allocOp.getType();
auto allocEncoding = cast<SharedEncodingAttr>(allocType.getEncoding());
MemDescType srcTy = trans.getSrc().getType();
RankedTensorType srcTy = trans.getSrc().getType();

// MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3
// without transpose for other data types.)
Expand Down
20 changes: 20 additions & 0 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return %r : tensor<128x64xf32, #mma>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: mma_reorder_transpose
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.memdesc_trans
// CHECK: triton_nvidia_gpu.warp_group_dot
tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !triton_gpu.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked>
%dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1>
%r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
tt.return %r : tensor<128x64xf32, #mma>
}
}

0 comments on commit 4ae95e7

Please sign in to comment.