Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ py_library(
srcs = ["dataclass_compat.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorboard/backend:process_graph",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:metadata",
Expand Down
38 changes: 3 additions & 35 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
from __future__ import division
from __future__ import print_function


from google.protobuf import message
from tensorboard.backend import process_graph
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.graph import metadata as graphs_metadata
Expand All @@ -39,60 +35,32 @@
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.text import metadata as text_metadata
from tensorboard.util import tensor_util
from tensorboard.util import tb_logging

logger = tb_logging.get_logger()


def migrate_event(event, experimental_filter_graph=False):
def migrate_event(event):
"""Migrate an event to a sequence of events.

Args:
event: An `event_pb2.Event`. The caller transfers ownership of the
event to this method; the event may be mutated, and may or may
not appear in the returned sequence.
experimental_filter_graph: When a graph event is encountered, process the
GraphDef to filter out attributes that are too large to be shown in the
graph UI.

Returns:
A sequence of `event_pb2.Event`s to use instead of `event`.
"""
if event.HasField("graph_def"):
return _migrate_graph_event(
event, experimental_filter_graph=experimental_filter_graph
)
return _migrate_graph_event(event)
if event.HasField("summary"):
return _migrate_summary_event(event)
return (event,)


def _migrate_graph_event(old_event, experimental_filter_graph=False):
def _migrate_graph_event(old_event):
result = event_pb2.Event()
result.wall_time = old_event.wall_time
result.step = old_event.step
value = result.summary.value.add(tag=graphs_metadata.RUN_GRAPH_NAME)
graph_bytes = old_event.graph_def

# TODO(@davidsoergel): Move this stopgap to a more appropriate place.
if experimental_filter_graph:
try:
graph_def = graph_pb2.GraphDef().FromString(graph_bytes)
# The reason for the RuntimeWarning catch here is b/27494216, whereby
# some proto parsers incorrectly raise that instead of DecodeError
# on certain kinds of malformed input. Triggering this seems to require
# a combination of mysterious circumstances.
except (message.DecodeError, RuntimeWarning):
logger.warning(
"Could not parse GraphDef of size %d. Skipping.",
len(graph_bytes),
)
return (old_event,)
# Use the default filter parameters:
# limit_attr_size=1024, large_attrs_key="_too_large_attrs"
process_graph.prepare_graph_for_ui(graph_def)
graph_bytes = graph_def.SerializeToString()

value.tensor.CopyFrom(tensor_util.make_tensor_proto([graph_bytes]))
value.metadata.plugin_data.plugin_name = graphs_metadata.PLUGIN_NAME
# `value.metadata.plugin_data.content` left as the empty proto
Expand Down
108 changes: 2 additions & 106 deletions tensorboard/dataclass_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,11 @@
class MigrateEventTest(tf.test.TestCase):
"""Tests for `migrate_event`."""

def _migrate_event(self, old_event, experimental_filter_graph=False):
def _migrate_event(self, old_event):
"""Like `migrate_event`, but performs some sanity checks."""
old_event_copy = event_pb2.Event()
old_event_copy.CopyFrom(old_event)
new_events = dataclass_compat.migrate_event(
old_event, experimental_filter_graph
)
new_events = dataclass_compat.migrate_event(old_event)
for event in new_events: # ensure that wall time and step are preserved
self.assertEqual(event.wall_time, old_event.wall_time)
self.assertEqual(event.step, old_event.step)
Expand Down Expand Up @@ -223,108 +221,6 @@ def test_graph_def(self):

self.assertProtoEquals(graph_def, new_graph_def)

def test_graph_def_experimental_filter_graph(self):
# Create a `GraphDef`
graph_def = graph_pb2.GraphDef()
graph_def.node.add(name="alice", op="Person")
graph_def.node.add(name="bob", op="Person")

graph_def.node[1].attr["small"].s = b"small_attr_value"
graph_def.node[1].attr["large"].s = (
b"large_attr_value" * 100 # 1600 bytes > 1024 limit
)
graph_def.node.add(
name="friendship", op="Friendship", input=["alice", "bob"]
)

# Simulate legacy graph event
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
old_event.graph_def = graph_def.SerializeToString()

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)

new_event = new_events[1]
tensor = tensor_util.make_ndarray(new_event.summary.value[0].tensor)
new_graph_def_bytes = tensor[0]
new_graph_def = graph_pb2.GraphDef.FromString(new_graph_def_bytes)

expected_graph_def = graph_pb2.GraphDef()
expected_graph_def.CopyFrom(graph_def)
del expected_graph_def.node[1].attr["large"]
expected_graph_def.node[1].attr["_too_large_attrs"].list.s.append(
b"large"
)

self.assertProtoEquals(expected_graph_def, new_graph_def)

def test_graph_def_experimental_filter_graph_corrupt(self):
# Simulate legacy graph event with an unparseable graph.
# We can't be sure whether this will produce `DecodeError` or
# `RuntimeWarning`, so we also check both cases below.
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
# Careful: some proto parsers choke on byte arrays filled with 0, but
# others don't (silently producing an empty proto, I guess).
# Thus `old_event.graph_def = bytes(1024)` is an unreliable example.
old_event.graph_def = b"<malformed>"

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)
# _migrate_event emits both the original event and the migrated event,
# but here there is no migrated event becasue the graph was unparseable.
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)

