Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ py_library(
"//tensorboard:expect_grpc_installed",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"//tensorboard/util:grpc_util",
"//tensorboard/util:tb_logging",
"@org_pythonhosted_six",
],
)
Expand Down
37 changes: 20 additions & 17 deletions tensorboard/uploader/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorboard.uploader.proto import export_service_pb2
from tensorboard.uploader import util
from tensorboard.util import grpc_util
from tensorboard.util import tb_logging

# Characters that are assumed to be safe in filenames. Note that the
# server's experiment IDs are base64 encodings of 16-byte blobs, so they
Expand All @@ -47,6 +48,8 @@
# Output filename for scalar data within an experiment directory.
_FILENAME_SCALARS = "scalars.json"

logger = tb_logging.get_logger()


class TensorBoardExporter(object):
"""Exports all of the user's experiment data from TensorBoard.dev.
Expand Down Expand Up @@ -115,7 +118,8 @@ def export(self, read_time=None):
"""
if read_time is None:
read_time = time.time()
for experiment_id in self._request_experiment_ids(read_time):
for experiment in list_experiments(self._api, read_time=read_time):
experiment_id = experiment.experiment_id
experiment_dir = _experiment_directory(self._outdir, experiment_id)
os.mkdir(experiment_dir)

Expand All @@ -134,18 +138,6 @@ def export(self, read_time=None):
else:
raise

def _request_experiment_ids(self, read_time):
"""Yields all of the calling user's experiment IDs, as strings."""
for experiment in list_experiments(self._api, read_time=read_time):
if isinstance(experiment, experiment_pb2.Experiment):
yield experiment.experiment_id
elif isinstance(experiment, six.string_types):
yield experiment
else:
raise AssertionError(
"Unexpected experiment type: %r" % (experiment,)
)

def _request_scalar_data(self, experiment_id, read_time):
"""Yields JSON-serializable blocks of scalar data."""
request = export_service_pb2.StreamExperimentDataRequest()
Expand Down Expand Up @@ -191,7 +183,11 @@ def list_experiments(api_client, fieldmask=None, read_time=None):

Yields:
For each experiment owned by the user, an `experiment_pb2.Experiment`
value, or a simple string experiment ID for older servers.
value.

