Skip to content

Commit

Permalink
Merge branch 'main' into grl-persistent-params
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 28, 2023
2 parents a836d84 + 175f51b commit c9d8ad0
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/spandrel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.3"
__version__ = "0.0.4"

from .__helpers.canonicalize import canonicalize_state_dict
from .__helpers.loader import ModelLoader
Expand Down
6 changes: 6 additions & 0 deletions src/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@


def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]:
# Remove junk from the state dict
state_dict_keys = set(state_dict.keys())
for key in state_dict_keys:
if key.endswith(("total_ops", "total_params")):
del state_dict[key]

num_in_ch = 3
num_out_ch = 3
num_feat = 64
Expand Down
114 changes: 81 additions & 33 deletions tests/__snapshots__/test_OmniSR.ambr
Original file line number Diff line number Diff line change
@@ -1,33 +1,81 @@
# serializer version: 1
# name: test_OmniSR_community1
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# name: test_OmniSR_community2
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=4,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# serializer version: 1
# name: test_OmniSR_community1
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# name: test_OmniSR_community2
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=4,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# name: test_OmniSR_official_x2
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# name: test_OmniSR_official_x3
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=3,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
# name: test_OmniSR_official_x4
SRModelDescriptor(
architecture='OmniSR',
input_channels=3,
output_channels=3,
scale=4,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64nf',
'w8',
'5nr',
]),
)
# ---
1 change: 1 addition & 0 deletions tests/test_GRLIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_GRLIR_load():
# upscale is only defined if we have an upsampler
and (not a.upsampler or a.upscale == b.upscale)
and a.input_resolution == b.input_resolution
# those aren't supported right now
# and a.pad_size == b.pad_size
# and a.window_size == b.window_size
# and a.stripe_size == b.stripe_size
Expand Down
35 changes: 35 additions & 0 deletions tests/test_OmniSR.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from spandrel.architectures.OmniSR import OmniSR, load

from .util import (
Expand Down Expand Up @@ -32,6 +34,39 @@ def test_OmniSR_load():
)


def test_OmniSR_official_x4(snapshot):
file = ModelFile.from_url_zip(
"https://drive.google.com/file/d/17rJXJHBYt4Su8cMDMh-NOWMBdE6ki5em/view",
rel_model_path=Path("OmniSR_X4_DF2K/checkpoints/epoch994_OmniSR.pth"),
name="epoch994_OmniSR_x4.pth",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, OmniSR)


def test_OmniSR_official_x3(snapshot):
file = ModelFile.from_url_zip(
"https://drive.google.com/file/d/1Rwg6o-RGC-TEiyVSVT9FS1iHjx5n948h/view",
rel_model_path=Path("OmniSR_X3_DIV2K/checkpoints/epoch919_OmniSR.pth"),
name="epoch919_OmniSR_x3.pth",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, OmniSR)


def test_OmniSR_official_x2(snapshot):
file = ModelFile.from_url_zip(
"https://drive.google.com/file/d/18lSvJq9CGCwDomkas2gh8K6UOq8qRLIw/view",
rel_model_path=Path("OmniSR_X2_DIV2K/checkpoints/epoch896_OmniSR.pth"),
name="epoch896_OmniSR_x2.pth",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, OmniSR)


def test_OmniSR_community1(snapshot):
file = ModelFile.from_url(
"https://github.com/Phhofm/models/raw/main/2xHFA2kAVCOmniSR/2xHFA2kAVCOmniSR.pth"
Expand Down
32 changes: 32 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import os
import re
import sys
import tempfile
import zipfile
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -48,6 +51,24 @@ def download_model(url: str, name: str | None = None) -> str:
return path


def extract_zip(path: str, rel_model_path: Path | str, name: str):
if not zipfile.is_zipfile(path):
print(f"Skipping {path} because it is not a zip file.")
return

if (MODEL_DIR / name).exists():
print(f"Skipping {path} because {name} already exists.")
return

with zipfile.ZipFile(path, "r") as zip_ref:
with tempfile.TemporaryDirectory() as tmpdirname:
zip_ref.extractall(tmpdirname)
model_path = Path(tmpdirname) / rel_model_path
assert model_path.exists(), f"Expected {model_path} to exist."
model_path.rename(MODEL_DIR / name)
return model_path


@dataclass
class ModelFile:
name: str
Expand All @@ -73,6 +94,17 @@ def from_url(url: str, name: str | None = None):
name = name or get_url_file_name(url)
return ModelFile(name).download(url)

@staticmethod
def from_url_zip(url: str, rel_model_path: Path | str, name: str | None = None):
name = os.path.basename(rel_model_path) if name is None else name
if (MODEL_DIR / name).exists():
return ModelFile(name)
path = download_model(url, "temp.zip")
print(f"Extracting {path}...")
extract_zip(path, rel_model_path or name, name)
os.remove(path)
return ModelFile(name or get_url_file_name(url))


disallowed_props = props("model", "state_dict")

Expand Down

0 comments on commit c9d8ad0

Please sign in to comment.