Skip to content

Commit 926e544

Browse files
committed
[CPU] Simplify matmul weights transposition logic
1 parent 5cda7b6 commit 926e544

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
#include "openvino/core/type/element_type.hpp"
3939
#include "post_ops.hpp"
4040
#include "thread_pool_imp.hpp"
41-
#include "utils/cpu_utils.hpp"
4241
#include "utils/debug_capabilities.h"
4342
#include "utils/general_utils.h"
4443

@@ -154,24 +153,44 @@ std::shared_ptr<DnnlMatMulPrimitive> DnnlMatMulPrimitive::create(const MemoryArg
154153
DnnlMemoryDescPtr DnnlMatMulPrimitive::makeTransposedWeightDescriptor(const DnnlMemoryDescPtr& srcDesc,
155154
const DnnlMemoryDescPtr& dstDesc,
156155
const MatMulAttrs& attrs) {
157-
if (!attrs.fcSemantic) {
158-
return dstDesc;
159-
}
156+
OPENVINO_ASSERT(attrs.constantWeights, "DnnlMatmulExecutor: constant weights are expected");
157+
158+
auto getDims = [](const dnnl::memory::desc& desc, const bool transpose) {
159+
auto dims = desc.get_dims();
160+
if (transpose) {
161+
std::swap(dims[dims.size() - 1], dims[dims.size() - 2]);
162+
return dims;
163+
}
164+
165+
return desc.get_dims();
166+
};
167+
168+
auto getFormat = [](const size_t rank, const bool transpose) {
169+
switch (rank) {
170+
case 2:
171+
return transpose ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
172+
case 3:
173+
return transpose ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::acb;
174+
default:
175+
OPENVINO_THROW("DnnlMatmulExecutor: unsupported weights rank: ", rank);
176+
}
177+
};
160178

161-
const bool weightsNonTransposed = attrs.weightsNonTransposed;
162179
const auto& weiDesc = srcDesc->getDnnlDesc();
163-
auto wDims = weiDesc.get_dims();
164-
std::swap(wDims[wDims.size() - 1], wDims[wDims.size() - 2]);
165180
const auto wDataType = weiDesc.get_data_type();
166-
if (wDims.size() == 3 && !weightsNonTransposed) {
167-
const auto format3D = dnnl::memory::format_tag::acb;
168-
const auto transposed3DWeiDesc = dnnl::memory::desc{wDims, wDataType, format3D};
169-
return DnnlExtensionUtils::makeDescriptor(transposed3DWeiDesc);
170-
}
171-
172-
const dnnl::memory::dims wDims2D = reshapeDownToRank<2>(wDims);
173-
const auto format = weightsNonTransposed ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
174-
const auto transposedWeiDesc = dnnl::memory::desc{wDims2D, wDataType, format};
181+
/* in case the second input is constant we can transpose weights beforehand
182+
* and avoid transposition during execution */
183+
const auto wDims = getDims(weiDesc, attrs.transposeB);
184+
/**
185+
* We need to transpose weights in two scenarios:
186+
* - dnnl matmul is used as FullyConnected executor and weights are not transposed yet (optimization to pack weights
187+
* in one step)
188+
* - dnnl mamtul is used as MatMul executor with transposeB equal to true. This case is just theoretical since
189+
* currently we always convert MatMul operation with constant second input to FullyConnected operation.
190+
*/
191+
const bool transpose = attrs.fcSemantic ? attrs.weightsNonTransposed : attrs.transposeB;
192+
const auto format = getFormat(weiDesc.get_ndims(), transpose);
193+
const auto transposedWeiDesc = dnnl::memory::desc{wDims, wDataType, format};
175194

176195
const auto reshapedWeiDesc = transposedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims());
177196

@@ -424,8 +443,6 @@ static std::pair<VectorDims, VectorDims> makeDummyInputDims(const Shape& in0,
424443
} else {
425444
inDims1[idx1] = inDims0[idx0];
426445
}
427-
} else if (inDims0[idx0] != Shape::UNDEFINED_DIM && inDims1[idx1] != Shape::UNDEFINED_DIM) {
428-
inDims1[idx1] = inDims0[idx0];
429446
}
430447
}
431448
};
@@ -530,9 +547,6 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const MatM
530547
attrs.transposeA,
531548
attrs.transposeB,
532549
dstDesc->getShape().getRank());
533-
if (attrs.fcSemantic && weiDymmyDims.size() == 3) {
534-
std::swap(weiDymmyDims[weiDymmyDims.size() - 1], weiDymmyDims[weiDymmyDims.size() - 2]);
535-
}
536550
srcDesc = std::make_shared<DnnlBlockedMemoryDesc>(srcDesc->getPrecision(), Shape(inDymmyDims));
537551
weiDesc = std::make_shared<DnnlBlockedMemoryDesc>(weiDesc->getPrecision(), Shape(weiDymmyDims));
538552
dstDesc = std::make_shared<DnnlBlockedMemoryDesc>(dstDesc->getPrecision(), Shape(outDymmyDims));

src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,19 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
438438
[](const FCAttrs& attrs,
439439
const MemoryArgs& memory,
440440
const ExecutorContext::CPtr& context) -> ExecutorPtr {
441-
MatMulAttrs matMulAttrs{false,
442-
false};
443-
matMulAttrs.postOps = attrs.postOps;
444-
matMulAttrs.weightsNonTransposed = attrs.weightsNonTransposed;
445-
matMulAttrs.constantWeights = true;
446-
matMulAttrs.fcSemantic = true;
441+
const bool hasBias = !memory.at(ARG_BIAS)->getDesc().empty();
442+
MatMulAttrs matMulAttrs {
443+
false,
444+
true,
445+
hasBias,
446+
attrs.weightsNonTransposed,
447+
false,
448+
true,
449+
true,
450+
0,
451+
{},
452+
attrs.postOps
453+
};
447454

448455
return std::make_shared<
449456
DnnlExecutor<DnnlMatMulPrimitive, MatMulAttrs, DnnlShapeAgnosticData,

0 commit comments

Comments
 (0)