Skip to content

Commit d348cad

Browse files
authored
Make mesh plugin tests TF 2.x–compatible (#2560)
Summary: These just needed to drop down to graph mode in a few places. The metadata test actually already worked in TF 2.x! :-) Makes progress toward #1705. Test Plan: Tests pass in both TF 1.x and TF 2.x, and the mesh plugin no longer uses `run_v1_only`: ``` $ git grep run_v1_only '*mesh*' | wc -l 0 ``` wchargin-branch: mesh-tests-tf2x
1 parent 2dce496 commit d348cad

File tree

4 files changed

+93
-96
lines changed

4 files changed

+93
-96
lines changed

tensorboard/plugins/mesh/mesh_plugin_test.py

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import mock # pylint: disable=g-import-not-at-top,unused-import
4545

4646

47-
@tensorboard_test_util.run_v1_only('requires tf.Session')
4847
class MeshPluginTest(tf.test.TestCase):
4948
"""Tests for mesh plugin server."""
5049

@@ -57,73 +56,71 @@ def setUp(self):
5756
self.log_dir = self.get_temp_dir()
5857

5958
# Create mesh summary.
60-
tf.compat.v1.reset_default_graph()
61-
sess = tf.compat.v1.Session()
62-
point_cloud = test_utils.get_random_mesh(1000)
63-
point_cloud_vertices = tf.compat.v1.placeholder(tf.float32,
64-
point_cloud.vertices.shape)
65-
66-
mesh_no_color = test_utils.get_random_mesh(2000, add_faces=True)
67-
mesh_no_color_extended = test_utils.get_random_mesh(2500, add_faces=True)
68-
mesh_no_color_vertices = tf.compat.v1.placeholder(
69-
tf.float32, [1, None, 3])
70-
mesh_no_color_faces = tf.compat.v1.placeholder(tf.int32,
71-
[1, None, 3])
72-
73-
mesh_color = test_utils.get_random_mesh(
74-
3000, add_faces=True, add_colors=True)
75-
mesh_color_vertices = tf.compat.v1.placeholder(tf.float32,
76-
mesh_color.vertices.shape)
77-
mesh_color_faces = tf.compat.v1.placeholder(tf.int32,
78-
mesh_color.faces.shape)
79-
mesh_color_colors = tf.compat.v1.placeholder(tf.uint8,
80-
mesh_color.colors.shape)
81-
self.data = [
82-
point_cloud, mesh_no_color, mesh_no_color_extended, mesh_color]
83-
84-
# In case when name is present and display_name is not, we will reuse name
85-
# as display_name. Summaries below intended to test both cases.
86-
self.names = ["point_cloud", "mesh_no_color", "mesh_color"]
87-
summary.op(
88-
self.names[0],
89-
point_cloud_vertices,
90-
description="just point cloud")
91-
summary.op(
92-
self.names[1],
93-
mesh_no_color_vertices,
94-
faces=mesh_no_color_faces,
95-
display_name="name_to_display_in_ui",
96-
description="beautiful mesh in grayscale")
97-
summary.op(
98-
self.names[2],
99-
mesh_color_vertices,
100-
faces=mesh_color_faces,
101-
colors=mesh_color_colors,
102-
description="mesh with random colors")
103-
104-
merged_summary_op = tf.compat.v1.summary.merge_all()
105-
self.runs = ["bar"]
106-
self.steps = 20
107-
bar_directory = os.path.join(self.log_dir, self.runs[0])
108-
with tensorboard_test_util.FileWriterCache.get(bar_directory) as writer:
109-
writer.add_graph(sess.graph)
110-
for step in range(self.steps):
111-
# Alternate between two random meshes with different number of
112-
# vertices.
113-
no_color = mesh_no_color if step % 2 == 0 else mesh_no_color_extended
114-
with patch.object(time, 'time', return_value=step):
115-
writer.add_summary(
116-
sess.run(
117-
merged_summary_op,
118-
feed_dict={
119-
point_cloud_vertices: point_cloud.vertices,
120-
mesh_no_color_vertices: no_color.vertices,
121-
mesh_no_color_faces: no_color.faces,
122-
mesh_color_vertices: mesh_color.vertices,
123-
mesh_color_faces: mesh_color.faces,
124-
mesh_color_colors: mesh_color.colors,
125-
}),
126-
global_step=step)
59+
with tf.compat.v1.Graph().as_default():
60+
tf_placeholder = tf.compat.v1.placeholder
61+
sess = tf.compat.v1.Session()
62+
point_cloud = test_utils.get_random_mesh(1000)
63+
point_cloud_vertices = tf_placeholder(
64+
tf.float32, point_cloud.vertices.shape
65+
)
66+
67+
mesh_no_color = test_utils.get_random_mesh(2000, add_faces=True)
68+
mesh_no_color_extended = test_utils.get_random_mesh(2500, add_faces=True)
69+
mesh_no_color_vertices = tf_placeholder(tf.float32, [1, None, 3])
70+
mesh_no_color_faces = tf_placeholder(tf.int32, [1, None, 3])
71+
72+
mesh_color = test_utils.get_random_mesh(
73+
3000, add_faces=True, add_colors=True)
74+
mesh_color_vertices = tf_placeholder(tf.float32, mesh_color.vertices.shape)
75+
mesh_color_faces = tf_placeholder(tf.int32, mesh_color.faces.shape)
76+
mesh_color_colors = tf_placeholder(tf.uint8, mesh_color.colors.shape)
77+
78+
self.data = [
79+
point_cloud, mesh_no_color, mesh_no_color_extended, mesh_color]
80+
81+
# In case when name is present and display_name is not, we will reuse name
82+
# as display_name. Summaries below intended to test both cases.
83+
self.names = ["point_cloud", "mesh_no_color", "mesh_color"]
84+
summary.op(
85+
self.names[0],
86+
point_cloud_vertices,
87+
description="just point cloud")
88+
summary.op(
89+
self.names[1],
90+
mesh_no_color_vertices,
91+
faces=mesh_no_color_faces,
92+
display_name="name_to_display_in_ui",
93+
description="beautiful mesh in grayscale")
94+
summary.op(
95+
self.names[2],
96+
mesh_color_vertices,
97+
faces=mesh_color_faces,
98+
colors=mesh_color_colors,
99+
description="mesh with random colors")
100+
101+
merged_summary_op = tf.compat.v1.summary.merge_all()
102+
self.runs = ["bar"]
103+
self.steps = 20
104+
bar_directory = os.path.join(self.log_dir, self.runs[0])
105+
with tensorboard_test_util.FileWriterCache.get(bar_directory) as writer:
106+
writer.add_graph(sess.graph)
107+
for step in range(self.steps):
108+
# Alternate between two random meshes with different number of
109+
# vertices.
110+
no_color = mesh_no_color if step % 2 == 0 else mesh_no_color_extended
111+
with patch.object(time, 'time', return_value=step):
112+
writer.add_summary(
113+
sess.run(
114+
merged_summary_op,
115+
feed_dict={
116+
point_cloud_vertices: point_cloud.vertices,
117+
mesh_no_color_vertices: no_color.vertices,
118+
mesh_no_color_faces: no_color.faces,
119+
mesh_color_vertices: mesh_color.vertices,
120+
mesh_color_faces: mesh_color.faces,
121+
mesh_color_colors: mesh_color.colors,
122+
}),
123+
global_step=step)
127124

