diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index cc7c3f69d6..5f529798b8 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -26,13 +26,15 @@ pub(crate) const SCALARS_PLUGIN_NAME: &str = "scalars"; pub(crate) const IMAGES_PLUGIN_NAME: &str = "images"; pub(crate) const AUDIO_PLUGIN_NAME: &str = "audio"; pub(crate) const GRAPHS_PLUGIN_NAME: &str = "graphs"; +pub(crate) const GRAPH_TAGGED_RUN_METADATA_PLUGIN_NAME: &str = "graph_tagged_run_metadata"; /// The inner contents of a single value from an event. /// /// This does not include associated step, wall time, tag, or summary metadata information. Step /// and wall time are available on every event and just not tracked here. Tag and summary metadata -/// information are materialized on `Event`s whose `oneof what` is `summary`, but implicit for -/// graph defs. See [`GraphDefValue::initial_metadata`] and [`SummaryValue::initial_metadata`] for +/// information are materialized on `Event`s whose `oneof what` is `tagged_run_metadata` or +/// `summary`, but implicit for graph defs. See [`GraphDefValue::initial_metadata`], +/// [`TaggedRunMetadataValue::initial_metadata`], and [`SummaryValue::initial_metadata`] for /// type-specific helpers to determine summary metadata given appropriate information. /// /// This is kept as close as possible to the on-disk event representation, since every record in @@ -46,6 +48,7 @@ pub(crate) const GRAPHS_PLUGIN_NAME: &str = "graphs"; #[derive(Debug)] pub enum EventValue { GraphDef(GraphDefValue), + TaggedRunMetadata(TaggedRunMetadataValue), Summary(SummaryValue), } @@ -53,11 +56,12 @@ impl EventValue { /// Consumes this event value and enriches it into a scalar. /// /// This supports `simple_value` (TF 1.x) summaries as well as rank-0 tensors of type - /// `DT_FLOAT`. Returns `DataLoss` if the value is a `GraphDef`, is an unsupported summary, or - /// is a tensor of the wrong rank. + /// `DT_FLOAT`. Returns `DataLoss` if the value is a `GraphDef`, a tagged run metadata proto, + /// an unsupported summary, or a tensor of the wrong rank. pub fn into_scalar(self) -> Result { let value_box = match self { EventValue::GraphDef(_) => return Err(DataLoss), + EventValue::TaggedRunMetadata(_) => return Err(DataLoss), EventValue::Summary(SummaryValue(v)) => v, }; match *value_box { @@ -72,16 +76,20 @@ impl EventValue { /// Consumes this event value and enriches it into a blob sequence. /// - /// For now, this supports `GraphDef`s, summaries with `image` or `audio`, or summaries with - /// `tensor` set to a rank-1 tensor of type `DT_STRING`. If the summary metadata indicates that - /// this is audio data, `tensor` may also be a string tensor of shape `[k, 2]`, in which case - /// the second axis is assumed to represent string labels and is dropped entirely. + /// For now, this supports `GraphDef`s, tagged run metadata protos, summaries with `image` or + /// `audio`, or summaries with `tensor` set to a rank-1 tensor of type `DT_STRING`. If the + /// summary metadata indicates that this is audio data, `tensor` may also be a string tensor of + /// shape `[k, 2]`, in which case the second axis is assumed to represent string labels and is + /// dropped entirely. pub fn into_blob_sequence( self, metadata: &pb::SummaryMetadata, ) -> Result { match self { EventValue::GraphDef(GraphDefValue(blob)) => Ok(BlobSequenceValue(vec![blob])), + EventValue::TaggedRunMetadata(TaggedRunMetadataValue(run_metadata)) => { + Ok(BlobSequenceValue(vec![run_metadata])) + } EventValue::Summary(SummaryValue(value_box)) => match *value_box { pb::summary::value::Value::Image(im) => { let w = format!("{}", im.width).into_bytes(); @@ -155,6 +163,12 @@ fn tensor_proto_to_scalar(tp: &pb::TensorProto) -> Option { /// plugin metadata, but these are not materialized. pub struct GraphDefValue(pub Vec); +/// A value from an `Event` whose `tagged_run_metadata` field is set. +/// +/// This contains only the `run_metadata` from the event (not the tag). This itself represents the +/// encoding of a `RunMetadata` proto, but that is deserialized at the plugin level. +pub struct TaggedRunMetadataValue(pub Vec); + /// A value from an `Event` whose `summary` field is set. /// /// This contains a [`summary::value::Value`], which represents the underlying `oneof value` field @@ -183,6 +197,17 @@ impl GraphDefValue { } } +impl TaggedRunMetadataValue { + /// Determines the metadata for a time series whose first event is a + /// [`TaggedRunMetadata`][`EventValue::TaggedRunMetadata`]. + pub fn initial_metadata() -> Box { + blank( + GRAPH_TAGGED_RUN_METADATA_PLUGIN_NAME, + pb::DataClass::BlobSequence, + ) + } +} + impl SummaryValue { /// Determines the metadata for a time series given its first event. /// @@ -237,6 +262,14 @@ impl Debug for GraphDefValue { } } +impl Debug for TaggedRunMetadataValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("TaggedRunMetadataValue") + .field(&format_args!("<{} bytes>", self.0.len())) + .finish() + } +} + /// Creates a summary metadata value with plugin name and data class, but no other contents. fn blank(plugin_name: &str, data_class: pb::DataClass) -> Box { Box::new(pb::SummaryMetadata { @@ -518,6 +551,16 @@ mod tests { assert_eq!(md.data_class, i32::from(pb::DataClass::BlobSequence)); } + #[test] + fn test_metadata_tagged_run_metadata() { + let md = TaggedRunMetadataValue::initial_metadata(); + assert_eq!( + &md.plugin_data.unwrap().plugin_name, + GRAPH_TAGGED_RUN_METADATA_PLUGIN_NAME + ); + assert_eq!(md.data_class, i32::from(pb::DataClass::BlobSequence)); + } + #[test] fn test_metadata_tf1x_image() { let v = SummaryValue(Box::new(Value::Image(pb::summary::Image { @@ -637,6 +680,15 @@ mod tests { ); } + #[test] + fn test_enrich_tagged_run_metadata() { + let v = EventValue::TaggedRunMetadata(TaggedRunMetadataValue(vec![1, 2, 3, 4])); + assert_eq!( + v.into_blob_sequence(GraphDefValue::initial_metadata().as_ref()), + Ok(BlobSequenceValue(vec![vec![1, 2, 3, 4]])) + ); + } + #[test] fn test_enrich_tf1x_image() { let v = SummaryValue(Box::new(Value::Image(pb::summary::Image { diff --git a/tensorboard/data/server/run.rs b/tensorboard/data/server/run.rs index 16f2dd07d8..e28a179f9d 100644 --- a/tensorboard/data/server/run.rs +++ b/tensorboard/data/server/run.rs @@ -23,7 +23,7 @@ use std::path::PathBuf; use std::sync::RwLock; use crate::commit; -use crate::data_compat::{EventValue, GraphDefValue, SummaryValue}; +use crate::data_compat::{EventValue, GraphDefValue, SummaryValue, TaggedRunMetadataValue}; use crate::event_file::EventFileReader; use crate::proto::tensorboard as pb; use crate::reservoir::StageReservoir; @@ -277,6 +277,21 @@ fn read_event( }; ts.rsv.offer(step, sv); } + Some(pb::event::What::TaggedRunMetadata(trm_proto)) => { + let sv = StageValue { + wall_time, + payload: EventValue::GraphDef(GraphDefValue(trm_proto.run_metadata)), + }; + use std::collections::hash_map::Entry; + let ts = match time_series.entry(Tag(trm_proto.tag)) { + Entry::Occupied(o) => o.into_mut(), + Entry::Vacant(v) => { + let metadata = TaggedRunMetadataValue::initial_metadata(); + v.insert(StageTimeSeries::new(metadata)) + } + }; + ts.rsv.offer(step, sv); + } Some(pb::event::What::Summary(sum)) => { for mut summary_pb_value in sum.value { let summary_value = match summary_pb_value.value { @@ -346,6 +361,12 @@ mod test { WallTime::new(1235.0).unwrap(), b"".to_vec(), )?; + f1.write_tagged_run_metadata( + &Tag("step0000".to_string()), + Step(0), + WallTime::new(1235.0).unwrap(), + b"".to_vec(), + )?; f1.write_scalar(&tag, Step(0), WallTime::new(1235.0).unwrap(), 0.25)?; f1.write_scalar(&tag, Step(1), WallTime::new(1236.0).unwrap(), 0.50)?; f1.write_scalar(&tag, Step(2), WallTime::new(1237.0).unwrap(), 0.75)?; @@ -404,11 +425,9 @@ mod test { ] ); + assert_eq!(run_data.blob_sequences.len(), 2); + let run_graph_tag = Tag(GraphDefValue::TAG_NAME.to_string()); - assert_eq!( - run_data.blob_sequences.keys().collect::>(), - vec![&run_graph_tag] - ); let graph_ts = run_data.blob_sequences.get(&run_graph_tag).unwrap(); assert_eq!( *graph_ts.metadata, @@ -430,6 +449,29 @@ mod test { )] ); + let run_metadata_tag = Tag("step0000".to_string()); + let run_metadata_ts = run_data.blob_sequences.get(&run_metadata_tag).unwrap(); + assert_eq!( + *run_metadata_ts.metadata, + pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: crate::data_compat::GRAPH_TAGGED_RUN_METADATA_PLUGIN_NAME + .to_string(), + ..Default::default() + }), + data_class: pb::DataClass::BlobSequence.into(), + ..Default::default() + } + ); + assert_eq!( + run_metadata_ts.valid_values().collect::>(), + vec![( + Step(0), + WallTime::new(1235.0).unwrap(), + &commit::BlobSequenceValue(vec![b"".to_vec()]) + )] + ); + Ok(()) } } diff --git a/tensorboard/data/server/writer.rs b/tensorboard/data/server/writer.rs index 3f2c5a3e89..019b659a9a 100644 --- a/tensorboard/data/server/writer.rs +++ b/tensorboard/data/server/writer.rs @@ -65,6 +65,27 @@ pub trait SummaryWriteExt: Write { }; self.write_event(&event) } + + /// Writes a TFRecord containing a TF 1.x `tagged_run_metadata` event. + fn write_tagged_run_metadata( + &mut self, + tag: &Tag, + step: Step, + wt: WallTime, + run_metadata: Vec, + ) -> std::io::Result<()> { + let event = pb::Event { + step: step.0, + wall_time: wt.into(), + what: Some(pb::event::What::TaggedRunMetadata(pb::TaggedRunMetadata { + tag: tag.0.clone(), + run_metadata, + ..Default::default() + })), + ..Default::default() + }; + self.write_event(&event) + } } impl SummaryWriteExt for W {} @@ -158,4 +179,32 @@ mod tests { }; assert_eq!(event, &expected); } + + #[test] + fn test_tagged_run_metadata_roundtrip() { + let mut cursor = Cursor::new(Vec::::new()); + cursor + .write_tagged_run_metadata( + &Tag("step0000".to_string()), + Step(777), + WallTime::new(1234.5).unwrap(), + b"my run metadata".to_vec(), + ) + .unwrap(); + cursor.set_position(0); + let events = read_all_events(cursor).unwrap(); + assert_eq!(events.len(), 1); + + let event = &events[0]; + let expected = pb::Event { + step: 777, + wall_time: 1234.5, + what: Some(pb::event::What::TaggedRunMetadata(pb::TaggedRunMetadata { + tag: "step0000".to_string(), + run_metadata: b"my run metadata".to_vec(), + })), + ..Default::default() + }; + assert_eq!(event, &expected); + } }