Skip to content

Commit 3e5597f

Browse files
committed
Fixes to allow DDP
Had to rollback lightning due to this issue: Lightning-AI/pytorch-lightning#18803 Now uses the `exp-id` to save the best model checkpoint (also fixed DDP issues with saving to wandb). Removed `device` from model call.
1 parent 8fcd577 commit 3e5597f

File tree

3 files changed

+17
-21
lines changed

3 files changed

+17
-21
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ authors = [{ name = "Caleb Ellington", email = "cellingt@andrew.cmu.edu" },
1111
{ name = "Abhinav Adduri"},
1212
{ name = "Monica Dayao"}]
1313
dependencies = [
14-
"lightning==2.4.0",
14+
"lightning==2.0.8",
1515
"torch==2.4.0",
1616
"pandas==2.2.2",
1717
"wandb==0.17.8",

ultrafast/model.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def __init__(
169169
dropout=0,
170170
lr=1e-4,
171171
contrastive=False,
172-
device='cpu',
173172
args=None,
174173
):
175174
super().__init__()
@@ -183,7 +182,6 @@ def __init__(
183182
self.classify = classify
184183
self.contrastive = contrastive
185184
self.args = args
186-
self.device_ = device
187185

188186
if args.drug_layers == 1:
189187
self.drug_projector = nn.Sequential(
@@ -313,7 +311,7 @@ def training_step(self, batch, batch_idx):
313311
loss = self.contrastive_step(batch)
314312
self.manual_backward(loss)
315313
con_opt.step()
316-
self.log("train/contrastive_loss", loss)
314+
self.log("train/contrastive_loss", loss, sync_dist=True if self.trainer.num_devices > 1 else False)
317315
else:
318316
if self.contrastive:
319317
opt, _ = self.optimizers()
@@ -323,7 +321,7 @@ def training_step(self, batch, batch_idx):
323321
loss = self.non_contrastive_step(batch)
324322
self.manual_backward(loss)
325323
opt.step()
326-
self.log("train/loss", loss)
324+
self.log("train/loss", loss, sync_dist=True if self.trainer.num_devices > 1 else False)
327325

328326
return loss
329327

@@ -332,18 +330,18 @@ def on_train_epoch_end(self):
332330
if self.contrastive:
333331
if self.current_epoch % 2 == 0: # supervised learning epoch
334332
sch[0].step()
335-
self.log("train/lr", sch[0].get_lr()[0])
333+
self.log("train/lr", sch[0].get_lr()[0], sync_dist=True if self.trainer.num_devices > 1 else False)
336334
else: # contrastive learning epoch
337335
sch[1].step()
338336
self.contrastive_loss_fct.step()
339-
self.log("train/triplet_margin", self.contrastive_loss_fct.margin)
340-
self.log("train/contrastive_lr", sch[1].get_lr()[0])
337+
self.log("train/triplet_margin", self.contrastive_loss_fct.margin, sync_dist=True if self.trainer.num_devices > 1 else False)
338+
self.log("train/contrastive_lr", sch[1].get_lr()[0], sync_dist=True if self.trainer.num_devices > 1 else False)
341339
else:
342-
self.log("train/lr", sch.get_lr()[0])
340+
self.log("train/lr", sch.get_lr()[0], sync_dist=True if self.trainer.num_devices > 1 else False)
343341
sch.step()
344342

345343
def validation_step(self, batch, batch_idx):
346-
if self.global_step == 0 and not self.args.no_wandb:
344+
if self.global_step == 0 and self.global_rank == 0 and not self.args.no_wandb:
347345
wandb.define_metric("val/aupr", summary="max")
348346
drug, protein, label = batch
349347
similarity = self.forward(drug, protein)
@@ -352,7 +350,7 @@ def validation_step(self, batch, batch_idx):
352350
similarity = torch.squeeze(F.sigmoid(similarity))
353351

354352
loss = self.loss_fct(similarity, label)
355-
self.log("val/loss", loss)
353+
self.log("val/loss", loss, sync_dist=True if self.trainer.num_devices > 1 else False)
356354

357355
self.val_step_outputs.extend(similarity)
358356
self.val_step_targets.extend(label)
@@ -365,7 +363,7 @@ def on_validation_epoch_end(self):
365363
metric(torch.Tensor(self.val_step_outputs), torch.Tensor(self.val_step_targets).to(torch.int))
366364
else:
367365
metric(torch.Tensor(self.val_step_outputs).cuda(), torch.Tensor(self.val_step_targets).to(torch.float).cuda())
368-
self.log(f"val/{name}", metric, on_step=False, on_epoch=True)
366+
self.log(f"val/{name}", metric, on_step=False, on_epoch=True, sync_dist=True if self.trainer.num_devices > 1 else False)
369367

370368
self.val_step_outputs.clear()
371369
self.val_step_targets.clear()
@@ -388,7 +386,7 @@ def on_test_epoch_end(self):
388386
metric(torch.Tensor(self.test_step_outputs), torch.Tensor(self.test_step_targets).to(torch.int))
389387
else:
390388
metric(torch.Tensor(self.test_step_outputs).cuda(), torch.Tensor(self.test_step_targets).to(torch.float).cuda())
391-
self.log(f"test/{name}", metric, on_step=False, on_epoch=True)
389+
self.log(f"test/{name}", metric, on_step=False, on_epoch=True, sync_dist=True if self.trainer.num_devices > 1 else False)
392390

393391
self.test_step_outputs.clear()
394392
self.test_step_targets.clear()

ultrafast/train.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def train(
112112
config.update(args_overrides)
113113

114114
save_dir = f'{config.get("model_save_dir", ".")}/{config.experiment_id}'
115-
os.makedirs(save_dir, exist_ok=True)
116115

117116
# Set CUDA device
118117
device_no = config.device
@@ -221,16 +220,17 @@ def train(
221220
contrastive=config.contrastive,
222221
num_layers_target=config.num_layers_target,
223222
dropout=config.dropout,
224-
device=device,
225223
args=config
226224
)
227225

228226
if not config.no_wandb:
229-
wandb_logger = WandbLogger(project=config.wandb_proj, log_model="gradients")
227+
wandb_logger = WandbLogger(project=config.wandb_proj, log_model=True)
230228
wandb_logger.watch(model)
231-
wandb_logger.experiment.config.update(OmegaConf.to_container(config, resolve=True, throw_on_missing=True))
229+
if hasattr(wandb_logger.experiment.config, 'update'):
230+
wandb_logger.experiment.config.update(OmegaConf.to_container(config, resolve=True, throw_on_missing=True))
232231

233-
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor=config.watch_metric, mode="max", filename=config.task, verbose=True)
232+
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor=config.watch_metric, mode="max", filename=config.task,
233+
dirpath=save_dir, verbose=True)
234234
# Train model
235235
trainer = pl.Trainer(
236236
accelerator="auto",
@@ -245,11 +245,9 @@ def train(
245245
datamodule=datamodule,
246246
)
247247

248-
wandb.save(f'{config.task}.ckpt')
249-
250248
# Test model using best weights
251249
trainer.test(datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path)
252250

253251

254252
if __name__ == '__main__':
255-
train()
253+
train()

0 commit comments

Comments
 (0)