def test_graph_def_experimental_filter_graph_DecodeError(self):
# Simulate raising DecodeError when parsing a graph event
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
old_event.graph_def = b"<malformed>"

with mock.patch(
"tensorboard.compat.proto.graph_pb2.GraphDef"
) as mockGraphDef:
instance = mockGraphDef.return_value
instance.FromString.side_effect = message.DecodeError

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)

# _migrate_event emits both the original event and the migrated event,
# but here there is no migrated event becasue the graph was unparseable.
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)

def test_graph_def_experimental_filter_graph_RuntimeWarning(self):
# Simulate raising RuntimeWarning when parsing a graph event
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
old_event.graph_def = b"<malformed>"

with mock.patch(
"tensorboard.compat.proto.graph_pb2.GraphDef"
) as mockGraphDef:
instance = mockGraphDef.return_value
instance.FromString.side_effect = RuntimeWarning

new_events = self._migrate_event(
old_event, experimental_filter_graph=True
)

# _migrate_event emits both the original event and the migrated event,
# but here there is no migrated event becasue the graph was unparseable.
self.assertLen(new_events, 1)
self.assertProtoEquals(new_events[0], old_event)


if __name__ == "__main__":
tf.test.main()
4 changes: 4 additions & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ py_library(
"//tensorboard:data_compat",
"//tensorboard:dataclass_compat",
"//tensorboard:expect_grpc_installed",
"//tensorboard/backend:process_graph",
"//tensorboard/backend/event_processing:directory_loader",
"//tensorboard/backend/event_processing:event_file_loader",
"//tensorboard/backend/event_processing:io_wrapper",
Expand All @@ -109,6 +110,7 @@ py_library(
"//tensorboard/util:grpc_util",
"//tensorboard/util:tb_logging",
"//tensorboard/util:tensor_util",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_six",
],
)
Expand All @@ -125,13 +127,15 @@ py_test(
"//tensorboard:expect_grpc_testing_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:summary_v2",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/scalar:summary_v2",
"//tensorboard/summary:summary_v1",
"//tensorboard/uploader/proto:protos_all_py_pb2",
"//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
"//tensorboard/util:test_util",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_mock",
],
)
Expand Down
54 changes: 48 additions & 6 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@
import grpc
import six

from google.protobuf import message
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.uploader.proto import write_service_pb2
from tensorboard.uploader.proto import experiment_pb2
from tensorboard.uploader import logdir_loader
from tensorboard.uploader import util
from tensorboard import data_compat
from tensorboard import dataclass_compat
from tensorboard.backend import process_graph
from tensorboard.backend.event_processing import directory_loader
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.plugins.graph import metadata as graphs_metadata
from tensorboard.plugins.scalar import metadata as scalar_metadata
from tensorboard.util import grpc_util
from tensorboard.util import tb_logging
Expand Down Expand Up @@ -425,12 +430,11 @@ def _run_values(self, run_to_events):
for (run_name, events) in six.iteritems(run_to_events):
for event in events:
v2_event = data_compat.migrate_event(event)
dataclass_events = dataclass_compat.migrate_event(
v2_event, experimental_filter_graph=True
)
for dataclass_event in dataclass_events:
if dataclass_event.summary:
for value in dataclass_event.summary.value:
events = dataclass_compat.migrate_event(v2_event)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't shadow events (and event below)-- that is confusing and error-prone. dataclass_events was fine imho.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience, this form of shadowing is less error-prone than the
alternative, because it is never possible to accidentally refer to the
old value. Case in point: prior to this change, the code was subtly
incorrect, because the yield expression referred to event rather
than dataclass_event. (This would be observable if a dataclass
transformation were to affect the wall time or step—certainly allowed,
and perfectly conceivable for conceptually run-level summaries like
graphs or hparams.)

