Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use DotOp layout for UpcastMXFPOp Lowering #3057

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ void init_triton_ir(py::module &&m) {
if (haveDump) {
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
// printingFlags.enableDebugInfo();
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
if (funcToDump.empty())
return true;
Expand Down
6 changes: 2 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3486,10 +3486,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
if is_xpu():
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
pytest.skip(
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
)
if 'e2m1' in (normal_type, mxfp_type):
pytest.skip("e2m1 dot-layout not supported on XPU")

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down
2 changes: 1 addition & 1 deletion third_party/intel/backend/arch_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static PyObject *parseDeviceArch(PyObject *self, PyObject *args) {
arch = "lnl";
break;
default:
printf("sycl_arch = %d", sycl_arch);
printf("sycl_arch = %d\n", sycl_arch);
}

return Py_BuildValue("s", arch.c_str());
Expand Down
88 changes: 83 additions & 5 deletions third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "PatternTritonGPUOpToLLVM.h"
#include <iostream>

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -29,14 +30,15 @@ static Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc,
auto undefRounding = static_cast<mlir::triton::RoundingMode>(-1);
Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16(
loc, rewriter, result, undefRounding);
// Value scaledBf16 = fmul(vBf16, scaleBf16);
// Account for NaN in the scale as per the mxfp specification.
return select(scaleIsNan, nanBf16, scaledBf16);
};

class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
const TargetInfoBase &targetInfo;
const bool upcastMXFPUseDotOpEnc =
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");

public:
UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter,
Expand All @@ -60,13 +62,89 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
Value warpId = udiv(tid, warpSize);
Value laneId = urem(tid, warpSize);

// TODO: check if using the correct Dot/DPAS term mapping
auto xType = cast<RankedTensorType>(op->getOperandTypes()[0]);
auto dotEnc = cast<DotOperandEncodingAttr>(xType.getEncoding());
auto dpasEnc = cast<DpasEncodingAttr>(dotEnc.getParent());

std::cout << "xVals size before fp4 -> fp8: " << xVals.size() << std::endl;
if (fpType == ScaleDotElemType::E2M1)
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
std::cout << "xVals size after fp4 -> fp8: " << xVals.size() << std::endl;
std::cout << "scaledVals size: " << scaleVals.size() << std::endl;

// TODO: need to refactor the logic, the concepts of DPAS do not match loops
// correctly
if (upcastMXFPUseDotOpEnc) {
assert(dotEnc.getOpIdx() == 0 && "NYI: rhs scale with dot encoding");
// FIXME: Doc is not completedly correct with Intel DPAS layout
// For Intel GPU PVC, each thread owns elements of 16 mxfp vectors so we
// need 16 scales Since we go from a threadShape of 2x16 to 32x1, we let
// c = tid / 16.
unsigned instShapeM = dpasEnc.getDPASInstShapeA()[0]; // 8
unsigned instShapeK =
dpasEnc.getDPASInstShapeA()[1]; // 16 for bf16, 32 for e2m1
unsigned scalingBlockSize = 32;
unsigned repSize =
scalingBlockSize / instShapeK; // 2 for bf16, 1 for e2m1
unsigned subTileSize = instShapeM;
unsigned stepSize =
dpasEnc.getOpsPerChannel() / 2; // TODO: check this definision
unsigned numMxfp =
TritonGPUDialect::TritonGPUDialect::getThreadsPerWarp(mod) /
instShapeM;
unsigned mxfpSize = repSize * subTileSize * stepSize;
unsigned numScales = 16;
// 2 fp4 are packed in one i8
if (fpType == ScaleDotElemType::E2M1) {
numMxfp /= 2;
numScales /= 2;
stepSize *= 2;
}

for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
for (int j = 0; j < 32; ++j) {
xVals[32 * i + j] =
mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], scaleVal);
std::cout << "ci: ";
Value c = udiv(laneId, i32_val(numScales));
SmallVector<Value, 16> ci;
for (int row = 0; row < numMxfp; ++row)
for (int col = 0; col < subTileSize; ++col) {
ci.emplace_back(add(c, i32_val(row + 2 * col)));
std::cout << " " << row + 2 * col;
}
std::cout << std::endl;

std::cout << "repSize: " << repSize << " subTileSize: " << subTileSize
<< " stepSize: " << stepSize << " numMxfp: " << numMxfp
<< " mxfpSize: " << mxfpSize << std::endl;
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
for (int mxfp = 0; mxfp < numMxfp; ++mxfp) {
SmallVector<Value, 8> si;
std::cout << "si: ";
for (int subTile = 0; subTile < 8; ++subTile) {
si.emplace_back(targetInfo.shuffleIdx(rewriter, loc, scaleVal,
ci[8 * mxfp + subTile]));
std::cout << " " << 8 * mxfp + subTile;
}
std::cout << "\nIdx: ";
for (int rep = 0; rep < repSize; ++rep)
for (int subTile = 0; subTile < subTileSize; ++subTile) {
std::cout << " Subtile(" << subTile << ") ";
for (int k = 0; k < stepSize; ++k) {
unsigned idx = i * scalingBlockSize + mxfp * mxfpSize +
rep * subTileSize + subTile * stepSize + k;
std::cout << " " << idx;
xVals[idx] =
mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile]);
}
}
std::cout << std::endl;
}
}
} else {
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
for (int j = 0; j < 32; ++j) {
xVals[32 * i + j] =
mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], scaleVal);
}
}
}

