diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index 25650f3e5f6..df5802cabdf 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -10,6 +10,8 @@ import sys from io import UnsupportedOperation from tempfile import TemporaryFile +from typing import Callable +from typing import List import pytest from _pytest.compat import CaptureIO @@ -77,19 +79,28 @@ def __init__(self, method): self._method = method self._global_capturing = None self._current_item = None + self._atexit_funcs: List[Callable] = [] + atexit.register(self._atexit_run) def __repr__(self): return "".format( self._method, self._global_capturing, self._current_item ) + def _atexit_register(self, func): + self._atexit_funcs.append(func) + + def _atexit_run(self): + for func in self._atexit_funcs: + func() + def _getcapture(self, method): if method == "fd": - return MultiCapture(out=True, err=True, Capture=FDCapture) + return MultiCapture(out=True, err=True, Capture=FDCapture, capman=self) elif method == "sys": - return MultiCapture(out=True, err=True, Capture=SysCapture) + return MultiCapture(out=True, err=True, Capture=SysCapture, capman=self) elif method == "no": - return MultiCapture(out=False, err=False, in_=False) + return MultiCapture(out=False, err=False, in_=False, capman=self) raise ValueError("unknown capturing method: %r" % method) # pragma: no cover def is_capturing(self): @@ -451,13 +462,13 @@ class MultiCapture: out = err = in_ = None _state = None - def __init__(self, out=True, err=True, in_=True, Capture=None): + def __init__(self, out=True, err=True, in_=True, Capture=None, capman: CaptureManager = None): if in_: - self.in_ = Capture(0) + self.in_ = Capture(0, capman=capman) if out: - self.out = Capture(1) + self.out = Capture(1, capman=capman) if err: - self.err = Capture(2) + self.err = Capture(2, capman=capman) def __repr__(self): return "".format( @@ -540,8 +551,9 @@ class FDCaptureBinary: EMPTY_BUFFER = b"" _state = None - def __init__(self, targetfd, tmpfile=None): + def __init__(self, targetfd, tmpfile=None, capman: CaptureManager = None): self.targetfd = targetfd + self._capman = capman try: self.targetfd_save = os.dup(self.targetfd) except OSError: @@ -553,14 +565,14 @@ def __init__(self, targetfd, tmpfile=None): if targetfd == 0: assert not tmpfile, "cannot set tmpfile with stdin" tmpfile = open(os.devnull, "r") - self.syscapture = SysCapture(targetfd) + self.syscapture = SysCapture(targetfd, capman=self._capman) else: if tmpfile is None: f = TemporaryFile() with f: tmpfile = safe_text_dupfile(f, mode="wb+") if targetfd in patchsysdict: - self.syscapture = SysCapture(targetfd, tmpfile) + self.syscapture = SysCapture(targetfd, tmpfile, capman) else: self.syscapture = NoCapture() self.tmpfile = tmpfile @@ -595,9 +607,12 @@ def _done(self): os.dup2(targetfd_save, self.targetfd) os.close(targetfd_save) self.syscapture.done() - # Redirect any remaining output. - os.dup2(self.targetfd, self.tmpfile_fd) - atexit.register(self.tmpfile.close) + if self._capman: + # Redirect any remaining output. + os.dup2(self.targetfd, self.tmpfile_fd) + self._capman._atexit_register(self.tmpfile.close) + else: + self.tmpfile.close() self._state = "done" def suspend(self): @@ -639,8 +654,9 @@ class SysCapture: EMPTY_BUFFER = str() _state = None - def __init__(self, fd, tmpfile=None): + def __init__(self, fd, tmpfile=None, capman: CaptureManager = None): name = patchsysdict[fd] + self._capman = capman self._old = getattr(sys, name) self.name = name if tmpfile is None: @@ -668,7 +684,10 @@ def snap(self): def done(self): setattr(sys, self.name, self._old) del self._old - atexit.register(self.tmpfile.close) + if self._capman: + self._capman._atexit_register(self.tmpfile.close) + else: + self.tmpfile.close() self._state = "done" def suspend(self):