diff --git a/scripts/dump_state_dict.py b/scripts/dump_state_dict.py index 7fab5e6e..728b843e 100644 --- a/scripts/dump_state_dict.py +++ b/scripts/dump_state_dict.py @@ -57,15 +57,7 @@ def load_state(file: str) -> State: - state_dict = ModelLoader().load_state_dict_from_file(file) - - unwrap_keys = ["params_ema", "params-ema", "params", "model", "net"] - for unwrap_key in unwrap_keys: - if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict): - state_dict = state_dict[unwrap_key] - break - - return state_dict + return ModelLoader().load_state_dict_from_file(file) def indent(lines: list[str], indentation: str = " "): diff --git a/src/spandrel/__helpers/canonicalize.py b/src/spandrel/__helpers/canonicalize.py new file mode 100644 index 00000000..f9523e88 --- /dev/null +++ b/src/spandrel/__helpers/canonicalize.py @@ -0,0 +1,27 @@ +from .model_descriptor import StateDict + + +def canonicalize_state_dict(state_dict: StateDict) -> StateDict: + """ + Canonicalize a state dict. + + This function is used to canonicalize a state dict, so that it can be + used for architecture detection and loading. + + This function is not intended to be used in production code. + """ + + # the real state dict might be inside a dict with a known key + unwrap_keys = ["state_dict", "params_ema", "params-ema", "params", "model", "net"] + for unwrap_key in unwrap_keys: + if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict): + state_dict = state_dict[unwrap_key] + break + + # remove known common prefixes + if len(state_dict) > 0: + for prefix in ["module.", "netG."]: + if all(i.startswith(prefix) for i in state_dict.keys()): + state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} + + return state_dict diff --git a/src/spandrel/__helpers/loader.py b/src/spandrel/__helpers/loader.py index 97a8bf25..67c0f010 100644 --- a/src/spandrel/__helpers/loader.py +++ b/src/spandrel/__helpers/loader.py @@ -6,6 +6,7 @@ import torch from safetensors.torch import load_file +from .canonicalize import canonicalize_state_dict from .main_registry import MAIN_REGISTRY from .model_descriptor import ModelDescriptor, StateDict from .registry import ArchRegistry @@ -52,19 +53,22 @@ def load_state_dict_from_file(self, path: str | Path) -> StateDict: extension = os.path.splitext(path)[1].lower() + state_dict: StateDict if extension == ".pt": - return self._load_torchscript(path) + state_dict = self._load_torchscript(path) elif extension == ".pth": - return self._load_pth(path) + state_dict = self._load_pth(path) elif extension == ".ckpt": - return self._load_ckpt(path) + state_dict = self._load_ckpt(path) elif extension == ".safetensors": - return self._load_safetensors(path) + state_dict = self._load_safetensors(path) else: raise ValueError( f"Unsupported model file extension {extension}. Please try a supported model type." ) + return canonicalize_state_dict(state_dict) + def load_from_state_dict(self, state_dict: StateDict) -> ModelDescriptor: """ Load a model from the given state dict. @@ -90,19 +94,8 @@ def _load_safetensors(self, path: str | Path) -> StateDict: return load_file(path, device=str(self.device)) def _load_ckpt(self, path: str | Path) -> StateDict: - checkpoint = torch.load( + return torch.load( path, map_location=self.device, pickle_module=RestrictedUnpickle, # type: ignore ) - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - state_dict = {} - for i, j in checkpoint.items(): - if "netG." in i: - key = i.replace("netG.", "") - state_dict[key] = j - elif "module." in i: - key = i.replace("module.", "") - state_dict[key] = j - return state_dict diff --git a/src/spandrel/__helpers/main_registry.py b/src/spandrel/__helpers/main_registry.py index 59f48516..448da938 100644 --- a/src/spandrel/__helpers/main_registry.py +++ b/src/spandrel/__helpers/main_registry.py @@ -181,14 +181,6 @@ def _detect(state_dict: StateDict) -> bool: "decoders.0.0.attgamma", "ending.weight", )(state) - # some KBNet_s models are prefixed with "module." for some reason - or _has_keys( - "module.intro.weight", - "module.encoders.0.0.attgamma", - "module.middle_blks.0.w", - "module.decoders.0.0.attgamma", - "module.ending.weight", - )(state) # KBNet_l or _has_keys( "patch_embed.proj.weight", diff --git a/src/spandrel/__helpers/registry.py b/src/spandrel/__helpers/registry.py index 3f095d62..0db55a75 100644 --- a/src/spandrel/__helpers/registry.py +++ b/src/spandrel/__helpers/registry.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Callable, Literal +from .canonicalize import canonicalize_state_dict from .model_descriptor import ModelDescriptor, StateDict @@ -146,14 +147,12 @@ def load(self, state_dict: StateDict) -> ModelDescriptor: """ Detects the architecture of the given state dict and loads it. + This will canonicalize the state dict if it isn't already. + Throws an `UnsupportedModelError` if the model architecture is not supported. """ - unwrap_keys = ["params_ema", "params-ema", "params", "model", "net"] - for unwrap_key in unwrap_keys: - if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict): - state_dict = state_dict[unwrap_key] - break + state_dict = canonicalize_state_dict(state_dict) for arch in self._ordered: if arch.detect(state_dict): diff --git a/src/spandrel/__init__.py b/src/spandrel/__init__.py index 1d52f059..909abb60 100644 --- a/src/spandrel/__init__.py +++ b/src/spandrel/__init__.py @@ -1,5 +1,6 @@ __version__ = "0.0.3" +from .__helpers.canonicalize import canonicalize_state_dict from .__helpers.loader import ModelLoader from .__helpers.main_registry import MAIN_REGISTRY from .__helpers.model_descriptor import ( @@ -17,13 +18,14 @@ __all__ = [ "ArchRegistry", "ArchSupport", - "RestorationModelDescriptor", + "canonicalize_state_dict", "FaceSRModelDescriptor", "InpaintModelDescriptor", "MAIN_REGISTRY", "ModelBase", "ModelDescriptor", "ModelLoader", + "RestorationModelDescriptor", "SizeRequirements", "SRModelDescriptor", "StateDict", diff --git a/src/spandrel/architectures/KBNet/__init__.py b/src/spandrel/architectures/KBNet/__init__.py index d9e7b2a9..11f9ff06 100644 --- a/src/spandrel/architectures/KBNet/__init__.py +++ b/src/spandrel/architectures/KBNet/__init__.py @@ -70,10 +70,6 @@ def load_l(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_l]: def load_s(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_s]: - # remove module. prefix - if "module.intro.weight" in state_dict: - state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()} - img_channel = 3 width = 64 middle_blk_num = 12