Skip to content

Commit

Permalink
Added activation to from_rgb
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Aug 21, 2019
1 parent 093503a commit b160dbc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
27 changes: 17 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def mean_style(self, input):


class Discriminator(nn.Module):
def __init__(self, fused=True):
def __init__(self, fused=True, from_rgb_activate=False):
super().__init__()

self.progression = nn.ModuleList(
Expand All @@ -524,17 +524,24 @@ def __init__(self, fused=True):
]
)

def make_from_rgb(out_channel):
if from_rgb_activate:
return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))

else:
return EqualConv2d(3, out_channel, 1)

self.from_rgb = nn.ModuleList(
[
EqualConv2d(3, 16, 1),
EqualConv2d(3, 32, 1),
EqualConv2d(3, 64, 1),
EqualConv2d(3, 128, 1),
EqualConv2d(3, 256, 1),
EqualConv2d(3, 512, 1),
EqualConv2d(3, 512, 1),
EqualConv2d(3, 512, 1),
EqualConv2d(3, 512, 1),
make_from_rgb(16),
make_from_rgb(32),
make_from_rgb(64),
make_from_rgb(128),
make_from_rgb(256),
make_from_rgb(512),
make_from_rgb(512),
make_from_rgb(512),
make_from_rgb(512),
]
)

Expand Down
11 changes: 8 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def train(args, dataset, generator, discriminator):

alpha = 0
used_sample = 0

max_step = int(math.log2(args.max_size)) - 2
final_progress = False

Expand Down Expand Up @@ -101,7 +101,7 @@ def train(args, dataset, generator, discriminator):
'discriminator': discriminator.module.state_dict(),
'g_optimizer': g_optimizer.state_dict(),
'd_optimizer': d_optimizer.state_dict(),
'g_running': g_running.state_dict()
'g_running': g_running.state_dict(),
},
f'checkpoint/train_step-{step}.model',
)
Expand Down Expand Up @@ -262,6 +262,11 @@ def train(args, dataset, generator, discriminator):
parser.add_argument('--sched', action='store_true', help='use lr scheduling')
parser.add_argument('--init_size', default=8, type=int, help='initial image size')
parser.add_argument('--max_size', default=1024, type=int, help='max image size')
parser.add_argument(
'--from_rgb_activate',
action='store_true',
help='use activate in from_rgb (original implementation)',
)
parser.add_argument(
'--mixing', action='store_true', help='use mixing regularization'
)
Expand All @@ -276,7 +281,7 @@ def train(args, dataset, generator, discriminator):
args = parser.parse_args()

generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
discriminator = nn.DataParallel(Discriminator()).cuda()
discriminator = nn.DataParallel(Discriminator(from_rgb_activate=args.from_rgb_activate)).cuda()
g_running = StyledGenerator(code_size).cuda()
g_running.train(False)

Expand Down

0 comments on commit b160dbc

Please sign in to comment.