Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aadi/sleap-anchors #21

Merged
merged 14 commits into from
Apr 22, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
Expand Down
77 changes: 54 additions & 23 deletions biogtr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from omegaconf import DictConfig, OmegaConf
from pprint import pprint
from typing import Union, Iterable
from pathlib import Path
import os
aaprasad marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning as pl
import torch

Expand Down Expand Up @@ -43,7 +45,7 @@ def __repr__(self):
return f"Config({self.cfg})"

def __str__(self):
"""String representation of config class."""
"""Return a string representation of config class."""
return f"Config({self.cfg})"

def set_hparams(self, hparams: dict) -> bool:
Expand Down Expand Up @@ -92,20 +94,33 @@ def get_tracker_cfg(self) -> dict:

def get_gtr_runner(self):
"""Get lightning module for training, validation, and inference."""
model_params = self.cfg.model
tracker_params = self.cfg.tracker
optimizer_params = self.cfg.optimizer
scheduler_params = self.cfg.scheduler
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner
return GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

return model

def get_dataset(
self, mode: str
Expand Down Expand Up @@ -174,13 +189,13 @@ def get_dataloader(
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params
**dataloader_params,
)

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -225,8 +240,10 @@ def get_logger(self):
Returns:
A Logger with specified params
"""
logger_params = self.cfg.logging
return init_logger(logger_params)
logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)

def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Expand Down Expand Up @@ -254,12 +271,25 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:

else:
dirpath = checkpoint_params["dirpath"]

dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}"
)

_ = checkpoint_params.pop("dirpath")
checkpointers = []
monitor = checkpoint_params.pop("monitor")
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric, dirpath=dirpath, **checkpoint_params
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointers.append(checkpointer)
Expand All @@ -270,7 +300,7 @@ def get_trainer(
callbacks: list[pl.callbacks.Callback],
logger: pl.loggers.WandbLogger,
devices: int = 1,
accelerator: str = None
accelerator: str = None,
) -> pl.Trainer:
"""Getter for the lightning trainer.

Expand All @@ -285,18 +315,19 @@ def get_trainer(
A lightning Trainer with specified params
"""
if "accelerator" not in self.cfg.trainer:
self.set_hparams({'trainer.accelerator': accelerator})
self.set_hparams({"trainer.accelerator": accelerator})
if "devices" not in self.cfg.trainer:
self.set_hparams({'trainer.devices': devices})
self.set_hparams({"trainer.devices": devices})

trainer_params = self.cfg.trainer

if "profiler" in trainer_params:
profiler = pl.profilers.AdvancedProfiler(filename="profile.txt")
trainer_params.pop("profiler")
else:
profiler = None
return pl.Trainer(
callbacks=callbacks,
logger=logger,
profiler=profiler,
**trainer_params,
)

def get_ckpt_path(self):
"""Get model ckpt path for loading."""
return self.cfg.model.ckpt_path
Loading
Loading