Skip to content

Commit

Permalink
Add osx-arm64 env (#13)
Browse files Browse the repository at this point in the history
Co-authored-by: sheridana <asheridan@salk.edu>
Co-authored-by: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 22, 2024
1 parent e30a6b5 commit 16add88
Show file tree
Hide file tree
Showing 43 changed files with 3,341 additions and 1,006 deletions.
15 changes: 10 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ on:
push:
branches:
- main
- talmo/updated-ci
paths:
- "biogtr/**"
- "tests/**"
Expand Down Expand Up @@ -76,19 +75,25 @@ jobs:
uses: actions/checkout@v3

- name: Setup Micromamba
# https://github.com/mamba-org/provision-with-micromamba
uses: mamba-org/provision-with-micromamba@main
# https://github.com/mamba-org/setup-micromamba
# Note: Set channel-priority in .condarc if needed
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment_cpu.yml
cache-env: true
channel-priority: flexible
cache-environment: true
cache-environment-key: environment-${{ hashFiles('environment_cpu.yml') }}-${{ hashFiles('pyproject.toml') }}
init-shell: >-
bash
powershell
post-cleanup: all

- name: Print environment info
shell: bash -l {0}
run: |
which python
micromamba info
micromamba list
pip freeze
- name: Test with pytest
if: ${{ !(startsWith(matrix.os, 'ubuntu') && matrix.python == 3.9) }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
Expand Down
89 changes: 59 additions & 30 deletions biogtr/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# to implement - config class that handles getters/setters
"""Data structures for handling config parsing."""

from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
from biogtr.models.gtr_runner import GTRRunner
from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger
from biogtr.training.losses import AssoLoss
from omegaconf import DictConfig, OmegaConf
from pprint import pprint
from typing import Union, Iterable
from pathlib import Path
import pytorch_lightning as pl
import torch

Expand Down Expand Up @@ -43,7 +45,7 @@ def __repr__(self):
return f"Config({self.cfg})"

def __str__(self):
"""String representation of config class."""
"""Return a string representation of config class."""
return f"Config({self.cfg})"

def set_hparams(self, hparams: dict) -> bool:
Expand Down Expand Up @@ -92,20 +94,33 @@ def get_tracker_cfg(self) -> dict:

def get_gtr_runner(self):
"""Get lightning module for training, validation, and inference."""
model_params = self.cfg.model
tracker_params = self.cfg.tracker
optimizer_params = self.cfg.optimizer
scheduler_params = self.cfg.scheduler
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner
return GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
model_params,
tracker_params,
loss_params,
optimizer_params,
scheduler_params,
**gtr_runner_params,
)

return model

def get_dataset(
self, mode: str
Expand Down Expand Up @@ -174,17 +189,11 @@ def get_dataloader(
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False
if dataloader_params.shuffle:
generator = (
torch.Generator(device="cuda") if torch.cuda.is_available() else None
)
else:
generator = None

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
generator=generator,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)
Expand Down Expand Up @@ -231,8 +240,10 @@ def get_logger(self):
Returns:
A Logger with specified params
"""
logger_params = self.cfg.logging
return init_logger(logger_params)
logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)

def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Expand Down Expand Up @@ -260,12 +271,25 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:

else:
dirpath = checkpoint_params["dirpath"]

dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}"
)

_ = checkpoint_params.pop("dirpath")
checkpointers = []
monitor = checkpoint_params.pop("monitor")
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric, dirpath=dirpath, **checkpoint_params
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointers.append(checkpointer)
Expand All @@ -275,30 +299,35 @@ def get_trainer(
self,
callbacks: list[pl.callbacks.Callback],
logger: pl.loggers.WandbLogger,
accelerator: str,
devices: int,
devices: int = 1,
accelerator: str = None,
) -> pl.Trainer:
"""Getter for the lightning trainer.
Args:
callbacks: a list of lightning callbacks preconfigured to be used
for training
logger: the Wandb logger used for logging during training
accelerator: either "gpu" or "cpu" specifies which device to use
devices: The number of gpus to be used. 0 means cpu
accelerator: either "gpu" or "cpu" specifies which device to use
Returns:
A lightning Trainer with specified params
"""
if "accelerator" not in self.cfg.trainer:
self.set_hparams({"trainer.accelerator": accelerator})
if "devices" not in self.cfg.trainer:
self.set_hparams({"trainer.devices": devices})

trainer_params = self.cfg.trainer
if "profiler" in trainer_params:
profiler = pl.profilers.AdvancedProfiler(filename="profile.txt")
trainer_params.pop("profiler")
else:
profiler = None
return pl.Trainer(
callbacks=callbacks,
logger=logger,
accelerator=accelerator,
devices=devices,
profiler=profiler,
**trainer_params,
)

def get_ckpt_path(self):
"""Get model ckpt path for loading."""
return self.cfg.model.ckpt_path
Loading

0 comments on commit 16add88

Please sign in to comment.