Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODE SHARING] Ravil/sched inst #611

Draft
wants to merge 10 commits into
base: sjw-pipeline-infra
Choose a base branch
from
10 changes: 5 additions & 5 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,19 @@ def get_hip_autotune_config():
return [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
]


Expand Down
2,281 changes: 2,281 additions & 0 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir

Large diffs are not rendered by default.

1,671 changes: 1,632 additions & 39 deletions test/TritonGPU/amd/amd-stream-pipeline.mlir

Large diffs are not rendered by default.

19 changes: 13 additions & 6 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class HIPOptions:
num_warps: int = 4
waves_per_eu: int = 1
num_stages: int = 0
num_stages: int = 2
num_ctas: int = 1
extern_libs: dict = None
cluster_dims: tuple = (1, 1, 1)
Expand Down Expand Up @@ -136,14 +136,13 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_remove_layout_conversions(pm)
amd.passes.ttgpuir.add_optimize_epilogue(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_stream_pipeline(pm)
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
if options.num_stages != 0:
amd.passes.ttgpuir.add_reorder_instructions(pm)
amd.passes.ttgpuir.add_reorder_instructions(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
Expand All @@ -167,8 +166,16 @@ def make_llir(src, metadata, options):
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
sched_mode = ""
if "AMD_OPS_SCHED_MODE" in os.environ.keys():
sched_mode = os.environ['AMD_OPS_SCHED_MODE']
allowed = ["iglp-opt-0", "iglp-opt-1", "sched-barriers", ""]
if not sched_mode in allowed:
raise RuntimeError(
f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed: {", ".join(allowed)}')

__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ, sched_mode)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)

Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz,
std::string schedMode);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();

#define GEN_PASS_REGISTRATION
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::math::MathDialect",
Expand All @@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
"gfx target device architecture, e.g., gfx942">,
Option<"ftz", "ftz", "bool", /*default*/"true",
"flush denorms for math functions">,
Option<"sched", "sched", "std::string", /*default*/"\"\"",
"scheduling variants">,
];
}

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace mlir {

std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass();
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(int numStages = 2);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
Expand Down
6 changes: 6 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()";

let dependentDialects = [];

let options = [
Option<"numStages", "num_stages",
"int32_t", /*default*/"2",
"Number of Pipeline stages">
];
}

def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> {
Expand Down
19 changes: 14 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA;
namespace mlir::triton::AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
ConversionPatternRewriter &rewriter,
StringRef schedMode);

LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
Expand All @@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,

namespace {
struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
// using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit,
StringRef schedMode)
: ConvertOpToLLVMPattern<triton::DotOp>(typeConverter, benefit),
schedMode(schedMode) {}

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
Expand All @@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
if (!isOuter) {
auto dEncoding = cast<RankedTensorType>(D.getType()).getEncoding();
if (isa<AMDMfmaEncodingAttr>(dEncoding) && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter,
schedMode);
}
if (isa<AMDWmmaEncodingAttr>(dEncoding)) {
return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter);
Expand All @@ -51,14 +57,17 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
llvm::report_fatal_error(
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}

private:
StringRef schedMode;
};
} // namespace

namespace mlir::triton::AMD {
void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, benefit);
PatternBenefit benefit, StringRef schedMode) {
patterns.add<DotOpConversion>(typeConverter, benefit, schedMode);
}
} // namespace mlir::triton::AMD
82 changes: 78 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

using ValueTable = std::map<std::array<int, 3>, Value>;

enum class SchedulingOptionsEnum : int64_t {
IGLP_OPT_0 = 0,
IGLP_OPT_1 = 1,
SCHED_BARRIERS,
NONE_SCHED
};
enum class InstructionMaskEnum : int64_t {
VALU = 0x00000002,
SALU = 0x00000004,
MFMA = 0x00000008,
ALL_VMEM = 0x00000010,
VMEM_READ = 0x00000020,
VMEM_WRITE = 0x00000040,
ALL_DS = 0x00000080,
DS_READ = 0x00000100,
DS_WRITE = 0x00000200
};

