Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/plugins/projector/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_test(
":projector",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/util:test_util",
"@org_pythonhosted_six",
],
)

Expand Down
11 changes: 7 additions & 4 deletions tensorboard/plugins/projector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig


def visualize_embeddings(summary_writer, config):
def visualize_embeddings(logdir, config):
"""Stores a config file used by the embedding projector.

Args:
summary_writer: The summary writer used for writing events.
logdir: Directory into which to store the config file, as a `str`.
For compatibility, can also be a `tf.compat.v1.summary.FileWriter`
object open at the desired logdir.
config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig`
proto that holds the configuration for the projector such as paths to
checkpoint files and metadata files for the embeddings. If
Expand All @@ -49,11 +51,12 @@ def visualize_embeddings(summary_writer, config):
Raises:
ValueError: If the summary writer does not have a `logdir`.
"""
logdir = summary_writer.get_logdir()
# Convert from `tf.compat.v1.summary.FileWriter` if necessary.
logdir = getattr(logdir, 'get_logdir', lambda: logdir)()

# Sanity checks.
if logdir is None:
raise ValueError('Summary writer must have a logdir')
raise ValueError('Expected logdir to be a path, but got None')

# Saving the config file in the logdir.
config_pbtxt = _text_format.MessageToString(config)
Expand Down
58 changes: 41 additions & 17 deletions tensorboard/plugins/projector/projector_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,63 @@
from __future__ import print_function

import os
import shutil

import six
import tensorflow as tf

from google.protobuf import text_format

from tensorboard.plugins import projector
from tensorboard.util import test_util

tf.compat.v1.disable_v2_behavior()

def create_dummy_config():
return projector.ProjectorConfig(
model_checkpoint_path='test',
embeddings = [
projector.EmbeddingInfo(
tensor_name='tensor1',
metadata_path='metadata1',
),
],
)

class ProjectorApiTest(tf.test.TestCase):

def testVisualizeEmbeddings(self):
# Create a dummy configuration.
config = projector.ProjectorConfig()
config.model_checkpoint_path = 'test'
emb1 = config.embeddings.add()
emb1.tensor_name = 'tensor1'
emb1.metadata_path = 'metadata1'
def test_visualize_embeddings_with_logdir(self):
logdir = self.get_temp_dir()
config = create_dummy_config()
projector.visualize_embeddings(logdir, config)

# Read the configurations from disk and make sure it matches the original.
with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f:
config2 = projector.ProjectorConfig()
text_format.Parse(f.read(), config2)

# Call the API method to save the configuration to a temporary dir.
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
with test_util.FileWriterCache.get(temp_dir) as writer:
projector.visualize_embeddings(writer, config)
self.assertEqual(config, config2)

def test_visualize_embeddings_with_file_writer(self):
if tf.__version__ == "stub":
self.skipTest("Requires TensorFlow for FileWriter")
logdir = self.get_temp_dir()
config = create_dummy_config()

with tf.compat.v1.Graph().as_default():
with test_util.FileWriterCache.get(logdir) as writer:
projector.visualize_embeddings(writer, config)

# Read the configurations from disk and make sure it matches the original.
with tf.io.gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f:
with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f:
config2 = projector.ProjectorConfig()
text_format.Parse(f.read(), config2)
self.assertEqual(config, config2)

self.assertEqual(config, config2)

def test_visualize_embeddings_no_logdir(self):
with six.assertRaisesRegex(
self,
ValueError,
"Expected logdir to be a path, but got None"):
projector.visualize_embeddings(None, create_dummy_config())


if __name__ == '__main__':
Expand Down