diff --git a/tensorboard/plugins/projector/BUILD b/tensorboard/plugins/projector/BUILD index 1d4d06308c..0ab83b7eeb 100644 --- a/tensorboard/plugins/projector/BUILD +++ b/tensorboard/plugins/projector/BUILD @@ -49,6 +49,7 @@ py_test( ":projector", "//tensorboard:expect_tensorflow_installed", "//tensorboard/util:test_util", + "@org_pythonhosted_six", ], ) diff --git a/tensorboard/plugins/projector/__init__.py b/tensorboard/plugins/projector/__init__.py index 9fce67188c..47e0008f56 100644 --- a/tensorboard/plugins/projector/__init__.py +++ b/tensorboard/plugins/projector/__init__.py @@ -35,11 +35,13 @@ from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig -def visualize_embeddings(summary_writer, config): +def visualize_embeddings(logdir, config): """Stores a config file used by the embedding projector. Args: - summary_writer: The summary writer used for writing events. + logdir: Directory into which to store the config file, as a `str`. + For compatibility, can also be a `tf.compat.v1.summary.FileWriter` + object open at the desired logdir. config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig` proto that holds the configuration for the projector such as paths to checkpoint files and metadata files for the embeddings. If @@ -49,11 +51,12 @@ def visualize_embeddings(summary_writer, config): Raises: ValueError: If the summary writer does not have a `logdir`. """ - logdir = summary_writer.get_logdir() + # Convert from `tf.compat.v1.summary.FileWriter` if necessary. + logdir = getattr(logdir, 'get_logdir', lambda: logdir)() # Sanity checks. if logdir is None: - raise ValueError('Summary writer must have a logdir') + raise ValueError('Expected logdir to be a path, but got None') # Saving the config file in the logdir. config_pbtxt = _text_format.MessageToString(config) diff --git a/tensorboard/plugins/projector/projector_api_test.py b/tensorboard/plugins/projector/projector_api_test.py index b89846ccec..417c5839a0 100644 --- a/tensorboard/plugins/projector/projector_api_test.py +++ b/tensorboard/plugins/projector/projector_api_test.py @@ -19,8 +19,8 @@ from __future__ import print_function import os -import shutil +import six import tensorflow as tf from google.protobuf import text_format @@ -28,30 +28,54 @@ from tensorboard.plugins import projector from tensorboard.util import test_util -tf.compat.v1.disable_v2_behavior() - +def create_dummy_config(): + return projector.ProjectorConfig( + model_checkpoint_path='test', + embeddings = [ + projector.EmbeddingInfo( + tensor_name='tensor1', + metadata_path='metadata1', + ), + ], + ) class ProjectorApiTest(tf.test.TestCase): - def testVisualizeEmbeddings(self): - # Create a dummy configuration. - config = projector.ProjectorConfig() - config.model_checkpoint_path = 'test' - emb1 = config.embeddings.add() - emb1.tensor_name = 'tensor1' - emb1.metadata_path = 'metadata1' + def test_visualize_embeddings_with_logdir(self): + logdir = self.get_temp_dir() + config = create_dummy_config() + projector.visualize_embeddings(logdir, config) + + # Read the configurations from disk and make sure it matches the original. + with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f: + config2 = projector.ProjectorConfig() + text_format.Parse(f.read(), config2) - # Call the API method to save the configuration to a temporary dir. - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - with test_util.FileWriterCache.get(temp_dir) as writer: - projector.visualize_embeddings(writer, config) + self.assertEqual(config, config2) + + def test_visualize_embeddings_with_file_writer(self): + if tf.__version__ == "stub": + self.skipTest("Requires TensorFlow for FileWriter") + logdir = self.get_temp_dir() + config = create_dummy_config() + + with tf.compat.v1.Graph().as_default(): + with test_util.FileWriterCache.get(logdir) as writer: + projector.visualize_embeddings(writer, config) # Read the configurations from disk and make sure it matches the original. - with tf.io.gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f: + with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f: config2 = projector.ProjectorConfig() text_format.Parse(f.read(), config2) - self.assertEqual(config, config2) + + self.assertEqual(config, config2) + + def test_visualize_embeddings_no_logdir(self): + with six.assertRaisesRegex( + self, + ValueError, + "Expected logdir to be a path, but got None"): + projector.visualize_embeddings(None, create_dummy_config()) if __name__ == '__main__':