Skip to content

Commit

Permalink
Add/complete type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
CoolCat467 authored Jan 30, 2025
2 parents e7706f4 + 35a6235 commit f5b2014
Show file tree
Hide file tree
Showing 12 changed files with 673 additions and 352 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ lint:
$(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/

typecheck:
$(PYTHON) -m mypy --explicit-package-bases trio_websocket tests autobahn examples
$(PYTHON) -m mypy

publish:
rm -fr build dist .egg trio_websocket.egg-info
Expand Down
17 changes: 10 additions & 7 deletions autobahn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
logger = logging.getLogger('client')


async def get_case_count(url):
async def get_case_count(url: str) -> int:
url = url + '/getCaseCount'
async with open_websocket_url(url) as conn:
case_count = await conn.get_message()
logger.info('Case count=%s', case_count)
return int(case_count)


async def get_case_info(url, case):
async def get_case_info(url: str, case: str) -> object:
url = f'{url}/getCaseInfo?case={case}'
async with open_websocket_url(url) as conn:
return json.loads(await conn.get_message())


async def run_case(url, case):
async def run_case(url: str, case: str) -> None:
url = f'{url}/runCase?case={case}&agent={AGENT}'
try:
async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn:
Expand All @@ -42,15 +42,15 @@ async def run_case(url, case):
pass


async def update_reports(url):
async def update_reports(url: str) -> None:
url = f'{url}/updateReports?agent={AGENT}'
async with open_websocket_url(url) as conn:
# This command runs as soon as we connect to it, so we don't need to
# send any messages.
pass


async def run_tests(args):
async def run_tests(args: argparse.Namespace) -> None:
logger = logging.getLogger('trio-websocket')
if args.debug_cases:
# Don't fetch case count when debugging a subset of test cases. It adds
Expand All @@ -62,7 +62,10 @@ async def run_tests(args):
test_cases = list(range(1, case_count + 1))
exception_cases = []
for case in test_cases:
case_id = (await get_case_info(args.url, case))['id']
result = await get_case_info(args.url, case)
assert isinstance(result, dict)
case_id = result['id']
assert isinstance(case_id, int)
if case_count:
logger.info("Running test case %s (%d of %d)", case_id, case, case_count)
else:
Expand All @@ -82,7 +85,7 @@ async def run_tests(args):
sys.exit(1)


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Autobahn client for'
' trio-websocket')
Expand Down
6 changes: 3 additions & 3 deletions autobahn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
connection_count = 0


async def main():
async def main() -> None:
''' Main entry point. '''
logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT)
await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None,
max_message_size=MAX_MESSAGE_SIZE)


async def handler(request: WebSocketRequest):
async def handler(request: WebSocketRequest) -> None:
''' Reverse incoming websocket messages and send them back. '''
global connection_count # pylint: disable=global-statement
connection_count += 1
Expand All @@ -46,7 +46,7 @@ async def handler(request: WebSocketRequest):
logger.exception(' runtime exception handling connection #%d', connection_count)


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Autobahn server for'
' trio-websocket')
Expand Down
30 changes: 19 additions & 11 deletions examples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
import ssl
import sys
import urllib.parse
from typing import NoReturn

import trio
from trio_websocket import open_websocket_url, ConnectionClosed, HandshakeError
from trio_websocket import (
open_websocket_url,
ConnectionClosed,
HandshakeError,
WebSocketConnection,
CloseReason,
)


logging.basicConfig(level=logging.DEBUG)
here = pathlib.Path(__file__).parent


