diff --git a/tensorboard/plugins/scalar/scalars_plugin_test.py b/tensorboard/plugins/scalar/scalars_plugin_test.py index c523fc1294..a7b0990a6a 100644 --- a/tensorboard/plugins/scalar/scalars_plugin_test.py +++ b/tensorboard/plugins/scalar/scalars_plugin_test.py @@ -37,7 +37,9 @@ from tensorboard.util import test_util -@test_util.run_v1_only('Requires contrib for db writer or uses op.Placeholder') +tf.compat.v1.enable_eager_execution() + + class ScalarsPluginTest(tf.test.TestCase): _STEPS = 9 @@ -85,27 +87,26 @@ def set_up_db(self): self.plugin = scalars_plugin.ScalarsPlugin(context) def generate_run_to_db(self, experiment_name, run_name): + # This method uses `tf.contrib.summary`, and so must only be invoked + # when TensorFlow 1.x is installed. tf.compat.v1.reset_default_graph() - - global_step = tf.compat.v1.placeholder(tf.int64) - db_writer = tf.contrib.summary.create_db_writer( - db_uri=self.db_path, - experiment_name=experiment_name, - run_name=run_name, - user_name='user') - - scalar_ops = None - with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar(self._SCALAR_TAG, 42, step=global_step) - flush_op = tf.contrib.summary.flush(db_writer._resource) - - with tf.compat.v1.Session() as sess: - sess.run(tf.contrib.summary.summary_writer_initializer_op()) - summaries = tf.contrib.summary.all_summary_ops() - for step in xrange(self._STEPS): - feed_dict = {global_step: step} - sess.run(summaries, feed_dict=feed_dict) - sess.run(flush_op) + with tf.compat.v1.Graph().as_default(): + global_step = tf.compat.v1.placeholder(tf.int64) + db_writer = tf.contrib.summary.create_db_writer( + db_uri=self.db_path, + experiment_name=experiment_name, + run_name=run_name, + user_name='user') + with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar(self._SCALAR_TAG, 42, step=global_step) + flush_op = tf.contrib.summary.flush(db_writer._resource) + with tf.compat.v1.Session() as sess: + sess.run(tf.contrib.summary.summary_writer_initializer_op()) + summaries = tf.contrib.summary.all_summary_ops() + for step in xrange(self._STEPS): + feed_dict = {global_step: step} + sess.run(summaries, feed_dict=feed_dict) + sess.run(flush_op) def testRoutesProvided(self): """Tests that the plugin offers the correct routes.""" @@ -115,29 +116,28 @@ def testRoutesProvided(self): self.assertIsInstance(routes['/tags'], collections.Callable) def generate_run(self, run_name): - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3]) - - if run_name == self._RUN_WITH_LEGACY_SCALARS: - tf.compat.v1.summary.scalar(self._LEGACY_SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder)) - elif run_name == self._RUN_WITH_SCALARS: - summary.op(self._SCALAR_TAG, tf.reduce_sum(input_tensor=placeholder), - display_name=self._DISPLAY_NAME, - description=self._DESCRIPTION) - elif run_name == self._RUN_WITH_HISTOGRAM: - tf.compat.v1.summary.histogram(self._HISTOGRAM_TAG, placeholder) - else: - assert False, 'Invalid run name: %r' % run_name - summ = tf.compat.v1.summary.merge_all() - subdir = os.path.join(self.logdir, run_name) with test_util.FileWriterCache.get(subdir) as writer: - writer.add_graph(sess.graph) for step in xrange(self._STEPS): - feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} - s = sess.run(summ, feed_dict=feed_dict) - writer.add_summary(s, global_step=step) + data = [1 + step, 2 + step, 3 + step] + if run_name == self._RUN_WITH_LEGACY_SCALARS: + summ = tf.compat.v1.summary.scalar( + self._LEGACY_SCALAR_TAG, tf.reduce_mean(data), + ).numpy() + elif run_name == self._RUN_WITH_SCALARS: + summ = summary.op( + self._SCALAR_TAG, + tf.reduce_sum(data), + display_name=self._DISPLAY_NAME, + description=self._DESCRIPTION, + ).numpy() + elif run_name == self._RUN_WITH_HISTOGRAM: + summ = tf.compat.v1.summary.histogram( + self._HISTOGRAM_TAG, data + ).numpy() + else: + assert False, 'Invalid run name: %r' % run_name + writer.add_summary(summ, global_step=step) def test_index(self): self.set_up_with_runs([self._RUN_WITH_LEGACY_SCALARS, @@ -232,6 +232,7 @@ def test_active_with_all(self): self._RUN_WITH_HISTOGRAM]) self.assertTrue(self.plugin.is_active()) + @test_util.run_v1_only('Requires contrib for db writer') def test_scalars_db_without_exp(self): self.set_up_db() self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) @@ -245,6 +246,7 @@ def test_scalars_db_without_exp(self): # raw SQL queries though. self.assertEqual(len(data), 0) + @test_util.run_v1_only('Requires contrib for db writer') def test_scalars_db_filter_by_experiment(self): self.set_up_db() self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) @@ -257,6 +259,7 @@ def test_scalars_db_filter_by_experiment(self): self.assertEqual('application/json', mime_type) self.assertEqual(len(data), self._STEPS) + @test_util.run_v1_only('Requires contrib for db writer') def test_scalars_db_no_match(self): self.set_up_db() self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS)