Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tensorboard/data/server/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,45 @@ pub mod test_data {
self
}

/// Adds a blob sequence time series, creating the run if it doesn't exist, and setting its
/// start time if unset.
///
/// # Examples
///
/// ```
/// use rustboard_core::commit::{test_data::CommitBuilder, BlobSequenceValue, Commit};
///
/// let my_commit: Commit = CommitBuilder::new()
/// .blob_sequences("train", "input_image", |mut b| {
/// b.plugin_name("images")
/// .values(vec![
/// BlobSequenceValue(vec![b"step0img0".to_vec()]),
/// BlobSequenceValue(vec![b"step1img0".to_vec(), b"step1img1".to_vec()]),
/// ])
/// .build()
/// })
/// .build();
/// ```
pub fn blob_sequences(
self,
run: &str,
tag: &str,
build: impl FnOnce(BlobSequenceTimeSeriesBuilder) -> TimeSeries<BlobSequenceValue>,
) -> Self {
self.with_run_data(Run(run.to_string()), |run_data| {
let time_series = build(BlobSequenceTimeSeriesBuilder::default());
if let (None, Some((_step, wall_time, _value))) =
(run_data.start_time, time_series.valid_values().next())
{
run_data.start_time = Some(wall_time);
}
run_data
.blob_sequences
.insert(Tag(tag.to_string()), time_series);
});
self
}

/// Ensures that a run is present and sets its start time.
///
/// If you don't care about the start time and the run is going to have data, anyway, you
Expand Down Expand Up @@ -307,4 +346,88 @@ pub mod test_data {
time_series
}
}

pub struct BlobSequenceTimeSeriesBuilder {
/// Initial step. Increments by `1` for each point.
step_start: Step,
/// Initial wall time. Increments by `1.0` for each point.
wall_time_start: WallTime,
/// Raw data for blob sequences in this time series. Defaults to
/// `vec![BlobSequenceValue(vec![])]`: i.e., one blob sequence, with one blob, which is
/// empty.
values: Vec<BlobSequenceValue>,
/// Custom summary metadata. Leave `None` to use default.
metadata: Option<Box<pb::SummaryMetadata>>,
}

impl Default for BlobSequenceTimeSeriesBuilder {
fn default() -> Self {
BlobSequenceTimeSeriesBuilder {
step_start: Step(0),
wall_time_start: WallTime::new(0.0).unwrap(),
values: vec![BlobSequenceValue(vec![])],
metadata: None,
}
}
}

/// Creates a summary metadata value with plugin name and data class, but no other contents.
fn blank(plugin_name: &str, data_class: pb::DataClass) -> Box<pb::SummaryMetadata> {
Box::new(pb::SummaryMetadata {
plugin_data: Some(pb::summary_metadata::PluginData {
plugin_name: plugin_name.to_string(),
..Default::default()
}),
data_class: data_class.into(),
..Default::default()
})
}

impl BlobSequenceTimeSeriesBuilder {
pub fn step_start(&mut self, raw_step: i64) -> &mut Self {
self.step_start = Step(raw_step);
self
}
pub fn wall_time_start(&mut self, raw_wall_time: f64) -> &mut Self {
self.wall_time_start = WallTime::new(raw_wall_time).unwrap();
self
}
pub fn values(&mut self, values: Vec<BlobSequenceValue>) -> &mut Self {
self.values = values;
self
}
pub fn metadata(&mut self, metadata: Option<Box<pb::SummaryMetadata>>) -> &mut Self {
self.metadata = metadata;
self
}
/// Sets the metadata to a blank, blob-sequence-class metadata value with the given plugin
/// name. Overwrites any existing call to [`metadata`][Self::metadata].
pub fn plugin_name(&mut self, plugin_name: &str) -> &mut Self {
self.metadata(Some(blank(plugin_name, pb::DataClass::BlobSequence)))
}

/// Constructs a scalar time series from the state of this builder.
///
/// # Panics
///
/// If the wall time of a point would overflow to be infinite.
pub fn build(&self) -> TimeSeries<BlobSequenceValue> {
let metadata = self
.metadata
.clone()
.unwrap_or_else(|| blank("blobs", pb::DataClass::BlobSequence));
let mut time_series = TimeSeries::new(metadata);

let mut rsv = StageReservoir::new(self.values.len());
for (i, value) in self.values.iter().enumerate() {
let step = Step(self.step_start.0 + i as i64);
let wall_time =
WallTime::new(f64::from(self.wall_time_start) + (i as f64)).unwrap();
rsv.offer(step, (wall_time, Ok(value.clone())));
}
rsv.commit(&mut time_series.basin);

time_series
}
}
}
19 changes: 14 additions & 5 deletions tensorboard/data/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use tonic::{Request, Response, Status};

use crate::commit::{self, Commit};
use crate::downsample;
use crate::proto::tensorboard as pb;
use crate::proto::tensorboard::data;
use crate::types::{Run, Tag, WallTime};
use data::tensor_board_data_provider_server::TensorBoardDataProvider;
Expand Down Expand Up @@ -59,8 +58,9 @@ impl TensorBoardDataProvider for DataProviderHandler {
let data = data
.read()
.map_err(|_| Status::internal(format!("failed to read run data for {:?}", run)))?;
for time_series in data.scalars.values() {
let metadata: &pb::SummaryMetadata = time_series.metadata.as_ref();
for metadata in (data.scalars.values().map(|ts| ts.metadata.as_ref()))
.chain(data.blob_sequences.values().map(|ts| ts.metadata.as_ref()))
{
let plugin_name = match &metadata.plugin_data {
Some(d) => d.plugin_name.clone(),
None => String::new(),
Expand Down Expand Up @@ -348,6 +348,7 @@ mod tests {
use tonic::Code;

use crate::commit::test_data::CommitBuilder;
use crate::proto::tensorboard as pb;
use crate::types::{Run, Step, Tag};

fn sample_handler(commit: Commit) -> DataProviderHandler {
Expand All @@ -361,15 +362,23 @@ mod tests {
async fn test_list_plugins() {
let commit = CommitBuilder::new()
.scalars("train", "xent", |b| b.build())
.blob_sequences("train", "input_image", |mut b| {
b.plugin_name("images").build()
})
.build();
let handler = sample_handler(commit);
let req = Request::new(data::ListPluginsRequest {
experiment_id: "123".to_string(),
});
let res = handler.list_plugins(req).await.unwrap().into_inner();
assert_eq!(
res.plugins.into_iter().map(|p| p.name).collect::<Vec<_>>(),
vec!["scalars"]
res.plugins
.iter()
.map(|p| p.name.as_str())
.collect::<HashSet<&str>>(),
vec!["scalars", "images"]
.into_iter()
.collect::<HashSet<&str>>(),
);
}

Expand Down