Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ py_library(
srcs = ["exporter.py"],
deps = [
":util",
"//tensorboard:expect_grpc_installed",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"//tensorboard/util:grpc_util",
],
Expand All @@ -28,6 +29,7 @@ py_test(
deps = [
":exporter_lib",
":test_util",
"//tensorboard:expect_grpc_installed",
"//tensorboard:expect_grpc_testing_installed",
"//tensorboard:test",
"//tensorboard/compat/proto:protos_all_py_pb2",
Expand Down
26 changes: 18 additions & 8 deletions tensorboard/uploader/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import base64
import errno
import grpc
import json
import os
import string
Expand All @@ -40,7 +41,6 @@
# Maximum value of a signed 64-bit integer.
_MAX_INT64 = 2**63 - 1


class TensorBoardExporter(object):
"""Exports all of the user's experiment data from a hosted service.

Expand Down Expand Up @@ -110,13 +110,19 @@ def export(self, read_time=None):
read_time = time.time()
for experiment_id in self._request_experiment_ids(read_time):
filepath = _scalars_filepath(self._outdir, experiment_id)
with _open_excl(filepath) as outfile:
data = self._request_scalar_data(experiment_id, read_time)
for block in data:
json.dump(block, outfile, sort_keys=True)
outfile.write("\n")
outfile.flush()
yield experiment_id
try:
with _open_excl(filepath) as outfile:
data = self._request_scalar_data(experiment_id, read_time)
for block in data:
json.dump(block, outfile, sort_keys=True)
outfile.write("\n")
outfile.flush()
yield experiment_id
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.CANCELLED:
raise GrpcTimeoutException(experiment_id)
else:
raise

def _request_experiment_ids(self, read_time):
"""Yields all of the calling user's experiment IDs, as strings."""
Expand Down Expand Up @@ -165,6 +171,10 @@ class OutputFileExistsError(ValueError):
# Like Python 3's `__builtins__.FileExistsError`.
pass

class GrpcTimeoutException(Exception):
def __init__(self, experiment_id):
super(GrpcTimeoutException, self).__init__(experiment_id)
self.experiment_id = experiment_id

def _scalars_filepath(base_dir, experiment_id):
"""Gets file path in which to store scalars for the given experiment."""
Expand Down
58 changes: 58 additions & 0 deletions tensorboard/uploader/exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import os

import grpc
import grpc_testing

try:
Expand Down Expand Up @@ -199,6 +200,63 @@ def stream_experiments(request, **kwargs):
self.assertIn("../authorized_keys", msg)
mock_api_client.StreamExperimentData.assert_not_called()

def test_fails_nicely_on_stream_experiment_data_timeout(self):
# Setup: Client where:
# 1. stream_experiments will say there is one experiment_id.
# 2. stream_experiment_data will raise a grpc CANCELLED, as per
# a timeout.
mock_api_client = self._create_mock_api_client()
experiment_id="123"

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=[experiment_id])

def stream_experiment_data(request, **kwargs):
raise test_util.grpc_error(grpc.StatusCode.CANCELLED, "details string")

mock_api_client.StreamExperiments = stream_experiments
mock_api_client.StreamExperimentData = stream_experiment_data

outdir = os.path.join(self.get_temp_dir(), "outdir")
# Execute: exporter.export()
exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
generator = exporter.export()
# Expect: A nice exception of the right type and carrying the right
# experiment_id.
with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm:
next(generator)
self.assertEquals(cm.exception.experiment_id, experiment_id)

def test_stream_experiment_data_passes_through_unexpected_exception(self):
# Setup: Client where:
# 1. stream_experiments will say there is one experiment_id.
# 2. stream_experiment_data will throw an internal error.
mock_api_client = self._create_mock_api_client()
experiment_id = "123"

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=[experiment_id])

def stream_experiment_data(request, **kwargs):
del request # unused
raise test_util.grpc_error(grpc.StatusCode.INTERNAL, "details string")

mock_api_client.StreamExperiments = stream_experiments
mock_api_client.StreamExperimentData = stream_experiment_data

outdir = os.path.join(self.get_temp_dir(), "outdir")
# Execute: exporter.export().
exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
generator = exporter.export()
# Expect: The internal error is passed through.
with self.assertRaises(grpc.RpcError) as cm:
next(generator)
self.assertEquals(cm.exception.details(), "details string")

def test_handles_outdir_with_no_slash(self):
oldcwd = os.getcwd()
try:
Expand Down
12 changes: 9 additions & 3 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,15 @@ def execute(self, channel):
msg = 'Output directory already exists: %r' % outdir
raise base_plugin.FlagsError(msg)
num_experiments = 0
for experiment_id in exporter.export():
num_experiments += 1
print('Downloaded experiment %s' % experiment_id)
try:
for experiment_id in exporter.export():
num_experiments += 1
print('Downloaded experiment %s' % experiment_id)
except exporter_lib.GrpcTimeoutException as e:
print(
'\nUploader has failed because of a timeout error. Please reach '
'out via e-mail to tensorboard.dev-support@google.com to get help '
'completing your export of experiment %s.' % e.experiment_id)
print('Done. Downloaded %d experiments to: %s' % (num_experiments, outdir))


Expand Down