diff --git a/tensorboard/data/server/commit.rs b/tensorboard/data/server/commit.rs index fb8574fed9..24587c68d9 100644 --- a/tensorboard/data/server/commit.rs +++ b/tensorboard/data/server/commit.rs @@ -56,6 +56,9 @@ pub struct RunData { /// Scalar time series for this run. pub scalars: TagStore, + + /// Blob sequence time series for this run. + pub blob_sequences: TagStore, } pub type TagStore = HashMap>; @@ -105,6 +108,12 @@ pub struct DataLoss; #[derive(Debug, Copy, Clone, PartialEq)] pub struct ScalarValue(pub f32); +/// The value of a blob sequence time series at a single point. +/// +/// This value is a sequence of zero or more blobs, stored in memory. +#[derive(Debug, Clone, PartialEq)] +pub struct BlobSequenceValue(pub Vec>); + #[cfg(test)] mod tests { use super::*; diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index d79d7edbb1..0c951c39ce 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -18,7 +18,7 @@ limitations under the License. use std::convert::TryInto; use std::fmt::Debug; -use crate::commit::{DataLoss, ScalarValue}; +use crate::commit::{BlobSequenceValue, DataLoss, ScalarValue}; use crate::proto::tensorboard as pb; use pb::summary_metadata::PluginData; @@ -67,6 +67,16 @@ impl EventValue { _ => Err(DataLoss), } } + + /// Consumes this event value and enriches it into a blob sequence. + /// + /// For now, this succeeds only for graphs. + pub fn into_blob_sequence(self) -> Result { + match self { + EventValue::Summary(_) => Err(DataLoss), + EventValue::GraphDef(GraphDefValue(blob)) => Ok(BlobSequenceValue(vec![blob])), + } + } } fn tensor_proto_to_scalar(tp: &pb::TensorProto) -> Option { @@ -118,6 +128,11 @@ pub struct GraphDefValue(pub Vec); pub struct SummaryValue(pub Box); impl GraphDefValue { + /// Tag name used for run-level graphs. + /// + /// This must match `tensorboard.plugins.graph.metadata.RUN_GRAPH_NAME`. + pub const TAG_NAME: &'static str = "__run_graph__"; + /// Determines the metadata for a time series whose first event is a /// [`GraphDef`][`EventValue::GraphDef`]. pub fn initial_metadata() -> Box { @@ -450,11 +465,20 @@ mod tests { use super::*; #[test] - fn test() { + fn test_metadata() { let md = GraphDefValue::initial_metadata(); assert_eq!(&md.plugin_data.unwrap().plugin_name, GRAPHS_PLUGIN_NAME); assert_eq!(md.data_class, i32::from(pb::DataClass::BlobSequence)); } + + #[test] + fn test_enrich_graph_def() { + let v = EventValue::GraphDef(GraphDefValue(vec![1, 2, 3, 4])); + assert_eq!( + v.into_blob_sequence(), + Ok(BlobSequenceValue(vec![vec![1, 2, 3, 4]])) + ); + } } mod unknown { diff --git a/tensorboard/data/server/run.rs b/tensorboard/data/server/run.rs index 770c96fec8..9491c95950 100644 --- a/tensorboard/data/server/run.rs +++ b/tensorboard/data/server/run.rs @@ -23,7 +23,7 @@ use std::path::PathBuf; use std::sync::RwLock; use crate::commit; -use crate::data_compat::{EventValue, SummaryValue}; +use crate::data_compat::{EventValue, GraphDefValue, SummaryValue}; use crate::event_file::EventFileReader; use crate::proto::tensorboard as pb; use crate::reservoir::StageReservoir; @@ -117,15 +117,7 @@ impl StageTimeSeries { ); } DataClass::BlobSequence => { - warn!( - "Blob sequence time series not yet supported (tag: {}, plugin: {})", - tag.0, - self.metadata - .plugin_data - .as_ref() - .map(|p| p.plugin_name.as_str()) - .unwrap_or("") - ); + self.commit_to(tag, &mut run.blob_sequences, EventValue::into_blob_sequence) } _ => (), }; @@ -270,9 +262,19 @@ fn read_event( *start_time = Some(wall_time); } match e.what { - Some(pb::event::What::GraphDef(_)) => { - // TODO(@wchargin): Handle run graphs. - warn!("`graph_def` events not yet handled"); + Some(pb::event::What::GraphDef(graph_bytes)) => { + let sv = StageValue { + wall_time, + payload: EventValue::GraphDef(GraphDefValue(graph_bytes)), + }; + use std::collections::hash_map::Entry; + let ts = match time_series.entry(Tag(GraphDefValue::TAG_NAME.to_string())) { + Entry::Occupied(o) => o.into_mut(), + Entry::Vacant(v) => { + v.insert(StageTimeSeries::new(GraphDefValue::initial_metadata())) + } + }; + ts.rsv.offer(step, sv); } Some(pb::event::What::Summary(sum)) => { for mut summary_pb_value in sum.value { @@ -338,6 +340,11 @@ mod test { // Write some data points across both files. let run = Run("train".to_string()); let tag = Tag("accuracy".to_string()); + f1.write_graph( + Step(0), + WallTime::new(1235.0).unwrap(), + b"".to_vec(), + )?; f1.write_scalar(&tag, Step(0), WallTime::new(1235.0).unwrap(), 0.25)?; f1.write_scalar(&tag, Step(1), WallTime::new(1236.0).unwrap(), 0.50)?; f1.write_scalar(&tag, Step(2), WallTime::new(1237.0).unwrap(), 0.75)?; @@ -371,9 +378,9 @@ mod test { .expect("read-locking run data map"); assert_eq!(run_data.scalars.keys().collect::>(), vec![&tag]); - let ts = run_data.scalars.get(&tag).unwrap(); + let scalar_ts = run_data.scalars.get(&tag).unwrap(); assert_eq!( - *ts.metadata, + *scalar_ts.metadata, pb::SummaryMetadata { plugin_data: Some(pb::summary_metadata::PluginData { plugin_name: crate::data_compat::SCALARS_PLUGIN_NAME.to_string(), @@ -383,11 +390,10 @@ mod test { ..Default::default() } ); - // Points should be as expected (no downsampling at these sizes). let scalar = commit::ScalarValue; assert_eq!( - ts.valid_values().collect::>(), + scalar_ts.valid_values().collect::>(), vec![ (Step(0), WallTime::new(1235.0).unwrap(), &scalar(0.25)), (Step(1), WallTime::new(1236.0).unwrap(), &scalar(0.50)), @@ -397,6 +403,32 @@ mod test { ] ); + let run_graph_tag = Tag(GraphDefValue::TAG_NAME.to_string()); + assert_eq!( + run_data.blob_sequences.keys().collect::>(), + vec![&run_graph_tag] + ); + let graph_ts = run_data.blob_sequences.get(&run_graph_tag).unwrap(); + assert_eq!( + *graph_ts.metadata, + pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: crate::data_compat::GRAPHS_PLUGIN_NAME.to_string(), + ..Default::default() + }), + data_class: pb::DataClass::BlobSequence.into(), + ..Default::default() + } + ); + assert_eq!( + graph_ts.valid_values().collect::>(), + vec![( + Step(0), + WallTime::new(1235.0).unwrap(), + &commit::BlobSequenceValue(vec![b"".to_vec()]) + )] + ); + Ok(()) } } diff --git a/tensorboard/data/server/writer.rs b/tensorboard/data/server/writer.rs index e0cb303b1d..3f2c5a3e89 100644 --- a/tensorboard/data/server/writer.rs +++ b/tensorboard/data/server/writer.rs @@ -54,6 +54,17 @@ pub trait SummaryWriteExt: Write { }; self.write_event(&event) } + + /// Writes a TFRecord containing a TF 1.x `graph_def` event. + fn write_graph(&mut self, step: Step, wt: WallTime, bytes: Vec) -> std::io::Result<()> { + let event = pb::Event { + step: step.0, + wall_time: wt.into(), + what: Some(pb::event::What::GraphDef(bytes)), + ..Default::default() + }; + self.write_event(&event) + } } impl SummaryWriteExt for W {} @@ -123,4 +134,28 @@ mod tests { }; assert_eq!(event, &expected); } + + #[test] + fn test_graph_roundtrip() { + let mut cursor = Cursor::new(Vec::::new()); + cursor + .write_graph( + Step(777), + WallTime::new(1234.5).unwrap(), + b"my graph".to_vec(), + ) + .unwrap(); + cursor.set_position(0); + let events = read_all_events(cursor).unwrap(); + assert_eq!(events.len(), 1); + + let event = &events[0]; + let expected = pb::Event { + step: 777, + wall_time: 1234.5, + what: Some(pb::event::What::GraphDef(b"my graph".to_vec())), + ..Default::default() + }; + assert_eq!(event, &expected); + } }