diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 1b9ccf0eda..cc7c3f69d6 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -24,6 +24,7 @@ use pb::summary_metadata::PluginData; pub(crate) const SCALARS_PLUGIN_NAME: &str = "scalars"; pub(crate) const IMAGES_PLUGIN_NAME: &str = "images"; +pub(crate) const AUDIO_PLUGIN_NAME: &str = "audio"; pub(crate) const GRAPHS_PLUGIN_NAME: &str = "graphs"; /// The inner contents of a single value from an event. @@ -71,11 +72,13 @@ impl EventValue { /// Consumes this event value and enriches it into a blob sequence. /// - /// For now, this supports `GraphDef`s, summaries with `image`, or summaries with `tensor` set - /// to a rank-1 tensor of type `DT_STRING`. + /// For now, this supports `GraphDef`s, summaries with `image` or `audio`, or summaries with + /// `tensor` set to a rank-1 tensor of type `DT_STRING`. If the summary metadata indicates that + /// this is audio data, `tensor` may also be a string tensor of shape `[k, 2]`, in which case + /// the second axis is assumed to represent string labels and is dropped entirely. pub fn into_blob_sequence( self, - _metadata: &pb::SummaryMetadata, + metadata: &pb::SummaryMetadata, ) -> Result { match self { EventValue::GraphDef(GraphDefValue(blob)) => Ok(BlobSequenceValue(vec![blob])), @@ -86,11 +89,29 @@ impl EventValue { let buf = im.encoded_image_string; Ok(BlobSequenceValue(vec![w, h, buf])) } - pb::summary::value::Value::Tensor(tp) => { - if tp.dtype == i32::from(pb::DataType::DtString) - && tp.tensor_shape.map_or(false, |shape| shape.dim.len() == 1) - { + pb::summary::value::Value::Audio(au) => { + Ok(BlobSequenceValue(vec![au.encoded_audio_string])) + } + pb::summary::value::Value::Tensor(mut tp) + if tp.dtype == i32::from(pb::DataType::DtString) => + { + let shape = tp.tensor_shape.unwrap_or_default(); + if shape.dim.len() == 1 { Ok(BlobSequenceValue(tp.string_val)) + } else if shape.dim.len() == 2 + && shape.dim[1].size == 2 + && (metadata + .plugin_data + .as_ref() + .map_or(false, |pd| pd.plugin_name == AUDIO_PLUGIN_NAME)) + { + // Extract just the actual audio clips along the first axis. + let audio: Vec> = tp + .string_val + .chunks_exact_mut(2) + .map(|chunk| std::mem::take(&mut chunk[0])) + .collect(); + Ok(BlobSequenceValue(audio)) } else { Err(DataLoss) } @@ -189,13 +210,14 @@ impl SummaryValue { (Some(md), _) if md.data_class != i32::from(pb::DataClass::Unknown) => Box::new(md), (_, Value::SimpleValue(_)) => blank(SCALARS_PLUGIN_NAME, pb::DataClass::Scalar), (_, Value::Image(_)) => blank(IMAGES_PLUGIN_NAME, pb::DataClass::BlobSequence), + (_, Value::Audio(_)) => blank(AUDIO_PLUGIN_NAME, pb::DataClass::BlobSequence), (Some(mut md), _) => { // Use given metadata, but first set data class based on plugin name, if known. match md.plugin_data.as_ref().map(|pd| pd.plugin_name.as_str()) { Some(SCALARS_PLUGIN_NAME) => { md.data_class = pb::DataClass::Scalar.into(); } - Some(IMAGES_PLUGIN_NAME) => { + Some(IMAGES_PLUGIN_NAME) | Some(AUDIO_PLUGIN_NAME) => { md.data_class = pb::DataClass::BlobSequence.into(); } _ => {} @@ -552,6 +574,60 @@ mod tests { ); } + #[test] + fn test_metadata_tf1x_audio() { + let v = SummaryValue(Box::new(Value::Audio(pb::summary::Audio { + sample_rate: 44100.0, + encoded_audio_string: b"RIFFabcd".to_vec(), + ..Default::default() + }))); + let result = v.initial_metadata(None); + + assert_eq!( + *result, + pb::SummaryMetadata { + plugin_data: Some(PluginData { + plugin_name: AUDIO_PLUGIN_NAME.to_string(), + ..Default::default() + }), + data_class: pb::DataClass::BlobSequence.into(), + ..Default::default() + } + ); + } + + #[test] + fn test_metadata_tf2x_audio_without_dataclass() { + let md = pb::SummaryMetadata { + plugin_data: Some(PluginData { + plugin_name: AUDIO_PLUGIN_NAME.to_string(), + content: b"preserved!".to_vec(), + ..Default::default() + }), + ..Default::default() + }; + let v = SummaryValue(Box::new(Value::Tensor(pb::TensorProto { + dtype: pb::DataType::DtString.into(), + tensor_shape: Some(tensor_shape(&[1, 2])), + string_val: vec![b"\x89PNGabc".to_vec(), b"label".to_vec()], + ..Default::default() + }))); + let result = v.initial_metadata(Some(md)); + + assert_eq!( + *result, + pb::SummaryMetadata { + plugin_data: Some(PluginData { + plugin_name: AUDIO_PLUGIN_NAME.to_string(), + content: b"preserved!".to_vec(), + ..Default::default() + }), + data_class: pb::DataClass::BlobSequence.into(), + ..Default::default() + } + ); + } + #[test] fn test_enrich_graph_def() { let v = EventValue::GraphDef(GraphDefValue(vec![1, 2, 3, 4])); @@ -672,6 +748,70 @@ mod tests { Err(DataLoss) ); } + + #[test] + fn test_enrich_tf1x_audio() { + let v = SummaryValue(Box::new(Value::Audio(pb::summary::Audio { + sample_rate: 44100.0, + encoded_audio_string: b"RIFFabcd".to_vec(), + ..Default::default() + }))); + let md = v.initial_metadata(None); + let expected = BlobSequenceValue(vec![b"RIFFabcd".to_vec()]); + assert_eq!( + EventValue::Summary(v).into_blob_sequence(md.as_ref()), + Ok(expected) + ); + } + + #[test] + fn test_enrich_audio_without_labels() { + let v = EventValue::Summary(SummaryValue(Box::new(Value::Tensor(pb::TensorProto { + dtype: pb::DataType::DtString.into(), + tensor_shape: Some(tensor_shape(&[3])), + string_val: vec![ + b"RIFFwav0".to_vec(), + b"RIFFwav1".to_vec(), + b"RIFFwav2".to_vec(), + ], + ..Default::default() + })))); + let expected = BlobSequenceValue(vec![ + b"RIFFwav0".to_vec(), + b"RIFFwav1".to_vec(), + b"RIFFwav2".to_vec(), + ]); + assert_eq!( + v.into_blob_sequence(&blank(AUDIO_PLUGIN_NAME, pb::DataClass::BlobSequence)), + Ok(expected) + ); + } + + #[test] + fn test_enrich_audio_with_labels() { + let v = EventValue::Summary(SummaryValue(Box::new(Value::Tensor(pb::TensorProto { + dtype: pb::DataType::DtString.into(), + tensor_shape: Some(tensor_shape(&[3, 2])), + string_val: vec![ + b"RIFFwav0".to_vec(), + b"label 0".to_vec(), + b"RIFFwav1".to_vec(), + b"label 1".to_vec(), + b"RIFFwav2".to_vec(), + b"label 2".to_vec(), + ], + ..Default::default() + })))); + let expected = BlobSequenceValue(vec![ + b"RIFFwav0".to_vec(), + b"RIFFwav1".to_vec(), + b"RIFFwav2".to_vec(), + ]); + assert_eq!( + v.into_blob_sequence(&blank(AUDIO_PLUGIN_NAME, pb::DataClass::BlobSequence)), + Ok(expected) + ); + } } mod unknown {