Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define the public API by what's documented #287

Merged
merged 6 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
7 changes: 4 additions & 3 deletions libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,10 @@ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):

@store_hyperparameters()
class ATD(nn.Module):
r"""ATD
A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer
with Adaptive Token Dictionary`.
r"""
ATD

A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary`.

Args:
img_size (int | tuple(int)): Input image size. Default 64
Expand Down
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/ATD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=8),
)


__all__ = ["ATDArch", "ATD"]
4 changes: 3 additions & 1 deletion libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def forward(self, biases):


class Attention_regular(nn.Module):
"""Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias.
"""
Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias.
It supports both of shifted and non-shifted window.

Args:
dim (int): Number of input channels.
resolution (int): Input resolution.
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/CRAFT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CRAFT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16, multiple_of=16),
)


__all__ = ["CRAFTArch", "CRAFT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Compact]:
input_channels=in_nc,
output_channels=out_nc,
)


__all__ = ["CompactArch", "Compact"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DATArch", "DAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DCTLSA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DCTLSAArch", "DCTLSA"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DITN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DITN]:
output_channels=3, # hard-coded in the architecture
size_requirements=SizeRequirements(multiple_of=patch_size),
)


__all__ = ["DITNArch", "DITN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRCT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]:
output_channels=in_chans,
size_requirements=SizeRequirements(multiple_of=16),
)


__all__ = ["DRCTArch", "DRCT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ def call(model: DRUNet, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=8),
call_fn=call,
)


__all__ = ["DRUNetArch", "DRUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DnCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,6 @@ def call(model: DnCNN, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(),
call_fn=call,
)


__all__ = ["DnCNNArch", "DnCNN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/ESRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ESRGAN]:
multiple_of=4 if shuffle_factor else 1,
),
)


__all__ = ["ESRGANArch", "ESRGAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FBCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FBCNN]:
output_channels=out_nc,
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["FBCNNArch", "FBCNN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FFTformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]:
output_channels=out_channels,
size_requirements=SizeRequirements(multiple_of=32),
)


__all__ = ["FFTformerArch", "FFTformer"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/GFPGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGAN]:
size_requirements=SizeRequirements(minimum=512),
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["GFPGANArch", "GFPGAN"]
Empty file.
29 changes: 18 additions & 11 deletions libs/spandrel/spandrel/architectures/GRL/__arch/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@


class TransformerStage(nn.Module):
"""Transformer stage.
"""
Transformer stage.

Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
Expand All @@ -58,11 +60,13 @@ class TransformerStage(nn.Module):
pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
conv_type: The convolutional block before residual connection.
init_method: initialization method of the weight parameters used to train large scale models.
Choices: n, normal -- Swin V1 init method.
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale

Choices:
* n, normal -- Swin V1 init method.
* l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
* r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
* w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
* t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
offload_to_cpu (bool): used by fairscale_checkpoint
args:
Expand Down Expand Up @@ -185,6 +189,7 @@ def flops(self):
@store_hyperparameters()
class GRL(nn.Module):
r"""Image restoration transformer with global, non-local, and local connections

Args:
img_size (int | list[int]): Input image size. Default 64
in_channels (int): Number of input image channels. Default: 3
Expand Down Expand Up @@ -216,11 +221,13 @@ class GRL(nn.Module):
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear
init_method: initialization method of the weight parameters used to train large scale models.
Choices: n, normal -- Swin V1 init method.
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale

Choices:
* n, normal -- Swin V1 init method.
* l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
* r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
* w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
* t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
offload_to_cpu (bool): used by fairscale_checkpoint
euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
Expand Down
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/GRL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GRL]:
input_channels=in_channels,
output_channels=out_channels,
)


__all__ = ["GRLArch", "GRL"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["HATArch", "HAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]:
size_requirements=SizeRequirements(multiple_of=8),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["HVICIDNetArch", "HVICIDNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/IPT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ def call(model: IPT, x: torch.Tensor):
size_requirements=SizeRequirements(minimum=patch_size),
call_fn=call,
)


__all__ = ["IPTArch", "IPT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/KBNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_KBNet]:
return self._load_l(state_dict)
else:
return self._load_s(state_dict)


__all__ = ["KBNetArch", "KBNet_s", "KBNet_l"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/LaMa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16, multiple_of=8),
)


__all__ = ["LaMaArch", "LaMa"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MMRealSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MMRealSR]:
size_requirements=SizeRequirements(minimum=16),
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["MMRealSRArch", "MMRealSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MixDehazeNet]:
tiling=ModelTiling.DISCOURAGED,
call_fn=lambda model, image: model(image) * 0.5 + 0.5,
)


__all__ = ["MixDehazeNetArch", "MixDehazeNet"]
20 changes: 10 additions & 10 deletions libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

"""
Simple Baselines for Image Restoration

@article{chen2022simple,
title={Simple Baselines for Image Restoration},
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
journal={arXiv preprint arXiv:2204.04676},
year={2022}
}
"""
# """
# Simple Baselines for Image Restoration

# @article{chen2022simple,
# title={Simple Baselines for Image Restoration},
# author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
# journal={arXiv preprint arXiv:2204.04676},
# year={2022}
# }
# """

from __future__ import annotations

Expand Down
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/NAFNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[NAFNet]:
input_channels=img_channel,
output_channels=img_channel,
)


__all__ = ["NAFNetArch", "NAFNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[OmniSR]:
output_channels=num_out_ch,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["OmniSRArch", "OmniSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
input_channels=3,
output_channels=3,
)


__all__ = ["PLKSRArch", "PLKSR", "RealPLKSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RGT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["RGTArch", "RGT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]:
output_channels=out_channels,
size_requirements=size_requirements,
)


__all__ = ["RealCUGANArch", "UpCunet2x", "UpCunet3x", "UpCunet4x", "UpCunet2x_fast"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________

see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py

Args:
n_e : number of embeddings
e_dim : dimension of embedding
beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
"""

def __init__(self, n_e, e_dim, beta):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ def call(model: RestoreFormer, x: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=32),
call_fn=call,
)


__all__ = ["RestoreFormerArch", "RestoreFormer"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RetinexFormer]:
tiling=ModelTiling.DISCOURAGED,
call_fn=_call_fn,
)


__all__ = ["RetinexFormerArch", "RetinexFormer"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SAFMN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN]:
output_channels=3, # hard-coded in the arch
size_requirements=SizeRequirements(multiple_of=8),
)


__all__ = ["SAFMNArch", "SAFMN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMNBCIE]:
output_channels=3, # hard-coded in the arch
size_requirements=SizeRequirements(multiple_of=16),
)


__all__ = ["SAFMNBCIEArch", "SAFMNBCIE"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SCUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SCUNet]:
size_requirements=SizeRequirements(minimum=40),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["SCUNetArch", "SCUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SPAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SPAN]:
input_channels=num_in_ch,
output_channels=num_out_ch,
)


__all__ = ["SPANArch", "SPAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwiftSRGAN]:
input_channels=in_channels,
output_channels=in_channels,
)


__all__ = ["SwiftSRGANArch", "SwiftSRGAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["Swin2SRArch", "Swin2SR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["SwinIRArch", "SwinIR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Uformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]:
output_channels=dd_in,
size_requirements=SizeRequirements(multiple_of=128, square=True),
)


__all__ = ["UformerArch", "Uformer"]
2 changes: 2 additions & 0 deletions libs/spandrel/spandrel/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""
The package containing the implementations of all supported architectures. Not necessary for most user code.
"""

__docformat__ = "google"
Loading
Loading