Skip to content

Commit

Permalink
Improved parameter detection for DAT (#35)
Browse files Browse the repository at this point in the history
* Improved parameter detection for DAT

* Review comments
  • Loading branch information
RunDevelopment authored Nov 22, 2023
1 parent 80ce019 commit 71fc7e2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 98 deletions.
135 changes: 45 additions & 90 deletions src/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,89 @@
import math
import re

from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict
from ..__arch_helpers.state import get_seq_len
from .arch.DAT import DAT


def load(state_dict: StateDict) -> SRModelDescriptor[DAT]:
# defaults
img_size = 64
img_size = 64 # cannot be deduced from state dict in general
in_chans = 3
embed_dim = 180
split_size = [2, 4]
depth = [2, 2, 2, 2]
num_heads = [2, 2, 2, 2]
expansion_factor = 4.0
qkv_bias = True
upscale = 2
img_range = 1.0
resi_connection = "1conv"
upsampler = "pixelshuffle"
num_feat = 64

state_keys = state_dict.keys()
if "conv_before_upsample.0.weight" in state_keys:
if "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
in_chans = state_dict["conv_first.weight"].shape[1]
embed_dim = state_dict["conv_first.weight"].shape[0]

# num_layers = len(depth)
num_layers = get_seq_len(state_dict, "layers")
depth = [get_seq_len(state_dict, f"layers.{i}.blocks") for i in range(num_layers)]

# num_heads is linked to depth
num_heads = [2] * num_layers
for i in range(num_layers):
if depth[i] >= 2:
# that's the easy path, we can directly read the head count
num_heads[i] = state_dict[f"layers.{i}.blocks.1.attn.temperature"].shape[0]
else:
upsampler = "pixelshuffle"
elif "upsample.0.weight" in state_keys:
upsampler = "pixelshuffledirect"
else:
upsampler = ""

num_feat_pre_layer = state_dict.get("conv_before_upsample.weight", None)
num_feat_layer = state_dict.get("conv_before_upsample.0.weight", None)
num_feat = (
num_feat_layer.shape[1]
if num_feat_layer is not None and num_feat_pre_layer is not None
else 64
)
# because of a head_num // 2, we can only reconstruct even head counts
key = f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight"
num_heads[i] = state_dict[key].shape[0] * 2

num_in_ch = state_dict["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
num_out_ch = state_dict["conv_last.weight"].shape[0]
else:
num_out_ch = num_in_ch

upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x]

for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state_dict[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
upsampler = (
"pixelshuffle" if "conv_last.weight" in state_dict else "pixelshuffledirect"
)
resi_connection = "1conv" if "conv_after_body.weight" in state_dict else "3conv"

if upsampler == "pixelshuffle":
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
num_feat = state_dict[f"upsample.{i}.weight"].shape[1]
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
elif upsampler == "pixelshuffledirect":
upscale = int(math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch))

max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(r"layers.(\d*).blocks.(\d*).norm1.weight", key)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))

depth = [max_block_num + 1 for _ in range(max_layer_num + 1)]

if "layers.0.blocks.1.attn.temperature" in state_keys:
num_heads_num = state_dict["layers.0.blocks.1.attn.temperature"].shape[0]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
else:
num_heads = depth
num_feat = state_dict["upsample.0.weight"].shape[1]
upscale = int(math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans))

qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in state_dict

embed_dim = state_dict["conv_first.weight"].shape[0]
expansion_factor = float(
state_dict["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim
)

# TODO: could actually count the layers, but this should do
if "layers.0.conv.4.weight" in state_keys:
resi_connection = "3conv"
else:
resi_connection = "1conv"

if "layers.0.blocks.2.attn.attn_mask_0" in state_keys:
if "layers.0.blocks.2.attn.attn_mask_0" in state_dict:
attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = state_dict[
"layers.0.blocks.2.attn.attn_mask_0"
].shape

img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))

