Skip to content

Commit

Permalink
type hints on main,plugin,updater,utilites.localsocket
Browse files Browse the repository at this point in the history
  • Loading branch information
marios8543 authored and AAGaming00 committed Sep 25, 2023
1 parent ecc5f5c commit 75fbc75
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 63 deletions.
37 changes: 22 additions & 15 deletions backend/localsocket.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio, time, random
import asyncio, time
from typing import Awaitable, Callable
import random

from localplatform import ON_WINDOWS

BUFFER_LIMIT = 2 ** 20 # 1 MiB

class UnixSocket:
def __init__(self, on_new_message):
def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Expand Down Expand Up @@ -46,28 +49,32 @@ async def close_socket_connection(self):
self.reader = None

async def read_single_line(self) -> str|None:
reader, writer = await self.get_socket_connection()
reader, _ = await self.get_socket_connection()

if self.reader == None:
return None
try:
assert reader
except AssertionError:
return

return await self._read_single_line(reader)

async def write_single_line(self, message : str):
reader, writer = await self.get_socket_connection()
_, writer = await self.get_socket_connection()

if self.writer == None:
return;
try:
assert writer
except AssertionError:
return

await self._write_single_line(writer, message)

async def _read_single_line(self, reader) -> str:
async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
line = bytearray()
while True:
try:
line.extend(await reader.readuntil())
except asyncio.LimitOverrunError:
line.extend(await reader.read(reader._limit))
line.extend(await reader.read(reader._limit)) # type: ignore
continue
except asyncio.IncompleteReadError as err:
line.extend(err.partial)
Expand All @@ -77,27 +84,27 @@ async def _read_single_line(self, reader) -> str:

return line.decode("utf-8")

async def _write_single_line(self, writer, message : str):
async def _write_single_line(self, writer: asyncio.StreamWriter, message : str):
if not message.endswith("\n"):
message += "\n"

writer.write(message.encode("utf-8"))
await writer.drain()

async def _listen_for_method_call(self, reader, writer):
async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while True:
line = await self._read_single_line(reader)

try:
res = await self.on_new_message(line)
except Exception as e:
except Exception:
return

if res != None:
await self._write_single_line(writer, res)

class PortSocket (UnixSocket):
def __init__(self, on_new_message):
def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Expand Down Expand Up @@ -125,7 +132,7 @@ async def _open_socket_if_not_exists(self):
return True

if ON_WINDOWS:
class LocalSocket (PortSocket):
class LocalSocket (PortSocket): # type: ignore
pass
else:
class LocalSocket (UnixSocket):
Expand Down
15 changes: 8 additions & 7 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Change PyInstaller files permissions
import sys
from typing import Dict
from localplatform import (chmod, chown, service_stop, service_start,
ON_WINDOWS, get_log_level, get_live_reload,
get_server_port, get_server_host, get_chown_plugin_path,
Expand All @@ -16,7 +17,7 @@
import aiohttp_cors # type: ignore
# Partial imports
from aiohttp import client_exceptions
from aiohttp.web import Application, Response, get, run_app, static # type: ignore
from aiohttp.web import Application, Response, Request, get, run_app, static # type: ignore
from aiohttp_jinja2 import setup as jinja_setup

# local modules
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(self, loop: AbstractEventLoop) -> None:

jinja_setup(self.web_app)

async def startup(_):
async def startup(_: Application):
if self.settings.getSetting("cef_forward", False):
self.loop.create_task(service_start(REMOTE_DEBUGGER_UNIT))
else:
Expand All @@ -84,16 +85,16 @@ async def startup(_):
self.web_app.add_routes([get("/auth/token", self.get_auth_token)])

for route in list(self.web_app.router.routes()):
self.cors.add(route)
self.cors.add(route) # type: ignore
self.web_app.add_routes([static("/static", path.join(path.dirname(__file__), 'static'))])
self.web_app.add_routes([static("/legacy", path.join(path.dirname(__file__), 'legacy'))])

def exception_handler(self, loop, context):
def exception_handler(self, loop: AbstractEventLoop, context: Dict[str, str]):
if context["message"] == "Unclosed connection":
return
loop.default_exception_handler(context)

async def get_auth_token(self, request):
async def get_auth_token(self, request: Request):
return Response(text=get_csrf_token())

async def load_plugins(self):
Expand Down Expand Up @@ -144,7 +145,7 @@ async def loader_reinjector(self):
# This is because of https://github.com/aio-libs/aiohttp/blob/3ee7091b40a1bc58a8d7846e7878a77640e96996/aiohttp/client_ws.py#L321
logger.info("CEF has disconnected...")
# At this point the loop starts again and we connect to the freshly started Steam client once it is ready.
except Exception as e:
except Exception:
logger.error("Exception while reading page events " + format_exc())
await tab.close_websocket()
pass
Expand All @@ -154,7 +155,7 @@ async def loader_reinjector(self):
# logger.info("Plugin loader isn't present in Steam anymore, reinjecting...")
# await self.inject_javascript(tab)

async def inject_javascript(self, tab: Tab, first=False, request=None):
async def inject_javascript(self, tab: Tab, first: bool=False, request: Request|None=None):
logger.info("Loading Decky frontend!")
try:
if first:
Expand Down
2 changes: 1 addition & 1 deletion backend/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str) -> None:
self.plugin_path = plugin_path
self.plugin_directory = plugin_directory
self.method_call_lock = Lock()
self.socket = LocalSocket(self._on_new_message)
self.socket: LocalSocket = LocalSocket(self._on_new_message)

