diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30a60b49..169c9e23 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,7 @@ on: pull_request: types: [opened, reopened, synchronize] paths: - - "biogtr/**" + - "dreem/**" - "tests/**" - ".github/workflows/ci.yml" - "environment_cpu.yml" @@ -14,7 +14,7 @@ on: branches: - main paths: - - "biogtr/**" + - "dreem/**" - "tests/**" - ".github/workflows/ci.yml" - "environment_cpu.yml" @@ -53,11 +53,11 @@ jobs: - name: Run Black run: | - black --check biogtr tests + black --check dreem tests - name: Run pydocstyle run: | - pydocstyle --convention=google biogtr/ + pydocstyle --convention=google dreem/ # Tests with pytest tests: @@ -105,7 +105,7 @@ jobs: if: ${{ startsWith(matrix.os, 'ubuntu') && matrix.python == 3.9 }} shell: bash -l {0} run: | - pytest --cov=biogtr --cov-report=xml tests/ + pytest --cov=dreem --cov-report=xml tests/ - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index e9fdce8a..b69e7d35 100644 --- a/.gitignore +++ b/.gitignore @@ -137,5 +137,5 @@ logs # vscode .vscode -biogtr/training/.hydra/* -biogtr/training/models/* +dreem/training/.hydra/* +dreem/training/models/* diff --git a/README.md b/README.md index e8aa69c5..495eedc3 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,22 @@ -# BioGTR +# DREEM Reconstructs Every Entities' Motion Global Tracking Transformers for biological multi-object tracking. ## Installation ### Development 1. Clone the repository: ``` -git clone https://github.com/talmolab/biogtr && cd biogtr +git clone https://github.com/talmolab/dreem && cd dreem ``` 2. Set up in a new conda environment: ``` -conda env create -y -f environment.yml && conda activate biogtr +conda env create -y -f environment.yml && conda activate dreem ``` ### Uninstalling ``` -conda env remove -n biogtr +conda env remove -n dreem ``` \ No newline at end of file diff --git a/biogtr/__init__.py b/biogtr/__init__.py deleted file mode 100644 index e4c823ef..00000000 --- a/biogtr/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Top-level package for BioGTR.""" - -from biogtr.version import __version__ - -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models.gtr_runner import GTRRunner -from biogtr.models.transformer import Transformer -from biogtr.models.visual_encoder import VisualEncoder - -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix -from biogtr.io.config import Config -from biogtr.io.visualize import annotate_video - -# from .training import run - -from biogtr.inference.tracker import Tracker diff --git a/biogtr/cli.py b/biogtr/cli.py deleted file mode 100644 index 31db5230..00000000 --- a/biogtr/cli.py +++ /dev/null @@ -1 +0,0 @@ -"""This module contains the command line interfaces for the biogtr package.""" diff --git a/biogtr/io/__init__.py b/biogtr/io/__init__.py deleted file mode 100644 index 0cda02ae..00000000 --- a/biogtr/io/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Module containing input/output data structures for easy storage and manipulation.""" - -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix -from biogtr.io.track import Track -from biogtr.io.config import Config diff --git a/dreem/__init__.py b/dreem/__init__.py new file mode 100644 index 00000000..5299cbe3 --- /dev/null +++ b/dreem/__init__.py @@ -0,0 +1,18 @@ +"""Top-level package for dreem.""" + +from dreem.version import __version__ + +from dreem.models.global_tracking_transformer import GlobalTrackingTransformer +from dreem.models.gtr_runner import GTRRunner +from dreem.models.transformer import Transformer +from dreem.models.visual_encoder import VisualEncoder + +from dreem.io.frame import Frame +from dreem.io.instance import Instance +from dreem.io.association_matrix import AssociationMatrix +from dreem.io.config import Config +from dreem.io.visualize import annotate_video + +# from .training import run + +from dreem.inference.tracker import Tracker diff --git a/dreem/cli.py b/dreem/cli.py new file mode 100644 index 00000000..1755475c --- /dev/null +++ b/dreem/cli.py @@ -0,0 +1 @@ +"""This module contains the command line interfaces for the dreem package.""" diff --git a/biogtr/datasets/__init__.py b/dreem/datasets/__init__.py similarity index 100% rename from biogtr/datasets/__init__.py rename to dreem/datasets/__init__.py diff --git a/biogtr/datasets/base_dataset.py b/dreem/datasets/base_dataset.py similarity index 98% rename from biogtr/datasets/base_dataset.py rename to dreem/datasets/base_dataset.py index 15b87d45..8dd6af4b 100644 --- a/biogtr/datasets/base_dataset.py +++ b/dreem/datasets/base_dataset.py @@ -1,7 +1,7 @@ """Module containing logic for loading datasets.""" -from biogtr.datasets import data_utils -from biogtr.io import Frame +from dreem.datasets import data_utils +from dreem.io import Frame from torch.utils.data import Dataset from typing import List, Union import numpy as np diff --git a/biogtr/datasets/cell_tracking_dataset.py b/dreem/datasets/cell_tracking_dataset.py similarity index 97% rename from biogtr/datasets/cell_tracking_dataset.py rename to dreem/datasets/cell_tracking_dataset.py index 9567de46..74e7182a 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/dreem/datasets/cell_tracking_dataset.py @@ -1,8 +1,8 @@ """Module containing cell tracking challenge dataset.""" from PIL import Image -from biogtr.datasets import data_utils, BaseDataset -from biogtr.io import Frame, Instance +from dreem.datasets import data_utils, BaseDataset +from dreem.io import Frame, Instance from scipy.ndimage import measurements from typing import List, Optional, Union import albumentations as A @@ -122,7 +122,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram Returns: a list of Frame objects containing frame metadata and Instance Objects. - See `biogtr.io.data_structures` for more info. + See `dreem.io.data_structures` for more info. """ image = self.videos[label_idx] gt = self.labels[label_idx] diff --git a/biogtr/datasets/data_utils.py b/dreem/datasets/data_utils.py similarity index 100% rename from biogtr/datasets/data_utils.py rename to dreem/datasets/data_utils.py diff --git a/biogtr/datasets/eval_dataset.py b/dreem/datasets/eval_dataset.py similarity index 98% rename from biogtr/datasets/eval_dataset.py rename to dreem/datasets/eval_dataset.py index 95836a51..c8c5c63a 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/dreem/datasets/eval_dataset.py @@ -1,7 +1,7 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" from torch.utils.data import Dataset -from biogtr.io import Instance, Frame +from dreem.io import Instance, Frame from typing import List diff --git a/biogtr/datasets/microscopy_dataset.py b/dreem/datasets/microscopy_dataset.py similarity index 98% rename from biogtr/datasets/microscopy_dataset.py rename to dreem/datasets/microscopy_dataset.py index 9656d19d..484c453d 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/dreem/datasets/microscopy_dataset.py @@ -1,8 +1,8 @@ """Module containing microscopy dataset.""" from PIL import Image -from biogtr.datasets import data_utils, BaseDataset -from biogtr.io import Instance, Frame +from dreem.datasets import data_utils, BaseDataset +from dreem.io import Instance, Frame from typing import Union import albumentations as A import numpy as np @@ -122,7 +122,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frame_idx: index of the frames Returns: - A list of Frames containing Instances to be tracked (See `biogtr.io.data_structures for more info`) + A list of Frames containing Instances to be tracked (See `dreem.io.data_structures for more info`) """ labels = self.labels[label_idx] labels = labels.dropna(how="all") diff --git a/biogtr/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py similarity index 99% rename from biogtr/datasets/sleap_dataset.py rename to dreem/datasets/sleap_dataset.py index b23b4e70..3538dc60 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -7,8 +7,8 @@ import sleap_io as sio import random import warnings -from biogtr.io import Instance, Frame -from biogtr.datasets import data_utils, BaseDataset +from dreem.io import Instance, Frame +from dreem.datasets import data_utils, BaseDataset from torchvision.transforms import functional as tvf from typing import List, Union diff --git a/biogtr/datasets/tracking_dataset.py b/dreem/datasets/tracking_dataset.py similarity index 95% rename from biogtr/datasets/tracking_dataset.py rename to dreem/datasets/tracking_dataset.py index fdc54cac..960bf2d1 100644 --- a/biogtr/datasets/tracking_dataset.py +++ b/dreem/datasets/tracking_dataset.py @@ -1,8 +1,8 @@ """Module containing Lightning module wrapper around all other datasets.""" -from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset -from biogtr.datasets.microscopy_dataset import MicroscopyDataset -from biogtr.datasets.sleap_dataset import SleapDataset +from dreem.datasets.cell_tracking_dataset import CellTrackingDataset +from dreem.datasets.microscopy_dataset import MicroscopyDataset +from dreem.datasets.sleap_dataset import SleapDataset from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from typing import Union diff --git a/biogtr/inference/__init__.py b/dreem/inference/__init__.py similarity index 100% rename from biogtr/inference/__init__.py rename to dreem/inference/__init__.py diff --git a/biogtr/inference/boxes.py b/dreem/inference/boxes.py similarity index 100% rename from biogtr/inference/boxes.py rename to dreem/inference/boxes.py diff --git a/biogtr/inference/configs/inference.yaml b/dreem/inference/configs/inference.yaml similarity index 100% rename from biogtr/inference/configs/inference.yaml rename to dreem/inference/configs/inference.yaml diff --git a/biogtr/inference/metrics.py b/dreem/inference/metrics.py similarity index 96% rename from biogtr/inference/metrics.py rename to dreem/inference/metrics.py index c80c15c3..935df49b 100644 --- a/biogtr/inference/metrics.py +++ b/dreem/inference/metrics.py @@ -5,11 +5,11 @@ import torch from typing import Union, Iterable -# from biogtr.inference.post_processing import _pairwise_iou -# from biogtr.inference.boxes import Boxes +# from dreem.inference.post_processing import _pairwise_iou +# from dreem.inference.boxes import Boxes -def get_matches(frames: list["biogtr.io.Frame"]) -> tuple[dict, list, int]: +def get_matches(frames: list["dreem.io.Frame"]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. Args: @@ -100,11 +100,11 @@ def get_switch_count(switches: dict) -> int: return sw_cnt -def to_track_eval(frames: list["biogtr.io.Frame"]) -> dict: +def to_track_eval(frames: list["dreem.io.Frame"]) -> dict: """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: - frames: A list of Frames. `See biogtr.io.data_structures for more info`. + frames: A list of Frames. `See dreem.io.data_structures for more info`. Returns: data: A dictionary. Example provided below. diff --git a/biogtr/inference/post_processing.py b/dreem/inference/post_processing.py similarity index 99% rename from biogtr/inference/post_processing.py rename to dreem/inference/post_processing.py index 1aaf21cc..e4db20bc 100644 --- a/biogtr/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -1,7 +1,7 @@ """Helper functions for post-processing association matrix pre-tracking.""" import torch -from biogtr.inference.boxes import Boxes +from dreem.inference.boxes import Boxes def weight_decay_time( diff --git a/biogtr/inference/track.py b/dreem/inference/track.py similarity index 95% rename from biogtr/inference/track.py rename to dreem/inference/track.py index aa766d5d..07b8464b 100644 --- a/biogtr/inference/track.py +++ b/dreem/inference/track.py @@ -1,7 +1,7 @@ """Script to run inference and get out tracks.""" -from biogtr.io import Config -from biogtr.models import GTRRunner +from dreem.io import Config +from dreem.models import GTRRunner from omegaconf import DictConfig from pathlib import Path from pprint import pprint @@ -14,7 +14,7 @@ import sleap_io as sio -def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None): +def export_trajectories(frames_pred: list["dreem.io.Frame"], save_path: str = None): """Convert trajectories to data frame and save as .csv. Args: @@ -132,7 +132,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: for i, pred in preds.items(): outpath = os.path.join( outdir, - f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp", + f"{Path(dataloader.dataset.label_files[i]).stem}.dreem_inference.v{run_num}.slp", ) if os.path.exists(outpath): run_num += 1 diff --git a/biogtr/inference/track_queue.py b/dreem/inference/track_queue.py similarity index 99% rename from biogtr/inference/track_queue.py rename to dreem/inference/track_queue.py index 739869fb..54927ac0 100644 --- a/biogtr/inference/track_queue.py +++ b/dreem/inference/track_queue.py @@ -1,7 +1,7 @@ """Module handling sliding window tracking.""" import warnings -from biogtr.io import Frame +from dreem.io import Frame from collections import deque import numpy as np diff --git a/biogtr/inference/tracker.py b/dreem/inference/tracker.py similarity index 98% rename from biogtr/inference/tracker.py rename to dreem/inference/tracker.py index 4aa36f39..93bcad8e 100644 --- a/biogtr/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -3,11 +3,11 @@ import torch import pandas as pd import warnings -from biogtr.io import Frame -from biogtr.models import model_utils, GlobalTrackingTransformer -from biogtr.inference.track_queue import TrackQueue -from biogtr.inference import post_processing -from biogtr.inference.boxes import Boxes +from dreem.io import Frame +from dreem.models import model_utils, GlobalTrackingTransformer +from dreem.inference.track_queue import TrackQueue +from dreem.inference import post_processing +from dreem.inference.boxes import Boxes from scipy.optimize import linear_sum_assignment from math import inf @@ -127,7 +127,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames (See `biogtr.io.Frame` for more info). + frames: A list of Frames (See `dreem.io.Frame` for more info). Returns: Frames: A list of Frames populated with pred_track_ids and asso_matrices @@ -207,7 +207,7 @@ def _run_global_tracker( Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames containing reid features. See `biogtr.io.data_structures` for more info. + frames: A list of Frames containing reid features. See `dreem.io.data_structures` for more info. query_ind: An integer for the query frame within the window of instances. Returns: diff --git a/dreem/io/__init__.py b/dreem/io/__init__.py new file mode 100644 index 00000000..5bd23340 --- /dev/null +++ b/dreem/io/__init__.py @@ -0,0 +1,7 @@ +"""Module containing input/output data structures for easy storage and manipulation.""" + +from dreem.io.frame import Frame +from dreem.io.instance import Instance +from dreem.io.association_matrix import AssociationMatrix +from dreem.io.track import Track +from dreem.io.config import Config diff --git a/biogtr/io/association_matrix.py b/dreem/io/association_matrix.py similarity index 99% rename from biogtr/io/association_matrix.py rename to dreem/io/association_matrix.py index 9d6d366e..84aee035 100644 --- a/biogtr/io/association_matrix.py +++ b/dreem/io/association_matrix.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import attrs -from biogtr.io import Instance +from dreem.io import Instance from typing import Union diff --git a/biogtr/io/config.py b/dreem/io/config.py similarity index 96% rename from biogtr/io/config.py rename to dreem/io/config.py index 7ea8a0ac..7cd1e6de 100644 --- a/biogtr/io/config.py +++ b/dreem/io/config.py @@ -85,7 +85,7 @@ def get_model(self) -> "GlobalTrackingTransformer": Returns: A global tracking transformer with parameters indicated by cfg """ - from biogtr.models import GlobalTrackingTransformer + from dreem.models import GlobalTrackingTransformer model_params = self.cfg.model ckpt_path = model_params.pop("ckpt_path", None) @@ -109,7 +109,7 @@ def get_tracker_cfg(self) -> dict: def get_gtr_runner(self) -> "GTRRunner": """Get lightning module for training, validation, and inference.""" - from biogtr.models import GTRRunner + from dreem.models import GTRRunner tracker_params = self.cfg.tracker optimizer_params = self.cfg.optimizer @@ -174,7 +174,7 @@ def get_dataset( Returns: Either a `SleapDataset` or `MicroscopyDataset` with params indicated by cfg """ - from biogtr.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset + from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset if mode.lower() == "train": dataset_params = self.cfg.dataset.train_dataset @@ -276,7 +276,7 @@ def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: Returns: A torch Optimizer with specified params """ - from biogtr.models.model_utils import init_optimizer + from dreem.models.model_utils import init_optimizer optimizer_params = self.cfg.optimizer @@ -293,19 +293,19 @@ def get_scheduler( Returns: A torch learning rate scheduler with specified params """ - from biogtr.models.model_utils import init_scheduler + from dreem.models.model_utils import init_scheduler lr_scheduler_params = self.cfg.scheduler return init_scheduler(optimizer, lr_scheduler_params) - def get_loss(self) -> "biogtr.training.losses.AssoLoss": + def get_loss(self) -> "dreem.training.losses.AssoLoss": """Getter for loss functions. Returns: An AssoLoss with specified params """ - from biogtr.training.losses import AssoLoss + from dreem.training.losses import AssoLoss loss_params = self.cfg.loss @@ -317,7 +317,7 @@ def get_logger(self): Returns: A Logger with specified params """ - from biogtr.models.model_utils import init_logger + from dreem.models.model_utils import init_logger logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) diff --git a/biogtr/io/frame.py b/dreem/io/frame.py similarity index 99% rename from biogtr/io/frame.py rename to dreem/io/frame.py index 5607e832..67f6dc95 100644 --- a/biogtr/io/frame.py +++ b/dreem/io/frame.py @@ -129,15 +129,15 @@ def from_slp( device: str = None, **kwargs, ) -> "Frame": - """Convert `sio.LabeledFrame` to `biogtr.io.Frame`. + """Convert `sio.LabeledFrame` to `dreem.io.Frame`. Args: lf: A sio.LabeledFrame object Returns: - A biogtr.io.Frame object + A dreem.io.Frame object """ - from biogtr.io import Instance + from dreem.io import Instance img_shape = lf.image.shape if len(img_shape) == 2: diff --git a/biogtr/io/instance.py b/dreem/io/instance.py similarity index 99% rename from biogtr/io/instance.py rename to dreem/io/instance.py index 5ffef867..64e1af7d 100644 --- a/biogtr/io/instance.py +++ b/dreem/io/instance.py @@ -159,7 +159,7 @@ def from_slp( crop: ArrayLike = None, device: str = None, ) -> None: - """Convert a slp instance to a biogtr instance. + """Convert a slp instance to a dreem instance. Args: slp_instance: A `sleap_io.Instance` object representing a detection @@ -167,7 +167,7 @@ def from_slp( crop: The corresponding crop of the bbox device: which device to keep the instance on Returns: - A biogtr.Instance object with a pose-centered bbox and no crop. + A dreem.Instance object with a pose-centered bbox and no crop. """ try: track_id = int(slp_instance.track.name) diff --git a/biogtr/io/track.py b/dreem/io/track.py similarity index 100% rename from biogtr/io/track.py rename to dreem/io/track.py diff --git a/biogtr/io/visualize.py b/dreem/io/visualize.py similarity index 100% rename from biogtr/io/visualize.py rename to dreem/io/visualize.py diff --git a/biogtr/models/__init__.py b/dreem/models/__init__.py similarity index 100% rename from biogtr/models/__init__.py rename to dreem/models/__init__.py diff --git a/biogtr/models/attention_head.py b/dreem/models/attention_head.py similarity index 97% rename from biogtr/models/attention_head.py rename to dreem/models/attention_head.py index ed8c6f50..2b160552 100644 --- a/biogtr/models/attention_head.py +++ b/dreem/models/attention_head.py @@ -1,7 +1,7 @@ """Module containing different components of multi-head attention heads.""" import torch -from biogtr.models.mlp import MLP +from dreem.models.mlp import MLP # todo: add named tensors diff --git a/biogtr/models/embedding.py b/dreem/models/embedding.py similarity index 99% rename from biogtr/models/embedding.py rename to dreem/models/embedding.py index 222d8585..cf9a4f6f 100644 --- a/biogtr/models/embedding.py +++ b/dreem/models/embedding.py @@ -3,7 +3,7 @@ from typing import Tuple, Optional import math import torch -from biogtr.models.mlp import MLP +from dreem.models.mlp import MLP # todo: add named tensors, clean variable names diff --git a/biogtr/models/global_tracking_transformer.py b/dreem/models/global_tracking_transformer.py similarity index 95% rename from biogtr/models/global_tracking_transformer.py rename to dreem/models/global_tracking_transformer.py index c746d1aa..6114080e 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/dreem/models/global_tracking_transformer.py @@ -1,7 +1,7 @@ """Module containing GTR model used for training.""" -from biogtr.models import Transformer -from biogtr.models import VisualEncoder +from dreem.models import Transformer +from dreem.models import VisualEncoder import torch # todo: do we want to handle params with configs already here? @@ -51,8 +51,8 @@ def __init__( that no positional embeddings should be used. To use the positional embeddings pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, - "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` - and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES` + and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() diff --git a/biogtr/models/gtr_runner.py b/dreem/models/gtr_runner.py similarity index 91% rename from biogtr/models/gtr_runner.py rename to dreem/models/gtr_runner.py index 7dbf4b2b..35965674 100644 --- a/biogtr/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -2,14 +2,14 @@ import torch import gc -from biogtr.inference import Tracker -from biogtr.inference import metrics -from biogtr.models import GlobalTrackingTransformer -from biogtr.training.losses import AssoLoss -from biogtr.models.model_utils import init_optimizer, init_scheduler +from dreem.inference import Tracker +from dreem.inference import metrics +from dreem.models import GlobalTrackingTransformer +from dreem.training.losses import AssoLoss +from dreem.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from dreem.io.frame import Frame +from dreem.io.instance import Instance class GTRRunner(LightningModule): @@ -75,8 +75,8 @@ def __init__( def forward( self, - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> torch.Tensor: """Execute forward pass of the lightning module. @@ -91,7 +91,7 @@ def forward( return asso_preds def training_step( - self, train_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, train_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single training step for model. @@ -109,7 +109,7 @@ def training_step( return result def validation_step( - self, val_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, val_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single val step for model. @@ -127,7 +127,7 @@ def validation_step( return result def test_step( - self, test_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, test_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single test step for model. @@ -145,8 +145,8 @@ def test_step( return result def predict_step( - self, batch: list[list["biogtr.io.Frame"]], batch_idx: int - ) -> list["biogtr.io.Frame"]: + self, batch: list[list["dreem.io.Frame"]], batch_idx: int + ) -> list["dreem.io.Frame"]: """Run inference for model. Computes association + assignment. @@ -163,7 +163,7 @@ def predict_step( return frames_pred def _shared_eval_step( - self, frames: list["biogtr.io.Frame"], mode: str + self, frames: list["dreem.io.Frame"], mode: str ) -> dict[str, float]: """Run evaluation used by train, test, and val steps. diff --git a/biogtr/models/mlp.py b/dreem/models/mlp.py similarity index 100% rename from biogtr/models/mlp.py rename to dreem/models/mlp.py diff --git a/biogtr/models/model_utils.py b/dreem/models/model_utils.py similarity index 97% rename from biogtr/models/model_utils.py rename to dreem/models/model_utils.py index a2885a0f..249170f4 100644 --- a/biogtr/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -5,7 +5,7 @@ import torch -def get_boxes(instances: List["biogtr.io.Instance"]) -> torch.tensor: +def get_boxes(instances: List["dreem.io.Instance"]) -> torch.tensor: """Extract the bounding boxes from the input list of instances. Args: @@ -29,8 +29,8 @@ def get_boxes(instances: List["biogtr.io.Instance"]) -> torch.tensor: def get_times( - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Extract the time indices of each instance relative to the window length. diff --git a/biogtr/models/transformer.py b/dreem/models/transformer.py similarity index 97% rename from biogtr/models/transformer.py rename to dreem/models/transformer.py index 75579da5..1c4ff019 100644 --- a/biogtr/models/transformer.py +++ b/dreem/models/transformer.py @@ -11,10 +11,10 @@ * added fixed embeddings over boxes """ -from biogtr.io import AssociationMatrix -from biogtr.models.attention_head import ATTWeightHead -from biogtr.models import Embedding -from biogtr.models.model_utils import get_boxes, get_times +from dreem.io import AssociationMatrix +from dreem.models.attention_head import ATTWeightHead +from dreem.models import Embedding +from dreem.models.model_utils import get_boxes, get_times from torch import nn import copy import torch @@ -65,8 +65,8 @@ def __init__( that no positional embeddings should be used. To use the positional embeddings pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, - "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` - and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES` + and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() @@ -141,13 +141,13 @@ def _reset_parameters(self): def forward( self, - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> list[AssociationMatrix]: """Execute a forward pass through the transformer and attention head. Args: - ref instances: A list of instance objects (See `biogtr.io.Instance` for more info.) + ref instances: A list of instance objects (See `dreem.io.Instance` for more info.) query_instances: An set of instances to be used as decoder queries. Returns: diff --git a/biogtr/models/visual_encoder.py b/dreem/models/visual_encoder.py similarity index 100% rename from biogtr/models/visual_encoder.py rename to dreem/models/visual_encoder.py diff --git a/biogtr/training/__init__.py b/dreem/training/__init__.py similarity index 100% rename from biogtr/training/__init__.py rename to dreem/training/__init__.py diff --git a/biogtr/training/configs/base.yaml b/dreem/training/configs/base.yaml similarity index 100% rename from biogtr/training/configs/base.yaml rename to dreem/training/configs/base.yaml diff --git a/biogtr/training/configs/params.yaml b/dreem/training/configs/params.yaml similarity index 100% rename from biogtr/training/configs/params.yaml rename to dreem/training/configs/params.yaml diff --git a/biogtr/training/configs/test_batch_train.csv b/dreem/training/configs/test_batch_train.csv similarity index 100% rename from biogtr/training/configs/test_batch_train.csv rename to dreem/training/configs/test_batch_train.csv diff --git a/biogtr/training/losses.py b/dreem/training/losses.py similarity index 99% rename from biogtr/training/losses.py rename to dreem/training/losses.py index ff3e6eca..53b4289c 100644 --- a/biogtr/training/losses.py +++ b/dreem/training/losses.py @@ -1,6 +1,6 @@ """Module containing different loss functions to be optimized.""" -from biogtr.models.model_utils import get_boxes, get_times +from dreem.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple import torch diff --git a/biogtr/training/train.py b/dreem/training/train.py similarity index 96% rename from biogtr/training/train.py rename to dreem/training/train.py index e252d3f5..a2522817 100644 --- a/biogtr/training/train.py +++ b/dreem/training/train.py @@ -3,9 +3,9 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ -from biogtr.io import Config -from biogtr.datasets import TrackingDataset -from biogtr.datasets.data_utils import view_training_batch +from dreem.io import Config +from dreem.datasets import TrackingDataset +from dreem.datasets.data_utils import view_training_batch from multiprocessing import cpu_count from omegaconf import DictConfig from pprint import pprint diff --git a/biogtr/version.py b/dreem/version.py similarity index 100% rename from biogtr/version.py rename to dreem/version.py diff --git a/environment.yml b/environment.yml index da26b24b..0b129889 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/environment_cpu.yml b/environment_cpu.yml index 6714b9c1..bc63a14a 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/environment_osx-arm64.yml b/environment_osx-arm64.yml index 124c4576..0393cb66 100644 --- a/environment_osx-arm64.yml +++ b/environment_osx-arm64.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/pyproject.toml b/pyproject.toml index eeb8b59a..9b13894f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" [project] -name = "biogtr" +name = "dreem" authors = [ {name = "Arlo Sheridan", email = "asheridan@salk.edu"}, {name = "Aaditya Prasad", email = "aprasad@salk.edu"}, @@ -33,7 +33,7 @@ dependencies = [ dynamic = ["version", "readme"] [tool.setuptools.dynamic] -version = {attr = "biogtr.version.__version__"} +version = {attr = "dreem.version.__version__"} readme = {file = ["README.md"]} [project.optional-dependencies] @@ -48,11 +48,11 @@ dev = [ ] [project.scripts] -biogtr = "biogtr.cli:cli" +dreem = "dreem.cli:cli" [project.urls] -Homepage = "https://github.com/talmolab/biogtr" -Repository = "https://github.com/talmolab/biogtr" +Homepage = "https://github.com/talmolab/dreem" +Repository = "https://github.com/talmolab/dreem" [tool.black] line-length = 88 diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index db574099..f28e0950 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,4 +1,4 @@ -"""Fixtures for testing biogtr.""" +"""Fixtures for testing dreem.""" import pytest from pathlib import Path diff --git a/tests/test_config.py b/tests/test_config.py index 4f7ebbc7..0b1c8267 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,8 @@ """Tests for `config.py`""" from omegaconf import OmegaConf -from biogtr.io import Config -from biogtr.models import GlobalTrackingTransformer, GTRRunner +from dreem.io import Config +from dreem.models import GlobalTrackingTransformer, GTRRunner import torch diff --git a/tests/test_data_model.py b/tests/test_data_model.py index ef5d0320..aeeb09ca 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -1,6 +1,6 @@ """Tests for Instance, Frame, and AssociationMatrix Objects""" -from biogtr.io import Frame, Instance, AssociationMatrix, Track +from dreem.io import Frame, Instance, AssociationMatrix, Track import torch import pytest import numpy as np diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ab9d7640..2287c4f9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,14 +1,14 @@ """Test dataset logic.""" -from biogtr.datasets import ( +from dreem.datasets import ( BaseDataset, MicroscopyDataset, SleapDataset, CellTrackingDataset, TrackingDataset, ) -from biogtr.datasets.data_utils import get_max_padding, NodeDropout -from biogtr.models.model_utils import get_device +from dreem.datasets.data_utils import get_max_padding, NodeDropout +from dreem.models.model_utils import get_device from torch.utils.data import DataLoader import pytest import torch diff --git a/tests/test_inference.py b/tests/test_inference.py index a5ef05f9..06b7154b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,11 +5,11 @@ import numpy as np from pytorch_lightning import Trainer from omegaconf import OmegaConf, DictConfig -from biogtr.io import Frame, Instance, Config -from biogtr.models import GTRRunner, GlobalTrackingTransformer -from biogtr.inference import Tracker, post_processing, metrics -from biogtr.inference.track_queue import TrackQueue -from biogtr.inference.track import run +from dreem.io import Frame, Instance, Config +from dreem.models import GTRRunner, GlobalTrackingTransformer +from dreem.inference import Tracker, post_processing, metrics +from dreem.inference.track_queue import TrackQueue +from dreem.inference.track import run def test_track_queue(): diff --git a/tests/test_models.py b/tests/test_models.py index bf4e8e47..3eaf9c22 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,16 +2,16 @@ import pytest import torch -from biogtr.io import Frame, Instance -from biogtr.models.mlp import MLP -from biogtr.models.attention_head import ATTWeightHead -from biogtr.models import ( +from dreem.io import Frame, Instance +from dreem.models.mlp import MLP +from dreem.models.attention_head import ATTWeightHead +from dreem.models import ( Embedding, VisualEncoder, Transformer, GlobalTrackingTransformer, ) -from biogtr.models.transformer import ( +from dreem.models.transformer import ( TransformerEncoderLayer, TransformerDecoderLayer, ) diff --git a/tests/test_training.py b/tests/test_training.py index 357ce267..bd8bbe75 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,11 +3,11 @@ import os import pytest import torch -from biogtr.io import Frame, Instance, Config -from biogtr.training.losses import AssoLoss -from biogtr.models import GTRRunner +from dreem.io import Frame, Instance, Config +from dreem.training.losses import AssoLoss +from dreem.models import GTRRunner from omegaconf import OmegaConf, DictConfig -from biogtr.training.train import run +from dreem.training.train import run # TODO: add named tensor tests # TODO: use temp dir and cleanup after tests (https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html) diff --git a/tests/test_version.py b/tests/test_version.py index 6bde7e48..2d43f00b 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,8 +1,8 @@ """Test version.""" -import biogtr +import dreem def test_version(): """Test version.""" - assert biogtr.__version__ == biogtr.version.__version__ + assert dreem.__version__ == dreem.version.__version__