Shadowing makes that kind of error structurally impossible, and, as you
correctly note, this shadowing is entirely lexical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for what it's worth, all of this goes away in the subsequent PR anyway (https://github.com/tensorflow/tensorboard/pull/3511/files#diff-64ad30888691c0bf32cae63247f4ca5c).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remain opposed to any shadowing whatsoever, but don't need to argue this case since it's going away in #3511

(Eep, the former bug does suck, though; thanks for the fix).

events = _filter_graph_defs(events)
for event in events:
if event.summary:
for value in event.summary.value:
yield (run_name, event, value)


Expand Down Expand Up @@ -833,3 +837,41 @@ def _varint_cost(n):
result += 1
n >>= 7
return result


def _filter_graph_defs(events):
for e in events:
for v in e.summary.value:
if (
v.metadata.plugin_data.plugin_name
!= graphs_metadata.PLUGIN_NAME
):
continue
if v.tag == graphs_metadata.RUN_GRAPH_NAME:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessarily restrictive. In practice, for now, we have only run-level graphs anyway. But IIRC we expect to allow "__op_graph__/foo" and "__conceptual_graph__/bar" also, and those should also be filtered. NBD for now but maybe add a comment or TODO to address this later.

(I know this is a straight refactor of code that already ignored those cases, but still)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you note, I’m porting existing code. Op graphs and conceptual graphs
do not actually have plugin name graphs; instead, they use plugin
names graph_run_metadata_graph and graph_keras_model, so clearly
at least some special considerations will need to be taken, even before
we consider whether the semantics should be the same. Attempting to
anticipate those now seems premature to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I forgot about the different plugin names. So this condition is protecting us from data that is associated with the graph plugin, but under a different tag. There is currently no such thing; if something arises in the future, we can't predict whether it should be filtered or not-- so now I agree the condition is correct.

data = list(v.tensor.string_val)
filtered_data = [_filtered_graph_bytes(x) for x in data]
filtered_data = [x for x in filtered_data if x is not None]
if filtered_data != data:
new_tensor = tensor_util.make_tensor_proto(
filtered_data, dtype=types_pb2.DT_STRING
)
v.tensor.CopyFrom(new_tensor)
yield e


def _filtered_graph_bytes(graph_bytes):
try:
graph_def = graph_pb2.GraphDef().FromString(graph_bytes)
# The reason for the RuntimeWarning catch here is b/27494216, whereby
# some proto parsers incorrectly raise that instead of DecodeError
# on certain kinds of malformed input. Triggering this seems to require
# a combination of mysterious circumstances.
except (message.DecodeError, RuntimeWarning):
logger.warning(
"Could not parse GraphDef of size %d. Skipping.", len(graph_bytes),
)
return None
# Use the default filter parameters:
# limit_attr_size=1024, large_attrs_key="_too_large_attrs"
process_graph.prepare_graph_for_ui(graph_def)
return graph_def.SerializeToString()
62 changes: 62 additions & 0 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import tensorflow as tf

from google.protobuf import message
from tensorboard.uploader.proto import experiment_pb2
from tensorboard.uploader.proto import scalar_pb2
from tensorboard.uploader.proto import write_service_pb2
Expand Down Expand Up @@ -359,6 +360,67 @@ def test_upload_skip_large_blob(self):
self.assertEqual(0, mock_rate_limiter.tick.call_count)
self.assertEqual(1, mock_blob_rate_limiter.tick.call_count)

