diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index 2d33b334e7..360efc6bfc 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -17,6 +17,7 @@ py_library( srcs = ["exporter.py"], deps = [ ":util", + "//tensorboard:expect_grpc_installed", "//tensorboard/uploader/proto:protos_all_py_pb2", "//tensorboard/util:grpc_util", ], @@ -28,6 +29,7 @@ py_test( deps = [ ":exporter_lib", ":test_util", + "//tensorboard:expect_grpc_installed", "//tensorboard:expect_grpc_testing_installed", "//tensorboard:test", "//tensorboard/compat/proto:protos_all_py_pb2", diff --git a/tensorboard/uploader/exporter.py b/tensorboard/uploader/exporter.py index 123dd6c4c0..04da87e696 100644 --- a/tensorboard/uploader/exporter.py +++ b/tensorboard/uploader/exporter.py @@ -20,6 +20,7 @@ import base64 import errno +import grpc import json import os import string @@ -40,7 +41,6 @@ # Maximum value of a signed 64-bit integer. _MAX_INT64 = 2**63 - 1 - class TensorBoardExporter(object): """Exports all of the user's experiment data from a hosted service. @@ -110,13 +110,19 @@ def export(self, read_time=None): read_time = time.time() for experiment_id in self._request_experiment_ids(read_time): filepath = _scalars_filepath(self._outdir, experiment_id) - with _open_excl(filepath) as outfile: - data = self._request_scalar_data(experiment_id, read_time) - for block in data: - json.dump(block, outfile, sort_keys=True) - outfile.write("\n") - outfile.flush() - yield experiment_id + try: + with _open_excl(filepath) as outfile: + data = self._request_scalar_data(experiment_id, read_time) + for block in data: + json.dump(block, outfile, sort_keys=True) + outfile.write("\n") + outfile.flush() + yield experiment_id + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + raise GrpcTimeoutException(experiment_id) + else: + raise def _request_experiment_ids(self, read_time): """Yields all of the calling user's experiment IDs, as strings.""" @@ -165,6 +171,10 @@ class OutputFileExistsError(ValueError): # Like Python 3's `__builtins__.FileExistsError`. pass +class GrpcTimeoutException(Exception): + def __init__(self, experiment_id): + super(GrpcTimeoutException, self).__init__(experiment_id) + self.experiment_id = experiment_id def _scalars_filepath(base_dir, experiment_id): """Gets file path in which to store scalars for the given experiment.""" diff --git a/tensorboard/uploader/exporter_test.py b/tensorboard/uploader/exporter_test.py index 9072458cc2..f4e5213a86 100644 --- a/tensorboard/uploader/exporter_test.py +++ b/tensorboard/uploader/exporter_test.py @@ -23,6 +23,7 @@ import json import os +import grpc import grpc_testing try: @@ -199,6 +200,63 @@ def stream_experiments(request, **kwargs): self.assertIn("../authorized_keys", msg) mock_api_client.StreamExperimentData.assert_not_called() + def test_fails_nicely_on_stream_experiment_data_timeout(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will raise a grpc CANCELLED, as per + # a timeout. + mock_api_client = self._create_mock_api_client() + experiment_id="123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id]) + + def stream_experiment_data(request, **kwargs): + raise test_util.grpc_error(grpc.StatusCode.CANCELLED, "details string") + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export() + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: A nice exception of the right type and carrying the right + # experiment_id. + with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm: + next(generator) + self.assertEquals(cm.exception.experiment_id, experiment_id) + + def test_stream_experiment_data_passes_through_unexpected_exception(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will throw an internal error. + mock_api_client = self._create_mock_api_client() + experiment_id = "123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id]) + + def stream_experiment_data(request, **kwargs): + del request # unused + raise test_util.grpc_error(grpc.StatusCode.INTERNAL, "details string") + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export(). + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: The internal error is passed through. + with self.assertRaises(grpc.RpcError) as cm: + next(generator) + self.assertEquals(cm.exception.details(), "details string") + def test_handles_outdir_with_no_slash(self): oldcwd = os.getcwd() try: diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index e30b2bd655..32f001ba13 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -367,9 +367,15 @@ def execute(self, channel): msg = 'Output directory already exists: %r' % outdir raise base_plugin.FlagsError(msg) num_experiments = 0 - for experiment_id in exporter.export(): - num_experiments += 1 - print('Downloaded experiment %s' % experiment_id) + try: + for experiment_id in exporter.export(): + num_experiments += 1 + print('Downloaded experiment %s' % experiment_id) + except exporter_lib.GrpcTimeoutException as e: + print( + '\nUploader has failed because of a timeout error. Please reach ' + 'out via e-mail to tensorboard.dev-support@google.com to get help ' + 'completing your export of experiment %s.' % e.experiment_id) print('Done. Downloaded %d experiments to: %s' % (num_experiments, outdir))