Skip to content

Commit f87e787

Browse files
committed
Skip tests that require an OpenGL context if DISABLE_MUJOCO_RENDERING is set
Also fixed a Python3 incompatibility in `engine.py` and added a missing test dependency on scipy. PiperOrigin-RevId: 183157102
1 parent 3b4552d commit f87e787

10 files changed

+48
-22
lines changed

dm_control/mujoco/engine.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
import numpy as np
5454
import six
55+
from six.moves import xrange # pylint: disable=redefined-builtin
5556

5657
from dm_control.rl import specs
5758

dm_control/mujoco/engine_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import unittest
22+
2123
# Internal dependencies.
2224

2325
from absl.testing import absltest
2426
from absl.testing import parameterized
2527

28+
from dm_control import render
2629
from dm_control.mujoco import engine
2730
from dm_control.mujoco import wrapper
2831
from dm_control.mujoco.testing import assets
@@ -63,12 +66,14 @@ def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
6366
raise AssertionError("Attribute '{}' differs from expected value. {}"
6467
"".format(name, e.message))
6568

69+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
6670
@parameterized.parameters(0, 'cart', u'cart')
6771
def testCameraIndexing(self, camera_id):
6872
height, width = 480, 640
6973
_ = engine.Camera(
7074
self._physics, height, width, camera_id=camera_id)
7175

76+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
7277
def testDepthRender(self):
7378
plane_and_box = """
7479
<mujoco>
@@ -86,6 +91,7 @@ def testDepthRender(self):
8691
# Furthest pixels should be 3m away (depth is orthographic)
8792
np.testing.assert_approx_equal(pixels.max(), 3.0, 3)
8893

94+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
8995
def testTextOverlay(self):
9096
height, width = 480, 640
9197
overlay = engine.TextOverlay(title='Title', body='Body', style='big',
@@ -97,6 +103,7 @@ def testTextOverlay(self):
97103
self.assertFalse(np.all(no_overlay == with_overlay),
98104
msg='Images are identical with and without text overlay.')
99105

106+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
100107
def testSceneOption(self):
101108
height, width = 480, 640
102109
scene_option = wrapper.MjvOption()
@@ -115,6 +122,7 @@ def testSceneOption(self):
115122
((0.5, 0.1), (0, 0)), # ground
116123
((0.9, 0.9), (None, None)), # sky
117124
)
125+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
118126
def testCameraSelection(self, coordinates, expected_selection):
119127
height, width = 480, 640
120128
camera = engine.Camera(self._physics, height, width, camera_id=0)
@@ -127,6 +135,7 @@ def testCameraSelection(self, coordinates, expected_selection):
127135
selected = camera.select(coordinates)
128136
self.assertEqual(expected_selection, selected[:2])
129137

138+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
130139
def testMovableCameraSetGetPose(self):
131140
height, width = 240, 320
132141

@@ -154,6 +163,7 @@ def testMovableCameraSetGetPose(self):
154163

155164
self.assertFalse(np.all(image == camera.render()))
156165

166+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
157167
def testRenderExceptions(self):
158168
max_width = self._physics.model.vis.global_.offwidth
159169
max_height = self._physics.model.vis.global_.offheight
@@ -167,6 +177,7 @@ def testRenderExceptions(self):
167177
with self.assertRaisesRegexp(ValueError, 'camera_id'):
168178
self._physics.render(max_height, max_width, camera_id=-2)
169179

180+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
170181
def testPhysicsRenderMethod(self):
171182
height, width = 240, 320
172183
image = self._physics.render(height=height, width=width)

dm_control/mujoco/render_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from __future__ import print_function
2121

2222
import os
23+
import unittest
2324

2425
# Internal dependencies.
2526
from absl.testing import absltest
2627
from absl.testing import parameterized
28+
from dm_control import render
2729
from dm_control.mujoco.testing import decorators
2830
from dm_control.mujoco.testing import image_utils
2931
from six.moves import zip # pylint: disable=redefined-builtin
@@ -36,6 +38,7 @@
3638
CALLS_PER_THREAD = 1
3739

3840

41+
@unittest.skipIf(render.DISABLED, render.DISABLED_MESSAGE)
3942
class RenderTest(parameterized.TestCase):
4043

4144
@parameterized.named_parameters(image_utils.SEQUENCES.items())

dm_control/mujoco/thread_safety_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import unittest
23+
2224
# Internal dependencies.
2325

2426
from absl.testing import absltest
25-
27+
from dm_control import render
2628
from dm_control.mujoco import engine
2729
from dm_control.mujoco.testing import assets
2830
from dm_control.mujoco.testing import decorators
@@ -68,6 +70,7 @@ def test_load_and_step_multiple_physics_sequential(self):
6870
for _ in xrange(NUM_STEPS):
6971
physics2.step()
7072

73+
@unittest.skipIf(render.DISABLED, render.DISABLED_MESSAGE)
7174
@decorators.run_threaded(calls_per_thread=5)
7275
def test_load_physics_and_render(self):
7376
physics = engine.Physics.from_xml_string(MODEL)
@@ -83,6 +86,7 @@ def test_load_physics_and_render(self):
8386

8487
self.assertEqual(NUM_STEPS, len(unique_frames))
8588

89+
@unittest.skipIf(render.DISABLED, render.DISABLED_MESSAGE)
8690
@decorators.run_threaded(calls_per_thread=5)
8791
def test_render_multiple_physics_instances_per_thread_parallel(self):
8892
physics1 = engine.Physics.from_xml_string(MODEL)

dm_control/render/glfw_renderer.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,25 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import sys
23+
2224
# Internal dependencies.
2325

2426
from dm_control.render import base
25-
import glfw
26-
27-
_done_init_glfw = False
28-
27+
import six
2928

30-
def _maybe_init_glfw():
31-
global _done_init_glfw
32-
if not _done_init_glfw:
33-
if not glfw.init():
34-
raise OSError('Failed to initialize GLFW.')
35-
_done_init_glfw = True
29+
# Re-raise any exceptions that occur during module import as `ImportError`s.
30+
# This simplifies the conditional imports in `render/__init__.py`.
31+
try:
32+
import glfw # pylint: disable=g-import-not-at-top
33+
except (ImportError, IOError, OSError) as exc:
34+
_, exc, tb = sys.exc_info()
35+
six.reraise(ImportError, ImportError(str(exc)), tb)
36+
try:
37+
glfw.init()
38+
except glfw.GLFWError as exc:
39+
_, exc, tb = sys.exc_info()
40+
six.reraise(ImportError, ImportError(str(exc)), tb)
3641

3742

3843
class GLFWContext(base.ContextBase):
@@ -46,7 +51,6 @@ def __init__(self, max_width, max_height):
4651
max_height: Integer specifying the maximum framebuffer height in pixels.
4752
"""
4853
super(GLFWContext, self).__init__()
49-
_maybe_init_glfw()
5054
glfw.window_hint(glfw.VISIBLE, 0)
5155
glfw.window_hint(glfw.DOUBLEBUFFER, 0)
5256
self._context = glfw.create_window(width=max_width, height=max_height,

dm_control/render/glfw_renderer_test.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,28 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
# Internal dependencies.
22+
import unittest
2323

24+
# Internal dependencies.
2425
from absl.testing import absltest
25-
26-
from dm_control.render import glfw_renderer
27-
26+
from dm_control import render
2827
import mock
2928

3029
MAX_WIDTH = 1024
3130
MAX_HEIGHT = 1024
32-
CONTEXT_PATH = glfw_renderer.__name__ + ".glfw"
31+
CONTEXT_PATH = render.__name__ + '.glfw_renderer.glfw'
3332

3433

34+
@unittest.skipUnless(render._GLFWRenderer,
35+
reason='GLFW renderer could not be imported.')
3536
@mock.patch(CONTEXT_PATH)
3637
class GLFWContextTest(absltest.TestCase):
3738

3839
def setUp(self):
3940
self.context = mock.MagicMock()
4041

4142
with mock.patch(CONTEXT_PATH):
42-
self.renderer = glfw_renderer.GLFWContext(MAX_WIDTH, MAX_HEIGHT)
43+
self.renderer = render.Renderer(MAX_WIDTH, MAX_HEIGHT)
4344

4445
def tearDown(self):
4546
self.renderer._context = None
@@ -61,5 +62,5 @@ def test_freeing(self, mock_glfw):
6162
self.assertIsNone(self.renderer._previous_context)
6263

6364

64-
if __name__ == "__main__":
65+
if __name__ == '__main__':
6566
absltest.main()

dm_control/suite/utils/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,3 @@
1414
# ============================================================================
1515

1616
"""Utility functions used in the control suite."""
17-
18-
from dm_control.suite.utils import randomizers

dm_control/suite/wrappers/pixels_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from __future__ import print_function
2121

2222
import collections
23+
import unittest
2324

2425
# Internal dependencies.
2526

2627
from absl.testing import absltest
2728
from absl.testing import parameterized
2829

29-
30+
from dm_control import render
3031
from dm_control.rl import environment
3132
from dm_control.rl import specs
3233
from dm_control.suite import cartpole
@@ -62,6 +63,7 @@ def observation_spec(self):
6263
return specs.ArraySpec(shape=(2,), dtype=np.float)
6364

6465

66+
@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
6567
class PixelsTest(parameterized.TestCase):
6668

6769
@parameterized.parameters(True, False)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ nose==1.3.7
88
numpy==1.13.3
99
pillow==5.0.0
1010
pyparsing==2.2.0
11+
scipy==1.0.0
1112
six==1.11.0

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def run(self):
163163
'mock',
164164
'nose',
165165
'pillow',
166+
'scipy',
166167
],
167168
test_suite='nose.collector',
168169
packages=find_packages(),

0 commit comments

Comments
 (0)