Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved parameter detection for Swin2SR #59

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 76 additions & 114 deletions src/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import math
import re

from torch import nn

from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict
from ..__arch_helpers.state import get_seq_len
from .arch.Swin2SR import Swin2SR


def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
# Defaults
img_size = 128
img_size = 64
patch_size = 1
in_chans = 3
embed_dim = 96
Expand All @@ -18,141 +16,107 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
window_size = 7
mlp_ratio = 4.0
qkv_bias = True
drop_rate = 0.0
attn_drop_rate = 0.0
drop_path_rate = 0.1
norm_layer = nn.LayerNorm
drop_rate = 0.0 # cannot be deduced from state_dict
attn_drop_rate = 0.0 # cannot be deduced from state_dict
drop_path_rate = 0.1 # cannot be deduced from state_dict
ape = False
patch_norm = True
use_checkpoint = False
use_checkpoint = False # cannot be deduced from state_dict
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
upscale = 2
img_range = 1.0
upsampler = ""
resi_connection = "1conv"
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64

state = state_dict

state_keys = state.keys()

if "conv_before_upsample.0.weight" in state_keys:
if "conv_aux.weight" in state_keys:
upsampler = "pixelshuffle_aux"
elif "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
else:
upsampler = "pixelshuffle"
elif "upsample.0.weight" in state_keys:
upsampler = "pixelshuffledirect"
else:
upsampler = ""

num_feat_pre_layer = state_dict.get("conv_before_upsample.weight", None)
num_feat_layer = state_dict.get("conv_before_upsample.0.weight", None)
num_feat = (
num_feat_layer.shape[1]
if num_feat_layer is not None and num_feat_pre_layer is not None
else 64
)
in_chans = state_dict["conv_first.weight"].shape[1]
embed_dim = state_dict["conv_first.weight"].shape[0]
patch_size = state_dict["patch_embed.proj.weight"].shape[2]

num_in_ch = state["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
num_out_ch = state["conv_last.weight"].shape[0]
else:
num_out_ch = num_in_ch

upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x]

for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
elif upsampler == "pixelshuffledirect":
upscale = int(math.sqrt(state["upsample.0.bias"].shape[0] // num_out_ch))

max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))

depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]

if (
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
in state_keys
):
num_heads_num = state[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[-1]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
else:
num_heads = depths
ape = "absolute_pos_embed" in state_dict
patch_norm = "patch_embed.norm.weight" in state_dict
qkv_bias = "layers.0.residual_group.blocks.0.attn.q_bias" in state_dict

embed_dim = state["conv_first.weight"].shape[0]
# depths & num_heads
num_layers = get_seq_len(state_dict, "layers")
depths = [6] * num_layers
num_heads = [6] * num_layers
for i in range(num_layers):
depths[i] = get_seq_len(state_dict, f"layers.{i}.residual_group.blocks")
num_heads[i] = state_dict[
f"layers.{i}.residual_group.blocks.0.attn.logit_scale"
].shape[0]

mlp_ratio = float(
state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0] / embed_dim
state_dict["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0]
/ embed_dim
)

# TODO: could actually count the layers, but this should do
if "layers.0.conv.4.weight" in state_keys:
if "conv_after_body.0.weight" in state_dict:
resi_connection = "3conv"
else:
elif "conv_after_body.weight" in state_dict:
resi_connection = "1conv"
else:
raise ValueError("Unknown residual connection type")

# upsampler
if "conv_bicubic.weight" in state_dict:
upsampler = "pixelshuffle_aux"
elif "conv_hr.weight" in state_dict:
upsampler = "nearest+conv"
elif "conv_after_body_hf.weight" in state_dict:
upsampler = "pixelshuffle_hf"
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"
elif "upsample.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
else:
upsampler = ""

if upsampler == "":
upscale = 1
elif upsampler == "nearest+conv":
upscale = 4 # only supports 4x
elif upsampler == "pixelshuffledirect":
upscale = int(math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans))
else:
num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))

