Skip to content

Commit

Permalink
load_spandrel_model: make half prefer_half
Browse files Browse the repository at this point in the history
As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
  • Loading branch information
akx committed Jan 2, 2024
1 parent 51f1cca commit 2cacbc1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
20 changes: 14 additions & 6 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,31 @@ def load_upscalers():


def load_spandrel_model(
path: str,
path: str | os.PathLike,
*,
device: str | torch.device | None,
half: bool = False,
prefer_half: bool = False,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
import spandrel
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
model_descriptor.model.half()
half = False
if prefer_half:
if model_descriptor.supports_half:
model_descriptor.model.half()
half = True
else:
logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype:
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
logger.debug(
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
model_descriptor, path, device, half, dtype,
)
return model_descriptor
2 changes: 1 addition & 1 deletion modules/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def do_upscale(self, img, path):
model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
Expand Down

0 comments on commit 2cacbc1

Please sign in to comment.