def test_filter_graphs(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up and simplifying these tests! However, I am a little skeptical of this design because the tests are no longer independent. Can you factor out some common setup but keep the tests separate (while retaining most of the simplicity)?

Also, why did you remove the explicit tests for RuntimeWarning vs. DecodeError that @caisq had just requested in #3508?

Finally FYI #3509 is outstanding. NP, I'll merge it, assuming you submit first. The behavior of the empty graph case may actually change, depending on how you address the 'corrupt graph' case above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up and simplifying these tests! However, I am a
little skeptical of this design because the tests are no longer
independent.

The tests do fail independently. Python subTest is roughly like the JS
pattern of having a describe with setup followed by multiple its.
The setup runs once, and the cases run independently.

Thus:

import unittest


class Test(unittest.TestCase):
    def test(self):
        x = 2
        with self.subTest("should be one"):
            self.assertEqual(x, 1)
        with self.subTest("but also two"):
            self.assertEqual(x, 2)
        with self.subTest("and somehow three"):
            self.assertEqual(x, 3)


if __name__ == "__main__":
    unittest.main()
======================================================================
FAIL: test (__main__.Test) [should be one]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/test.py", line 8, in test
    self.assertEqual(x, 1)
AssertionError: 2 != 1

======================================================================
FAIL: test (__main__.Test) [and somehow three]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/test.py", line 12, in test
    self.assertEqual(x, 3)
AssertionError: 2 != 3

----------------------------------------------------------------------
Ran 1 test in 0.001s

FAILED (failures=2)

It’s true that any side effects in one sub-test would be visible in the
others, but all the potentially effectful logic clearly happens in the
setup, not the assertions.

I’m not sure that I fully understand your concern; is this satisfactory?

Also, why did you remove the explicit tests for RuntimeWarning vs.
DecodeError that @caisq had just requested in #3508?

Mock-based tests are fragile, and the one proposed in #3508 is no
exception. The code under test could well be refactored to use
ParseFromString or MergeFromString instead of FromString, in which
case the test would spuriously fail. I have spent far too much sanity
over the past day, month, and year debugging and working around brittle
mocks to be comfortable adding more of them without clear need.

The proposed test covers that the code does what we expect it to if a
RuntimeWarning is raised from a certain place. That’s not really an
contract-level spec. As far as I have been able to tell, we have no idea
how to actually trigger the offending case! (The internal bug has no
comments from anyone on the TensorBoard team, and I haven’t been able to
find a reference to why we’re even talking about this, even with the
resources available to me as a Googler.) The code as written obviously
satisfies the proposed test case—it’s right there in the except—so the
test serves only to protect against people accidentally deleting the
except clause, which is already protected by the attached comment.

It’s always tradeoffs in the end. It is my opinion that the test does
not justify the substantial additional complexity that it would incur.

Finally FYI #3509 is outstanding. NP, I'll merge it, assuming you
submit first. The behavior of the empty graph case may actually
change, depending on how you address the 'corrupt graph' case above.

Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK for now for expediency. For the record, my concern is that the setup code is not generic; it includes setup specific to the tests (e.g., the values and order of the bytes examples). If I wanted to add another subtest--perhaps the one from #3509--I'd have to change the setup code here. I can see this both ways (i.e., I could also add more tests that use the same examples, without changing the setup). The fact that currently parts of the setup correspond 1:1 with specific tests makes it look more entangled than it necessarily is.

# Three graphs: one short, one long, one corrupt.
bytes_0 = _create_example_graph_bytes(123)
bytes_1 = _create_example_graph_bytes(9999)
# invalid (truncated) proto: length-delimited field 1 (0x0a) of
# length 0x7f specified, but only len("bogus") = 5 bytes given
# <https://developers.google.com/protocol-buffers/docs/encoding>
bytes_2 = b"\x0a\x7fbogus"

logdir = self.get_temp_dir()
for (i, b) in enumerate([bytes_0, bytes_1, bytes_2]):
run_dir = os.path.join(logdir, "run_%04d" % i)
event = event_pb2.Event(step=0, wall_time=123 * i, graph_def=b)
with tb_test_util.FileWriter(run_dir) as writer:
writer.add_event(event)

limiter = mock.create_autospec(util.RateLimiter)
limiter.tick.side_effect = [None, AbortUploadError]
mock_client = _create_mock_client()
uploader = _create_uploader(
mock_client,
logdir,
logdir_poll_rate_limiter=limiter,
allowed_plugins=[
scalars_metadata.PLUGIN_NAME,
graphs_metadata.PLUGIN_NAME,
],
)
uploader.create_experiment()

with self.assertRaises(AbortUploadError):
uploader.start_uploading()

actual_blobs = []
for call in mock_client.WriteBlob.call_args_list:
requests = call[0][0]
actual_blobs.append(b"".join(r.data for r in requests))

actual_graph_defs = []
for blob in actual_blobs:
try:
actual_graph_defs.append(graph_pb2.GraphDef.FromString(blob))
except message.DecodeError:
actual_graph_defs.append(None)

with self.subTest("graphs with small attr values should be unchanged"):
expected_graph_def_0 = graph_pb2.GraphDef.FromString(bytes_0)
self.assertEqual(actual_graph_defs[0], expected_graph_def_0)

with self.subTest("large attr values should be filtered out"):
expected_graph_def_1 = graph_pb2.GraphDef.FromString(bytes_1)
del expected_graph_def_1.node[1].attr["large"]
expected_graph_def_1.node[1].attr["_too_large_attrs"].list.s.append(
b"large"
)
requests = list(mock_client.WriteBlob.call_args[0][0])
self.assertEqual(actual_graph_defs[1], expected_graph_def_1)

with self.subTest("corrupt graphs should be skipped"):
self.assertLen(actual_blobs, 2)

def test_upload_server_error(self):
mock_client = _create_mock_client()
mock_rate_limiter = mock.create_autospec(util.RateLimiter)
Expand Down