Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate tiled inference code from SwinIR/ScuNET #14476

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 11 additions & 44 deletions extensions-builtin/ScuNET/scripts/scunet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors

from modules.shared import opts
from modules.upscaler_utils import tiled_upscale_2


class UpscalerScuNET(modules.upscaler.Upscaler):
Expand Down Expand Up @@ -40,47 +39,6 @@ def __init__(self, dirname):
scalers.append(scaler_data2)
self.scalers = scalers

@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)

device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1

stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)

with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:

for w_idx in w_idx_list:

in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]

out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)

E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)

return output

def do_upscale(self, img: PIL.Image.Image, selected_file):

devices.torch_gc()
Expand All @@ -104,7 +62,16 @@ def do_upscale(self, img: PIL.Image.Image, selected_file):
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img

torch_output = self.tiled_inference(torch_img, model).squeeze(0)
with torch.no_grad():
torch_output = tiled_upscale_2(
torch_img,
model,
tile_size=opts.SCUNET_tile,
tile_overlap=opts.SCUNET_tile_overlap,
scale=1,
device=devices.get_device_for('scunet'),
desc="ScuNET tiles",
).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
Expand Down
57 changes: 5 additions & 52 deletions extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import tiled_upscale_2

SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"

Expand Down Expand Up @@ -110,14 +110,14 @@ def upscale(
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(
output = tiled_upscale_2(
img,
model,
tile=tile,
tile_size=tile,
tile_overlap=tile_overlap,
window_size=window_size,
scale=scale,
device=device,
desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
Expand All @@ -129,53 +129,6 @@ def upscale(
return Image.fromarray(output, "RGB")


def inference(
img,
model,
*,
tile: int,
tile_overlap: int,
window_size: int,
scale: int,
device,
):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale

stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)

with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break

for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break

in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)

E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)

return output


def on_ui_settings():
import gradio as gr

Expand Down
72 changes: 71 additions & 1 deletion modules/upscaler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tqdm
from PIL import Image

from modules import images
from modules import images, shared

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,3 +68,73 @@ def upscale_with_model(
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)


def tiled_upscale_2(
img,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)

if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img.shape)
return model(img)

stride = tile_size - tile_overlap
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
result = torch.zeros(
b,
c,
h * scale,
w * scale,
device=device,
).type_as(img)
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
for h_idx in h_idx_list:
if shared.state.interrupted or shared.state.skipped:
break

for w_idx in w_idx_list:
if shared.state.interrupted or shared.state.skipped:
break

in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
]
out_patch = model(in_patch)

result[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch)

out_patch_mask = torch.ones_like(out_patch)

weights[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch_mask)

pbar.update(1)

output = result.div_(weights)

return output