Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tensorboard/data/server/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ pub struct RunData {

/// Scalar time series for this run.
pub scalars: TagStore<ScalarValue>,

/// Blob sequence time series for this run.
pub blob_sequences: TagStore<BlobSequenceValue>,
}

pub type TagStore<V> = HashMap<Tag, TimeSeries<V>>;
Expand Down Expand Up @@ -105,6 +108,12 @@ pub struct DataLoss;
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct ScalarValue(pub f32);

/// The value of a blob sequence time series at a single point.
///
/// This value is a sequence of zero or more blobs, stored in memory.
#[derive(Debug, Clone, PartialEq)]
pub struct BlobSequenceValue(pub Vec<Vec<u8>>);

#[cfg(test)]
mod tests {
use super::*;
Expand Down
28 changes: 26 additions & 2 deletions tensorboard/data/server/data_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
use std::convert::TryInto;
use std::fmt::Debug;

use crate::commit::{DataLoss, ScalarValue};
use crate::commit::{BlobSequenceValue, DataLoss, ScalarValue};
use crate::proto::tensorboard as pb;
use pb::summary_metadata::PluginData;

Expand Down Expand Up @@ -67,6 +67,16 @@ impl EventValue {
_ => Err(DataLoss),
}
}

/// Consumes this event value and enriches it into a blob sequence.
///
/// For now, this succeeds only for graphs.
pub fn into_blob_sequence(self) -> Result<BlobSequenceValue, DataLoss> {
match self {
EventValue::Summary(_) => Err(DataLoss),
EventValue::GraphDef(GraphDefValue(blob)) => Ok(BlobSequenceValue(vec![blob])),
}
}
}

