diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index 8b4ffaefcabfd3..fad6613f36b6cb 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -245,6 +245,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"Ngram", Type::Ngram}, {"ScaledDotProductAttention", Type::ScaledDotProductAttention}, {"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention}, + {"SDPAWithTransposeReshape", Type::ScaledDotProductAttention}, {"PagedAttentionExtension", Type::PagedAttention}, {"RoPE", Type::RoPE}, {"GatherCompressed", Type::Gather}, diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index f2256d9d03df15..a29282d4af3101 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -75,6 +75,7 @@ class TypeRelaxedExtension : public ov::OpExtension> { OP_EXTENSION(ov::intel_cpu::PowerStaticNode) \ OP_EXTENSION(ov::intel_cpu::CausalMaskPreprocessNode) \ OP_EXTENSION(ov::intel_cpu::SwishNode) \ + OP_EXTENSION(ov::intel_cpu::SDPAWithTransposeReshape) \ OP_EXTENSION(ov::intel_cpu::NgramNode) \ OP_EXTENSION(ov::op::internal::GatherCompressed) \ OP_EXTENSION(ov::op::internal::NonMaxSuppressionIEInternal) \ diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index e70a3932b11b1e..e229ff4bb72c57 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -866,6 +866,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt void execute(dnnl::stream strm, const Config& config, const std::vector& inputs, const MemoryPtr output, const MemoryPtr presentk_input, const MemoryPtr presentv_input, const MemoryPtr beam_input, const PlainTensor& k_scale_zp, const PlainTensor& v_scale_zp) override { + bool has_in_reshape = config.config.input_BLHxS; bool has_out_transpose = config.config.output_BLHxS; bool fuse_causal_attn = config.config.fuse_causal_attn; bool is_causal = config.config.is_causal; @@ -881,11 +882,28 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt float scale_input = 0.0f; size_t B, L1, L0, S, SV; + // B,L,H*S->B,L,H,S + auto get_reshape_shape = [&config](const PlainTensor& input) { + // [B,L,H*S] + auto inp_shape = input.shape(); + // [B,L,H,S] + return VectorDims{inp_shape[0], inp_shape[1], config.config.order_HS[0], config.config.order_HS[1]}; + }; + q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); present_key.reset(presentk_input); present_value.reset(presentv_input); + if (has_in_reshape) { + q_input = q_input.reshape(get_reshape_shape(q_input)); + auto kv_shape = get_reshape_shape(k_input); + k_input = k_input.reshape(kv_shape); + v_input = v_input.reshape(kv_shape); + present_key = present_key.reshape(kv_shape); + present_value = present_value.reshape(kv_shape); + } + if (beam_input) beam_table.reset(beam_input); if (input_num > 3) { @@ -985,11 +1003,11 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr(op); - if (node) { + if (const auto node = std::dynamic_pointer_cast(op)) { m_config.config.is_causal = node->get_causal(); - } else { - const auto node = std::dynamic_pointer_cast(op); + } else if (const auto node = std::dynamic_pointer_cast(op)) { + m_config.config = node->get_config(); + } else if (const auto node = std::dynamic_pointer_cast(op)) { m_config.config = node->get_config(); } } @@ -1142,17 +1160,28 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { + auto sdpaWithTransposeReshapeOp = std::dynamic_pointer_cast(op); if (!std::dynamic_pointer_cast(op) && - !std::dynamic_pointer_cast(op)) { - errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionWithKVCache operation are supported"; + !std::dynamic_pointer_cast(op) && !sdpaWithTransposeReshapeOp) { + errorMessage = "Only ScaledDotProductAttention, ScaledDotProductAttentionWithKVCache or " + "SDPAWithTransposeReshape operation are supported"; return false; } - // expect shape of q: [B, H, L, S] auto inRank = op->get_input_partial_shape(0).size(); - if (inRank != 4u) { - errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank); - return false; + if (sdpaWithTransposeReshapeOp) { + // inRank expect shape of q: [B, L, H*S] + if (inRank != 3u) { + errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank); + return false; + } + } else { + // inRank expect shape of q: [B, H, L, S] + if (inRank != 4u) { + errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank); + return false; + } } + int orgSDPAInput = static_cast(op->get_input_size()); const auto node = std::dynamic_pointer_cast(op); if (node) { diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp index 4421499d10204d..bea56e2b8c833f 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp @@ -99,4 +99,46 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A visitor.on_attribute("permute_axes", m_config.permute_axes); visitor.finish_structure(); return true; +} + +ov::intel_cpu::SDPAWithTransposeReshape::SDPAWithTransposeReshape(const OutputVector& args, const Config& cfg) + : Op(args), + m_config(cfg) {} + +std::shared_ptr ov::intel_cpu::SDPAWithTransposeReshape::clone_with_new_inputs( + const ov::OutputVector& new_args) const { + INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args, m_config); +} + +void ov::intel_cpu::SDPAWithTransposeReshape::validate_and_infer_types() { + INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_validate_and_infer_types); + // [B,L,H*S] + auto q_ps = get_input_partial_shape(0); + auto output_ps = q_ps; + NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == true); + NODE_VALIDATION_CHECK(this, m_config.input_BLHxS == true); + NODE_VALIDATION_CHECK(this, q_ps.size() == 3u); + + // permute_axes should be [B, H, L, S] + const auto& permute_axes = this->m_config.permute_axes; + NODE_VALIDATION_CHECK(this, permute_axes.size() == 4u); + + // order_HS should be [H,S] + const auto& order_HS = this->m_config.order_HS; + NODE_VALIDATION_CHECK(this, order_HS.size() == 2u); + + set_output_type(0, get_input_element_type(0), output_ps); +} + +bool ov::intel_cpu::SDPAWithTransposeReshape::visit_attributes(ov::AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_visit_attributes); + visitor.start_structure("config"); + visitor.on_attribute("input_BLHxS", m_config.input_BLHxS); + visitor.on_attribute("output_BLHxS", m_config.output_BLHxS); + visitor.on_attribute("permute_axes", m_config.permute_axes); + visitor.on_attribute("order_HS", m_config.order_HS); + visitor.finish_structure(); + return true; } \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp index 8fe1c9ce4ffa19..8c811f16262734 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp @@ -21,13 +21,15 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { ScaledDotProductAttentionWithKVCache() = default; struct Config { - bool output_BLHxS = false; // true implies that output is [B,L,H*S] + bool input_BLHxS = false; // true implies that input is [B,L,H*S] + bool output_BLHxS = false; // true implies that output is [B,L,H*S] - bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask - bool is_causal = false; // apply causal mask internally - bool fuse_concat = false; // fuse (concat->sdp) ==> sdp - std::vector permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S] - // e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S] + bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask + bool is_causal = false; // apply causal mask internally + bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + std::vector permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S] + // e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S] + std::vector order_HS; // Reshape[B,L,H*S]->B,L,H,S], H,S are fixed value, when input_BLHxS is true. }; ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg); @@ -48,5 +50,30 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { Config m_config; }; +class SDPAWithTransposeReshape : public ov::op::Op { +public: + OPENVINO_OP("SDPAWithTransposeReshape", "cpu_plugin_opset"); + using Config = ScaledDotProductAttentionWithKVCache::Config; + + SDPAWithTransposeReshape() = default; + + SDPAWithTransposeReshape(const OutputVector& args, const Config& cfg); + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + + const Config& get_config() const { + return m_config; + } + + Config& get_config() { + return m_config; + } + +private: + Config m_config; +}; + } // namespace intel_cpu } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.cpp new file mode 100644 index 00000000000000..3aa0fd0d08e69b --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.cpp @@ -0,0 +1,188 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdpa_fuse_transpose_reshape.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" + +/* + * Description: SDPA fuse transpose and reshape. + * Original pattern Fused pattern + * + * input1 input2 input3 + * | | | + * q_reshape k_reshape v_reshap + * | | | (qkv transpose and reshape's orders) + * q_transpose k_transpose v_transpose | + * \ | / input1 input2 input3 | + * \ | / \ | / / + * ScaledDotProductAttention ---------> SDPAWithTransposeReshape + * | | + * out_transpose | + * | output + * out_reshpae + * | + * output + */ + +using namespace ov; +using namespace ov::pass::pattern; + +intel_cpu::SDPAFuseTransposeReshape::SDPAFuseTransposeReshape() { + MATCHER_SCOPE(SDPAFuseTransposeReshape); + + auto q_reshape_node = wrap_type({any_input(), any_input()}); + auto k_reshape_node = wrap_type({any_input(), any_input()}); + auto v_reshape_node = wrap_type({any_input(), any_input()}); + + auto q_transpose_order_node = wrap_type(); + auto k_transpose_order_node = wrap_type(); + auto v_transpose_order_node = wrap_type(); + auto q_transpose_node = wrap_type({q_reshape_node, q_transpose_order_node}); + auto k_transpose_node = wrap_type({k_reshape_node, k_transpose_order_node}); + auto v_transpose_node = wrap_type({v_reshape_node, v_transpose_order_node}); + + auto sdpa_node = + wrap_type({q_transpose_node, k_transpose_node, v_transpose_node}); + + auto out_transpose_order_node = wrap_type(); + auto out_transpose_node = wrap_type({sdpa_node, out_transpose_order_node}); + auto out_reshape_node = wrap_type({out_transpose_node, wrap_type()}); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pass::pattern::Matcher& m) { + auto& pattern_map = m.get_pattern_value_map(); + auto sdpa = as_type_ptr(pattern_map.at(sdpa_node).get_node_shared_ptr()); + if (sdpa == nullptr || transformation_callback(sdpa)) { + return false; + } + + // Order=[0, 2, 1, 3] + auto is_expected_transpose = [&](std::shared_ptr& transpose) { + if (transpose) { + const auto orders = as_type_ptr(transpose->get_input_node_shared_ptr(1)); + return orders && (std::vector({0, 2, 1, 3}) == orders->cast_vector()); + } + return false; + }; + + // Reshape [B,L,H*S] -> [B,L,H,S] + auto is_expected_reshape = [&](std::shared_ptr& reshape_node, bool reverse = false) { + if (reshape_node) { + auto inp_shape = reshape_node->get_input_partial_shape(0); + auto outp_shape = reshape_node->get_output_partial_shape(0); + // Expect shape: [?, ?, val] + auto check_dim_3 = [](ov::PartialShape shape) { + return shape.rank().is_static() && shape.rank() == 3 && shape[2].is_static(); + }; + // Expect shape: [?, ?, val, val] + auto check_dim_4 = [](ov::PartialShape shape) { + return shape.rank().is_static() && shape.rank() == 4 && shape[2].is_static() && + shape[3].is_static(); + }; + + if (reverse) { + return check_dim_4(inp_shape) && check_dim_3(outp_shape) && + (outp_shape[2] == inp_shape[2] * inp_shape[3]); + } else { + return check_dim_3(inp_shape) && check_dim_4(outp_shape) && + (inp_shape[2] == outp_shape[2] * outp_shape[3]); + } + } + return false; + }; + + // Pattern: Reshape->Transpose->SDPA + auto q_reshape = as_type_ptr(pattern_map.at(q_reshape_node).get_node_shared_ptr()); + auto k_reshape = as_type_ptr(pattern_map.at(k_reshape_node).get_node_shared_ptr()); + auto v_reshape = as_type_ptr(pattern_map.at(v_reshape_node).get_node_shared_ptr()); + + if (!(is_expected_reshape(q_reshape) && is_expected_reshape(k_reshape) && is_expected_reshape(v_reshape))) { + return false; + } + // K,V Reshape's order should be same node. + auto k_reshape_order = as_type_ptr(k_reshape->get_input_node_shared_ptr(1)); + auto v_reshape_order = as_type_ptr(v_reshape->get_input_node_shared_ptr(1)); + if (k_reshape_order && v_reshape_order) { + if (k_reshape_order->cast_vector() != v_reshape_order->cast_vector()) { + return false; + } + } else if (k_reshape->get_input_node_shared_ptr(1) != v_reshape->get_input_node_shared_ptr(1)) { + return false; + } + + std::shared_ptr qkv_transpose[3] = {}; + std::shared_ptr qkv_transpose_order[3] = {}; + qkv_transpose[0] = as_type_ptr(pattern_map.at(q_transpose_node).get_node_shared_ptr()); + qkv_transpose[1] = as_type_ptr(pattern_map.at(k_transpose_node).get_node_shared_ptr()); + qkv_transpose[2] = as_type_ptr(pattern_map.at(v_transpose_node).get_node_shared_ptr()); + qkv_transpose_order[0] = as_type_ptr(pattern_map.at(q_transpose_order_node).get_node_shared_ptr()); + qkv_transpose_order[1] = as_type_ptr(pattern_map.at(k_transpose_order_node).get_node_shared_ptr()); + qkv_transpose_order[2] = as_type_ptr(pattern_map.at(v_transpose_order_node).get_node_shared_ptr()); + auto out_tranpose = as_type_ptr(pattern_map.at(out_transpose_node).get_node_shared_ptr()); + auto out_transpose_order = as_type_ptr(pattern_map.at(out_transpose_order_node).get_node_shared_ptr()); + + if (!(is_expected_transpose(qkv_transpose[0]) && is_expected_transpose(qkv_transpose[1]) && + is_expected_transpose(qkv_transpose[2]))) { + return false; + } + if (!is_expected_transpose(out_tranpose)) { + return false; + } + + auto out_reshape = as_type_ptr(pattern_map.at(out_reshape_node).get_node_shared_ptr()); + if (!is_expected_reshape(out_reshape, true)) { + return false; + } + + OutputVector args = {q_reshape->get_input_node_shared_ptr(0), + k_reshape->get_input_node_shared_ptr(0), + v_reshape->get_input_node_shared_ptr(0)}; + + // Config + intel_cpu::SDPAWithTransposeReshape::Config config; + config.is_causal = sdpa->get_causal(); + config.fuse_concat = false; + config.output_BLHxS = true; + + // Config::permute_axes + const auto& permute_q = qkv_transpose_order[0]->cast_vector(); + config.permute_axes.resize(permute_q.size()); + for (size_t i = 0; i < permute_q.size(); i++) { + config.permute_axes[i] = static_cast(permute_q[i]); + } + + // Config::order_HS + config.order_HS.resize(2); + auto reshape_out_shape = q_reshape->get_output_partial_shape(0).get_min_shape(); // [?,?,H,S] + config.order_HS[0] = reshape_out_shape[2]; + config.order_HS[1] = reshape_out_shape[3]; + config.input_BLHxS = true; + + auto new_sdpa = std::make_shared(args, config); + new_sdpa->set_friendly_name(sdpa->get_friendly_name() + "/fused_reshape_transpose"); + NodeVector replaced_nodes = {q_reshape, + k_reshape, + v_reshape, + qkv_transpose[0], + qkv_transpose[1], + qkv_transpose[2], + sdpa, + out_tranpose, + out_reshape}; + copy_runtime_info(replaced_nodes, new_sdpa); + ov::replace_node(out_reshape, new_sdpa); + return true; + }; + + auto m = std::make_shared(out_reshape_node, matcher_name); + register_matcher(m, callback); +} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp new file mode 100644 index 00000000000000..74ba6ec6221d1e --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace intel_cpu { +class SDPAFuseTransposeReshape : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("SDPAFuseTransposeReshape", "0"); + SDPAFuseTransposeReshape(); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 04808baaebec54..e45b6379d1e968 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -139,6 +139,7 @@ #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" #include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp" #include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp" +#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -864,6 +865,7 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward); CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm); CPU_SET_CALLBACK_X64(postLPTPassManager, diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/fuse_reshape_transpose_to_sdpa.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/fuse_reshape_transpose_to_sdpa.cpp new file mode 100644 index 00000000000000..a75156c0f69fcb --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/fuse_reshape_transpose_to_sdpa.cpp @@ -0,0 +1,245 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" +#include "openvino/pass/manager.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp" +#include "utils/cpu_test_utils.hpp" + +using namespace ov::test; +using namespace CPUTestUtils; + +namespace ov { +namespace test { + +// Subgraph: +/* + * Parameter Parameter + * | | + * Parameter ReadValue ReadValue + * | | \ | \ + * Reshape Reshape Assign Reshape Assign + * | | | + * Transpose Transpoe Transpose + * \ | / + * ScaledDotProductAttention + * | + * Tranpose + * | + * Reshape + * | + * Result + */ + +// +using InputShapeAndReshapeOrder = std::pair, std::vector>; +using FuseSDPAReshapeTransposeTestParams = std::tuple; +class FuseSDPAReshapeTransposeTest : virtual public ov::test::SubgraphBaseTest, + public testing::WithParamInterface, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType inType; + InputShapeAndReshapeOrder inputShapeAndOrders; + std::tie(inType, inputShapeAndOrders) = obj.param; + std::ostringstream result; + std::vector& inputShapes = inputShapeAndOrders.first; + auto& reshapeOrderHS = inputShapeAndOrders.second; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType << "_"; + result << "ReshapeOrderHS="; + result << "("; + for (const auto& itr : reshapeOrderHS) { + result << itr << ","; + } + result << ")"; + + return result.str(); + } + + void SetUp() override { + ElementType inType; + InputShapeAndReshapeOrder inputShapeAndOrders; + std::tie(inType, inputShapeAndOrders) = this->GetParam(); + std::vector& inputShapes = inputShapeAndOrders.first; + auto& reshapeOrderHS = inputShapeAndOrders.second; + targetDevice = ov::test::utils::DEVICE_CPU; + rel_threshold = 1e-2f; + configuration[ov::hint::inference_precision.name()] = ov::element::f32; + if (inType == ElementType::bf16) { + configuration[ov::hint::inference_precision.name()] = ov::element::bf16; + rel_threshold = 0.01f; + } + init_input_shapes(inputShapes); + + // pre SDPA reshape->transpose + ov::ParameterVector inputParams(3); + ov::SinkVector sinkNodes; + OutputVector transposes(3); + for (size_t i = 0; i < 3u; i++) { + inputParams[i] = std::make_shared(inType, inputDynamicShapes[0]); + + auto reshape_axis = + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 0, reshapeOrderHS[0], reshapeOrderHS[1]}); + + std::shared_ptr reshape_input_1 = inputParams[i]; + if (i > 0) { + auto var = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[0], inType, "var_" + std::to_string(i)}); + auto readvalue = std::make_shared(inputParams[i], var); + auto assign = std::make_shared(readvalue, var); + sinkNodes.emplace_back(assign); + reshape_input_1 = readvalue; + } + + auto reshape = std::make_shared(reshape_input_1, reshape_axis, true); + auto transposeOrder = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}); + transposes[i] = std::make_shared(reshape, transposeOrder); + } + + auto sdpa = std::make_shared(transposes, false); + sdpa->set_friendly_name("mha"); + + // post SDPA transpose + reshape + auto postOrder = + ov::op::v0::Constant::create(ov::element::i64, {4}, std::vector{0, 2, 1, 3}); // BHLS -> BLHS + auto transposeSDPA = std::make_shared(sdpa, postOrder); + + auto constReshape = + ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, reshapeOrderHS[0] * reshapeOrderHS[1]}); + auto reshapeSDPA = std::make_shared(transposeSDPA, constReshape, true); // BLHS -> B,L,HxS + + function = std::make_shared(ov::OutputVector{reshapeSDPA}, + sinkNodes, + inputParams, + "FuseSDPAReshapeTranspose"); + targetDevice = ov::test::utils::DEVICE_CPU; + functionRefs = function->clone(); + pass::Manager manager; + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); + } + + template + static void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + } + void generate(int idx, const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this] (std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == ov::element::i32) { + ov::Tensor t{ov::element::i32, shape}; + auto size = ov::shape_size(shape); + auto* p = static_cast(t.data()); + auto start = static_cast(val); + for (size_t i = 0; i < size; i++) { + p[i] = (start + i) % size; + } + inputs.insert({param, t}); + } else if (param->get_element_type() == ov::element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ASSERT_TRUE(param->get_element_type() == ov::element::bf16); + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v + create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f); + create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + } + + std::vector run_test(std::shared_ptr model) { + function = model; + prepare(); + std::vector outputs; + int idx = 0; + for (auto&& shapes : targetStaticShapes) { + generate(idx++, shapes); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + } + inferRequest.infer(); + auto outputTensor = inferRequest.get_output_tensor(0); + ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()}; + outputTensor.copy_to(copy); + outputs.push_back(copy); + reset(); + } + return outputs; + } +}; + +TEST_P(FuseSDPAReshapeTransposeTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + bool reshape_transpose_fused = false; + auto actualOutputs = run_test(function); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); + CheckNumberOfNodesWithType(compiledModel, "Reshape", 0); + CheckNumberOfNodesWithType(compiledModel, "Transpose", 0); + for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) { + if (n->get_friendly_name() == "mha/fused_reshape_transpose") { + reshape_transpose_fused = true; + } + } + ASSERT_TRUE(reshape_transpose_fused); + + auto expectedOutputs = run_test(functionRefs); + for (size_t i = 0; i < actualOutputs.size(); i++) { + ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold); + } +} + +namespace { +const std::vector inputShapeAndReshapeOrders = { + // + { + {{ + // Q,K,V:[B, L, H*S] + {{-1, -1, 4 * 16}, {{1, 1, 4 * 16}, {1, 2, 4 * 16}, {2, 2, 4 * 16}}}, + }, + // reshapeOrderHS + {4, 16}}, + }}; + +INSTANTIATE_TEST_SUITE_P(smoke_FuseSDPAReshapeTransposeTest, + FuseSDPAReshapeTransposeTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::ValuesIn(inputShapeAndReshapeOrders)), + FuseSDPAReshapeTransposeTest::getTestCaseName); +} // namespace +} // namespace test +} // namespace ov