|
38 | 38 | #include "openvino/core/type/element_type.hpp" |
39 | 39 | #include "post_ops.hpp" |
40 | 40 | #include "thread_pool_imp.hpp" |
41 | | -#include "utils/cpu_utils.hpp" |
42 | 41 | #include "utils/debug_capabilities.h" |
43 | 42 | #include "utils/general_utils.h" |
44 | 43 |
|
@@ -154,24 +153,44 @@ std::shared_ptr<DnnlMatMulPrimitive> DnnlMatMulPrimitive::create(const MemoryArg |
154 | 153 | DnnlMemoryDescPtr DnnlMatMulPrimitive::makeTransposedWeightDescriptor(const DnnlMemoryDescPtr& srcDesc, |
155 | 154 | const DnnlMemoryDescPtr& dstDesc, |
156 | 155 | const MatMulAttrs& attrs) { |
157 | | - if (!attrs.fcSemantic) { |
158 | | - return dstDesc; |
159 | | - } |
| 156 | + OPENVINO_ASSERT(attrs.constantWeights, "DnnlMatmulExecutor: constant weights are expected"); |
| 157 | + |
| 158 | + auto getDims = [](const dnnl::memory::desc& desc, const bool transpose) { |
| 159 | + auto dims = desc.get_dims(); |
| 160 | + if (transpose) { |
| 161 | + std::swap(dims[dims.size() - 1], dims[dims.size() - 2]); |
| 162 | + return dims; |
| 163 | + } |
| 164 | + |
| 165 | + return desc.get_dims(); |
| 166 | + }; |
| 167 | + |
| 168 | + auto getFormat = [](const size_t rank, const bool transpose) { |
| 169 | + switch (rank) { |
| 170 | + case 2: |
| 171 | + return transpose ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; |
| 172 | + case 3: |
| 173 | + return transpose ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::acb; |
| 174 | + default: |
| 175 | + OPENVINO_THROW("DnnlMatmulExecutor: unsupported weights rank: ", rank); |
| 176 | + } |
| 177 | + }; |
160 | 178 |
|
161 | | - const bool weightsNonTransposed = attrs.weightsNonTransposed; |
162 | 179 | const auto& weiDesc = srcDesc->getDnnlDesc(); |
163 | | - auto wDims = weiDesc.get_dims(); |
164 | | - std::swap(wDims[wDims.size() - 1], wDims[wDims.size() - 2]); |
165 | 180 | const auto wDataType = weiDesc.get_data_type(); |
166 | | - if (wDims.size() == 3 && !weightsNonTransposed) { |
167 | | - const auto format3D = dnnl::memory::format_tag::acb; |
168 | | - const auto transposed3DWeiDesc = dnnl::memory::desc{wDims, wDataType, format3D}; |
169 | | - return DnnlExtensionUtils::makeDescriptor(transposed3DWeiDesc); |
170 | | - } |
171 | | - |
172 | | - const dnnl::memory::dims wDims2D = reshapeDownToRank<2>(wDims); |
173 | | - const auto format = weightsNonTransposed ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; |
174 | | - const auto transposedWeiDesc = dnnl::memory::desc{wDims2D, wDataType, format}; |
| 181 | + /* in case the second input is constant we can transpose weights beforehand |
| 182 | + * and avoid transposition during execution */ |
| 183 | + const auto wDims = getDims(weiDesc, attrs.transposeB); |
| 184 | + /** |
| 185 | + * We need to transpose weights in two scenarios: |
| 186 | + * - dnnl matmul is used as FullyConnected executor and weights are not transposed yet (optimization to pack weights |
| 187 | + * in one step) |
| 188 | + * - dnnl mamtul is used as MatMul executor with transposeB equal to true. This case is just theoretical since |
| 189 | + * currently we always convert MatMul operation with constant second input to FullyConnected operation. |
| 190 | + */ |
| 191 | + const bool transpose = attrs.fcSemantic ? attrs.weightsNonTransposed : attrs.transposeB; |
| 192 | + const auto format = getFormat(weiDesc.get_ndims(), transpose); |
| 193 | + const auto transposedWeiDesc = dnnl::memory::desc{wDims, wDataType, format}; |
175 | 194 |
|
176 | 195 | const auto reshapedWeiDesc = transposedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims()); |
177 | 196 |
|
@@ -424,8 +443,6 @@ static std::pair<VectorDims, VectorDims> makeDummyInputDims(const Shape& in0, |
424 | 443 | } else { |
425 | 444 | inDims1[idx1] = inDims0[idx0]; |
426 | 445 | } |
427 | | - } else if (inDims0[idx0] != Shape::UNDEFINED_DIM && inDims1[idx1] != Shape::UNDEFINED_DIM) { |
428 | | - inDims1[idx1] = inDims0[idx0]; |
429 | 446 | } |
430 | 447 | } |
431 | 448 | }; |
@@ -530,9 +547,6 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const MatM |
530 | 547 | attrs.transposeA, |
531 | 548 | attrs.transposeB, |
532 | 549 | dstDesc->getShape().getRank()); |
533 | | - if (attrs.fcSemantic && weiDymmyDims.size() == 3) { |
534 | | - std::swap(weiDymmyDims[weiDymmyDims.size() - 1], weiDymmyDims[weiDymmyDims.size() - 2]); |
535 | | - } |
536 | 550 | srcDesc = std::make_shared<DnnlBlockedMemoryDesc>(srcDesc->getPrecision(), Shape(inDymmyDims)); |
537 | 551 | weiDesc = std::make_shared<DnnlBlockedMemoryDesc>(weiDesc->getPrecision(), Shape(weiDymmyDims)); |
538 | 552 | dstDesc = std::make_shared<DnnlBlockedMemoryDesc>(dstDesc->getPrecision(), Shape(outDymmyDims)); |
|
0 commit comments