if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_keys:
if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_dict:
split_sizes = state_dict["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
split_size = [int(x) for x in split_sizes]

in_nc = num_in_ch
out_nc = num_out_ch
num_feat = num_feat
embed_dim = embed_dim
num_heads = num_heads
depth = depth
scale = upscale
upsampler = upsampler
img_size = img_size
img_range = img_range
expansion_factor = expansion_factor
resi_connection = resi_connection
split_size = split_size

model = DAT(
img_size=img_size,
in_chans=in_chans,
num_feat=num_feat,
embed_dim=embed_dim,
split_size=split_size,
depth=depth,
num_heads=num_heads,
expansion_factor=expansion_factor,
split_size=split_size,
scale=scale,
upsampler=upsampler,
qkv_bias=qkv_bias,
upscale=upscale,
img_range=img_range,
resi_connection=resi_connection,
upsampler=upsampler,
)

head_length = len(depth)
Expand All @@ -139,6 +93,7 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]:
size_tag = "medium"
else:
size_tag = "large"

tags = [
size_tag,
f"s{img_size}|{split_size[0]}x{split_size[1]}",
Expand All @@ -154,8 +109,8 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]:
tags=tags,
supports_half=False, # Too much weirdness to support this at the moment
supports_bfloat16=True,
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
scale=upscale,
input_channels=in_chans,
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)
15 changes: 7 additions & 8 deletions src/spandrel/architectures/DAT/arch/DAT.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# type: ignore
from __future__ import annotations

import math

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, dim):
def forward(self, x, H, W):
# Split
x1, x2 = x.chunk(2, dim=-1)
B, N, C = x.shape
B, _N, C = x.shape
x2 = (
self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C // 2, H, W))
.flatten(2)
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(
self.attn_drop = nn.Dropout(attn_drop)

def im2win(self, x, H, W):
B, N, C = x.shape
B, _N, C = x.shape
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
x = img2windows(x, self.H_sp, self.W_sp)
x = (
Expand Down Expand Up @@ -805,7 +805,7 @@ def __init__(
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_paths[i],
drop_path=drop_paths[i], # type: ignore
act_layer=act_layer,
norm_layer=norm_layer,
rg_idx=rg_idx,
Expand Down Expand Up @@ -879,7 +879,7 @@ class UpsampleOneStep(nn.Sequential):
"""

def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
def __init__(self, scale, num_feat, num_out_ch, input_resolution):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
Expand Down Expand Up @@ -926,7 +926,7 @@ def __init__(
num_heads=[2, 2, 2, 2],
expansion_factor=4.0,
qkv_bias=True,
qk_scale=None,
qk_scale: float | None = None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
Expand All @@ -937,7 +937,6 @@ def __init__(
img_range=1.0,
resi_connection="1conv",
upsampler="pixelshuffle",
**kwargs,
):
super().__init__()

Expand Down Expand Up @@ -1028,7 +1027,7 @@ def __init__(
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if isinstance(m, nn.Linear) and m.bias is not None: # type: ignore
nn.init.constant_(m.bias, 0)
elif isinstance(
m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_DAT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from spandrel.architectures.DAT import DAT, load

from .util import assert_loads_correctly


def test_DAT_load():
assert_loads_correctly(
load,
lambda: DAT(),
lambda: DAT(embed_dim=60),
lambda: DAT(in_chans=1),
lambda: DAT(in_chans=4),
lambda: DAT(depth=[2, 3], num_heads=[2, 5]),
lambda: DAT(depth=[2, 3, 4, 2], num_heads=[2, 3, 2, 2]),
lambda: DAT(depth=[2, 3, 4, 2, 5], num_heads=[2, 3, 2, 2, 3]),
lambda: DAT(upsampler="pixelshuffle", upscale=1),
lambda: DAT(upsampler="pixelshuffle", upscale=2),
lambda: DAT(upsampler="pixelshuffle", upscale=3),
lambda: DAT(upsampler="pixelshuffle", upscale=4),
lambda: DAT(upsampler="pixelshuffle", upscale=8),
lambda: DAT(upsampler="pixelshuffledirect", upscale=1),
lambda: DAT(upsampler="pixelshuffledirect", upscale=2),
lambda: DAT(upsampler="pixelshuffledirect", upscale=3),
lambda: DAT(upsampler="pixelshuffledirect", upscale=4),
lambda: DAT(upsampler="pixelshuffledirect", upscale=8),
lambda: DAT(resi_connection="3conv"),
lambda: DAT(qkv_bias=False),
lambda: DAT(split_size=[4, 4]),
lambda: DAT(split_size=[2, 8]),
condition=lambda a, b: (
a.num_layers == b.num_layers
and a.upscale == b.upscale
and a.upsampler == b.upsampler
and a.embed_dim == b.embed_dim
and a.num_features == b.num_features
),
)

0 comments on commit 71fc7e2

Please sign in to comment.