Skip to content

Commit

Permalink
[NPUW] Add Slice before last MatMul (#27229)
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey authored Oct 24, 2024
1 parent 9b377f4 commit 433e44e
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime);
DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, CompileTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime);
DEFINE_OPT(NPUW_SPATIAL, bool, false, npuw::partitioning::spatial, CompileTime);
DEFINE_OPT(NPUW_SPATIAL_NWAY, std::size_t, 128, npuw::partitioning::spatial_nway, CompileTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ static constexpr ov::Property<bool> dyn_quant{"NPUW_DQ"};
*/
static constexpr ov::Property<std::string> par_matmul_merge_dims{"NPUW_PMM"};

/**
* @brief
* Type: bool.
* Add Slice before the last MatMul reducing output's dimention.
* Default value: false.
*/
static constexpr ov::Property<bool> slice_out{"NPUW_SLICE_OUT"};

/**
* @brief
* Type: boolean.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_CWAI>();
desc.add<NPUW_DQ>();
desc.add<NPUW_PMM>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
desc.add<NPUW_SPATIAL_NWAY>();
desc.add<NPUW_SPATIAL_DYN>();
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
rewr.run_on_model(model);
}

if (m_cfg.get<::intel_npu::NPUW_SLICE_OUT>()) {
// Add Slice before last MatMul for the prefill model
ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmul>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulAdd>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulTranspose>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulMultiply>();
rewr.run_on_model(model);
}

auto partitioning = getPartitioning(model, m_cfg);
m_total_stat.gflops = partitioning.total_gflops;
m_total_stat.ops = partitioning.total_ops;
Expand Down Expand Up @@ -906,6 +916,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::cwai, NPUW_CWAI),
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
BIND(npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY),
BIND(npuw::partitioning::spatial_dyn, NPUW_SPATIAL_DYN),
Expand Down
158 changes: 146 additions & 12 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,7 @@

#include "../../logging.hpp"
#include "../../util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/label.hpp" // any_input
#include "openvino/pass/pattern/op/optional.hpp"
Expand Down Expand Up @@ -1296,6 +1285,151 @@ CompressDictMatMulf32::CompressDictMatMulf32(Context::Ref ctx) {
register_matcher(std::make_shared<opp::Matcher>(res, "OptCompressDictMatMulf32"), std::move(callback));
}

SliceLastMatmul::SliceLastMatmul() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({matmul});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmul"), std::move(callback));
}

SliceLastMatmulAdd::SliceLastMatmulAdd() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({add});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulAdd"), std::move(callback));
}

SliceLastMatmulTranspose::SliceLastMatmulTranspose() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto add = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({matmul});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulTranspose"), std::move(callback));
}

SliceLastMatmulMultiply::SliceLastMatmulMultiply() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto div = opp::wrap_type<ov::op::v1::Divide>({matmul, opp::any_input()});
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({multiply});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulMultiply"), std::move(callback));
}

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ class CompressDictMatMulf32 : public ov::pass::MatcherPass {
CompressDictMatMulf32(Context::Ref ctx);
};

// Slice last Matmul
class SliceLastMatmul : public ov::pass::MatcherPass {
public:
SliceLastMatmul();
};

class SliceLastMatmulAdd : public ov::pass::MatcherPass {
public:
SliceLastMatmulAdd();
};

class SliceLastMatmulTranspose : public ov::pass::MatcherPass {
public:
SliceLastMatmulTranspose();
};

class SliceLastMatmulMultiply : public ov::pass::MatcherPass {
public:
SliceLastMatmulMultiply();
};

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down

0 comments on commit 433e44e

Please sign in to comment.