diff --git a/src/spandrel/architectures/Compact/__init__.py b/src/spandrel/architectures/Compact/__init__.py index 5af01c82..1f7c21ce 100644 --- a/src/spandrel/architectures/Compact/__init__.py +++ b/src/spandrel/architectures/Compact/__init__.py @@ -21,22 +21,11 @@ def _get_scale_and_output_channels(x: int, input_channels: int) -> tuple[int, in def is_square(n: int) -> bool: return math.sqrt(n) == int(math.sqrt(n)) - def perfect_square_root(n: int) -> int: - root = int(math.sqrt(n)) - assert root * root == n, f"{n} is not a perfect square" - return root - - if x % 3 == 0 and x % 9 != 0: - # we know that output_channels MUST be a multiple of 3 - # so let's assume that output_channels is exactly 3 - x = x // 3 - return perfect_square_root(x), 3 - # just try out a few candidates and see which ones fulfill the requirements candidates = [input_channels, 3, 4, 1] for c in candidates: if x % c == 0 and is_square(x // c): - return perfect_square_root(x // c), c + return int(math.sqrt(x // c)), c raise AssertionError( f"Expected output channels to be either 1, 3, or 4. Could not find a a pair (s, o) such that s*s*o = {x}"