@@ -34,6 +34,10 @@ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_m
34
34
self .pl_decay = pl_decay
35
35
self .pl_weight = pl_weight
36
36
self .pl_mean = torch .zeros ([], device = device )
37
+ self .G_top_k = G_top_k
38
+ self .G_top_k_gamma = G_top_k_gamma
39
+ self .G_top_k_frac = G_top_k_frac
40
+
37
41
38
42
def run_G (self , z , c , sync ):
39
43
with misc .ddp_sync (self .G_mapping , sync ):
@@ -68,6 +72,19 @@ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain
68
72
gen_logits = self .run_D (gen_img , gen_c , sync = False )
69
73
training_stats .report ('Loss/scores/fake' , gen_logits )
70
74
training_stats .report ('Loss/signs/fake' , gen_logits .sign ())
75
+
76
+ # top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102
77
+ if G_top_k :
78
+ D_fake_scores = gen_logits
79
+ k_frac = torch .max (self .G_top_k_gamma ** self .G_mapping .epochs , self .G_top_k_frac )
80
+ print (k_frac )
81
+ k = (torch .ceil (minibatch_size .type (torch .float ) * k_frac )).type (torch .int )
82
+ print (k )
83
+ lowest_k_scores , _ = torch .topk (- torch .squeeze (D_fake_scores ), k = k ) # want smallest probabilities not largest
84
+ print (lowest_k_scores )
85
+ gen_logits = torch .expand (- lowest_k_scores , axis = 1 )
86
+ print (gen_logits )
87
+
71
88
loss_Gmain = torch .nn .functional .softplus (- gen_logits ) # -log(sigmoid(gen_logits))
72
89
training_stats .report ('Loss/G/loss' , loss_Gmain )
73
90
with torch .autograd .profiler .record_function ('Gmain_backward' ):
0 commit comments