From 37c26b0c4139d440b7cb7346003190e8aebcd622 Mon Sep 17 00:00:00 2001 From: Ligong Han Date: Mon, 11 Jan 2021 22:25:11 -0500 Subject: [PATCH] idinvert encoder, full w --- idinvert_pytorch | 2 +- model.py | 37 +++++++++++++++++++++++++++++++------ train_encoder.py | 36 ++++++++++++++++++++---------------- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/idinvert_pytorch b/idinvert_pytorch index ce866b05..a018836a 160000 --- a/idinvert_pytorch +++ b/idinvert_pytorch @@ -1 +1 @@ -Subproject commit ce866b05f13ff0466e26f1641292b456315a75a0 +Subproject commit a018836a62dc42e416ff1bf3eac270d07a1d468c diff --git a/model.py b/model.py index 6c1e194e..6b4382ed 100755 --- a/model.py +++ b/model.py @@ -470,7 +470,7 @@ def forward( noise=None, randomize_noise=True, ): - if not input_is_latent: # `style' is z, then w = self.style(z) + if not input_is_latent: # if `style' is z, then get w = self.style(z) styles = [self.style(s) for s in styles] if noise is None: @@ -497,7 +497,7 @@ def forward( if styles[0].ndim < 3: # w is of dim [batch, 512], repeat at dim 1 for each layer latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) - else: # w is of dim [batch, num_layers, 512] + else: # w is of dim [batch, n_latent, 512] latent = styles[0] else: # mixing @@ -509,7 +509,7 @@ def forward( latent = torch.cat([latent, latent2], 1) - out = self.input(latent) + out = self.input(latent) # only batch_size of latent is used out = self.conv1(out, latent[:, 0], noise=noise[0]) skip = self.to_rgb1(out, latent[:, 1]) @@ -661,7 +661,20 @@ def forward(self, input): class Encoder(nn.Module): - def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + def __init__( + self, + size, + style_dim=512, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + which_latent='w', + reshape_latent=True, + ): + """ + which_latent: 'w' predict different w for all blocks; 'w_shared' predict + a single w for all blocks; 'wb' predict w and b (bias) for all blocks; + 'wb_shared' predict shared w and different biases. + """ super().__init__() channels = { @@ -679,6 +692,11 @@ def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3 convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) + self.n_latent = log_size * 2 - 2 # copied from Generator + self.n_noises = (log_size - 2) * 2 + 1 + self.which_latent = which_latent + self.reshape_latent = reshape_latent + self.style_dim = style_dim in_channel = channels[size] @@ -695,9 +713,15 @@ def __init__(self, size, out_dim=512, channel_multiplier=2, blur_kernel=[1, 3, 3 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + if self.which_latent == 'w': + out_channel = style_dim * self.n_latent + elif self.which_latent == 'w_shared': + out_channel = style_dim + else: + raise NotImplementedError self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), - EqualLinear(channels[4], out_dim), + EqualLinear(channels[4], out_channel), ) def forward(self, input): @@ -717,5 +741,6 @@ def forward(self, input): out = out.view(batch, -1) out = self.final_linear(out) - + if self.which_latent == 'w' and self.reshape_latent: + out = out.reshape(batch, self.n_latent, self.style_dim) return out diff --git a/train_encoder.py b/train_encoder.py index 0fb664a3..5031edb5 100755 --- a/train_encoder.py +++ b/train_encoder.py @@ -170,12 +170,10 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) - # sample_z = torch.randn(args.n_sample, args.latent, device=device) - sample_x = accumulate_batches(loader, args.n_sample) - sample_x = sample_x.to(device) + sample_x = accumulate_batches(loader, args.n_sample).to(device) requires_grad(generator, False) # always False - generator.eval() + generator.eval() # Generator should be ema and in eval mode for idx in pbar: i = idx + args.start_iter @@ -350,6 +348,7 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op parser.add_argument("--output_layer_idx", type=int, default=23) parser.add_argument('--vgg_ckpt', type=str, default='pretrained/vgg16.pth') parser.add_argument('--which_encoder', type=str, default='idinvert') + parser.add_argument('--which_latent', type=str, default='w_shared') parser.add_argument( "--iter", type=int, default=800000, help="total training iterations" ) @@ -458,11 +457,18 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() - args.latent = 512 # fixed, dim of z + args.n_latent = int(np.log2(args.size)) * 2 - 2 # used in Generator + args.latent = 512 # fixed, dim of w or z (same size) + if args.which_latent == 'w': + args.latent_full = args.latent * args.n_latent + elif args.which__latent == 'w_shared': + args.latent_full = args.latent + else: + raise NotImplementedError args.n_mlp = 8 args.start_iter = 0 - args.mixing = 0 # no mixing + # args.mixing = 0 # no mixing util.set_log_dir(args) util.print_args(parser, args) @@ -484,25 +490,23 @@ def train(args, loader, encoder, generator, discriminator, vggnet, e_optim, d_op if args.which_encoder == 'idinvert': from idinvert_pytorch.models.stylegan_encoder_network import StyleGANEncoderNet - encoder = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent).to(device) - e_ema = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent).to(device) + encoder = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent, + which_latent=args.which_latent, reshape_latent=True).to(device) + e_ema = StyleGANEncoderNet(resolution=args.size, w_space_dim=args.latent, + which_latent=args.which_latent, reshape_latent=True).to(device) else: from model import Encoder - encoder = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier).to(device) - e_ema = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier).to(device) + encoder = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier, + which_latent=args.which_latent, reshape_latent=True).to(device) + e_ema = Encoder(args.size, args.latent, channel_multiplier=args.channel_multiplier, + which_latent=args.which_latent, reshape_latent=True).to(device) e_ema.eval() accumulate(e_ema, encoder, 0) # TODO: what is this used for? - # g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) e_reg_ratio = args.e_reg_every / (args.e_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) - # g_optim = optim.Adam( - # generator.parameters(), - # lr=args.lr * g_reg_ratio, - # betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), - # ) e_optim = optim.Adam( encoder.parameters(), lr=args.lr * e_reg_ratio,