Skip to content

Commit

Permalink
wip: atexit
Browse files Browse the repository at this point in the history
  • Loading branch information
blueyed committed Oct 23, 2019
1 parent 60c365a commit b9d7d91
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/_pytest/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from typing import Callable
from typing import List

import pytest
from _pytest.compat import CaptureIO
Expand Down Expand Up @@ -77,19 +79,28 @@ def __init__(self, method):
self._method = method
self._global_capturing = None
self._current_item = None
self._atexit_funcs: List[Callable] = []
atexit.register(self._atexit_run)

def __repr__(self):
return "<CaptureManager _method={!r} _global_capturing={!r} _current_item={!r}>".format(
self._method, self._global_capturing, self._current_item
)

def _atexit_register(self, func):
self._atexit_funcs.append(func)

def _atexit_run(self):
for func in self._atexit_funcs:
func()

def _getcapture(self, method):
if method == "fd":
return MultiCapture(out=True, err=True, Capture=FDCapture)
return MultiCapture(out=True, err=True, Capture=FDCapture, capman=self)
elif method == "sys":
return MultiCapture(out=True, err=True, Capture=SysCapture)
return MultiCapture(out=True, err=True, Capture=SysCapture, capman=self)
elif method == "no":
return MultiCapture(out=False, err=False, in_=False)
return MultiCapture(out=False, err=False, in_=False, capman=self)
raise ValueError("unknown capturing method: %r" % method) # pragma: no cover

def is_capturing(self):
Expand Down Expand Up @@ -451,13 +462,13 @@ class MultiCapture:
out = err = in_ = None
_state = None

def __init__(self, out=True, err=True, in_=True, Capture=None):
def __init__(self, out=True, err=True, in_=True, Capture=None, capman: CaptureManager = None):
if in_:
self.in_ = Capture(0)
self.in_ = Capture(0, capman=capman)
if out:
self.out = Capture(1)
self.out = Capture(1, capman=capman)
if err:
self.err = Capture(2)
self.err = Capture(2, capman=capman)

def __repr__(self):
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
Expand Down Expand Up @@ -540,8 +551,9 @@ class FDCaptureBinary:
EMPTY_BUFFER = b""
_state = None

def __init__(self, targetfd, tmpfile=None):
def __init__(self, targetfd, tmpfile=None, capman: CaptureManager = None):
self.targetfd = targetfd
self._capman = capman
try:
self.targetfd_save = os.dup(self.targetfd)
except OSError:
Expand All @@ -553,14 +565,14 @@ def __init__(self, targetfd, tmpfile=None):
if targetfd == 0:
assert not tmpfile, "cannot set tmpfile with stdin"
tmpfile = open(os.devnull, "r")
self.syscapture = SysCapture(targetfd)
self.syscapture = SysCapture(targetfd, capman=self._capman)
else:
if tmpfile is None:
f = TemporaryFile()
with f:
tmpfile = safe_text_dupfile(f, mode="wb+")
if targetfd in patchsysdict:
self.syscapture = SysCapture(targetfd, tmpfile)
self.syscapture = SysCapture(targetfd, tmpfile, capman)
else:
self.syscapture = NoCapture()
self.tmpfile = tmpfile
Expand Down Expand Up @@ -595,9 +607,12 @@ def _done(self):
os.dup2(targetfd_save, self.targetfd)
os.close(targetfd_save)
self.syscapture.done()
# Redirect any remaining output.
os.dup2(self.targetfd, self.tmpfile_fd)
atexit.register(self.tmpfile.close)
if self._capman:
# Redirect any remaining output.
os.dup2(self.targetfd, self.tmpfile_fd)
self._capman._atexit_register(self.tmpfile.close)
else:
self.tmpfile.close()
self._state = "done"

def suspend(self):
Expand Down Expand Up @@ -639,8 +654,9 @@ class SysCapture:
EMPTY_BUFFER = str()
_state = None

def __init__(self, fd, tmpfile=None):
def __init__(self, fd, tmpfile=None, capman: CaptureManager = None):
name = patchsysdict[fd]
self._capman = capman
self._old = getattr(sys, name)
self.name = name
if tmpfile is None:
Expand Down Expand Up @@ -668,7 +684,10 @@ def snap(self):
def done(self):
setattr(sys, self.name, self._old)
del self._old
atexit.register(self.tmpfile.close)
if self._capman:
self._capman._atexit_register(self.tmpfile.close)
else:
self.tmpfile.close()
self._state = "done"

def suspend(self):
Expand Down

0 comments on commit b9d7d91

Please sign in to comment.