Skip to content

Commit

Permalink
Avoid possible infinite recursion when writing pyc files in assert re…
Browse files Browse the repository at this point in the history
…write

What happens is that atomic_write on Python 2.7 on Windows will try
to convert the paths to unicode, but this triggers the import of
the encoding module for the file system codec, which in turn triggers
the rewrite, which in turn again tries to import the module, and so on.

This short-circuits the cases where we try to import another file when
writing a pyc file; I don't expect this to affect anything because
the only modules that could be affected are those imported by
atomic_writes.

Fix pytest-dev#3506
  • Loading branch information
nicoddemus committed Aug 28, 2018
1 parent 9620b16 commit 14aa4d2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/3506.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix possible infinite recursion when writing ``.pyc`` files.
11 changes: 10 additions & 1 deletion src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ def __init__(self, config):
self._rewritten_names = set()
self._register_with_pkg_resources()
self._must_rewrite = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False

def set_session(self, session):
self.session = session

def find_module(self, name, path=None):
if self._writing_pyc:
return None
state = self.config._assertstate
state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1)
Expand Down Expand Up @@ -151,7 +156,11 @@ def find_module(self, name, path=None):
# Probably a SyntaxError in the test.
return None
if write:
_write_pyc(state, co, source_stat, pyc)
self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
finally:
self._writing_pyc = False
else:
state.trace("found cached rewritten pyc for %r" % (fn,))
self.modules[name] = co, pyc
Expand Down
28 changes: 28 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,3 +1124,31 @@ def test_simple_failure():

result = testdir.runpytest()
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")


def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
"""Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
file, this would cause another call to the hook, which would trigger another pyc writing, which could
trigger another import, and so on. (#3506)"""
from _pytest.assertion import rewrite

testdir.syspathinsert()
testdir.makepyfile(test_foo="def test(): pass")

original_write_pyc = rewrite._write_pyc

write_pyc_called = []

def spy_write_pyc(*args, **kwargs):
# make a note that we have called _write_pyc
write_pyc_called.append(True)
# try to import a module at this point: we should not try to rewrite this module
assert hook.find_module("test_foo") is None
return original_write_pyc(*args, **kwargs)

monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(sys, "dont_write_bytecode", False)

hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None
assert len(write_pyc_called) == 1

0 comments on commit 14aa4d2

Please sign in to comment.