struct DotOpMFMAConversionHelper {
AMDMfmaEncodingAttr mfmaLayout;

ConversionPatternRewriter &rewriter;
const LLVMTypeConverter *typeConverter;
SchedulingOptionsEnum schedMode;
Location loc;
MLIRContext *ctx{};

explicit DotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter *typeConverter,
SchedulingOptionsEnum schedMode,
Location loc)
: mfmaLayout(mfmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
typeConverter(typeConverter), schedMode(schedMode), loc(loc),
ctx(mfmaLayout.getContext()) {}

Value getThreadId() const {
auto llvmIndexTy = typeConverter->getIndexType();
Expand All @@ -70,6 +91,45 @@ struct DotOpMFMAConversionHelper {
return rewriter.create(loweredOp)->getResult(0);
}

void generatedIglpIntrinsic() const {
if (!((schedMode == SchedulingOptionsEnum::IGLP_OPT_0) ||
(schedMode == SchedulingOptionsEnum::IGLP_OPT_1))) {
return;
}
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.iglp.opt");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = rewriter.getI32Type();

auto option = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, static_cast<int>(schedMode)));
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{option}, defaultFlags);
}

void buildSchedGroupBarrier(InstructionMaskEnum maskValue, int sizeValue,
int groupIdValue) const {
auto intrinsicName =
StringAttr::get(ctx, "llvm.amdgcn.sched.group.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = rewriter.getI32Type();
auto mask = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, static_cast<int64_t>(maskValue)));
auto size = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, sizeValue));
auto groupId = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, groupIdValue));

rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
}

void insertSchedBarriers() const {
if (!(schedMode == SchedulingOptionsEnum::SCHED_BARRIERS))
return;
// TODO(ravil)
}

int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if ((mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64))
return 1;
Expand Down Expand Up @@ -171,6 +231,8 @@ struct DotOpMFMAConversionHelper {
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

generatedIglpIntrinsic();

Value a = op.getA();
Value b = op.getB();
Value d = op.getD();
Expand Down Expand Up @@ -263,6 +325,9 @@ struct DotOpMFMAConversionHelper {
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

insertSchedBarriers();

rewriter.replaceOp(op, res);

return success();
Expand Down Expand Up @@ -351,13 +416,13 @@ struct DotOpMFMAConversionHelper {
return dotOpVals;
}
};

} // namespace

namespace mlir::triton::AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter,
StringRef schedMode) {
auto rankedTType = [](Value tensor) {
return cast<RankedTensorType>(tensor.getType());
};
Expand All @@ -375,11 +440,20 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
"DotOp's $c operand should pass the same number of values as $d");

static const DenseMap<StringRef, SchedulingOptionsEnum> schedModesToEnum = {
{"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0},
{"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1},
{"sched-barriers", SchedulingOptionsEnum::SCHED_BARRIERS},
{"", SchedulingOptionsEnum::NONE_SCHED}};
assert(schedModesToEnum.contains(schedMode) &&
"sched mode must be in the allowed set");

auto loc = op.getLoc();
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(
cast<RankedTensorType>(op.getResult().getType()).getEncoding());

DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, loc);
DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter,
schedModesToEnum.at(schedMode), loc);

return helper.convertDot(op, adaptor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns(
void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit);
PatternBenefit benefit, StringRef schedMode);
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
Expand Down
12 changes: 8 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget {
struct ConvertTritonAMDGPUToLLVM
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
ConvertTritonAMDGPUToLLVM> {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz,
StringRef schedMode) {
this->arch = targetArch.str();
this->ftz = ftz;
this->sched = schedMode.str();
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
typeConverter, targetInfo, patterns, commonBenefit);
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, AMDBenefit);
axisInfoAnalysis, AMDBenefit, sched);
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
axisInfoAnalysis, allocation,
targetInfo, AMDBenefit);
Expand Down Expand Up @@ -246,8 +248,10 @@ namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz,
std::string schedMode) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz,
schedMode);
}

} // namespace triton
Expand Down
Loading