Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict mypy typing #3

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions chessnet/events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from collections.abc import Iterable
from collections.abc import Awaitable, Callable, Iterable
from enum import Enum
from typing import Literal, Union
from typing import Dict, Literal, Union
import uuid

from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -41,48 +41,52 @@ class MakeMoveEvent:
Event = Union[
StartGameEvent,
EndGameEvent,
MakeMoveEvent,
]


class Events:
@staticmethod
def start_game(game: Game):
def start_game(game: Game) -> StartGameEvent:
return StartGameEvent(EventType.START_GAME, game)

@staticmethod
def end_game(game_id: str, outcome: str):
def end_game(game_id: str, outcome: str) -> EndGameEvent:
return EndGameEvent(EventType.END_GAME, game_id, outcome)

@staticmethod
def make_move(game_id: str, move: Move, fen_before: str, engine_id: str):
def make_move(game_id: str, move: Move, fen_before: str, engine_id: str) -> MakeMoveEvent:
return MakeMoveEvent(EventType.MAKE_MOVE, game_id, move, fen_before, engine_id)


EventCallback = Callable[[Event], Awaitable[None]]


class Subscription:
def __init__(self, channel: str, typs: Iterable[EventType], callback):
def __init__(self, channel: str, typs: Iterable[EventType], callback: EventCallback):
self.channel = channel
self.typs = set(typs)
self.callback = callback

async def accept(self, event: Event):
async def accept(self, event: Event) -> None:
if event.typ in self.typs:
await self.callback(event)


class Broker:
def __init__(self):
def __init__(self) -> None:
# Subscriptions per channel.
# Special key "*" for all.
self.subscriptions = {}
self.subscriptions: Dict[str, Dict[str, Subscription]] = {}

def publish(self, channel: str, event: Event):
def publish(self, channel: str, event: Event) -> None:
for sub in self.subscriptions.get(channel, {}).values():
asyncio.create_task(sub.accept(event))

for sub in self.subscriptions.get("*", {}).values():
asyncio.create_task(sub.accept(event))

def subscribe(self, channel, typs: Iterable[EventType], callback) -> str:
def subscribe(self, channel: str, typs: Iterable[EventType], callback: EventCallback) -> str:
handle = str(uuid.uuid4())
sub = Subscription(channel, typs, callback)
if channel not in self.subscriptions:
Expand All @@ -91,7 +95,7 @@ def subscribe(self, channel, typs: Iterable[EventType], callback) -> str:
self.subscriptions[channel][handle] = sub
return handle

def unsubscribe(self, handle: str):
def unsubscribe(self, handle: str) -> None:
channels = list(self.subscriptions.keys())
for channel in channels:
if handle in self.subscriptions[channel]:
Expand Down
55 changes: 32 additions & 23 deletions chessnet/fargate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
import functools
import logging
import os
import time
from typing import cast, Any, Dict, Optional, TypeVar

import boto3
from mypy_boto3_ec2.client import EC2Client
from mypy_boto3_ecs.client import ECSClient
from mypy_boto3_ecs.type_defs import AttachmentTypeDef
import chess

from chessnet.storage import Engine
Expand All @@ -23,9 +28,11 @@
log = logging.getLogger(__name__)


def run_in_executor(f):
ReturnType = TypeVar("ReturnType")

def run_in_executor(f: Callable[..., ReturnType]) -> Callable[..., Awaitable[ReturnType]]:
@functools.wraps(f)
async def _async_f(*args, **kwargs):
async def _async_f(*args: Any, **kwargs: Any) -> ReturnType:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: f(*args, **kwargs))
return _async_f
Expand All @@ -39,13 +46,13 @@ class RunningEngine:


class FargateRunner(EngineRunner):
def __init__(self, manager, engine):
def __init__(self, manager: FargateEngineManager, engine: Engine):
self.manager = manager
self._engine = engine
self.protocol = None
self.running_engine = None
self.protocol: Optional[chess.engine.UciProtocol] = None
self.running_engine: Optional[RunningEngine] = None

async def run(self):
async def run(self) -> None:
log.info("Starting container...")
self.running_engine = await self.manager.run_engine(self._engine)
try:
Expand All @@ -54,18 +61,20 @@ async def run(self):
lambda: ProtocolAdapter(chess.engine.UciProtocol()),
host=self.running_engine.ip_addr,
port=self.running_engine.port)
self.protocol = adapter.protocol
self.protocol = cast(ProtocolAdapter, adapter).protocol

log.info("Initializing engine...")
await self.protocol.initialize()
except:
await self.shutdown()
except Exception as e:
await self.shutdown(f"Error during initialization: {type(e).__name__}: {e}")
raise

async def play(self, board, limit):
async def play(self, board: chess.Board, limit: chess.engine.Limit) -> chess.engine.PlayResult:
if self.protocol is None:
raise Exception("Engine is not running")
return await self.protocol.play(board, limit)

async def shutdown(self, reason):
async def shutdown(self, reason: str) -> None:
await self.manager.stop_engine(self.running_engine, reason)

def engine(self) -> Engine:
Expand All @@ -75,14 +84,14 @@ def engine(self) -> Engine:
class FargateEngineManager():
TASK_DEF_VERSION = 2

def __init__(self, cluster):
self.client = boto3.client(
def __init__(self, cluster: str):
self.client: ECSClient = boto3.client(
'ecs',
region_name=AWS_REGION,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
)
self.ec2_client = boto3.client(
self.ec2_client: EC2Client = boto3.client(
'ec2',
region_name=AWS_REGION,
aws_access_key_id=AWS_ACCESS_KEY_ID,
Expand All @@ -91,21 +100,21 @@ def __init__(self, cluster):
self.cluster = cluster

@run_in_executor
def run_engine(self, engine):
def run_engine(self, engine: Engine) -> RunningEngine:
task_def = self._get_or_create_task_definition(engine)
task = self.client.run_task(**self._run_task_configuration(task_def))["tasks"][0]
running_engine = self._wait_for_ready(task["taskArn"])
return running_engine

@run_in_executor
def stop_engine(self, running_engine: RunningEngine, reason: str):
def stop_engine(self, running_engine: RunningEngine, reason: str) -> None:
self.client.stop_task(
cluster=self.cluster,
task=running_engine.task_arn,
reason=reason,
)

def _wait_for_ready(self, task_arn):
def _wait_for_ready(self, task_arn: str) -> RunningEngine:
sleeps = [0, 5, 5, 10, 10, 30, 30, 60, 60]
for sleep in sleeps:
time.sleep(sleep)
Expand Down Expand Up @@ -138,7 +147,7 @@ def _wait_for_ready(self, task_arn):

raise Exception("Task took too long to start")

def _get_or_create_task_definition(self, engine):
def _get_or_create_task_definition(self, engine: Engine) -> str:
task_name = self._safe_name(engine.id())
try:
description = self.client.describe_task_definition(taskDefinition=task_name, include=["TAGS"])
Expand All @@ -155,16 +164,16 @@ def _get_or_create_task_definition(self, engine):
log.info(task_def)
return task_def["family"]

def _safe_name(self, name):
def _safe_name(self, name: str) -> str:
return name.replace("#", "_")

def _eni_detail(self, eni, name):
def _eni_detail(self, eni: AttachmentTypeDef, name: str) -> Any:
values = [d["value"] for d in eni["details"] if d["name"] == name]
if len(values) == 0:
return None
return values[0]

def _task_definition(self, engine):
def _task_definition(self, engine: Engine) -> Dict[str, Any]:
return {
"family": self._safe_name(engine.id()),
"networkMode": "awsvpc",
Expand All @@ -186,7 +195,7 @@ def _task_definition(self, engine):
"tags": [self._version_tag()],
}

def _run_task_configuration(self, task_def):
def _run_task_configuration(self, task_def: str) -> Dict[str, Any]:
return {
"cluster": self.cluster,
"taskDefinition": task_def,
Expand All @@ -205,7 +214,7 @@ def _run_task_configuration(self, task_def):
},
}

def _version_tag(self):
def _version_tag(self) -> Dict[str, str]:
return {
"key": "task_definition_version",
"value": f"v{self.TASK_DEF_VERSION}",
Expand Down
35 changes: 18 additions & 17 deletions chessnet/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Optional

import chess
import docker

from chessnet.events import Broker, Events
from chessnet.runner import EngineRunner
Expand All @@ -27,15 +26,16 @@ async def play_game(broker: Broker, game_id: str, white: EngineRunner, black: En
res = await white.play(board, chess.engine.Limit(time=0.1))
log.info(f"Got move: {res.move}")

broker.publish(game_id, Events.make_move(
game_id,
Move(res.move.uci(), 0),
board.fen(),
white.engine().id(),
))
if res.move is not None:
broker.publish(game_id, Events.make_move(
game_id,
Move(res.move.uci(), 0),
board.fen(),
white.engine().id(),
))

board.push(res.move)
log.info(board)
board.push(res.move)
log.info("\n" + str(board))

outcome = board.outcome()
if outcome is not None:
Expand All @@ -46,14 +46,15 @@ async def play_game(broker: Broker, game_id: str, white: EngineRunner, black: En
res = await black.play(board, chess.engine.Limit(time=0.1))
log.info(f"Got move: {res.move}")

broker.publish(game_id, Events.make_move(
game_id,
Move(res.move.uci(), 0),
board.fen(),
white.engine().id(),
))
board.push(res.move)
log.info("\n" + str(board))
if res.move is not None:
broker.publish(game_id, Events.make_move(
game_id,
Move(res.move.uci(), 0),
board.fen(),
white.engine().id(),
))
board.push(res.move)
log.info("\n" + str(board))

outcome = board.outcome()
if outcome is not None:
Expand Down
Loading