-
Notifications
You must be signed in to change notification settings - Fork 1.7k
dataclass_compat: track initial tag metadata #3512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
54c2da2
2cde07d
f339182
61ab333
45b02e7
8a4247c
aac18b4
8ef2886
f159fa5
6d11b7c
ab52a51
d17e9f9
84fd685
ff15dc7
94a974c
2ac4935
7c33f3e
a9adf93
8f469cc
ac58c6d
61cc397
524d0bc
82310b4
d553605
3d1d73f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,21 +37,26 @@ | |
| from tensorboard.util import tensor_util | ||
|
|
||
|
|
||
| def migrate_event(event): | ||
| def migrate_event(event, initial_metadata): | ||
| """Migrate an event to a sequence of events. | ||
|
|
||
| Args: | ||
| event: An `event_pb2.Event`. The caller transfers ownership of the | ||
| event to this method; the event may be mutated, and may or may | ||
| not appear in the returned sequence. | ||
| initial_metadata: Map from tag name (string) to `SummaryMetadata` | ||
| proto for the initial occurrence of the given tag within the | ||
| enclosing run. While loading a given run, the caller should | ||
| always pass the same dictionary here, initially `{}`; this | ||
| function will mutate it and reuse it for future calls. | ||
|
|
||
| Returns: | ||
| A sequence of `event_pb2.Event`s to use instead of `event`. | ||
| """ | ||
| if event.HasField("graph_def"): | ||
| return _migrate_graph_event(event) | ||
| if event.HasField("summary"): | ||
| return _migrate_summary_event(event) | ||
| return _migrate_summary_event(event, initial_metadata) | ||
| return (event,) | ||
|
|
||
|
|
||
|
|
@@ -70,9 +75,11 @@ def _migrate_graph_event(old_event): | |
| return (old_event, result) | ||
|
|
||
|
|
||
| def _migrate_summary_event(event): | ||
| def _migrate_summary_event(event, initial_metadata): | ||
| values = event.summary.value | ||
| new_values = [new for old in values for new in _migrate_value(old)] | ||
| new_values = [ | ||
| new for old in values for new in _migrate_value(old, initial_metadata) | ||
| ] | ||
| # Optimization: Don't create a new event if there were no shallow | ||
| # changes (there may still have been in-place changes). | ||
| if len(values) == len(new_values) and all( | ||
|
|
@@ -84,11 +91,21 @@ def _migrate_summary_event(event): | |
| return (event,) | ||
|
|
||
|
|
||
| def _migrate_value(value): | ||
| def _migrate_value(value, initial_metadata): | ||
| """Convert an old value to a stream of new values. May mutate.""" | ||
| if value.metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN: | ||
| metadata = initial_metadata.get(value.tag) | ||
| initial = False | ||
| if metadata is None: | ||
| initial = True | ||
| # Retain a copy of the initial metadata, so that even after we | ||
| # update its data class we know whether to also transform later | ||
| # events in this time series. | ||
| metadata = summary_pb2.SummaryMetadata() | ||
| metadata.CopyFrom(value.metadata) | ||
| initial_metadata[value.tag] = metadata | ||
| if metadata.data_class != summary_pb2.DATA_CLASS_UNKNOWN: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check my understanding, our assumption is that any summary that natively sets the data class on the initial metadata will already be using the dataclass-aware format (so to later summaries with that tag will require value conversions)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. In some sense there’s a bit of awkward coupling: one could imagine
where there are some dependencies (4 depends on 2, 2 depends on 1) but For now, at least, I thought that it was simpler to keep everything in |
||
| return (value,) | ||
| plugin_name = value.metadata.plugin_data.plugin_name | ||
| plugin_name = metadata.plugin_data.plugin_name | ||
| if plugin_name == histograms_metadata.PLUGIN_NAME: | ||
| return _migrate_histogram_value(value) | ||
| if plugin_name == images_metadata.PLUGIN_NAME: | ||
|
|
@@ -103,27 +120,32 @@ def _migrate_value(value): | |
|
|
||
|
|
||
| def _migrate_scalar_value(value): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR | ||
| if value.HasField("metadata"): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_SCALAR | ||
| return (value,) | ||
|
|
||
|
|
||
| def _migrate_histogram_value(value): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| if value.HasField("metadata"): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| return (value,) | ||
|
|
||
|
|
||
| def _migrate_image_value(value): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE | ||
| if value.HasField("metadata"): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE | ||
| return (value,) | ||
|
|
||
|
|
||
| def _migrate_text_value(value): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| if value.HasField("metadata"): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| return (value,) | ||
|
|
||
|
|
||
| def _migrate_hparams_value(value): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| if value.HasField("metadata"): | ||
| value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR | ||
| if not value.HasField("tensor"): | ||
| value.tensor.CopyFrom(hparams_metadata.NULL_TENSOR) | ||
| return (value,) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,11 +51,15 @@ | |
| class MigrateEventTest(tf.test.TestCase): | ||
| """Tests for `migrate_event`.""" | ||
|
|
||
| def _migrate_event(self, old_event): | ||
| def _migrate_event(self, old_event, initial_metadata=None): | ||
| """Like `migrate_event`, but performs some sanity checks.""" | ||
| if initial_metadata is None: | ||
| initial_metadata = {} | ||
| old_event_copy = event_pb2.Event() | ||
| old_event_copy.CopyFrom(old_event) | ||
| new_events = dataclass_compat.migrate_event(old_event) | ||
| new_events = dataclass_compat.migrate_event( | ||
| old_event, initial_metadata=initial_metadata | ||
| ) | ||
| for event in new_events: # ensure that wall time and step are preserved | ||
| self.assertEqual(event.wall_time, old_event.wall_time) | ||
| self.assertEqual(event.step, old_event.step) | ||
|
|
@@ -95,6 +99,35 @@ def test_already_newstyle_summary_passes_through(self): | |
| self.assertLen(new_events, 1) | ||
| self.assertIs(new_events[0], old_event) | ||
|
|
||
| def test_doesnt_add_metadata_to_later_steps(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW, I guess this behavior is fine, but it would also seem fine to me to populate the dataclass in the metadata if the metadata is present on non-initial steps (i.e. base it just on whether that summary has metadata or not) which seems a little less stateful and thus perhaps slightly simpler to reason about?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, that’s fine with me. It’s important that we not populate the |
||
| old_events = [] | ||
| for step in range(3): | ||
| e = event_pb2.Event() | ||
| e.step = step | ||
| summary = scalar_summary.pb("foo", 0.125) | ||
| if step > 0: | ||
| for v in summary.value: | ||
| v.ClearField("metadata") | ||
| e.summary.ParseFromString(summary.SerializeToString()) | ||
| old_events.append(e) | ||
|
|
||
| initial_metadata = {} | ||
| new_events = [] | ||
| for e in old_events: | ||
| migrated = self._migrate_event(e, initial_metadata=initial_metadata) | ||
| new_events.extend(migrated) | ||
|
|
||
| self.assertLen(new_events, len(old_events)) | ||
| self.assertEqual( | ||
| { | ||
| e.step | ||
| for e in new_events | ||
| for v in e.summary.value | ||
| if v.HasField("metadata") | ||
| }, | ||
| {0}, | ||
| ) | ||
|
|
||
| def test_scalar(self): | ||
| old_event = event_pb2.Event() | ||
| old_event.step = 123 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for documenting this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Hopefully no one ever has to read it.