diff --git a/tensorboard/uploader/uploader_subcommand.py b/tensorboard/uploader/uploader_subcommand.py index d580dfc3d7..23dbdd3a22 100644 --- a/tensorboard/uploader/uploader_subcommand.py +++ b/tensorboard/uploader/uploader_subcommand.py @@ -311,6 +311,8 @@ def execute(self, server_info, channel): channel ) fieldmask = experiment_pb2.ExperimentMask( + name=True, + description=True, create_time=True, update_time=True, num_runs=True, diff --git a/tensorboard/uploader/uploader_subcommand_test.py b/tensorboard/uploader/uploader_subcommand_test.py index ccd5537986..4b0eb43d4c 100644 --- a/tensorboard/uploader/uploader_subcommand_test.py +++ b/tensorboard/uploader/uploader_subcommand_test.py @@ -20,10 +20,12 @@ import tensorflow as tf +from tensorboard.uploader.proto import experiment_pb2 from tensorboard.uploader.proto import server_info_pb2 from tensorboard.uploader.proto import write_service_pb2 from tensorboard.uploader.proto import write_service_pb2_grpc from tensorboard.uploader import dry_run_stubs +from tensorboard.uploader import exporter as exporter_lib from tensorboard.uploader import uploader as uploader_lib from tensorboard.uploader import uploader_subcommand from tensorboard.plugins.histogram import metadata as histograms_metadata @@ -202,6 +204,31 @@ def testUploadIntentNonDryRunNonOneShotInterrupted(self): mock_stdout_write.call_args_list[-1][0][0], ) + def testListIntentSetsExperimentMask(self): + mock_server_info = mock.MagicMock() + mock_channel = mock.MagicMock() + expected_mask = experiment_pb2.ExperimentMask( + name=True, + description=True, + create_time=True, + update_time=True, + num_runs=True, + num_tags=True, + num_scalars=True, + total_tensor_bytes=True, + total_blob_bytes=True, + ) + with mock.patch.object( + exporter_lib, + "list_experiments", + ): + intent = uploader_subcommand._ListIntent() + intent.execute(mock_server_info, mock_channel) + actual_mask = exporter_lib.list_experiments.call_args[1][ + "fieldmask" + ] + self.assertEquals(actual_mask, expected_mask) + if __name__ == "__main__": tf.test.main()