Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ class XmlSerializer : public ov::util::XmlSerializer {
false,
ov::element::dynamic,
false),
m_base_constant_writer(std::ref(constant_write_handler)),
m_weightless_constant_writer(weightless_constant_writer
? weightless_constant_writer
: std::make_shared<WeightlessWriter>(constant_write_handler)) {}

private:
/**
* @brief Toggles between the two writers.
*/
ov::util::ConstantWriter& get_constant_write_handler() override;

/**
* @brief Overriden in order to choose which weights writer will be used based on the occurrence of the
* "WeightsPointerAttribute".
*/
bool append_node_attributes(ov::Node& node) override;

std::unique_ptr<ov::util::XmlSerializer> make_visitor(pugi::xml_node& data,
const std::string& node_type_name,
ov::util::ConstantWriter& constant_write_handler,
Expand All @@ -56,7 +66,15 @@ class XmlSerializer : public ov::util::XmlSerializer {
ov::element::Type,
bool) const override;

/**
* @brief The base OV writer, copies the weights in a dedicated buffer.
*
* @note Ideally, we would not require this writer at all. The current algorithm does not handle subgraphs properly,
* so falling back to copying a part of the weights is a temporary fix.
*/
std::reference_wrapper<ov::util::ConstantWriter> m_base_constant_writer;
std::shared_ptr<WeightlessWriter> m_weightless_constant_writer = nullptr;
bool m_use_weightless_writer = false;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,20 @@
namespace intel_npu {

ov::util::ConstantWriter& XmlSerializer::get_constant_write_handler() {
return *m_weightless_constant_writer;
if (m_use_weightless_writer) {
return *m_weightless_constant_writer;
} else {
return m_base_constant_writer;
}
}

bool XmlSerializer::append_node_attributes(ov::Node& node) {
// If the "WeightsPointerAttribute" is found, then we have the metadata required to avoid copying the weights
// corresponding to this node.
m_use_weightless_writer = node.get_rt_info().count(WeightsPointerAttribute::get_type_info_static()) != 0;
auto result = ov::util::XmlSerializer::append_node_attributes(node);
m_use_weightless_writer = false;
return result;
}

std::unique_ptr<ov::util::XmlSerializer> XmlSerializer::make_visitor(pugi::xml_node& data,
Expand Down
Loading