diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 8dcd401a26..05c1344390 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -26,6 +26,7 @@ import six from tensorboard.uploader.proto import write_service_pb2 +from tensorboard.uploader.proto import experiment_pb2 from tensorboard.uploader import logdir_loader from tensorboard.uploader import util from tensorboard import data_compat @@ -64,7 +65,14 @@ class TensorBoardUploader(object): """Uploads a TensorBoard logdir to TensorBoard.dev.""" - def __init__(self, writer_client, logdir, rpc_rate_limiter=None): + def __init__( + self, + writer_client, + logdir, + rpc_rate_limiter=None, + name=None, + description=None, + ): """Constructs a TensorBoardUploader. Args: @@ -77,9 +85,13 @@ def __init__(self, writer_client, logdir, rpc_rate_limiter=None): of chunks. Note the chunk stream is internally rate-limited by backpressure from the server, so it is not a concern that we do not explicitly rate-limit within the stream here. + name: String name to assign to the experiment. + description: String description to assign to the experiment. """ self._api = writer_client self._logdir = logdir + self._name = name + self._description = description self._request_sender = None if rpc_rate_limiter is None: self._rpc_rate_limiter = util.RateLimiter( @@ -103,7 +115,9 @@ def __init__(self, writer_client, logdir, rpc_rate_limiter=None): def create_experiment(self): """Creates an Experiment for this upload session and returns the ID.""" logger.info("Creating experiment") - request = write_service_pb2.CreateExperimentRequest() + request = write_service_pb2.CreateExperimentRequest( + name=self._name, description=self._description + ) response = grpc_util.call_with_retries( self._api.CreateExperiment, request ) @@ -140,6 +154,50 @@ def _upload_once(self): self._request_sender.send_requests(run_to_events) +def update_experiment_metadata( + writer_client, experiment_id, name=None, description=None +): + """Modifies user data associated with an experiment. + + Args: + writer_client: a TensorBoardWriterService stub instance + experiment_id: string ID of the experiment to modify + name: If provided, modifies name of experiment to this value. + description: If provided, modifies the description of the experiment to + this value + + Raises: + ExperimentNotFoundError: If no such experiment exists. + PermissionDeniedError: If the user is not authorized to modify this + experiment. + InvalidArgumentError: If the server rejected the name or description, if, + for instance, the size limits have changed on the server. + """ + logger.info("Modifying experiment %r", experiment_id) + request = write_service_pb2.UpdateExperimentRequest() + request.experiment.experiment_id = experiment_id + if name is not None: + logger.info("Setting exp %r name to %r", experiment_id, name) + request.experiment.name = name + request.experiment_mask.name = True + if description is not None: + logger.info( + "Setting exp %r description to %r", experiment_id, description + ) + request.experiment.description = description + request.experiment_mask.description = True + try: + grpc_util.call_with_retries(writer_client.UpdateExperiment, request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + if e.code() == grpc.StatusCode.PERMISSION_DENIED: + raise PermissionDeniedError() + if e.code() == grpc.StatusCode.INVALID_ARGUMENT: + raise InvalidArgumentError(e.details()) + raise + + def delete_experiment(writer_client, experiment_id): """Permanently deletes an experiment and all of its contents. @@ -166,6 +224,10 @@ def delete_experiment(writer_client, experiment_id): raise +class InvalidArgumentError(RuntimeError): + pass + + class ExperimentNotFoundError(RuntimeError): pass diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index 57c406837a..2693691dda 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -65,6 +65,7 @@ _SUBCOMMAND_KEY_DELETE = "DELETE" _SUBCOMMAND_KEY_LIST = "LIST" _SUBCOMMAND_KEY_EXPORT = "EXPORT" +_SUBCOMMAND_KEY_UPDATE_METADATA = "UPDATEMETADATA" _SUBCOMMAND_KEY_AUTH = "AUTH" _AUTH_SUBCOMMAND_FLAG = "_uploader__subcommand_auth" _AUTH_SUBCOMMAND_KEY_REVOKE = "REVOKE" @@ -72,6 +73,12 @@ _DEFAULT_ORIGIN = "https://tensorboard.dev" +# Size limits for input fields not bounded at a wire level. "Chars" in this +# context refers to Unicode code points as stipulated by https://aip.dev/210. +_EXPERIMENT_NAME_MAX_CHARS = 100 +_EXPERIMENT_DESCRIPTION_MAX_CHARS = 600 + + def _prompt_for_user_ack(intent): """Prompts for user consent, exiting the program if they decline.""" body = intent.get_ack_message_body() @@ -139,6 +146,46 @@ def _define_flags(parser): default=None, help="Directory containing the logs to process", ) + upload.add_argument( + "--name", + type=str, + default=None, + help="Title of the experiment. Max 100 characters.", + ) + upload.add_argument( + "--description", + type=str, + default=None, + help="Experiment description. Markdown format. Max 600 characters.", + ) + + update_metadata = subparsers.add_parser( + "update-metadata", + help="change the name, description, or other user " + "metadata associated with an experiment.", + ) + update_metadata.set_defaults( + **{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_UPDATE_METADATA} + ) + update_metadata.add_argument( + "--experiment_id", + metavar="EXPERIMENT_ID", + type=str, + default=None, + help="ID of the experiment on which to modify the metadata.", + ) + update_metadata.add_argument( + "--name", + type=str, + default=None, + help="Title of the experiment. Max 100 characters.", + ) + update_metadata.add_argument( + "--description", + type=str, + default=None, + help="Experiment description. Markdown format. Max 600 characters.", + ) delete = subparsers.add_parser( "delete", @@ -372,6 +419,72 @@ def execute(self, server_info, channel): print("Deleted experiment %s." % experiment_id) +class _UpdateMetadataIntent(_Intent): + """The user intends to update the metadata for an experiment.""" + + _MESSAGE_TEMPLATE = textwrap.dedent( + u"""\ + This will modify the metadata associated with the experiment on + https://tensorboard.dev with the following experiment ID: + + {experiment_id} + + You have chosen to modify an experiment. All experiments uploaded + to TensorBoard.dev are publicly visible. Do not upload sensitive + data. + """ + ) + + def __init__(self, experiment_id, name=None, description=None): + self.experiment_id = experiment_id + self.name = name + self.description = description + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id) + + def execute(self, server_info, channel): + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( + channel + ) + experiment_id = self.experiment_id + _die_if_bad_experiment_name(self.name) + _die_if_bad_experiment_description(self.description) + if not experiment_id: + raise base_plugin.FlagsError( + "Must specify a non-empty experiment ID to modify." + ) + try: + uploader_lib.update_experiment_metadata( + api_client, + experiment_id, + name=self.name, + description=self.description, + ) + except uploader_lib.ExperimentNotFoundError: + _die( + "No such experiment %s. Either it never existed or it has " + "already been deleted." % experiment_id + ) + except uploader_lib.PermissionDeniedError: + _die( + "Cannot modify experiment %s because it is owned by a " + "different user." % experiment_id + ) + except uploader_lib.InvalidArgumentError as cm: + _die( + "Server cannot modify experiment as requested.\n" + "Server responded: %s" % cm.description() + ) + except grpc.RpcError as e: + _die("Internal error modifying experiment: %s" % e) + logging.info("Modified experiment %s.", experiment_id) + if self.name is not None: + logging.info("Set name to %r", self.name) + if self.description is not None: + logging.info(f"Set description to %r", repr(self.description)) + + class _ListIntent(_Intent): """The user intends to list all their experiments.""" @@ -409,6 +522,8 @@ def execute(self, server_info, channel): url = server_info_lib.experiment_url(server_info, experiment_id) print(url) data = [ + ("Name", experiment.name or "[No Name]"), + ("Description", experiment.description or "[No Description]"), ("Id", experiment.experiment_id), ("Created", util.format_time(experiment.create_time)), ("Updated", util.format_time(experiment.update_time)), @@ -417,7 +532,7 @@ def execute(self, server_info, channel): ("Tags", str(experiment.num_tags)), ] for (name, value) in data: - print("\t%s %s" % (name.ljust(10), value)) + print("\t%s %s" % (name.ljust(12), value)) sys.stdout.flush() if not count: sys.stderr.write( @@ -428,6 +543,24 @@ def execute(self, server_info, channel): sys.stderr.flush() +def _die_if_bad_experiment_name(name): + if name and len(name) > _EXPERIMENT_NAME_MAX_CHARS: + _die( + "Experiment name is too long. Limit is " + "%s characters.\n" + "%r was provided." % (_EXPERIMENT_NAME_MAX_CHARS, name) + ) + + +def _die_if_bad_experiment_description(description): + if description and len(description) > _EXPERIMENT_DESCRIPTION_MAX_CHARS: + _die( + "Experiment description is too long. Limit is %s characters.\n" + "%r was provided." + % (_EXPERIMENT_DESCRIPTION_MAX_CHARS, description) + ) + + class _UploadIntent(_Intent): """The user intends to upload an experiment from the given logdir.""" @@ -443,8 +576,10 @@ class _UploadIntent(_Intent): """ ) - def __init__(self, logdir): + def __init__(self, logdir, name=None, description=None): self.logdir = logdir + self.name = name + self.description = description def get_ack_message_body(self): return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) @@ -453,7 +588,14 @@ def execute(self, server_info, channel): api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( channel ) - uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir) + _die_if_bad_experiment_name(self.name) + _die_if_bad_experiment_description(self.description) + uploader = uploader_lib.TensorBoardUploader( + api_client, + self.logdir, + name=self.name, + description=self.description, + ) experiment_id = uploader.create_experiment() url = server_info_lib.experiment_url(server_info, experiment_id) print( @@ -541,11 +683,31 @@ def _get_intent(flags): raise base_plugin.FlagsError("Must specify subcommand (try --help).") if cmd == _SUBCOMMAND_KEY_UPLOAD: if flags.logdir: - return _UploadIntent(os.path.expanduser(flags.logdir)) + return _UploadIntent( + os.path.expanduser(flags.logdir), + name=flags.name, + description=flags.description, + ) else: raise base_plugin.FlagsError( "Must specify directory to upload via `--logdir`." ) + if cmd == _SUBCOMMAND_KEY_UPDATE_METADATA: + if flags.experiment_id: + if flags.name is not None or flags.description is not None: + return _UpdateMetadataIntent( + flags.experiment_id, + name=flags.name, + description=flags.description, + ) + else: + raise base_plugin.FlagsError( + "Must specify either `--name` or `--description`." + ) + else: + raise base_plugin.FlagsError( + "Must specify experiment to modify via `--experiment_id`." + ) elif cmd == _SUBCOMMAND_KEY_DELETE: if flags.experiment_id: return _DeleteExperimentIntent(flags.experiment_id) diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index a714370419..7f42cf9c86 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -32,6 +32,7 @@ import tensorflow as tf +from tensorboard.uploader.proto import experiment_pb2 from tensorboard.uploader.proto import scalar_pb2 from tensorboard.uploader.proto import write_service_pb2 from tensorboard.uploader.proto import write_service_pb2_grpc @@ -75,6 +76,64 @@ def test_create_experiment(self): eid = uploader.create_experiment() self.assertEqual(eid, "123") + def test_create_experiment_with_name(self): + logdir = "/logs/foo" + mock_client = _create_mock_client() + new_name = "This is the new name" + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, name=new_name + ) + eid = uploader.create_experiment() + self.assertEqual(eid, "123") + mock_client.CreateExperiment.assert_called_once() + (args, _) = mock_client.CreateExperiment.call_args + + expected_request = write_service_pb2.CreateExperimentRequest( + name=new_name, + ) + self.assertEqual(args[0], expected_request) + + def test_create_experiment_with_description(self): + logdir = "/logs/foo" + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \\/<> + """ + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, description=new_description + ) + eid = uploader.create_experiment() + self.assertEqual(eid, "123") + mock_client.CreateExperiment.assert_called_once() + (args, _) = mock_client.CreateExperiment.call_args + + expected_request = write_service_pb2.CreateExperimentRequest( + description=new_description, + ) + self.assertEqual(args[0], expected_request) + + def test_create_experiment_with_all_metadata(self): + logdir = "/logs/foo" + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \/<> + """ + new_name = "This is a cool name." + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, name=new_name, description=new_description + ) + eid = uploader.create_experiment() + self.assertEqual(eid, "123") + mock_client.CreateExperiment.assert_called_once() + (args, _) = mock_client.CreateExperiment.call_args + + expected_request = write_service_pb2.CreateExperimentRequest( + name=new_name, description=new_description, + ) + self.assertEqual(args[0], expected_request) + def test_start_uploading_without_create_experiment_fails(self): mock_client = _create_mock_client() uploader = uploader_lib.TensorBoardUploader(mock_client, "/logs/foo") @@ -762,6 +821,77 @@ def test_internal_error(self): self.assertIn("travesty", msg) +class UpdateExperimentMetadataTest(tf.test.TestCase): + def _create_mock_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) + mock_client = mock.create_autospec(stub) + return mock_client + + def test_success(self): + mock_client = _create_mock_client() + new_name = "a new name" + response = write_service_pb2.UpdateExperimentResponse() + mock_client.UpdateExperiment.return_value = response + + uploader_lib.update_experiment_metadata( + mock_client, "123", name=new_name + ) + + expected_request = write_service_pb2.UpdateExperimentRequest( + experiment=experiment_pb2.Experiment( + experiment_id="123", name=new_name + ), + experiment_mask=experiment_pb2.ExperimentMask(name=True), + ) + mock_client.UpdateExperiment.assert_called_once() + (args, _) = mock_client.UpdateExperiment.call_args + self.assertEqual(args[0], expected_request) + + def test_not_found(self): + mock_client = _create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.UpdateExperiment.side_effect = error + + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + uploader_lib.update_experiment_metadata(mock_client, "123", name="") + + def test_unauthorized(self): + mock_client = _create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope") + mock_client.UpdateExperiment.side_effect = error + + with self.assertRaises(uploader_lib.PermissionDeniedError): + uploader_lib.update_experiment_metadata(mock_client, "123", name="") + + def test_invalid_argument(self): + mock_client = _create_mock_client() + error = test_util.grpc_error( + grpc.StatusCode.INVALID_ARGUMENT, "too many" + ) + mock_client.UpdateExperiment.side_effect = error + + with self.assertRaises(uploader_lib.InvalidArgumentError) as cm: + uploader_lib.update_experiment_metadata(mock_client, "123", name="") + msg = str(cm.exception) + self.assertIn("too many", msg) + + def test_internal_error(self): + mock_client = _create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty") + mock_client.UpdateExperiment.side_effect = error + + with self.assertRaises(grpc.RpcError) as cm: + uploader_lib.update_experiment_metadata(mock_client, "123", name="") + msg = str(cm.exception) + self.assertIn("travesty", msg) + + class VarintCostTest(tf.test.TestCase): def test_varint_cost(self): self.assertEqual(uploader_lib._varint_cost(0), 1)