-
Notifications
You must be signed in to change notification settings - Fork 1.7k
rust: support graph sub-plugins #4569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b53e1a1
d2ceae4
a1c4362
61a5e01
ede7691
22dd050
a4fe0c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,9 @@ pub(crate) const IMAGES_PLUGIN_NAME: &str = "images"; | |
| pub(crate) const AUDIO_PLUGIN_NAME: &str = "audio"; | ||
| pub(crate) const GRAPHS_PLUGIN_NAME: &str = "graphs"; | ||
| pub(crate) const GRAPH_TAGGED_RUN_METADATA_PLUGIN_NAME: &str = "graph_tagged_run_metadata"; | ||
| pub(crate) const GRAPH_RUN_METADATA_PLUGIN_NAME: &str = "graph_run_metadata"; | ||
| pub(crate) const GRAPH_RUN_METADATA_WITH_GRAPH_PLUGIN_NAME: &str = "graph_run_metadata_graph"; | ||
| pub(crate) const GRAPH_KERAS_MODEL_PLUGIN_NAME: &str = "graph_keras_model"; | ||
|
|
||
| /// The inner contents of a single value from an event. | ||
| /// | ||
|
|
@@ -76,11 +79,17 @@ impl EventValue { | |
|
|
||
| /// Consumes this event value and enriches it into a blob sequence. | ||
| /// | ||
| /// For now, this supports `GraphDef`s, tagged run metadata protos, summaries with `image` or | ||
| /// `audio`, or summaries with `tensor` set to a rank-1 tensor of type `DT_STRING`. If the | ||
| /// summary metadata indicates that this is audio data, `tensor` may also be a string tensor of | ||
| /// shape `[k, 2]`, in which case the second axis is assumed to represent string labels and is | ||
| /// dropped entirely. | ||
| /// This supports: | ||
| /// | ||
| /// - `GraphDef`s; | ||
| /// - tagged run metadata protos; | ||
| /// - summaries with TensorFlow 1.x `image` or `audio`; | ||
| /// - summaries with `tensor` set to a rank-1 tensor of type `DT_STRING`; | ||
| /// - for audio metadata, summaries with `tensor` set to a shape-`[k, 2]` tensor of type | ||
| /// `DT_STRING`, in which case the second axis is assumed to represent string labels and is | ||
| /// dropped entirely; | ||
| /// - for graph sub-plugin metadata, summaries with `tensor` set to a rank-0 tensor of type | ||
| /// `DT_STRING`, which is converted to a shape-`[1]` tensor. | ||
| pub fn into_blob_sequence( | ||
| self, | ||
| metadata: &pb::SummaryMetadata, | ||
|
|
@@ -108,10 +117,7 @@ impl EventValue { | |
| Ok(BlobSequenceValue(tp.string_val)) | ||
| } else if shape.dim.len() == 2 | ||
| && shape.dim[1].size == 2 | ||
| && (metadata | ||
| .plugin_data | ||
| .as_ref() | ||
| .map_or(false, |pd| pd.plugin_name == AUDIO_PLUGIN_NAME)) | ||
| && is_plugin(&metadata, AUDIO_PLUGIN_NAME) | ||
| { | ||
| // Extract just the actual audio clips along the first axis. | ||
| let audio: Vec<Vec<u8>> = tp | ||
|
|
@@ -120,6 +126,14 @@ impl EventValue { | |
| .map(|chunk| std::mem::take(&mut chunk[0])) | ||
| .collect(); | ||
| Ok(BlobSequenceValue(audio)) | ||
| } else if shape.dim.is_empty() | ||
| && tp.string_val.len() == 1 | ||
| && (is_plugin(&metadata, GRAPH_RUN_METADATA_PLUGIN_NAME) | ||
| || is_plugin(&metadata, GRAPH_RUN_METADATA_WITH_GRAPH_PLUGIN_NAME) | ||
| || is_plugin(&metadata, GRAPH_KERAS_MODEL_PLUGIN_NAME)) | ||
| { | ||
| let data = tp.string_val.into_iter().next().unwrap(); | ||
| Ok(BlobSequenceValue(vec![data])) | ||
| } else { | ||
| Err(DataLoss) | ||
| } | ||
|
|
@@ -157,6 +171,13 @@ fn tensor_proto_to_scalar(tp: &pb::TensorProto) -> Option<f32> { | |
| } | ||
| } | ||
|
|
||
| /// Tests whether `md` has plugin name `plugin_name`. | ||
| fn is_plugin(md: &pb::SummaryMetadata, plugin_name: &str) -> bool { | ||
| md.plugin_data | ||
| .as_ref() | ||
| .map_or(false, |pd| pd.plugin_name == plugin_name) | ||
| } | ||
|
|
||
| /// A value from an `Event` whose `graph_def` field is set. | ||
| /// | ||
| /// This contains the raw bytes of a serialized `GraphDef` proto. It implies a fixed tag name and | ||
|
|
@@ -242,7 +263,11 @@ impl SummaryValue { | |
| Some(SCALARS_PLUGIN_NAME) => { | ||
| md.data_class = pb::DataClass::Scalar.into(); | ||
| } | ||
| Some(IMAGES_PLUGIN_NAME) | Some(AUDIO_PLUGIN_NAME) => { | ||
| Some(IMAGES_PLUGIN_NAME) | ||
| | Some(AUDIO_PLUGIN_NAME) | ||
| | Some(GRAPH_RUN_METADATA_PLUGIN_NAME) | ||
| | Some(GRAPH_RUN_METADATA_WITH_GRAPH_PLUGIN_NAME) | ||
| | Some(GRAPH_KERAS_MODEL_PLUGIN_NAME) => { | ||
| md.data_class = pb::DataClass::BlobSequence.into(); | ||
| } | ||
| _ => {} | ||
|
|
@@ -671,6 +696,48 @@ mod tests { | |
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_graph_subplugins() { | ||
| for &plugin_name in &[ | ||
| GRAPH_RUN_METADATA_PLUGIN_NAME, | ||
| GRAPH_RUN_METADATA_WITH_GRAPH_PLUGIN_NAME, | ||
| GRAPH_KERAS_MODEL_PLUGIN_NAME, | ||
| ] { | ||
| let md = pb::SummaryMetadata { | ||
| plugin_data: Some(PluginData { | ||
| plugin_name: plugin_name.to_string(), | ||
| content: b"1".to_vec(), | ||
| ..Default::default() | ||
| }), | ||
| ..Default::default() | ||
| }; | ||
| let v = SummaryValue(Box::new(Value::Tensor(pb::TensorProto { | ||
| dtype: pb::DataType::DtString.into(), | ||
| tensor_shape: Some(tensor_shape(&[])), | ||
| string_val: vec![b"some-graph-proto".to_vec()], | ||
| ..Default::default() | ||
| }))); | ||
|
|
||
| // Test both metadata and enrichment here, for convenience. | ||
| let initial_metadata = v.initial_metadata(Some(md)); | ||
| assert_eq!( | ||
| *initial_metadata, | ||
| pb::SummaryMetadata { | ||
| plugin_data: Some(PluginData { | ||
| plugin_name: plugin_name.to_string(), | ||
| content: b"1".to_vec(), | ||
| ..Default::default() | ||
| }), | ||
| data_class: pb::DataClass::BlobSequence.into(), | ||
| ..Default::default() | ||
| }, | ||
| ); | ||
|
Comment on lines
+722
to
+734
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No AI required: this test feels too mechanical :\
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like, too much of “assert that Or maybe you meant that the test implementation spends a lot of lines on Noted that you said “no AI required”, so feel free to reply or not, as |
||
| let expected_enriched = BlobSequenceValue(vec![b"some-graph-proto".to_vec()]); | ||
| let actual_enriched = EventValue::Summary(v).into_blob_sequence(&initial_metadata); | ||
| assert_eq!(actual_enriched, Ok(expected_enriched)); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_enrich_graph_def() { | ||
| let v = EventValue::GraphDef(GraphDefValue(vec![1, 2, 3, 4])); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future plan question: are we going to trust the
data_classproperty on the proto in the future? I believe the graph_run* and graph_keras* data are all annotated correctly with the data_class.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already trust the data class: the first clause of this
matchis(Likewise for Python TensorBoard.)
But I’m not sure if the graph data does set data classes? The data
written by
:graphs_demodoesn’t seem to have them, and I don’t seedata classes in the writing code. In any case, data classes were
introduced more recently than the
graph_*writing code, so old datamust not have them, I think.
So eventually I would probably like plugins to write summaries with data
class explicitly set, but I expect to maintain this compatibility layer
indefinitely, and that seems okay to me. WDYT?