Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metadata parsing #1490

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions onnx/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::path::PathBuf;
use std::{fs, path};
use std::{ fs, path};

use std::collections::HashMap;

Expand All @@ -8,7 +8,7 @@ use tract_hir::prelude::tract_itertools::Itertools;

use crate::data_resolver::{self, ModelDataResolver};
use crate::pb::type_proto::Value;
use crate::pb::{self, TensorProto, TypeProto};
use crate::pb::{self, GraphProto, OperatorSetIdProto, TensorProto, TypeProto};
use crate::tensor::{load_tensor, translate_inference_fact};
use prost::Message;

Expand Down Expand Up @@ -213,6 +213,60 @@ impl OnnxOpRegister {
}
}

#[derive(Debug, Clone)]
pub struct OnnxMetadata {
pub ir_version: i64,
pub opset_import: Vec<OperatorSetIdProto>,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub model_version: i64,
pub doc_string: String,
pub graph: Option<GraphProto>,
pub metadata_props: HashMap<String,String>
}


impl OnnxMetadata {
pub fn get_metadata(model_proto: &pb::ModelProto ) -> TractResult<OnnxMetadata> {
let parse_metadata_props: HashMap<String,String> = model_proto.to_owned().metadata_props
.into_iter().map(|entry| (entry.key, entry.value)).collect();
Ok(OnnxMetadata {
ir_version: model_proto.ir_version,
opset_import: model_proto.clone().opset_import,
producer_name: model_proto.clone().producer_name,
producer_version: model_proto.clone().producer_version,
domain: model_proto.clone().domain,
model_version: model_proto.model_version,
doc_string: model_proto.clone().doc_string,
graph: match model_proto.graph {
Some(_) => model_proto.clone().graph ,
None => None
},
metadata_props: parse_metadata_props
})
}
}

impl Default for OnnxMetadata {
fn default() -> Self {
let _opset = OperatorSetIdProto::default();
let graph = GraphProto::default();
OnnxMetadata {
ir_version: 0,
opset_import: vec![_opset],
producer_name: String::from(""),
producer_version: String::from(""),
domain: String::from(""),
model_version: 0,
doc_string: String::from(""),
graph: Some(graph),
metadata_props: HashMap::new()
}
}
}


#[derive(Clone)]
pub struct Onnx {
pub op_register: OnnxOpRegister,
Expand All @@ -236,6 +290,23 @@ impl Onnx {
pub fn parse(&self, proto: &pb::ModelProto, path: Option<&str>) -> TractResult<ParseResult> {
self.parse_with_template(proto, path, Default::default())
}

pub fn load_model_with_metadata(&mut self, model_path: impl AsRef<path::Path>) -> TractResult<(InferenceModel, OnnxMetadata)>{
let mut path = PathBuf::new();
path.push(&model_path);
let proto = self.proto_model_for_path(&model_path)?;
let mut dir: Option<&str> = None;
if let Some(dir_opt) = path.parent() {
dir = dir_opt.to_str();
}
let ParseResult { model, unresolved_inputs, .. } = self.parse(&proto, dir)?;
if unresolved_inputs.len() > 0 {
bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs)
}
let _metadata = OnnxMetadata::get_metadata(&proto)?;
Ok((model, _metadata))
}

pub fn parse_with_template(
&self,
proto: &pb::ModelProto,
Expand All @@ -248,8 +319,8 @@ impl Onnx {
.find(|import| import.domain.is_empty() || import.domain == "ai.onnx")
.map(|op| op.version)
.unwrap_or(0);
let graph =
proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?;
// self.metadata = OnnxMetadata::get_metadata(&proto)?;
let graph = proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?;
debug!("ONNX operator set version: {:?}", onnx_operator_set_version);
if onnx_operator_set_version != 0 && !(9..19).contains(&onnx_operator_set_version) {
warn!("ONNX operator for your model is {}, tract is only tested against \
Expand All @@ -267,7 +338,6 @@ impl Onnx {
trace!("created ParsingContext");
ctx.parse_graph(graph)
}

pub fn with_ignore_output_shapes(self, ignore: bool) -> Onnx {
Self { use_output_shapes: !ignore, ..self }
}
Expand Down
Loading