From 3e4265feef61925e0765525164bfba993acd346f Mon Sep 17 00:00:00 2001 From: mseitzer <16725193+mseitzer@users.noreply.github.com> Date: Sat, 16 Mar 2024 18:53:51 +0100 Subject: [PATCH 1/3] Switch to pyproject.toml from setup.cfg --- .flake8 | 4 ++++ pyproject.toml | 6 ++++++ setup.cfg | 8 -------- 3 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 .flake8 create mode 100644 pyproject.toml delete mode 100644 setup.cfg diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..19a5aa9 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +select = F,W,E,I,B,B9 +ignore = W503,B950 +max-line-length = 88 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..55beee0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.isort] +profile = "black" +line_length = 88 +force_single_line = true +multi_line_output = 3 +lines_after_imports = 2 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a1b91ae..0000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[flake8] -select=F,W,E,I,B,B9 -ignore=W503,B950 -max-line-length=79 - -[isort] -multi_line_output=1 -line_length=79 From 9ec5d3cca0d71f54e1cb2ba84ed36ff760aa80dc Mon Sep 17 00:00:00 2001 From: mseitzer <16725193+mseitzer@users.noreply.github.com> Date: Sat, 16 Mar 2024 18:56:03 +0100 Subject: [PATCH 2/3] Add auto-formatting with black --- .flake8 | 2 +- noxfile.py | 2 ++ pyproject.toml | 5 +++-- setup.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.flake8 b/.flake8 index 19a5aa9..706e2b1 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] select = F,W,E,I,B,B9 -ignore = W503,B950 +ignore = W503,E203,B950 max-line-length = 88 diff --git a/noxfile.py b/noxfile.py index 8779177..7338107 100644 --- a/noxfile.py +++ b/noxfile.py @@ -8,9 +8,11 @@ def lint(session): session.install('flake8') session.install('flake8-bugbear') session.install('flake8-isort') + session.install('black==24.3.0') args = session.posargs or LOCATIONS session.run('flake8', *args) + session.run('black', '--check', '--diff', *args) @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) diff --git a/pyproject.toml b/pyproject.toml index 55beee0..fcffd43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ +[tool.black] +target-version = ["py311"] + [tool.isort] profile = "black" line_length = 88 -force_single_line = true multi_line_output = 3 -lines_after_imports = 2 diff --git a/setup.py b/setup.py index 1e22363..9e18722 100644 --- a/setup.py +++ b/setup.py @@ -51,5 +51,6 @@ def get_version(rel_path): extras_require={'dev': ['flake8', 'flake8-bugbear', 'flake8-isort', + 'black==24.3.0', 'nox']}, ) From af953a98e1e9c50a5e51d9a5294a282bfaee5e07 Mon Sep 17 00:00:00 2001 From: mseitzer <16725193+mseitzer@users.noreply.github.com> Date: Sat, 16 Mar 2024 19:03:10 +0100 Subject: [PATCH 3/3] Format with black --- noxfile.py | 29 +++--- setup.py | 57 ++++++------ src/pytorch_fid/__init__.py | 2 +- src/pytorch_fid/fid_score.py | 174 ++++++++++++++++++++--------------- src/pytorch_fid/inception.py | 79 ++++++++-------- tests/test_fid_score.py | 58 ++++++------ 6 files changed, 215 insertions(+), 184 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7338107..8dc5c82 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,28 +1,29 @@ import nox -LOCATIONS = ('src/', 'tests/', 'noxfile.py', 'setup.py') +LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py") @nox.session def lint(session): - session.install('flake8') - session.install('flake8-bugbear') - session.install('flake8-isort') - session.install('black==24.3.0') + session.install("flake8") + session.install("flake8-bugbear") + session.install("flake8-isort") + session.install("black==24.3.0") args = session.posargs or LOCATIONS - session.run('flake8', *args) - session.run('black', '--check', '--diff', *args) + session.run("flake8", *args) + session.run("black", "--check", "--diff", *args) @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) def tests(session): session.install( - 'torch==2.2.1', - 'torchvision', - '--index-url', 'https://download.pytorch.org/whl/cpu' + "torch==2.2.1", + "torchvision", + "--index-url", + "https://download.pytorch.org/whl/cpu", ) - session.install('.') - session.install('pytest') - session.install('pytest-mock') - session.run('pytest', *session.posargs) + session.install(".") + session.install("pytest") + session.install("pytest-mock") + session.run("pytest", *session.posargs) diff --git a/setup.py b/setup.py index 9e18722..8db0f25 100644 --- a/setup.py +++ b/setup.py @@ -5,52 +5,51 @@ def read(rel_path): base_path = os.path.abspath(os.path.dirname(__file__)) - with open(os.path.join(base_path, rel_path), 'r') as f: + with open(os.path.join(base_path, rel_path), "r") as f: return f.read() def get_version(rel_path): for line in read(rel_path).splitlines(): - if line.startswith('__version__'): + if line.startswith("__version__"): # __version__ = "0.9" delim = '"' if '"' in line else "'" return line.split(delim)[1] - raise RuntimeError('Unable to find version string.') + raise RuntimeError("Unable to find version string.") -if __name__ == '__main__': +if __name__ == "__main__": setuptools.setup( - name='pytorch-fid', - version=get_version(os.path.join('src', 'pytorch_fid', '__init__.py')), - author='Max Seitzer', - description=('Package for calculating Frechet Inception Distance (FID)' - ' using PyTorch'), - long_description=read('README.md'), - long_description_content_type='text/markdown', - url='https://github.com/mseitzer/pytorch-fid', - package_dir={'': 'src'}, - packages=setuptools.find_packages(where='src'), + name="pytorch-fid", + version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")), + author="Max Seitzer", + description=( + "Package for calculating Frechet Inception Distance (FID)" " using PyTorch" + ), + long_description=read("README.md"), + long_description_content_type="text/markdown", + url="https://github.com/mseitzer/pytorch-fid", + package_dir={"": "src"}, + packages=setuptools.find_packages(where="src"), classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ], - python_requires='>=3.5', + python_requires=">=3.5", entry_points={ - 'console_scripts': [ - 'pytorch-fid = pytorch_fid.fid_score:main', + "console_scripts": [ + "pytorch-fid = pytorch_fid.fid_score:main", ], }, install_requires=[ - 'numpy', - 'pillow', - 'scipy', - 'torch>=1.0.1', - 'torchvision>=0.2.2' + "numpy", + "pillow", + "scipy", + "torch>=1.0.1", + "torchvision>=0.2.2", ], - extras_require={'dev': ['flake8', - 'flake8-bugbear', - 'flake8-isort', - 'black==24.3.0', - 'nox']}, + extras_require={ + "dev": ["flake8", "flake8-bugbear", "flake8-isort", "black==24.3.0", "nox"] + }, ) diff --git a/src/pytorch_fid/__init__.py b/src/pytorch_fid/__init__.py index 0404d81..493f741 100644 --- a/src/pytorch_fid/__init__.py +++ b/src/pytorch_fid/__init__.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = "0.3.0" diff --git a/src/pytorch_fid/fid_score.py b/src/pytorch_fid/fid_score.py index 5102a4b..9c8acb2 100755 --- a/src/pytorch_fid/fid_score.py +++ b/src/pytorch_fid/fid_score.py @@ -31,6 +31,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import pathlib from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser @@ -49,30 +50,48 @@ def tqdm(x): return x + from pytorch_fid.inception import InceptionV3 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) -parser.add_argument('--batch-size', type=int, default=50, - help='Batch size to use') -parser.add_argument('--num-workers', type=int, - help=('Number of processes to use for data loading. ' - 'Defaults to `min(8, num_cpus)`')) -parser.add_argument('--device', type=str, default=None, - help='Device to use. Like cuda, cuda:0 or cpu') -parser.add_argument('--dims', type=int, default=2048, - choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), - help=('Dimensionality of Inception features to use. ' - 'By default, uses pool3 features')) -parser.add_argument('--save-stats', action='store_true', - help=('Generate an npz archive from a directory of ' - 'samples. The first path is used as input and the ' - 'second as output.')) -parser.add_argument('path', type=str, nargs=2, - help=('Paths to the generated images or ' - 'to .npz statistic files')) - -IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', - 'tif', 'tiff', 'webp'} +parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use") +parser.add_argument( + "--num-workers", + type=int, + help=( + "Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`" + ), +) +parser.add_argument( + "--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu" +) +parser.add_argument( + "--dims", + type=int, + default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=( + "Dimensionality of Inception features to use. " + "By default, uses pool3 features" + ), +) +parser.add_argument( + "--save-stats", + action="store_true", + help=( + "Generate an npz archive from a directory of " + "samples. The first path is used as input and the " + "second as output." + ), +) +parser.add_argument( + "path", + type=str, + nargs=2, + help=("Paths to the generated images or " "to .npz statistic files"), +) + +IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"} class ImagePathDataset(torch.utils.data.Dataset): @@ -85,14 +104,15 @@ def __len__(self): def __getitem__(self, i): path = self.files[i] - img = Image.open(path).convert('RGB') + img = Image.open(path).convert("RGB") if self.transforms is not None: img = self.transforms(img) return img -def get_activations(files, model, batch_size=50, dims=2048, device='cpu', - num_workers=1): +def get_activations( + files, model, batch_size=50, dims=2048, device="cpu", num_workers=1 +): """Calculates the activations of the pool_3 layer for all images. Params: @@ -115,16 +135,22 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', model.eval() if batch_size > len(files): - print(('Warning: batch size is bigger than the data size. ' - 'Setting batch size to data size')) + print( + ( + "Warning: batch size is bigger than the data size. " + "Setting batch size to data size" + ) + ) batch_size = len(files) dataset = ImagePathDataset(files, transforms=TF.ToTensor()) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) pred_arr = np.empty((len(files), dims)) @@ -143,7 +169,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', pred = pred.squeeze(3).squeeze(2).cpu().numpy() - pred_arr[start_idx:start_idx + pred.shape[0]] = pred + pred_arr[start_idx : start_idx + pred.shape[0]] = pred start_idx = start_idx + pred.shape[0] @@ -178,18 +204,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) - assert mu1.shape == mu2.shape, \ - 'Training and test mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, \ - 'Training and test covariances have different dimensions' + assert ( + mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert ( + sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): - msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) @@ -198,17 +228,17 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) + raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) - return (diff.dot(diff) + np.trace(sigma1) - + np.trace(sigma2) - 2 * tr_covmean) + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean -def calculate_activation_statistics(files, model, batch_size=50, dims=2048, - device='cpu', num_workers=1): +def calculate_activation_statistics( + files, model, batch_size=50, dims=2048, device="cpu", num_workers=1 +): """Calculation of the statistics used by the FID. Params: -- files : List of image files paths @@ -232,17 +262,18 @@ def calculate_activation_statistics(files, model, batch_size=50, dims=2048, return mu, sigma -def compute_statistics_of_path(path, model, batch_size, dims, device, - num_workers=1): - if path.endswith('.npz'): +def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1): + if path.endswith(".npz"): with np.load(path) as f: - m, s = f['mu'][:], f['sigma'][:] + m, s = f["mu"][:], f["sigma"][:] else: path = pathlib.Path(path) - files = sorted([file for ext in IMAGE_EXTENSIONS - for file in path.glob('*.{}'.format(ext))]) - m, s = calculate_activation_statistics(files, model, batch_size, - dims, device, num_workers) + files = sorted( + [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))] + ) + m, s = calculate_activation_statistics( + files, model, batch_size, dims, device, num_workers + ) return m, s @@ -251,16 +282,18 @@ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): """Calculates the FID of two paths""" for p in paths: if not os.path.exists(p): - raise RuntimeError('Invalid path: %s' % p) + raise RuntimeError("Invalid path: %s" % p) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx]).to(device) - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device, num_workers) - m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, - dims, device, num_workers) + m1, s1 = compute_statistics_of_path( + paths[0], model, batch_size, dims, device, num_workers + ) + m2, s2 = compute_statistics_of_path( + paths[1], model, batch_size, dims, device, num_workers + ) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value @@ -269,10 +302,10 @@ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): def save_fid_stats(paths, batch_size, device, dims, num_workers=1): """Saves FID statistics of one path""" if not os.path.exists(paths[0]): - raise RuntimeError('Invalid path: %s' % paths[0]) + raise RuntimeError("Invalid path: %s" % paths[0]) if os.path.exists(paths[1]): - raise RuntimeError('Existing output file: %s' % paths[1]) + raise RuntimeError("Existing output file: %s" % paths[1]) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] @@ -280,8 +313,9 @@ def save_fid_stats(paths, batch_size, device, dims, num_workers=1): print(f"Saving statistics for {paths[0]}") - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device, num_workers) + m1, s1 = compute_statistics_of_path( + paths[0], model, batch_size, dims, device, num_workers + ) np.savez_compressed(paths[1], mu=m1, sigma=s1) @@ -290,7 +324,7 @@ def main(): args = parser.parse_args() if args.device is None: - device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") else: device = torch.device(args.device) @@ -308,20 +342,14 @@ def main(): num_workers = args.num_workers if args.save_stats: - save_fid_stats(args.path, - args.batch_size, - device, - args.dims, - num_workers) + save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers) return - fid_value = calculate_fid_given_paths(args.path, - args.batch_size, - device, - args.dims, - num_workers) - print('FID: ', fid_value) + fid_value = calculate_fid_given_paths( + args.path, args.batch_size, device, args.dims, num_workers + ) + print("FID: ", fid_value) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/pytorch_fid/inception.py b/src/pytorch_fid/inception.py index 8898a20..a6fb465 100644 --- a/src/pytorch_fid/inception.py +++ b/src/pytorch_fid/inception.py @@ -10,7 +10,7 @@ # Inception weights ported to Pytorch from # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501 class InceptionV3(nn.Module): @@ -22,18 +22,20 @@ class InceptionV3(nn.Module): # Maps feature dimensionality to their output blocks indices BLOCK_INDEX_BY_DIM = { - 64: 0, # First max pooling features + 64: 0, # First max pooling features 192: 1, # Second max pooling featurs 768: 2, # Pre-aux classifier features - 2048: 3 # Final average pooling features + 2048: 3, # Final average pooling features } - def __init__(self, - output_blocks=(DEFAULT_BLOCK_INDEX,), - resize_input=True, - normalize_input=True, - requires_grad=False, - use_fid_inception=True): + def __init__( + self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True, + ): """Build pretrained InceptionV3 Parameters @@ -71,22 +73,21 @@ def __init__(self, self.output_blocks = sorted(output_blocks) self.last_needed_block = max(output_blocks) - assert self.last_needed_block <= 3, \ - 'Last possible output block index is 3' + assert self.last_needed_block <= 3, "Last possible output block index is 3" self.blocks = nn.ModuleList() if use_fid_inception: inception = fid_inception_v3() else: - inception = _inception_v3(weights='DEFAULT') + inception = _inception_v3(weights="DEFAULT") # Block 0: input to maxpool1 block0 = [ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) + nn.MaxPool2d(kernel_size=3, stride=2), ] self.blocks.append(nn.Sequential(*block0)) @@ -95,7 +96,7 @@ def __init__(self, block1 = [ inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) + nn.MaxPool2d(kernel_size=3, stride=2), ] self.blocks.append(nn.Sequential(*block1)) @@ -119,7 +120,7 @@ def __init__(self, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, - nn.AdaptiveAvgPool2d(output_size=(1, 1)) + nn.AdaptiveAvgPool2d(output_size=(1, 1)), ] self.blocks.append(nn.Sequential(*block3)) @@ -144,10 +145,7 @@ def forward(self, inp): x = inp if self.resize_input: - x = F.interpolate(x, - size=(299, 299), - mode='bilinear', - align_corners=False) + x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) if self.normalize_input: x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) @@ -166,7 +164,7 @@ def forward(self, inp): def _inception_v3(*args, **kwargs): """Wraps `torchvision.models.inception_v3`""" try: - version = tuple(map(int, torchvision.__version__.split('.')[:2])) + version = tuple(map(int, torchvision.__version__.split(".")[:2])) except ValueError: # Just a caution against weird version strings version = (0,) @@ -174,22 +172,22 @@ def _inception_v3(*args, **kwargs): # Skips default weight inititialization if supported by torchvision # version. See https://github.com/mseitzer/pytorch-fid/issues/28. if version >= (0, 6): - kwargs['init_weights'] = False + kwargs["init_weights"] = False # Backwards compatibility: `weights` argument was handled by `pretrained` # argument prior to version 0.13. - if version < (0, 13) and 'weights' in kwargs: - if kwargs['weights'] == 'DEFAULT': - kwargs['pretrained'] = True - elif kwargs['weights'] is None: - kwargs['pretrained'] = False + if version < (0, 13) and "weights" in kwargs: + if kwargs["weights"] == "DEFAULT": + kwargs["pretrained"] = True + elif kwargs["weights"] is None: + kwargs["pretrained"] = False else: raise ValueError( - 'weights=={} not supported in torchvision {}'.format( - kwargs['weights'], torchvision.__version__ + "weights=={} not supported in torchvision {}".format( + kwargs["weights"], torchvision.__version__ ) ) - del kwargs['weights'] + del kwargs["weights"] return torchvision.models.inception_v3(*args, **kwargs) @@ -203,9 +201,7 @@ def fid_inception_v3(): This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ - inception = _inception_v3(num_classes=1008, - aux_logits=False, - weights=None) + inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None) inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) @@ -223,6 +219,7 @@ def fid_inception_v3(): class FIDInceptionA(torchvision.models.inception.InceptionA): """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): super(FIDInceptionA, self).__init__(in_channels, pool_features) @@ -238,8 +235,9 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False + ) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] @@ -248,6 +246,7 @@ def forward(self, x): class FIDInceptionC(torchvision.models.inception.InceptionC): """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): super(FIDInceptionC, self).__init__(in_channels, channels_7x7) @@ -266,8 +265,9 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False + ) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] @@ -276,6 +276,7 @@ def forward(self, x): class FIDInceptionE_1(torchvision.models.inception.InceptionE): """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): super(FIDInceptionE_1, self).__init__(in_channels) @@ -299,8 +300,9 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False + ) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] @@ -309,6 +311,7 @@ def forward(self, x): class FIDInceptionE_2(torchvision.models.inception.InceptionE): """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): super(FIDInceptionE_2, self).__init__(in_channels) diff --git a/tests/test_fid_score.py b/tests/test_fid_score.py index 6728b39..4e4b6a1 100644 --- a/tests/test_fid_score.py +++ b/tests/test_fid_score.py @@ -8,7 +8,7 @@ @pytest.fixture def device(): - return torch.device('cpu') + return torch.device("cpu") def test_calculate_fid_given_statistics(mocker, tmp_path, device): @@ -17,31 +17,30 @@ def test_calculate_fid_given_statistics(mocker, tmp_path, device): sigma = np.eye(dim) def dummy_statistics(path, model, batch_size, dims, device, num_workers): - if path.endswith('1'): + if path.endswith("1"): return m1, sigma - elif path.endswith('2'): + elif path.endswith("2"): return m2, sigma else: raise ValueError - mocker.patch('pytorch_fid.fid_score.compute_statistics_of_path', - side_effect=dummy_statistics) + mocker.patch( + "pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics + ) - dir_names = ['1', '2'] + dir_names = ["1", "2"] paths = [] for name in dir_names: path = tmp_path / name path.mkdir() paths.append(str(path)) - fid_value = fid_score.calculate_fid_given_paths(paths, - batch_size=dim, - device=device, - dims=dim, - num_workers=0) + fid_value = fid_score.calculate_fid_given_paths( + paths, batch_size=dim, device=device, dims=dim, num_workers=0 + ) # Given equal covariance, FID is just the squared norm of difference - assert fid_value == np.sum((m1 - m2)**2) + assert fid_value == np.sum((m1 - m2) ** 2) def test_compute_statistics_of_path(mocker, tmp_path, device): @@ -54,14 +53,17 @@ def test_compute_statistics_of_path(mocker, tmp_path, device): paths = [] for idx, image in enumerate(images): - paths.append(str(tmp_path / '{}.png'.format(idx))) - Image.fromarray(image, mode='RGB').save(paths[-1]) - - stats = fid_score.compute_statistics_of_path(str(tmp_path), model, - batch_size=len(images), - dims=3, - device=device, - num_workers=0) + paths.append(str(tmp_path / "{}.png".format(idx))) + Image.fromarray(image, mode="RGB").save(paths[-1]) + + stats = fid_score.compute_statistics_of_path( + str(tmp_path), + model, + batch_size=len(images), + dims=3, + device=device, + num_workers=0, + ) assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) @@ -73,15 +75,13 @@ def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): mu = np.random.randn(5) sigma = np.random.randn(5, 5) - path = tmp_path / 'stats.npz' - with path.open('wb') as f: + path = tmp_path / "stats.npz" + with path.open("wb") as f: np.savez(f, mu=mu, sigma=sigma) - stats = fid_score.compute_statistics_of_path(str(path), model, - batch_size=1, - dims=5, - device=device, - num_workers=0) + stats = fid_score.compute_statistics_of_path( + str(path), model, batch_size=1, dims=5, device=device, num_workers=0 + ) assert np.allclose(stats[0], mu) assert np.allclose(stats[1], sigma) @@ -89,11 +89,11 @@ def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): def test_image_types(tmp_path): in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255 - in_image = Image.fromarray(in_arr, mode='RGB') + in_image = Image.fromarray(in_arr, mode="RGB") paths = [] for ext in fid_score.IMAGE_EXTENSIONS: - paths.append(str(tmp_path / 'img.{}'.format(ext))) + paths.append(str(tmp_path / "img.{}".format(ext))) in_image.save(paths[-1]) dataset = fid_score.ImagePathDataset(paths)