Expand Down
152 changes: 151 additions & 1 deletion third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,14 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);

unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1];
unsigned repeatCount = dpasEnc.getRepeatCount();
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[0];
SmallVector<unsigned, 2> threadsPerWarp{instrShapeM,
warpSize / instrShapeM};
// auto scaleTy = cast<RankedTensorType>(opDesc.scale.getType());
// unsigned scalingBlocks = scaleTy.getShape()[1];
// SmallVector<unsigned, 2> threadsPerWarp = {repeatCount, warpSize /
// repeatCount};
SmallVector<unsigned, 2> warpsPerCTA(rank, 1);
warpsPerCTA[0] = numWarps;
auto CTALayout = ttg::getCTALayout(retType.getEncoding());
Expand Down Expand Up @@ -531,6 +536,149 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
});
}

static void updateValueType(Value v, Attribute encoding,
ArrayRef<int64_t> shape) {
auto tensorType = cast<RankedTensorType>(v.getType());
auto newType =
RankedTensorType::get(shape, tensorType.getElementType(), encoding);
v.setType(newType);
}

static tt::TransOp updateUsers(Value result,
const SetVector<Operation *> &slice) {
tt::TransOp transOp;
if (llvm::any_of(result.getUsers(),
[&](Operation *user) { return slice.count(user) == 0; })) {
OpBuilder builder(result.getContext());
builder.setInsertionPointAfterValue(result);
transOp =
builder.create<tt::TransOp>(result.getLoc(), result, ArrayRef({1, 0}));
result.replaceUsesWithIf(transOp.getResult(), [&](OpOperand &operand) {
return operand.getOwner() != transOp.getOperation() &&
slice.count(operand.getOwner()) == 0;
});
}
return transOp;
}

