diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b6d855a05388..cee1ae84ef59 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() { op->erase(); } -// Look ahead to at the transitive uses and see if there is a convert to mma -// operations. -bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { - SmallVector queue = {op->getResult(0)}; - SetVector forwardSlice; - llvm::SmallDenseSet seen; - while (!queue.empty()) { - Value currentValue = queue.back(); - queue.pop_back(); - getForwardSlice(currentValue, &forwardSlice); - for (Operation *op : forwardSlice) { - // HACK: Stop propagation if the ReduceOp is using mma layout but is - // producing tensor smaller than the layout we would like to propagate. - // This is to avoid stepping into the known bug. - if (isa(op)) { - auto tensorType = - dyn_cast(op->getOperand(0).getType()); - if (tensorType && - isa(tensorType.getEncoding())) { - auto mmaInstrShape = - cast(encoding).getInstrShape(); - if (tensorType.getShape()[tensorType.getRank() - 2] < - mmaInstrShape[0] || - tensorType.getShape()[tensorType.getRank() - 1] < - mmaInstrShape[1]) { - return false; - } - } - } - - if (auto convertOp = dyn_cast(op)) { - Attribute dstEncoding = convertOp.getType().getEncoding(); - if (auto mmaLayout = dyn_cast(dstEncoding)) - return (mmaLayout.getVersionMajor() > 1) ? true - : mmaLayout == encoding; - if (isa(dstEncoding)) - return true; - if (isa(dstEncoding)) { - if (auto mmaLayout = dyn_cast(encoding)) { - return mmaLayout.getVersionMajor() > 1; - } else { - assert((mlir::isa(encoding))); - return true; - } - } - } - bool isMMAV3 = - isa(encoding) && - cast(encoding).getVersionMajor() == 3; - if (isMMAV3 && (isa(op) || isa(op))) - return true; - auto yield = dyn_cast(op); - if (!yield) - continue; - if (auto ifOp = dyn_cast(yield->getParentOp())) { - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && - (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(ifOp.getResult(operand.getOperandNumber())); - } - } - auto forOp = dyn_cast(yield.getOperation()->getParentOp()); - if (!forOp) - continue; - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); - } - } - } - return false; -} - // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. bool isLayoutAnchor(Operation *op) { @@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) { } void LayoutPropagation::initAnchorLayout() { - auto maybeAddAnchor = [&](Value v) { + auto addAnchor = [&](Value v) { if (auto tensorType = dyn_cast(v.getType())) { - // Workaround, don't popagate MMA layout unless there is a convert - // back to mma further down to avoid generating reduction with MMA - // layout that may have lower performance. - // This can be improved with more aggressive backward propagation. - if (isa(tensorType.getEncoding()) && - v.getDefiningOp() && - !hasConvertToMMATransisitiveUse(v.getDefiningOp(), - tensorType.getEncoding())) { - return; - } layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); } }; @@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() { // you can pass a tensor with an encoding as an arg, instead of explicitly // calling tt.load. for (auto arg : funcOp.getArguments()) { - maybeAddAnchor(arg); + addAnchor(arg); } funcOp.walk([&](Operation *op) { if (isLayoutAnchor(op)) { for (auto result : op->getResults()) { - maybeAddAnchor(result); + addAnchor(result); } } }); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index cac75271e19f..3013bbf53177 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3222,21 +3222,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) - if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): - if not is_cuda(): - pass - else: - ptx = pgm.asm["ptx"] - start = ptx.find("shfl.sync.bfly") - end = ptx.find("cvt.rn.f16.f32") - red_code = ptx[start:end] - assert len(red_code) > 0 - - # skip this check on hopper because there are some functions whose name contain "shared" in ptx. - # TODO: we should eliminate these unused functions in ptx code. - if not (capability[0] >= 9): - assert "shared" not in red_code - assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 78c6f68bf612..682c1cb3019d 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %outLHS : tensor<128x64xf32, #blocked1> } } + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { + // CHECK-LABEL: matmul_add + tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %c_ptr_init = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #CL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL> + %cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + // CHECK: %[[T0:.*]] = tt.dot + // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> + %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: scf.yield + scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> + } + + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> + tt.return + } +}