Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mfma prune #1577

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ class BasePopulateParams {
// Succced if `params` should be included in a "full" tuning space that
// excludes those known to not yeild good performance on the problem described
// in `info`. This function uses hardcoded heuristics.
virtual LogicalResult couldBePerformant(const PopulateParamsInfo &info,
virtual LogicalResult couldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamType &params) = 0;

// Convert the provided InitParamType into an MLIR `Attribute`.
Expand Down Expand Up @@ -316,7 +317,8 @@ class PopulateParams : public BasePopulateParams<InitParamsNonAccel> {
const PopulateParamsInfo &info,
const InitParamsNonAccel &params) override;

LogicalResult couldBePerformant(const PopulateParamsInfo &info,
LogicalResult couldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsNonAccel &params) override;

int64_t calculatePaddingAmount(const InitParamsNonAccel &params,
Expand Down Expand Up @@ -357,7 +359,8 @@ class PopulateParamsAccel : public BasePopulateParams<InitParamsAccel> {
const PopulateParamsInfo &info,
const InitParamsAccel &params) override;

LogicalResult couldBePerformant(const PopulateParamsInfo &info,
LogicalResult couldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) override;

virtual LogicalResult
Expand All @@ -376,9 +379,10 @@ class PopulateParamsAccel : public BasePopulateParams<InitParamsAccel> {
/// The actual implementation of couldBePerformant(), which shouldn't exist
/// once we merge gridwise_gemm and gridwise_gemm_accel and thus flatten
/// out the class heirachy in this file.
virtual LogicalResult specificCouldBePerformant(const InitParamsAccel &params,
Type dataTypeA,
Type dataTypeB) = 0;
virtual LogicalResult specificCouldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) = 0;

};

//
Expand Down Expand Up @@ -413,9 +417,9 @@ class PopulateParamsXDL : public PopulateParamsAccel {
bool enableDPerWaveFiltering = true) override;

protected:
LogicalResult specificCouldBePerformant(const InitParamsAccel &params,
Type dataTypeA,
Type dataTypeB) override;
LogicalResult specificCouldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) override;
};

//
Expand Down Expand Up @@ -448,9 +452,9 @@ class PopulateParamsWmma : public PopulateParamsAccel {
bool enableDPerWaveFiltering = true) override;

protected:
LogicalResult specificCouldBePerformant(const InitParamsAccel &params,
Type dataTypeA,
Type dataTypeB) override;
LogicalResult specificCouldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) override;
};

} // namespace rock
Expand Down
51 changes: 38 additions & 13 deletions mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Dialect/Rock/utility/AmdArchDb.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -212,9 +213,11 @@ PopulateParams::paramsProbablyValid(OpBuilder &b,
}

LogicalResult
PopulateParams::couldBePerformant(const PopulateParamsInfo &info,
PopulateParams::couldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsNonAccel &params) {
// Implement this if needed.
(void)b;
(void)info;
(void)params;
return success();
Expand Down Expand Up @@ -336,9 +339,9 @@ PopulateParamsAccel::paramsProbablyValid(OpBuilder &b,
}

LogicalResult
PopulateParamsAccel::couldBePerformant(const PopulateParamsInfo &info,
PopulateParamsAccel::couldBePerformant(OpBuilder &b, const PopulateParamsInfo &info,
const InitParamsAccel &params) {
return specificCouldBePerformant(params, info.gemmAType, info.gemmBType);
return specificCouldBePerformant(b, info, params);
}

LogicalResult PopulateParamsAccel::obtainTuningParameters(
Expand Down Expand Up @@ -693,12 +696,33 @@ PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA,
}

LogicalResult
PopulateParamsXDL::specificCouldBePerformant(const InitParamsAccel &params,
Type dataTypeA, Type dataTypeB) {
// Implement this if needed.
(void)params;
(void)dataTypeA;
(void)dataTypeB;
PopulateParamsXDL::specificCouldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) {
Attribute params0 = getGemmParamsAttr(b, params);
RockAccelTuningParamAttrInterface accelParams0;
if (auto xdlopsParams0 = dyn_cast<XdlopsGemmParamsAttr>(params0)) {
auto xdlopsDerivedParams0 = XdlopsGemmDerivedParamsAttr::get(xdlopsParams0);
accelParams0 = xdlopsDerivedParams0;
} else {
accelParams0 = cast<RockAccelTuningParamAttrInterface>(params0);
}
auto accelEmitterPtr = accel::AccelEmitter::select(
info.gemmFeatures, info.gemmAType, info.gemmBType, StringRef(info.arch), accelParams0);

if (!accelEmitterPtr)
return failure();

rock::accel::AccelEmitterParams accelParams = accelEmitterPtr->getParams();

int64_t numOutputVectorElements = accelParams.numOutputVectorElements();

// would be best to have register count be a part of arch, is not necessarily totalVGPRPerEu
if(numOutputVectorElements > 256) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go ahead and make that an arch field, then

return failure();
}


return success();
}

Expand Down Expand Up @@ -913,12 +937,13 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA,
}

LogicalResult
PopulateParamsWmma::specificCouldBePerformant(const InitParamsAccel &params,
Type dataTypeA, Type dataTypeB) {
PopulateParamsWmma::specificCouldBePerformant(OpBuilder &b,
const PopulateParamsInfo &info,
const InitParamsAccel &params) {
// Implement this if needed.
(void)b;
(void)info;
(void)params;
(void)dataTypeA;
(void)dataTypeB;
return success();
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace,
b, info, gemmParams)) &&
(kind == TuningParamSetKind::Exhaustive ||
succeeded(
tuningInfo.couldBePerformant(info, gemmParams))))
tuningInfo.couldBePerformant(b, info, gemmParams))))
newSpace->tuningRange.push_back(
cast<RockTuningParamAttrInterface>(
tuningInfo.getGemmParamsAttr(b, gemmParams)));
Expand Down Expand Up @@ -309,7 +309,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace,
gemmParams)) &&
(kind == TuningParamSetKind::Exhaustive ||
succeeded(
tuningInfo.couldBePerformant(info, gemmParams))))
tuningInfo.couldBePerformant(b, info, gemmParams))))
newSpace->tuningRange.push_back(
cast<RockTuningParamAttrInterface>(
tuningInfo.getGemmParamsAttr(b, gemmParams)));
Expand Down Expand Up @@ -340,7 +340,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace,
gemmParams)) &&
(kind == TuningParamSetKind::Exhaustive ||
succeeded(
tuningInfo.couldBePerformant(info, gemmParams))))
tuningInfo.couldBePerformant(b, info, gemmParams))))
newSpace->tuningRange.push_back(
cast<RockTuningParamAttrInterface>(
tuningInfo.getGemmParamsAttr(b, gemmParams)));
Expand Down