def commands():
def commands() -> None:
''' Print the supported commands. '''
print('Commands: ')
print('send <MESSAGE> -> send message')
Expand All @@ -29,7 +36,7 @@ def commands():
print()


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Example trio-websocket client')
parser.add_argument('--heartbeat', action='store_true',
Expand All @@ -38,7 +45,7 @@ def parse_args():
return parser.parse_args()


async def main(args):
async def main(args: argparse.Namespace) -> bool:
''' Main entry point, returning False in the case of logged error. '''
if urllib.parse.urlsplit(args.url).scheme == 'wss':
# Configure SSL context to handle our self-signed certificate. Most
Expand All @@ -59,9 +66,10 @@ async def main(args):
except HandshakeError as e:
logging.error('Connection attempt failed: %s', e)
return False
return True


async def handle_connection(ws, use_heartbeat):
async def handle_connection(ws: WebSocketConnection, use_heartbeat: bool) -> None:
''' Handle the connection. '''
logging.debug('Connected!')
try:
Expand All @@ -71,11 +79,12 @@ async def handle_connection(ws, use_heartbeat):
nursery.start_soon(get_commands, ws)
nursery.start_soon(get_messages, ws)
except ConnectionClosed as cc:
assert isinstance(cc.reason, CloseReason)
reason = '<no reason>' if cc.reason.reason is None else f'"{cc.reason.reason}"'
print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}')


async def heartbeat(ws, timeout, interval):
async def heartbeat(ws: WebSocketConnection, timeout: float, interval: float) -> NoReturn:
'''
Send periodic pings on WebSocket ``ws``.
Expand All @@ -99,11 +108,10 @@ async def heartbeat(ws, timeout, interval):
await trio.sleep(interval)


async def get_commands(ws):
async def get_commands(ws: WebSocketConnection) -> None:
''' In a loop: get a command from the user and execute it. '''
while True:
cmd = await trio.to_thread.run_sync(input, 'cmd> ',
cancellable=True)
cmd = await trio.to_thread.run_sync(input, 'cmd> ')
if cmd.startswith('ping'):
payload = cmd[5:].encode('utf8') or None
await ws.ping(payload)
Expand All @@ -123,11 +131,11 @@ async def get_commands(ws):
await trio.sleep(0.25)


async def get_messages(ws):
async def get_messages(ws: WebSocketConnection) -> None:
''' In a loop: get a WebSocket message and print it out. '''
while True:
message = await ws.get_message()
print(f'message: {message}')
print(f'message: {message!r}')


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/generate-cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import trustme

def main():
def main() -> None:
here = pathlib.Path(__file__).parent
ca_path = here / 'fake.ca.pem'
server_path = here / 'fake.server.pem'
Expand Down
8 changes: 4 additions & 4 deletions examples/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
import ssl

import trio
from trio_websocket import serve_websocket, ConnectionClosed
from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest


logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
here = pathlib.Path(__file__).parent


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Example trio-websocket client')
parser.add_argument('--ssl', action='store_true', help='Use SSL')
Expand All @@ -32,7 +32,7 @@ def parse_args():
return parser.parse_args()


async def main(args):
async def main(args: argparse.Namespace) -> None:
''' Main entry point. '''
logging.info('Starting websocket server…')
if args.ssl:
Expand All @@ -48,7 +48,7 @@ async def main(args):
await serve_websocket(handler, host, args.port, ssl_context)


async def handler(request):
async def handler(request: WebSocketRequest) -> None:
''' Reverse incoming websocket messages and send them back. '''
logging.info('Handler starting on path "%s"', request.path)
ws = await request.accept()
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[tool.mypy]
explicit_package_bases = true
files = ["trio_websocket", "tests", "autobahn", "examples"]
show_column_numbers = true
show_error_codes = true
show_traceback = true
disallow_any_decorated = true
disallow_any_unimported = true
ignore_missing_imports = true
local_partial_types = true
no_implicit_optional = true
strict = true
warn_unreachable = true
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Typing :: Typed',
],
python_requires=">=3.8",
keywords='websocket client server trio',
packages=find_packages(exclude=['docs', 'examples', 'tests']),
package_data={"trio-websocket": ["py.typed"]},
install_requires=[
'exceptiongroup; python_version<"3.11"',
'trio>=0.11',
Expand Down
Loading

0 comments on commit f5b2014

Please sign in to comment.