Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 44 additions & 41 deletions tensorboard/plugins/scalar/scalars_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from tensorboard.util import test_util


@test_util.run_v1_only('Requires contrib for db writer or uses op.Placeholder')
tf.compat.v1.enable_eager_execution()


class ScalarsPluginTest(tf.test.TestCase):

_STEPS = 9
Expand Down Expand Up @@ -85,27 +87,26 @@ def set_up_db(self):
self.plugin = scalars_plugin.ScalarsPlugin(context)

def generate_run_to_db(self, experiment_name, run_name):
# This method uses `tf.contrib.summary`, and so must only be invoked
# when TensorFlow 1.x is installed.
tf.compat.v1.reset_default_graph()

global_step = tf.compat.v1.placeholder(tf.int64)
db_writer = tf.contrib.summary.create_db_writer(
db_uri=self.db_path,
experiment_name=experiment_name,
run_name=run_name,
user_name='user')

scalar_ops = None
with db_writer.as_default(), tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar(self._SCALAR_TAG, 42, step=global_step)
flush_op = tf.contrib.summary.flush(db_writer._resource)

with tf.compat.v1.Session() as sess:
sess.run(tf.contrib.summary.summary_writer_initializer_op())
summaries = tf.contrib.summary.all_summary_ops()
for step in xrange(self._STEPS):
feed_dict = {global_step: step}
sess.run(summaries, feed_dict=feed_dict)
sess.run(flush_op)
with tf.compat.v1.Graph().as_default():
global_step = tf.compat.v1.placeholder(tf.int64)
db_writer = tf.contrib.summary.create_db_writer(
db_uri=self.db_path,
experiment_name=experiment_name,
run_name=run_name,
user_name='user')
with db_writer.as_default(), tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar(self._SCALAR_TAG, 42, step=global_step)
flush_op = tf.contrib.summary.flush(db_writer._resource)
with tf.compat.v1.Session() as sess:
sess.run(tf.contrib.summary.summary_writer_initializer_op())
summaries = tf.contrib.summary.all_summary_ops()
for step in xrange(self._STEPS):
feed_dict = {global_step: step}
sess.run(summaries, feed_dict=feed_dict)
sess.run(flush_op)

def testRoutesProvided(self):
"""Tests that the plugin offers the correct routes."""
Expand All @@ -115,29 +116,28 @@ def testRoutesProvided(self):
self.assertIsInstance(routes['/tags'], collections.Callable)

def generate_run(self, run_name):
tf.compat.v1.reset_default_graph()
sess = tf.compat.v1.Session()
placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3])

if run_name == self._RUN_WITH_LEGACY_SCALARS:
tf.compat.v1.summary.scalar(self._LEGACY_SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder))
elif run_name == self._RUN_WITH_SCALARS:
summary.op(self._SCALAR_TAG, tf.reduce_sum(input_tensor=placeholder),
display_name=self._DISPLAY_NAME,
description=self._DESCRIPTION)
elif run_name == self._RUN_WITH_HISTOGRAM:
tf.compat.v1.summary.histogram(self._HISTOGRAM_TAG, placeholder)
else:
assert False, 'Invalid run name: %r' % run_name
summ = tf.compat.v1.summary.merge_all()

subdir = os.path.join(self.logdir, run_name)
with test_util.FileWriterCache.get(subdir) as writer:
writer.add_graph(sess.graph)
for step in xrange(self._STEPS):
feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]}
s = sess.run(summ, feed_dict=feed_dict)
writer.add_summary(s, global_step=step)
data = [1 + step, 2 + step, 3 + step]
if run_name == self._RUN_WITH_LEGACY_SCALARS:
summ = tf.compat.v1.summary.scalar(
self._LEGACY_SCALAR_TAG, tf.reduce_mean(data),
).numpy()
elif run_name == self._RUN_WITH_SCALARS:
summ = summary.op(
self._SCALAR_TAG,
tf.reduce_sum(data),
display_name=self._DISPLAY_NAME,
description=self._DESCRIPTION,
).numpy()
elif run_name == self._RUN_WITH_HISTOGRAM:
summ = tf.compat.v1.summary.histogram(
self._HISTOGRAM_TAG, data
).numpy()
else:
assert False, 'Invalid run name: %r' % run_name
writer.add_summary(summ, global_step=step)

def test_index(self):
self.set_up_with_runs([self._RUN_WITH_LEGACY_SCALARS,
Expand Down Expand Up @@ -232,6 +232,7 @@ def test_active_with_all(self):
self._RUN_WITH_HISTOGRAM])
self.assertTrue(self.plugin.is_active())

@test_util.run_v1_only('Requires contrib for db writer')
def test_scalars_db_without_exp(self):
self.set_up_db()
self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS)
Expand All @@ -245,6 +246,7 @@ def test_scalars_db_without_exp(self):
# raw SQL queries though.
self.assertEqual(len(data), 0)

@test_util.run_v1_only('Requires contrib for db writer')
def test_scalars_db_filter_by_experiment(self):
self.set_up_db()
self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS)
Expand All @@ -257,6 +259,7 @@ def test_scalars_db_filter_by_experiment(self):
self.assertEqual('application/json', mime_type)
self.assertEqual(len(data), self._STEPS)

@test_util.run_v1_only('Requires contrib for db writer')
def test_scalars_db_no_match(self):
self.set_up_db()
self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS)
Expand Down