diff --git a/src/spandrel/architectures/Swin2SR/__init__.py b/src/spandrel/architectures/Swin2SR/__init__.py index a6c56774..f3554d59 100644 --- a/src/spandrel/architectures/Swin2SR/__init__.py +++ b/src/spandrel/architectures/Swin2SR/__init__.py @@ -1,15 +1,13 @@ import math -import re - -from torch import nn from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict +from ..__arch_helpers.state import get_seq_len from .arch.Swin2SR import Swin2SR def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: # Defaults - img_size = 128 + img_size = 64 patch_size = 1 in_chans = 3 embed_dim = 96 @@ -18,141 +16,107 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: window_size = 7 mlp_ratio = 4.0 qkv_bias = True - drop_rate = 0.0 - attn_drop_rate = 0.0 - drop_path_rate = 0.1 - norm_layer = nn.LayerNorm + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict ape = False patch_norm = True - use_checkpoint = False + use_checkpoint = False # cannot be deduced from state_dict upscale = 2 img_range = 1.0 upsampler = "" resi_connection = "1conv" - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - - state = state_dict - - state_keys = state.keys() - - if "conv_before_upsample.0.weight" in state_keys: - if "conv_aux.weight" in state_keys: - upsampler = "pixelshuffle_aux" - elif "conv_up1.weight" in state_keys: - upsampler = "nearest+conv" - else: - upsampler = "pixelshuffle" - elif "upsample.0.weight" in state_keys: - upsampler = "pixelshuffledirect" - else: - upsampler = "" - num_feat_pre_layer = state_dict.get("conv_before_upsample.weight", None) - num_feat_layer = state_dict.get("conv_before_upsample.0.weight", None) - num_feat = ( - num_feat_layer.shape[1] - if num_feat_layer is not None and num_feat_pre_layer is not None - else 64 - ) + in_chans = state_dict["conv_first.weight"].shape[1] + embed_dim = state_dict["conv_first.weight"].shape[0] + patch_size = state_dict["patch_embed.proj.weight"].shape[2] - num_in_ch = state["conv_first.weight"].shape[1] - in_chans = num_in_ch - if "conv_last.weight" in state_keys: - num_out_ch = state["conv_last.weight"].shape[0] - else: - num_out_ch = num_in_ch - - upscale = 1 - if upsampler == "nearest+conv": - upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x] - - for upsample_key in upsample_keys: - upscale *= 2 - elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux": - upsample_keys = [ - x - for x in state_keys - if "upsample" in x and "conv" not in x and "bias" not in x - ] - for upsample_key in upsample_keys: - shape = state[upsample_key].shape[0] - upscale *= math.sqrt(shape // num_feat) - upscale = int(upscale) - elif upsampler == "pixelshuffledirect": - upscale = int(math.sqrt(state["upsample.0.bias"].shape[0] // num_out_ch)) - - max_layer_num = 0 - max_block_num = 0 - for key in state_keys: - result = re.match(r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key) - if result: - layer_num, block_num = result.groups() - max_layer_num = max(max_layer_num, int(layer_num)) - max_block_num = max(max_block_num, int(block_num)) - - depths = [max_block_num + 1 for _ in range(max_layer_num + 1)] - - if ( - "layers.0.residual_group.blocks.0.attn.relative_position_bias_table" - in state_keys - ): - num_heads_num = state[ - "layers.0.residual_group.blocks.0.attn.relative_position_bias_table" - ].shape[-1] - num_heads = [num_heads_num for _ in range(max_layer_num + 1)] - else: - num_heads = depths + ape = "absolute_pos_embed" in state_dict + patch_norm = "patch_embed.norm.weight" in state_dict + qkv_bias = "layers.0.residual_group.blocks.0.attn.q_bias" in state_dict - embed_dim = state["conv_first.weight"].shape[0] + # depths & num_heads + num_layers = get_seq_len(state_dict, "layers") + depths = [6] * num_layers + num_heads = [6] * num_layers + for i in range(num_layers): + depths[i] = get_seq_len(state_dict, f"layers.{i}.residual_group.blocks") + num_heads[i] = state_dict[ + f"layers.{i}.residual_group.blocks.0.attn.logit_scale" + ].shape[0] mlp_ratio = float( - state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0] / embed_dim + state_dict["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim ) - # TODO: could actually count the layers, but this should do - if "layers.0.conv.4.weight" in state_keys: + if "conv_after_body.0.weight" in state_dict: resi_connection = "3conv" - else: + elif "conv_after_body.weight" in state_dict: resi_connection = "1conv" + else: + raise ValueError("Unknown residual connection type") + + # upsampler + if "conv_bicubic.weight" in state_dict: + upsampler = "pixelshuffle_aux" + elif "conv_hr.weight" in state_dict: + upsampler = "nearest+conv" + elif "conv_after_body_hf.weight" in state_dict: + upsampler = "pixelshuffle_hf" + elif "conv_before_upsample.0.weight" in state_dict: + upsampler = "pixelshuffle" + elif "upsample.0.weight" in state_dict: + upsampler = "pixelshuffledirect" + else: + upsampler = "" + + if upsampler == "": + upscale = 1 + elif upsampler == "nearest+conv": + upscale = 4 # only supports 4x + elif upsampler == "pixelshuffledirect": + upscale = int(math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans)) + else: + num_feat = 64 # hard-coded constant + upscale = 1 + for i in range(0, get_seq_len(state_dict, "upsample"), 2): + shape = state_dict[f"upsample.{i}.weight"].shape[0] + upscale *= int(math.sqrt(shape // num_feat)) window_size = int( math.sqrt( - state[ + state_dict[ "layers.0.residual_group.blocks.0.attn.relative_position_index" ].shape[0] ) ) - if "layers.0.residual_group.blocks.1.attn_mask" in state_keys: - img_size = int( - math.sqrt(state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]) - * window_size - ) + # Now for img_size... What we know: + # patches_resolution = img_size // patch_size + # if window_size > patches_resolution: + # attn_mask[0] = patches_resolution**2 // window_size**2 + if "layers.0.residual_group.blocks.1.attn_mask" in state_dict: + attn_mask_0 = state_dict["layers.0.residual_group.blocks.1.attn_mask"].shape[0] + patches_resolution = int(math.sqrt(attn_mask_0 * window_size * window_size)) + img_size = patches_resolution * patch_size + else: + # we only know that window_size <= patches_resolution + # assume window_size == patches_resolution + img_size = patch_size * window_size + + # if APE is enabled, we know that absolute_pos_embed[1] == patches_resolution**2 + if ape: + patches_resolution = int(math.sqrt(state_dict["absolute_pos_embed"][1])) + img_size = patches_resolution * patch_size # The JPEG models are the only ones with window-size 7, and they also use this range img_range = 255.0 if window_size == 7 else 1.0 - in_nc = num_in_ch - out_nc = num_out_ch - num_feat = num_feat - embed_dim = embed_dim - num_heads = num_heads - depths = depths - window_size = window_size - mlp_ratio = mlp_ratio - scale = upscale - upsampler = upsampler - img_size = img_size - img_range = img_range - resi_connection = resi_connection - model = Swin2SR( img_size=img_size, patch_size=patch_size, - in_chans=in_nc, - num_feat=num_feat, + in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, @@ -162,7 +126,6 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, - norm_layer=norm_layer, ape=ape, patch_norm=patch_norm, use_checkpoint=use_checkpoint, @@ -182,7 +145,6 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: tags = [ size_tag, f"s{img_size}w{window_size}", - f"{num_feat}nf", f"{embed_dim}dim", f"{resi_connection}", ] @@ -194,8 +156,8 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: tags=tags, supports_half=False, # Too much weirdness to support this at the moment supports_bfloat16=True, - scale=scale, - input_channels=in_nc, - output_channels=out_nc, + scale=upscale, + input_channels=in_chans, + output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/Swin2SR/arch/Swin2SR.py b/src/spandrel/architectures/Swin2SR/arch/Swin2SR.py index 975ebcb3..94ca8a38 100644 --- a/src/spandrel/architectures/Swin2SR/arch/Swin2SR.py +++ b/src/spandrel/architectures/Swin2SR/arch/Swin2SR.py @@ -925,7 +925,6 @@ def __init__( img_range=1.0, upsampler="", resi_connection="1conv", - **kwargs, ): super().__init__() num_in_ch = in_chans diff --git a/tests/__snapshots__/test_Swin2SR.ambr b/tests/__snapshots__/test_Swin2SR.ambr index a021f072..660db125 100644 --- a/tests/__snapshots__/test_Swin2SR.ambr +++ b/tests/__snapshots__/test_Swin2SR.ambr @@ -1,4 +1,21 @@ # serializer version: 1 +# name: test_Swin2SR_2x + SRModelDescriptor( + architecture='Swin2SR', + input_channels=3, + output_channels=3, + scale=2, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + 'medium', + 's64w8', + '180dim', + '1conv', + ]), + ) +# --- # name: test_Swin2SR_4x SRModelDescriptor( architecture='Swin2SR', @@ -11,9 +28,59 @@ tags=list([ 'medium', 's64w8', - '64nf', '180dim', '1conv', ]), ) # --- +# name: test_Swin2SR_compressed + SRModelDescriptor( + architecture='Swin2SR', + input_channels=3, + output_channels=3, + scale=4, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + 'medium', + 's48w8', + '180dim', + '1conv', + ]), + ) +# --- +# name: test_Swin2SR_jpeg + SRModelDescriptor( + architecture='Swin2SR', + input_channels=1, + output_channels=1, + scale=1, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + 'medium', + 's126w7', + '180dim', + '1conv', + ]), + ) +# --- +# name: test_Swin2SR_lightweight_2x + SRModelDescriptor( + architecture='Swin2SR', + input_channels=3, + output_channels=3, + scale=2, + size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + 'small', + 's64w8', + '60dim', + '1conv', + ]), + ) +# --- diff --git a/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X2_64.png b/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X2_64.png new file mode 100644 index 00000000..02e61a66 Binary files /dev/null and b/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X2_64.png differ diff --git a/tests/images/outputs/16x16/Swin2SR_CompressedSR_X4_48.png b/tests/images/outputs/16x16/Swin2SR_CompressedSR_X4_48.png new file mode 100644 index 00000000..2fb9ed95 Binary files /dev/null and b/tests/images/outputs/16x16/Swin2SR_CompressedSR_X4_48.png differ diff --git a/tests/images/outputs/16x16/Swin2SR_Lightweight_X2_64.png b/tests/images/outputs/16x16/Swin2SR_Lightweight_X2_64.png new file mode 100644 index 00000000..dc2e1261 Binary files /dev/null and b/tests/images/outputs/16x16/Swin2SR_Lightweight_X2_64.png differ diff --git a/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X2_64.png b/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X2_64.png new file mode 100644 index 00000000..af0cc295 Binary files /dev/null and b/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X2_64.png differ diff --git a/tests/images/outputs/32x32/Swin2SR_CompressedSR_X4_48.png b/tests/images/outputs/32x32/Swin2SR_CompressedSR_X4_48.png new file mode 100644 index 00000000..c2602324 Binary files /dev/null and b/tests/images/outputs/32x32/Swin2SR_CompressedSR_X4_48.png differ diff --git a/tests/images/outputs/32x32/Swin2SR_Lightweight_X2_64.png b/tests/images/outputs/32x32/Swin2SR_Lightweight_X2_64.png new file mode 100644 index 00000000..6d50233e Binary files /dev/null and b/tests/images/outputs/32x32/Swin2SR_Lightweight_X2_64.png differ diff --git a/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X2_64.png b/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X2_64.png new file mode 100644 index 00000000..9ea5841f Binary files /dev/null and b/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X2_64.png differ diff --git a/tests/images/outputs/64x64/Swin2SR_CompressedSR_X4_48.png b/tests/images/outputs/64x64/Swin2SR_CompressedSR_X4_48.png new file mode 100644 index 00000000..f2cde0bb Binary files /dev/null and b/tests/images/outputs/64x64/Swin2SR_CompressedSR_X4_48.png differ diff --git a/tests/images/outputs/64x64/Swin2SR_Lightweight_X2_64.png b/tests/images/outputs/64x64/Swin2SR_Lightweight_X2_64.png new file mode 100644 index 00000000..6f99cf84 Binary files /dev/null and b/tests/images/outputs/64x64/Swin2SR_Lightweight_X2_64.png differ diff --git a/tests/test_Swin2SR.py b/tests/test_Swin2SR.py index ebc6dcab..84846d83 100644 --- a/tests/test_Swin2SR.py +++ b/tests/test_Swin2SR.py @@ -1,6 +1,64 @@ -from spandrel.architectures.Swin2SR import Swin2SR +from spandrel.architectures.Swin2SR import Swin2SR, load -from .util import ModelFile, TestImage, assert_image_inference, disallowed_props +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + disallowed_props, +) + + +def test_Swin2SR_load(): + assert_loads_correctly( + load, + lambda: Swin2SR(window_size=8, upsampler="pixelshuffledirect"), + lambda: Swin2SR(window_size=8, upsampler="pixelshuffledirect", ape=True), + lambda: Swin2SR( + window_size=8, + upsampler="pixelshuffledirect", + depths=[6, 7, 5, 3, 4], + num_heads=[5, 2, 9, 1, 2], + ), + lambda: Swin2SR(window_size=8, upsampler="pixelshuffledirect", qkv_bias=False), + lambda: Swin2SR( + window_size=8, upsampler="pixelshuffledirect", patch_norm=False + ), + lambda: Swin2SR( + window_size=8, upsampler="pixelshuffledirect", resi_connection="1conv" + ), + lambda: Swin2SR( + window_size=8, upsampler="pixelshuffledirect", resi_connection="3conv" + ), + lambda: Swin2SR(window_size=8, upsampler="pixelshuffledirect", patch_size=2), + condition=lambda a, b: ( + a.img_range == b.img_range + and a.upscale == b.upscale + and a.upsampler == b.upsampler + and a.window_size == b.window_size + and a.num_layers == b.num_layers + and a.embed_dim == b.embed_dim + and a.ape == b.ape + and a.patch_norm == b.patch_norm + and a.num_features == b.num_features + and a.mlp_ratio == b.mlp_ratio + and a.patches_resolution == b.patches_resolution + ), + ) + + +def test_Swin2SR_2x(snapshot): + file = ModelFile.from_url( + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X2_64.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, Swin2SR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_Swin2SR_4x(snapshot): @@ -15,3 +73,46 @@ def test_Swin2SR_4x(snapshot): model, [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], ) + + +def test_Swin2SR_compressed(snapshot): + file = ModelFile.from_url( + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_CompressedSR_X4_48.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, Swin2SR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_Swin2SR_jpeg(snapshot): + file = ModelFile.from_url( + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_Jpeg_dynamic.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, Swin2SR) + # This is a grayscale model for some reason... + # assert_image_inference( + # file, + # model, + # [TestImage.SR_64, TestImage.JPEG_15], + # ) + + +def test_Swin2SR_lightweight_2x(snapshot): + file = ModelFile.from_url( + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_Lightweight_X2_64.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, Swin2SR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + )