diff --git a/src/spandrel/architectures/DAT/__init__.py b/src/spandrel/architectures/DAT/__init__.py index 75a192a5..7b8cca4b 100644 --- a/src/spandrel/architectures/DAT/__init__.py +++ b/src/spandrel/architectures/DAT/__init__.py @@ -1,135 +1,89 @@ import math -import re from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict +from ..__arch_helpers.state import get_seq_len from .arch.DAT import DAT def load(state_dict: StateDict) -> SRModelDescriptor[DAT]: # defaults - img_size = 64 + img_size = 64 # cannot be deduced from state dict in general in_chans = 3 embed_dim = 180 split_size = [2, 4] depth = [2, 2, 2, 2] num_heads = [2, 2, 2, 2] expansion_factor = 4.0 + qkv_bias = True upscale = 2 img_range = 1.0 resi_connection = "1conv" upsampler = "pixelshuffle" + num_feat = 64 - state_keys = state_dict.keys() - if "conv_before_upsample.0.weight" in state_keys: - if "conv_up1.weight" in state_keys: - upsampler = "nearest+conv" + in_chans = state_dict["conv_first.weight"].shape[1] + embed_dim = state_dict["conv_first.weight"].shape[0] + + # num_layers = len(depth) + num_layers = get_seq_len(state_dict, "layers") + depth = [get_seq_len(state_dict, f"layers.{i}.blocks") for i in range(num_layers)] + + # num_heads is linked to depth + num_heads = [2] * num_layers + for i in range(num_layers): + if depth[i] >= 2: + # that's the easy path, we can directly read the head count + num_heads[i] = state_dict[f"layers.{i}.blocks.1.attn.temperature"].shape[0] else: - upsampler = "pixelshuffle" - elif "upsample.0.weight" in state_keys: - upsampler = "pixelshuffledirect" - else: - upsampler = "" - - num_feat_pre_layer = state_dict.get("conv_before_upsample.weight", None) - num_feat_layer = state_dict.get("conv_before_upsample.0.weight", None) - num_feat = ( - num_feat_layer.shape[1] - if num_feat_layer is not None and num_feat_pre_layer is not None - else 64 - ) + # because of a head_num // 2, we can only reconstruct even head counts + key = f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight" + num_heads[i] = state_dict[key].shape[0] * 2 - num_in_ch = state_dict["conv_first.weight"].shape[1] - in_chans = num_in_ch - if "conv_last.weight" in state_keys: - num_out_ch = state_dict["conv_last.weight"].shape[0] - else: - num_out_ch = num_in_ch - - upscale = 1 - if upsampler == "nearest+conv": - upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x] - - for upsample_key in upsample_keys: - upscale *= 2 - elif upsampler == "pixelshuffle": - upsample_keys = [ - x - for x in state_keys - if "upsample" in x and "conv" not in x and "bias" not in x - ] - for upsample_key in upsample_keys: - shape = state_dict[upsample_key].shape[0] - upscale *= math.sqrt(shape // num_feat) - upscale = int(upscale) + upsampler = ( + "pixelshuffle" if "conv_last.weight" in state_dict else "pixelshuffledirect" + ) + resi_connection = "1conv" if "conv_after_body.weight" in state_dict else "3conv" + + if upsampler == "pixelshuffle": + upscale = 1 + for i in range(0, get_seq_len(state_dict, "upsample"), 2): + num_feat = state_dict[f"upsample.{i}.weight"].shape[1] + shape = state_dict[f"upsample.{i}.weight"].shape[0] + upscale *= int(math.sqrt(shape // num_feat)) elif upsampler == "pixelshuffledirect": - upscale = int(math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch)) - - max_layer_num = 0 - max_block_num = 0 - for key in state_keys: - result = re.match(r"layers.(\d*).blocks.(\d*).norm1.weight", key) - if result: - layer_num, block_num = result.groups() - max_layer_num = max(max_layer_num, int(layer_num)) - max_block_num = max(max_block_num, int(block_num)) - - depth = [max_block_num + 1 for _ in range(max_layer_num + 1)] - - if "layers.0.blocks.1.attn.temperature" in state_keys: - num_heads_num = state_dict["layers.0.blocks.1.attn.temperature"].shape[0] - num_heads = [num_heads_num for _ in range(max_layer_num + 1)] - else: - num_heads = depth + num_feat = state_dict["upsample.0.weight"].shape[1] + upscale = int(math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans)) + + qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in state_dict - embed_dim = state_dict["conv_first.weight"].shape[0] expansion_factor = float( state_dict["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim ) - # TODO: could actually count the layers, but this should do - if "layers.0.conv.4.weight" in state_keys: - resi_connection = "3conv" - else: - resi_connection = "1conv" - - if "layers.0.blocks.2.attn.attn_mask_0" in state_keys: + if "layers.0.blocks.2.attn.attn_mask_0" in state_dict: attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = state_dict[ "layers.0.blocks.2.attn.attn_mask_0" ].shape img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y)) - if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_keys: + if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_dict: split_sizes = state_dict["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1 split_size = [int(x) for x in split_sizes] - in_nc = num_in_ch - out_nc = num_out_ch - num_feat = num_feat - embed_dim = embed_dim - num_heads = num_heads - depth = depth - scale = upscale - upsampler = upsampler - img_size = img_size - img_range = img_range - expansion_factor = expansion_factor - resi_connection = resi_connection - split_size = split_size - model = DAT( img_size=img_size, in_chans=in_chans, - num_feat=num_feat, embed_dim=embed_dim, + split_size=split_size, depth=depth, num_heads=num_heads, expansion_factor=expansion_factor, - split_size=split_size, - scale=scale, - upsampler=upsampler, + qkv_bias=qkv_bias, + upscale=upscale, img_range=img_range, resi_connection=resi_connection, + upsampler=upsampler, ) head_length = len(depth) @@ -139,6 +93,7 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]: size_tag = "medium" else: size_tag = "large" + tags = [ size_tag, f"s{img_size}|{split_size[0]}x{split_size[1]}", @@ -154,8 +109,8 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]: tags=tags, supports_half=False, # Too much weirdness to support this at the moment supports_bfloat16=True, - scale=scale, - input_channels=in_nc, - output_channels=out_nc, + scale=upscale, + input_channels=in_chans, + output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/DAT/arch/DAT.py b/src/spandrel/architectures/DAT/arch/DAT.py index 81ace883..edcc0b6a 100644 --- a/src/spandrel/architectures/DAT/arch/DAT.py +++ b/src/spandrel/architectures/DAT/arch/DAT.py @@ -1,4 +1,4 @@ -# type: ignore +from __future__ import annotations import math @@ -55,7 +55,7 @@ def __init__(self, dim): def forward(self, x, H, W): # Split x1, x2 = x.chunk(2, dim=-1) - B, N, C = x.shape + B, _N, C = x.shape x2 = ( self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C // 2, H, W)) .flatten(2) @@ -225,7 +225,7 @@ def __init__( self.attn_drop = nn.Dropout(attn_drop) def im2win(self, x, H, W): - B, N, C = x.shape + B, _N, C = x.shape x = x.transpose(-2, -1).contiguous().view(B, C, H, W) x = img2windows(x, self.H_sp, self.W_sp) x = ( @@ -805,7 +805,7 @@ def __init__( qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, - drop_path=drop_paths[i], + drop_path=drop_paths[i], # type: ignore act_layer=act_layer, norm_layer=norm_layer, rg_idx=rg_idx, @@ -879,7 +879,7 @@ class UpsampleOneStep(nn.Sequential): """ - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + def __init__(self, scale, num_feat, num_out_ch, input_resolution): self.num_feat = num_feat self.input_resolution = input_resolution m = [] @@ -926,7 +926,7 @@ def __init__( num_heads=[2, 2, 2, 2], expansion_factor=4.0, qkv_bias=True, - qk_scale=None, + qk_scale: float | None = None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, @@ -937,7 +937,6 @@ def __init__( img_range=1.0, resi_connection="1conv", upsampler="pixelshuffle", - **kwargs, ): super().__init__() @@ -1028,7 +1027,7 @@ def __init__( def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: + if isinstance(m, nn.Linear) and m.bias is not None: # type: ignore nn.init.constant_(m.bias, 0) elif isinstance( m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d) diff --git a/tests/test_DAT.py b/tests/test_DAT.py new file mode 100644 index 00000000..1f8dcf2f --- /dev/null +++ b/tests/test_DAT.py @@ -0,0 +1,37 @@ +from spandrel.architectures.DAT import DAT, load + +from .util import assert_loads_correctly + + +def test_DAT_load(): + assert_loads_correctly( + load, + lambda: DAT(), + lambda: DAT(embed_dim=60), + lambda: DAT(in_chans=1), + lambda: DAT(in_chans=4), + lambda: DAT(depth=[2, 3], num_heads=[2, 5]), + lambda: DAT(depth=[2, 3, 4, 2], num_heads=[2, 3, 2, 2]), + lambda: DAT(depth=[2, 3, 4, 2, 5], num_heads=[2, 3, 2, 2, 3]), + lambda: DAT(upsampler="pixelshuffle", upscale=1), + lambda: DAT(upsampler="pixelshuffle", upscale=2), + lambda: DAT(upsampler="pixelshuffle", upscale=3), + lambda: DAT(upsampler="pixelshuffle", upscale=4), + lambda: DAT(upsampler="pixelshuffle", upscale=8), + lambda: DAT(upsampler="pixelshuffledirect", upscale=1), + lambda: DAT(upsampler="pixelshuffledirect", upscale=2), + lambda: DAT(upsampler="pixelshuffledirect", upscale=3), + lambda: DAT(upsampler="pixelshuffledirect", upscale=4), + lambda: DAT(upsampler="pixelshuffledirect", upscale=8), + lambda: DAT(resi_connection="3conv"), + lambda: DAT(qkv_bias=False), + lambda: DAT(split_size=[4, 4]), + lambda: DAT(split_size=[2, 8]), + condition=lambda a, b: ( + a.num_layers == b.num_layers + and a.upscale == b.upscale + and a.upsampler == b.upsampler + and a.embed_dim == b.embed_dim + and a.num_features == b.num_features + ), + )