diff --git a/tensorboard/backend/event_processing/event_file_loader.py b/tensorboard/backend/event_processing/event_file_loader.py index 49e01b25bc..4485935806 100644 --- a/tensorboard/backend/event_processing/event_file_loader.py +++ b/tensorboard/backend/event_processing/event_file_loader.py @@ -173,10 +173,27 @@ class EventFileLoader(LegacyEventFileLoader): Specifically, this includes `data_compat` and `dataclass_compat`. """ + def __init__(self, file_path): + super(EventFileLoader, self).__init__(file_path) + # Track initial metadata for each tag, for `dataclass_compat`. + # This is meant to be tracked per run, not per event file, so + # there is a potential failure case when the second event file + # in a single run has no summary metadata. This only occurs when + # all of the following hold: (a) the events were written with + # the TensorFlow 1.x (not 2.x) writer, (b) the summaries were + # created by `tensorboard.summary.v1` ops and so do not undergo + # `data_compat` transformation, and (c) the file writer was + # reopened by calling `.reopen()` on it, which creates a new + # file but does not clear the tag cache. This is considered + # sufficiently improbable that we don't take extra mitigations. + self._initial_metadata = {} # from tag name to `SummaryMetadata` + def Load(self): for event in super(EventFileLoader, self).Load(): event = data_compat.migrate_event(event) - events = dataclass_compat.migrate_event(event) + events = dataclass_compat.migrate_event( + event, self._initial_metadata + ) for event in events: yield event diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py index 08cdb67ae8..9942a6b624 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py @@ -59,12 +59,15 @@ def __init__(self, testcase, zero_out_timestamps=False): self._testcase = testcase self.items = [] self.zero_out_timestamps = zero_out_timestamps + self._initial_metadata = {} def Load(self): while self.items: event = self.items.pop(0) event = data_compat.migrate_event(event) - events = dataclass_compat.migrate_event(event) + events = dataclass_compat.migrate_event( + event, self._initial_metadata + ) for event in events: yield event diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index c21bd13daf..e42e288b6d 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -37,13 +37,18 @@ from tensorboard.util import tensor_util -def migrate_event(event): +def migrate_event(event, initial_metadata): """Migrate an event to a sequence of events. Args: event: An `event_pb2.Event`. The caller transfers ownership of the event to this method; the event may be mutated, and may or may not appear in the returned sequence. + initial_metadata: Map from tag name (string) to `SummaryMetadata` + proto for the initial occurrence of the given tag within the + enclosing run. While loading a given run, the caller should + always pass the same dictionary here, initially `{}`; this + function will mutate it and reuse it for future calls. Returns: A sequence of `event_pb2.Event`s to use instead of `event`. @@ -51,7 +56,7 @@ def migrate_event(event): if event.HasField("graph_def"): return _migrate_graph_event(event) if event.HasField("summary"): - return _migrate_summary_event(event) + return _migrate_summary_event(event, initial_metadata) return (event,) @@ -70,9 +75,11 @@ def _migrate_graph_event(old_event): return (old_event, result) -def _migrate_summary_event(event): +def _migrate_summary_event(event, initial_metadata): values = event.summary.value - new_values = [new for old in values for new in _migrate_value(old)] + new_values = [ + new for old in values for new in _migrate_value(old, initial_metadata) + ] # Optimization: Don't create a new event if there were no shallow # changes (there may still have been in-place changes). if len(values) == len(new_values) and all( @@ -84,11 +91,21 @@ def _migrate_summary_event(event): return (event,) -def _migrate_value(value): +def _migrate_value(value, initial_metadata): """Convert an old value to a stream of new values. May mutate.""" - if value.metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN: + metadata = initial_metadata.get(value.tag) + initial = False + if metadata is None: + initial = True + # Retain a copy of the initial metadata, so that even after we + # update its data class we know whether to also transform later + # events in this time series. + metadata = summary_pb2.SummaryMetadata() + metadata.CopyFrom(value.metadata) + initial_metadata[value.tag] = metadata + if metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN: return (value,) - plugin_name = value.metadata.plugin_data.plugin_name + plugin_name = metadata.plugin_data.plugin_name if plugin_name == histograms_metadata.PLUGIN_NAME: return _migrate_histogram_value(value) if plugin_name == images_metadata.PLUGIN_NAME: @@ -103,27 +120,32 @@ def _migrate_value(value): def _migrate_scalar_value(value): - value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR return (value,) def _migrate_histogram_value(value): - value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR return (value,) def _migrate_image_value(value): - value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE return (value,) def _migrate_text_value(value): - value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR return (value,) def _migrate_hparams_value(value): - value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR if not value.HasField("tensor"): value.tensor.CopyFrom(hparams_metadata.NULL_TENSOR) return (value,) diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index b704600e57..34e51b38b9 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -51,11 +51,15 @@ class MigrateEventTest(tf.test.TestCase): """Tests for `migrate_event`.""" - def _migrate_event(self, old_event): + def _migrate_event(self, old_event, initial_metadata=None): """Like `migrate_event`, but performs some sanity checks.""" + if initial_metadata is None: + initial_metadata = {} old_event_copy = event_pb2.Event() old_event_copy.CopyFrom(old_event) - new_events = dataclass_compat.migrate_event(old_event) + new_events = dataclass_compat.migrate_event( + old_event, initial_metadata=initial_metadata + ) for event in new_events: # ensure that wall time and step are preserved self.assertEqual(event.wall_time, old_event.wall_time) self.assertEqual(event.step, old_event.step) @@ -95,6 +99,35 @@ def test_already_newstyle_summary_passes_through(self): self.assertLen(new_events, 1) self.assertIs(new_events[0], old_event) + def test_doesnt_add_metadata_to_later_steps(self): + old_events = [] + for step in range(3): + e = event_pb2.Event() + e.step = step + summary = scalar_summary.pb("foo", 0.125) + if step > 0: + for v in summary.value: + v.ClearField("metadata") + e.summary.ParseFromString(summary.SerializeToString()) + old_events.append(e) + + initial_metadata = {} + new_events = [] + for e in old_events: + migrated = self._migrate_event(e, initial_metadata=initial_metadata) + new_events.extend(migrated) + + self.assertLen(new_events, len(old_events)) + self.assertEqual( + { + e.step + for e in new_events + for v in e.summary.value + if v.HasField("metadata") + }, + {0}, + ) + def test_scalar(self): old_event = event_pb2.Event() old_event.step = 123 diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index 047a25fec1..d91415547c 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -1278,9 +1278,12 @@ def _clear_wall_times(request): def _apply_compat(events): + initial_metadata = {} for event in events: event = data_compat.migrate_event(event) - events = dataclass_compat.migrate_event(event) + events = dataclass_compat.migrate_event( + event, initial_metadata=initial_metadata + ) for event in events: yield event