From 75fbc7524f33dd7c5bd9aea93c2907d2738de930 Mon Sep 17 00:00:00 2001 From: marios8543 Date: Mon, 18 Sep 2023 00:31:54 +0300 Subject: [PATCH] type hints on main,plugin,updater,utilites.localsocket --- backend/localsocket.py | 37 ++++++++++++--------- backend/main.py | 15 +++++---- backend/plugin.py | 2 +- backend/updater.py | 27 ++++++++++----- backend/utilities.py | 74 ++++++++++++++++++++++++------------------ 5 files changed, 92 insertions(+), 63 deletions(-) diff --git a/backend/localsocket.py b/backend/localsocket.py index ef0e3933a..3659da038 100644 --- a/backend/localsocket.py +++ b/backend/localsocket.py @@ -1,10 +1,13 @@ -import asyncio, time, random +import asyncio, time +from typing import Awaitable, Callable +import random + from localplatform import ON_WINDOWS BUFFER_LIMIT = 2 ** 20 # 1 MiB class UnixSocket: - def __init__(self, on_new_message): + def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. @@ -46,28 +49,32 @@ async def close_socket_connection(self): self.reader = None async def read_single_line(self) -> str|None: - reader, writer = await self.get_socket_connection() + reader, _ = await self.get_socket_connection() - if self.reader == None: - return None + try: + assert reader + except AssertionError: + return return await self._read_single_line(reader) async def write_single_line(self, message : str): - reader, writer = await self.get_socket_connection() + _, writer = await self.get_socket_connection() - if self.writer == None: - return; + try: + assert writer + except AssertionError: + return await self._write_single_line(writer, message) - async def _read_single_line(self, reader) -> str: + async def _read_single_line(self, reader: asyncio.StreamReader) -> str: line = bytearray() while True: try: line.extend(await reader.readuntil()) except asyncio.LimitOverrunError: - line.extend(await reader.read(reader._limit)) + line.extend(await reader.read(reader._limit)) # type: ignore continue except asyncio.IncompleteReadError as err: line.extend(err.partial) @@ -77,27 +84,27 @@ async def _read_single_line(self, reader) -> str: return line.decode("utf-8") - async def _write_single_line(self, writer, message : str): + async def _write_single_line(self, writer: asyncio.StreamWriter, message : str): if not message.endswith("\n"): message += "\n" writer.write(message.encode("utf-8")) await writer.drain() - async def _listen_for_method_call(self, reader, writer): + async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): while True: line = await self._read_single_line(reader) try: res = await self.on_new_message(line) - except Exception as e: + except Exception: return if res != None: await self._write_single_line(writer, res) class PortSocket (UnixSocket): - def __init__(self, on_new_message): + def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. @@ -125,7 +132,7 @@ async def _open_socket_if_not_exists(self): return True if ON_WINDOWS: - class LocalSocket (PortSocket): + class LocalSocket (PortSocket): # type: ignore pass else: class LocalSocket (UnixSocket): diff --git a/backend/main.py b/backend/main.py index 433b202fa..8857fb225 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ # Change PyInstaller files permissions import sys +from typing import Dict from localplatform import (chmod, chown, service_stop, service_start, ON_WINDOWS, get_log_level, get_live_reload, get_server_port, get_server_host, get_chown_plugin_path, @@ -16,7 +17,7 @@ import aiohttp_cors # type: ignore # Partial imports from aiohttp import client_exceptions -from aiohttp.web import Application, Response, get, run_app, static # type: ignore +from aiohttp.web import Application, Response, Request, get, run_app, static # type: ignore from aiohttp_jinja2 import setup as jinja_setup # local modules @@ -70,7 +71,7 @@ def __init__(self, loop: AbstractEventLoop) -> None: jinja_setup(self.web_app) - async def startup(_): + async def startup(_: Application): if self.settings.getSetting("cef_forward", False): self.loop.create_task(service_start(REMOTE_DEBUGGER_UNIT)) else: @@ -84,16 +85,16 @@ async def startup(_): self.web_app.add_routes([get("/auth/token", self.get_auth_token)]) for route in list(self.web_app.router.routes()): - self.cors.add(route) + self.cors.add(route) # type: ignore self.web_app.add_routes([static("/static", path.join(path.dirname(__file__), 'static'))]) self.web_app.add_routes([static("/legacy", path.join(path.dirname(__file__), 'legacy'))]) - def exception_handler(self, loop, context): + def exception_handler(self, loop: AbstractEventLoop, context: Dict[str, str]): if context["message"] == "Unclosed connection": return loop.default_exception_handler(context) - async def get_auth_token(self, request): + async def get_auth_token(self, request: Request): return Response(text=get_csrf_token()) async def load_plugins(self): @@ -144,7 +145,7 @@ async def loader_reinjector(self): # This is because of https://github.com/aio-libs/aiohttp/blob/3ee7091b40a1bc58a8d7846e7878a77640e96996/aiohttp/client_ws.py#L321 logger.info("CEF has disconnected...") # At this point the loop starts again and we connect to the freshly started Steam client once it is ready. - except Exception as e: + except Exception: logger.error("Exception while reading page events " + format_exc()) await tab.close_websocket() pass @@ -154,7 +155,7 @@ async def loader_reinjector(self): # logger.info("Plugin loader isn't present in Steam anymore, reinjecting...") # await self.inject_javascript(tab) - async def inject_javascript(self, tab: Tab, first=False, request=None): + async def inject_javascript(self, tab: Tab, first: bool=False, request: Request|None=None): logger.info("Loading Decky frontend!") try: if first: diff --git a/backend/plugin.py b/backend/plugin.py index 781d9f7b3..5c1e099fa 100644 --- a/backend/plugin.py +++ b/backend/plugin.py @@ -20,7 +20,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str) -> None: self.plugin_path = plugin_path self.plugin_directory = plugin_directory self.method_call_lock = Lock() - self.socket = LocalSocket(self._on_new_message) + self.socket: LocalSocket = LocalSocket(self._on_new_message) self.version = None diff --git a/backend/updater.py b/backend/updater.py index 6b38dd25d..d7a3d7128 100644 --- a/backend/updater.py +++ b/backend/updater.py @@ -1,23 +1,31 @@ import os import shutil -import uuid from asyncio import sleep -from ensurepip import version from json.decoder import JSONDecodeError from logging import getLogger from os import getcwd, path, remove +from typing import List, TypedDict +from backend.main import PluginManager from localplatform import chmod, service_restart, ON_LINUX, get_keep_systemd_service, get_selinux from aiohttp import ClientSession, web import helpers -from injector import get_gamepadui_tab, inject_to_tab +from injector import get_gamepadui_tab from settings import SettingsManager logger = getLogger("Updater") +class RemoteVerAsset(TypedDict): + name: str + browser_download_url: str +class RemoteVer(TypedDict): + tag_name: str + prerelease: bool + assets: List[RemoteVerAsset] + class Updater: - def __init__(self, context) -> None: + def __init__(self, context: PluginManager) -> None: self.context = context self.settings = self.context.settings # Exposes updater methods to frontend @@ -28,8 +36,8 @@ def __init__(self, context) -> None: "do_restart": self.do_restart, "check_for_updates": self.check_for_updates } - self.remoteVer = None - self.allRemoteVers = None + self.remoteVer: RemoteVer | None = None + self.allRemoteVers: List[RemoteVer] = [] self.localVer = helpers.get_loader_version() try: @@ -44,7 +52,7 @@ def __init__(self, context) -> None: ]) context.loop.create_task(self.version_reloader()) - async def _handle_server_method_call(self, request): + async def _handle_server_method_call(self, request: web.Request): method_name = request.match_info["method_name"] try: args = await request.json() @@ -52,7 +60,7 @@ async def _handle_server_method_call(self, request): args = {} res = {} try: - r = await self.updater_methods[method_name](**args) + r = await self.updater_methods[method_name](**args) # type: ignore res["result"] = r res["success"] = True except Exception as e: @@ -105,7 +113,7 @@ async def check_for_updates(self): selectedBranch = self.get_branch(self.context.settings) async with ClientSession() as web: async with web.request("GET", "https://api.github.com/repos/SteamDeckHomebrew/decky-loader/releases", ssl=helpers.get_ssl_context()) as res: - remoteVersions = await res.json() + remoteVersions: List[RemoteVer] = await res.json() if selectedBranch == 0: logger.debug("release type: release") remoteVersions = list(filter(lambda ver: ver["tag_name"].startswith("v") and not ver["prerelease"] and not ver["tag_name"].find("-pre") > 0 and ver["tag_name"], remoteVersions)) @@ -142,6 +150,7 @@ async def version_reloader(self): async def do_update(self): logger.debug("Starting update.") + assert self.remoteVer version = self.remoteVer["tag_name"] download_url = None download_filename = "PluginLoader" if ON_LINUX else "PluginLoader.exe" diff --git a/backend/utilities.py b/backend/utilities.py index 72b6f008a..d45bec9bb 100644 --- a/backend/utilities.py +++ b/backend/utilities.py @@ -1,3 +1,4 @@ +from os import stat_result import uuid from json.decoder import JSONDecodeError from os.path import splitext @@ -5,12 +6,12 @@ from traceback import format_exc from stat import FILE_ATTRIBUTE_HIDDEN # type: ignore -from asyncio import start_server, gather, open_connection +from asyncio import StreamReader, StreamWriter, start_server, gather, open_connection from aiohttp import ClientSession, web -from typing import Dict +from typing import Callable, Coroutine, Dict, Any, List, TypedDict from logging import getLogger -from backend.browser import PluginInstallType +from backend.browser import PluginInstallRequest, PluginInstallType from backend.main import PluginManager from injector import inject_to_tab, get_gamepadui_tab, close_old_tabs, get_tab from pathlib import Path @@ -18,10 +19,15 @@ import helpers from localplatform import service_stop, service_start, get_home_path, get_username +class FilePickerObj(TypedDict): + file: Path + filest: stat_result + is_dir: bool + class Utilities: def __init__(self, context: PluginManager) -> None: self.context = context - self.util_methods: Dict[] = { + self.util_methods: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = { "ping": self.ping, "http_request": self.http_request, "install_plugin": self.install_plugin, @@ -54,7 +60,7 @@ def __init__(self, context: PluginManager) -> None: web.post("/methods/{method_name}", self._handle_server_method_call) ]) - async def _handle_server_method_call(self, request): + async def _handle_server_method_call(self, request: web.Request): method_name = request.match_info["method_name"] try: args = await request.json() @@ -70,7 +76,7 @@ async def _handle_server_method_call(self, request): res["success"] = False return web.json_response(res) - async def install_plugin(self, artifact="", name="No name", version="dev", hash=False, install_type=PluginInstallType.INSTALL): + async def install_plugin(self, artifact: str="", name: str="No name", version: str="dev", hash: str="", install_type: PluginInstallType=PluginInstallType.INSTALL): return await self.context.plugin_browser.request_plugin_install( artifact=artifact, name=name, @@ -79,21 +85,21 @@ async def install_plugin(self, artifact="", name="No name", version="dev", hash= install_type=install_type ) - async def install_plugins(self, requests): + async def install_plugins(self, requests: List[PluginInstallRequest]): return await self.context.plugin_browser.request_multiple_plugin_installs( requests=requests ) - async def confirm_plugin_install(self, request_id): + async def confirm_plugin_install(self, request_id: str): return await self.context.plugin_browser.confirm_plugin_install(request_id) - def cancel_plugin_install(self, request_id): + async def cancel_plugin_install(self, request_id: str): return self.context.plugin_browser.cancel_plugin_install(request_id) - async def uninstall_plugin(self, name): + async def uninstall_plugin(self, name: str): return await self.context.plugin_browser.uninstall_plugin(name) - async def http_request(self, method="", url="", **kwargs): + async def http_request(self, method: str="", url: str="", **kwargs: Any): async with ClientSession() as web: res = await web.request(method, url, ssl=helpers.get_ssl_context(), **kwargs) text = await res.text() @@ -103,12 +109,13 @@ async def http_request(self, method="", url="", **kwargs): "body": text } - async def ping(self, **kwargs): + async def ping(self, **kwargs: Any): return "pong" - async def execute_in_tab(self, tab, run_async, code): + async def execute_in_tab(self, tab: str, run_async: bool, code: str): try: result = await inject_to_tab(tab, code, run_async) + assert result if "exceptionDetails" in result["result"]: return { "success": False, @@ -125,7 +132,7 @@ async def execute_in_tab(self, tab, run_async, code): "result": e } - async def inject_css_into_tab(self, tab, style): + async def inject_css_into_tab(self, tab: str, style: str): try: css_id = str(uuid.uuid4()) @@ -139,7 +146,7 @@ async def inject_css_into_tab(self, tab, style): }})() """, False) - if "exceptionDetails" in result["result"]: + if result and "exceptionDetails" in result["result"]: return { "success": False, "result": result["result"] @@ -155,7 +162,7 @@ async def inject_css_into_tab(self, tab, style): "result": e } - async def remove_css_from_tab(self, tab, css_id): + async def remove_css_from_tab(self, tab: str, css_id: str): try: result = await inject_to_tab(tab, f""" @@ -167,7 +174,7 @@ async def remove_css_from_tab(self, tab, css_id): }})() """, False) - if "exceptionDetails" in result["result"]: + if result and "exceptionDetails" in result["result"]: return { "success": False, "result": result @@ -182,10 +189,10 @@ async def remove_css_from_tab(self, tab, css_id): "result": e } - async def get_setting(self, key, default): + async def get_setting(self, key: str, default: Any): return self.context.settings.getSetting(key, default) - async def set_setting(self, key, value): + async def set_setting(self, key: str, value: Any): return self.context.settings.setSetting(key, value) async def allow_remote_debugging(self): @@ -210,17 +217,18 @@ async def filepicker_ls(self, if path == None: path = get_home_path() - path = Path(path).resolve() + path_obj = Path(path).resolve() - files, folders = [], [] + files: List[FilePickerObj] = [] + folders: List[FilePickerObj] = [] #Resolving all files/folders in the requested directory - for file in path.iterdir(): + for file in path_obj.iterdir(): if file.exists(): filest = file.stat() is_hidden = file.name.startswith('.') if ON_WINDOWS and not is_hidden: - is_hidden = bool(filest.st_file_attributes & FILE_ATTRIBUTE_HIDDEN) + is_hidden = bool(filest.st_file_attributes & FILE_ATTRIBUTE_HIDDEN) # type: ignore if include_folders and file.is_dir(): if (is_hidden and include_hidden) or not is_hidden: folders.append({"file": file, "filest": filest, "is_dir": True}) @@ -234,9 +242,9 @@ async def filepicker_ls(self, if filter_for is not None: try: if re.compile(filter_for): - files = filter(lambda file: re.search(filter_for, file.name) != None, files) + files = list(filter(lambda file: re.search(filter_for, file["file"].name) != None, files)) except re.error: - files = filter(lambda file: file.name.find(filter_for) != -1, files) + files = list(filter(lambda file: file["file"].name.find(filter_for) != -1, files)) # Ordering logic ord_arg = order_by.split("_") @@ -256,6 +264,9 @@ async def filepicker_ls(self, files.sort(key=lambda x: x['filest'].st_size, reverse = not rev) # Folders has no file size, order by name instead folders.sort(key=lambda x: x['file'].name.casefold()) + case _: + files.sort(key=lambda x: x['file'].name.casefold(), reverse = rev) + folders.sort(key=lambda x: x['file'].name.casefold(), reverse = rev) #Constructing the final file list, folders first all = [{ @@ -275,14 +286,14 @@ async def filepicker_ls(self, # Based on https://stackoverflow.com/a/46422554/13174603 - def start_rdt_proxy(self, ip, port): - async def pipe(reader, writer): + def start_rdt_proxy(self, ip: str, port: int): + async def pipe(reader: StreamReader, writer: StreamWriter): try: while not reader.at_eof(): writer.write(await reader.read(2048)) finally: writer.close() - async def handle_client(local_reader, local_writer): + async def handle_client(local_reader: StreamReader, local_writer: StreamWriter): try: remote_reader, remote_writer = await open_connection( ip, port) @@ -298,7 +309,8 @@ async def handle_client(local_reader, local_writer): def stop_rdt_proxy(self): if self.rdt_proxy_server: self.rdt_proxy_server.close() - self.rdt_proxy_task.cancel() + if self.rdt_proxy_task: + self.rdt_proxy_task.cancel() async def _enable_rdt(self): # TODO un-hardcode port @@ -348,11 +360,11 @@ async def disable_rdt(self): await tab.evaluate_js("location.reload();", False, True, False) self.logger.info("React DevTools disabled") - async def get_user_info(self) -> dict: + async def get_user_info(self) -> Dict[str, str]: return { "username": get_username(), "path": get_home_path() } - async def get_tab_id(self, name): + async def get_tab_id(self, name: str): return (await get_tab(name)).id