Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tensorboard/backend/event_processing/event_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,27 @@ class EventFileLoader(LegacyEventFileLoader):
Specifically, this includes `data_compat` and `dataclass_compat`.
"""

def __init__(self, file_path):
super(EventFileLoader, self).__init__(file_path)
# Track initial metadata for each tag, for `dataclass_compat`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for documenting this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Hopefully no one ever has to read it.

# This is meant to be tracked per run, not per event file, so
# there is a potential failure case when the second event file
# in a single run has no summary metadata. This only occurs when
# all of the following hold: (a) the events were written with
# the TensorFlow 1.x (not 2.x) writer, (b) the summaries were
# created by `tensorboard.summary.v1` ops and so do not undergo
# `data_compat` transformation, and (c) the file writer was
# reopened by calling `.reopen()` on it, which creates a new
# file but does not clear the tag cache. This is considered
# sufficiently improbable that we don't take extra mitigations.
self._initial_metadata = {} # from tag name to `SummaryMetadata`

def Load(self):
for event in super(EventFileLoader, self).Load():
event = data_compat.migrate_event(event)
events = dataclass_compat.migrate_event(event)
events = dataclass_compat.migrate_event(
event, self._initial_metadata
)
for event in events:
yield event

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ def __init__(self, testcase, zero_out_timestamps=False):
self._testcase = testcase
self.items = []
self.zero_out_timestamps = zero_out_timestamps
self._initial_metadata = {}

def Load(self):
while self.items:
event = self.items.pop(0)
event = data_compat.migrate_event(event)
events = dataclass_compat.migrate_event(event)
events = dataclass_compat.migrate_event(
event, self._initial_metadata
)
for event in events:
yield event

Expand Down
46 changes: 34 additions & 12 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,26 @@
from tensorboard.util import tensor_util


def migrate_event(event):
def migrate_event(event, initial_metadata):
"""Migrate an event to a sequence of events.

Args:
event: An `event_pb2.Event`. The caller transfers ownership of the
event to this method; the event may be mutated, and may or may
not appear in the returned sequence.
initial_metadata: Map from tag name (string) to `SummaryMetadata`
proto for the initial occurrence of the given tag within the
enclosing run. While loading a given run, the caller should
always pass the same dictionary here, initially `{}`; this
function will mutate it and reuse it for future calls.

