Skip to content

Commit

Permalink
Merge pull request ROCm#158 from CongMa13/contraction_f16_bf16
Browse files Browse the repository at this point in the history
Contraction f16, bf16, f32_f16, f32_bf16, f64_f32
  • Loading branch information
CongMa13 authored Dec 11, 2023
2 parents 852992e + b21fe0b commit 8c11d59
Show file tree
Hide file tree
Showing 108 changed files with 6,308 additions and 1,036 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions library/include/hiptensor/internal/hiptensor_utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <iostream>

#include "../hiptensor_types.hpp"
#include "types_ext.hpp"

#ifndef CHECK_HIP_ERROR
#define CHECK_HIP_ERROR(expression) \
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
48 changes: 25 additions & 23 deletions library/src/contraction/contraction_cpu_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,33 @@
#include "contraction_cpu_reference_impl.hpp"
#include "contraction_cpu_reference_instances.hpp"

hiptensorStatus_t hiptensorContractionReference(void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ns_ks_lengths,
std::vector<size_t> const& b_ns_ks_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace)
hiptensorStatus_t hiptensorContractionReference(const hiptensorContractionPlan_t* plan,
void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ns_ks_lengths,
std::vector<size_t> const& b_ns_ks_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace)
{
auto& instances = hiptensor::ContractionCpuReferenceInstances::instance();
auto& instances = hiptensor::ContractionCpuReferenceInstances::instance();
auto computeType = plan->mContractionDesc.mComputeType;
auto candidates
= (C == nullptr)
? instances->allSolutions().query(typeA, typeB, hiptensor::NONE_TYPE, typeD)
: instances->allSolutions().query(typeA, typeB, typeC, typeD);
= (C == nullptr) ? instances->allSolutions().query(
typeA, typeB, hiptensor::NONE_TYPE, typeD, computeType)
: instances->allSolutions().query(typeA, typeB, typeC, typeD, computeType);

auto toCKVec
= [](auto& inputVec) { return std::vector<ck::index_t>(inputVec.begin(), inputVec.end()); };
Expand Down
39 changes: 20 additions & 19 deletions library/src/contraction/contraction_cpu_reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,25 @@

#include <hiptensor/hiptensor.hpp>

hiptensorStatus_t hiptensorContractionReference(void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ks_ns_lengths,
std::vector<size_t> const& b_ks_ns_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace);
hiptensorStatus_t hiptensorContractionReference(const hiptensorContractionPlan_t* plan,
void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ks_ns_lengths,
std::vector<size_t> const& b_ks_ns_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace);

#endif // HIPTENSOR_CONTRACTION_CPU_REFERENCE_HPP
60 changes: 39 additions & 21 deletions library/src/contraction/contraction_cpu_reference_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,25 @@
namespace hiptensor
{
// hardcoded for NumDimM == NumDimN == NumDimK == 2
//
// ck::bhalf_t is ushort, cannot perform bhalf_t * bhalf_t
// CK does not use ck::bhalf_t as AccDataType. But we still
// add this guard here
template <
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1,
typename ComputeDataType = ADataType,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1
&& !std::is_same_v<AccDataType, ck::bhalf_t>,
bool>
= false>
struct ReferenceContraction_M2_N2_K2
Expand All @@ -70,7 +76,8 @@ namespace hiptensor
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
Expand Down Expand Up @@ -150,7 +157,7 @@ namespace hiptensor
};

auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
auto accum = static_cast<AccDataType>(0);
AccDataType accum = 0;

auto K0 = arg.mA_ms_ks_lengths[2];
auto K1 = arg.mA_ms_ks_lengths[3];
Expand All @@ -164,16 +171,19 @@ namespace hiptensor
auto indexB
= offset(std::vector<size_t>{n0, n1, k0, k1}, arg.mB_ns_ks_strides);

ADataType valA;
BDataType valB;
AccDataType valA;
AccDataType valB;

// Element-wise ops
arg.mOpA(valA, ((ADataType*)arg.mA)[indexA]);
arg.mOpB(valB, ((BDataType*)arg.mB)[indexB]);
arg.mOpA(
valA,
ck::type_convert<ComputeDataType>(((ADataType*)arg.mA)[indexA]));
arg.mOpB(
valB,
ck::type_convert<ComputeDataType>(((BDataType*)arg.mB)[indexB]));

// Mult / accum
accum
+= static_cast<AccDataType>(valA) * static_cast<AccDataType>(valB);
accum += valA * valB;
}
}

Expand All @@ -182,15 +192,17 @@ namespace hiptensor
if constexpr(std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::Scale>)
{
arg.mOpCDE(((EDataType*)arg.mE)[indexE], accum);
arg.mOpCDE(((EDataType*)arg.mE)[indexE],
ck::type_convert<EDataType>(accum));
}
else // bilinear
{
// NumDTensor will be 1 due to SFINAE of this class
auto indexD
= offset(std::vector<size_t>{m0, m1, n0, n1}, arg.mD_ms_ns_strides[0]);
arg.mOpCDE(
((EDataType*)arg.mE)[indexE], accum, ((EDataType*)(arg.mD[0]))[indexD]);
arg.mOpCDE(((EDataType*)arg.mE)[indexE],
ck::type_convert<EDataType>(accum),
((EDataType*)(arg.mD[0]))[indexD]);
}
};

Expand Down Expand Up @@ -319,23 +331,25 @@ namespace hiptensor
ck::index_t NumDimsK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AccumDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType>
struct MetaTraits<ReferenceContraction_M2_N2_K2<NumDimsM,
NumDimsN,
NumDimsK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
AccumDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>
CDEElementwiseOperation,
ComputeDataType>>
: public MetaTraits<
ck::tensor_operation::device::DeviceContractionMultipleD<NumDimsM,
NumDimsN,
Expand All @@ -346,7 +360,8 @@ namespace hiptensor
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>
CDEElementwiseOperation,
ComputeDataType>>
{
};

Expand All @@ -355,24 +370,27 @@ namespace hiptensor
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
auto enumerateReferenceSolutions()
{
using ReferenceOp = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
ComputeDataType>;

auto solution = std::make_unique<ContractionSolutionImpl<ReferenceOp>>(
std::make_unique<ReferenceOp>());
Expand Down
Loading

0 comments on commit 8c11d59

Please sign in to comment.