diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 20f5235eb0..68961fc711 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -6,6 +6,14 @@ "ResourceBusyError", ] +# Exceptions often get printed as module.Classname. We pretend these are in +# the trio namespace (where they'll eventually end up) so that users get +# better messages. +def pretend_module_is_trio(cls): + cls.__module__ = "trio" + return cls + +@pretend_module_is_trio class TrioInternalError(Exception): """Raised by :func:`run` if we encounter a bug in trio, or (possibly) a misuse of one of the low-level :mod:`trio.hazmat` APIs. @@ -20,9 +28,8 @@ class TrioInternalError(Exception): """ pass -TrioInternalError.__module__ = "trio" - +@pretend_module_is_trio class RunFinishedError(RuntimeError): """Raised by ``run_in_trio_thread`` and similar functions if the corresponding call to :func:`trio.run` has already finished. @@ -30,18 +37,16 @@ class RunFinishedError(RuntimeError): """ pass -RunFinishedError.__module__ = "trio" - +@pretend_module_is_trio class WouldBlock(Exception): """Raised by ``X_nowait`` functions if ``X`` would block. """ pass -WouldBlock.__module__ = "trio" - +@pretend_module_is_trio class Cancelled(BaseException): """Raised by blocking calls if the surrounding scope has been cancelled. @@ -72,9 +77,8 @@ class Cancelled(BaseException): """ _scope = None -Cancelled.__module__ = "trio" - +@pretend_module_is_trio class ResourceBusyError(Exception): """Raised when a task attempts to use a resource that some other task is already using, and this would lead to bugs and nonsense. diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 8725a08769..cc57dc825c 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -131,7 +131,8 @@ def __init__(self): # except that wakeup socket is mapped to None self._socket_waiters = {"read": {}, "write": {}} self._main_thread_waker = WakeupSocketpair() - self._socket_waiters["read"][self._main_thread_waker.wakeup_sock] = None + wakeup_sock = self._main_thread_waker.wakeup_sock + self._socket_waiters["read"][wakeup_sock] = None # This is necessary to allow control-C to interrupt select(). # https://github.com/python-trio/trio/issues/42 @@ -237,11 +238,12 @@ def do_select(): else: # dispatch on lpCompletionKey queue = self._completion_key_queues[entry.lpCompletionKey] + overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped)) + transferred = entry.dwNumberOfBytesTransferred info = CompletionKeyEventInfo( - lpOverlapped= - int(ffi.cast("uintptr_t", entry.lpOverlapped)), - dwNumberOfBytesTransferred= - entry.dwNumberOfBytesTransferred) + lpOverlapped=overlapped, + dwNumberOfBytesTransferred=transferred, + ) queue.put_nowait(info) def _iocp_thread_fn(self): diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index 3d2761bf25..7d4c380623 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -10,7 +10,9 @@ from . import _hazmat -__all__ = ["enable_ki_protection", "disable_ki_protection", "currently_ki_protected"] +__all__ = [ + "enable_ki_protection", "disable_ki_protection", "currently_ki_protected", +] # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index db0c130c41..63a9f16be6 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -460,17 +460,22 @@ async def main(): ("before_run",), ("schedule", tasks["t1"]), ("schedule", tasks["t2"]), - {("before", tasks["t1"]), - ("after", tasks["t1"]), - ("before", tasks["t2"]), - ("after", tasks["t2"])}, - {("schedule", tasks["t1"]), - ("before", tasks["t1"]), - ("after", tasks["t1"]), - ("schedule", tasks["t2"]), - ("before", tasks["t2"]), - ("after", tasks["t2"])}, - ("after_run",)] + { + ("before", tasks["t1"]), + ("after", tasks["t1"]), + ("before", tasks["t2"]), + ("after", tasks["t2"]) + }, + { + ("schedule", tasks["t1"]), + ("before", tasks["t1"]), + ("after", tasks["t1"]), + ("schedule", tasks["t2"]), + ("before", tasks["t2"]), + ("after", tasks["t2"]) + }, + ("after_run",), + ] # yapf: disable print(list(r.filter_tasks(tasks.values()))) check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) diff --git a/trio/socket.py b/trio/socket.py index cee5c6180b..a01888d5c0 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -22,30 +22,35 @@ def _reexport(name): # Usage: +# # async with _try_sync(): # return sync_call_that_might_fail_with_exception() -# # we only get here if the sync call in fact did fail with an exception -# # that passed the blocking_exc_check +# # we only get here if the sync call in fact did fail with a +# # BlockingIOError # return await do_it_properly_with_a_check_point() - -def _is_blocking_io_error(exc): - return isinstance(exc, BlockingIOError) - +# class _try_sync: - def __init__(self, blocking_exc_check=_is_blocking_io_error): - self._blocking_exc_check = blocking_exc_check + def __init__(self, blocking_exc_override=None): + self._blocking_exc_override = blocking_exc_override + + def _is_blocking_io_error(self, exc): + if self._blocking_exc_override is None: + return isinstance(exc, BlockingIOError) + else: + return self._blocking_exc_override(exc) async def __aenter__(self): await _core.yield_if_cancelled() async def __aexit__(self, etype, value, tb): - if value is not None and self._blocking_exc_check(value): - # discard the exception and fall through to the code below the + if value is not None and self._is_blocking_io_error(value): + # Discard the exception and fall through to the code below the # block return True else: await _core.yield_briefly_no_cancel() # Let the return or exception propagate + return False ################################################################ @@ -163,8 +168,15 @@ def set_custom_socket_factory(socket_factory): # getaddrinfo and friends ################################################################ +def _add_to_all(obj): + __all__.append(obj.__name__) + return obj + + _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV + +@_add_to_all async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): """Look up a numeric address given a name. @@ -215,9 +227,8 @@ def numeric_only_failure(exc): _stdlib_socket.getaddrinfo, host, port, family, type, proto, flags, cancellable=True) -__all__.append("getaddrinfo") - +@_add_to_all async def getnameinfo(sockaddr, flags): """Look up a name given a numeric address. @@ -235,9 +246,9 @@ async def getnameinfo(sockaddr, flags): return await _run_in_worker_thread( _stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True) -__all__.append("getnameinfo") +@_add_to_all async def getprotobyname(name): """Look up a protocol number by name. (Rarely used.) @@ -247,7 +258,6 @@ async def getprotobyname(name): return await _run_in_worker_thread( _stdlib_socket.getprotobyname, name, cancellable=True) -__all__.append("getprotobyname") # obsolete gethostbyname etc. intentionally omitted @@ -256,32 +266,33 @@ async def getprotobyname(name): # Socket "constructors" ################################################################ +@_add_to_all def from_stdlib_socket(sock): """Convert a standard library :func:`socket.socket` object into a trio socket object. """ return _SocketType(sock) -__all__.append("from_stdlib_socket") @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) +@_add_to_all def fromfd(*args, **kwargs): """Like :func:`socket.fromfd`, but returns a trio socket object. """ return from_stdlib_socket(_stdlib_socket.fromfd(*args, **kwargs)) -__all__.append("fromfd") if hasattr(_stdlib_socket, "fromshare"): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) + @_add_to_all def fromshare(*args, **kwargs): return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) - __all__.append("fromshare") @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) +@_add_to_all def socketpair(*args, **kwargs): """Like :func:`socket.socketpair`, but returns a pair of trio socket objects. @@ -289,10 +300,10 @@ def socketpair(*args, **kwargs): """ left, right = _stdlib_socket.socketpair(*args, **kwargs) return (from_stdlib_socket(left), from_stdlib_socket(right)) -__all__.append("socketpair") @_wraps(_stdlib_socket.socket, assigned=(), updated=()) +@_add_to_all def socket(family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): """Create a new trio socket, like :func:`socket.socket`. @@ -306,13 +317,13 @@ def socket(family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): return sf.socket(family, type, proto) stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) return from_stdlib_socket(stdlib_socket) -__all__.append("socket") ################################################################ # Type checking ################################################################ +@_add_to_all def is_trio_socket(obj): """Check whether the given object is a trio socket. @@ -325,7 +336,6 @@ def is_trio_socket(obj): return True return isinstance(obj, _SocketType) -__all__.append("is_trio_socket") ################################################################ # _SocketType @@ -349,6 +359,7 @@ def is_trio_socket(obj): def _real_type(type_num): return type_num & _SOCK_TYPE_MASK +@_add_to_all class _SocketType: def __init__(self, sock): if type(sock) is not _stdlib_socket.socket: diff --git a/trio/ssl.py b/trio/ssl.py index 322abe57bf..df7e284d13 100644 --- a/trio/ssl.py +++ b/trio/ssl.py @@ -174,6 +174,9 @@ def _reexport(name): globals()[name] = value __all__.append(name) + +# Intentionally not re-exported: +# SSLContext for _name in [ "SSLError", "SSLZeroReturnError", "SSLSyscallError", "SSLEOFError", "CertificateError", "create_default_context", "match_hostname", @@ -181,7 +184,6 @@ def _reexport(name): "get_default_verify_paths", "Purpose", "enum_certificates", "enum_crls", "SSLSession", "VerifyMode", "VerifyFlags", "Options", "AlertDescription", "SSLErrorNumber", - # Intentionally not re-exported: SSLContext ]: _reexport(_name) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 16404570d6..3c3749a85d 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -90,24 +90,27 @@ def test_socket_has_some_reexports(): ################################################################ async def test_getaddrinfo(monkeygai): - # Simple non-blocking non-error cases, ipv4 and ipv6: - with assert_yields(): - res = await tsocket.getaddrinfo( - "127.0.0.1", "12345", type=tsocket.SOCK_STREAM) def check(got, expected): # win32 returns 0 for the proto field def without_proto(gai_tup): return gai_tup[:2] + (0,) + gai_tup[3:] + expected2 = [without_proto(gt) for gt in expected] assert got == expected or got == expected2 + # Simple non-blocking non-error cases, ipv4 and ipv6: + with assert_yields(): + res = await tsocket.getaddrinfo( + "127.0.0.1", "12345", type=tsocket.SOCK_STREAM + ) + check(res, [ (tsocket.AF_INET, # 127.0.0.1 is ipv4 tsocket.SOCK_STREAM, tsocket.IPPROTO_TCP, "", ("127.0.0.1", 12345)), - ]) + ]) # yapf: disable with assert_yields(): res = await tsocket.getaddrinfo( @@ -118,7 +121,7 @@ def without_proto(gai_tup): tsocket.IPPROTO_UDP, "", ("::1", 12345, 0, 0)), - ]) + ]) # yapf: disable monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0) with assert_yields(): @@ -149,8 +152,8 @@ async def test_getnameinfo(): # Trivial test: ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV with assert_yields(): - assert (await tsocket.getnameinfo(("127.0.0.1", 1234), ni_numeric) - == ("127.0.0.1", "1234")) + got = await tsocket.getnameinfo(("127.0.0.1", 1234), ni_numeric) + assert got == ("127.0.0.1", "1234") # getnameinfo requires a numeric address as input: with assert_yields(): @@ -165,14 +168,14 @@ async def test_getnameinfo(): host, service = stdlib_socket.getnameinfo(("127.0.0.1", 80), 0) # Some working calls: - assert (await tsocket.getnameinfo(("127.0.0.1", 80), 0) - == (host, service)) + got = await tsocket.getnameinfo(("127.0.0.1", 80), 0) + assert got == (host, service) - assert (await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICHOST) - == ("127.0.0.1", service)) + got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICHOST) + assert got == ("127.0.0.1", service) - assert (await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICSERV) - == (host, "80")) + got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICSERV) + assert got == (host, "80") ################################################################ @@ -348,7 +351,9 @@ async def test_SocketType_shutdown(): @pytest.mark.parametrize("address, socket_type", [('127.0.0.1', tsocket.AF_INET), ('::1', tsocket.AF_INET6)]) async def test_SocketType_simple_server(address, socket_type): # listen, bind, accept, connect, getpeername, getsockname - with tsocket.socket(socket_type) as listener, tsocket.socket(socket_type) as client: + listener = tsocket.socket(socket_type) + client = tsocket.socket(socket_type) + with listener, client: listener.bind((address, 0)) listener.listen(20) addr = listener.getsockname()[:2] @@ -365,18 +370,20 @@ async def test_SocketType_simple_server(address, socket_type): async def test_SocketType_resolve(): sock4 = tsocket.socket(family=tsocket.AF_INET) with assert_yields(): - assert await sock4.resolve_local_address((None, 80)) == ("0.0.0.0", 80) + got = await sock4.resolve_local_address((None, 80)) + assert got == ("0.0.0.0", 80) with assert_yields(): - assert (await sock4.resolve_remote_address((None, 80)) - == ("127.0.0.1", 80)) + got = await sock4.resolve_remote_address((None, 80)) + assert got == ("127.0.0.1", 80) sock6 = tsocket.socket(family=tsocket.AF_INET6) with assert_yields(): - assert (await sock6.resolve_local_address((None, 80)) - == ("::", 80, 0, 0)) + got = await sock6.resolve_local_address((None, 80)) + assert got == ("::", 80, 0, 0) + with assert_yields(): - assert (await sock6.resolve_remote_address((None, 80)) - == ("::1", 80, 0, 0)) + got = await sock6.resolve_remote_address((None, 80)) + assert got == ("::1", 80, 0, 0) # AI_PASSIVE only affects the wildcard address, so for everything else # resolve_local_address and resolve_remote_address should work the same: