diff --git a/tensorboard/data/server/commit.rs b/tensorboard/data/server/commit.rs index 24587c68d9..aa088dd069 100644 --- a/tensorboard/data/server/commit.rs +++ b/tensorboard/data/server/commit.rs @@ -211,6 +211,45 @@ pub mod test_data { self } + /// Adds a blob sequence time series, creating the run if it doesn't exist, and setting its + /// start time if unset. + /// + /// # Examples + /// + /// ``` + /// use rustboard_core::commit::{test_data::CommitBuilder, BlobSequenceValue, Commit}; + /// + /// let my_commit: Commit = CommitBuilder::new() + /// .blob_sequences("train", "input_image", |mut b| { + /// b.plugin_name("images") + /// .values(vec![ + /// BlobSequenceValue(vec![b"step0img0".to_vec()]), + /// BlobSequenceValue(vec![b"step1img0".to_vec(), b"step1img1".to_vec()]), + /// ]) + /// .build() + /// }) + /// .build(); + /// ``` + pub fn blob_sequences( + self, + run: &str, + tag: &str, + build: impl FnOnce(BlobSequenceTimeSeriesBuilder) -> TimeSeries, + ) -> Self { + self.with_run_data(Run(run.to_string()), |run_data| { + let time_series = build(BlobSequenceTimeSeriesBuilder::default()); + if let (None, Some((_step, wall_time, _value))) = + (run_data.start_time, time_series.valid_values().next()) + { + run_data.start_time = Some(wall_time); + } + run_data + .blob_sequences + .insert(Tag(tag.to_string()), time_series); + }); + self + } + /// Ensures that a run is present and sets its start time. /// /// If you don't care about the start time and the run is going to have data, anyway, you @@ -307,4 +346,88 @@ pub mod test_data { time_series } } + + pub struct BlobSequenceTimeSeriesBuilder { + /// Initial step. Increments by `1` for each point. + step_start: Step, + /// Initial wall time. Increments by `1.0` for each point. + wall_time_start: WallTime, + /// Raw data for blob sequences in this time series. Defaults to + /// `vec![BlobSequenceValue(vec![])]`: i.e., one blob sequence, with one blob, which is + /// empty. + values: Vec, + /// Custom summary metadata. Leave `None` to use default. + metadata: Option>, + } + + impl Default for BlobSequenceTimeSeriesBuilder { + fn default() -> Self { + BlobSequenceTimeSeriesBuilder { + step_start: Step(0), + wall_time_start: WallTime::new(0.0).unwrap(), + values: vec![BlobSequenceValue(vec![])], + metadata: None, + } + } + } + + /// Creates a summary metadata value with plugin name and data class, but no other contents. + fn blank(plugin_name: &str, data_class: pb::DataClass) -> Box { + Box::new(pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: plugin_name.to_string(), + ..Default::default() + }), + data_class: data_class.into(), + ..Default::default() + }) + } + + impl BlobSequenceTimeSeriesBuilder { + pub fn step_start(&mut self, raw_step: i64) -> &mut Self { + self.step_start = Step(raw_step); + self + } + pub fn wall_time_start(&mut self, raw_wall_time: f64) -> &mut Self { + self.wall_time_start = WallTime::new(raw_wall_time).unwrap(); + self + } + pub fn values(&mut self, values: Vec) -> &mut Self { + self.values = values; + self + } + pub fn metadata(&mut self, metadata: Option>) -> &mut Self { + self.metadata = metadata; + self + } + /// Sets the metadata to a blank, blob-sequence-class metadata value with the given plugin + /// name. Overwrites any existing call to [`metadata`][Self::metadata]. + pub fn plugin_name(&mut self, plugin_name: &str) -> &mut Self { + self.metadata(Some(blank(plugin_name, pb::DataClass::BlobSequence))) + } + + /// Constructs a scalar time series from the state of this builder. + /// + /// # Panics + /// + /// If the wall time of a point would overflow to be infinite. + pub fn build(&self) -> TimeSeries { + let metadata = self + .metadata + .clone() + .unwrap_or_else(|| blank("blobs", pb::DataClass::BlobSequence)); + let mut time_series = TimeSeries::new(metadata); + + let mut rsv = StageReservoir::new(self.values.len()); + for (i, value) in self.values.iter().enumerate() { + let step = Step(self.step_start.0 + i as i64); + let wall_time = + WallTime::new(f64::from(self.wall_time_start) + (i as f64)).unwrap(); + rsv.offer(step, (wall_time, Ok(value.clone()))); + } + rsv.commit(&mut time_series.basin); + + time_series + } + } } diff --git a/tensorboard/data/server/server.rs b/tensorboard/data/server/server.rs index a0889edff7..30d72d8e03 100644 --- a/tensorboard/data/server/server.rs +++ b/tensorboard/data/server/server.rs @@ -25,7 +25,6 @@ use tonic::{Request, Response, Status}; use crate::commit::{self, Commit}; use crate::downsample; -use crate::proto::tensorboard as pb; use crate::proto::tensorboard::data; use crate::types::{Run, Tag, WallTime}; use data::tensor_board_data_provider_server::TensorBoardDataProvider; @@ -59,8 +58,9 @@ impl TensorBoardDataProvider for DataProviderHandler { let data = data .read() .map_err(|_| Status::internal(format!("failed to read run data for {:?}", run)))?; - for time_series in data.scalars.values() { - let metadata: &pb::SummaryMetadata = time_series.metadata.as_ref(); + for metadata in (data.scalars.values().map(|ts| ts.metadata.as_ref())) + .chain(data.blob_sequences.values().map(|ts| ts.metadata.as_ref())) + { let plugin_name = match &metadata.plugin_data { Some(d) => d.plugin_name.clone(), None => String::new(), @@ -348,6 +348,7 @@ mod tests { use tonic::Code; use crate::commit::test_data::CommitBuilder; + use crate::proto::tensorboard as pb; use crate::types::{Run, Step, Tag}; fn sample_handler(commit: Commit) -> DataProviderHandler { @@ -361,6 +362,9 @@ mod tests { async fn test_list_plugins() { let commit = CommitBuilder::new() .scalars("train", "xent", |b| b.build()) + .blob_sequences("train", "input_image", |mut b| { + b.plugin_name("images").build() + }) .build(); let handler = sample_handler(commit); let req = Request::new(data::ListPluginsRequest { @@ -368,8 +372,13 @@ mod tests { }); let res = handler.list_plugins(req).await.unwrap().into_inner(); assert_eq!( - res.plugins.into_iter().map(|p| p.name).collect::>(), - vec!["scalars"] + res.plugins + .iter() + .map(|p| p.name.as_str()) + .collect::>(), + vec!["scalars", "images"] + .into_iter() + .collect::>(), ); }