diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 0a8ba26d3a..e52db05ab6 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -38,6 +38,10 @@ pub(crate) mod plugin_names { pub const GRAPH_RUN_METADATA: &str = "graph_run_metadata"; pub const GRAPH_RUN_METADATA_WITH_GRAPH: &str = "graph_run_metadata_graph"; pub const GRAPH_KERAS_MODEL: &str = "graph_keras_model"; + pub const HISTOGRAMS: &str = "histograms"; + pub const TEXT: &str = "text"; + pub const PR_CURVES: &str = "pr_curves"; + pub const HPARAMS: &str = "hparams"; } /// The inner contents of a single value from an event. @@ -86,6 +90,23 @@ impl EventValue { } } + /// Consumes this event value and enriches it into a tensor. + /// + /// This supports summaries with `tensor` populated. + // + // TODO(#4422): support conversion of other summary types to tensors. + pub fn into_tensor(self, _metadata: &pb::SummaryMetadata) -> Result { + let value_box = match self { + EventValue::GraphDef(_) => return Err(DataLoss), + EventValue::TaggedRunMetadata(_) => return Err(DataLoss), + EventValue::Summary(SummaryValue(v)) => v, + }; + match *value_box { + pb::summary::value::Value::Tensor(tp) => Ok(tp), + _ => Err(DataLoss), + } + } + /// Consumes this event value and enriches it into a blob sequence. /// /// This supports: @@ -272,6 +293,12 @@ impl SummaryValue { Some(plugin_names::SCALARS) => { md.data_class = pb::DataClass::Scalar.into(); } + Some(plugin_names::HISTOGRAMS) + | Some(plugin_names::TEXT) + | Some(plugin_names::HPARAMS) + | Some(plugin_names::PR_CURVES) => { + md.data_class = pb::DataClass::Tensor.into(); + } Some(plugin_names::IMAGES) | Some(plugin_names::AUDIO) | Some(plugin_names::GRAPH_RUN_METADATA) @@ -446,7 +473,7 @@ mod tests { } #[test] - fn test_enrich_valid_tensors() { + fn test_enrich_rank_0_tensors() { let tensors = vec![ pb::TensorProto { dtype: pb::DataType::DtFloat.into(), @@ -480,7 +507,7 @@ mod tests { } #[test] - fn test_enrich_short_tensors() { + fn test_enrich_rank_0_tensors_corrupted_with_short_data() { let tensors = vec![ pb::TensorProto { dtype: pb::DataType::DtFloat.into(), @@ -508,7 +535,7 @@ mod tests { } #[test] - fn test_enrich_long_tensors() { + fn test_enrich_rank_0_tensors_corrupted_with_long_data() { let tensors = vec![ pb::TensorProto { dtype: pb::DataType::DtFloat.into(), @@ -569,7 +596,7 @@ mod tests { } #[test] - fn test_enrich_non_float_tensors() { + fn test_enrich_non_float_rank_0_tensors() { let tensors = vec![ pb::TensorProto { dtype: pb::DataType::DtString.into(), @@ -619,6 +646,67 @@ mod tests { } } + mod tensors { + use super::*; + + #[test] + fn test_metadata_tensor_with_dataclass() { + let md = blank_with_plugin_content( + "rando", + pb::DataClass::Tensor, + Bytes::from_static(b"preserved!"), + ); + let v = SummaryValue(Box::new(Value::Tensor(pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + }))); + let result = v.initial_metadata(Some(md.as_ref().clone())); + assert_eq!(*result, *md); + } + + #[test] + fn test_metadata_tensor_without_dataclass() { + for plugin_name in &[ + plugin_names::HISTOGRAMS, + plugin_names::TEXT, + plugin_names::PR_CURVES, + plugin_names::HPARAMS, + ] { + let md = blank_with_plugin_content( + plugin_name, + pb::DataClass::Unknown, + Bytes::from_static(b"preserved!"), + ); + let v = SummaryValue(Box::new(Value::Tensor(pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + }))); + let result = v.initial_metadata(Some(md.as_ref().clone())); + let expected = pb::SummaryMetadata { + data_class: pb::DataClass::Tensor.into(), + ..*md + }; + assert_eq!(*result, expected); + } + } + + #[test] + fn test_enrich_tensor() { + let tp = pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + }; + let v = EventValue::Summary(SummaryValue(Box::new(Value::Tensor(tp.clone())))); + assert_eq!( + v.into_tensor(&blank("mytensors", pb::DataClass::Tensor)), + Ok(tp) + ); + } + } + mod blob_sequences { use super::*; @@ -875,7 +963,7 @@ mod tests { } #[test] - fn test_enrich_scalar_tensor() { + fn test_enrich_rank_0_tensor() { let v = EventValue::Summary(SummaryValue(Box::new(Value::Tensor(pb::TensorProto { dtype: pb::DataType::DtString.into(), tensor_shape: Some(tensor_shape(&[])), @@ -990,7 +1078,7 @@ mod tests { use super::*; #[test] - fn test_custom_plugin_with_dataclass() { + fn test_metadata_custom_plugin_with_dataclass() { let md = pb::SummaryMetadata { plugin_data: Some(PluginData { plugin_name: "myplugin".to_string(), @@ -1007,7 +1095,7 @@ mod tests { } #[test] - fn test_unknown_plugin_no_dataclass() { + fn test_metadata_unknown_plugin_no_dataclass() { let md = pb::SummaryMetadata { plugin_data: Some(PluginData { plugin_name: "myplugin".to_string(), @@ -1022,7 +1110,7 @@ mod tests { } #[test] - fn test_empty() { + fn test_metadata_empty() { let v = SummaryValue(Box::new(Value::Tensor(pb::TensorProto::default()))); let result = v.initial_metadata(None); assert_eq!(*result, pb::SummaryMetadata::default()); diff --git a/tensorboard/data/server/run.rs b/tensorboard/data/server/run.rs index 9730511b3c..3cb6dead16 100644 --- a/tensorboard/data/server/run.rs +++ b/tensorboard/data/server/run.rs @@ -147,17 +147,7 @@ impl StageTimeSeries { use pb::DataClass; match self.data_class { DataClass::Scalar => self.commit_to(tag, &mut run.scalars, |ev, _| ev.into_scalar()), - DataClass::Tensor => { - warn!( - "Tensor time series not yet supported (tag: {:?}, plugin: {:?})", - tag.0, - self.metadata - .plugin_data - .as_ref() - .map(|p| p.plugin_name.as_str()) - .unwrap_or("") - ); - } + DataClass::Tensor => self.commit_to(tag, &mut run.tensors, EventValue::into_tensor), DataClass::BlobSequence => { self.commit_to(tag, &mut run.blob_sequences, EventValue::into_blob_sequence) } @@ -450,6 +440,23 @@ mod test { WallTime::new(1235.0).unwrap(), Bytes::from_static(b""), )?; + f1.write_tensor( + &Tag("weights".to_string()), + Step(0), + WallTime::new(1235.0).unwrap(), + pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + }, + pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: plugin_names::HISTOGRAMS.to_string(), + ..Default::default() + }), + ..Default::default() + }, + )?; f1.write_scalar(&tag, Step(0), WallTime::new(1235.0).unwrap(), 0.25)?; f1.write_scalar(&tag, Step(1), WallTime::new(1236.0).unwrap(), 0.50)?; f1.write_scalar(&tag, Step(2), WallTime::new(1237.0).unwrap(), 0.75)?; @@ -513,6 +520,32 @@ mod test { ] ); + assert_eq!(run_data.tensors.len(), 1); + let tensor_ts = run_data.tensors.get(&Tag("weights".to_string())).unwrap(); + assert_eq!( + *tensor_ts.metadata, + pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: plugin_names::HISTOGRAMS.to_string(), + ..Default::default() + }), + data_class: pb::DataClass::Tensor.into(), + ..Default::default() + } + ); + assert_eq!( + tensor_ts.valid_values().collect::>(), + vec![( + Step(0), + WallTime::new(1235.0).unwrap(), + &pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + } + )] + ); + assert_eq!(run_data.blob_sequences.len(), 2); let run_graph_tag = Tag(GraphDefValue::TAG_NAME.to_string()); diff --git a/tensorboard/data/server/writer.rs b/tensorboard/data/server/writer.rs index e05ce2e1f2..0a356cac3c 100644 --- a/tensorboard/data/server/writer.rs +++ b/tensorboard/data/server/writer.rs @@ -56,6 +56,32 @@ pub trait SummaryWriteExt: Write { self.write_event(&event) } + /// Writes a TFRecord containing a TF 2.x `tensor` summary. + fn write_tensor( + &mut self, + tag: &Tag, + step: Step, + wt: WallTime, + tensor: pb::TensorProto, + metadata: pb::SummaryMetadata, + ) -> std::io::Result<()> { + let event = pb::Event { + step: step.0, + wall_time: wt.into(), + what: Some(pb::event::What::Summary(pb::Summary { + value: vec![pb::summary::Value { + tag: tag.0.clone(), + value: Some(pb::summary::value::Value::Tensor(tensor)), + metadata: Some(metadata), + ..Default::default() + }], + ..Default::default() + })), + ..Default::default() + }; + self.write_event(&event) + } + /// Writes a TFRecord containing a TF 1.x `graph_def` event. fn write_graph(&mut self, step: Step, wt: WallTime, bytes: Bytes) -> std::io::Result<()> { let event = pb::Event { @@ -157,6 +183,52 @@ mod tests { assert_eq!(event, &expected); } + #[test] + fn test_tensor_roundtrip() { + let tensor_proto = pb::TensorProto { + dtype: pb::DataType::DtString.into(), + string_val: vec![Bytes::from_static(b"foo")], + ..Default::default() + }; + let summary_metadata = pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: "histograms".to_string(), + ..Default::default() + }), + ..Default::default() + }; + let mut cursor = Cursor::new(Vec::::new()); + cursor + .write_tensor( + &Tag("weights".to_string()), + Step(777), + WallTime::new(1234.5).unwrap(), + tensor_proto.clone(), + summary_metadata.clone(), + ) + .unwrap(); + cursor.set_position(0); + let events = read_all_events(cursor).unwrap(); + assert_eq!(events.len(), 1); + + let event = &events[0]; + let expected = pb::Event { + step: 777, + wall_time: 1234.5, + what: Some(pb::event::What::Summary(pb::Summary { + value: vec![pb::summary::Value { + tag: "weights".to_string(), + value: Some(pb::summary::value::Value::Tensor(tensor_proto)), + metadata: Some(summary_metadata), + ..Default::default() + }], + ..Default::default() + })), + ..Default::default() + }; + assert_eq!(event, &expected); + } + #[test] fn test_graph_roundtrip() { let mut cursor = Cursor::new(Vec::::new());