diff --git a/src/spandrel/__init__.py b/src/spandrel/__init__.py index 909abb60..f1a4d943 100644 --- a/src/spandrel/__init__.py +++ b/src/spandrel/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.3" +__version__ = "0.0.4" from .__helpers.canonicalize import canonicalize_state_dict from .__helpers.loader import ModelLoader diff --git a/src/spandrel/architectures/OmniSR/__init__.py b/src/spandrel/architectures/OmniSR/__init__.py index e181204b..15720ba5 100644 --- a/src/spandrel/architectures/OmniSR/__init__.py +++ b/src/spandrel/architectures/OmniSR/__init__.py @@ -9,6 +9,12 @@ def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]: + # Remove junk from the state dict + state_dict_keys = set(state_dict.keys()) + for key in state_dict_keys: + if key.endswith(("total_ops", "total_params")): + del state_dict[key] + num_in_ch = 3 num_out_ch = 3 num_feat = 64 diff --git a/tests/__snapshots__/test_OmniSR.ambr b/tests/__snapshots__/test_OmniSR.ambr index 6a873b23..35f7975c 100644 --- a/tests/__snapshots__/test_OmniSR.ambr +++ b/tests/__snapshots__/test_OmniSR.ambr @@ -1,33 +1,81 @@ -# serializer version: 1 -# name: test_OmniSR_community1 - SRModelDescriptor( - architecture='OmniSR', - input_channels=3, - output_channels=3, - scale=2, - size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), - supports_bfloat16=True, - supports_half=True, - tags=list([ - '64nf', - 'w8', - '5nr', - ]), - ) -# --- -# name: test_OmniSR_community2 - SRModelDescriptor( - architecture='OmniSR', - input_channels=3, - output_channels=3, - scale=4, - size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), - supports_bfloat16=True, - supports_half=True, - tags=list([ - '64nf', - 'w8', - '5nr', - ]), - ) -# --- +# serializer version: 1 +# name: test_OmniSR_community1 + SRModelDescriptor( + architecture='OmniSR', + input_channels=3, + output_channels=3, + scale=2, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + '64nf', + 'w8', + '5nr', + ]), + ) +# --- +# name: test_OmniSR_community2 + SRModelDescriptor( + architecture='OmniSR', + input_channels=3, + output_channels=3, + scale=4, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + '64nf', + 'w8', + '5nr', + ]), + ) +# --- +# name: test_OmniSR_official_x2 + SRModelDescriptor( + architecture='OmniSR', + input_channels=3, + output_channels=3, + scale=2, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + '64nf', + 'w8', + '5nr', + ]), + ) +# --- +# name: test_OmniSR_official_x3 + SRModelDescriptor( + architecture='OmniSR', + input_channels=3, + output_channels=3, + scale=3, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + '64nf', + 'w8', + '5nr', + ]), + ) +# --- +# name: test_OmniSR_official_x4 + SRModelDescriptor( + architecture='OmniSR', + input_channels=3, + output_channels=3, + scale=4, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + '64nf', + 'w8', + '5nr', + ]), + ) +# --- diff --git a/tests/test_GRLIR.py b/tests/test_GRLIR.py index a10aaf15..8da09538 100644 --- a/tests/test_GRLIR.py +++ b/tests/test_GRLIR.py @@ -115,6 +115,7 @@ def test_GRLIR_load(): # upscale is only defined if we have an upsampler and (not a.upsampler or a.upscale == b.upscale) and a.input_resolution == b.input_resolution + # those aren't supported right now # and a.pad_size == b.pad_size # and a.window_size == b.window_size # and a.stripe_size == b.stripe_size diff --git a/tests/test_OmniSR.py b/tests/test_OmniSR.py index ff94a00e..58203335 100644 --- a/tests/test_OmniSR.py +++ b/tests/test_OmniSR.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spandrel.architectures.OmniSR import OmniSR, load from .util import ( @@ -32,6 +34,39 @@ def test_OmniSR_load(): ) +def test_OmniSR_official_x4(snapshot): + file = ModelFile.from_url_zip( + "https://drive.google.com/file/d/17rJXJHBYt4Su8cMDMh-NOWMBdE6ki5em/view", + rel_model_path=Path("OmniSR_X4_DF2K/checkpoints/epoch994_OmniSR.pth"), + name="epoch994_OmniSR_x4.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, OmniSR) + + +def test_OmniSR_official_x3(snapshot): + file = ModelFile.from_url_zip( + "https://drive.google.com/file/d/1Rwg6o-RGC-TEiyVSVT9FS1iHjx5n948h/view", + rel_model_path=Path("OmniSR_X3_DIV2K/checkpoints/epoch919_OmniSR.pth"), + name="epoch919_OmniSR_x3.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, OmniSR) + + +def test_OmniSR_official_x2(snapshot): + file = ModelFile.from_url_zip( + "https://drive.google.com/file/d/18lSvJq9CGCwDomkas2gh8K6UOq8qRLIw/view", + rel_model_path=Path("OmniSR_X2_DIV2K/checkpoints/epoch896_OmniSR.pth"), + name="epoch896_OmniSR_x2.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, OmniSR) + + def test_OmniSR_community1(snapshot): file = ModelFile.from_url( "https://github.com/Phhofm/models/raw/main/2xHFA2kAVCOmniSR/2xHFA2kAVCOmniSR.pth" diff --git a/tests/util.py b/tests/util.py index 4450c8da..374a42d3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,7 +1,10 @@ from __future__ import annotations +import os import re import sys +import tempfile +import zipfile from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -48,6 +51,24 @@ def download_model(url: str, name: str | None = None) -> str: return path +def extract_zip(path: str, rel_model_path: Path | str, name: str): + if not zipfile.is_zipfile(path): + print(f"Skipping {path} because it is not a zip file.") + return + + if (MODEL_DIR / name).exists(): + print(f"Skipping {path} because {name} already exists.") + return + + with zipfile.ZipFile(path, "r") as zip_ref: + with tempfile.TemporaryDirectory() as tmpdirname: + zip_ref.extractall(tmpdirname) + model_path = Path(tmpdirname) / rel_model_path + assert model_path.exists(), f"Expected {model_path} to exist." + model_path.rename(MODEL_DIR / name) + return model_path + + @dataclass class ModelFile: name: str @@ -73,6 +94,17 @@ def from_url(url: str, name: str | None = None): name = name or get_url_file_name(url) return ModelFile(name).download(url) + @staticmethod + def from_url_zip(url: str, rel_model_path: Path | str, name: str | None = None): + name = os.path.basename(rel_model_path) if name is None else name + if (MODEL_DIR / name).exists(): + return ModelFile(name) + path = download_model(url, "temp.zip") + print(f"Extracting {path}...") + extract_zip(path, rel_model_path or name, name) + os.remove(path) + return ModelFile(name or get_url_file_name(url)) + disallowed_props = props("model", "state_dict")