Skip to content

Commit aa4b9af

Browse files
authored
data: add tests for blob sequence handling (#3435)
Summary: Follow-up to #2991. Fixes #3434. Test Plan: Tests pass as written. wchargin-branch: data-blob-sequence-tests
1 parent 68b33e9 commit aa4b9af

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

tensorboard/backend/event_processing/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ py_test(
5757
"//tensorboard/plugins/graph:metadata",
5858
"//tensorboard/plugins/histogram:metadata",
5959
"//tensorboard/plugins/histogram:summary_v2",
60+
"//tensorboard/plugins/image:metadata",
61+
"//tensorboard/plugins/image:summary_v2",
6062
"//tensorboard/plugins/scalar:metadata",
6163
"//tensorboard/plugins/scalar:summary_v2",
6264
"//tensorboard/util:tensor_util",

tensorboard/backend/event_processing/data_provider_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from tensorboard.plugins.histogram import summary_v2 as histogram_summary
3636
from tensorboard.plugins.scalar import metadata as scalar_metadata
3737
from tensorboard.plugins.scalar import summary_v2 as scalar_summary
38+
from tensorboard.plugins.image import metadata as image_metadata
39+
from tensorboard.plugins.image import summary_v2 as image_summary
3840
from tensorboard.util import tensor_util
3941
import tensorflow.compat.v1 as tf1
4042
import tensorflow.compat.v2 as tf
@@ -91,6 +93,29 @@ def setUp(self):
9193
name, tensor * i, step=i, description=description
9294
)
9395

96+
logdir = os.path.join(self.logdir, "mondrian")
97+
with tf.summary.create_file_writer(logdir).as_default():
98+
data = [
99+
("red", (221, 28, 38), "top-right"),
100+
("blue", (1, 91, 158), "bottom-left"),
101+
("yellow", (239, 220, 111), "bottom-right"),
102+
]
103+
for (name, color, description) in data:
104+
image_1x1 = tf.constant([[[color]]], dtype=tf.uint8)
105+
for i in xrange(1, 11):
106+
# Use a non-monotonic sequence of sample sizes to
107+
# test `max_length` calculation.
108+
k = 6 - abs(6 - i) # 1, .., 6, .., 2
109+
# a `k`-sample image summary of `i`-by-`i` images
110+
image = tf.tile(image_1x1, [k, i, i, 1])
111+
image_summary.image(
112+
name,
113+
image,
114+
step=i,
115+
description=description,
116+
max_outputs=99,
117+
)
118+
94119
def create_multiplexer(self):
95120
multiplexer = event_multiplexer.EventMultiplexer()
96121
multiplexer.AddRunsFromDirectory(self.logdir)
@@ -115,6 +140,7 @@ def test_list_plugins_with_no_graph(self):
115140
"greetings",
116141
"marigraphs",
117142
histogram_metadata.PLUGIN_NAME,
143+
image_metadata.PLUGIN_NAME,
118144
scalar_metadata.PLUGIN_NAME,
119145
],
120146
)
@@ -134,6 +160,7 @@ def test_list_plugins_with_graph(self):
134160
"marigraphs",
135161
graph_metadata.PLUGIN_NAME,
136162
histogram_metadata.PLUGIN_NAME,
163+
image_metadata.PLUGIN_NAME,
137164
scalar_metadata.PLUGIN_NAME,
138165
],
139166
)
@@ -371,6 +398,90 @@ def test_read_tensors_downsamples(self):
371398
)
372399
self.assertLen(result["lebesgue"]["uniform"], 3)
373400

401+
def test_list_blob_sequences(self):
402+
provider = self.create_provider()
403+
404+
with self.subTest("finds all time series for a plugin"):
405+
result = provider.list_blob_sequences(
406+
experiment_id="unused", plugin_name=image_metadata.PLUGIN_NAME
407+
)
408+
self.assertItemsEqual(result.keys(), ["mondrian"])
409+
self.assertItemsEqual(
410+
result["mondrian"].keys(), ["red", "blue", "yellow"]
411+
)
412+
sample = result["mondrian"]["blue"]
413+
self.assertIsInstance(sample, base_provider.BlobSequenceTimeSeries)
414+
self.assertEqual(sample.max_step, 10)
415+
# nothing to test for wall time, as it can't be mocked out
416+
self.assertEqual(sample.plugin_content, b"")
417+
self.assertEqual(sample.max_length, 6 + 2)
418+
self.assertEqual(sample.description, "bottom-left")
419+
self.assertEqual(sample.display_name, "")
420+
421+
with self.subTest("filters by run/tag"):
422+
result = provider.list_blob_sequences(
423+
experiment_id="unused",
424+
plugin_name=image_metadata.PLUGIN_NAME,
425+
run_tag_filter=base_provider.RunTagFilter(
426+
runs=["mondrian", "picasso"], tags=["yellow", "green't"]
427+
),
428+
)
429+
self.assertItemsEqual(result.keys(), ["mondrian"])
430+
self.assertItemsEqual(result["mondrian"].keys(), ["yellow"])
431+
self.assertIsInstance(
432+
result["mondrian"]["yellow"],
433+
base_provider.BlobSequenceTimeSeries,
434+
)
435+
436+
def test_read_blob_sequences_and_read_blob(self):
437+
provider = self.create_provider()
438+
439+
with self.subTest("reads all time series for a plugin"):
440+
result = provider.read_blob_sequences(
441+
experiment_id="unused",
442+
plugin_name=image_metadata.PLUGIN_NAME,
443+
downsample=4,
444+
)
445+
self.assertItemsEqual(result.keys(), ["mondrian"])
446+
self.assertItemsEqual(
447+
result["mondrian"].keys(), ["red", "blue", "yellow"]
448+
)
449+
sample = result["mondrian"]["blue"]
450+
self.assertLen(sample, 4) # downsampled from 10
451+
last = sample[-1]
452+
self.assertIsInstance(last, base_provider.BlobSequenceDatum)
453+
self.assertEqual(last.step, 10)
454+
self.assertLen(last.values, 2 + 2)
455+
blobs = [provider.read_blob(v.blob_key) for v in last.values]
456+
self.assertEqual(blobs[0], b"10")
457+
self.assertEqual(blobs[1], b"10")
458+
self.assertStartsWith(blobs[2], b"\x89PNG")
459+
self.assertStartsWith(blobs[3], b"\x89PNG")
460+
461+
blue1 = blobs[2]
462+
blue2 = blobs[3]
463+
red1 = provider.read_blob(
464+
result["mondrian"]["red"][-1].values[2].blob_key
465+
)
466+
self.assertEqual(blue1, blue2)
467+
self.assertNotEqual(blue1, red1)
468+
469+
with self.subTest("filters by run/tag"):
470+
result = provider.read_blob_sequences(
471+
experiment_id="unused",
472+
plugin_name=image_metadata.PLUGIN_NAME,
473+
run_tag_filter=base_provider.RunTagFilter(
474+
runs=["mondrian", "picasso"], tags=["yellow", "green't"]
475+
),
476+
downsample=1,
477+
)
478+
self.assertItemsEqual(result.keys(), ["mondrian"])
479+
self.assertItemsEqual(result["mondrian"].keys(), ["yellow"])
480+
self.assertIsInstance(
481+
result["mondrian"]["yellow"][0],
482+
base_provider.BlobSequenceDatum,
483+
)
484+
374485

375486
class DownsampleTest(tf.test.TestCase):
376487
"""Tests for the `_downsample` private helper function."""

0 commit comments

Comments
 (0)