diff --git a/setup.py b/setup.py index f171253..f04c9ac 100644 --- a/setup.py +++ b/setup.py @@ -8,38 +8,41 @@ name = 'stylegan2_pytorch', packages = find_packages(), entry_points={ - 'console_scripts': [ - 'stylegan2_pytorch = stylegan2_pytorch.cli:main', - ], + 'console_scripts': [ + 'stylegan2_pytorch = stylegan2_pytorch.cli:main', + ], }, version = __version__, - license='GPLv3+', + license = 'MIT', description = 'StyleGan2 in Pytorch', long_description_content_type = 'text/markdown', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/stylegan2-pytorch', download_url = 'https://github.com/lucidrains/stylegan2-pytorch/archive/v_036.tar.gz', - keywords = ['generative adversarial networks', 'artificial intelligence'], + keywords = [ + 'generative adversarial networks', + 'artificial intelligence' + ], install_requires=[ - 'aim', - 'einops', - 'contrastive_learner>=0.1.0', - 'fire', - 'kornia>=0.5.4', - 'numpy', - 'retry', - 'tqdm', - 'torch', - 'torchvision', - 'pillow', - 'vector-quantize-pytorch==0.1.0' + 'aim', + 'einops>=0.7.0', + 'contrastive_learner>=0.1.0', + 'fire', + 'kornia>=0.5.4', + 'numpy', + 'retry', + 'tqdm', + 'torch', + 'torchvision', + 'pillow', + 'vector-quantize-pytorch==0.1.0' ], classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', ], -) \ No newline at end of file +) diff --git a/stylegan2_pytorch/stylegan2_pytorch.py b/stylegan2_pytorch/stylegan2_pytorch.py index d3f50dc..ada96d6 100644 --- a/stylegan2_pytorch/stylegan2_pytorch.py +++ b/stylegan2_pytorch/stylegan2_pytorch.py @@ -1251,8 +1251,8 @@ def calculate_fid(self, num_batches): return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048) @torch.no_grad() - def truncate_style(self, tensor, trunc_psi = 0.75): - S = self.GAN.S + def truncate_style(self, tensor, S = None, trunc_psi = 0.75): + S = default(S, self.GAN.S) batch_size = self.batch_size latent_dim = self.GAN.G.latent_dim @@ -1267,17 +1267,17 @@ def truncate_style(self, tensor, trunc_psi = 0.75): return tensor @torch.no_grad() - def truncate_style_defs(self, w, trunc_psi = 0.75): + def truncate_style_defs(self, w, S = None, trunc_psi = 0.75): w_space = [] for tensor, num_layers in w: - tensor = self.truncate_style(tensor, trunc_psi = trunc_psi) + tensor = self.truncate_style(tensor, S = S, trunc_psi = trunc_psi) w_space.append((tensor, num_layers)) return w_space @torch.no_grad() def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8): w = map(lambda t: (S(t[0]), t[1]), style) - w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi) + w_truncated = self.truncate_style_defs(w, S = S, trunc_psi = trunc_psi) w_styles = styles_def_to_tensor(w_truncated) generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi) return generated_images.clamp_(0., 1.) diff --git a/stylegan2_pytorch/version.py b/stylegan2_pytorch/version.py index 20947b6..8d94f63 100644 --- a/stylegan2_pytorch/version.py +++ b/stylegan2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.9' +__version__ = '1.8.10'