From 38c4cb3c192c1b5dfe4568d33f51a4548564c05e Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:51:20 -0400 Subject: [PATCH] Refactor config and patch small typo (#87) --- docs/configs/index.md | 2 +- docs/configs/training.md | 2 +- dreem/inference/track.py | 19 +++- dreem/inference/tracker.py | 2 +- dreem/io/config.py | 202 ++++++++++++++++++++++-------------- dreem/models/gtr_runner.py | 1 - dreem/models/model_utils.py | 11 +- dreem/training/train.py | 5 +- environment_osx-arm64.yml | 7 ++ tests/test_config.py | 32 +++++- 10 files changed, 197 insertions(+), 86 deletions(-) diff --git a/docs/configs/index.md b/docs/configs/index.md index 65f105f..fd676b1 100644 --- a/docs/configs/index.md +++ b/docs/configs/index.md @@ -1,3 +1,3 @@ # DREEM Config API -We utilize `.yaml` based configs with `hydra` and `omegaconf` for config parsing. \ No newline at end of file +We utilize `.yaml` based configs with [`hydra`](https://hydra.cc) and [`omegaconf`](https://omegaconf.readthedocs.io/en/2.3_branch/) for config parsing. \ No newline at end of file diff --git a/docs/configs/training.md b/docs/configs/training.md index b5ec3b1..9806e91 100644 --- a/docs/configs/training.md +++ b/docs/configs/training.md @@ -2,7 +2,7 @@ Here, we describe the hyperparameters used for setting up training. Please see [here](./training.md#example-config) for an example training config. -> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` e.g +> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` which we use to represent turning off certain features such as logging, early stopping etc. e.g > ```YAML > model: > d_model: #defaults to 1024 diff --git a/dreem/inference/track.py b/dreem/inference/track.py index 2c2eac4..b10ebc9 100644 --- a/dreem/inference/track.py +++ b/dreem/inference/track.py @@ -5,6 +5,7 @@ from dreem.models import GTRRunner from omegaconf import DictConfig from pathlib import Path +from datetime import datetime import hydra import os @@ -14,9 +15,21 @@ import sleap_io as sio import logging + logger = logging.getLogger("dreem.inference") +def get_timestamp() -> str: + """Get current timestamp. + + Returns: + the current timestamp in /m/d/y-H:M:S format + """ + date_time = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + print(date_time) + return date_time + + def export_trajectories( frames_pred: list["dreem.io.Frame"], save_path: str | None = None ) -> pd.DataFrame: @@ -129,7 +142,11 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: ) dataloader = pred_cfg.get_dataloader(dataset, mode="test") preds = track(model, trainer, dataloader) - outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp") + outpath = os.path.join( + outdir, f"{Path(label_file).stem}.dreem_inference.{get_timestamp()}.slp" + ) + if os.path.exists(outpath): + outpath.replace(".slp", ".") preds.save(outpath) return preds diff --git a/dreem/inference/tracker.py b/dreem/inference/tracker.py index f7c29b4..9f2e713 100644 --- a/dreem/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -194,7 +194,7 @@ def sliding_inference( for i, instance in enumerate(frames[batch_idx].instances): if instance.pred_track_id == -1: - curr_track += 1 + curr_track_id += 1 instance.pred_track_id = curr_track_id else: diff --git a/dreem/io/config.py b/dreem/io/config.py index b018790..53fd733 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -57,7 +57,7 @@ def from_yaml(cls, base_cfg_path: str, params_cfg_path: str | None = None) -> No params_cfg_path: path to override params. """ base_cfg = OmegaConf.load(base_cfg_path) - params_cfg = OmegaConf.load(params_cfg_path) if params_cfg else None + params_cfg = OmegaConf.load(params_cfg_path) if params_cfg_path else None return cls(base_cfg, params_cfg) def set_hparams(self, hparams: dict) -> bool: @@ -83,17 +83,35 @@ def set_hparams(self, hparams: dict) -> bool: return False return True + def get(self, key: str, default=None, cfg: dict = None): + """Get config item. + + Args: + key: key of item to return + default: default value to return if key is missing. + cfg: the config dict from which to retrieve an item + """ + if cfg is None: + cfg = self.cfg + + param = cfg.get(key, default) + + if isinstance(param, DictConfig): + param = OmegaConf.to_container(param, resolve=True) + + return param + def get_model(self) -> "GlobalTrackingTransformer": """Getter for gtr model. Returns: A global tracking transformer with parameters indicated by cfg """ - from dreem.models import GlobalTrackingTransformer + from dreem.models import GlobalTrackingTransformer, GTRRunner - model_params = self.cfg.model - with open_dict(model_params): - ckpt_path = model_params.pop("ckpt_path", None) + model_params = self.get("model", {}) + + ckpt_path = model_params.pop("ckpt_path", None) if ckpt_path is not None and len(ckpt_path) > 0: return GTRRunner.load_from_checkpoint(ckpt_path).model @@ -106,11 +124,7 @@ def get_tracker_cfg(self) -> dict: Returns: A dict containing the init params for `Tracker`. """ - tracker_params = self.cfg.tracker - tracker_cfg = {} - for key, val in tracker_params.items(): - tracker_cfg[key] = val - return tracker_cfg + return self.get("tracker", {}) def get_gtr_runner(self, ckpt_path: str | None = None) -> "GTRRunner": """Get lightning module for training, validation, and inference. @@ -123,36 +137,34 @@ def get_gtr_runner(self, ckpt_path: str | None = None) -> "GTRRunner": """ from dreem.models import GTRRunner - tracker_params = self.cfg.tracker - optimizer_params = self.cfg.optimizer - scheduler_params = self.cfg.scheduler - loss_params = self.cfg.loss - gtr_runner_params = self.cfg.runner - model_params = self.cfg.model + keys = ["tracker", "optimizer", "scheduler", "loss", "runner", "model"] + args = [key + "_cfg" if key != "runner" else key for key in keys] + + params = {} + for key, arg in zip(keys, args): + sub_params = self.get(key, {}) - if ckpt_path is None: - with open_dict(model_params): - ckpt_path = model_params.pop("ckpt_path", None) + if len(sub_params) == 0: + logger.warning( + f"`{key}` not found in config or is empty. Using defaults for {arg}!" + ) + + if key == "runner": + runner_params = sub_params + for k, v in runner_params.items(): + params[k] = v + else: + params[arg] = sub_params + + ckpt_path = params["model_cfg"].pop("ckpt_path", None) if ckpt_path is not None and ckpt_path != "": model = GTRRunner.load_from_checkpoint( - ckpt_path, - tracker_cfg=tracker_params, - train_metrics=self.cfg.runner.metrics.train, - val_metrics=self.cfg.runner.metrics.val, - test_metrics=self.cfg.runner.metrics.test, - test_save_path=self.cfg.runner.save_path, + ckpt_path, tracker_cfg=params["tracker_cfg"], **runner_params ) else: - model = GTRRunner( - model_params, - tracker_params, - loss_params, - optimizer_params, - scheduler_params, - **gtr_runner_params, - ) + model = GTRRunner(**params) return model @@ -165,9 +177,10 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]: Returns: lists of labels file paths and video file paths respectively """ - with open_dict(data_cfg): - dir_cfg = data_cfg.pop("dir", None) + dir_cfg = data_cfg.pop("dir", None) + label_files = vid_files = None + if dir_cfg: labels_suff = dir_cfg.labels_suffix vid_suff = dir_cfg.vid_suffix @@ -181,14 +194,14 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]: else: if "slp_files" in data_cfg: - label_files = data_cfg.slp_files - vid_files = data_cfg.video_files + label_files = data_cfg["slp_files"] + vid_files = data_cfg["video_files"] elif "tracks" in data_cfg or "source" in data_cfg: - label_files = data_cfg.tracks - vid_files = data_cfg.videos + label_files = data_cfg["tracks"] + vid_files = data_cfg["videos"] elif "raw_images" in data_cfg: - label_files = data_cfg.gt_images - vid_files = data_cfg.raw_images + label_files = data_cfg["gt_images"] + vid_files = data_cfg["raw_images"] return label_files, vid_files @@ -211,39 +224,42 @@ def get_dataset( """ from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset + dataset_params = self.get("dataset") + if dataset_params is None: + raise KeyError("`dataset` key is missing from cfg!") + if mode.lower() == "train": - dataset_params = self.cfg.dataset.train_dataset + dataset_params = self.get("train_dataset", {}, dataset_params) elif mode.lower() == "val": - dataset_params = self.cfg.dataset.val_dataset + dataset_params = self.get("val_dataset", {}, dataset_params) elif mode.lower() == "test": - dataset_params = self.cfg.dataset.test_dataset + dataset_params = self.get("test_dataset", {}, dataset_params) else: raise ValueError( "`mode` must be one of ['train', 'val','test'], not '{mode}'" ) if label_files is None or vid_files is None: - with open_dict(dataset_params): - label_files, vid_files = self.get_data_paths(dataset_params) + label_files, vid_files = self.get_data_paths(dataset_params) # todo: handle this better if "slp_files" in dataset_params: if label_files is not None: - dataset_params.slp_files = label_files + dataset_params["slp_files"] = label_files if vid_files is not None: - dataset_params.video_files = vid_files + dataset_params["video_files"] = vid_files return SleapDataset(**dataset_params) elif "tracks" in dataset_params or "source" in dataset_params: if label_files is not None: - dataset_params.tracks = label_files + dataset_params["tracks"] = label_files if vid_files is not None: - dataset_params.videos = vid_files + dataset_params["videos"] = vid_files return MicroscopyDataset(**dataset_params) elif "raw_images" in dataset_params: if label_files is not None: - dataset_params.gt_images = label_files + dataset_params["gt_images"] = label_files if vid_files is not None: - dataset_params.raw_images = vid_files + dataset_params["raw_images"] = vid_files return CellTrackingDataset(**dataset_params) else: @@ -267,17 +283,18 @@ def get_dataloader( Returns: A torch dataloader for `dataset` with parameters configured as specified """ + dataloader_params = self.get("dataloader", {}) if mode.lower() == "train": - dataloader_params = self.cfg.dataloader.train_dataloader + dataloader_params = self.get("train_dataloader", {}, dataloader_params) elif mode.lower() == "val": - dataloader_params = self.cfg.dataloader.val_dataloader + dataloader_params = self.get("val_dataloader", {}, dataloader_params) elif mode.lower() == "test": - dataloader_params = self.cfg.dataloader.test_dataloader + dataloader_params = self.get("test_dataloader", {}, dataloader_params) else: raise ValueError( "`mode` must be one of ['train', 'val','test'], not '{mode}'" ) - if dataloader_params.num_workers > 0: + if dataloader_params.get("num_workers", 0) > 0: # prevent too many open files error pin_memory = True torch.multiprocessing.set_sharing_strategy("file_system") @@ -304,13 +321,13 @@ def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: """ from dreem.models.model_utils import init_optimizer - optimizer_params = self.cfg.optimizer + optimizer_params = self.get("optimizer") return init_optimizer(params, optimizer_params) def get_scheduler( self, optimizer: torch.optim.Optimizer - ) -> torch.optim.lr_scheduler.LRScheduler: + ) -> torch.optim.lr_scheduler.LRScheduler | None: """Getter for lr scheduler. Args: @@ -321,8 +338,13 @@ def get_scheduler( """ from dreem.models.model_utils import init_scheduler - lr_scheduler_params = self.cfg.scheduler + lr_scheduler_params = self.get("scheduler") + if lr_scheduler_params is None: + logger.warning( + "`scheduler` key not found in cfg or is empty. No scheduler will be returned!" + ) + return None return init_scheduler(optimizer, lr_scheduler_params) def get_loss(self) -> "dreem.training.losses.AssoLoss": @@ -333,7 +355,12 @@ def get_loss(self) -> "dreem.training.losses.AssoLoss": """ from dreem.training.losses import AssoLoss - loss_params = self.cfg.loss + loss_params = self.get("loss", {}) + + if len(loss_params) == 0: + logger.warning( + "`loss` key not found in cfg. Using default params for `AssoLoss`" + ) return AssoLoss(**loss_params) @@ -345,7 +372,11 @@ def get_logger(self) -> pl.loggers.Logger: """ from dreem.models.model_utils import init_logger - logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) + logger_params = self.get("logging", {}) + if len(logger_params) == 0: + logger.warning( + "`logging` key not found in cfg. No logger will be configured!" + ) return init_logger( logger_params, OmegaConf.to_container(self.cfg, resolve=True) @@ -357,7 +388,15 @@ def get_early_stopping(self) -> pl.callbacks.EarlyStopping: Returns: A lightning early stopping callback with specified params """ - early_stopping_params = self.cfg.early_stopping + early_stopping_params = self.get("early_stopping", None) + + if early_stopping_params is None: + logger.warning( + "`early_stopping` was not found in cfg or was `null`. Early stopping will not be used!" + ) + return None + elif len(early_stopping_params) == 0: + logger.warning("`early_stopping` cfg is empty! Using defaults") return pl.callbacks.EarlyStopping(**early_stopping_params) def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: @@ -367,9 +406,11 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: A lightning checkpointing callback with specified params """ # convert to dict to enable extracting/removing params - checkpoint_params = self.cfg.checkpointing - logging_params = self.cfg.logging + checkpoint_params = self.get("checkpointing", {}) + logging_params = self.get("logging", {}) + dirpath = checkpoint_params.pop("dirpath", None) + if dirpath is None: if "group" in logging_params: dirpath = f"./models/{logging_params.group}/{logging_params.name}" @@ -382,13 +423,22 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: Path(dirpath).mkdir(parents=True, exist_ok=True) except OSError as e: logger.exception( - f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}" + f"Cannot create a new folder!. Check the permissions to {dirpath}. \n {e}" ) - with open_dict(checkpoint_params): - _ = checkpoint_params.pop("dirpath", None) - monitor = checkpoint_params.pop("monitor", ["val_loss"]) + + _ = checkpoint_params.pop("dirpath", None) + monitor = checkpoint_params.pop("monitor", ["val_loss"]) checkpointers = [] + logger.info( + f"Saving checkpoints to `{dirpath}` based on the following metrics: {monitor}" + ) + if len(checkpoint_params) == 0: + logger.warning( + """`checkpointing` key was not found in cfg or was empty! + Configuring checkpointing to use default params!""" + ) + for metric in monitor: checkpointer = pl.callbacks.ModelCheckpoint( monitor=metric, @@ -419,22 +469,20 @@ def get_trainer( Returns: A lightning Trainer with specified params """ - if "trainer" in self.cfg: - trainer_params = OmegaConf.to_container(self.cfg.trainer, resolve=True) - - else: - trainer_params = {} - + trainer_params = self.get("trainer", {}) profiler = trainer_params.pop("profiler", None) + if len(trainer_params) == 0: + print( + "`trainer` key was not found in cfg or was empty. Using defaults for `pl.Trainer`!" + ) + if "accelerator" not in trainer_params: trainer_params["accelerator"] = accelerator if "devices" not in trainer_params: trainer_params["devices"] = devices - if "profiler": + if profiler: profiler = pl.profilers.AdvancedProfiler(filename="profile.txt") - else: - profiler = None return pl.Trainer( callbacks=callbacks, diff --git a/dreem/models/gtr_runner.py b/dreem/models/gtr_runner.py index 01c2039..2a09ec3 100644 --- a/dreem/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -65,7 +65,6 @@ def __init__( self.loss_cfg = loss_cfg if loss_cfg else {} self.tracker_cfg = tracker_cfg if tracker_cfg else {} - _ = self.model_cfg.pop("ckpt_path", None) self.model = GlobalTrackingTransformer(**self.model_cfg) self.loss = AssoLoss(**self.loss_cfg) self.tracker = Tracker(**self.tracker_cfg) diff --git a/dreem/models/model_utils.py b/dreem/models/model_utils.py index fa5c773..b886d4f 100644 --- a/dreem/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -106,7 +106,9 @@ def init_optimizer(params: Iterable, config: dict) -> torch.optim.Optimizer: Returns: optimizer: A torch.Optimizer with specified params """ - optimizer = config["name"] + if config is None: + config = {"name": "Adam"} + optimizer = config.get("name", "Adam") optimizer_params = { param: val for param, val in config.items() if param.lower() != "name" } @@ -145,7 +147,12 @@ def init_scheduler( Returns: scheduler: A scheduler with specified params """ - scheduler = config["name"] + if config is None: + return None + scheduler = config.get("name") + if scheduler is None: + scheduler = "ReduceLROnPlateau" + scheduler_params = { param: val for param, val in config.items() if param.lower() != "name" } diff --git a/dreem/training/train.py b/dreem/training/train.py index 372bfa6..254a714 100644 --- a/dreem/training/train.py +++ b/dreem/training/train.py @@ -80,7 +80,10 @@ def run(cfg: DictConfig): callbacks = [] _ = callbacks.extend(train_cfg.get_checkpointing()) _ = callbacks.append(pl.callbacks.LearningRateMonitor()) - _ = callbacks.append(train_cfg.get_early_stopping()) + + early_stopping = train_cfg.get_early_stopping() + if early_stopping is not None: + callbacks.append(early_stopping) accelerator = "gpu" if torch.cuda.is_available() else "cpu" devices = torch.cuda.device_count() if torch.cuda.is_available() else cpu_count() diff --git a/environment_osx-arm64.yml b/environment_osx-arm64.yml index 31ae80f..9917b13 100644 --- a/environment_osx-arm64.yml +++ b/environment_osx-arm64.yml @@ -17,5 +17,12 @@ dependencies: - matplotlib - pip - pip: + - matplotlib + - sleap-io - "--editable=.[dev]" + - imageio[ffmpeg] + - hydra-core + - motmetrics + - seaborn + - wandb - timm \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py index 0b1c826..74f8b36 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,7 @@ """Tests for `config.py`""" -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict +from copy import deepcopy from dreem.io import Config from dreem.models import GlobalTrackingTransformer, GTRRunner @@ -96,3 +97,32 @@ def test_getters(base_config): scheduler = cfg.get_scheduler(optim) assert isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + + +def test_missing(base_config): + """Test cases when keys are missing from config for expected behavior. + + Args: + base_config: the config params to override + """ + cfg = Config.from_yaml(base_config) + + key = "model" + with open_dict(cfg.cfg): + cfg.cfg.pop(key) + assert isinstance(cfg.get_model(), GlobalTrackingTransformer) + + cfg = Config.from_yaml(base_config) + key = "tracker" + with open_dict(cfg.cfg): + cfg.cfg.pop(key) + assert ( + isinstance(cfg.get_tracker_cfg(), dict) and len(cfg.get_tracker_cfg()) == 0 + ) + + cfg = Config.from_yaml(base_config) + keys = ["tracker", "optimizer", "scheduler", "loss", "runner", "model"] + with open_dict(cfg.cfg): + for key in keys: + cfg.cfg.pop(key) + assert isinstance(cfg.get_gtr_runner(), GTRRunner)