From 9d2f712da11e2f6b23e26fdc374cd68d2d643917 Mon Sep 17 00:00:00 2001 From: Ethan Date: Thu, 18 Jul 2024 13:55:51 +0000 Subject: [PATCH 1/2] Added tentative check for output sizes (currently > 256) that would cause register spills. Current value is hardcoded to 256, but should be generalized. --- .../Dialect/Rock/Tuning/GridwiseGemmParams.h | 28 +++++---- .../Rock/Tuning/GridwiseGemmParams.cpp | 62 +++++++++++++++---- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 6 +- 3 files changed, 69 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index e20f3c66b782..381202c8744d 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -269,7 +269,8 @@ class BasePopulateParams { // Succced if `params` should be included in a "full" tuning space that // excludes those known to not yeild good performance on the problem described // in `info`. This function uses hardcoded heuristics. - virtual LogicalResult couldBePerformant(const PopulateParamsInfo &info, + virtual LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamType ¶ms) = 0; // Convert the provided InitParamType into an MLIR `Attribute`. @@ -316,7 +317,8 @@ class PopulateParams : public BasePopulateParams { const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) override; - LogicalResult couldBePerformant(const PopulateParamsInfo &info, + LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) override; int64_t calculatePaddingAmount(const InitParamsNonAccel ¶ms, @@ -357,7 +359,8 @@ class PopulateParamsAccel : public BasePopulateParams { const PopulateParamsInfo &info, const InitParamsAccel ¶ms) override; - LogicalResult couldBePerformant(const PopulateParamsInfo &info, + LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsAccel ¶ms) override; virtual LogicalResult @@ -376,9 +379,10 @@ class PopulateParamsAccel : public BasePopulateParams { /// The actual implementation of couldBePerformant(), which shouldn't exist /// once we merge gridwise_gemm and gridwise_gemm_accel and thus flatten /// out the class heirachy in this file. - virtual LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) = 0; + virtual LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) = 0; + }; // @@ -413,9 +417,9 @@ class PopulateParamsXDL : public PopulateParamsAccel { bool enableDPerWaveFiltering = true) override; protected: - LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) override; + LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) override; }; // @@ -448,9 +452,9 @@ class PopulateParamsWmma : public PopulateParamsAccel { bool enableDPerWaveFiltering = true) override; protected: - LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) override; + LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) override; }; } // namespace rock diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index b19d86aa1a7a..f334191ef102 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Rock/utility/AmdArchDb.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" #include "mlir/Dialect/Rock/utility/math.h" +#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/Debug.h" @@ -212,9 +213,11 @@ PopulateParams::paramsProbablyValid(OpBuilder &b, } LogicalResult -PopulateParams::couldBePerformant(const PopulateParamsInfo &info, +PopulateParams::couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) { // Implement this if needed. + (void)b; (void)info; (void)params; return success(); @@ -336,9 +339,17 @@ PopulateParamsAccel::paramsProbablyValid(OpBuilder &b, } LogicalResult -PopulateParamsAccel::couldBePerformant(const PopulateParamsInfo &info, +PopulateParamsAccel::couldBePerformant(OpBuilder &b, const PopulateParamsInfo &info, const InitParamsAccel ¶ms) { - return specificCouldBePerformant(params, info.gemmAType, info.gemmBType); + //int64_t mRepeats = ; + //int64_t nRepeats = ;e + // look ABBEgg + // OpBuilder b + // info.gemmAType, info.gemmBtype + // info.arch + // tuningParams = op.getParams() + // + return specificCouldBePerformant(b, info, params); } LogicalResult PopulateParamsAccel::obtainTuningParameters( @@ -693,12 +704,38 @@ PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA, } LogicalResult -PopulateParamsXDL::specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, Type dataTypeB) { +PopulateParamsXDL::specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) { // Implement this if needed. - (void)params; - (void)dataTypeA; - (void)dataTypeB; + /* + Attribute params0 = getGemmParamsAttr(b, params); + RockAccelTuningParamAttrInterface accelParams0; + if (auto xdlopsParams0 = dyn_cast(params0)) { + auto xdlopsDerivedParams0 = XdlopsGemmDerivedParamsAttr::get(xdlopsParams0); + accelParams0 = xdlopsDerivedParams0; + } else { + accelParams0 = cast(params0); + } + auto accelEmitterPtr = accel::AccelEmitter::select( + info.gemmFeatures, info.gemmAType, info.gemmBType, info.arch.StringRef(), accelParams0); + + if (!accelEmitterPtr) + return failure(); + + rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams(); + + int64_t numOutputVectorElements = params.numOutputVectorElements(); + + // would be best to have register count be a part of arch, is not necessarily totalVGPRPerEu + if(numOutputVectorElements > 256) { + return failure(); + } + + check output size + int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats; + */ + return success(); } @@ -913,12 +950,13 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA, } LogicalResult -PopulateParamsWmma::specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, Type dataTypeB) { +PopulateParamsWmma::specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) { // Implement this if needed. + (void)b; + (void)info; (void)params; - (void)dataTypeA; - (void)dataTypeB; return success(); } diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 113ee996e4ba..c652ce2227c1 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -272,7 +272,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, b, info, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams))); @@ -309,7 +309,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams))); @@ -340,7 +340,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams))); From 0d5c030eacbcefefd58aa953317da6ae792da8e8 Mon Sep 17 00:00:00 2001 From: Ethan Date: Thu, 18 Jul 2024 14:31:10 +0000 Subject: [PATCH 2/2] Removed comments --- .../Rock/Tuning/GridwiseGemmParams.cpp | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index f334191ef102..923ece9fe647 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -341,14 +341,6 @@ PopulateParamsAccel::paramsProbablyValid(OpBuilder &b, LogicalResult PopulateParamsAccel::couldBePerformant(OpBuilder &b, const PopulateParamsInfo &info, const InitParamsAccel ¶ms) { - //int64_t mRepeats = ; - //int64_t nRepeats = ;e - // look ABBEgg - // OpBuilder b - // info.gemmAType, info.gemmBtype - // info.arch - // tuningParams = op.getParams() - // return specificCouldBePerformant(b, info, params); } @@ -707,8 +699,6 @@ LogicalResult PopulateParamsXDL::specificCouldBePerformant(OpBuilder &b, const PopulateParamsInfo &info, const InitParamsAccel ¶ms) { - // Implement this if needed. - /* Attribute params0 = getGemmParamsAttr(b, params); RockAccelTuningParamAttrInterface accelParams0; if (auto xdlopsParams0 = dyn_cast(params0)) { @@ -718,23 +708,20 @@ PopulateParamsXDL::specificCouldBePerformant(OpBuilder &b, accelParams0 = cast(params0); } auto accelEmitterPtr = accel::AccelEmitter::select( - info.gemmFeatures, info.gemmAType, info.gemmBType, info.arch.StringRef(), accelParams0); + info.gemmFeatures, info.gemmAType, info.gemmBType, StringRef(info.arch), accelParams0); if (!accelEmitterPtr) return failure(); - rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams(); + rock::accel::AccelEmitterParams accelParams = accelEmitterPtr->getParams(); - int64_t numOutputVectorElements = params.numOutputVectorElements(); + int64_t numOutputVectorElements = accelParams.numOutputVectorElements(); // would be best to have register count be a part of arch, is not necessarily totalVGPRPerEu if(numOutputVectorElements > 256) { return failure(); } - check output size - int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats; - */ return success(); }