Skip to content

Commit b82fde8

Browse files
committed
full topk function
1 parent aa41a17 commit b82fde8

File tree

4 files changed

+28
-8
lines changed

4 files changed

+28
-8
lines changed

train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def setup_training_loop_kwargs(
250250

251251
if topk is not None:
252252
assert isinstance(topk, float)
253-
args.loss_args.G_top_k = True
254-
args.loss_args.G_top_k_gamma = topk
255-
args.loss_args.G_top_k_frac = 0.5
253+
args.loss_kwargs.G_top_k = True
254+
args.loss_kwargs.G_top_k_gamma = topk
255+
args.loss_kwargs.G_top_k_frac = 0.5
256256

257257
# ---------------------------------------------------
258258
# Discriminator augmentation: aug, p, target, augpipe

training/loss.py

+17
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_m
3434
self.pl_decay = pl_decay
3535
self.pl_weight = pl_weight
3636
self.pl_mean = torch.zeros([], device=device)
37+
self.G_top_k = G_top_k
38+
self.G_top_k_gamma = G_top_k_gamma
39+
self.G_top_k_frac = G_top_k_frac
40+
3741

3842
def run_G(self, z, c, sync):
3943
with misc.ddp_sync(self.G_mapping, sync):
@@ -68,6 +72,19 @@ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain
6872
gen_logits = self.run_D(gen_img, gen_c, sync=False)
6973
training_stats.report('Loss/scores/fake', gen_logits)
7074
training_stats.report('Loss/signs/fake', gen_logits.sign())
75+
76+
# top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102
77+
if G_top_k:
78+
D_fake_scores = gen_logits
79+
k_frac = torch.max(self.G_top_k_gamma ** self.G_mapping.epochs, self.G_top_k_frac)
80+
print(k_frac)
81+
k = (torch.ceil(minibatch_size.type(torch.float) * k_frac)).type(torch.int)
82+
print(k)
83+
lowest_k_scores, _ = torch.topk(-torch.squeeze(D_fake_scores), k=k) # want smallest probabilities not largest
84+
print(lowest_k_scores)
85+
gen_logits = torch.expand(-lowest_k_scores, axis=1)
86+
print(gen_logits)
87+
7188
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
7289
training_stats.report('Loss/G/loss', loss_Gmain)
7390
with torch.autograd.profiler.record_function('Gmain_backward'):

training/networks.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,6 @@ def __init__(self,
484484
mapping_kwargs = {}, # Arguments for MappingNetwork.
485485
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
486486
epochs = 0., # Track epoch count for top-k
487-
nimg = 0,
488-
total_kimg = 25000,
489487
):
490488
super().__init__()
491489
self.z_dim = z_dim
@@ -496,13 +494,16 @@ def __init__(self,
496494
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
497495
self.num_ws = self.synthesis.num_ws
498496
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
499-
self.epochs = float(100 * nimg / (total_kimg * 1000)).
497+
self.epochs = 0.
500498

501499
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
502500
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
503501
img = self.synthesis(ws, **synthesis_kwargs)
504502
return img
505503

504+
def update_epochs(self, epoch):
505+
self.epochs = epoch
506+
506507
#----------------------------------------------------------------------------
507508

508509
@persistence.persistent_class

training/training_loop.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def training_loop(
148148
if rank == 0:
149149
print('Constructing networks...')
150150
common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
151-
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs, nimg, total_kimg).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
151+
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
152152
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
153153
G_ema = copy.deepcopy(G).eval()
154154

155-
G.epochs = float(100 * nimg / (total_kimg * 1000)) # 100 total top k "epochs" in total_kimg
155+
G.update_epochs( float(100 * nimg / (total_kimg * 1000)) ) # 100 total top k "epochs" in total_kimg
156156
print('starting G epochs: ',G.epochs)
157157

158158
# Resume from existing pickle.
@@ -275,6 +275,8 @@ def training_loop(
275275
if batch_idx % phase.interval != 0:
276276
continue
277277

278+
G.update_epochs( float(100 * nimg / (total_kimg * 1000)) ) # 100 total top k "epochs" in total_kimg
279+
278280
# Initialize gradient accumulation.
279281
if phase.start_event is not None:
280282
phase.start_event.record(torch.cuda.current_stream(device))

0 commit comments

Comments
 (0)