Skip to content

Commit

Permalink
Improved HAT parameter detection (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Nov 23, 2023
1 parent c7724a3 commit 2e1a538
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 142 deletions.
238 changes: 128 additions & 110 deletions src/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,166 +1,184 @@
import math
import re

from torch import nn

from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict
from ..__arch_helpers.state import get_seq_len
from .arch.HAT import HAT


def _get_overlap_ratio(window_size: int, with_overlap: int) -> float:
# What we know:
# with_overlap = int(window_size + window_size * overlap_ratio)
#
# The issue is that this relationship doesn't uniquely define overlap_ratio. E.g.
# for window_size=7, overlap_ratio=0.5 and overlap_ratio=0.51 both result in
# with_overlap=10. So to get "nice" ratios, we will first try out "nice" numbers
# before falling back to the general formula.

nice_numbers = [0, 1, 0.5, 0.25, 0.75, 0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]
for ratio in nice_numbers:
if int(window_size + window_size * ratio) == with_overlap:
return ratio

# calculate the ratio and add a little something to account for rounding errors
return (with_overlap - window_size) / window_size + 0.01


def _inv_int_div(a: int, c: int) -> float:
"""
Returns a number `b` such that `a // b == c`.
"""
b_float = a / c

if b_float.is_integer():
return int(b_float)
if c == a // math.ceil(b_float):
return math.ceil(b_float)
if c == a // math.floor(b_float):
return math.floor(b_float)

# account for rounding errors
if c == a // b_float:
return b_float
if c == a // (b_float - 0.01):
return b_float - 0.01
if c == a // (b_float + 0.01):
return b_float + 0.01

raise ValueError(f"Could not find a number b such that a // b == c. a={a}, c={c}")


def load(state_dict: StateDict) -> SRModelDescriptor[HAT]:
# Defaults
img_size = 64
patch_size = 1
in_chans = 3
embed_dim = 96
depths = (6, 6, 6, 6)
num_heads = (6, 6, 6, 6)
window_size = 7
compress_ratio = 3
squeeze_factor = 30
conv_scale = 0.01 # cannot be deduced from state dict
overlap_ratio = 0.5
mlp_ratio = 4.0
qkv_bias = True
qk_scale = None
drop_rate = 0.0
attn_drop_rate = 0.0
drop_path_rate = 0.1
norm_layer = nn.LayerNorm
qk_scale = None # cannot be deduced from state dict
drop_rate = 0.0 # cannot be deduced from state dict
attn_drop_rate = 0.0 # cannot be deduced from state dict
drop_path_rate = 0.1 # cannot be deduced from state dict
ape = False
patch_norm = True
use_checkpoint = False
upscale = 2
img_range = 1.0
img_range = 1.0 # cannot be deduced from state dict
upsampler = ""
resi_connection = "1conv"
num_feat = 64

state_keys = list(state_dict.keys())
state = state_dict

num_feat = state_dict["conv_last.weight"].shape[1]
in_chans = state_dict["conv_first.weight"].shape[1]
num_out_ch = state_dict["conv_last.weight"].shape[0]
embed_dim = state_dict["conv_first.weight"].shape[0]

if "conv_before_upsample.0.weight" in state_keys:
if "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
else:
upsampler = "pixelshuffle"
elif "upsample.0.weight" in state_keys:
upsampler = "pixelshuffledirect"
if "conv_last.weight" in state_dict:
# upscaling model
upsampler = "pixelshuffle"
num_feat = state_dict["conv_last.weight"].shape[1]

upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
else:
# 1x model
upsampler = ""
upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x]

for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
elif upsampler == "pixelshuffledirect":
upscale = int(math.sqrt(state["upsample.0.bias"].shape[0] // num_out_ch))

max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(
r"layers.(\d*).residual_group.blocks.(\d*).conv_block.cab.0.weight", key
)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))

depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]

