Skip to content

Commit 7195a4f

Browse files
committed
data: add tests for blob sequence handling
Summary: Follow-up to #2991. Fixes #3434. Test Plan: Tests pass as written. wchargin-branch: data-blob-sequence-tests wchargin-source: fbd3302933cb0c50609df970edf137202723c769
1 parent 644a7b3 commit 7195a4f

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

tensorboard/backend/event_processing/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ py_test(
5959
"//tensorboard/plugins/histogram:summary_v2",
6060
"//tensorboard/plugins/scalar:metadata",
6161
"//tensorboard/plugins/scalar:summary_v2",
62+
"//tensorboard/plugins/image:metadata",
63+
"//tensorboard/plugins/image:summary_v2",
6264
"//tensorboard/util:tensor_util",
6365
"@org_pythonhosted_six",
6466
],

tensorboard/backend/event_processing/data_provider_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from tensorboard.plugins.histogram import summary_v2 as histogram_summary
3636
from tensorboard.plugins.scalar import metadata as scalar_metadata
3737
from tensorboard.plugins.scalar import summary_v2 as scalar_summary
38+
from tensorboard.plugins.image import metadata as image_metadata
39+
from tensorboard.plugins.image import summary_v2 as image_summary
3840
from tensorboard.util import tensor_util
3941
import tensorflow.compat.v1 as tf1
4042
import tensorflow.compat.v2 as tf
@@ -91,6 +93,27 @@ def setUp(self):
9193
name, tensor * i, step=i, description=description
9294
)
9395

96+
logdir = os.path.join(self.logdir, "mondrian")
97+
with tf.summary.create_file_writer(logdir).as_default():
98+
data = [
99+
("red", (221, 28, 38), "top-right"),
100+
("blue", (1, 91, 158), "bottom-left"),
101+
("yellow", (239, 220, 111), "bottom-right"),
102+
]
103+
for (name, color, description) in data:
104+
image_1x1 = tf.constant([[[color]]], dtype=tf.uint8)
105+
for i in xrange(1, 11):
106+
k = 6 - abs(6 - i) # 1, .., 6, .., 2
107+
# a `k`-sample image summary of `i`-by-`i` images
108+
image = tf.tile(image_1x1, [k, i, i, 1])
109+
image_summary.image(
110+
name,
111+
image,
112+
step=i,
113+
description=description,
114+
max_outputs=99,
115+
)
116+
94117
def create_multiplexer(self):
95118
multiplexer = event_multiplexer.EventMultiplexer()
96119
multiplexer.AddRunsFromDirectory(self.logdir)
@@ -115,6 +138,7 @@ def test_list_plugins_with_no_graph(self):
115138
"greetings",
116139
"marigraphs",
117140
histogram_metadata.PLUGIN_NAME,
141+
image_metadata.PLUGIN_NAME,
118142
scalar_metadata.PLUGIN_NAME,
119143
],
120144
)
@@ -134,6 +158,7 @@ def test_list_plugins_with_graph(self):
134158
"marigraphs",
135159
graph_metadata.PLUGIN_NAME,
136160
histogram_metadata.PLUGIN_NAME,
161+
image_metadata.PLUGIN_NAME,
137162
scalar_metadata.PLUGIN_NAME,
138163
],
139164
)
@@ -371,6 +396,90 @@ def test_read_tensors_downsamples(self):
371396
)
372397
self.assertLen(result["lebesgue"]["uniform"], 3)
373398

399+
def test_list_blob_sequences(self):
400+
provider = self.create_provider()
401+
402+
with self.subTest("finds all time series for a plugin"):
403+
result = provider.list_blob_sequences(
404+
experiment_id="unused", plugin_name=image_metadata.PLUGIN_NAME
405+
)
406+
self.assertItemsEqual(result.keys(), ["mondrian"])
407+
self.assertItemsEqual(
408+
result["mondrian"].keys(), ["red", "blue", "yellow"]
409+
)
410+
sample = result["mondrian"]["blue"]
411+
self.assertIsInstance(sample, base_provider.BlobSequenceTimeSeries)
412+
self.assertEqual(sample.max_step, 10)
413+
# nothing to test for wall time, as it can't be mocked out
414+
self.assertEqual(sample.plugin_content, b"")
415+
self.assertEqual(sample.max_length, 6 + 2)
416+
self.assertEqual(sample.description, "bottom-left")
417+
self.assertEqual(sample.display_name, "")
418+
419+
with self.subTest("filters by run/tag"):
420+
result = provider.list_blob_sequences(
421+
experiment_id="unused",
422+
plugin_name=image_metadata.PLUGIN_NAME,
423+
run_tag_filter=base_provider.RunTagFilter(
424+
runs=["mondrian", "picasso"], tags=["yellow", "green't"]
425+
),
426+
)
427+
self.assertItemsEqual(result.keys(), ["mondrian"])
428+
self.assertItemsEqual(result["mondrian"].keys(), ["yellow"])
429+
self.assertIsInstance(
430+
result["mondrian"]["yellow"],
431+
base_provider.BlobSequenceTimeSeries,
432+
)
433+
434+
def test_read_blob_sequences_and_read_blob(self):
435+
provider = self.create_provider()
436+
437+
with self.subTest("reads all time series for a plugin"):
438+
result = provider.read_blob_sequences(
439+
experiment_id="unused",
440+
plugin_name=image_metadata.PLUGIN_NAME,
441+
downsample=4,
442+
)
443+
self.assertItemsEqual(result.keys(), ["mondrian"])
444+
self.assertItemsEqual(
445+
result["mondrian"].keys(), ["red", "blue", "yellow"]
446+
)
447+
sample = result["mondrian"]["blue"]
448+
self.assertLen(sample, 4) # downsampled from 10
449+
last = sample[-1]
450+
self.assertIsInstance(last, base_provider.BlobSequenceDatum)
451+
self.assertEqual(last.step, 10)
452+
self.assertLen(last.values, 2 + 2)
453+
blobs = [provider.read_blob(v.blob_key) for v in last.values]
454+
self.assertEqual(blobs[0], b"10")
455+
self.assertEqual(blobs[1], b"10")
456+
self.assertStartsWith(blobs[2], b"\x89PNG")
457+
self.assertStartsWith(blobs[3], b"\x89PNG")
458+
459+
blue1 = blobs[2]
460+
blue2 = blobs[3]
461+
red1 = provider.read_blob(
462+
result["mondrian"]["red"][-1].values[2].blob_key
463+
)
464+
self.assertEqual(blue1, blue2)
465+
self.assertNotEqual(blue1, red1)
466+
467+
with self.subTest("filters by run/tag"):
468+
result = provider.read_blob_sequences(
469+
experiment_id="unused",
470+
plugin_name=image_metadata.PLUGIN_NAME,
471+
run_tag_filter=base_provider.RunTagFilter(
472+
runs=["mondrian", "picasso"], tags=["yellow", "green't"]
473+
),
474+
downsample=1,
475+
)
476+
self.assertItemsEqual(result.keys(), ["mondrian"])
477+
self.assertItemsEqual(result["mondrian"].keys(), ["yellow"])
478+
self.assertIsInstance(
479+
result["mondrian"]["yellow"][0],
480+
base_provider.BlobSequenceDatum,
481+
)
482+
374483

375484
class DownsampleTest(tf.test.TestCase):
376485
"""Tests for the `_downsample` private helper function."""

0 commit comments

Comments
 (0)