Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions tensorboard/uploader/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ def export(self, read_time=None):

def _request_experiment_ids(self, read_time):
"""Yields all of the calling user's experiment IDs, as strings."""
request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64)
util.set_timestamp(request.read_timestamp, read_time)
stream = self._api.StreamExperiments(
request, metadata=grpc_util.version_metadata())
for response in stream:
for experiment_id in response.experiment_ids:
yield experiment_id
return list_experiments(self._api, read_time=read_time)

def _request_scalar_data(self, experiment_id, read_time):
"""Yields JSON-serializable blocks of scalar data."""
Expand Down Expand Up @@ -163,6 +157,30 @@ def _request_scalar_data(self, experiment_id, read_time):
}


def list_experiments(api_client, read_time=None):
"""Yields all of the calling user's experiment IDs.

Args:
api_client: A TensorBoardExporterService stub instance.
read_time: A fixed timestamp from which to export data, as float seconds
since epoch (like `time.time()`). Optional; defaults to the current
time.

Yields:
One string for each experiment owned by the calling user, in arbitrary
order.
"""
if read_time is None:
read_time = time.time()
request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64)
util.set_timestamp(request.read_timestamp, read_time)
stream = api_client.StreamExperiments(
request, metadata=grpc_util.version_metadata())
for response in stream:
for experiment_id in response.experiment_ids:
yield experiment_id


class OutputDirectoryExistsError(ValueError):
pass

Expand Down
40 changes: 31 additions & 9 deletions tensorboard/uploader/exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,7 @@
class TensorBoardExporterTest(tb_test.TestCase):

def _create_mock_api_client(self):
# Create a stub instance (using a test channel) in order to derive a mock
# from it with autospec enabled. Mocking TensorBoardExporterServiceStub
# itself doesn't work with autospec because grpc constructs stubs via
# metaclassing.
test_channel = grpc_testing.channel(
service_descriptors=[], time=grpc_testing.strict_real_time())
stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel)
mock_api_client = mock.create_autospec(stub)
return mock_api_client
return _create_mock_api_client()

def _make_experiments_response(self, eids):
return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids)
Expand Down Expand Up @@ -323,6 +315,24 @@ def test_propagates_mkdir_errors(self):
mock_api_client.StreamExperimentData.assert_not_called()


class ListExperimentsTest(tb_test.TestCase):

def test(self):
mock_api_client = _create_mock_api_client()

def stream_experiments(request, **kwargs):
del request # unused
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["123", "456"])
yield export_service_pb2.StreamExperimentsResponse(
experiment_ids=["789"])

mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
gen = exporter_lib.list_experiments(mock_api_client)
mock_api_client.StreamExperiments.assert_not_called()
self.assertEqual(list(gen), ["123", "456", "789"])


class MkdirPTest(tb_test.TestCase):

def test_makes_full_chain(self):
Expand Down Expand Up @@ -384,5 +394,17 @@ def test_propagates_other_errors(self):
self.assertEqual(cm.exception.errno, errno.ENOENT)


def _create_mock_api_client():
# Create a stub instance (using a test channel) in order to derive a mock
# from it with autospec enabled. Mocking TensorBoardExporterServiceStub
# itself doesn't work with autospec because grpc constructs stubs via
# metaclassing.
test_channel = grpc_testing.channel(
service_descriptors=[], time=grpc_testing.strict_real_time())
stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel)
mock_api_client = mock.create_autospec(stub)
return mock_api_client


if __name__ == "__main__":
tb_test.main()
37 changes: 37 additions & 0 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
_SUBCOMMAND_FLAG = '_uploader__subcommand'
_SUBCOMMAND_KEY_UPLOAD = 'UPLOAD'
_SUBCOMMAND_KEY_DELETE = 'DELETE'
_SUBCOMMAND_KEY_LIST = 'LIST'
_SUBCOMMAND_KEY_EXPORT = 'EXPORT'
_SUBCOMMAND_KEY_AUTH = 'AUTH'
_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth'
Expand Down Expand Up @@ -135,6 +136,10 @@ def _define_flags(parser):
default=None,
help='ID of an experiment to delete permanently')

list_parser = subparsers.add_parser(
'list', help='list previously uploaded experiments')
list_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_LIST})

export = subparsers.add_parser(
'export', help='download all your experiment data')
export.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_EXPORT})
Expand Down Expand Up @@ -312,6 +317,36 @@ def execute(self, channel):
print('Deleted experiment %s.' % experiment_id)


class _ListIntent(_Intent):
"""The user intends to list all their experiments."""

_MESSAGE = textwrap.dedent(u"""\
This will list all experiments that you've uploaded to
https://tensorboard.dev. TensorBoard.dev experiments are visible
to everyone. Do not upload sensitive data.
""")

def get_ack_message_body(self):
return self._MESSAGE

def execute(self, channel):
api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
gen = exporter_lib.list_experiments(api_client)
count = 0
for experiment_id in gen:
count += 1
# TODO(@wchargin): Once #2879 is in, remove this hard-coded URL pattern.
url = 'https://tensorboard.dev/experiment/%s/' % experiment_id
print(url)
sys.stdout.flush()
if not count:
sys.stderr.write(
'No experiments. Use `tensorboard dev upload` to get started.\n')
else:
sys.stderr.write('Total: %d experiment(s)\n' % count)
sys.stderr.flush()


class _UploadIntent(_Intent):
"""The user intends to upload an experiment from the given logdir."""

Expand Down Expand Up @@ -421,6 +456,8 @@ def _get_intent(flags):
else:
raise base_plugin.FlagsError(
'Must specify experiment to delete via `--experiment_id`.')
elif cmd == _SUBCOMMAND_KEY_LIST:
return _ListIntent()
elif cmd == _SUBCOMMAND_KEY_EXPORT:
if flags.outdir:
return _ExportIntent(flags.outdir)
Expand Down