fn tensor_proto_to_scalar(tp: &pb::TensorProto) -> Option<f32> {
Expand Down Expand Up @@ -118,6 +128,11 @@ pub struct GraphDefValue(pub Vec<u8>);
pub struct SummaryValue(pub Box<pb::summary::value::Value>);

impl GraphDefValue {
/// Tag name used for run-level graphs.
///
/// This must match `tensorboard.plugins.graph.metadata.RUN_GRAPH_NAME`.
pub const TAG_NAME: &'static str = "__run_graph__";

/// Determines the metadata for a time series whose first event is a
/// [`GraphDef`][`EventValue::GraphDef`].
pub fn initial_metadata() -> Box<pb::SummaryMetadata> {
Expand Down Expand Up @@ -450,11 +465,20 @@ mod tests {
use super::*;

#[test]
fn test() {
fn test_metadata() {
let md = GraphDefValue::initial_metadata();
assert_eq!(&md.plugin_data.unwrap().plugin_name, GRAPHS_PLUGIN_NAME);
assert_eq!(md.data_class, i32::from(pb::DataClass::BlobSequence));
}

#[test]
fn test_enrich_graph_def() {
let v = EventValue::GraphDef(GraphDefValue(vec![1, 2, 3, 4]));
assert_eq!(
v.into_blob_sequence(),
Ok(BlobSequenceValue(vec![vec![1, 2, 3, 4]]))
);
}
}

mod unknown {
Expand Down
66 changes: 49 additions & 17 deletions tensorboard/data/server/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::path::PathBuf;
use std::sync::RwLock;

use crate::commit;
use crate::data_compat::{EventValue, SummaryValue};
use crate::data_compat::{EventValue, GraphDefValue, SummaryValue};
use crate::event_file::EventFileReader;
use crate::proto::tensorboard as pb;
use crate::reservoir::StageReservoir;
Expand Down Expand Up @@ -117,15 +117,7 @@ impl StageTimeSeries {
);
}
DataClass::BlobSequence => {
warn!(
"Blob sequence time series not yet supported (tag: {}, plugin: {})",
tag.0,
self.metadata
.plugin_data
.as_ref()
.map(|p| p.plugin_name.as_str())
.unwrap_or("")
);
self.commit_to(tag, &mut run.blob_sequences, EventValue::into_blob_sequence)
}
_ => (),
};
Expand Down Expand Up @@ -270,9 +262,19 @@ fn read_event(
*start_time = Some(wall_time);
}
match e.what {
Some(pb::event::What::GraphDef(_)) => {
// TODO(@wchargin): Handle run graphs.
warn!("`graph_def` events not yet handled");
Some(pb::event::What::GraphDef(graph_bytes)) => {
let sv = StageValue {
wall_time,
payload: EventValue::GraphDef(GraphDefValue(graph_bytes)),
};
use std::collections::hash_map::Entry;
let ts = match time_series.entry(Tag(GraphDefValue::TAG_NAME.to_string())) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
v.insert(StageTimeSeries::new(GraphDefValue::initial_metadata()))
}
};
ts.rsv.offer(step, sv);
}
Some(pb::event::What::Summary(sum)) => {
for mut summary_pb_value in sum.value {
Expand Down Expand Up @@ -338,6 +340,11 @@ mod test {
// Write some data points across both files.
let run = Run("train".to_string());
let tag = Tag("accuracy".to_string());
f1.write_graph(
Step(0),
WallTime::new(1235.0).unwrap(),
b"<sample model graph>".to_vec(),
)?;
f1.write_scalar(&tag, Step(0), WallTime::new(1235.0).unwrap(), 0.25)?;
f1.write_scalar(&tag, Step(1), WallTime::new(1236.0).unwrap(), 0.50)?;
f1.write_scalar(&tag, Step(2), WallTime::new(1237.0).unwrap(), 0.75)?;
Expand Down Expand Up @@ -371,9 +378,9 @@ mod test {
.expect("read-locking run data map");

assert_eq!(run_data.scalars.keys().collect::<Vec<_>>(), vec![&tag]);
let ts = run_data.scalars.get(&tag).unwrap();
let scalar_ts = run_data.scalars.get(&tag).unwrap();
assert_eq!(
*ts.metadata,
*scalar_ts.metadata,
pb::SummaryMetadata {
plugin_data: Some(pb::summary_metadata::PluginData {
plugin_name: crate::data_compat::SCALARS_PLUGIN_NAME.to_string(),
Expand All @@ -383,11 +390,10 @@ mod test {
..Default::default()
}
);

// Points should be as expected (no downsampling at these sizes).
let scalar = commit::ScalarValue;
assert_eq!(
ts.valid_values().collect::<Vec<_>>(),
scalar_ts.valid_values().collect::<Vec<_>>(),
vec![
(Step(0), WallTime::new(1235.0).unwrap(), &scalar(0.25)),
(Step(1), WallTime::new(1236.0).unwrap(), &scalar(0.50)),
Expand All @@ -397,6 +403,32 @@ mod test {
]
);

let run_graph_tag = Tag(GraphDefValue::TAG_NAME.to_string());
assert_eq!(
run_data.blob_sequences.keys().collect::<Vec<_>>(),
vec![&run_graph_tag]
);
let graph_ts = run_data.blob_sequences.get(&run_graph_tag).unwrap();
assert_eq!(
*graph_ts.metadata,
pb::SummaryMetadata {
plugin_data: Some(pb::summary_metadata::PluginData {
plugin_name: crate::data_compat::GRAPHS_PLUGIN_NAME.to_string(),
..Default::default()
}),
data_class: pb::DataClass::BlobSequence.into(),
..Default::default()
}
);
assert_eq!(
graph_ts.valid_values().collect::<Vec<_>>(),
vec![(
Step(0),
WallTime::new(1235.0).unwrap(),
&commit::BlobSequenceValue(vec![b"<sample model graph>".to_vec()])
)]
);

Ok(())
}
}
35 changes: 35 additions & 0 deletions tensorboard/data/server/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ pub trait SummaryWriteExt: Write {
};
self.write_event(&event)
}

/// Writes a TFRecord containing a TF 1.x `graph_def` event.
fn write_graph(&mut self, step: Step, wt: WallTime, bytes: Vec<u8>) -> std::io::Result<()> {
let event = pb::Event {
step: step.0,
wall_time: wt.into(),
what: Some(pb::event::What::GraphDef(bytes)),
..Default::default()
};
self.write_event(&event)
}
}

impl<W: Write> SummaryWriteExt for W {}
Expand Down Expand Up @@ -123,4 +134,28 @@ mod tests {
};
assert_eq!(event, &expected);
}

#[test]
fn test_graph_roundtrip() {
let mut cursor = Cursor::new(Vec::<u8>::new());
cursor
.write_graph(
Step(777),
WallTime::new(1234.5).unwrap(),
b"my graph".to_vec(),
)
.unwrap();
cursor.set_position(0);
let events = read_all_events(cursor).unwrap();
assert_eq!(events.len(), 1);

let event = &events[0];
let expected = pb::Event {
step: 777,
wall_time: 1234.5,
what: Some(pb::event::What::GraphDef(b"my graph".to_vec())),
..Default::default()
};
assert_eq!(event, &expected);
}
}