Skip to content

Commit

Permalink
Central function to canonicalize state dicts (#40)
Browse files Browse the repository at this point in the history
* Central function to canonicalize state dicts

* fixes

* Make public

* Return a canonicalized state dict
  • Loading branch information
RunDevelopment authored Nov 22, 2023
1 parent 6e29e02 commit 49f4494
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 43 deletions.
10 changes: 1 addition & 9 deletions scripts/dump_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,7 @@


def load_state(file: str) -> State:
state_dict = ModelLoader().load_state_dict_from_file(file)

unwrap_keys = ["params_ema", "params-ema", "params", "model", "net"]
for unwrap_key in unwrap_keys:
if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict):
state_dict = state_dict[unwrap_key]
break

return state_dict
return ModelLoader().load_state_dict_from_file(file)


def indent(lines: list[str], indentation: str = " "):
Expand Down
27 changes: 27 additions & 0 deletions src/spandrel/__helpers/canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .model_descriptor import StateDict


def canonicalize_state_dict(state_dict: StateDict) -> StateDict:
"""
Canonicalize a state dict.
This function is used to canonicalize a state dict, so that it can be
used for architecture detection and loading.
This function is not intended to be used in production code.
"""

# the real state dict might be inside a dict with a known key
unwrap_keys = ["state_dict", "params_ema", "params-ema", "params", "model", "net"]
for unwrap_key in unwrap_keys:
if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict):
state_dict = state_dict[unwrap_key]
break

# remove known common prefixes
if len(state_dict) > 0:
for prefix in ["module.", "netG."]:
if all(i.startswith(prefix) for i in state_dict.keys()):
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}

return state_dict
25 changes: 9 additions & 16 deletions src/spandrel/__helpers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from safetensors.torch import load_file

from .canonicalize import canonicalize_state_dict
from .main_registry import MAIN_REGISTRY
from .model_descriptor import ModelDescriptor, StateDict
from .registry import ArchRegistry
Expand Down Expand Up @@ -52,19 +53,22 @@ def load_state_dict_from_file(self, path: str | Path) -> StateDict:

extension = os.path.splitext(path)[1].lower()

state_dict: StateDict
if extension == ".pt":
return self._load_torchscript(path)
state_dict = self._load_torchscript(path)
elif extension == ".pth":
return self._load_pth(path)
state_dict = self._load_pth(path)
elif extension == ".ckpt":
return self._load_ckpt(path)
state_dict = self._load_ckpt(path)
elif extension == ".safetensors":
return self._load_safetensors(path)
state_dict = self._load_safetensors(path)
else:
raise ValueError(
f"Unsupported model file extension {extension}. Please try a supported model type."
)

return canonicalize_state_dict(state_dict)

def load_from_state_dict(self, state_dict: StateDict) -> ModelDescriptor:
"""
Load a model from the given state dict.
Expand All @@ -90,19 +94,8 @@ def _load_safetensors(self, path: str | Path) -> StateDict:
return load_file(path, device=str(self.device))

def _load_ckpt(self, path: str | Path) -> StateDict:
checkpoint = torch.load(
return torch.load(
path,
map_location=self.device,
pickle_module=RestrictedUnpickle, # type: ignore
)
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
state_dict = {}
for i, j in checkpoint.items():
if "netG." in i:
key = i.replace("netG.", "")
state_dict[key] = j
elif "module." in i:
key = i.replace("module.", "")
state_dict[key] = j
return state_dict
8 changes: 0 additions & 8 deletions src/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,6 @@ def _detect(state_dict: StateDict) -> bool:
"decoders.0.0.attgamma",
"ending.weight",
)(state)
# some KBNet_s models are prefixed with "module." for some reason
or _has_keys(
"module.intro.weight",
"module.encoders.0.0.attgamma",
"module.middle_blks.0.w",
"module.decoders.0.0.attgamma",
"module.ending.weight",
)(state)
# KBNet_l
or _has_keys(
"patch_embed.proj.weight",
Expand Down
9 changes: 4 additions & 5 deletions src/spandrel/__helpers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Callable, Literal

from .canonicalize import canonicalize_state_dict
from .model_descriptor import ModelDescriptor, StateDict


Expand Down Expand Up @@ -146,14 +147,12 @@ def load(self, state_dict: StateDict) -> ModelDescriptor:
"""
Detects the architecture of the given state dict and loads it.
This will canonicalize the state dict if it isn't already.
Throws an `UnsupportedModelError` if the model architecture is not supported.
"""

unwrap_keys = ["params_ema", "params-ema", "params", "model", "net"]
for unwrap_key in unwrap_keys:
if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict):
state_dict = state_dict[unwrap_key]
break
state_dict = canonicalize_state_dict(state_dict)

for arch in self._ordered:
if arch.detect(state_dict):
Expand Down
4 changes: 3 additions & 1 deletion src/spandrel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = "0.0.3"

from .__helpers.canonicalize import canonicalize_state_dict
from .__helpers.loader import ModelLoader
from .__helpers.main_registry import MAIN_REGISTRY
from .__helpers.model_descriptor import (
Expand All @@ -17,13 +18,14 @@
__all__ = [
"ArchRegistry",
"ArchSupport",
"RestorationModelDescriptor",
"canonicalize_state_dict",
"FaceSRModelDescriptor",
"InpaintModelDescriptor",
"MAIN_REGISTRY",
"ModelBase",
"ModelDescriptor",
"ModelLoader",
"RestorationModelDescriptor",
"SizeRequirements",
"SRModelDescriptor",
"StateDict",
Expand Down
4 changes: 0 additions & 4 deletions src/spandrel/architectures/KBNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def load_l(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_l]:


def load_s(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_s]:
# remove module. prefix
if "module.intro.weight" in state_dict:
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}

img_channel = 3
width = 64
middle_blk_num = 12
Expand Down

0 comments on commit 49f4494

Please sign in to comment.