self.version = None

Expand Down
27 changes: 18 additions & 9 deletions backend/updater.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import os
import shutil
import uuid
from asyncio import sleep
from ensurepip import version
from json.decoder import JSONDecodeError
from logging import getLogger
from os import getcwd, path, remove
from typing import List, TypedDict
from backend.main import PluginManager
from localplatform import chmod, service_restart, ON_LINUX, get_keep_systemd_service, get_selinux

from aiohttp import ClientSession, web

import helpers
from injector import get_gamepadui_tab, inject_to_tab
from injector import get_gamepadui_tab
from settings import SettingsManager

logger = getLogger("Updater")

class RemoteVerAsset(TypedDict):
name: str
browser_download_url: str
class RemoteVer(TypedDict):
tag_name: str
prerelease: bool
assets: List[RemoteVerAsset]

class Updater:
def __init__(self, context) -> None:
def __init__(self, context: PluginManager) -> None:
self.context = context
self.settings = self.context.settings
# Exposes updater methods to frontend
Expand All @@ -28,8 +36,8 @@ def __init__(self, context) -> None:
"do_restart": self.do_restart,
"check_for_updates": self.check_for_updates
}
self.remoteVer = None
self.allRemoteVers = None
self.remoteVer: RemoteVer | None = None
self.allRemoteVers: List[RemoteVer] = []
self.localVer = helpers.get_loader_version()

try:
Expand All @@ -44,15 +52,15 @@ def __init__(self, context) -> None:
])
context.loop.create_task(self.version_reloader())

async def _handle_server_method_call(self, request):
async def _handle_server_method_call(self, request: web.Request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
except JSONDecodeError:
args = {}
res = {}
try:
r = await self.updater_methods[method_name](**args)
r = await self.updater_methods[method_name](**args) # type: ignore
res["result"] = r
res["success"] = True
except Exception as e:
Expand Down Expand Up @@ -105,7 +113,7 @@ async def check_for_updates(self):
selectedBranch = self.get_branch(self.context.settings)
async with ClientSession() as web:
async with web.request("GET", "https://api.github.com/repos/SteamDeckHomebrew/decky-loader/releases", ssl=helpers.get_ssl_context()) as res:
remoteVersions = await res.json()
remoteVersions: List[RemoteVer] = await res.json()
if selectedBranch == 0:
logger.debug("release type: release")
remoteVersions = list(filter(lambda ver: ver["tag_name"].startswith("v") and not ver["prerelease"] and not ver["tag_name"].find("-pre") > 0 and ver["tag_name"], remoteVersions))
Expand Down Expand Up @@ -142,6 +150,7 @@ async def version_reloader(self):

async def do_update(self):
logger.debug("Starting update.")
assert self.remoteVer
version = self.remoteVer["tag_name"]
download_url = None
download_filename = "PluginLoader" if ON_LINUX else "PluginLoader.exe"
Expand Down
Loading

0 comments on commit 75fbc75

Please sign in to comment.