Raises:
RuntimeError: If the server returns experiment IDs but no experiments,
as in an old, unsupported version of the protocol.
"""
if read_time is None:
read_time = time.time()
Expand All @@ -206,10 +202,17 @@ def list_experiments(api_client, fieldmask=None, read_time=None):
if response.experiments:
for experiment in response.experiments:
yield experiment
elif response.experiment_ids:
raise RuntimeError(
"Server sent experiment_ids without experiments: <%r>"
% (list(response.experiment_ids),)
)
else:
# Old servers.
for experiment_id in response.experiment_ids:
yield experiment_id
# No data: not technically a problem, but not expected.
logger.warn(
"StreamExperiments RPC returned response with no experiments: <%r>",
response,
)


class OutputDirectoryExistsError(ValueError):
Expand Down
83 changes: 37 additions & 46 deletions tensorboard/uploader/exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,29 @@
from tensorboard.compat.proto import summary_pb2


def _make_experiments_response(eids):
"""Make a `StreamExperimentsResponse` with experiments with only IDs."""
response = export_service_pb2.StreamExperimentsResponse()
for eid in eids:
response.experiments.add(experiment_id=eid)
return response


class TensorBoardExporterTest(tb_test.TestCase):
def _create_mock_api_client(self):
return _create_mock_api_client()

def _make_experiments_response(self, eids):
return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids)

def test_e2e_success_case(self):
mock_api_client = self._create_mock_api_client()
mock_api_client.StreamExperiments.return_value = iter(
[
export_service_pb2.StreamExperimentsResponse(
experiment_ids=["789"]
),
]
[_make_experiments_response(["789"])]
)

def stream_experiments(request, **kwargs):
del request # unused
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["123", "456"]
)
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["789"]
)
yield _make_experiments_response(["123", "456"])
yield _make_experiments_response(["789"])

def stream_experiment_data(request, **kwargs):
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
Expand Down Expand Up @@ -200,9 +197,7 @@ def test_rejects_dangerous_experiment_ids(self):

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["../authorized_keys"]
)
yield _make_experiments_response(["../authorized_keys"])

mock_api_client.StreamExperiments = stream_experiments

Expand All @@ -229,9 +224,7 @@ def test_fails_nicely_on_stream_experiment_data_timeout(self):

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=[experiment_id]
)
yield _make_experiments_response([experiment_id])

def stream_experiment_data(request, **kwargs):
raise test_util.grpc_error(
Expand Down Expand Up @@ -260,9 +253,7 @@ def test_stream_experiment_data_passes_through_unexpected_exception(self):

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=[experiment_id]
)
yield _make_experiments_response([experiment_id])

def stream_experiment_data(request, **kwargs):
del request # unused
Expand All @@ -288,11 +279,7 @@ def test_handles_outdir_with_no_slash(self):
os.chdir(self.get_temp_dir())
mock_api_client = self._create_mock_api_client()
mock_api_client.StreamExperiments.return_value = iter(
[
export_service_pb2.StreamExperimentsResponse(
experiment_ids=["123"]
),
]
[_make_experiments_response(["123"])]
)
mock_api_client.StreamExperimentData.return_value = iter(
[export_service_pb2.StreamExperimentDataResponse()]
Expand Down Expand Up @@ -335,6 +322,7 @@ def test_propagates_mkdir_errors(self):

class ListExperimentsTest(tb_test.TestCase):
def test_experiment_ids_only(self):
# Legacy server behavior; should raise an error.
mock_api_client = _create_mock_api_client()

def stream_experiments(request, **kwargs):
Expand All @@ -347,45 +335,48 @@ def stream_experiments(request, **kwargs):
)

mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
gen = exporter_lib.list_experiments(mock_api_client)
mock_api_client.StreamExperiments.assert_not_called()
self.assertEqual(list(gen), ["123", "456", "789"])
with self.assertRaises(RuntimeError) as cm:
list(exporter_lib.list_experiments(mock_api_client))
self.assertIn(repr(["123", "456"]), str(cm.exception))

def test_mixed_experiments_and_ids(self):
mock_api_client = _create_mock_api_client()

def stream_experiments(request, **kwargs):
del request # unused

# Should include `experiment_ids` when no `experiments` given.
response = export_service_pb2.StreamExperimentsResponse()
response.experiment_ids.append("123")
response.experiment_ids.append("456")
yield response

# Should ignore `experiment_ids` in the presence of `experiments`.
response = export_service_pb2.StreamExperimentsResponse()
response.experiment_ids.append("999") # will be omitted
response.experiments.add(experiment_id="789")
response.experiments.add(experiment_id="012")
yield response

# Should include `experiments` even when no `experiment_ids` are given.
mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
gen = exporter_lib.list_experiments(mock_api_client)
mock_api_client.StreamExperiments.assert_not_called()
expected = [
experiment_pb2.Experiment(experiment_id="789"),
experiment_pb2.Experiment(experiment_id="012"),
]
self.assertEqual(list(gen), expected)

def test_experiments_only(self):
mock_api_client = _create_mock_api_client()

def stream_experiments(request, **kwargs):
del request # unused
response = export_service_pb2.StreamExperimentsResponse()
response.experiments.add(experiment_id="345")
response.experiments.add(experiment_id="678")
response.experiments.add(experiment_id="789", name="one")
response.experiments.add(experiment_id="012", description="two")
yield response

mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
gen = exporter_lib.list_experiments(mock_api_client)
mock_api_client.StreamExperiments.assert_not_called()
expected = [
"123",
"456",
experiment_pb2.Experiment(experiment_id="789"),
experiment_pb2.Experiment(experiment_id="012"),
experiment_pb2.Experiment(experiment_id="345"),
experiment_pb2.Experiment(experiment_id="678"),
experiment_pb2.Experiment(experiment_id="789", name="one"),
experiment_pb2.Experiment(experiment_id="012", description="two"),
]
self.assertEqual(list(gen), expected)

Expand Down
4 changes: 0 additions & 4 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,6 @@ def execute(self, server_info, channel):
count = 0
for experiment in gen:
count += 1
if not isinstance(experiment, experiment_pb2.Experiment):
url = server_info_lib.experiment_url(server_info, experiment)
print(url)
continue
experiment_id = experiment.experiment_id
url = server_info_lib.experiment_url(server_info, experiment_id)
print(url)
Expand Down