Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 91 additions & 8 deletions tensorboard/data/grpc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def list_scalars(
tags = {}
result[run_entry.run_name] = tags
for tag_entry in run_entry.tags:
ts = tag_entry.metadata
time_series = tag_entry.metadata
tags[tag_entry.tag_name] = provider.ScalarTimeSeries(
max_step=ts.max_step,
max_wall_time=ts.max_wall_time,
plugin_content=ts.summary_metadata.plugin_data.content,
description=ts.summary_metadata.summary_description,
display_name=ts.summary_metadata.display_name,
max_step=time_series.max_step,
max_wall_time=time_series.max_wall_time,
plugin_content=time_series.summary_metadata.plugin_data.content,
description=time_series.summary_metadata.summary_description,
display_name=time_series.summary_metadata.display_name,
)
return result

Expand Down Expand Up @@ -126,14 +126,97 @@ def read_scalars(
tags[tag_entry.tag_name] = series
d = tag_entry.data
for (step, wt, value) in zip(d.step, d.wall_time, d.value):
pt = provider.ScalarDatum(
point = provider.ScalarDatum(
step=step,
wall_time=wt,
value=value,
)
series.append(pt)
series.append(point)
return result

@timing.log_latency
def list_blob_sequences(
self, ctx, experiment_id, plugin_name, run_tag_filter=None
):
with timing.log_latency("build request"):
req = data_provider_pb2.ListBlobSequencesRequest()
req.experiment_id = experiment_id
req.plugin_filter.plugin_name = plugin_name
_populate_rtf(run_tag_filter, req.run_tag_filter)
with timing.log_latency("_stub.ListBlobSequences"):
with _translate_grpc_error():
res = self._stub.ListBlobSequences(req)
with timing.log_latency("build result"):
result = {}
for run_entry in res.runs:
tags = {}
result[run_entry.run_name] = tags
for tag_entry in run_entry.tags:
time_series = tag_entry.metadata
tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries(
max_step=time_series.max_step,
max_wall_time=time_series.max_wall_time,
max_length=time_series.max_length,
plugin_content=time_series.summary_metadata.plugin_data.content,
description=time_series.summary_metadata.summary_description,
display_name=time_series.summary_metadata.display_name,
)
return result

@timing.log_latency
def read_blob_sequences(
self,
ctx,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
with timing.log_latency("build request"):
req = data_provider_pb2.ReadBlobSequencesRequest()
req.experiment_id = experiment_id
req.plugin_filter.plugin_name = plugin_name
_populate_rtf(run_tag_filter, req.run_tag_filter)
req.downsample.num_points = downsample
with timing.log_latency("_stub.ReadBlobSequences"):
with _translate_grpc_error():
res = self._stub.ReadBlobSequences(req)
with timing.log_latency("build result"):
result = {}
for run_entry in res.runs:
tags = {}
result[run_entry.run_name] = tags
for tag_entry in run_entry.tags:
series = []
tags[tag_entry.tag_name] = series
d = tag_entry.data
for (step, wt, blob_sequence) in zip(
d.step, d.wall_time, d.values
):
values = []
for ref in blob_sequence.blob_refs:
values.append(
provider.BlobReference(
blob_key=ref.blob_key, url=ref.url or None
)
)
point = provider.BlobSequenceDatum(
step=step, wall_time=wt, values=tuple(values)
)
series.append(point)
return result

@timing.log_latency
def read_blob(self, ctx, blob_key):
with timing.log_latency("build request"):
req = data_provider_pb2.ReadBlobRequest()
req.blob_key = blob_key
with timing.log_latency("list(_stub.ReadBlob)"):
with _translate_grpc_error():
responses = list(self._stub.ReadBlob(req))
with timing.log_latency("build result"):
return b"".join(res.data for res in responses)


@contextlib.contextmanager
def _translate_grpc_error():
Expand Down
112 changes: 112 additions & 0 deletions tensorboard/data/grpc_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,118 @@ def test_read_scalars(self):
req.downsample.num_points = 4
self.stub.ReadScalars.assert_called_once_with(req)

def test_list_blob_sequences(self):
res = data_provider_pb2.ListBlobSequencesResponse()
run1 = res.runs.add(run_name="train")
tag11 = run1.tags.add(tag_name="input_image")
tag11.metadata.max_step = 7
tag11.metadata.max_wall_time = 7.77
tag11.metadata.max_length = 3
tag11.metadata.summary_metadata.plugin_data.content = b"PNG"
tag11.metadata.summary_metadata.display_name = "Input image"
tag11.metadata.summary_metadata.summary_description = "img"
self.stub.ListBlobSequences.return_value = res

actual = self.provider.list_blob_sequences(
self.ctx,
experiment_id="123",
plugin_name="images",
run_tag_filter=provider.RunTagFilter(runs=["val", "train"]),
)
expected = {
"train": {
"input_image": provider.BlobSequenceTimeSeries(
max_step=7,
max_wall_time=7.77,
max_length=3,
plugin_content=b"PNG",
description="img",
display_name="Input image",
),
},
}
self.assertEqual(actual, expected)

req = data_provider_pb2.ListBlobSequencesRequest()
req.experiment_id = "123"
req.plugin_filter.plugin_name = "images"
req.run_tag_filter.runs.names.extend(["train", "val"]) # sorted
self.stub.ListBlobSequences.assert_called_once_with(req)

def test_read_blob_sequences(self):
res = data_provider_pb2.ReadBlobSequencesResponse()
run = res.runs.add(run_name="test")
tag = run.tags.add(tag_name="input_image")
tag.data.step.extend([0, 1])
tag.data.wall_time.extend([1234.0, 1235.0])
seq0 = tag.data.values.add()
seq0.blob_refs.add(blob_key="step0img0")
seq0.blob_refs.add(blob_key="step0img1")
seq1 = tag.data.values.add()
seq1.blob_refs.add(blob_key="step1img0")
self.stub.ReadBlobSequences.return_value = res

actual = self.provider.read_blob_sequences(
self.ctx,
experiment_id="123",
plugin_name="images",
run_tag_filter=provider.RunTagFilter(runs=["test", "nope"]),
downsample=4,
)
expected = {
"test": {
"input_image": [
provider.BlobSequenceDatum(
step=0,
wall_time=1234.0,
values=(
provider.BlobReference(blob_key="step0img0"),
provider.BlobReference(blob_key="step0img1"),
),
),
provider.BlobSequenceDatum(
step=1,
wall_time=1235.0,
values=(provider.BlobReference(blob_key="step1img0"),),
),
],
},
}
self.assertEqual(actual, expected)

req = data_provider_pb2.ReadBlobSequencesRequest()
req.experiment_id = "123"
req.plugin_filter.plugin_name = "images"
req.run_tag_filter.runs.names.extend(["nope", "test"]) # sorted
req.downsample.num_points = 4
self.stub.ReadBlobSequences.assert_called_once_with(req)

def test_read_blob(self):
responses = [
data_provider_pb2.ReadBlobResponse(data=b"hello wo"),
data_provider_pb2.ReadBlobResponse(data=b"rld"),
]
self.stub.ReadBlob.return_value = responses

actual = self.provider.read_blob(self.ctx, blob_key="myblob")
expected = b"hello world"
self.assertEqual(actual, expected)

req = data_provider_pb2.ReadBlobRequest()
req.blob_key = "myblob"
self.stub.ReadBlob.assert_called_once_with(req)

def test_read_blob_error(self):
def fake_handler(req):
del req # unused
yield data_provider_pb2.ReadBlobResponse(data=b"hello wo"),
raise _grpc_error(grpc.StatusCode.NOT_FOUND, "it ran away!")

self.stub.ReadBlob.side_effect = fake_handler

with self.assertRaisesRegex(errors.NotFoundError, "it ran away!"):
self.provider.read_blob(self.ctx, blob_key="myblob")

def test_rpc_error(self):
# This error handling is implemented with a context manager used
# for all the methods, so take `list_plugins` as representative.
Expand Down
2 changes: 2 additions & 0 deletions tensorboard/data/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rust_library(
] + _checked_in_proto_files,
edition = "2018",
deps = [
"//third_party/rust:async_stream",
"//third_party/rust:base64",
"//third_party/rust:byteorder",
"//third_party/rust:clap",
Expand All @@ -59,6 +60,7 @@ rust_library(
"//third_party/rust:serde_json",
"//third_party/rust:thiserror",
"//third_party/rust:tokio",
"//third_party/rust:tokio_stream",
"//third_party/rust:tonic",
"//third_party/rust:walkdir",
],
Expand Down
Loading