Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check matmul types and error at compile-time if the backend doesn't support them #540

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,45 @@ union MatMulScaleType_t {
double cf64[2];
};

template <typename OpA, typename OpB, typename OpC, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
constexpr bool CompatibleGemmTypes() {
if constexpr (!std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type> &&
!std::is_same_v<typename OpA::scalar_type, typename OpC::scalar_type>) {
return false;
}

if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
if constexpr (std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
// List of accepted types when A/B/C match
return std::is_same_v<typename OpA::scalar_type, matxFp16> ||
std::is_same_v<typename OpA::scalar_type, matxBf16> ||
std::is_same_v<typename OpA::scalar_type, float> ||
std::is_same_v<typename OpA::scalar_type, double> ||
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<float>> ||
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<double>> ||
std::is_same_v<typename OpA::scalar_type, int8_t> ||
std::is_same_v<typename OpA::scalar_type, matxFp16Complex> ||
std::is_same_v<typename OpA::scalar_type, matxBf16Complex>;

}
// Accumulator type different from A/B
else if constexpr ( std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
return (std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, int32_t>) ||
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, matxBf16> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, matxFp16> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>);
}
}
else {
// For now return true for other providers until we support more
return true;
}
}

/**
* Parameters needed to execute a GEMM. For the most part, these are very
* similar to that of a standard GEMM call
Expand Down Expand Up @@ -834,7 +873,7 @@ class matxMatMulHandle_t {
static_cast<int>(
params_.ldc)}, // Tensor-ref for destination matrix D (may be
// different memory than source C matrix)
{alpha, beta}); // Scalars used in the Epilogue
{static_cast<T1>(alpha), static_cast<T1>(beta)}); // Scalars used in the Epilogue

CutlassGemm gemm_operator;
cutlass::Status status = gemm_operator(args, nullptr, stream);
Expand Down Expand Up @@ -895,7 +934,7 @@ class matxMatMulHandle_t {
params_.ldc)}, // Tensor-ref for destination matrix D (may
// be different memory than source C matrix)
c_adj.Stride(RANK - 3), // Batch Stride C
{alpha, beta},
{static_cast<T1>(alpha), static_cast<T1>(beta)},
params_.batch // Batch Dimension
); // Scalars used in the Epilogue

Expand Down Expand Up @@ -1118,6 +1157,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
auto A_ = as_type<typename TensorTypeC::scalar_type>(A);
auto B_ = as_type<typename TensorTypeC::scalar_type>(B);

static_assert(detail::CompatibleGemmTypes<decltype(A_), decltype(B_), TensorTypeC, PROV>(),
"Combination of A/B/C types are not supported");


// CublasLt does not support operators and certain transpose modes.
// Grab a suppported tensor here and copy in if necessary.
auto c = getCublasSupportedTensor(C, stream);
Expand Down