Skip to content

Commit

Permalink
Merge branch 'main' into rename-size
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Nov 18, 2023
2 parents 9803f30 + 35dae6f commit 175a566
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/spandrel/__helpers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def load_state_dict_from_file(self, path: str | Path) -> StateDict:
"""
Load the state dict of a model from the given file path.
State dicts are typically only useful to pass them into the `load` function of a specific architecture.
State dicts are typically only useful to pass them into the `load`
function of a specific architecture.
Throws a `ValueError` if the file extension is not supported.
"""
Expand Down
66 changes: 66 additions & 0 deletions src/spandrel/__helpers/model_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,29 @@ class SizeRequirements:
def none(self) -> bool:
"""
Returns True if no size requirements are specified.
If True, then `check` is guaranteed to always return True.
"""
return self.minimum is None and self.multiple_of is None and not self.square

def check(self, width: int, height: int) -> bool:
"""
Checks if the given width and height satisfy the size requirements.
"""
if self.minimum is not None:
if width < self.minimum or height < self.minimum:
return False

if self.multiple_of is not None:
if width % self.multiple_of != 0 or height % self.multiple_of != 0:
return False

if self.square:
if width != height:
return False

return True


class ModelBase(ABC, Generic[T]):
def __init__(
Expand All @@ -49,19 +69,58 @@ def __init__(
size_requirements: SizeRequirements | None = None,
):
self.model: T = model
"""
The model itself: a `torch.nn.Module` with weights loaded in.
The specific subclass of `torch.nn.Module` depends on the model architecture.
"""
self.state_dict: StateDict = state_dict
"""
The state dict of the model (weights and biases).
"""
self.architecture: str = architecture
"""
The name of the model architecture. E.g. "ESRGAN".
"""
self.tags: list[str] = tags
"""
A list of tags for the model, usually describing the size or model
parameters. E.g. "64nf" or "large".
Tags are specific to the architecture of the model. Some architectures
may not have any tags.
"""
self.supports_half: bool = supports_half
"""
Whether the model supports half precision (fp16).
"""
self.supports_bfloat16: bool = supports_bfloat16
"""
Whether the model supports bfloat16 precision.
"""

self.scale: int = scale
"""
The output scale of super resolution models. E.g. 4x, 2x, 1x.
Models that are not super resolution models (e.g. denoisers) have a
scale of 1.
"""
self.input_channels: int = input_channels
"""
The number of input image channels of the model. E.g. 3 for RGB, 1 for grayscale.
"""
self.output_channels: int = output_channels
"""
The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale.
"""

self.size_requirements: SizeRequirements = (
size_requirements or SizeRequirements()
)
"""
Size requirements for the input image. E.g. minimum size.
"""

self.model.load_state_dict(state_dict) # type: ignore

Expand Down Expand Up @@ -138,3 +197,10 @@ def __init__(
InpaintModelDescriptor,
RestorationModelDescriptor,
]
"""
A model descriptor is a loaded model with metadata. Metadata includes the
architecture, purpose, tags, and other information about the model.
The purpose of a model is described by the type of the model descriptor. E.g.
a super resolution model has a descriptor of type `SRModelDescriptor`.
"""
2 changes: 2 additions & 0 deletions src/spandrel/__helpers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def visit(arch: ArchSupport):
def load(self, state_dict: StateDict) -> ModelDescriptor:
"""
Detects the architecture of the given state dict and loads it.
Throws an `UnsupportedModelError` if the model architecture is not supported.
"""

if "params_ema" in state_dict:
Expand Down

0 comments on commit 175a566

Please sign in to comment.