Skip to content

Commit

Permalink
Merge pull request ROCm#215 from CongMa13/fix_yaml
Browse files Browse the repository at this point in the history
Add mode test
  • Loading branch information
CongMa13 authored Apr 30, 2024
2 parents 0ee375e + 194e42e commit a2c3e10
Show file tree
Hide file tree
Showing 15 changed files with 428 additions and 173 deletions.
12 changes: 4 additions & 8 deletions library/src/contraction/contraction_selection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,14 @@ namespace hiptensor
const uint64_t workspaceSize)
{
// Make sure that we calculate full element space incase strides are not packed.
auto sizeA = elementSpaceFromLengthsAndStrides(a_ms_ks_lengths, a_ms_ks_strides)
* hipDataTypeSize(typeA);
auto sizeB = elementSpaceFromLengthsAndStrides(b_ns_ks_lengths, b_ns_ks_strides)
* hipDataTypeSize(typeB);
auto sizeA = elementsFromLengths(a_ms_ks_lengths) * hipDataTypeSize(typeA);
auto sizeB = elementsFromLengths(b_ns_ks_lengths) * hipDataTypeSize(typeB);
auto sizeD = 0;
if(typeD != NONE_TYPE)
{
sizeD = elementSpaceFromLengthsAndStrides(d_ms_ns_lengths, d_ms_ns_strides)
* hipDataTypeSize(typeD);
sizeD = elementsFromLengths(d_ms_ns_lengths) * hipDataTypeSize(typeD);
}
auto sizeE = elementSpaceFromLengthsAndStrides(e_ms_ns_lengths, e_ms_ns_strides)
* hipDataTypeSize(typeE);
auto sizeE = elementsFromLengths(e_ms_ns_lengths) * hipDataTypeSize(typeE);

void *A_d, *B_d, *D_d, *E_d, *wspace;

Expand Down
130 changes: 130 additions & 0 deletions library/src/contraction/contraction_solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,140 @@
*
*******************************************************************************/

#include <set>

#include "contraction_solution.hpp"
#include "util.hpp"

namespace hiptensor
{
std::array<std::vector<std::size_t>, 8>
normalizeTensorModes(std::vector<std::size_t> const& a_ms_ks_lengths,
std::vector<std::size_t> const& a_ms_ks_strides,
std::vector<int32_t> const& a_ms_ks_modes,
std::vector<std::size_t> const& b_ns_ks_lengths,
std::vector<std::size_t> const& b_ns_ks_strides,
std::vector<int32_t> const& b_ns_ks_modes,
std::vector<std::size_t> const& e_ms_ns_lengths,
std::vector<std::size_t> const& e_ms_ns_strides,
std::vector<int32_t> const& e_ms_ns_modes)
{
std::vector<std::size_t> normal_a_ms_ks_lengths(MaxNumDimsM + MaxNumDimsK, 1);
std::vector<std::size_t> normal_a_ms_ks_strides(MaxNumDimsM + MaxNumDimsK, 1);
std::vector<int32_t> normal_a_ms_ks_modes(MaxNumDimsM + MaxNumDimsK, -1);
std::vector<std::size_t> normal_b_ns_ks_lengths(MaxNumDimsK + MaxNumDimsN, 1);
std::vector<std::size_t> normal_b_ns_ks_strides(MaxNumDimsK + MaxNumDimsN, 1);
std::vector<int32_t> normal_b_ns_ks_modes(MaxNumDimsK + MaxNumDimsN, -1);
std::vector<std::size_t> normal_e_ms_ns_lengths(MaxNumDimsM + MaxNumDimsN, 1);
std::vector<std::size_t> normal_e_ms_ns_strides(MaxNumDimsM + MaxNumDimsN, 1);
std::vector<int32_t> normal_e_ms_ns_modes(MaxNumDimsM + MaxNumDimsN, -1);
int mOffset = 0;
int nOffset = 0;

// reorder m, n in A, B
for(int i = 0; i < e_ms_ns_modes.size(); i++)
{
if(auto aIt = std::find(a_ms_ks_modes.cbegin(), a_ms_ks_modes.cend(), e_ms_ns_modes[i]);
aIt != a_ms_ks_modes.cend())
{
auto offset = std::distance(a_ms_ks_modes.cbegin(), aIt);
normal_a_ms_ks_modes[mOffset] = a_ms_ks_modes[offset];
normal_a_ms_ks_lengths[mOffset] = a_ms_ks_lengths[offset];
normal_a_ms_ks_strides[mOffset] = a_ms_ks_strides[offset];
mOffset++;
}
else
{
auto bIt
= std::find(b_ns_ks_modes.cbegin(), b_ns_ks_modes.cend(), e_ms_ns_modes[i]);
auto offset = std::distance(b_ns_ks_modes.cbegin(), bIt);
normal_b_ns_ks_modes[nOffset] = b_ns_ks_modes[offset];
normal_b_ns_ks_lengths[nOffset] = b_ns_ks_lengths[offset];
normal_b_ns_ks_strides[nOffset] = b_ns_ks_strides[offset];
nOffset++;
}
}

assert(mOffset > 0 && nOffset > 0);
for(; mOffset < MaxNumDimsM; mOffset++)
{
normal_a_ms_ks_lengths[mOffset] = 1;
normal_a_ms_ks_strides[mOffset] = normal_a_ms_ks_strides[mOffset - 1];
}
for(; nOffset < MaxNumDimsN; nOffset++)
{
normal_b_ns_ks_lengths[nOffset] = 1;
normal_b_ns_ks_strides[nOffset] = normal_b_ns_ks_strides[nOffset - 1];
}

// reorder k in A, B - Do not check if A and B have same k here.
for(int i = 0; i < a_ms_ks_modes.size(); i++)
{
if(auto it = std::find(b_ns_ks_modes.cbegin(), b_ns_ks_modes.cend(), a_ms_ks_modes[i]);
it != b_ns_ks_modes.cend())
{
normal_a_ms_ks_modes[mOffset] = a_ms_ks_modes[i];
normal_a_ms_ks_lengths[mOffset] = a_ms_ks_lengths[i];
normal_a_ms_ks_strides[mOffset] = a_ms_ks_strides[i];
mOffset++;

auto bOffset = std::distance(b_ns_ks_modes.cbegin(), it);
normal_b_ns_ks_modes[nOffset] = b_ns_ks_modes[bOffset];
normal_b_ns_ks_lengths[nOffset] = b_ns_ks_lengths[bOffset];
normal_b_ns_ks_strides[nOffset] = b_ns_ks_strides[bOffset];
nOffset++;
}
}

assert(mOffset > 0 && nOffset > 0);
for(; mOffset < MaxNumDimsM + MaxNumDimsK; mOffset++)
{
normal_a_ms_ks_lengths[mOffset] = 1;
normal_a_ms_ks_strides[mOffset] = normal_a_ms_ks_strides[mOffset - 1];
}
for(; nOffset < MaxNumDimsN + MaxNumDimsK; nOffset++)
{
normal_b_ns_ks_lengths[nOffset] = 1;
normal_b_ns_ks_strides[nOffset] = normal_b_ns_ks_strides[nOffset - 1];
}

// reorder m, n in D, E
std::vector<int32_t> contraction_result_modes(MaxNumDimsM + MaxNumDimsN, -1);
std::copy(normal_a_ms_ks_modes.cbegin(),
normal_a_ms_ks_modes.cbegin() + MaxNumDimsM,
contraction_result_modes.begin());
std::copy(normal_b_ns_ks_modes.cbegin(),
normal_b_ns_ks_modes.cbegin() + MaxNumDimsN,
contraction_result_modes.begin() + MaxNumDimsM);

for(int i = 0; i < contraction_result_modes.size(); i++)
{
auto it = std::find(
e_ms_ns_modes.cbegin(), e_ms_ns_modes.cend(), contraction_result_modes[i]);
if(it != e_ms_ns_modes.cend())
{
auto offset = std::distance(e_ms_ns_modes.cbegin(), it);
normal_e_ms_ns_lengths[i] = e_ms_ns_lengths[offset];
normal_e_ms_ns_strides[i] = e_ms_ns_strides[offset];
}
else
{
normal_e_ms_ns_lengths[i] = 1;
normal_e_ms_ns_strides[i] = normal_e_ms_ns_strides[i - 1];
}
}

return {
normal_a_ms_ks_lengths,
normal_a_ms_ks_strides,
normal_b_ns_ks_lengths,
normal_b_ns_ks_strides,
normal_e_ms_ns_lengths,
normal_e_ms_ns_strides,
normal_e_ms_ns_lengths,
normal_e_ms_ns_strides,
};
}

ContractionSolution::ContractionSolution(
std::unique_ptr<ck::tensor_operation::device::BaseOperator>&& deviceOp,
Expand Down
Loading

0 comments on commit a2c3e10

Please sign in to comment.