// Sync the transpose in the IR, this is done to avoid generating convert layout
// when we have a transpose right after a dot as mma layout cannot be propagated
// through transpose op. Once we have layouts that can represent transposed MMA
// we can remove this transformation.
static void sinkTransposeOp(tt::TransOp input) {
SmallVector<tt::TransOp> queue = {input};
while (!queue.empty()) {
tt::TransOp transOp = queue.back();
Value currentValue = transOp.getResult();
queue.pop_back();
mlir::ForwardSliceOptions options;
options.filter = [](Operation *op) {
if (op->hasTrait<OpTrait::Elementwise>() && op->getNumOperands() == 1)
return true;
if (isa<scf::YieldOp>(op))
return isa<scf::ForOp>(op->getParentOp());
if (isa<ttg::ConvertLayoutOp>(op))
return true;
return false;
};
SetVector<Operation *> slice;
mlir::getForwardSlice(currentValue, &slice, options);
for (Operation *op : slice) {
if (op->hasTrait<OpTrait::Elementwise>()) {
// Update users of transpose op.
if (op->getOperand(0) == transOp.getResult())
op->setOperand(0, transOp.getOperand());
// Update the type of the result.
for (Value result : op->getResults()) {
auto srcType = cast<RankedTensorType>(op->getOperand(0).getType());
updateValueType(result, srcType.getEncoding(), srcType.getShape());
updateUsers(result, slice);
}
continue;
}
if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(op)) {
// Update users of transpose op.
if (op->getOperand(0) == transOp.getResult())
op->setOperand(0, transOp.getOperand());
auto resultEncoding = cvtOp.getType().getEncoding();
auto newDstEncoding = ttgi::inferSrcEncoding(transOp, resultEncoding);
assert(newDstEncoding);
auto srcType = cast<RankedTensorType>(cvtOp.getOperand().getType());
updateValueType(cvtOp.getResult(), newDstEncoding, srcType.getShape());
updateUsers(cvtOp.getResult(), slice);
continue;
}
assert(isa<scf::YieldOp>(op));
auto forOp = dyn_cast<scf::ForOp>(op->getParentOp());
assert(forOp);
for (OpOperand &operand : op->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def && (slice.count(def)) || def == transOp.getOperation()) {
if (def == transOp.getOperation())
operand.set(transOp.getOperand());
Type newType = operand.get().getType();
forOp.getResult(operand.getOperandNumber()).setType(newType);
tt::TransOp retTrans =
updateUsers(forOp.getResult(operand.getOperandNumber()), slice);
// Recursively try to propagate the new transpose inserted.
if (retTrans)
queue.push_back(retTrans);
forOp.getRegionIterArg(operand.getOperandNumber()).setType(newType);
tt::TransOp argTrans = updateUsers(
forOp.getRegionIterArg(operand.getOperandNumber()), slice);
if (argTrans)
queue.push_back(argTrans);
OpBuilder builder(forOp);
OpOperand &init = forOp.getInitsMutable()[operand.getOperandNumber()];
Value initTranspose = builder.create<tt::TransOp>(
forOp.getLoc(), init.get(), ArrayRef({1, 0}));
init.set(initTranspose);
}
}
}
}
}

// Transpose scaled_dot ops that have a scale on lhs.
static Operation *transposeDotOp(tt::DotScaledOp dotOp) {
OpBuilder builder(dotOp);
Value lhs = dotOp.getLhs();
std::array<int, 2> transOrder = {1, 0};
Value lhsTransposed =
builder.create<tt::TransOp>(lhs.getLoc(), lhs, transOrder);
Value rhs = dotOp.getRhs();
Value rhsTransposed =
builder.create<tt::TransOp>(rhs.getLoc(), rhs, transOrder);
Value c = dotOp.getC();
Value cTransposed = builder.create<tt::TransOp>(c.getLoc(), c, transOrder);
Value result = builder.create<tt::DotScaledOp>(
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(),
dotOp.getLhsType());
Operation *transposedResult =
builder.create<tt::TransOp>(result.getLoc(), result, transOrder);
dotOp.replaceAllUsesWith(transposedResult);
dotOp.erase();
return transposedResult;
}

static void transposeDots(ModuleOp m) {
SmallVector<tt::DotScaledOp> toTranspose;
m.walk([&](tt::DotScaledOp dotOp) -> void {
if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr)
toTranspose.push_back(dotOp);
});
SmallVector<Operation *> transposes;
for (tt::DotScaledOp dotOp : toTranspose) {
Operation *transpose = transposeDotOp(dotOp);
transposes.push_back(transpose);
}

for (Operation *transpose : transposes) {
sinkTransposeOp(cast<tt::TransOp>(transpose));
}
}

class TritonIntelGPUAccelerateMatmulPass
: public triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase<
TritonIntelGPUAccelerateMatmulPass> {
Expand All @@ -543,6 +691,8 @@ class TritonIntelGPUAccelerateMatmulPass
ModuleOp m = getOperation();
auto &dpasAnalysis = getAnalysis<ttg::intel::DPASAnalysis>();

transposeDots(m);

RewritePatternSet patterns(context);
patterns.add<BlockedToDPAS, DecomposeScaledBlocked>(context, dpasAnalysis);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
Expand Down
Loading