Skip to content

Commit

Permalink
Improve robustness of MMA layout propagation to tt.store with block p…
Browse files Browse the repository at this point in the history
…ointer (triton-lang#1272)

Addition of a possible pattern for MMA layout propagation when the
ConvertLayoutOp is inside the loop, the layout is retrieved from the
layout map instead of the ConvertLayoutOp.

Addresses Issue: triton-lang#1271

---------

Signed-off-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
  • Loading branch information
mfrancepillois authored Jun 12, 2024
1 parent 93d168c commit c674beb
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 39 deletions.
48 changes: 45 additions & 3 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i64) {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i64, %arg7: i32, %arg8: i64) {
%c8_i32 = arith.constant 8 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i64 = arith.constant 1 : i64
Expand All @@ -39,9 +39,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
%14 = arith.muli %11, %c64_i32 : i32
%15 = arith.extsi %arg3 : i32 to i64
%16 = arith.extsi %arg5 : i32 to i64
%17 = arith.extsi %arg6 : i32 to i64
// CHECK: %[[VAL_36:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%arg6, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%19 = arith.muli %13, %c256_i32 : i32
%20 = arith.extsi %arg4 : i32 to i64
%21 = arith.extsi %arg7 : i32 to i64
Expand Down Expand Up @@ -221,3 +220,46 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}


// -----

// COM: Case 4:
// COM: Checks that DPAS encoding has been forwarded to the store op
// COM: and the triton_gpu.convert_layout operation in the loop has been removed
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #blocked1>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %36, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
// CHECK-NOT: triton_gpu.convert_layout
%35 = triton_gpu.convert_layout %32 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1>
scf.yield %35, %33, %34 : tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1>
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -805,54 +805,80 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {

// 2D block store are preceeded by a MakeTensorPtrOp
auto makeTensorPtrOp = ptr.getDefiningOp<MakeTensorPtrOp>();
if (!makeTensorPtrOp)
return false;

// DPAS encoding have to be propagate if conversion from DPAS to
// other has been done before.
auto convertOp = storeOp.getValue().getDefiningOp<ConvertLayoutOp>();
if (!convertOp || !makeTensorPtrOp)
return false;
PointerType newPtrType;
Attribute encoding;
Value value;
if (!convertOp) {
// If the Defining op is not a ConvertLayoutOp that means that conversion
// has not been hoisted out of the loop yet.
// We try then to find the layout in the map of the processed layouts.
value = storeOp.getValue();
auto it = layouts.find(value);
if (it == layouts.end())
return false;

encoding = *(it->second.encodings.begin());

if (!isa<ttgi::DpasEncodingAttr>(encoding))
return false;

auto ptrType = cast<PointerType>(makeTensorPtrOp.getType());
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());

auto tmpType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
newPtrType = PointerType::get(tmpType, ptrType.getAddressSpace());
} else {
Attribute convertOpDstEncoding = convertOp.getType().getEncoding();
RankedTensorType convertOpSrcType = convertOp.getSrc().getType();
if (((!convertOpDstEncoding) ||
isa<ttgi::DpasEncodingAttr>(convertOpDstEncoding)) ||
(!convertOpSrcType ||
!isa<ttgi::DpasEncodingAttr>(convertOpSrcType.getEncoding())))
return false;

Attribute convertOpDstEncoding = convertOp.getType().getEncoding();
RankedTensorType convertOpSrcType = convertOp.getSrc().getType();
if ((convertOpDstEncoding &&
!isa<ttgi::DpasEncodingAttr>(convertOpDstEncoding)) &&
(convertOpSrcType &&
isa<ttgi::DpasEncodingAttr>(convertOpSrcType.getEncoding()))) {
auto ptrType = cast<PointerType>(makeTensorPtrOp.getType());
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
// If the output type of the MakeTensorPtrOp already has a
// DPAS encoding, we do not forward the previous DPAS encoding.
if (isa<ttgi::DpasEncodingAttr>(tensorType.getEncoding()))
return false;

auto newPtrType =
PointerType::get(convertOpSrcType, ptrType.getAddressSpace());

// We create a new MakeTensorPtrOp with the new data type.
OpBuilder rewriter(makeTensorPtrOp);
Value newStorePtr = rewriter.create<MakeTensorPtrOp>(
makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(),
makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(),
makeTensorPtrOp.getOffsets(), rewriter.getDenseI32ArrayAttr({1, 0}));

// The encoding of the StoreOp is updated with the new
// operands:
// - the Ptr created by the MakeTensorPtrOp with the new data
// type
// - the forwarded DPAS encoding.
Value newOperand =
getValueAs(convertOp.getSrc(), convertOpSrcType.getEncoding());
storeOp.setOperand(0, newStorePtr);
storeOp.setOperand(1, newOperand);

// If the DPAS encoding is forwarded, we do not need the
// convertOp anymore if the convertOp was only used by the
// storeOp. Same for the initial MakeTensorPtrOp, if it was
// only used by the storeOp. If this is the case, these
// instructions are removed by the clean-up step performed at
// the end of this pass (step 4).
return true;
newPtrType = PointerType::get(convertOpSrcType, ptrType.getAddressSpace());

value = convertOp.getSrc();
encoding = convertOpSrcType.getEncoding();
}
return false;

// We create a new MakeTensorPtrOp with the new data type.
OpBuilder rewriter(makeTensorPtrOp);
Value newStorePtr = rewriter.create<MakeTensorPtrOp>(
makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(),
makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(),
makeTensorPtrOp.getOffsets(), rewriter.getDenseI32ArrayAttr({1, 0}));

// The encoding of the StoreOp is updated with the new
// operands:
// - the Ptr created by the MakeTensorPtrOp with the new data
// type
// - the forwarded DPAS encoding.
Value newOperand = getValueAs(value, encoding);
storeOp.setOperand(0, newStorePtr);
storeOp.setOperand(1, newOperand);

// If the DPAS encoding is forwarded, we do not need the
// convertOp anymore if the convertOp was only used by the
// storeOp. Same for the initial MakeTensorPtrOp, if it was
// only used by the storeOp. If this is the case, these
// instructions are removed by the clean-up step performed at
// the end of this pass (step 4).
return true;
}

Operation *LayoutPropagation::rewriteOp(Operation *op) {
Expand Down

0 comments on commit c674beb

Please sign in to comment.