From 0ab72c96aa89ede67106fd0005aa72fac7d76bef Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sat, 18 Nov 2023 00:35:11 -0500 Subject: [PATCH 1/3] enable annotation rules --- pyproject.toml | 22 ++++++++++--------- src/spandrel/__helpers/unpickler.py | 2 +- .../architectures/Compact/__init__.py | 8 ++++--- src/spandrel/architectures/ESRGAN/__init__.py | 8 +++---- src/spandrel/architectures/SPSR/__init__.py | 4 ++-- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 83db3748..7f4e8c15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dynamic = ["version"] [project.optional-dependencies] build = ["setuptools>=46.4.0", "wheel", "build", "twine"] -lint = ["ruff==0.1.4"] +lint = ["ruff==0.1.6"] typecheck = ["pyright==1.1.335"] test = ["pytest==7.4.0", "syrupy==4.6.0", "opencv-python==4.8.1.78"] @@ -53,13 +53,15 @@ src = ["src"] # extend-exclude = ["src/architectures/**"] extend-select = [ - "UP", # pyupgrade - "E", # pycodestyle - "W", # pycodestyle - # "F", # pyflakes - "I", # isort - "FA", # flake8-future-annotations - "N", # pep8-naming + "UP", # pyupgrade + "E", # pycodestyle + "W", # pycodestyle + "F", # pyflakes + "I", # isort + "FA", # flake8-future-annotations + "N", # pep8-naming + "ANN001", + "ANN002", ] ignore = [ "E501", # Line too long @@ -68,8 +70,8 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"**/arch/**/*" = ["N"] -"**/__arch_helpers/**/*" = ["N"] +"**/arch/**/*" = ["N", "ANN"] +"**/__arch_helpers/**/*" = ["N", "ANN"] "**/tests/**/*" = ["N802"] [tool.pytest.ini_options] diff --git a/src/spandrel/__helpers/unpickler.py b/src/spandrel/__helpers/unpickler.py index 279d82e2..90a37668 100644 --- a/src/spandrel/__helpers/unpickler.py +++ b/src/spandrel/__helpers/unpickler.py @@ -16,7 +16,7 @@ class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): + def find_class(self, module: str, name: str): # Only allow required classes to load state dict if (module, name) not in safe_list: raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden") diff --git a/src/spandrel/architectures/Compact/__init__.py b/src/spandrel/architectures/Compact/__init__.py index 53599464..c941cad5 100644 --- a/src/spandrel/architectures/Compact/__init__.py +++ b/src/spandrel/architectures/Compact/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math from ...__helpers.model_descriptor import SRModelDescriptor, StateDict @@ -8,15 +10,15 @@ def _get_num_conv(highest_num: int) -> int: return (highest_num - 2) // 2 -def _get_num_feats(state, weight_keys) -> int: +def _get_num_feats(state: StateDict, weight_keys: list[str]) -> int: return state[weight_keys[0]].shape[0] -def _get_in_nc(state, weight_keys) -> int: +def _get_in_nc(state: StateDict, weight_keys: list[str]) -> int: return state[weight_keys[0]].shape[1] -def _get_scale(pixelshuffle_shape, out_nc) -> int: +def _get_scale(pixelshuffle_shape: int, out_nc: int) -> int: scale = math.sqrt(pixelshuffle_shape / out_nc) if scale - int(scale) > 0: print( diff --git a/src/spandrel/architectures/ESRGAN/__init__.py b/src/spandrel/architectures/ESRGAN/__init__.py index a4a443a0..8e5a1813 100644 --- a/src/spandrel/architectures/ESRGAN/__init__.py +++ b/src/spandrel/architectures/ESRGAN/__init__.py @@ -7,7 +7,7 @@ from .arch.RRDB import RRDBNet -def _new_to_old_arch(state, state_map, num_blocks): +def _new_to_old_arch(state: StateDict, state_map: dict, num_blocks: int): """Convert a new-arch model state dictionary to an old-arch dictionary.""" if "params_ema" in state: state = state["params_ema"] @@ -56,7 +56,7 @@ def _new_to_old_arch(state, state_map, num_blocks): old_state[f"model.{max_upconv + 4}.bias"] = state[key] # Sort by first numeric value of each layer - def compare(item1, item2): + def compare(item1: str, item2: str): parts1 = item1.split(".") parts2 = item2.split(".") int1 = int(parts1[1]) @@ -71,7 +71,7 @@ def compare(item1, item2): return out_dict -def _get_scale(state, min_part: int = 6) -> int: +def _get_scale(state: StateDict, min_part: int = 6) -> int: n = 0 for part in list(state): parts = part.split(".")[1:] @@ -82,7 +82,7 @@ def _get_scale(state, min_part: int = 6) -> int: return 2**n -def _get_num_blocks(state, state_map) -> int: +def _get_num_blocks(state: StateDict, state_map: dict) -> int: nbs = [] state_keys = state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", diff --git a/src/spandrel/architectures/SPSR/__init__.py b/src/spandrel/architectures/SPSR/__init__.py index 0978f0a5..d5bd9140 100644 --- a/src/spandrel/architectures/SPSR/__init__.py +++ b/src/spandrel/architectures/SPSR/__init__.py @@ -2,7 +2,7 @@ from .arch.SPSR import SPSRNet as SPSR -def get_scale(state, min_part: int = 4) -> int: +def get_scale(state: StateDict, min_part: int = 4) -> int: n = 0 for part in list(state): parts = part.split(".") @@ -13,7 +13,7 @@ def get_scale(state, min_part: int = 4) -> int: return 2**n -def get_num_blocks(state) -> int: +def get_num_blocks(state: StateDict) -> int: nb = 0 for part in list(state): parts = part.split(".") From 3cedfc910b4759f3f760a0775b44219aefcc1647 Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sat, 18 Nov 2023 00:37:53 -0500 Subject: [PATCH 2/3] ignore annotations rules for tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7f4e8c15..ce205430 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "**/arch/**/*" = ["N", "ANN"] "**/__arch_helpers/**/*" = ["N", "ANN"] -"**/tests/**/*" = ["N802"] +"**/tests/**/*" = ["N802", "ANN"] [tool.pytest.ini_options] filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning"] From 9a2a4e36c870674821335647251544304d5394eb Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sat, 18 Nov 2023 00:40:55 -0500 Subject: [PATCH 3/3] add missing annotation --- scripts/dump_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dump_state_dict.py b/scripts/dump_state_dict.py index 31a06fa0..cc6d1aee 100644 --- a/scripts/dump_state_dict.py +++ b/scripts/dump_state_dict.py @@ -69,7 +69,7 @@ def load_state(file: str) -> State: return state_dict -def indent(lines: list[str], indentation=" "): +def indent(lines: list[str], indentation: str = " "): def do(line: str) -> str: return "\n".join(indentation + s for s in line.splitlines())