Skip to content

Commit

Permalink
Aadi/track local queues (#20)
Browse files Browse the repository at this point in the history
Co-authored-by: aaprasad <aaprasad.ucsd.edu>
  • Loading branch information
aaprasad committed Apr 22, 2024
1 parent cd6d61a commit cfc5500
Show file tree
Hide file tree
Showing 40 changed files with 2,666 additions and 913 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
Expand Down
57 changes: 40 additions & 17 deletions biogtr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from omegaconf import DictConfig, OmegaConf
from pprint import pprint
from typing import Union, Iterable
from pathlib import Path
import os
import pytorch_lightning as pl
import torch

Expand Down Expand Up @@ -43,7 +45,7 @@ def __repr__(self):
return f"Config({self.cfg})"

def __str__(self):
"""String representation of config class."""
"""Return a string representation of config class."""
return f"Config({self.cfg})"

def set_hparams(self, hparams: dict) -> bool:
Expand Down Expand Up @@ -92,20 +94,21 @@ def get_tracker_cfg(self) -> dict:

def get_gtr_runner(self):
"""Get lightning module for training, validation, and inference."""

tracker_params = self.cfg.tracker
optimizer_params = self.cfg.optimizer
scheduler_params = self.cfg.scheduler
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model = GTRRunner.load_from_checkpoint(self.cfg.model.ckpt_path,
tracker_cfg = tracker_params,
train_metrics=self.cfg.runner.train_metrics,
val_metrics=self.cfg.runner.val_metrics,
test_metrics=self.cfg.runner.test_metrics)

model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
Expand Down Expand Up @@ -186,13 +189,13 @@ def get_dataloader(
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params
**dataloader_params,
)

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -238,7 +241,9 @@ def get_logger(self):
A Logger with specified params
"""
logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True)
return init_logger(logger_params)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)

def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Expand Down Expand Up @@ -266,12 +271,25 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:

else:
dirpath = checkpoint_params["dirpath"]

dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}"
)

_ = checkpoint_params.pop("dirpath")
checkpointers = []
monitor = checkpoint_params.pop("monitor")
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric, dirpath=dirpath, **checkpoint_params
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointers.append(checkpointer)
Expand All @@ -282,7 +300,7 @@ def get_trainer(
callbacks: list[pl.callbacks.Callback],
logger: pl.loggers.WandbLogger,
devices: int = 1,
accelerator: str = None
accelerator: str = None,
) -> pl.Trainer:
"""Getter for the lightning trainer.
Expand All @@ -297,14 +315,19 @@ def get_trainer(
A lightning Trainer with specified params
"""
if "accelerator" not in self.cfg.trainer:
self.set_hparams({'trainer.accelerator': accelerator})
self.set_hparams({"trainer.accelerator": accelerator})
if "devices" not in self.cfg.trainer:
self.set_hparams({'trainer.devices': devices})
self.set_hparams({"trainer.devices": devices})

trainer_params = self.cfg.trainer

if "profiler" in trainer_params:
profiler = pl.profilers.AdvancedProfiler(filename="profile.txt")
trainer_params.pop("profiler")
else:
profiler = None
return pl.Trainer(
callbacks=callbacks,
logger=logger,
profiler=profiler,
**trainer_params,
)
Loading

0 comments on commit cfc5500

Please sign in to comment.