Skip to content

Commit

Permalink
NPUW: Extend HostGather to more cases (openvinotoolkit#27138)
Browse files Browse the repository at this point in the history
### Details:
 - *item1*
 - *...*

### Tickets:
 - *ticket-id*
  • Loading branch information
dmatveev authored and CuriousPanCake committed Nov 6, 2024
1 parent 1ba93ad commit 3398f28
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Context::PPtr Context::host_gather(Context::PPtr w, Context::PPtr ids) {
}

namespace opp = ov::pass::pattern;
namespace uat = ov::npuw::util::at;

// FROM:
// ???(Act) ----------------------------------->
Expand Down Expand Up @@ -802,7 +803,7 @@ DQLiftGatherAsymCW::DQLiftGatherAsymCW() {
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});

auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});
auto gather = opp::wrap_type<ov::op::v8::Gather>({qcvtm, cvtids, opp::any_input()});

// Note: Use [=] to make sure the above objects stay alive in the callback
Expand All @@ -813,7 +814,7 @@ DQLiftGatherAsymCW::DQLiftGatherAsymCW() {
auto matched_out_w = node_to_output.at(qweight);
auto matched_out_z = node_to_output.at(qzerop);
auto matched_out_s = node_to_output.at(qcoeff);
auto matched_out_ids = node_to_output.at(cvtids);
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
const auto& matched_out_gather = node_to_output.at(gather);

// Replicate the compute part
Expand Down Expand Up @@ -847,7 +848,7 @@ DQLiftGatherSymCW::DQLiftGatherSymCW() {
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});

auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});
auto gather = opp::wrap_type<ov::op::v8::Gather>({qcvtm, cvtids, opp::any_input()});

// Note: Use [=] to make sure the above objects stay alive in the callback
Expand All @@ -856,7 +857,7 @@ DQLiftGatherSymCW::DQLiftGatherSymCW() {

auto matched_out_w = node_to_output.at(qweight);
auto matched_out_s = node_to_output.at(qcoeff);
auto matched_out_ids = node_to_output.at(cvtids);
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
const auto& matched_out_gather = node_to_output.at(gather);

// Create new gathers on W and S, connect respectively
Expand Down Expand Up @@ -888,7 +889,7 @@ DQLiftGatherSymGQ::DQLiftGatherSymGQ() {
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qreshp});

auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});
auto gather = opp::wrap_type<ov::op::v8::Gather>({qcvtm, cvtids, opp::any_input()});

// Note: Use [=] to make sure the above objects stay alive in the callback
Expand All @@ -898,7 +899,7 @@ DQLiftGatherSymGQ::DQLiftGatherSymGQ() {
// Create new gathers on W and S respectively
auto matched_out_w = node_to_output.at(qweight);
auto matched_out_s = node_to_output.at(qcoeff);
auto matched_out_ids = node_to_output.at(cvtids);
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
const auto& matched_out_gather = node_to_output.at(gather);

auto matched_gather_shape = matched_out_gather.get_shape();
Expand Down Expand Up @@ -934,7 +935,7 @@ DQLiftGatherSymGQ::DQLiftGatherSymGQ() {
// compile-time converts asymmetric MM to fp16, do the same thing here
DQUnpackDictGatherCWu::DQUnpackDictGatherCWu(Context::Ref ctx) {
auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});

auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qzerop = opp::wrap_type<ov::op::v0::Parameter>();
Expand All @@ -956,7 +957,7 @@ DQUnpackDictGatherCWu::DQUnpackDictGatherCWu(Context::Ref ctx) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qzerop = node_to_output.at(qzerop).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_out_ids = node_to_output.at(cvtids);
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
auto matched_node_cvt = node_to_output.at(qcvtm).get_node_shared_ptr();

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -979,7 +980,7 @@ DQUnpackDictGatherCWu::DQUnpackDictGatherCWu(Context::Ref ctx) {
// block (mainly, a head) was turned a function (e.g. with FUNCALL_FOR_ALL)
DQUnpackDictGatherGQi::DQUnpackDictGatherGQi(Context::Ref ctx) {
auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});

auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
Expand All @@ -997,7 +998,7 @@ DQUnpackDictGatherGQi::DQUnpackDictGatherGQi(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_out_ids = node_to_output.at(cvtids);
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
auto matched_node_cvt = node_to_output.at(qcvtm).get_node_shared_ptr();

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -1022,7 +1023,7 @@ DQUnpackDictGatherGQi::DQUnpackDictGatherGQi(Context::Ref ctx) {
// * - DictGather-related transformations
HostGather::HostGather(Context::Ref ctx) {
auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});

auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qgthrw = opp::wrap_type<ov::op::v8::Gather>({qweight, cvtids, opp::any_input()});
Expand Down Expand Up @@ -1078,7 +1079,7 @@ HostGather::HostGather(Context::Ref ctx) {
// due to i4-to-fp16 conversion.
HostGatherDQ::HostGatherDQ(Context::Ref ctx) {
auto pids = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtids = opp::wrap_type<ov::op::v0::Convert>({pids});
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});

auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
Expand All @@ -1105,7 +1106,7 @@ HostGatherDQ::HostGatherDQ(Context::Ref ctx) {
const auto& matched_out_qweight = node_to_output.at(qweight);
auto qweight_type = matched_out_qweight.get_element_type();

if (out_len >= 2048 && qweight_type == ov::element::i4) {
if (out_len >= 2048 && (qweight_type == ov::element::i4 || qweight_type == ov::element::i8)) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_ids = node_to_output.at(pids).get_node_shared_ptr();
Expand Down
14 changes: 14 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,24 @@ struct Impl {
return iter->second;
}

template <typename K>
V& at_or_at(const K& k1, const K& k2) {
const auto iter = m->find(k1);
if (iter == m->end()) {
return at(k2);
}
return iter->second;
}

template <typename K>
const V& at(const K& k) const {
return const_cast<Impl*>(this)->at(k);
}

template <typename K>
const V& at_or_at(const K& k1, const K& k2) const {
return const_cast<Impl*>(this)->at_or_at(k1, k2);
}
};

template <typename M>
Expand Down

0 comments on commit 3398f28

Please sign in to comment.