if (
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
in state_keys
):
num_heads_num = state[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[-1]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
else:
num_heads = depths
num_feat = 0 # only used for upscaling

upscale = 1

mlp_ratio = float(
state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0] / embed_dim
window_size = int(math.sqrt(state_dict["relative_position_index_SA"].shape[0]))
overlap_ratio = _get_overlap_ratio(
window_size,
with_overlap=int(math.sqrt(state_dict["relative_position_index_OCA"].shape[1])),
)

# TODO: could actually count the layers, but this should do
if "layers.0.conv.4.weight" in state_keys:
resi_connection = "3conv"
else:
# num_layers = len(depths)
num_layers = get_seq_len(state_dict, "layers")
depths = [
get_seq_len(state_dict, f"layers.{i}.residual_group.blocks")
for i in range(num_layers)
]
num_heads = [
state_dict[
f"layers.{i}.residual_group.overlap_attn.relative_position_bias_table"
].shape[1]
for i in range(num_layers)
]

if "conv_after_body.weight" in state_dict:
resi_connection = "1conv"
else:
# There is no way to decide whether it's "identity" or something else.
# So we just assume it's identity.
resi_connection = "identity"

compress_ratio = _inv_int_div(
embed_dim,
state_dict["layers.0.residual_group.blocks.0.conv_block.cab.0.weight"].shape[0],
)
squeeze_factor = _inv_int_div(
embed_dim,
state_dict[
"layers.0.residual_group.blocks.0.conv_block.cab.3.attention.1.weight"
].shape[0],
)

qkv_bias = "layers.0.residual_group.blocks.0.attn.qkv.bias" in state_dict
patch_norm = "patch_embed.norm.weight" in state_dict
ape = "absolute_pos_embed" in state_dict

window_size = int(math.sqrt(state["relative_position_index_SA"].shape[0]))

# Not sure if this is needed or used at all anywhere in HAT's config
if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
img_size = int(
math.sqrt(state["layers.0.residual_group.blocks.1.attn_mask"].shape[0])
* window_size
)

window_size = window_size
shift_size = window_size // 2
overlap_ratio = overlap_ratio

in_nc = in_chans
out_nc = num_out_ch
num_feat = num_feat
embed_dim = embed_dim
num_heads = num_heads
depths = depths
window_size = window_size
mlp_ratio = mlp_ratio
scale = upscale
upsampler = upsampler
img_size = img_size
img_range = img_range
resi_connection = resi_connection
# mlp_hidden_dim = int(embed_dim * mlp_ratio)
mlp_hidden_dim = int(
state_dict["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0]
)
mlp_ratio = mlp_hidden_dim / embed_dim

# img_size and patch_size are linked to each other and not always stored in the
# state dict. If it isn't stored, then there is no way to deduce it.
if "absolute_pos_embed" in state_dict:
# patches_resolution = img_size // patch_size
# num_patches = patches_resolution ** 2
num_patches = state_dict["absolute_pos_embed"].shape[1]
patches_resolution = int(math.sqrt(num_patches))
# we'll just assume that the patch size is 1
patch_size = 1
img_size = patches_resolution

model = HAT(
img_size=img_size,
patch_size=patch_size,
in_chans=in_nc,
in_chans=in_chans,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
ape=ape,
patch_norm=patch_norm,
use_checkpoint=use_checkpoint,
upscale=upscale,
img_range=img_range,
upsampler=upsampler,
resi_connection=resi_connection,
num_feat=num_feat,
num_out_ch=out_nc,
shift_size=shift_size,
)

head_length = len(depths) # type: ignore
if head_length <= 4:
size_tag = "small"
elif head_length < 9:
size_tag = "medium"
if len(depths) < 9:
size_tag = "small" if compress_ratio > 4 else "medium"
else:
size_tag = "large"
tags = [
Expand All @@ -178,8 +196,8 @@ def load(state_dict: StateDict) -> SRModelDescriptor[HAT]:
tags=tags,
supports_half=False,
supports_bfloat16=True,
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
scale=upscale,
input_channels=in_chans,
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)
Loading

0 comments on commit 2e1a538

Please sign in to comment.