Skip to content

Commit

Permalink
Fixed upscale
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 23, 2023
1 parent 8152764 commit 7fbc3ad
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/spandrel/architectures/SwiftSRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def load(state: StateDict) -> SRModelDescriptor[SwiftSRGAN]:
in_channels = state["initial.cnn.depthwise.weight"].shape[0]
num_channels = state["initial.cnn.pointwise.weight"].shape[0]
num_blocks = get_seq_len(state, "residual")
upscale_factor = get_seq_len(state, "upsampler") * 2
upscale_factor = 2 ** get_seq_len(state, "upsampler")

model = SwiftSRGAN(
in_channels=in_channels,
Expand Down
4 changes: 3 additions & 1 deletion src/spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# type: ignore
# From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py

import math

import torch
from torch import nn

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(
self.upsampler = nn.Sequential(
*[
UpsampleBlock(num_channels, scale_factor=2)
for _ in range(upscale_factor // 2)
for _ in range(int(math.log2(upscale_factor)))
]
)
self.final_conv = SeperableConv2d(
Expand Down

0 comments on commit 7fbc3ad

Please sign in to comment.