Skip to content

Commit

Permalink
Merge branch 'aadi/to_slp' into aadi/track-local-queues
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Apr 22, 2024
2 parents 4fed899 + d7d2d09 commit 33b97ad
Show file tree
Hide file tree
Showing 23 changed files with 630 additions and 137 deletions.
34 changes: 28 additions & 6 deletions biogtr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from omegaconf import DictConfig, OmegaConf
from pprint import pprint
from typing import Union, Iterable
from pathlib import Path
import os
import pytorch_lightning as pl
import torch

Expand Down Expand Up @@ -102,9 +104,9 @@ def get_gtr_runner(self):
model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.train_metrics,
val_metrics=self.cfg.runner.val_metrics,
test_metrics=self.cfg.runner.test_metrics,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
Expand Down Expand Up @@ -239,7 +241,9 @@ def get_logger(self):
A Logger with specified params
"""
logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True)
return init_logger(logger_params)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)

def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Expand Down Expand Up @@ -267,12 +271,25 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:

else:
dirpath = checkpoint_params["dirpath"]

dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}"
)

_ = checkpoint_params.pop("dirpath")
checkpointers = []
monitor = checkpoint_params.pop("monitor")
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric, dirpath=dirpath, **checkpoint_params
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointers.append(checkpointer)
Expand Down Expand Up @@ -303,9 +320,14 @@ def get_trainer(
self.set_hparams({"trainer.devices": devices})

trainer_params = self.cfg.trainer

if "profiler" in trainer_params:
profiler = pl.profilers.AdvancedProfiler(filename="profile.txt")
trainer_params.pop("profiler")
else:
profiler = None
return pl.Trainer(
callbacks=callbacks,
logger=logger,
profiler=profiler,
**trainer_params,
)
Loading

0 comments on commit 33b97ad

Please sign in to comment.