From d8c22562296ca8b20802b0c1da73a3e004958cbc Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 8 Jul 2024 17:38:44 +0200 Subject: [PATCH 1/4] Define the public API by what's documented --- .../architectures/ATD/__arch/__init__.py | 0 .../spandrel/architectures/ATD/__init__.py | 3 + .../architectures/CRAFT/__arch/__init__.py | 0 .../spandrel/architectures/CRAFT/__init__.py | 3 + .../architectures/Compact/__arch/__init__.py | 0 .../architectures/DAT/__arch/__init__.py | 0 .../spandrel/architectures/DAT/__init__.py | 3 + .../architectures/DCTLSA/__arch/__init__.py | 0 .../spandrel/architectures/DCTLSA/__init__.py | 3 + .../architectures/DITN/__arch/__init__.py | 0 .../spandrel/architectures/DITN/__init__.py | 3 + .../architectures/DRCT/__arch/__init__.py | 0 .../spandrel/architectures/DRCT/__init__.py | 3 + .../architectures/DRUNet/__arch/__init__.py | 0 .../spandrel/architectures/DRUNet/__init__.py | 3 + .../architectures/DnCNN/__arch/__init__.py | 0 .../spandrel/architectures/DnCNN/__init__.py | 3 + .../architectures/ESRGAN/__arch/__init__.py | 0 .../architectures/FBCNN/__arch/__init__.py | 0 .../spandrel/architectures/FBCNN/__init__.py | 3 + .../FFTformer/__arch/__init__.py | 0 .../architectures/FFTformer/__init__.py | 3 + .../architectures/GFPGAN/__arch/__init__.py | 0 .../architectures/GRL/__arch/__init__.py | 0 .../spandrel/architectures/GRL/__init__.py | 3 + .../architectures/HAT/__arch/__init__.py | 0 .../spandrel/architectures/HAT/__init__.py | 3 + .../HVICIDNet/__arch/__init__.py | 0 .../architectures/HVICIDNet/__init__.py | 3 + .../architectures/IPT/__arch/__init__.py | 0 .../spandrel/architectures/IPT/__init__.py | 3 + .../architectures/KBNet/__arch/__init__.py | 0 .../spandrel/architectures/KBNet/__init__.py | 3 + .../architectures/LaMa/__arch/__init__.py | 0 .../spandrel/architectures/LaMa/__init__.py | 3 + .../architectures/MMRealSR/__arch/__init__.py | 0 .../architectures/MMRealSR/__init__.py | 3 + .../MixDehazeNet/__arch/__init__.py | 0 .../architectures/MixDehazeNet/__init__.py | 3 + .../architectures/NAFNet/__arch/__init__.py | 0 .../spandrel/architectures/NAFNet/__init__.py | 3 + .../architectures/OmniSR/__arch/__init__.py | 0 .../spandrel/architectures/OmniSR/__init__.py | 3 + .../architectures/PLKSR/__arch/__init__.py | 0 .../spandrel/architectures/PLKSR/__init__.py | 3 + .../architectures/RGT/__arch/__init__.py | 0 .../spandrel/architectures/RGT/__init__.py | 3 + .../RealCUGAN/__arch/__init__.py | 0 .../architectures/RealCUGAN/__init__.py | 3 + .../RestoreFormer/__arch/__init__.py | 0 .../architectures/RestoreFormer/__init__.py | 3 + .../RetinexFormer/__arch/__init__.py | 0 .../architectures/RetinexFormer/__init__.py | 3 + .../architectures/SAFMN/__arch/__init__.py | 0 .../spandrel/architectures/SAFMN/__init__.py | 3 + .../SAFMNBCIE/__arch/__init__.py | 0 .../architectures/SCUNet/__arch/__init__.py | 0 .../spandrel/architectures/SCUNet/__init__.py | 3 + .../architectures/SPAN/__arch/__init__.py | 0 .../spandrel/architectures/SPAN/__init__.py | 3 + .../SwiftSRGAN/__arch/__init__.py | 0 .../architectures/SwiftSRGAN/__init__.py | 3 + .../architectures/Swin2SR/__arch/__init__.py | 0 .../architectures/Swin2SR/__init__.py | 3 + .../architectures/SwinIR/__arch/__init__.py | 0 .../spandrel/architectures/SwinIR/__init__.py | 3 + .../architectures/Uformer/__arch/__init__.py | 0 .../architectures/Uformer/__init__.py | 3 + .../spandrel_extra_arches/__helper.py | 5 + .../spandrel_extra_arches/__init__.py | 6 + .../architectures/AdaCode/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/AdaCode/__init__.py | 3 + .../CodeFormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/CodeFormer/__init__.py | 3 + .../architectures/DDColor/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/DDColor/__init__.py | 3 + .../architectures/FeMaSR/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/FeMaSR/__init__.py | 3 + .../architectures/M3SNet/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/M3SNet/__init__.py | 3 + .../architectures/MAT/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MAT/__init__.py | 3 + .../architectures/MIRNet2/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MIRNet2/__init__.py | 3 + .../architectures/MPRNet/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MPRNet/__init__.py | 3 + .../Restormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/Restormer/__init__.py | 3 + .../architectures/SRFormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/SRFormer/__init__.py | 3 + .../architectures/__init__.py | 3 + pyproject.toml | 6 +- 92 files changed, 1605 insertions(+), 1 deletion(-) create mode 100644 libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py diff --git a/libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py b/libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/ATD/__init__.py b/libs/spandrel/spandrel/architectures/ATD/__init__.py index 1603368d..3ce04198 100644 --- a/libs/spandrel/spandrel/architectures/ATD/__init__.py +++ b/libs/spandrel/spandrel/architectures/ATD/__init__.py @@ -168,3 +168,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=8), ) + + +__all__ = ["ATDArch", "ATD"] diff --git a/libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/CRAFT/__init__.py b/libs/spandrel/spandrel/architectures/CRAFT/__init__.py index afb2f061..51fef69e 100644 --- a/libs/spandrel/spandrel/architectures/CRAFT/__init__.py +++ b/libs/spandrel/spandrel/architectures/CRAFT/__init__.py @@ -126,3 +126,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CRAFT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16, multiple_of=16), ) + + +__all__ = ["CRAFTArch", "CRAFT"] diff --git a/libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DAT/__init__.py b/libs/spandrel/spandrel/architectures/DAT/__init__.py index 6ed57566..2a545fec 100644 --- a/libs/spandrel/spandrel/architectures/DAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DAT/__init__.py @@ -177,3 +177,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["DATArch", "DAT"] diff --git a/libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py b/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py index b420227e..f6266c81 100644 --- a/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py +++ b/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py @@ -82,3 +82,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["DCTLSAArch", "DCTLSA"] diff --git a/libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DITN/__init__.py b/libs/spandrel/spandrel/architectures/DITN/__init__.py index d03c1d03..b343eda6 100644 --- a/libs/spandrel/spandrel/architectures/DITN/__init__.py +++ b/libs/spandrel/spandrel/architectures/DITN/__init__.py @@ -85,3 +85,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DITN]: output_channels=3, # hard-coded in the architecture size_requirements=SizeRequirements(multiple_of=patch_size), ) + + +__all__ = ["DITNArch", "DITN"] diff --git a/libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DRCT/__init__.py b/libs/spandrel/spandrel/architectures/DRCT/__init__.py index 64d7bae1..7d59e489 100644 --- a/libs/spandrel/spandrel/architectures/DRCT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DRCT/__init__.py @@ -155,3 +155,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]: output_channels=in_chans, size_requirements=SizeRequirements(multiple_of=16), ) + + +__all__ = ["DRCTArch", "DRCT"] diff --git a/libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DRUNet/__init__.py b/libs/spandrel/spandrel/architectures/DRUNet/__init__.py index 44d47c6e..5ee074ab 100644 --- a/libs/spandrel/spandrel/architectures/DRUNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/DRUNet/__init__.py @@ -99,3 +99,6 @@ def call(model: DRUNet, image: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(multiple_of=8), call_fn=call, ) + + +__all__ = ["DRUNetArch", "DRUNet"] diff --git a/libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DnCNN/__init__.py b/libs/spandrel/spandrel/architectures/DnCNN/__init__.py index fa62bd10..e383ba7d 100644 --- a/libs/spandrel/spandrel/architectures/DnCNN/__init__.py +++ b/libs/spandrel/spandrel/architectures/DnCNN/__init__.py @@ -125,3 +125,6 @@ def call(model: DnCNN, image: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(), call_fn=call, ) + + +__all__ = ["DnCNNArch", "DnCNN"] diff --git a/libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FBCNN/__init__.py b/libs/spandrel/spandrel/architectures/FBCNN/__init__.py index 46f2cff4..dca55516 100644 --- a/libs/spandrel/spandrel/architectures/FBCNN/__init__.py +++ b/libs/spandrel/spandrel/architectures/FBCNN/__init__.py @@ -79,3 +79,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FBCNN]: output_channels=out_nc, call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["FBCNNArch", "FBCNN"] diff --git a/libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FFTformer/__init__.py b/libs/spandrel/spandrel/architectures/FFTformer/__init__.py index eb69bda7..b7109f73 100644 --- a/libs/spandrel/spandrel/architectures/FFTformer/__init__.py +++ b/libs/spandrel/spandrel/architectures/FFTformer/__init__.py @@ -98,3 +98,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=32), ) + + +__all__ = ["FFTformerArch", "FFTformer"] diff --git a/libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py b/libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/GRL/__init__.py b/libs/spandrel/spandrel/architectures/GRL/__init__.py index 851be55e..a115d610 100644 --- a/libs/spandrel/spandrel/architectures/GRL/__init__.py +++ b/libs/spandrel/spandrel/architectures/GRL/__init__.py @@ -359,3 +359,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GRL]: input_channels=in_channels, output_channels=out_channels, ) + + +__all__ = ["GRLArch", "GRL"] diff --git a/libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/HAT/__init__.py b/libs/spandrel/spandrel/architectures/HAT/__init__.py index 2a219d26..686d91d6 100644 --- a/libs/spandrel/spandrel/architectures/HAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/HAT/__init__.py @@ -225,3 +225,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["HATArch", "HAT"] diff --git a/libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py b/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py index 55abcd13..b30bb5e5 100644 --- a/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py @@ -92,3 +92,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]: size_requirements=SizeRequirements(multiple_of=8), tiling=ModelTiling.DISCOURAGED, ) + + +__all__ = ["HVICIDNetArch", "HVICIDNet"] diff --git a/libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/IPT/__init__.py b/libs/spandrel/spandrel/architectures/IPT/__init__.py index aa511863..0e2f5eed 100644 --- a/libs/spandrel/spandrel/architectures/IPT/__init__.py +++ b/libs/spandrel/spandrel/architectures/IPT/__init__.py @@ -152,3 +152,6 @@ def call(model: IPT, x: torch.Tensor): size_requirements=SizeRequirements(minimum=patch_size), call_fn=call, ) + + +__all__ = ["IPTArch", "IPT"] diff --git a/libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/KBNet/__init__.py b/libs/spandrel/spandrel/architectures/KBNet/__init__.py index 9a574a5f..81f3b646 100644 --- a/libs/spandrel/spandrel/architectures/KBNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/KBNet/__init__.py @@ -166,3 +166,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_KBNet]: return self._load_l(state_dict) else: return self._load_s(state_dict) + + +__all__ = ["KBNetArch", "KBNet_s", "KBNet_l"] diff --git a/libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/LaMa/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__init__.py index b3008936..c03ba393 100644 --- a/libs/spandrel/spandrel/architectures/LaMa/__init__.py +++ b/libs/spandrel/spandrel/architectures/LaMa/__init__.py @@ -53,3 +53,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["LaMaArch", "LaMa"] diff --git a/libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py b/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py index ca519a75..161146f2 100644 --- a/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py @@ -193,3 +193,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MMRealSR]: size_requirements=SizeRequirements(minimum=16), call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["MMRealSRArch", "MMRealSR"] diff --git a/libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py b/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py index 601314d3..d6b98858 100644 --- a/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py @@ -112,3 +112,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MixDehazeNet]: tiling=ModelTiling.DISCOURAGED, call_fn=lambda model, image: model(image) * 0.5 + 0.5, ) + + +__all__ = ["MixDehazeNetArch", "MixDehazeNet"] diff --git a/libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/NAFNet/__init__.py b/libs/spandrel/spandrel/architectures/NAFNet/__init__.py index c9aae804..7f1e3cbe 100644 --- a/libs/spandrel/spandrel/architectures/NAFNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/NAFNet/__init__.py @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[NAFNet]: input_channels=img_channel, output_channels=img_channel, ) + + +__all__ = ["NAFNetArch", "NAFNet"] diff --git a/libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/OmniSR/__init__.py b/libs/spandrel/spandrel/architectures/OmniSR/__init__.py index 97904a86..808d02a7 100644 --- a/libs/spandrel/spandrel/architectures/OmniSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/OmniSR/__init__.py @@ -96,3 +96,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[OmniSR]: output_channels=num_out_ch, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["OmniSRArch", "OmniSR"] diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py index c8f3cd1b..5fd6f3ab 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py @@ -143,3 +143,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: input_channels=3, output_channels=3, ) + + +__all__ = ["PLKSRArch", "PLKSR", "RealPLKSR"] diff --git a/libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RGT/__init__.py b/libs/spandrel/spandrel/architectures/RGT/__init__.py index df6981be..3fc8a357 100644 --- a/libs/spandrel/spandrel/architectures/RGT/__init__.py +++ b/libs/spandrel/spandrel/architectures/RGT/__init__.py @@ -169,3 +169,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["RGTArch", "RGT"] diff --git a/libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py b/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py index 0255b81a..5d978482 100644 --- a/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py @@ -104,3 +104,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]: output_channels=out_channels, size_requirements=size_requirements, ) + + +__all__ = ["RealCUGANArch", "UpCunet2x", "UpCunet3x", "UpCunet4x", "UpCunet2x_fast"] diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py b/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py index ee572d38..5952a926 100644 --- a/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py +++ b/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py @@ -101,3 +101,6 @@ def call(model: RestoreFormer, x: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(multiple_of=32), call_fn=call, ) + + +__all__ = ["RestoreFormerArch", "RestoreFormer"] diff --git a/libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py b/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py index afe22e18..048efda6 100644 --- a/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py +++ b/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py @@ -105,3 +105,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RetinexFormer]: tiling=ModelTiling.DISCOURAGED, call_fn=_call_fn, ) + + +__all__ = ["RetinexFormerArch", "RetinexFormer"] diff --git a/libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SAFMN/__init__.py b/libs/spandrel/spandrel/architectures/SAFMN/__init__.py index c629399a..77ee625f 100644 --- a/libs/spandrel/spandrel/architectures/SAFMN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SAFMN/__init__.py @@ -67,3 +67,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN]: output_channels=3, # hard-coded in the arch size_requirements=SizeRequirements(multiple_of=8), ) + + +__all__ = ["SAFMNArch", "SAFMN"] diff --git a/libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SCUNet/__init__.py b/libs/spandrel/spandrel/architectures/SCUNet/__init__.py index 3e9d3778..9753ddad 100644 --- a/libs/spandrel/spandrel/architectures/SCUNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/SCUNet/__init__.py @@ -63,3 +63,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SCUNet]: size_requirements=SizeRequirements(minimum=40), tiling=ModelTiling.DISCOURAGED, ) + + +__all__ = ["SCUNetArch", "SCUNet"] diff --git a/libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SPAN/__init__.py b/libs/spandrel/spandrel/architectures/SPAN/__init__.py index ccce5590..2eabec00 100644 --- a/libs/spandrel/spandrel/architectures/SPAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SPAN/__init__.py @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SPAN]: input_channels=num_in_ch, output_channels=num_out_ch, ) + + +__all__ = ["SPANArch", "SPAN"] diff --git a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py index f0ae8255..c55618f0 100644 --- a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwiftSRGAN]: input_channels=in_channels, output_channels=in_channels, ) + + +__all__ = ["SwiftSRGANArch", "SwiftSRGAN"] diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py index f755d938..cbe37dad 100644 --- a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py +++ b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py @@ -184,3 +184,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["Swin2SRArch", "Swin2SR"] diff --git a/libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py index e60f0f10..94fc03f0 100644 --- a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py +++ b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py @@ -189,3 +189,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["SwinIRArch", "SwinIR"] diff --git a/libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/Uformer/__init__.py b/libs/spandrel/spandrel/architectures/Uformer/__init__.py index 2b4b0c2a..6540556d 100644 --- a/libs/spandrel/spandrel/architectures/Uformer/__init__.py +++ b/libs/spandrel/spandrel/architectures/Uformer/__init__.py @@ -141,3 +141,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: output_channels=dd_in, size_requirements=SizeRequirements(multiple_of=128, square=True), ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index d759ec6b..f56b814b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -14,6 +14,11 @@ ) EXTRA_REGISTRY = ArchRegistry() +""" +The registry of all architectures in this library. + +Use ``MAIN_REGISTRY.add(*EXTRA_REGISTRY)`` to add all architectures to the main registry of `spandrel`. +""" EXTRA_REGISTRY.add( ArchSupport.from_architecture(SRFormer.SRFormerArch()), diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py index 2a639290..9ebb2503 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py @@ -1,3 +1,9 @@ +""" +Spandrel extra arches contains more architectures for `spandrel`. + +All architectures in this library are registered in the `EXTRA_REGISTRY` dictionary. +""" + from .__helper import EXTRA_REGISTRY __version__ = "0.1.1" diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py index 557a2a59..91decbe1 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py @@ -148,3 +148,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[AdaCode]: output_channels=in_channel, size_requirements=SizeRequirements(multiple_of=multiple_of), ) + + +__all__ = ["AdaCodeArch", "AdaCode"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py index e7fbf21b..cfcbe067 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py @@ -73,3 +73,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CodeFormer]: size_requirements=SizeRequirements(multiple_of=512, square=True), call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["CodeFormerArch", "CodeFormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py index 90fb28f9..8f9aabef 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py @@ -191,3 +191,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DDColor]: tiling=ModelTiling.INTERNAL, call_fn=_call, ) + + +__all__ = ["DDColorArch", "DDColor"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py index a51c82dd..75ce1950 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py @@ -152,3 +152,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FeMaSR]: output_channels=in_channel, size_requirements=SizeRequirements(multiple_of=multiple_of), ) + + +__all__ = ["FeMaSRArch", "FeMaSR"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py index fa2c37da..de5006ea 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py @@ -97,3 +97,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[M3SNet]: output_channels=img_channel, size_requirements=SizeRequirements(multiple_of=16), ) + + +__all__ = ["M3SNetArch", "M3SNet"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py index ca416fd6..bb4cd2b9 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py @@ -48,3 +48,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[MAT]: minimum=512, multiple_of=512, square=True ), ) + + +__all__ = ["MATArch", "MAT"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py index e7ea5f42..6bcc8e30 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py @@ -95,3 +95,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MIRNet2]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=4), ) + + +__all__ = ["MIRNet2Arch", "MIRNet2"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py index b9e01552..abefdf6e 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py @@ -106,3 +106,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MPRNet]: size_requirements=SizeRequirements(multiple_of=8), call_fn=lambda model, x: model(x)[0], ) + + +__all__ = ["MPRNetArch", "MPRNet"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py index 8ccf0974..7eb38a49 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py @@ -120,3 +120,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Restormer]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=8), ) + + +__all__ = ["RestormerArch", "Restormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py index 2572da28..73e6f702 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py @@ -182,3 +182,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRFormer]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["SRFormerArch", "SRFormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py new file mode 100644 index 00000000..7e6c7cbe --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py @@ -0,0 +1,3 @@ +""" +The package containing the implementations of all supported architectures. Not necessary for most user code. +""" diff --git a/pyproject.toml b/pyproject.toml index fc7b8b62..00766939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ pythonpath = ["libs/spandrel", "libs/spandrel_extra_arches"] [tool.pydoctor] project-name = "spandrel" -add-package = ["libs/spandrel/spandrel"] +add-package = ["libs/spandrel/spandrel", "libs/spandrel_extra_arches/spandrel_extra_arches"] project-url = "https://github.com/chaiNNer-org/spandrel" docformat = "restructuredtext" warnings-as-errors = false @@ -51,5 +51,9 @@ theme = "readthedocs" privacy = [ "HIDDEN:spandrel.__version__", "HIDDEN:spandrel.__helpers", + "HIDDEN:spandrel.architectures.*.__arch", "PRIVATE:spandrel.canonicalize_state_dict", + "HIDDEN:spandrel_extra_arches.__version__", + "HIDDEN:spandrel_extra_arches.__helper", + "HIDDEN:spandrel_extra_arches.architectures.*.__arch", ] From b7578c4114e8b922c091ab83ca19c3d18ddffa09 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 8 Jul 2024 17:44:04 +0200 Subject: [PATCH 2/4] Fix inits --- .../architectures/AdaCode/__arch/__init__.py | 146 ------------------ .../CodeFormer/__arch/__init__.py | 146 ------------------ .../architectures/DDColor/__arch/__init__.py | 146 ------------------ .../architectures/FeMaSR/__arch/__init__.py | 146 ------------------ .../architectures/M3SNet/__arch/__init__.py | 146 ------------------ .../architectures/MAT/__arch/__init__.py | 146 ------------------ .../architectures/MIRNet2/__arch/__init__.py | 146 ------------------ .../architectures/MPRNet/__arch/__init__.py | 146 ------------------ .../Restormer/__arch/__init__.py | 146 ------------------ .../architectures/SRFormer/__arch/__init__.py | 146 ------------------ 10 files changed, 1460 deletions(-) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py index 6540556d..e69de29b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py @@ -1,146 +0,0 @@ -import math - -from typing_extensions import override - -from spandrel.util import KeyCondition, get_seq_len - -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) -from .__arch.Uformer import Uformer - - -class UformerArch(Architecture[Uformer]): - def __init__(self) -> None: - super().__init__( - id="Uformer", - detect=KeyCondition.has_all( - "input_proj.proj.0.weight", - "output_proj.proj.0.weight", - "encoderlayer_0.blocks.0.norm1.weight", - "encoderlayer_2.blocks.0.norm1.weight", - "conv.blocks.0.norm1.weight", - "decoderlayer_0.blocks.0.norm1.weight", - "decoderlayer_2.blocks.0.norm1.weight", - ), - ) - - @override - def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: - img_size = 256 # cannot be deduced from state_dict - in_chans = 3 - dd_in = 3 - embed_dim = 32 - depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] - num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] - win_size = 8 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 # cannot be deduced from state_dict - attn_drop_rate = 0.0 # cannot be deduced from state_dict - drop_path_rate = 0.1 # cannot be deduced from state_dict - token_projection = "linear" - token_mlp = "leff" - shift_flag = True # cannot be deduced from state_dict - modulator = False - cross_modulator = False - - embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] - dd_in = state_dict["input_proj.proj.0.weight"].shape[1] - in_chans = state_dict["output_proj.proj.0.weight"].shape[0] - - depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") - depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") - depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") - depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") - depths[4] = get_seq_len(state_dict, "conv.blocks") - depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") - depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") - depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") - depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") - - num_heads_suffix = "blocks.0.attn.relative_position_bias_table" - num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] - num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] - num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] - num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] - num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] - num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] - - if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: - token_projection = "conv" - qkv_bias = True # cannot be deduced from state_dict - else: - token_projection = "linear" - qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict - - modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict - cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict - - # size_temp = (2 * win_size - 1) ** 2 - size_temp = state_dict[ - "encoderlayer_0.blocks.0.attn.relative_position_bias_table" - ].shape[0] - win_size = (int(math.sqrt(size_temp)) + 1) // 2 - - if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: - token_mlp = "mlp" # or "ffn", doesn't matter - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] - / embed_dim - ) - elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: - token_mlp = "leff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - else: - token_mlp = "fastleff" - mlp_ratio = ( - state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] - / embed_dim - ) - - model = Uformer( - img_size=img_size, - in_chans=in_chans, - dd_in=dd_in, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - win_size=win_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - token_projection=token_projection, - token_mlp=token_mlp, - shift_flag=shift_flag, - modulator=modulator, - cross_modulator=cross_modulator, - ) - - return ImageModelDescriptor( - model, - state_dict, - architecture=self, - purpose="Restoration", - tags=[], - supports_half=False, # Too much weirdness to support this at the moment - supports_bfloat16=True, - scale=1, - input_channels=dd_in, - output_channels=dd_in, - size_requirements=SizeRequirements(multiple_of=128, square=True), - ) - - -__all__ = ["UformerArch", "Uformer"] From 785434490ea96e33ac656685fea9f09ae27ca5bf Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Tue, 9 Jul 2024 14:03:07 +0200 Subject: [PATCH 3/4] Fixed arch doc strings --- .../architectures/ATD/__arch/atd_arch.py | 7 +- .../architectures/CRAFT/__arch/CRAFT.py | 4 +- .../spandrel/architectures/GRL/__arch/grl.py | 29 +++++--- .../NAFNet/__arch/NAFNet_arch.py | 20 +++--- .../__arch/restoreformer_arch.py | 14 ++-- .../spandrel/architectures/__init__.py | 2 + .../AdaCode/__arch/adacode_contrast_arch.py | 68 ++----------------- .../DDColor/__arch/transformer.py | 7 +- .../architectures/FeMaSR/__arch/fema_utils.py | 14 ++-- .../architectures/FeMaSR/__arch/femasr.py | 16 ++--- .../architectures/__init__.py | 2 + 11 files changed, 71 insertions(+), 112 deletions(-) diff --git a/libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py b/libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py index 8628a9ec..6e8cd7a1 100644 --- a/libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py +++ b/libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py @@ -881,9 +881,10 @@ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): @store_hyperparameters() class ATD(nn.Module): - r"""ATD - A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer - with Adaptive Token Dictionary`. + r""" + ATD + + A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary`. Args: img_size (int | tuple(int)): Input image size. Default 64 diff --git a/libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py b/libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py index 94a23da4..08b3fa9e 100644 --- a/libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py +++ b/libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py @@ -103,8 +103,10 @@ def forward(self, biases): class Attention_regular(nn.Module): - """Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias. + """ + Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias. It supports both of shifted and non-shifted window. + Args: dim (int): Number of input channels. resolution (int): Input resolution. diff --git a/libs/spandrel/spandrel/architectures/GRL/__arch/grl.py b/libs/spandrel/spandrel/architectures/GRL/__arch/grl.py index c427f9ae..7be693c0 100644 --- a/libs/spandrel/spandrel/architectures/GRL/__arch/grl.py +++ b/libs/spandrel/spandrel/architectures/GRL/__arch/grl.py @@ -34,7 +34,9 @@ class TransformerStage(nn.Module): - """Transformer stage. + """ + Transformer stage. + Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. @@ -58,11 +60,13 @@ class TransformerStage(nn.Module): pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0]. conv_type: The convolutional block before residual connection. init_method: initialization method of the weight parameters used to train large scale models. - Choices: n, normal -- Swin V1 init method. - l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. - r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 - w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 - t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale + + Choices: + * n, normal -- Swin V1 init method. + * l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. + * r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 + * w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 + * t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale fairscale_checkpoint (bool): Whether to use fairscale checkpoint. offload_to_cpu (bool): used by fairscale_checkpoint args: @@ -185,6 +189,7 @@ def flops(self): @store_hyperparameters() class GRL(nn.Module): r"""Image restoration transformer with global, non-local, and local connections + Args: img_size (int | list[int]): Input image size. Default 64 in_channels (int): Number of input image channels. Default: 3 @@ -216,11 +221,13 @@ class GRL(nn.Module): norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear init_method: initialization method of the weight parameters used to train large scale models. - Choices: n, normal -- Swin V1 init method. - l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. - r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 - w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 - t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale + + Choices: + * n, normal -- Swin V1 init method. + * l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. + * r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 + * w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 + * t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale fairscale_checkpoint (bool): Whether to use fairscale checkpoint. offload_to_cpu (bool): used by fairscale_checkpoint euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study. diff --git a/libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py b/libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py index c688d03f..eda1ecdd 100644 --- a/libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py +++ b/libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py @@ -2,16 +2,16 @@ # Copyright (c) 2022 megvii-model. All Rights Reserved. # ------------------------------------------------------------------------ -""" -Simple Baselines for Image Restoration - -@article{chen2022simple, - title={Simple Baselines for Image Restoration}, - author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, - journal={arXiv preprint arXiv:2204.04676}, - year={2022} -} -""" +# """ +# Simple Baselines for Image Restoration + +# @article{chen2022simple, +# title={Simple Baselines for Image Restoration}, +# author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, +# journal={arXiv preprint arXiv:2204.04676}, +# year={2022} +# } +# """ from __future__ import annotations diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/restoreformer_arch.py b/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/restoreformer_arch.py index b1034297..8c6fcd42 100644 --- a/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/restoreformer_arch.py +++ b/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/restoreformer_arch.py @@ -10,14 +10,14 @@ class VectorQuantizer(nn.Module): """ - see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py - ____________________________________________ Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ + + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + + Args: + n_e : number of embeddings + e_dim : dimension of embedding + beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 """ def __init__(self, n_e, e_dim, beta): diff --git a/libs/spandrel/spandrel/architectures/__init__.py b/libs/spandrel/spandrel/architectures/__init__.py index 7e6c7cbe..c6e17bdf 100644 --- a/libs/spandrel/spandrel/architectures/__init__.py +++ b/libs/spandrel/spandrel/architectures/__init__.py @@ -1,3 +1,5 @@ """ The package containing the implementations of all supported architectures. Not necessary for most user code. """ + +__docformat__ = "google" diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/adacode_contrast_arch.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/adacode_contrast_arch.py index 00e4f6cc..57281557 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/adacode_contrast_arch.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/adacode_contrast_arch.py @@ -6,68 +6,12 @@ from spandrel.util import store_hyperparameters -from ...FeMaSR.__arch.femasr import DecoderBlock, MultiScaleEncoder, SwinLayers - - -class VectorQuantizer(nn.Module): - """ - see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py - ____________________________________________ - Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ - """ - - def __init__(self, n_e, e_dim): - super().__init__() - self.n_e = int(n_e) - self.e_dim = int(e_dim) - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - def dist(self, x, y): - return ( - torch.sum(x**2, dim=1, keepdim=True) - + torch.sum(y**2, dim=1) - - 2 * torch.matmul(x, y.t()) - ) - - def forward(self, z: torch.Tensor): - """ - Args: - z: input features to be quantized, z (continuous) -> z_q (discrete) - z.shape = (batch, channel, height, width) - gt_indices: feature map of given indices, used for visualization. - """ - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) - - codebook = self.embedding.weight - - d = self.dist(z_flattened, codebook) - - # find closest encodings - min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encodings = torch.zeros( - min_encoding_indices.shape[0], codebook.shape[0] - ).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, codebook) - z_q = z_q.view(z.shape) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q +from ...FeMaSR.__arch.femasr import ( + DecoderBlock, + MultiScaleEncoder, + SwinLayers, + VectorQuantizer, +) class WeightPredictor(nn.Module): diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/transformer.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/transformer.py index 15e62698..4f0e4bd6 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/transformer.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/transformer.py @@ -2,10 +2,11 @@ # Modified from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py """ Transformer class. + Copy-paste from torch.nn.Transformer with modifications: - * positional encodings are passed in MHattention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers +- positional encodings are passed in MHattention +- extra LN at the end of encoder is removed +- decoder returns a stack of activations from all decoding layers """ from __future__ import annotations diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/fema_utils.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/fema_utils.py index 02fdeefa..eb5f6d2f 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/fema_utils.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/fema_utils.py @@ -5,10 +5,10 @@ class NormLayer(nn.Module): """Normalization Layers. - ------------ - # Arguments - - channels: input channels, for batch norm and instance norm. - - input_size: input shape without batch size, for layer norm. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. """ def __init__(self, channels, norm_type="bn"): @@ -35,9 +35,9 @@ def forward(self, x): class ActLayer(nn.Module): """activation layer. - ------------ - # Arguments - - relu type: type of relu layer, candidates are + + Args: + relu type: type of relu layer, candidates are - ReLU - LeakyReLU: default relu slope 0.2 - PRelu diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/femasr.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/femasr.py index 0a62650f..62c0b65f 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/femasr.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/femasr.py @@ -12,14 +12,14 @@ class VectorQuantizer(nn.Module): """ - see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py - ____________________________________________ Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ + + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + + Args: + n_e : number of embeddings + e_dim : dimension of embedding + beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 """ def __init__(self, n_e, e_dim): @@ -36,7 +36,7 @@ def dist(self, x, y): - 2 * torch.matmul(x, y.t()) ) - def forward(self, z): + def forward(self, z: torch.Tensor): """ Args: z: input features to be quantized, z (continuous) -> z_q (discrete) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py index 7e6c7cbe..c6e17bdf 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py @@ -1,3 +1,5 @@ """ The package containing the implementations of all supported architectures. Not necessary for most user code. """ + +__docformat__ = "google" From e1fea747d7fac44cc95068ccbfc15b8ee269bcd5 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Thu, 11 Jul 2024 15:42:24 +0200 Subject: [PATCH 4/4] Missing four --- libs/spandrel/spandrel/architectures/Compact/__init__.py | 3 +++ libs/spandrel/spandrel/architectures/ESRGAN/__init__.py | 3 +++ libs/spandrel/spandrel/architectures/GFPGAN/__init__.py | 3 +++ libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/libs/spandrel/spandrel/architectures/Compact/__init__.py b/libs/spandrel/spandrel/architectures/Compact/__init__.py index be599c37..03bd204d 100644 --- a/libs/spandrel/spandrel/architectures/Compact/__init__.py +++ b/libs/spandrel/spandrel/architectures/Compact/__init__.py @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Compact]: input_channels=in_nc, output_channels=out_nc, ) + + +__all__ = ["CompactArch", "Compact"] diff --git a/libs/spandrel/spandrel/architectures/ESRGAN/__init__.py b/libs/spandrel/spandrel/architectures/ESRGAN/__init__.py index 352bc0f4..73ac6cfb 100644 --- a/libs/spandrel/spandrel/architectures/ESRGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/ESRGAN/__init__.py @@ -233,3 +233,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ESRGAN]: multiple_of=4 if shuffle_factor else 1, ), ) + + +__all__ = ["ESRGANArch", "ESRGAN"] diff --git a/libs/spandrel/spandrel/architectures/GFPGAN/__init__.py b/libs/spandrel/spandrel/architectures/GFPGAN/__init__.py index 29dd1af8..ad855da1 100644 --- a/libs/spandrel/spandrel/architectures/GFPGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/GFPGAN/__init__.py @@ -61,3 +61,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGAN]: size_requirements=SizeRequirements(minimum=512), call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["GFPGANArch", "GFPGAN"] diff --git a/libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py b/libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py index c811974e..000b16cb 100644 --- a/libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py +++ b/libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py @@ -81,3 +81,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMNBCIE]: output_channels=3, # hard-coded in the arch size_requirements=SizeRequirements(multiple_of=16), ) + + +__all__ = ["SAFMNBCIEArch", "SAFMNBCIE"]