diff --git a/tensorboard/data/grpc_provider.py b/tensorboard/data/grpc_provider.py index 329e09a8d8..656123d894 100644 --- a/tensorboard/data/grpc_provider.py +++ b/tensorboard/data/grpc_provider.py @@ -87,13 +87,13 @@ def list_scalars( tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: - ts = tag_entry.metadata + time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.ScalarTimeSeries( - max_step=ts.max_step, - max_wall_time=ts.max_wall_time, - plugin_content=ts.summary_metadata.plugin_data.content, - description=ts.summary_metadata.summary_description, - display_name=ts.summary_metadata.display_name, + max_step=time_series.max_step, + max_wall_time=time_series.max_wall_time, + plugin_content=time_series.summary_metadata.plugin_data.content, + description=time_series.summary_metadata.summary_description, + display_name=time_series.summary_metadata.display_name, ) return result @@ -126,14 +126,97 @@ def read_scalars( tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, value) in zip(d.step, d.wall_time, d.value): - pt = provider.ScalarDatum( + point = provider.ScalarDatum( step=step, wall_time=wt, value=value, ) - series.append(pt) + series.append(point) return result + @timing.log_latency + def list_blob_sequences( + self, ctx, experiment_id, plugin_name, run_tag_filter=None + ): + with timing.log_latency("build request"): + req = data_provider_pb2.ListBlobSequencesRequest() + req.experiment_id = experiment_id + req.plugin_filter.plugin_name = plugin_name + _populate_rtf(run_tag_filter, req.run_tag_filter) + with timing.log_latency("_stub.ListBlobSequences"): + with _translate_grpc_error(): + res = self._stub.ListBlobSequences(req) + with timing.log_latency("build result"): + result = {} + for run_entry in res.runs: + tags = {} + result[run_entry.run_name] = tags + for tag_entry in run_entry.tags: + time_series = tag_entry.metadata + tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries( + max_step=time_series.max_step, + max_wall_time=time_series.max_wall_time, + max_length=time_series.max_length, + plugin_content=time_series.summary_metadata.plugin_data.content, + description=time_series.summary_metadata.summary_description, + display_name=time_series.summary_metadata.display_name, + ) + return result + + @timing.log_latency + def read_blob_sequences( + self, + ctx, + experiment_id, + plugin_name, + downsample=None, + run_tag_filter=None, + ): + with timing.log_latency("build request"): + req = data_provider_pb2.ReadBlobSequencesRequest() + req.experiment_id = experiment_id + req.plugin_filter.plugin_name = plugin_name + _populate_rtf(run_tag_filter, req.run_tag_filter) + req.downsample.num_points = downsample + with timing.log_latency("_stub.ReadBlobSequences"): + with _translate_grpc_error(): + res = self._stub.ReadBlobSequences(req) + with timing.log_latency("build result"): + result = {} + for run_entry in res.runs: + tags = {} + result[run_entry.run_name] = tags + for tag_entry in run_entry.tags: + series = [] + tags[tag_entry.tag_name] = series + d = tag_entry.data + for (step, wt, blob_sequence) in zip( + d.step, d.wall_time, d.values + ): + values = [] + for ref in blob_sequence.blob_refs: + values.append( + provider.BlobReference( + blob_key=ref.blob_key, url=ref.url or None + ) + ) + point = provider.BlobSequenceDatum( + step=step, wall_time=wt, values=tuple(values) + ) + series.append(point) + return result + + @timing.log_latency + def read_blob(self, ctx, blob_key): + with timing.log_latency("build request"): + req = data_provider_pb2.ReadBlobRequest() + req.blob_key = blob_key + with timing.log_latency("list(_stub.ReadBlob)"): + with _translate_grpc_error(): + responses = list(self._stub.ReadBlob(req)) + with timing.log_latency("build result"): + return b"".join(res.data for res in responses) + @contextlib.contextmanager def _translate_grpc_error(): diff --git a/tensorboard/data/grpc_provider_test.py b/tensorboard/data/grpc_provider_test.py index 3cd715f223..dd7a66a10b 100644 --- a/tensorboard/data/grpc_provider_test.py +++ b/tensorboard/data/grpc_provider_test.py @@ -174,6 +174,118 @@ def test_read_scalars(self): req.downsample.num_points = 4 self.stub.ReadScalars.assert_called_once_with(req) + def test_list_blob_sequences(self): + res = data_provider_pb2.ListBlobSequencesResponse() + run1 = res.runs.add(run_name="train") + tag11 = run1.tags.add(tag_name="input_image") + tag11.metadata.max_step = 7 + tag11.metadata.max_wall_time = 7.77 + tag11.metadata.max_length = 3 + tag11.metadata.summary_metadata.plugin_data.content = b"PNG" + tag11.metadata.summary_metadata.display_name = "Input image" + tag11.metadata.summary_metadata.summary_description = "img" + self.stub.ListBlobSequences.return_value = res + + actual = self.provider.list_blob_sequences( + self.ctx, + experiment_id="123", + plugin_name="images", + run_tag_filter=provider.RunTagFilter(runs=["val", "train"]), + ) + expected = { + "train": { + "input_image": provider.BlobSequenceTimeSeries( + max_step=7, + max_wall_time=7.77, + max_length=3, + plugin_content=b"PNG", + description="img", + display_name="Input image", + ), + }, + } + self.assertEqual(actual, expected) + + req = data_provider_pb2.ListBlobSequencesRequest() + req.experiment_id = "123" + req.plugin_filter.plugin_name = "images" + req.run_tag_filter.runs.names.extend(["train", "val"]) # sorted + self.stub.ListBlobSequences.assert_called_once_with(req) + + def test_read_blob_sequences(self): + res = data_provider_pb2.ReadBlobSequencesResponse() + run = res.runs.add(run_name="test") + tag = run.tags.add(tag_name="input_image") + tag.data.step.extend([0, 1]) + tag.data.wall_time.extend([1234.0, 1235.0]) + seq0 = tag.data.values.add() + seq0.blob_refs.add(blob_key="step0img0") + seq0.blob_refs.add(blob_key="step0img1") + seq1 = tag.data.values.add() + seq1.blob_refs.add(blob_key="step1img0") + self.stub.ReadBlobSequences.return_value = res + + actual = self.provider.read_blob_sequences( + self.ctx, + experiment_id="123", + plugin_name="images", + run_tag_filter=provider.RunTagFilter(runs=["test", "nope"]), + downsample=4, + ) + expected = { + "test": { + "input_image": [ + provider.BlobSequenceDatum( + step=0, + wall_time=1234.0, + values=( + provider.BlobReference(blob_key="step0img0"), + provider.BlobReference(blob_key="step0img1"), + ), + ), + provider.BlobSequenceDatum( + step=1, + wall_time=1235.0, + values=(provider.BlobReference(blob_key="step1img0"),), + ), + ], + }, + } + self.assertEqual(actual, expected) + + req = data_provider_pb2.ReadBlobSequencesRequest() + req.experiment_id = "123" + req.plugin_filter.plugin_name = "images" + req.run_tag_filter.runs.names.extend(["nope", "test"]) # sorted + req.downsample.num_points = 4 + self.stub.ReadBlobSequences.assert_called_once_with(req) + + def test_read_blob(self): + responses = [ + data_provider_pb2.ReadBlobResponse(data=b"hello wo"), + data_provider_pb2.ReadBlobResponse(data=b"rld"), + ] + self.stub.ReadBlob.return_value = responses + + actual = self.provider.read_blob(self.ctx, blob_key="myblob") + expected = b"hello world" + self.assertEqual(actual, expected) + + req = data_provider_pb2.ReadBlobRequest() + req.blob_key = "myblob" + self.stub.ReadBlob.assert_called_once_with(req) + + def test_read_blob_error(self): + def fake_handler(req): + del req # unused + yield data_provider_pb2.ReadBlobResponse(data=b"hello wo"), + raise _grpc_error(grpc.StatusCode.NOT_FOUND, "it ran away!") + + self.stub.ReadBlob.side_effect = fake_handler + + with self.assertRaisesRegex(errors.NotFoundError, "it ran away!"): + self.provider.read_blob(self.ctx, blob_key="myblob") + def test_rpc_error(self): # This error handling is implemented with a context manager used # for all the methods, so take `list_plugins` as representative. diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index e79f6ecbcd..16c1b626cc 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -45,6 +45,7 @@ rust_library( ] + _checked_in_proto_files, edition = "2018", deps = [ + "//third_party/rust:async_stream", "//third_party/rust:base64", "//third_party/rust:byteorder", "//third_party/rust:clap", @@ -59,6 +60,7 @@ rust_library( "//third_party/rust:serde_json", "//third_party/rust:thiserror", "//third_party/rust:tokio", + "//third_party/rust:tokio_stream", "//third_party/rust:tonic", "//third_party/rust:walkdir", ], diff --git a/tensorboard/data/server/server.rs b/tensorboard/data/server/server.rs index 30d72d8e03..f05bd20b57 100644 --- a/tensorboard/data/server/server.rs +++ b/tensorboard/data/server/server.rs @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +use async_stream::try_stream; use futures_core::Stream; -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryInto; @@ -23,8 +24,10 @@ use std::pin::Pin; use std::sync::{RwLock, RwLockReadGuard}; use tonic::{Request, Response, Status}; -use crate::commit::{self, Commit}; +use crate::blob_key::BlobKey; +use crate::commit::{self, BlobSequenceValue, Commit}; use crate::downsample; +use crate::proto::tensorboard as pb; use crate::proto::tensorboard::data; use crate::types::{Run, Tag, WallTime}; use data::tensor_board_data_provider_server::TensorBoardDataProvider; @@ -45,6 +48,13 @@ impl DataProviderHandler { } } +/// Maximum size (in bytes) of the `data` field of any single [`data::ReadBlobResponse`]. +const BLOB_CHUNK_SIZE: usize = 1024 * 1024 * 8; + +fn plugin_name(md: &pb::SummaryMetadata) -> Option<&str> { + md.plugin_data.as_ref().map(|pd| pd.plugin_name.as_str()) +} + #[tonic::async_trait] impl TensorBoardDataProvider for DataProviderHandler { async fn list_plugins( @@ -132,12 +142,7 @@ impl TensorBoardDataProvider for DataProviderHandler { if !tag_filter.want(tag) { continue; } - let plugin_name = ts - .metadata - .plugin_data - .as_ref() - .map(|pd| pd.plugin_name.as_str()); - if plugin_name != Some(&want_plugin) { + if plugin_name(&ts.metadata) != Some(&want_plugin) { continue; } let max_step = match ts.valid_values().last() { @@ -192,12 +197,7 @@ impl TensorBoardDataProvider for DataProviderHandler { if !tag_filter.want(tag) { continue; } - let plugin_name = ts - .metadata - .plugin_data - .as_ref() - .map(|pd| pd.plugin_name.as_str()); - if plugin_name != Some(&want_plugin) { + if plugin_name(&ts.metadata) != Some(&want_plugin) { continue; } @@ -247,16 +247,140 @@ impl TensorBoardDataProvider for DataProviderHandler { async fn list_blob_sequences( &self, - _request: Request, + req: Request, ) -> Result, Status> { - Err(Status::unimplemented("not yet implemented")) + let req = req.into_inner(); + let want_plugin = parse_plugin_filter(req.plugin_filter)?; + let (run_filter, tag_filter) = parse_rtf(req.run_tag_filter); + let runs = self.read_runs()?; + + let mut res: data::ListBlobSequencesResponse = Default::default(); + for (run, data) in runs.iter() { + if !run_filter.want(run) { + continue; + } + let data = data + .read() + .map_err(|_| Status::internal(format!("failed to read run data for {:?}", run)))?; + let mut run_res: data::list_blob_sequences_response::RunEntry = Default::default(); + for (tag, ts) in &data.blob_sequences { + if !tag_filter.want(tag) { + continue; + } + if plugin_name(&ts.metadata) != Some(&want_plugin) { + continue; + } + let (mut max_step, mut max_wall_time, mut max_length) = (None, None, None); + for (step, wall_time, value) in ts.valid_values() { + if max_step.map_or(true, |s| s < step) { + max_step = Some(step); + } + if max_wall_time.map_or(true, |wt| wt < wall_time) { + max_wall_time = Some(wall_time); + } + if max_length.map_or(true, |len| len < value.0.len()) { + max_length = Some(value.0.len()); + } + } + let (max_step, max_wall_time, max_length) = + match (max_step, max_wall_time, max_length) { + (Some(s), Some(wt), Some(len)) => (s, wt, len), + _ => continue, + }; + run_res + .tags + .push(data::list_blob_sequences_response::TagEntry { + tag_name: tag.0.clone(), + metadata: Some(data::BlobSequenceMetadata { + max_step: max_step.into(), + max_wall_time: max_wall_time.into(), + max_length: max_length as i64, + summary_metadata: Some(*ts.metadata.clone()), + }), + }); + } + if !run_res.tags.is_empty() { + run_res.run_name = run.0.clone(); + res.runs.push(run_res); + } + } + + Ok(Response::new(res)) } async fn read_blob_sequences( &self, - _request: Request, + req: Request, ) -> Result, Status> { - Err(Status::unimplemented("not yet implemented")) + let req = req.into_inner(); + let want_plugin = parse_plugin_filter(req.plugin_filter)?; + let (run_filter, tag_filter) = parse_rtf(req.run_tag_filter); + let num_points = parse_downsample(req.downsample)?; + let runs = self.read_runs()?; + + let mut res: data::ReadBlobSequencesResponse = Default::default(); + for (run, data) in runs.iter() { + if !run_filter.want(run) { + continue; + } + let data = data + .read() + .map_err(|_| Status::internal(format!("failed to read run data for {:?}", run)))?; + let mut run_res: data::read_blob_sequences_response::RunEntry = Default::default(); + for (tag, ts) in &data.blob_sequences { + if !tag_filter.want(tag) { + continue; + } + if plugin_name(&ts.metadata) != Some(&want_plugin) { + continue; + } + + let mut points = ts.valid_values().collect::>(); + downsample::downsample(&mut points, num_points); + let n = points.len(); + let mut steps = Vec::with_capacity(n); + let mut wall_times = Vec::with_capacity(n); + let mut values = Vec::with_capacity(n); + for (step, wall_time, &BlobSequenceValue(ref value)) in points { + steps.push(step.into()); + wall_times.push(wall_time.into()); + let eid = req.experiment_id.as_str(); + let blob_refs = (0..value.len()) + .map(|i| { + let bk = BlobKey { + experiment_id: Cow::Borrowed(eid), + run: Cow::Borrowed(run.0.as_str()), + tag: Cow::Borrowed(tag.0.as_str()), + step, + index: i, + }; + data::BlobReference { + blob_key: bk.to_string(), + url: String::new(), + } + }) + .collect::>(); + values.push(data::BlobReferenceSequence { blob_refs }); + } + + run_res + .tags + .push(data::read_blob_sequences_response::TagEntry { + tag_name: tag.0.clone(), + data: Some(data::BlobSequenceData { + step: steps, + wall_time: wall_times, + values, + }), + }); + } + if !run_res.tags.is_empty() { + run_res.run_name = run.0.clone(); + res.runs.push(run_res); + } + } + + Ok(Response::new(res)) } type ReadBlobStream = @@ -264,9 +388,67 @@ impl TensorBoardDataProvider for DataProviderHandler { async fn read_blob( &self, - _request: Request, + req: Request, ) -> Result, Status> { - Err(Status::unimplemented("not yet implemented")) + let req = req.into_inner(); + let bk: BlobKey = req + .blob_key + .parse() + .map_err(|e| Status::invalid_argument(format!("failed to parse blob key: {:?}", e,)))?; + + let runs = self.read_runs()?; + let run_data = runs + .get(bk.run.as_ref()) + .ok_or_else(|| Status::not_found(format!("no such run: {:?}", bk.run)))? + .read() + .map_err(|_| Status::internal(format!("failed to read run data for {:?}", bk.run)))?; + let ts = run_data + .blob_sequences + .get(bk.tag.as_ref()) + .ok_or_else(|| { + Status::not_found(format!("run {:?} has no such tag: {:?}", bk.run, bk.tag)) + })?; + let datum = ts + .valid_values() + .find_map( + |(step, _, value)| { + if step == bk.step { + Some(value) + } else { + None + } + }, + ) + .ok_or_else(|| { + Status::not_found(format!( + "run {:?}, tag {:?} has no step {}; may have been evicted", + bk.run, bk.tag, bk.step.0 + )) + })?; + let blobs = &datum.0; + let blob = blobs.get(bk.index).ok_or_else(|| { + Status::not_found(format!( + "blob sequence at run {:?}, tag {:?}, step {:?} has no index {} (length: {})", + bk.run, + bk.tag, + bk.step, + bk.index, + blobs.len() + )) + })?; + // Clone blob so that we can send it down to the client after dropping the lock. + // TODO(@wchargin): Consider replacing this with an `Arc<[u8]>`. + let blob = blob.clone(); + drop(run_data); + drop(runs); + + let stream = try_stream! { + for chunk in blob.chunks(BLOB_CHUNK_SIZE) { + yield data::ReadBlobResponse {data: chunk.to_vec()}; + } + }; + + Ok(Response::new(Box::pin(stream) as Self::ReadBlobStream)) } } @@ -345,10 +527,10 @@ impl Filter { #[allow(clippy::float_cmp)] mod tests { use super::*; + use tokio_stream::StreamExt; use tonic::Code; use crate::commit::test_data::CommitBuilder; - use crate::proto::tensorboard as pb; use crate::types::{Run, Step, Tag}; fn sample_handler(commit: Commit) -> DataProviderHandler { @@ -637,4 +819,115 @@ mod tests { let xent_data = &train_run[&Tag("xent".to_string())].data.as_ref().unwrap(); assert_eq!(xent_data.value, Vec::::new()); } + + #[tokio::test] + async fn test_blob_sequences() { + let commit = CommitBuilder::new() + .scalars("train", "accuracy", |b| b.build()) + .blob_sequences("train", "input", |mut b| { + b.plugin_name("images") + .wall_time_start(1234.0) + .values(vec![ + BlobSequenceValue(vec![b"step0img0".to_vec(), b"step0img1".to_vec()]), + BlobSequenceValue(vec![b"z".repeat(BLOB_CHUNK_SIZE * 3 / 2)]), + ]) + .build() + }) + .blob_sequences("another_run", "input", |mut b| { + b.plugin_name("not_images").build() + }) + .build(); + let handler = sample_handler(commit); + + // List blob sequences and check the response exactly. It doesn't have any blob keys, so + // the exact value is well defined. + let list_req = Request::new(data::ListBlobSequencesRequest { + experiment_id: "123".to_string(), + plugin_filter: Some(data::PluginFilter { + plugin_name: "images".to_string(), + }), + run_tag_filter: None, + }); + let list_res = handler + .list_blob_sequences(list_req) + .await + .expect("ListBlobSequences") + .into_inner(); + assert_eq!( + list_res, + data::ListBlobSequencesResponse { + runs: vec![data::list_blob_sequences_response::RunEntry { + run_name: "train".to_string(), + tags: vec![data::list_blob_sequences_response::TagEntry { + tag_name: "input".to_string(), + metadata: Some(data::BlobSequenceMetadata { + max_step: 1, + max_wall_time: 1235.0, + max_length: 2, + summary_metadata: Some(pb::SummaryMetadata { + plugin_data: Some(pb::summary_metadata::PluginData { + plugin_name: "images".to_string(), + ..Default::default() + }), + data_class: pb::DataClass::BlobSequence.into(), + ..Default::default() + }), + }), + }], + }], + } + ); + + // Read blob sequences and check that its structure is right. The actual blob keys are + // opaque, so we don't expect any specific values. + let read_req = Request::new(data::ReadBlobSequencesRequest { + experiment_id: "123".to_string(), + plugin_filter: Some(data::PluginFilter { + plugin_name: "images".to_string(), + }), + downsample: Some(data::Downsample { num_points: 1000 }), + run_tag_filter: Some(data::RunTagFilter { + runs: Some(data::RunFilter { + names: vec!["train".to_string()], + }), + tags: Some(data::TagFilter { + names: vec!["input".to_string()], + }), + }), + }); + let read_res = handler + .read_blob_sequences(read_req) + .await + .expect("ReadBlobSequences") + .into_inner(); + assert_eq!(read_res.runs.len(), 1); + assert_eq!(read_res.runs[0].tags.len(), 1); + let data = (read_res.runs[0].tags[0].data.as_ref()).expect("blob sequence data"); + + assert_eq!(data.step, vec![0, 1]); + assert_eq!(data.wall_time, vec![1234.0, 1235.0]); + assert_eq!(data.values.len(), 2); + assert_eq!(data.values[0].blob_refs.len(), 2); + assert_eq!(data.values[1].blob_refs.len(), 1); + + // Read the blob that's supposed to take multiple chunks. + let blob_req = Request::new(data::ReadBlobRequest { + blob_key: data.values[1].blob_refs[0].blob_key.clone(), + }); + let mut blob_res = handler + .read_blob(blob_req) + .await + .expect("ReadBlob") + .into_inner(); + let mut chunks = Vec::new(); + while let Some(chunk) = blob_res.next().await { + let chunk = chunk.unwrap_or_else(|_| panic!("chunk {}", chunks.len())); + chunks.push(chunk.data); + } + let expected_chunks = vec![ + b"z".repeat(BLOB_CHUNK_SIZE), + b"z".repeat(BLOB_CHUNK_SIZE / 2), + ]; + assert_eq!(chunks, expected_chunks); + } }