Skip to content

Commit

Permalink
address #282
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 26, 2024
1 parent 0588213 commit 88c87e1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 29 deletions.
49 changes: 26 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,41 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
entry_points={
'console_scripts': [
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
],
'console_scripts': [
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
],
},
version = __version__,
license='GPLv3+',
license = 'MIT',
description = 'StyleGan2 in Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/stylegan2-pytorch',
download_url = 'https://github.com/lucidrains/stylegan2-pytorch/archive/v_036.tar.gz',
keywords = ['generative adversarial networks', 'artificial intelligence'],
keywords = [
'generative adversarial networks',
'artificial intelligence'
],
install_requires=[
'aim',
'einops',
'contrastive_learner>=0.1.0',
'fire',
'kornia>=0.5.4',
'numpy',
'retry',
'tqdm',
'torch',
'torchvision',
'pillow',
'vector-quantize-pytorch==0.1.0'
'aim',
'einops>=0.7.0',
'contrastive_learner>=0.1.0',
'fire',
'kornia>=0.5.4',
'numpy',
'retry',
'tqdm',
'torch',
'torchvision',
'pillow',
'vector-quantize-pytorch==0.1.0'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)
10 changes: 5 additions & 5 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,8 +1251,8 @@ def calculate_fid(self, num_batches):
return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048)

@torch.no_grad()
def truncate_style(self, tensor, trunc_psi = 0.75):
S = self.GAN.S
def truncate_style(self, tensor, S = None, trunc_psi = 0.75):
S = default(S, self.GAN.S)
batch_size = self.batch_size
latent_dim = self.GAN.G.latent_dim

Expand All @@ -1267,17 +1267,17 @@ def truncate_style(self, tensor, trunc_psi = 0.75):
return tensor

@torch.no_grad()
def truncate_style_defs(self, w, trunc_psi = 0.75):
def truncate_style_defs(self, w, S = None, trunc_psi = 0.75):
w_space = []
for tensor, num_layers in w:
tensor = self.truncate_style(tensor, trunc_psi = trunc_psi)
tensor = self.truncate_style(tensor, S = S, trunc_psi = trunc_psi)
w_space.append((tensor, num_layers))
return w_space

@torch.no_grad()
def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
w = map(lambda t: (S(t[0]), t[1]), style)
w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi)
w_truncated = self.truncate_style_defs(w, S = S, trunc_psi = trunc_psi)
w_styles = styles_def_to_tensor(w_truncated)
generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
return generated_images.clamp_(0., 1.)
Expand Down
2 changes: 1 addition & 1 deletion stylegan2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.8.9'
__version__ = '1.8.10'

0 comments on commit 88c87e1

Please sign in to comment.