128125
# Start a server that will receive requests.
129126
self.multiplexer = event_multiplexer.EventMultiplexer({

tensorboard/plugins/mesh/metadata_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from tensorboard.util import test_util
2727

2828

29-
@test_util.run_v1_only('requires tf.Session')
3029
class MetadataTest(tf.test.TestCase):
3130

3231
def _create_metadata(self, shape=None):

tensorboard/plugins/mesh/summary.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _get_tensor_summary(
6464
shape,
6565
description,
6666
json_config=json_config)
67-
tensor_summary = tf.summary.tensor_summary(
67+
tensor_summary = tf.compat.v1.summary.tensor_summary(
6868
metadata.get_instance_name(name, content_type),
6969
tensor,
7070
summary_metadata=tensor_metadata,
@@ -135,7 +135,7 @@ def op(name, vertices, faces=None, colors=None, display_name=None,
135135
tensor.content_type, components, json_config,
136136
collections))
137137

138-
all_summaries = tf.summary.merge(
138+
all_summaries = tf.compat.v1.summary.merge(
139139
summaries, collections=collections, name=name)
140140
return all_summaries
141141

@@ -196,9 +196,9 @@ def pb(name,
196196
tag = metadata.get_instance_name(name, tensor.content_type)
197197
summaries.append((tag, summary_metadata, tensor_proto))
198198

199-
summary = tf.Summary()
199+
summary = tf.compat.v1.Summary()
200200
for tag, summary_metadata, tensor_proto in summaries:
201-
tf_summary_metadata = tf.SummaryMetadata.FromString(
201+
tf_summary_metadata = tf.compat.v1.SummaryMetadata.FromString(
202202
summary_metadata.SerializeToString())
203203
summary.value.add(
204204
tag=tag, metadata=tf_summary_metadata, tensor=tensor_proto)

tensorboard/plugins/mesh/summary_test.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensorboard.util import test_util
2929

3030

31-
@test_util.run_v1_only('requires tf.Session')
3231
class MeshSummaryTest(tf.test.TestCase):
3332

3433
def pb_via_op(self, summary_op):
@@ -60,35 +59,37 @@ def test_get_tensor_summary(self):
6059
description = "my mesh is the best of meshes"
6160
tensor_data = test_utils.get_random_mesh(100)
6261
components = 14
63-
tensor_summary = summary._get_tensor_summary(
64-
name, display_name, description, tensor_data.vertices,
65-
plugin_data_pb2.MeshPluginData.VERTEX, components, "", None)
66-
with self.test_session():
67-
proto = self.pb_via_op(tensor_summary)
68-
self.assertEqual("%s_VERTEX" % name, proto.value[0].tag)
69-
self.assertEqual(metadata.PLUGIN_NAME,
70-
proto.value[0].metadata.plugin_data.plugin_name)
71-
self.assertEqual(components, self.get_components(proto.value[0]))
62+
with tf.compat.v1.Graph().as_default():
63+
tensor_summary = summary._get_tensor_summary(
64+
name, display_name, description, tensor_data.vertices,
65+
plugin_data_pb2.MeshPluginData.VERTEX, components, "", None)
66+
with self.test_session():
67+
proto = self.pb_via_op(tensor_summary)
68+
self.assertEqual("%s_VERTEX" % name, proto.value[0].tag)
69+
self.assertEqual(metadata.PLUGIN_NAME,
70+
proto.value[0].metadata.plugin_data.plugin_name)
71+
self.assertEqual(components, self.get_components(proto.value[0]))
7272

7373
def test_op(self):
7474
"""Tests merged summary with different types of data."""
7575
name = "my_mesh"
7676
tensor_data = test_utils.get_random_mesh(
7777
100, add_faces=True, add_colors=True)
7878
config_dict = {"foo": 1}
79-
tensor_summary = summary.op(
80-
name,
81-
tensor_data.vertices,
82-
faces=tensor_data.faces,
83-
colors=tensor_data.colors,
84-
config_dict=config_dict)
85-
with self.test_session():
86-
proto = self.pb_via_op(tensor_summary)
87-
self.verify_proto(proto, name)
88-
plugin_metadata = metadata.parse_plugin_metadata(
89-
proto.value[0].metadata.plugin_data.content)
90-
self.assertEqual(
91-
json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config)
79+
with tf.compat.v1.Graph().as_default():
80+
tensor_summary = summary.op(
81+
name,
82+
tensor_data.vertices,
83+
faces=tensor_data.faces,
84+
colors=tensor_data.colors,
85+
config_dict=config_dict)
86+
with self.test_session() as sess:
87+
proto = self.pb_via_op(tensor_summary)
88+
self.verify_proto(proto, name)
89+
plugin_metadata = metadata.parse_plugin_metadata(
90+
proto.value[0].metadata.plugin_data.content)
91+
self.assertEqual(
92+
json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config)
9293

9394
def test_pb(self):
9495
"""Tests merged summary protobuf with different types of data."""

0 commit comments

Comments
 (0)