Skip to content

Commit

Permalink
idinvert encoder, full w
Browse files Browse the repository at this point in the history
  • Loading branch information
phymhan committed Jan 12, 2021
1 parent ab9e89a commit 37c26b0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 23 deletions.
2 changes: 1 addition & 1 deletion idinvert_pytorch
37 changes: 31 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def forward(
noise=None,
randomize_noise=True,
):
if not input_is_latent: # `style' is z, then w = self.style(z)
if not input_is_latent: # if `style' is z, then get w = self.style(z)
styles = [self.style(s) for s in styles]

if noise is None:
Expand All @@ -497,7 +497,7 @@ def forward(
if styles[0].ndim < 3: # w is of dim [batch, 512], repeat at dim 1 for each layer
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

else: # w is of dim [batch, num_layers, 512]
else: # w is of dim [batch, n_latent, 512]
latent = styles[0]

else: # mixing
Expand All @@ -509,7 +509,7 @@ def forward(

latent = torch.cat([latent, latent2], 1)

out = self.input(latent)
out = self.input(latent) # only batch_size of latent is used
out = self.conv1(out, latent[:, 0], noise=noise[0])

skip = self.to_rgb1(out, latent[:, 1])
Expand Down Expand Up @@ -661,7 +661,20 @@ def forward(self, input):


class Encoder(nn.Module):
def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
def __init__(
self,
size,
style_dim=512,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
which_latent='w',
reshape_latent=True,
):
"""
which_latent: 'w' predict different w for all blocks; 'w_shared' predict
a single w for all blocks; 'wb' predict w and b (bias) for all blocks;
'wb_shared' predict shared w and different biases.
"""
super().__init__()

channels = {
Expand All @@ -679,6 +692,11 @@ def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3
convs = [ConvLayer(3, channels[size], 1)]

log_size = int(math.log(size, 2))
self.n_latent = log_size * 2 - 2 # copied from Generator
self.n_noises = (log_size - 2) * 2 + 1
self.which_latent = which_latent
self.reshape_latent = reshape_latent
self.style_dim = style_dim

in_channel = channels[size]

Expand All @@ -695,9 +713,15 @@ def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3
self.stddev_feat = 1

self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
if self.which_latent == 'w':
out_channel = style_dim * self.n_latent
elif self.which_latent == 'w_shared':
out_channel = style_dim
else:
raise NotImplementedError
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], out_dim),
EqualLinear(channels[4], out_channel),
)

def forward(self, input):
Expand All @@ -717,5 +741,6 @@ def forward(self, input):

out = out.view(batch, -1)
out = self.final_linear(out)

if self.which_latent == 'w' and self.reshape_latent:
out = out.reshape(batch, self.n_latent, self.style_dim)
return out
36 changes: 20 additions & 16 deletions train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,10 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op
if args.augment and args.augment_p == 0:
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)

# sample_z = torch.randn(args.n_sample, args.latent, device=device)
sample_x = accumulate_batches(loader, args.n_sample)
sample_x = sample_x.to(device)
sample_x = accumulate_batches(loader, args.n_sample).to(device)

requires_grad(generator, False) # always False
generator.eval()
generator.eval() # Generator should be ema and in eval mode

for idx in pbar:
i = idx + args.start_iter
Expand Down Expand Up @@ -350,6 +348,7 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op
parser.add_argument("--output_layer_idx", type=int, default=23)
parser.add_argument('--vgg_ckpt', type=str, default='pretrained/vgg16.pth')
parser.add_argument('--which_encoder', type=str, default='idinvert')
parser.add_argument('--which_latent', type=str, default='w_shared')
parser.add_argument(
"--iter", type=int, default=800000, help="total training iterations"
)
Expand Down Expand Up @@ -458,11 +457,18 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()

args.latent = 512 # fixed, dim of z
args.n_latent = int(np.log2(args.size)) * 2 - 2 # used in Generator
args.latent = 512 # fixed, dim of w or z (same size)
if args.which_latent == 'w':
args.latent_full = args.latent * args.n_latent
elif args.which__latent == 'w_shared':
args.latent_full = args.latent
else:
raise NotImplementedError
args.n_mlp = 8

args.start_iter = 0
args.mixing = 0 # no mixing
# args.mixing = 0 # no mixing
util.set_log_dir(args)
util.print_args(parser, args)

Expand All @@ -484,25 +490,23 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op

if args.which_encoder == 'idinvert':
from idinvert_pytorch.models.stylegan_encoder_network import StyleGANEncoderNet
encoder = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent).to(device)
e_ema = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent).to(device)
encoder = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent,
which_latent=args.which_latent, reshape_latent=True).to(device)
e_ema = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent,
which_latent=args.which_latent, reshape_latent=True).to(device)
else:
from model import Encoder
encoder = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier).to(device)
e_ema = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier).to(device)
encoder = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier,
which_latent=args.which_latent, reshape_latent=True).to(device)
e_ema = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier,
which_latent=args.which_latent, reshape_latent=True).to(device)
e_ema.eval()
accumulate(e_ema, encoder, 0)

# TODO: what is this used for?
# g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
e_reg_ratio = args.e_reg_every / (args.e_reg_every + 1)
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

# g_optim = optim.Adam(
# generator.parameters(),
# lr=args.lr * g_reg_ratio,
# betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
# )
e_optim = optim.Adam(
encoder.parameters(),
lr=args.lr * e_reg_ratio,
Expand Down

0 comments on commit 37c26b0

Please sign in to comment.