Skip to content

Commit 2fe55c8

Browse files
author
jun
committed
version 0.0.1
1 parent 88ca26b commit 2fe55c8

15 files changed

+85
-38
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
Simple torch module implementation of [Alias-Free GAN](https://nvlabs.github.io/alias-free-gan/).
44

55
This repository including
6-
- Alias-Free GAN style lowpass sinc filter @[filter.py](/filter.py)
6+
- Alias-Free GAN style lowpass sinc filter @[filter.py](/src/alias_free_torch/filter.py)
77

8-
- Alias-Free GAN style up/downsample @[resample.py](/resample.py)
8+
- Alias-Free GAN style up/downsample @[resample.py](/src/alias_free_torch/resample.py)
99

10-
- Alias-Free activation @[act.py](/act.py)
10+
- Alias-Free activation @[act.py](/src/alias_free_torch/act.py)
1111

12-
- and test codes @[./test](/test)
12+
- and test codes @[./test](/src/alias_free_torch/test)
1313

1414
**Note: Since this repository is unofficial, filter and upsample could be different with [official implementation](https://github.com/NVlabs/stylegan3).**
1515

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[build-system]
2+
requires = [
3+
"setuptools>=42",
4+
"wheel"
5+
]
6+
build-backend = "setuptools.build_meta"

setup.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from setuptools import setup, find_packages
2+
3+
with open("README.md", "r", encoding="utf-8") as fh:
4+
long_description = fh.read()
5+
6+
7+
setup(
8+
name = 'alias_free_torch',
9+
version = '0.0.1',
10+
description = 'alias free torch',
11+
long_description=long_description,
12+
long_description_content_type="text/markdown",
13+
author = 'junjun3518',
14+
author_email = 'junjun3518@gmail.com',
15+
url = 'https://github.com/junjun3518/alias-free-torch',
16+
install_requires = [],
17+
packages = find_packages(where = "src"),
18+
package_dir = {"": "src"},
19+
keywords = ['alias','torch','pytorch','filter'],
20+
python_requires = '>=3',
21+
zip_safe = False,
22+
classifiers = [
23+
'Programming Language :: Python :: 3',
24+
'Programming Language :: Python :: 3.2',
25+
'Programming Language :: Python :: 3.3',
26+
'Programming Language :: Python :: 3.4',
27+
'Programming Language :: Python :: 3.5',
28+
'Programming Language :: Python :: 3.6',
29+
'Programming Language :: Python :: 3.7',
30+
'Programming Language :: Python :: 3.8',
31+
'Programming Language :: Python :: 3.9',
32+
],
33+
)
34+

src/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+

src/alias_free_torch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .filter import LowPassFilter1d, LowPassFilter2d
2+
from .resample import UpSample1d, UpSample2d
3+
from .act import Activation1d, Activation2d

act.py src/alias_free_torch/act.py

File renamed without changes.

filter.py src/alias_free_torch/filter.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
if 'sinc' in dir(torch):
77
sinc = torch.sinc
88
else:
9-
# This code is adopted from adefossez's julius.core.sinc
10-
# https://adefossez.github.io/julius/julius/core.html
9+
# This code is adopted from adefossez's julius.core.sinc
10+
# https://adefossez.github.io/julius/julius/core.html
1111
def sinc(x: torch.Tensor):
1212
"""
1313
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
@@ -26,8 +26,9 @@ def __init__(self,
2626
half_width=0.6,
2727
stride: int = 1,
2828
pad: bool = True,
29-
kernel_size=12): # kernel_size should be even number for stylegan3 setup,
30-
# in this implementation, odd number is also possible.
29+
kernel_size=12
30+
): # kernel_size should be even number for stylegan3 setup,
31+
# in this implementation, odd number is also possible.
3132
super().__init__()
3233
if cutoff < -0.:
3334
raise ValueError("Minimum cutoff must be larger than zero.")
@@ -54,7 +55,7 @@ def __init__(self,
5455
if self.even:
5556
time = (torch.arange(-self.half_size, self.half_size) + 0.5)
5657
else:
57-
time = torch.arange(self.kernel_size) - self.half_size
58+
time = torch.arange(self.kernel_size) - self.half_size
5859
if cutoff == 0:
5960
filter_ = torch.zeros_like(time)
6061
else:
@@ -71,11 +72,10 @@ def forward(self, x):
7172
new_shape = shape[:-1] + [-1]
7273
x = x.view(-1, 1, shape[-1])
7374
if self.pad:
74-
x = F.pad(
75-
x,
76-
(self.half_size, self.half_size),
77-
mode='constant', value=0) # empirically, it is better than replicate
78-
#mode='replicate')
75+
x = F.pad(x, (self.half_size, self.half_size),
76+
mode='constant',
77+
value=0) # empirically, it is better than replicate
78+
#mode='replicate')
7979
if self.even:
8080
out = F.conv1d(x, self.filter, stride=self.stride)[..., :-1]
8181
else:
@@ -90,7 +90,7 @@ def __init__(self,
9090
stride: int = 1,
9191
pad: bool = True,
9292
kernel_size=12): # kernel_size should be even number
93-
# in this implementation, odd number is also possible.
93+
# in this implementation, odd number is also possible.
9494
super().__init__()
9595
if cutoff < -0.:
9696
raise ValueError("Minimum cutoff must be larger than zero.")
@@ -115,15 +115,15 @@ def __init__(self,
115115

116116
#rotation equivariant grid
117117
if self.even:
118-
time = (torch.stack(torch.meshgrid(
119-
torch.arange(-self.half_size, self.half_size) + 0.5,
120-
torch.arange(-self.half_size, self.half_size) + 0.5),
121-
dim=-1))
118+
time = torch.stack(torch.meshgrid(
119+
torch.arange(-self.half_size, self.half_size) + 0.5,
120+
torch.arange(-self.half_size, self.half_size) + 0.5),
121+
dim=-1)
122122
else:
123-
time = (torch.stack(torch.meshgrid(
124-
torch.arange(self.kernel_size) - self.half_size,
125-
torch.arange(self.kernel_size) - self.half_size,
126-
dim=-1))
123+
time = torch.stack(torch.meshgrid(
124+
torch.arange(self.kernel_size) - self.half_size,
125+
torch.arange(self.kernel_size) - self.half_size),
126+
dim=-1)
127127

128128
time = torch.norm(time, dim=-1)
129129
#rotation equivariant window
@@ -148,10 +148,12 @@ def forward(self, x):
148148
shape = list(x.shape)
149149
x = x.view(-1, 1, shape[-2], shape[-1])
150150
if self.pad:
151-
x = F.pad(x, (self.half_size, self.half_size, self.half_size,
152-
self.half_size),
153-
mode='constant', value=0) # empirically, it is better than replicate or reflect
154-
#mode='replicate')
151+
x = F.pad(
152+
x, (self.half_size, self.half_size, self.half_size,
153+
self.half_size),
154+
mode='constant',
155+
value=0) # empirically, it is better than replicate or reflect
156+
#mode='replicate')
155157
if self.even:
156158
out = F.conv2d(x, self.filter, stride=self.stride)[..., :-1, :-1]
157159
else:
File renamed without changes.

test/act2d_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..act import Activation2d
4-
from ..resample import UpSample2d, DownSample2d
5-
from ..filter import LowPassFilter2d
3+
from alias_free_torch.act import Activation2d
4+
from alias_free_torch.resample import UpSample2d, DownSample2d
5+
from alias_free_torch.filter import LowPassFilter2d
66
import math
77
continuous_ratio = 16
88
ratio = 2

test/down1d_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..resample import DownSample1d
3+
from alias_free_torch.resample import DownSample1d
44

55
ratio = 10
66
t = torch.arange(100) / 100. * 3.141592

test/down2d_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..resample import DownSample2d
3+
from alias_free_torch.resample import DownSample2d
44

55
ratio = 4
66
size = 80

test/filter1d_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..filter import LowPassFilter1d
3+
from alias_free_torch.filter import LowPassFilter1d
44

55
ratio = 2
66
t = torch.arange(400) / 40. * 3.141592

test/filter2d_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..filter import LowPassFilter2d
3+
from alias_free_torch.filter import LowPassFilter2d
44

55
ratio = 2
66

test/up1d_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..resample import UpSample1d
4-
from ..filter import LowPassFilter1d
3+
from alias_free_torch.resample import UpSample1d
4+
from alias_free_torch.filter import LowPassFilter1d
55

66
ratio = 2
77
t = torch.arange(100) / 10. * 3.141592

test/up2d_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import matplotlib.pyplot as plt
3-
from ..resample import UpSample2d
4-
from ..filter import LowPassFilter2d
3+
from alias_free_torch.resample import UpSample2d
4+
from alias_free_torch.filter import LowPassFilter2d
55

66
ratio = 8
77
size = 40

0 commit comments

Comments
 (0)