Returns:
A sequence of `event_pb2.Event`s to use instead of `event`.
"""
if event.HasField("graph_def"):
return _migrate_graph_event(event)
if event.HasField("summary"):
return _migrate_summary_event(event)
return _migrate_summary_event(event, initial_metadata)
return (event,)


Expand All @@ -70,9 +75,11 @@ def _migrate_graph_event(old_event):
return (old_event, result)


def _migrate_summary_event(event):
def _migrate_summary_event(event, initial_metadata):
values = event.summary.value
new_values = [new for old in values for new in _migrate_value(old)]
new_values = [
new for old in values for new in _migrate_value(old, initial_metadata)
]
# Optimization: Don't create a new event if there were no shallow
# changes (there may still have been in-place changes).
if len(values) == len(new_values) and all(
Expand All @@ -84,11 +91,21 @@ def _migrate_summary_event(event):
return (event,)


def _migrate_value(value):
def _migrate_value(value, initial_metadata):
"""Convert an old value to a stream of new values. May mutate."""
if value.metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN:
metadata = initial_metadata.get(value.tag)
initial = False
if metadata is None:
initial = True
# Retain a copy of the initial metadata, so that even after we
# update its data class we know whether to also transform later
# events in this time series.
metadata = summary_pb2.SummaryMetadata()
metadata.CopyFrom(value.metadata)
initial_metadata[value.tag] = metadata
if metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding, our assumption is that any summary that natively sets the data class on the initial metadata will already be using the dataclass-aware format (so to later summaries with that tag will require value conversions)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. In some sense there’s a bit of awkward coupling: one could imagine
factoring the migrations more finely as

  1. data_compat (v1 → v2 summaries)
  2. audio summary projection (label removal)
  3. hparams tensor fix
  4. dataclass_compat (add metadata.data_class only)

where there are some dependencies (4 depends on 2, 2 depends on 1) but
the dependency graph is not (an orientation of) a complete graph. Then
if we needed to add more read-time migrations later even after summaries
natively write data_class we’d still be able to do so.

For now, at least, I thought that it was simpler to keep everything in
this layer so that there are fewer moving pieces.

return (value,)
plugin_name = value.metadata.plugin_data.plugin_name
plugin_name = metadata.plugin_data.plugin_name
if plugin_name == histograms_metadata.PLUGIN_NAME:
return _migrate_histogram_value(value)
if plugin_name == images_metadata.PLUGIN_NAME:
Expand All @@ -103,27 +120,32 @@ def _migrate_value(value):


def _migrate_scalar_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR
return (value,)


def _migrate_histogram_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_image_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE
return (value,)


def _migrate_text_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_hparams_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
if not value.HasField("tensor"):
value.tensor.CopyFrom(hparams_metadata.NULL_TENSOR)
return (value,)
37 changes: 35 additions & 2 deletions tensorboard/dataclass_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
class MigrateEventTest(tf.test.TestCase):
"""Tests for `migrate_event`."""

def _migrate_event(self, old_event):
def _migrate_event(self, old_event, initial_metadata=None):
"""Like `migrate_event`, but performs some sanity checks."""
if initial_metadata is None:
initial_metadata = {}
old_event_copy = event_pb2.Event()
old_event_copy.CopyFrom(old_event)
new_events = dataclass_compat.migrate_event(old_event)
new_events = dataclass_compat.migrate_event(
old_event, initial_metadata=initial_metadata
)
for event in new_events: # ensure that wall time and step are preserved
self.assertEqual(event.wall_time, old_event.wall_time)
self.assertEqual(event.step, old_event.step)
Expand Down Expand Up @@ -95,6 +99,35 @@ def test_already_newstyle_summary_passes_through(self):
self.assertLen(new_events, 1)
self.assertIs(new_events[0], old_event)

def test_doesnt_add_metadata_to_later_steps(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I guess this behavior is fine, but it would also seem fine to me to populate the dataclass in the metadata if the metadata is present on non-initial steps (i.e. base it just on whether that summary has metadata or not) which seems a little less stateful and thus perhaps slightly simpler to reason about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that’s fine with me. It’s important that we not populate the
metadata where it didn’t previously exist, because (e.g.) this triggers
a “mismatching plugin names” ("" != "scalars") warning in the
uploader. But if the metadata already exists I don’t see any harm in
populating the data class field. Done, and removed initial args.

old_events = []
for step in range(3):
e = event_pb2.Event()
e.step = step
summary = scalar_summary.pb("foo", 0.125)
if step > 0:
for v in summary.value:
v.ClearField("metadata")
e.summary.ParseFromString(summary.SerializeToString())
old_events.append(e)

initial_metadata = {}
new_events = []
for e in old_events:
migrated = self._migrate_event(e, initial_metadata=initial_metadata)
new_events.extend(migrated)

self.assertLen(new_events, len(old_events))
self.assertEqual(
{
e.step
for e in new_events
for v in e.summary.value
if v.HasField("metadata")
},
{0},
)

def test_scalar(self):
old_event = event_pb2.Event()
old_event.step = 123
Expand Down
5 changes: 4 additions & 1 deletion tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,12 @@ def _clear_wall_times(request):


def _apply_compat(events):
initial_metadata = {}
for event in events:
event = data_compat.migrate_event(event)
events = dataclass_compat.migrate_event(event)
events = dataclass_compat.migrate_event(
event, initial_metadata=initial_metadata
)
for event in events:
yield event

Expand Down