Skip to content

Commit

Permalink
Aadi/refactor-tracker (#23)
Browse files Browse the repository at this point in the history
Co-authored-by: aaprasad <aaprasad.ucsd.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
aaprasad and coderabbitai[bot] authored Apr 22, 2024
1 parent e7ca49f commit dd51119
Show file tree
Hide file tree
Showing 40 changed files with 3,138 additions and 977 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
Expand Down
83 changes: 59 additions & 24 deletions biogtr/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# to implement - config class that handles getters/setters
"""Data structures for handling config parsing."""

from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
Expand All @@ -10,6 +11,7 @@
from omegaconf import DictConfig, OmegaConf
from pprint import pprint
from typing import Union, Iterable
from pathlib import Path
import pytorch_lightning as pl
import torch

Expand Down Expand Up @@ -43,7 +45,7 @@ def __repr__(self):
return f"Config({self.cfg})"

def __str__(self):
"""String representation of config class."""
"""Return a string representation of config class."""
return f"Config({self.cfg})"

def set_hparams(self, hparams: dict) -> bool:
Expand Down Expand Up @@ -92,20 +94,33 @@ def get_tracker_cfg(self) -> dict:

def get_gtr_runner(self):
"""Get lightning module for training, validation, and inference."""
model_params = self.cfg.model
tracker_params = self.cfg.tracker
optimizer_params = self.cfg.optimizer
scheduler_params = self.cfg.scheduler
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner
return GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

return model

def get_dataset(
self, mode: str
Expand Down Expand Up @@ -174,13 +189,13 @@ def get_dataloader(
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params
**dataloader_params,
)

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -225,8 +240,10 @@ def get_logger(self):
Returns:
A Logger with specified params
"""
logger_params = self.cfg.logging
return init_logger(logger_params)
logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)

def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Expand Down Expand Up @@ -254,12 +271,25 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:

else:
dirpath = checkpoint_params["dirpath"]

dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}"
)

_ = checkpoint_params.pop("dirpath")
checkpointers = []
monitor = checkpoint_params.pop("monitor")
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric, dirpath=dirpath, **checkpoint_params
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointers.append(checkpointer)
Expand All @@ -269,30 +299,35 @@ def get_trainer(
self,
callbacks: list[pl.callbacks.Callback],
logger: pl.loggers.WandbLogger,
accelerator: str,
devices: int,
devices: int = 1,
accelerator: str = None,
) -> pl.Trainer:
"""Getter for the lightning trainer.
Args:
callbacks: a list of lightning callbacks preconfigured to be used
for training
logger: the Wandb logger used for logging during training
accelerator: either "gpu" or "cpu" specifies which device to use
devices: The number of gpus to be used. 0 means cpu
accelerator: either "gpu" or "cpu" specifies which device to use
Returns:
A lightning Trainer with specified params
"""
if "accelerator" not in self.cfg.trainer:
self.set_hparams({"trainer.accelerator": accelerator})
if "devices" not in self.cfg.trainer:
self.set_hparams({"trainer.devices": devices})

trainer_params = self.cfg.trainer
if "profiler" in trainer_params:
profiler = pl.profilers.AdvancedProfiler(filename="profile.txt")
trainer_params.pop("profiler")
else:
profiler = None
return pl.Trainer(
callbacks=callbacks,
logger=logger,
accelerator=accelerator,
devices=devices,
profiler=profiler,
**trainer_params,
)

def get_ckpt_path(self):
"""Get model ckpt path for loading."""
return self.cfg.model.ckpt_path
Loading

0 comments on commit dd51119

Please sign in to comment.