window_size = int(
math.sqrt(
state[
state_dict[
"layers.0.residual_group.blocks.0.attn.relative_position_index"
].shape[0]
)
)

if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
img_size = int(
math.sqrt(state["layers.0.residual_group.blocks.1.attn_mask"].shape[0])
* window_size
)
# Now for img_size... What we know:
# patches_resolution = img_size // patch_size
# if window_size > patches_resolution:
# attn_mask[0] = patches_resolution**2 // window_size**2
if "layers.0.residual_group.blocks.1.attn_mask" in state_dict:
attn_mask_0 = state_dict["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
patches_resolution = int(math.sqrt(attn_mask_0 * window_size * window_size))
img_size = patches_resolution * patch_size
else:
# we only know that window_size <= patches_resolution
# assume window_size == patches_resolution
img_size = patch_size * window_size

# if APE is enabled, we know that absolute_pos_embed[1] == patches_resolution**2
if ape:
patches_resolution = int(math.sqrt(state_dict["absolute_pos_embed"][1]))
img_size = patches_resolution * patch_size

# The JPEG models are the only ones with window-size 7, and they also use this range
img_range = 255.0 if window_size == 7 else 1.0

in_nc = num_in_ch
out_nc = num_out_ch
num_feat = num_feat
embed_dim = embed_dim
num_heads = num_heads
depths = depths
window_size = window_size
mlp_ratio = mlp_ratio
scale = upscale
upsampler = upsampler
img_size = img_size
img_range = img_range
resi_connection = resi_connection

model = Swin2SR(
img_size=img_size,
patch_size=patch_size,
in_chans=in_nc,
num_feat=num_feat,
in_chans=in_chans,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
Expand All @@ -162,7 +126,6 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
ape=ape,
patch_norm=patch_norm,
use_checkpoint=use_checkpoint,
Expand All @@ -182,7 +145,6 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
tags = [
size_tag,
f"s{img_size}w{window_size}",
f"{num_feat}nf",
f"{embed_dim}dim",
f"{resi_connection}",
]
Expand All @@ -194,8 +156,8 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
tags=tags,
supports_half=False, # Too much weirdness to support this at the moment
supports_bfloat16=True,
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
scale=upscale,
input_channels=in_chans,
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)
1 change: 0 additions & 1 deletion src/spandrel/architectures/Swin2SR/arch/Swin2SR.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,6 @@ def __init__(
img_range=1.0,
upsampler="",
resi_connection="1conv",
**kwargs,
):
super().__init__()
num_in_ch = in_chans
Expand Down
69 changes: 68 additions & 1 deletion tests/__snapshots__/test_Swin2SR.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
# serializer version: 1
# name: test_Swin2SR_2x
SRModelDescriptor(
architecture='Swin2SR',
input_channels=3,
output_channels=3,
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'medium',
's64w8',
'180dim',
'1conv',
]),
)
# ---
# name: test_Swin2SR_4x
SRModelDescriptor(
architecture='Swin2SR',
Expand All @@ -11,9 +28,59 @@
tags=list([
'medium',
's64w8',
'64nf',
'180dim',
'1conv',
]),
)
# ---
# name: test_Swin2SR_compressed
SRModelDescriptor(
architecture='Swin2SR',
input_channels=3,
output_channels=3,
scale=4,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'medium',
's48w8',
'180dim',
'1conv',
]),
)
# ---
# name: test_Swin2SR_jpeg
SRModelDescriptor(
architecture='Swin2SR',
input_channels=1,
output_channels=1,
scale=1,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'medium',
's126w7',
'180dim',
'1conv',
]),
)
# ---
# name: test_Swin2SR_lightweight_2x
SRModelDescriptor(
architecture='Swin2SR',
input_channels=3,
output_channels=3,
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'small',
's64w8',
'60dim',
'1conv',
]),
)
# ---
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading