Skip to content

Commit

Permalink
Improve reliability of reload (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Dec 7, 2023
1 parent e3a340f commit 04eac2a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 7 deletions.
7 changes: 6 additions & 1 deletion aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import json
import mimetypes
import sys
import time
import warnings
from errno import EADDRINUSE
from pathlib import Path
from typing import Any, Iterator, NoReturn, Optional, Set, Tuple
from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple

from aiohttp import WSMsgType, web
from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH
Expand All @@ -33,6 +34,7 @@
LIVE_RELOAD_LOCAL_SNIPPET = b'\n<script src="/livereload.js"></script>\n'
HOST = '0.0.0.0'

LAST_RELOAD = web.AppKey("LAST_RELOAD", List[float])
LIVERELOAD_SCRIPT = web.AppKey("LIVERELOAD_SCRIPT", bytes)
STATIC_PATH = web.AppKey("STATIC_PATH", str)
STATIC_URL = web.AppKey("STATIC_URL", str)
Expand Down Expand Up @@ -240,6 +242,8 @@ async def src_reload(app: web.Application, path: Optional[str] = None) -> int:
else:
reloads += 1

app[LAST_RELOAD][0] = len(app[WS])
app[LAST_RELOAD][1] = time.time()
if reloads:
s = '' if reloads == 1 else 's'
aux_logger.info('prompted reload of %s on %d client%s', path or 'page', reloads, s)
Expand All @@ -256,6 +260,7 @@ def create_auxiliary_app(
browser_cache: bool = False) -> web.Application:
app = web.Application()
ws: Set[Tuple[web.WebSocketResponse, str]] = set()
app[LAST_RELOAD] = [0, 0.]
app[STATIC_PATH] = static_path or ""
app[STATIC_URL] = static_url
app[WS] = ws
Expand Down
17 changes: 15 additions & 2 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import signal
import sys
import time
from contextlib import suppress
from multiprocessing import Process
from pathlib import Path
Expand All @@ -14,7 +15,7 @@
from ..exceptions import AiohttpDevException
from ..logs import rs_dft_logger as logger
from .config import Config
from .serve import STATIC_PATH, WS, serve_main_app, src_reload
from .serve import LAST_RELOAD, STATIC_PATH, WS, serve_main_app, src_reload


class WatchTask:
Expand All @@ -27,7 +28,7 @@ def __init__(self, path: Union[Path, str]):
async def start(self, app: web.Application) -> None:
self._app = app
self.stopper = asyncio.Event()
self._awatch = awatch(self._path, stop_event=self.stopper)
self._awatch = awatch(self._path, stop_event=self.stopper, step=250)
self._task = asyncio.create_task(self._run())

async def _run(self) -> None:
Expand Down Expand Up @@ -71,8 +72,20 @@ def is_static(changes: Iterable[Tuple[object, str]]) -> bool:

async for changes in self._awatch:
self._reloads += 1
logger.debug("file changes: %s", changes)
if any(f.endswith('.py') for _, f in changes):
logger.debug('%d changes, restarting server', len(changes))

count, t = self._app[LAST_RELOAD]
if len(self._app[WS]) < count:
wait_delay = max(t + 5 - time.time(), 0)
logger.debug("waiting upto %s seconds before restarting", wait_delay)

for i in range(int(wait_delay / 0.1)):
await asyncio.sleep(0.1)
if len(self._app[WS]) >= count:
break

await self._stop_dev_server()
self._start_dev_server()
await self._src_reload_when_live(live_checks)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_runserver_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiohttp_devtools.runserver.config import Config
from aiohttp_devtools.runserver.log_handlers import fmt_size
from aiohttp_devtools.runserver.serve import (
STATIC_PATH, STATIC_URL, WS, check_port_open, cleanup_aux_app,
LAST_RELOAD, STATIC_PATH, STATIC_URL, WS, check_port_open, cleanup_aux_app,
modify_main_app, src_reload)

from .conftest import SIMPLE_APP, create_future
Expand All @@ -36,6 +36,7 @@ async def test_aux_reload(smart_caplog):
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app[LAST_RELOAD] = [0, 0.]
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
Expand All @@ -56,6 +57,7 @@ async def test_aux_reload_no_path():
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app[LAST_RELOAD] = [0, 0.]
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
Expand All @@ -74,6 +76,7 @@ async def test_aux_reload_html_different():
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app[LAST_RELOAD] = [0, 0.]
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
Expand All @@ -86,6 +89,7 @@ async def test_aux_reload_runtime_error(smart_caplog):
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
ws.send_str = MagicMock(side_effect=RuntimeError('foobar'))
aux_app[LAST_RELOAD] = [0, 0.]
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
Expand Down
40 changes: 37 additions & 3 deletions tests/test_runserver_watch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import time
from functools import partial
from typing import Set, Tuple
from unittest.mock import MagicMock, call
from typing import Any, Set, Tuple
from unittest.mock import AsyncMock, MagicMock, call

from aiohttp import ClientSession
from aiohttp.web import Application, WebSocketResponse

from aiohttp_devtools.runserver.serve import STATIC_PATH, WS
from aiohttp_devtools.runserver.serve import LAST_RELOAD, STATIC_PATH, WS
from aiohttp_devtools.runserver.watch import AppTask, LiveReloadTask

from .conftest import create_future
Expand Down Expand Up @@ -81,6 +82,7 @@ async def test_python_no_server(event_loop, mocker):
stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True)
mocker.patch.object(app_task, "_run", partial(app_task._run, live_checks=2))
app = Application()
app[LAST_RELOAD] = [0, 0.]
app[STATIC_PATH] = "/path/to/"
app.src_reload = MagicMock()
mock_ws = MagicMock()
Expand Down Expand Up @@ -192,3 +194,35 @@ async def test_stop_process_dirty(mocker):
await app_task._stop_dev_server()
assert mock_kill.call_args_list == [call(321, 2)]
assert process_mock.kill.called_once()


async def test_restart_after_connection_loss(mocker):
mocked_awatch = mocker.patch("aiohttp_devtools.runserver.watch.awatch", autospec=True, spec_set=True)
mocked_awatch.side_effect = create_awatch_mock({("x", "/path/to/file.py")})
app_task = AppTask(MagicMock())
start_mock = mocker.patch.object(app_task, "_start_dev_server", autospec=True, spec_set=True)
mock_reload = mocker.patch.object(app_task, "_src_reload_when_live", autospec=True, spec_set=True)
mocker.patch.object(app_task, "_stop_dev_server", autospec=True, spec_set=True)

app = mocker.create_autospec(Application, spec_set=True, instance=True)
# Simulate connection lost from recent restart.
ws: Set[Any] = set()
d = {WS: ws, LAST_RELOAD: [1, time.time()]}
app.__getitem__.side_effect = lambda k: d.get(k, MagicMock())

def update_ws(i):
ws.add(MagicMock(spec_set=()))
return AsyncMock()

sleep_mock = mocker.patch("asyncio.sleep", autospec=True, spec_set=True)
sleep_mock.side_effect = update_ws

await app_task.start(app)
assert app_task._task is not None
await app_task._task
assert sleep_mock.call_count < 5
assert call(0.1) in sleep_mock.call_args_list
mock_reload.assert_called_once()
assert start_mock.call_count == 2
assert app_task._session is not None
await app_task._session.close()

0 comments on commit 04eac2a

Please sign in to comment.