From a2db9518930f30d1fcbe008e9f69fdc7d8754f46 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 31 Jul 2024 23:44:31 -0500 Subject: [PATCH 1/3] Enable ruff's `flake8-return` rule --- .../how-does-windows-so-reuseaddr-work.py | 5 +- pyproject.toml | 1 + src/trio/_core/_io_kqueue.py | 5 +- src/trio/_core/_ki.py | 15 ++-- src/trio/_core/_mock_clock.py | 16 ++--- src/trio/_core/_run.py | 71 +++++++++---------- src/trio/_core/_tests/test_guest_mode.py | 3 +- src/trio/_core/_tests/test_run.py | 5 +- src/trio/_core/_tests/test_thread_cache.py | 5 +- src/trio/_core/_tests/test_windows.py | 6 +- src/trio/_core/_unbounded_queue.py | 9 ++- src/trio/_dtls.py | 27 +++---- src/trio/_file_io.py | 6 +- src/trio/_highlevel_generic.py | 3 +- src/trio/_highlevel_open_tcp_listeners.py | 3 +- src/trio/_highlevel_open_tcp_stream.py | 10 ++- src/trio/_highlevel_socket.py | 3 +- src/trio/_highlevel_ssl_helpers.py | 3 +- src/trio/_repl.py | 23 +++--- src/trio/_socket.py | 52 +++++++------- src/trio/_ssl.py | 16 ++--- src/trio/_subprocess.py | 5 +- src/trio/_sync.py | 2 +- src/trio/_tests/check_type_completeness.py | 11 ++- src/trio/_tests/test_dtls.py | 5 +- src/trio/_tests/test_highlevel_socket.py | 3 +- src/trio/_tests/test_socket.py | 5 +- src/trio/_tests/test_ssl.py | 5 +- src/trio/_tests/test_subprocess.py | 3 +- src/trio/_tests/test_threads.py | 12 ++-- src/trio/_threads.py | 5 +- src/trio/_tools/mypy_annotate.py | 3 +- src/trio/_unix_pipes.py | 9 +-- src/trio/_util.py | 3 +- src/trio/testing/_fake_net.py | 13 ++-- src/trio/testing/_memory_streams.py | 8 +-- 36 files changed, 160 insertions(+), 219 deletions(-) diff --git a/notes-to-self/how-does-windows-so-reuseaddr-work.py b/notes-to-self/how-does-windows-so-reuseaddr-work.py index 70dd75e39f..ae2495486b 100644 --- a/notes-to-self/how-does-windows-so-reuseaddr-work.py +++ b/notes-to-self/how-does-windows-so-reuseaddr-work.py @@ -38,11 +38,10 @@ def table_entry(mode1, bind_type1, mode2, bind_type2): except OSError as exc: if exc.winerror == errno.WSAEADDRINUSE: return "INUSE" - elif exc.winerror == errno.WSAEACCES: + if exc.winerror == errno.WSAEACCES: return "ACCESS" raise - else: - return "Success" + return "Success" print( diff --git a/pyproject.toml b/pyproject.toml index 0e26fea83a..3fa48d2476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ select = [ "PERF", # Perflint "PT", # flake8-pytest-style "PYI", # flake8-pyi + "RET", # flake8-return "RUF", # Ruff-specific rules "SIM", # flake8-simplify "TCH", # flake8-type-checking diff --git a/src/trio/_core/_io_kqueue.py b/src/trio/_core/_io_kqueue.py index 3d0aed7d35..a1f3e72c86 100644 --- a/src/trio/_core/_io_kqueue.py +++ b/src/trio/_core/_io_kqueue.py @@ -77,9 +77,8 @@ def get_events(self, timeout: float) -> EventResult: events += batch if len(batch) < max_events: break - else: - timeout = 0 - # and loop back to the start + timeout = 0 + # and loop back to the start return events def process_events(self, events: EventResult) -> None: diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index a8431f89db..225c1cc2bb 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -148,7 +148,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[ return coro # type: ignore[return-value] return wrapper - elif inspect.isgeneratorfunction(fn): + if inspect.isgeneratorfunction(fn): @wraps(fn) def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] @@ -165,7 +165,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[ return gen # type: ignore[return-value] return wrapper - elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): + if inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): @wraps(fn) # type: ignore[arg-type] def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] @@ -175,14 +175,13 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[ return agen # type: ignore[return-value] return wrapper - else: - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return fn(*args, **kwargs) + @wraps(fn) + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled + return fn(*args, **kwargs) - return wrapper + return wrapper return decorator diff --git a/src/trio/_core/_mock_clock.py b/src/trio/_core/_mock_clock.py index 70c4e58a2d..913c435695 100644 --- a/src/trio/_core/_mock_clock.py +++ b/src/trio/_core/_mock_clock.py @@ -89,12 +89,11 @@ def rate(self) -> float: def rate(self, new_rate: float) -> None: if new_rate < 0: raise ValueError("rate must be >= 0") - else: - real = self._real_clock() - virtual = self._real_to_virtual(real) - self._virtual_base = virtual - self._real_base = real - self._rate = float(new_rate) + real = self._real_clock() + virtual = self._real_to_virtual(real) + self._virtual_base = virtual + self._real_base = real + self._rate = float(new_rate) @property def autojump_threshold(self) -> float: @@ -144,10 +143,9 @@ def deadline_to_sleep_time(self, deadline: float) -> float: virtual_timeout = deadline - self.current_time() if virtual_timeout <= 0: return 0 - elif self._rate > 0: + if self._rate > 0: return virtual_timeout / self._rate - else: - return 999999999 + return 999999999 def jump(self, seconds: float) -> None: """Manually advance the clock by the given number of seconds. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 5453c3602e..4f3704722d 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -222,10 +222,9 @@ def collapse_exception_group( excgroup.__traceback__, exceptions[0].__traceback__ ) return exceptions[0] - elif modified: + if modified: return excgroup.derive(exceptions) - else: - return excgroup + return excgroup @attrs.define(eq=False) @@ -254,9 +253,8 @@ def next_deadline(self) -> float: deadline, _, cancel_scope = self._heap[0] if deadline == cancel_scope._registered_deadline: return deadline - else: - # This entry is stale; discard it and try again - heappop(self._heap) + # This entry is stale; discard it and try again + heappop(self._heap) return inf def _prune(self) -> None: @@ -642,22 +640,21 @@ def __exit__( remaining_error_after_cancel_scope = self._close(exc) if remaining_error_after_cancel_scope is None: return True - elif remaining_error_after_cancel_scope is exc: + if remaining_error_after_cancel_scope is exc: return False - else: - # Copied verbatim from the old MultiErrorCatcher. Python doesn't - # allow us to encapsulate this __context__ fixup. - old_context = remaining_error_after_cancel_scope.__context__ - try: - raise remaining_error_after_cancel_scope - finally: - _, value, _ = sys.exc_info() - assert value is remaining_error_after_cancel_scope - value.__context__ = old_context - # delete references from locals to avoid creating cycles - # see test_cancel_scope_exit_doesnt_create_cyclic_garbage - # Note: still relevant - del remaining_error_after_cancel_scope, value, _, exc + # Copied verbatim from the old MultiErrorCatcher. Python doesn't + # allow us to encapsulate this __context__ fixup. + old_context = remaining_error_after_cancel_scope.__context__ + try: + raise remaining_error_after_cancel_scope + finally: + _, value, _ = sys.exc_info() + assert value is remaining_error_after_cancel_scope + value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + # Note: still relevant + del remaining_error_after_cancel_scope, value, _, exc def __repr__(self) -> str: if self._cancel_status is not None: @@ -949,21 +946,20 @@ async def __aexit__( combined_error_from_nursery = self._scope._close(new_exc) if combined_error_from_nursery is None: return True - elif combined_error_from_nursery is exc: + if combined_error_from_nursery is exc: return False - else: - # Copied verbatim from the old MultiErrorCatcher. Python doesn't - # allow us to encapsulate this __context__ fixup. - old_context = combined_error_from_nursery.__context__ - try: - raise combined_error_from_nursery - finally: - _, value, _ = sys.exc_info() - assert value is combined_error_from_nursery - value.__context__ = old_context - # delete references from locals to avoid creating cycles - # see test_cancel_scope_exit_doesnt_create_cyclic_garbage - del _, combined_error_from_nursery, value, new_exc + # Copied verbatim from the old MultiErrorCatcher. Python doesn't + # allow us to encapsulate this __context__ fixup. + old_context = combined_error_from_nursery.__context__ + try: + raise combined_error_from_nursery + finally: + _, value, _ = sys.exc_info() + assert value is combined_error_from_nursery + value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + del _, combined_error_from_nursery, value, new_exc # make sure these raise errors in static analysis if called if not TYPE_CHECKING: @@ -2302,10 +2298,9 @@ def run( # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): return cast(RetT, runner.main_task_outcome.value) - elif isinstance(runner.main_task_outcome, Error): + if isinstance(runner.main_task_outcome, Error): raise runner.main_task_outcome.error - else: # pragma: no cover - raise AssertionError(runner.main_task_outcome) + raise AssertionError(runner.main_task_outcome) # pragma: no cover def start_guest_run( diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index 8972ec735a..282b95890a 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -554,8 +554,7 @@ async def crash_in_worker_thread_io(in_host: InHost) -> None: def bad_get_events(*args: Any) -> object: if threading.current_thread() is not t: raise ValueError("oh no!") - else: - return old_get_events(*args) + return old_get_events(*args) m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index ee823cb81a..f7d2958b86 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -2039,8 +2039,9 @@ async def __anext__(self) -> list[int]: e.exceptions[0], StopAsyncIteration ): raise e.exceptions[0] from None - else: # pragma: no cover - raise AssertionError("unknown error in _accumulate") from e + raise AssertionError( # pragma: no cover + "unknown error in _accumulate" + ) from e return items diff --git a/src/trio/_core/_tests/test_thread_cache.py b/src/trio/_core/_tests/test_thread_cache.py index ee301d17fd..b59d5b26fe 100644 --- a/src/trio/_core/_tests/test_thread_cache.py +++ b/src/trio/_core/_tests/test_thread_cache.py @@ -147,14 +147,13 @@ def acquire(self, timeout: int = -1) -> bool: got_it = self._lock.acquire(timeout=timeout) if timeout == -1: return True - elif got_it: + if got_it: if self._counter > 0: self._counter -= 1 self._lock.release() return False return True - else: - return False + return False def release(self) -> None: self._lock.release() diff --git a/src/trio/_core/_tests/test_windows.py b/src/trio/_core/_tests/test_windows.py index 486a405590..905676a034 100644 --- a/src/trio/_core/_tests/test_windows.py +++ b/src/trio/_core/_tests/test_windows.py @@ -247,8 +247,7 @@ def patched_get_underlying( sock = sock.fileno() if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: return _handle(sock + 1) - else: - return _handle(sock) + return _handle(sock) monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) with pytest.raises( @@ -275,8 +274,7 @@ def patched_get_underlying( sock = sock.fileno() if which == WSAIoctls.SIO_BASE_HANDLE: raise OSError("nope") - else: - return _handle(sock) + return _handle(sock) monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) with pytest.raises( diff --git a/src/trio/_core/_unbounded_queue.py b/src/trio/_core/_unbounded_queue.py index b9ebe484d7..fd7f886f78 100644 --- a/src/trio/_core/_unbounded_queue.py +++ b/src/trio/_core/_unbounded_queue.py @@ -143,11 +143,10 @@ async def get_batch(self) -> list[T]: if not self._can_get: await self._lot.park() return self._get_batch_protected() - else: - try: - return self._get_batch_protected() - finally: - await _core.cancel_shielded_checkpoint() + try: + return self._get_batch_protected() + finally: + await _core.cancel_shielded_checkpoint() def statistics(self) -> UnboundedQueueStatistics: """Return an :class:`UnboundedQueueStatistics` object containing debugging information.""" diff --git a/src/trio/_dtls.py b/src/trio/_dtls.py index 31f7817e1c..626040ad32 100644 --- a/src/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -53,15 +53,13 @@ def packet_header_overhead(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 - else: - return 48 + return 48 def worst_case_mtu(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) - else: - return 1280 - packet_header_overhead(sock) + return 1280 - packet_header_overhead(sock) def best_guess_mtu(sock: SocketType) -> int: @@ -600,8 +598,7 @@ def valid_cookie( return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( cookie, old_cookie ) - else: - return False + return False def challenge_for( @@ -640,10 +637,9 @@ def challenge_for( ) payload = encode_handshake_fragment(hs) - packet = encode_record( + return encode_record( Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload) ) - return packet _T = TypeVar("_T") @@ -737,9 +733,8 @@ async def handle_client_hello_untrusted( if old_stream._client_hello == (cookie, bits): # ...This was just a duplicate of the last ClientHello, so never mind. return - else: - # Ok, this *really is* a new handshake; the old stream should go away. - old_stream._set_replaced() + # Ok, this *really is* a new handshake; the old stream should go away. + old_stream._set_replaced() stream._client_hello = (cookie, bits) endpoint._streams[address] = stream endpoint._incoming_connections_q.s.send_nowait(stream) @@ -761,8 +756,7 @@ async def dtls_receive_loop( # This is totally useless -- there's nothing we can do with this # information. So we just ignore it and retry the recv. continue - else: - raise + raise endpoint = endpoint_ref() try: if endpoint is None: @@ -798,9 +792,7 @@ async def dtls_receive_loop( if exc.errno in (errno.EBADF, errno.ENOTSOCK): # socket was closed return - else: # pragma: no cover - # ??? shouldn't happen - raise + raise # ??? shouldn't happen # pragma: no cover @attrs.frozen @@ -989,8 +981,7 @@ def read_volley() -> list[_AnyHandshakeMessage]: # openssl decided to retransmit; discard because we handle # retransmits ourselves return [] - else: - return new_volley_messages + return new_volley_messages # If we're a client, we send the initial volley. If we're a server, then # the initial ClientHello has already been inserted into self._ssl's diff --git a/src/trio/_file_io.py b/src/trio/_file_io.py index ef867243f0..d9f897f665 100644 --- a/src/trio/_file_io.py +++ b/src/trio/_file_io.py @@ -265,8 +265,7 @@ async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr: line = await self.readline() if line: return line - else: - raise StopAsyncIteration + raise StopAsyncIteration async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]: """Like :meth:`io.BufferedIOBase.detach`, but async. @@ -463,12 +462,11 @@ async def open_file( :func:`trio.Path.open` """ - _file = wrap_file( + return wrap_file( await trio.to_thread.run_sync( io.open, file, mode, buffering, encoding, errors, newline, closefd, opener ) ) - return _file def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: diff --git a/src/trio/_highlevel_generic.py b/src/trio/_highlevel_generic.py index 88a86318a3..8f1467054f 100644 --- a/src/trio/_highlevel_generic.py +++ b/src/trio/_highlevel_generic.py @@ -113,8 +113,7 @@ async def send_eof(self) -> None: stream = self.send_stream if _is_halfclosable(stream): return await stream.send_eof() - else: - return await stream.aclose() + return await stream.aclose() # we intentionally accept more types from the caller than we support returning async def receive_some(self, max_bytes: int | None = None) -> bytes: diff --git a/src/trio/_highlevel_open_tcp_listeners.py b/src/trio/_highlevel_open_tcp_listeners.py index 80555be33e..6856fe498d 100644 --- a/src/trio/_highlevel_open_tcp_listeners.py +++ b/src/trio/_highlevel_open_tcp_listeners.py @@ -131,8 +131,7 @@ async def open_tcp_listeners( # failure to create the other. unsupported_address_families.append(ex) continue - else: - raise + raise try: # See https://github.com/python-trio/trio/issues/39 if sys.platform != "win32": diff --git a/src/trio/_highlevel_open_tcp_stream.py b/src/trio/_highlevel_open_tcp_stream.py index d5c83da7c0..a58a8ab955 100644 --- a/src/trio/_highlevel_open_tcp_stream.py +++ b/src/trio/_highlevel_open_tcp_stream.py @@ -164,8 +164,7 @@ def format_host_port(host: str | bytes, port: int | str) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" - else: - return f"{host}:{port}" + return f"{host}:{port}" # Twisted's HostnameEndpoint has a good set of configurables: @@ -401,7 +400,6 @@ async def attempt_connect( assert len(oserrors) == len(targets) msg = f"all attempts to connect to {format_host_port(host, port)} failed" raise OSError(msg) from ExceptionGroup(msg, oserrors) - else: - stream = trio.SocketStream(winning_socket) - open_sockets.remove(winning_socket) - return stream + stream = trio.SocketStream(winning_socket) + open_sockets.remove(winning_socket) + return stream diff --git a/src/trio/_highlevel_socket.py b/src/trio/_highlevel_socket.py index 901e22f345..49bff76b27 100644 --- a/src/trio/_highlevel_socket.py +++ b/src/trio/_highlevel_socket.py @@ -193,8 +193,7 @@ def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | byte # We should be able to drop it when the next PyPy3 beta is released. if buffersize == 0: return self.socket.getsockopt(level, option) - else: - return self.socket.getsockopt(level, option, buffersize) + return self.socket.getsockopt(level, option, buffersize) ################################################################ diff --git a/src/trio/_highlevel_ssl_helpers.py b/src/trio/_highlevel_ssl_helpers.py index 03562c9edb..275f69c826 100644 --- a/src/trio/_highlevel_ssl_helpers.py +++ b/src/trio/_highlevel_ssl_helpers.py @@ -93,11 +93,10 @@ async def open_ssl_over_tcp_listeners( """ tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) - ssl_listeners = [ + return [ trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible) for tcp_listener in tcp_listeners ] - return ssl_listeners async def serve_ssl_over_tcp( diff --git a/src/trio/_repl.py b/src/trio/_repl.py index 73f050140e..9da7e1ecde 100644 --- a/src/trio/_repl.py +++ b/src/trio/_repl.py @@ -40,19 +40,18 @@ def runcode(self, code: types.CodeType) -> None: # return to the REPL. if isinstance(result.error, SystemExit): raise result.error - else: - # Inline our own version of self.showtraceback that can use - # outcome.Error.error directly to print clean tracebacks. - # This also means overriding self.showtraceback does nothing. - sys.last_type, sys.last_value = type(result.error), result.error - sys.last_traceback = result.error.__traceback__ - # see https://docs.python.org/3/library/sys.html#sys.last_exc - if sys.version_info >= (3, 12): - sys.last_exc = result.error + # Inline our own version of self.showtraceback that can use + # outcome.Error.error directly to print clean tracebacks. + # This also means overriding self.showtraceback does nothing. + sys.last_type, sys.last_value = type(result.error), result.error + sys.last_traceback = result.error.__traceback__ + # see https://docs.python.org/3/library/sys.html#sys.last_exc + if sys.version_info >= (3, 12): + sys.last_exc = result.error - # We always use sys.excepthook, unlike other implementations. - # This means that overriding self.write also does nothing to tbs. - sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) + # We always use sys.excepthook, unlike other implementations. + # This means that overriding self.write also does nothing to tbs. + sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) async def run_repl(console: TrioInteractiveConsole) -> None: diff --git a/src/trio/_socket.py b/src/trio/_socket.py index 0a3bd1cba1..8f69f6cdcd 100644 --- a/src/trio/_socket.py +++ b/src/trio/_socket.py @@ -69,8 +69,7 @@ def __init__( def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) - else: - return self._blocking_exc_override(exc) + return self._blocking_exc_override(exc) async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() @@ -85,10 +84,9 @@ async def __aexit__( # Discard the exception and fall through to the code below the # block return True - else: - await trio.lowlevel.cancel_shielded_checkpoint() - # Let the return or exception propagate - return False + await trio.lowlevel.cancel_shielded_checkpoint() + # Let the return or exception propagate + return False ################################################################ @@ -231,17 +229,16 @@ def numeric_only_failure(exc: BaseException) -> bool: hr = _resolver.get(None) if hr is not None: return await hr.getaddrinfo(host, port, family, type, proto, flags) - else: - return await trio.to_thread.run_sync( - _stdlib_socket.getaddrinfo, - host, - port, - family, - type, - proto, - flags, - abandon_on_cancel=True, - ) + return await trio.to_thread.run_sync( + _stdlib_socket.getaddrinfo, + host, + port, + family, + type, + proto, + flags, + abandon_on_cancel=True, + ) async def getnameinfo( @@ -259,10 +256,9 @@ async def getnameinfo( hr = _resolver.get(None) if hr is not None: return await hr.getnameinfo(sockaddr, flags) - else: - return await trio.to_thread.run_sync( - _stdlib_socket.getnameinfo, sockaddr, flags, abandon_on_cancel=True - ) + return await trio.to_thread.run_sync( + _stdlib_socket.getnameinfo, sockaddr, flags, abandon_on_cancel=True + ) async def getprotobyname(name: str) -> int: @@ -880,13 +876,12 @@ async def bind(self, address: AddressFormat) -> None: # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) return await trio.to_thread.run_sync(self._sock.bind, address) - else: - # POSIX actually says that bind can return EWOULDBLOCK and - # complete asynchronously, like connect. But in practice AFAICT - # there aren't yet any real systems that do this, so we'll worry - # about it when it happens. - await trio.lowlevel.checkpoint() - return self._sock.bind(address) + # POSIX actually says that bind can return EWOULDBLOCK and + # complete asynchronously, like connect. But in practice AFAICT + # there aren't yet any real systems that do this, so we'll worry + # about it when it happens. + await trio.lowlevel.checkpoint() + return self._sock.bind(address) def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: @@ -1058,6 +1053,7 @@ async def connect(self, address: AddressFormat) -> None: err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") + return None ################################################################ # recv diff --git a/src/trio/_ssl.py b/src/trio/_ssl.py index 5bc37cf7dc..bc9b7657a2 100644 --- a/src/trio/_ssl.py +++ b/src/trio/_ssl.py @@ -419,8 +419,7 @@ def __getattr__(self, name: str) -> Any: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") return getattr(self._ssl_object, name) - else: - raise AttributeError(name) + raise AttributeError(name) def __setattr__(self, name: str, value: object) -> None: if name in self._forwarded: @@ -434,12 +433,11 @@ def __dir__(self) -> list[str]: def _check_status(self) -> None: if self._state is _State.OK: return - elif self._state is _State.BROKEN: + if self._state is _State.BROKEN: raise trio.BrokenResourceError - elif self._state is _State.CLOSED: + if self._state is _State.CLOSED: raise trio.ClosedResourceError - else: # pragma: no cover - raise AssertionError() + raise AssertionError() # pragma: no cover # This is probably the single trickiest function in Trio. It has lots of # comments, though, just make sure to think carefully if you ever have to @@ -692,8 +690,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: ): await trio.lowlevel.checkpoint() return b"" - else: - raise + raise if max_bytes is None: # If we somehow have more data already in our pending buffer # than the estimate receive size, bump up our size a bit for @@ -717,8 +714,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: if self._https_compatible and _is_eof(exc.__cause__): await trio.lowlevel.checkpoint() return b"" - else: - raise + raise async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Encrypt some data and then send it on the underlying transport. diff --git a/src/trio/_subprocess.py b/src/trio/_subprocess.py index 553e3d4885..2edc1baa45 100644 --- a/src/trio/_subprocess.py +++ b/src/trio/_subprocess.py @@ -770,9 +770,8 @@ async def killer() -> None: raise subprocess.CalledProcessError( proc.returncode, proc.args, output=stdout, stderr=stderr ) - else: - assert proc.returncode is not None - return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) + assert proc.returncode is not None + return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) # There's a lot of duplication here because type checkers don't diff --git a/src/trio/_sync.py b/src/trio/_sync.py index 6e62eceeff..555a72e59f 100644 --- a/src/trio/_sync.py +++ b/src/trio/_sync.py @@ -573,7 +573,7 @@ def acquire_nowait(self) -> None: task = trio.lowlevel.current_task() if self._owner is task: raise RuntimeError("attempt to re-acquire an already held Lock") - elif self._owner is None and not self._lot: + if self._owner is None and not self._lot: # No-one owns it self._owner = task else: diff --git a/src/trio/_tests/check_type_completeness.py b/src/trio/_tests/check_type_completeness.py index fa6ace074f..8e4002e661 100755 --- a/src/trio/_tests/check_type_completeness.py +++ b/src/trio/_tests/check_type_completeness.py @@ -101,12 +101,11 @@ def has_docstring_at_runtime(name: str) -> bool: ): return True - else: - print( - f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).", - file=sys.stderr, - ) - return False + print( + f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).", + file=sys.stderr, + ) + return False return bool(obj.__doc__) diff --git a/src/trio/_tests/test_dtls.py b/src/trio/_tests/test_dtls.py index d14edae25c..d359b210cf 100644 --- a/src/trio/_tests/test_dtls.py +++ b/src/trio/_tests/test_dtls.py @@ -134,7 +134,7 @@ async def route_packet(packet: UDPPacket) -> None: print(f"{packet.source} -> {packet.destination}: {op}") if op == "drop": return - elif op == "dupe": + if op == "dupe": fn.send_packet(packet) elif op == "delay": await trio.sleep(r.random() * 3) @@ -735,8 +735,7 @@ async def start_and_forget_endpoint() -> int: await trio.testing.wait_all_tasks_blocked() nursery.cancel_scope.cancel() - during_tasks = trio.lowlevel.current_statistics().tasks_living - return during_tasks + return trio.lowlevel.current_statistics().tasks_living with pytest.warns(ResourceWarning): during_tasks = await start_and_forget_endpoint() diff --git a/src/trio/_tests/test_highlevel_socket.py b/src/trio/_tests/test_highlevel_socket.py index 976a3b5e04..4f91debd9b 100644 --- a/src/trio/_tests/test_highlevel_socket.py +++ b/src/trio/_tests/test_highlevel_socket.py @@ -261,8 +261,7 @@ async def accept(self) -> tuple[SocketType, object]: event = next(self._events) if isinstance(event, BaseException): raise event - else: - return event, None + return event, None fake_server_sock = FakeSocket([]) diff --git a/src/trio/_tests/test_socket.py b/src/trio/_tests/test_socket.py index b98b3246e9..11d7ee70df 100644 --- a/src/trio/_tests/test_socket.py +++ b/src/trio/_tests/test_socket.py @@ -64,10 +64,9 @@ def getaddrinfo(self, *args: Any, **kwargs: Any) -> GetAddrInfoResponse | str: self.record.append(bound) if bound in self._responses: return self._responses[bound] - elif bound[-1] & stdlib_socket.AI_NUMERICHOST: + if bound[-1] & stdlib_socket.AI_NUMERICHOST: return self._orig_getaddrinfo(*args, **kwargs) - else: - raise RuntimeError(f"gai called with unexpected arguments {bound}") + raise RuntimeError(f"gai called with unexpected arguments {bound}") @pytest.fixture diff --git a/src/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py index 8e780b2f9c..d1e56d5f30 100644 --- a/src/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -107,11 +107,10 @@ def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext: TRIO_TEST_CA.configure_trust(ctx) if request.param in ["default", "tls13"]: return ctx - elif request.param == "tls12": + if request.param == "tls12": ctx.maximum_version = ssl.TLSVersion.TLSv1_2 return ctx - else: # pragma: no cover - raise AssertionError() + raise AssertionError() # pragma: no cover # The blocking socket server. diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 0a70e7a974..27da1bc582 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -86,8 +86,7 @@ def SLEEP(seconds: int) -> list[str]: def got_signal(proc: Process, sig: SignalType) -> bool: if (not TYPE_CHECKING and posix) or sys.platform != "win32": return proc.returncode == -sig - else: - return proc.returncode != 0 + return proc.returncode != 0 @asynccontextmanager # type: ignore[misc] # Any in decorator diff --git a/src/trio/_tests/test_threads.py b/src/trio/_tests/test_threads.py index b4a5842ff0..fe267ce7c4 100644 --- a/src/trio/_tests/test_threads.py +++ b/src/trio/_tests/test_threads.py @@ -629,8 +629,7 @@ async def test_trio_to_thread_run_sync_token() -> None: # Test that to_thread_run_sync automatically injects the current trio token # into a spawned thread def thread_fn() -> _core.TrioToken: - callee_token = from_thread_run_sync(_core.current_trio_token) - return callee_token + return from_thread_run_sync(_core.current_trio_token) caller_token = _core.current_trio_token() callee_token = await to_thread_run_sync(thread_fn) @@ -692,8 +691,7 @@ async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() def thread_fn_1() -> float: - trio_time = from_thread_run_sync(_core.current_time) - return trio_time + return from_thread_run_sync(_core.current_time) trio_time = await to_thread_run_sync(thread_fn_1) assert isinstance(trio_time, float) @@ -737,8 +735,7 @@ async def test_trio_from_thread_token() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() # share the same Trio token def thread_fn() -> _core.TrioToken: - callee_token = from_thread_run_sync(_core.current_trio_token) - return callee_token + return from_thread_run_sync(_core.current_trio_token) caller_token = _core.current_trio_token() callee_token = await to_thread_run_sync(thread_fn) @@ -749,8 +746,7 @@ async def test_trio_from_thread_token_kwarg() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token: _core.TrioToken) -> _core.TrioToken: - callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token) - return callee_token + return from_thread_run_sync(_core.current_trio_token, trio_token=token) caller_token = _core.current_trio_token() callee_token = await to_thread_run_sync(thread_fn, caller_token) diff --git a/src/trio/_threads.py b/src/trio/_threads.py index a04b737292..fe5e578964 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -425,8 +425,7 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: # empty so report_back_in_trio_thread_fn cannot reschedule task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED - else: - return trio.lowlevel.Abort.FAILED + return trio.lowlevel.Abort.FAILED while True: # wait_task_rescheduled return value cannot be typed @@ -435,7 +434,7 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: ) if isinstance(msg_from_thread, outcome.Outcome): return msg_from_thread.unwrap() - elif isinstance(msg_from_thread, Run): + if isinstance(msg_from_thread, Run): await msg_from_thread.run() elif isinstance(msg_from_thread, RunSync): msg_from_thread.run_sync() diff --git a/src/trio/_tools/mypy_annotate.py b/src/trio/_tools/mypy_annotate.py index 6bd20f401c..c297e847a6 100644 --- a/src/trio/_tools/mypy_annotate.py +++ b/src/trio/_tools/mypy_annotate.py @@ -62,8 +62,7 @@ def process_line(line: str) -> Result | None: kind=mypy_to_github[kind], message=message, ) - else: - return None + return None def export(results: dict[Result, list[str]]) -> None: diff --git a/src/trio/_unix_pipes.py b/src/trio/_unix_pipes.py index 34340d2b36..e11110ef3c 100644 --- a/src/trio/_unix_pipes.py +++ b/src/trio/_unix_pipes.py @@ -149,8 +149,7 @@ async def send_all(self, data: bytes) -> None: raise trio.ClosedResourceError( "file was already closed" ) from None - else: - raise trio.BrokenResourceError from e + raise trio.BrokenResourceError from e async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: @@ -184,10 +183,8 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: raise trio.ClosedResourceError( "file was already closed" ) from None - else: - raise trio.BrokenResourceError from exc - else: - break + raise trio.BrokenResourceError from exc + break return data diff --git a/src/trio/_util.py b/src/trio/_util.py index 7c9e194d19..ed7cec4f4b 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -222,8 +222,7 @@ def __init__(self, msg: str) -> None: def __enter__(self) -> None: if self._held: raise trio.BusyResourceError(self._msg) - else: - self._held = True + self._held = True def __exit__( self, diff --git a/src/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py index f8589f3a9c..f527d11fe5 100644 --- a/src/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -42,7 +42,7 @@ def _family_for(ip: IPAddress) -> int: if isinstance(ip, ipaddress.IPv4Address): return trio.socket.AF_INET - elif isinstance(ip, ipaddress.IPv6Address): + if isinstance(ip, ipaddress.IPv6Address): return trio.socket.AF_INET6 raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover @@ -50,7 +50,7 @@ def _family_for(ip: IPAddress) -> int: def _wildcard_ip_for(family: int) -> IPAddress: if family == trio.socket.AF_INET: return ipaddress.ip_address("0.0.0.0") - elif family == trio.socket.AF_INET6: + if family == trio.socket.AF_INET6: return ipaddress.ip_address("::") raise NotImplementedError("Unhandled ip address family") # pragma: no cover @@ -59,7 +59,7 @@ def _wildcard_ip_for(family: int) -> IPAddress: def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover if family == trio.socket.AF_INET: return ipaddress.ip_address("127.0.0.1") - elif family == trio.socket.AF_INET6: + if family == trio.socket.AF_INET6: return ipaddress.ip_address("::1") raise NotImplementedError("Unhandled ip address family") @@ -388,11 +388,10 @@ def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]: self._check_closed() if self._binding is not None: return self._binding.local.as_python_sockaddr() - elif self.family == trio.socket.AF_INET: + if self.family == trio.socket.AF_INET: return ("0.0.0.0", 0) - else: - assert self.family == trio.socket.AF_INET6 - return ("::", 0) + assert self.family == trio.socket.AF_INET6 + return ("::", 0) # TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError. def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: diff --git a/src/trio/testing/_memory_streams.py b/src/trio/testing/_memory_streams.py index c9d430a9e6..fe7bbb6b6d 100644 --- a/src/trio/testing/_memory_streams.py +++ b/src/trio/testing/_memory_streams.py @@ -66,8 +66,7 @@ def _get_impl(self, max_bytes: int | None) -> bytearray: del self._data[:max_bytes] assert chunk return chunk - else: - return bytearray() + return bytearray() def get_nowait(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: @@ -542,9 +541,8 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: del self._data[:max_bytes] self._something_happened() return got - else: - assert self._sender_closed - return b"" + assert self._sender_closed + return b"" class _LockstepSendStream(SendStream): From 5670c4c8d0dd336acc115a9e1e658c5af65ea24e Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 1 Aug 2024 00:30:43 -0500 Subject: [PATCH 2/3] Fix issue with `data` not being defined --- src/trio/_unix_pipes.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/trio/_unix_pipes.py b/src/trio/_unix_pipes.py index e11110ef3c..bbb28b445c 100644 --- a/src/trio/_unix_pipes.py +++ b/src/trio/_unix_pipes.py @@ -175,7 +175,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: await trio.lowlevel.checkpoint() while True: try: - data = os.read(self._fd_holder.fd, max_bytes) + return os.read(self._fd_holder.fd, max_bytes) except BlockingIOError: await trio.lowlevel.wait_readable(self._fd_holder.fd) except OSError as exc: @@ -184,9 +184,6 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: "file was already closed" ) from None raise trio.BrokenResourceError from exc - break - - return data def close(self) -> None: self._fd_holder.close() From a28ed50485b9c3fdf4f342cf25ae74f65150756d Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 1 Aug 2024 00:34:41 -0500 Subject: [PATCH 3/3] Fix type issue (mypy things names conflict) --- src/trio/_core/_ki.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index 225c1cc2bb..51e8a871e2 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -177,11 +177,11 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[ return wrapper @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + def wrapper_(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) - return wrapper + return wrapper_ return decorator