Skip to content

Commit

Permalink
Use heuristic for undetectable GRLIR params in some cases (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Nov 26, 2023
1 parent 5970791 commit 774dc98
Show file tree
Hide file tree
Showing 27 changed files with 227 additions and 418 deletions.
37 changes: 23 additions & 14 deletions src/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,29 @@ def _detect(state_dict: StateDict) -> bool:
),
ArchSupport(
id="GRLIR",
detect=lambda state: _has_keys(
"conv_first.weight",
"norm_start.weight",
"norm_end.weight",
"layers.0.blocks.0.attn.window_attn.attn_transform.logit_scale",
"layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state)
or _has_keys(
"model.conv_first.weight",
"model.norm_start.weight",
"model.norm_end.weight",
"model.layers.0.blocks.0.attn.window_attn.attn_transform.logit_scale",
"model.layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state),
detect=lambda state: (
_has_keys(
"conv_first.weight",
"norm_start.weight",
"norm_end.weight",
"layers.0.blocks.0.attn.window_attn.attn_transform.logit_scale",
"layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state)
or _has_keys(
"model.conv_first.weight",
"model.norm_start.weight",
"model.norm_end.weight",
"model.layers.0.blocks.0.attn.window_attn.attn_transform.logit_scale",
"model.layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state)
or _has_keys(
"model_g.conv_first.weight",
"model_g.norm_start.weight",
"model_g.norm_end.weight",
"model_g.layers.0.blocks.0.attn.window_attn.attn_transform.logit_scale",
"model_g.layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state)
),
load=GRLIR.load,
),
ArchSupport(
Expand Down
64 changes: 55 additions & 9 deletions src/spandrel/architectures/GRLIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,23 @@
from .arch.grl import GRL as GRLIR


def _get_output_params(state_dict: StateDict, in_channels: int) -> tuple[int, str, int]:
def _clean_up_checkpoint(state_dict: StateDict) -> StateDict:
# The official checkpoints are all over the place.

# Issue 1: some models prefix all keys with "model."
state_dict = remove_common_prefix(state_dict, ["model."])

# Issue 2: some models have a bunch of useless keys and prefix all important keys with "model_g."
# (looking at you, `bsr_grl_base.ckpt`)
if "model_g.conv_first.weight" in state_dict:
# only keep keys with "model_g." prefix
state_dict = {k: v for k, v in state_dict.items() if k.startswith("model_g.")}
state_dict = remove_common_prefix(state_dict, ["model_g."])

return state_dict


def _get_output_params(state_dict: StateDict, in_channels: int):
out_channels: int
upsampler: str
upscale: int
Expand Down Expand Up @@ -43,7 +59,9 @@ def _get_output_params(state_dict: StateDict, in_channels: int) -> tuple[int, st
return out_channels, upsampler, upscale


def _get_anchor_params(state_dict: StateDict) -> tuple[bool, str, int]:
def _get_anchor_params(
state_dict: StateDict, default_down_factor: int
) -> tuple[bool, str, int]:
anchor_one_stage: bool
anchor_proj_type: str
anchor_window_down_factor: int
Expand All @@ -57,11 +75,11 @@ def _get_anchor_params(state_dict: StateDict) -> tuple[bool, str, int]:
# We can deduce neither proj_type nor window_down_factor.
# So we'll just assume the values the official configs use.
anchor_proj_type = "avgpool" # or "maxpool", who knows?
anchor_window_down_factor = 2
anchor_window_down_factor = default_down_factor
else:
anchor_proj_type = "patchmerging"
# window_down_factor is technically undefined here, but 2 makes sense
anchor_window_down_factor = 2
# window_down_factor is undefined here
anchor_window_down_factor = default_down_factor
elif "layers.0.blocks.0.attn.anchor.body.0.weight" in state_dict:
anchor_proj_type = "conv2d"
anchor_window_down_factor = (
Expand All @@ -87,8 +105,7 @@ def _get_anchor_params(state_dict: StateDict) -> tuple[bool, str, int]:


def load(state_dict: StateDict) -> SRModelDescriptor[GRLIR]:
# Delightfully, we have to remove a common prefix from the state_dict.
state_dict = remove_common_prefix(state_dict, ["model."])
state_dict = _clean_up_checkpoint(state_dict)

img_size: int = 64
# in_channels: int = 3
Expand Down Expand Up @@ -161,13 +178,33 @@ def load(state_dict: StateDict) -> SRModelDescriptor[GRLIR]:

# anchor
anchor_one_stage, anchor_proj_type, anchor_window_down_factor = _get_anchor_params(
state_dict
state_dict,
# We use 4 as the default value (if the value cannot be detected), because
# that's what all the official models use.
default_down_factor=4,
)

# other
local_connection = "layers.0.blocks.0.conv.cab.0.weight" in state_dict
mlp_ratio = state_dict["layers.0.blocks.0.mlp.fc1.weight"].shape[0] / embed_dim

# Set undetectable parameters.
# These parameters are huge pain, because they vary widely between models, so we'll
# just use some heuristics to support the official models, and call it a day.
if upscale == 1:
# denoise (dn), deblur (db), demosaic (dm), or jpeg
pass
else:
# sr or bsr
if upsampler == "nearest+conv":
# bsr
window_size = 16
stripe_size = [32, 64]
else:
# sr
window_size = 32
stripe_size = [64, 64]

model = GRLIR(
img_size=img_size,
in_channels=in_channels,
Expand Down Expand Up @@ -200,11 +237,20 @@ def load(state_dict: StateDict) -> SRModelDescriptor[GRLIR]:
euclidean_dist=euclidean_dist,
)

size_tag = "base"
if len(depths) < 6:
size_tag = "small" if embed_dim >= 96 else "tiny"

return SRModelDescriptor(
model,
state_dict,
architecture="GRLIR",
tags=[],
tags=[
size_tag,
f"{embed_dim}dim",
f"w{window_size}df{anchor_window_down_factor}",
f"s{stripe_size[0]}x{stripe_size[1]}",
],
supports_half=False,
supports_bfloat16=True,
scale=upscale,
Expand Down
207 changes: 4 additions & 203 deletions src/spandrel/architectures/GRLIR/arch/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,15 @@ def __init__(
for layer in self.layers:
layer._init_weights()

def set_table_index_mask(self, x_size):
def set_table_index_mask(self, x_size: tuple[int, int]):
"""
Two used cases:
1) At initialization: set the shared buffers.
2) During forward pass: get the new buffers if the resolution of the input changes
"""
# ss - stripe_size, sss - stripe_shift_size
# ss ~= self.stripe_size
# sss ~= self.stripe_size / 2
ss, sss = _get_stripe_info(self.stripe_size, self.stripe_groups, True, x_size)
df = self.anchor_window_down_factor

Expand Down Expand Up @@ -436,7 +438,7 @@ def set_table_index_mask(self, x_size):
"mask_sv_w2a": mask_sv_w2a,
}

def get_table_index_mask(self, device=None, input_resolution=None):
def get_table_index_mask(self, device, input_resolution: tuple[int, int]):
# Used during forward pass
if input_resolution == self.input_resolution:
return {
Expand Down Expand Up @@ -575,204 +577,3 @@ def convert_checkpoint(self, state_dict):
state_dict.pop(k)
print(k)
return state_dict


if __name__ == "__main__":
window_size = 8

# Tiny, 0.33 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[4, 4, 4, 4],
# embed_dim=32,
# num_heads_window=[2, 2, 2, 2],
# num_heads_stripe=[2, 2, 2, 2],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffledirect",
# )

# Small, 3.49 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[4, 4, 4, 4],
# embed_dim=128,
# num_heads_window=[2, 2, 2, 2],
# num_heads_stripe=[2, 2, 2, 2],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Base, 13.84 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[4, 4, 4, 4, 4, 4, 4, 4],
# embed_dim=192,
# num_heads_window=[4, 4, 4, 4, 4, 4, 4, 4],
# num_heads_stripe=[4, 4, 4, 4, 4, 4, 4, 4],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Large, 24.29 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[8, 8, 8, 8, 8, 8, 8, 8],
# embed_dim=192,
# num_heads_window=[4, 4, 4, 4, 4, 4, 4, 4],
# num_heads_stripe=[4, 4, 4, 4, 4, 4, 4, 4],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Huge, 47.83 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
# embed_dim=192,
# num_heads_window=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# num_heads_stripe=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Giant, MLP4 - 117.16 M, MLP2 - 83.54 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
# embed_dim=256,
# num_heads_window=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# num_heads_stripe=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# mlp_ratio=4,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Compare with HAT Large, 43.22 M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
# embed_dim=192,
# num_heads_window=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# num_heads_stripe=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

####################
# Final version
####################

# Tiny-final, 0.91M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[4, 4, 4, 4],
# embed_dim=64,
# num_heads_window=[2, 2, 2, 2],
# num_heads_stripe=[2, 2, 2, 2],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffledirect",
# )

# Small-final, 3.49M
# model = GRL(
# upscale=4,
# img_size=64,
# window_size=window_size,
# depths=[4, 4, 4, 4],
# embed_dim=128,
# num_heads_window=[2, 2, 2, 2],
# num_heads_stripe=[2, 2, 2, 2],
# mlp_ratio=2,
# qkv_proj_type="linear",
# anchor_proj_type="avgpool",
# anchor_window_down_factor=2,
# out_proj_type="linear",
# conv_type="1conv",
# upsampler="pixelshuffle",
# )

# Large, 20.13 M
model = GRL(
upscale=4,
img_size=64,
window_size=window_size,
depths=[4, 4, 8, 8, 8, 4, 4],
embed_dim=180,
num_heads_window=[3, 3, 3, 3, 3, 3, 3],
num_heads_stripe=[3, 3, 3, 3, 3, 3, 3],
mlp_ratio=2,
qkv_proj_type="linear",
anchor_proj_type="avgpool",
anchor_window_down_factor=2,
out_proj_type="linear",
conv_type="1conv",
upsampler="pixelshuffle",
local_connection=True,
)

print(model)
# print(height, width, model.flops() / 1e9)

x = torch.randn((1, 3, 64, 64))
x = model(x)
print(x.shape)
num_params = 0
for p in model.parameters():
if p.requires_grad:
num_params += p.numel()
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
Loading

0 comments on commit 774dc98

Please sign in to comment.