@@ -169,7 +169,6 @@ def __init__(
169
169
dropout = 0 ,
170
170
lr = 1e-4 ,
171
171
contrastive = False ,
172
- device = 'cpu' ,
173
172
args = None ,
174
173
):
175
174
super ().__init__ ()
@@ -183,7 +182,6 @@ def __init__(
183
182
self .classify = classify
184
183
self .contrastive = contrastive
185
184
self .args = args
186
- self .device_ = device
187
185
188
186
if args .drug_layers == 1 :
189
187
self .drug_projector = nn .Sequential (
@@ -313,7 +311,7 @@ def training_step(self, batch, batch_idx):
313
311
loss = self .contrastive_step (batch )
314
312
self .manual_backward (loss )
315
313
con_opt .step ()
316
- self .log ("train/contrastive_loss" , loss )
314
+ self .log ("train/contrastive_loss" , loss , sync_dist = True if self . trainer . num_devices > 1 else False )
317
315
else :
318
316
if self .contrastive :
319
317
opt , _ = self .optimizers ()
@@ -323,7 +321,7 @@ def training_step(self, batch, batch_idx):
323
321
loss = self .non_contrastive_step (batch )
324
322
self .manual_backward (loss )
325
323
opt .step ()
326
- self .log ("train/loss" , loss )
324
+ self .log ("train/loss" , loss , sync_dist = True if self . trainer . num_devices > 1 else False )
327
325
328
326
return loss
329
327
@@ -332,18 +330,18 @@ def on_train_epoch_end(self):
332
330
if self .contrastive :
333
331
if self .current_epoch % 2 == 0 : # supervised learning epoch
334
332
sch [0 ].step ()
335
- self .log ("train/lr" , sch [0 ].get_lr ()[0 ])
333
+ self .log ("train/lr" , sch [0 ].get_lr ()[0 ], sync_dist = True if self . trainer . num_devices > 1 else False )
336
334
else : # contrastive learning epoch
337
335
sch [1 ].step ()
338
336
self .contrastive_loss_fct .step ()
339
- self .log ("train/triplet_margin" , self .contrastive_loss_fct .margin )
340
- self .log ("train/contrastive_lr" , sch [1 ].get_lr ()[0 ])
337
+ self .log ("train/triplet_margin" , self .contrastive_loss_fct .margin , sync_dist = True if self . trainer . num_devices > 1 else False )
338
+ self .log ("train/contrastive_lr" , sch [1 ].get_lr ()[0 ], sync_dist = True if self . trainer . num_devices > 1 else False )
341
339
else :
342
- self .log ("train/lr" , sch .get_lr ()[0 ])
340
+ self .log ("train/lr" , sch .get_lr ()[0 ], sync_dist = True if self . trainer . num_devices > 1 else False )
343
341
sch .step ()
344
342
345
343
def validation_step (self , batch , batch_idx ):
346
- if self .global_step == 0 and not self .args .no_wandb :
344
+ if self .global_step == 0 and self . global_rank == 0 and not self .args .no_wandb :
347
345
wandb .define_metric ("val/aupr" , summary = "max" )
348
346
drug , protein , label = batch
349
347
similarity = self .forward (drug , protein )
@@ -352,7 +350,7 @@ def validation_step(self, batch, batch_idx):
352
350
similarity = torch .squeeze (F .sigmoid (similarity ))
353
351
354
352
loss = self .loss_fct (similarity , label )
355
- self .log ("val/loss" , loss )
353
+ self .log ("val/loss" , loss , sync_dist = True if self . trainer . num_devices > 1 else False )
356
354
357
355
self .val_step_outputs .extend (similarity )
358
356
self .val_step_targets .extend (label )
@@ -365,7 +363,7 @@ def on_validation_epoch_end(self):
365
363
metric (torch .Tensor (self .val_step_outputs ), torch .Tensor (self .val_step_targets ).to (torch .int ))
366
364
else :
367
365
metric (torch .Tensor (self .val_step_outputs ).cuda (), torch .Tensor (self .val_step_targets ).to (torch .float ).cuda ())
368
- self .log (f"val/{ name } " , metric , on_step = False , on_epoch = True )
366
+ self .log (f"val/{ name } " , metric , on_step = False , on_epoch = True , sync_dist = True if self . trainer . num_devices > 1 else False )
369
367
370
368
self .val_step_outputs .clear ()
371
369
self .val_step_targets .clear ()
@@ -388,7 +386,7 @@ def on_test_epoch_end(self):
388
386
metric (torch .Tensor (self .test_step_outputs ), torch .Tensor (self .test_step_targets ).to (torch .int ))
389
387
else :
390
388
metric (torch .Tensor (self .test_step_outputs ).cuda (), torch .Tensor (self .test_step_targets ).to (torch .float ).cuda ())
391
- self .log (f"test/{ name } " , metric , on_step = False , on_epoch = True )
389
+ self .log (f"test/{ name } " , metric , on_step = False , on_epoch = True , sync_dist = True if self . trainer . num_devices > 1 else False )
392
390
393
391
self .test_step_outputs .clear ()
394
392
self .test_step_targets .clear ()
0 commit comments