Skip to content

Commit

Permalink
Merge branch 'main' into aadi/documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Jun 6, 2024
2 parents a0ee366 + f04eb9f commit 7320578
Show file tree
Hide file tree
Showing 63 changed files with 349 additions and 297 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
types: [opened, reopened, synchronize]
paths:
- "biogtr/**"
- "dreem/**"
- "tests/**"
- ".github/workflows/ci.yml"
- "environment_cpu.yml"
Expand All @@ -14,7 +14,7 @@ on:
branches:
- main
paths:
- "biogtr/**"
- "dreem/**"
- "tests/**"
- ".github/workflows/ci.yml"
- "environment_cpu.yml"
Expand Down Expand Up @@ -53,11 +53,11 @@ jobs:
- name: Run Black
run: |
black --check biogtr tests
black --check dreem tests
- name: Run pydocstyle
run: |
pydocstyle --convention=google biogtr/
pydocstyle --convention=google dreem/
# Tests with pytest
tests:
Expand Down Expand Up @@ -105,7 +105,7 @@ jobs:
if: ${{ startsWith(matrix.os, 'ubuntu') && matrix.python == 3.9 }}
shell: bash -l {0}
run: |
pytest --cov=biogtr --cov-report=xml tests/
pytest --cov=dreem --cov-report=xml tests/
- name: Upload coverage
uses: codecov/codecov-action@v3
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,5 @@ logs

# vscode
.vscode
biogtr/training/.hydra/*
biogtr/training/models/*
dreem/training/.hydra/*
dreem/training/models/*
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# BioGTR
# DREEM Reconstructs Every Entities' Motion
Global Tracking Transformers for biological multi-object tracking.

## Installation
<!-- ### Basic
```
pip install biogtr
pip install dreem
``` -->
### Development
#### Clone the repository:
Expand All @@ -27,7 +27,7 @@ conda env create -y -f environment_osx-arm.yml && conda activate biogtr
```
### Uninstalling
```
conda env remove -n biogtr
conda env remove -n dreem
```

## Usage
Expand Down Expand Up @@ -275,4 +275,4 @@ e.g if I want to set the window size of the tracker to 32 instead of 8 through `
python /home/aaprasad/biogtr/inference/inference.py --config-base=/home/aaprasad/biogtr_configs --config-name=track ckpt_path="/home/aaprasad/models/new_best.ckpt" tracker.window_size=32`
```
#### Output
This will run inference on the videos/detections you specified in the `dataset.test_dataset` section of the config and save the tracks to individual `[VID_NAME].biogtr_inference.slp` files. If an `outdir` is specified in the config it will save to `./[OUTDIR]/[VID_NAME].biogtr_inference.slp`, otherwise it will just save to `./results/[VID_NAME].biogtr_inference.slp`. Now you can load the file with `sleap-io` and do what you please!
This will run inference on the videos/detections you specified in the `dataset.test_dataset` section of the config and save the tracks to individual `[VID_NAME].biogtr_inference.slp` files. If an `outdir` is specified in the config it will save to `./[OUTDIR]/[VID_NAME].biogtr_inference.slp`, otherwise it will just save to `./results/[VID_NAME].biogtr_inference.slp`. Now you can load the file with `sleap-io` and do what you please!
18 changes: 0 additions & 18 deletions biogtr/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion biogtr/cli.py

This file was deleted.

9 changes: 0 additions & 9 deletions biogtr/io/__init__.py

This file was deleted.

18 changes: 18 additions & 0 deletions dreem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Top-level package for dreem."""

from dreem.version import __version__

from dreem.models.global_tracking_transformer import GlobalTrackingTransformer
from dreem.models.gtr_runner import GTRRunner
from dreem.models.transformer import Transformer
from dreem.models.visual_encoder import VisualEncoder

from dreem.io.frame import Frame
from dreem.io.instance import Instance
from dreem.io.association_matrix import AssociationMatrix
from dreem.io.config import Config
from dreem.io.visualize import annotate_video

# from .training import run

from dreem.inference.tracker import Tracker
1 change: 1 addition & 0 deletions dreem/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""This module contains the command line interfaces for the dreem package."""
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module containing logic for loading datasets."""

from biogtr.datasets import data_utils
from biogtr.io import Frame
from dreem.datasets import data_utils
from dreem.io import Frame
from torch.utils.data import Dataset
from typing import List, Union
import numpy as np
Expand Down Expand Up @@ -29,14 +29,14 @@ def __init__(
Args:
label_files: a list of paths to label files.
should at least contain detections for inference, detections + tracks for training.
should at least contain detections for inference, detections + tracks for training.
vid_files: list of paths to video files.
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation.
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters.
See subclasses for details.
n_chunks: Number of chunks to subsample from.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module containing cell tracking challenge dataset."""

from PIL import Image
from biogtr.datasets import data_utils, BaseDataset
from biogtr.io import Frame, Instance
from dreem.datasets import data_utils, BaseDataset
from dreem.io import Frame, Instance
from scipy.ndimage import measurements
from typing import List, Optional, Union
import albumentations as A
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram
Returns:
a list of Frame objects containing frame metadata and Instance Objects.
See `biogtr.io.data_structures` for more info.
See `dreem.io.data_structures` for more info.
"""
image = self.videos[label_idx]
gt = self.labels[label_idx]
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module containing wrapper for merging gt and pred datasets for evaluation."""

from torch.utils.data import Dataset
from biogtr.io import Instance, Frame
from dreem.io import Instance, Frame
from typing import List


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module containing microscopy dataset."""

from PIL import Image
from biogtr.datasets import data_utils, BaseDataset
from biogtr.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from dreem.io import Instance, Frame
from typing import Union
import albumentations as A
import numpy as np
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
frame_idx: index of the frames
Returns:
A list of Frames containing Instances to be tracked (See `biogtr.io.data_structures for more info`)
A list of Frames containing Instances to be tracked (See `dreem.io.data_structures for more info`)
"""
labels = self.labels[label_idx]
labels = labels.dropna(how="all")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import sleap_io as sio
import random
import warnings
from biogtr.io import Instance, Frame
from biogtr.datasets import data_utils, BaseDataset
from dreem.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from torchvision.transforms import functional as tvf
from typing import List, Union

Expand Down Expand Up @@ -124,30 +124,15 @@ def get_indices(self, idx):
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict]:
def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[Frame]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
A list of dicts where each dict corresponds a frame in the chunk and each value is a `torch.Tensor`
Dict Elements:
{
"video_id": The video being passed through the transformer,
"img_shape": the shape of each frame,
"frame_id": the specific frame in the entire video being used,
"num_detected": The number of objects in the frame,
"gt_track_ids": The ground truth labels,
"bboxes": The bounding boxes of each object,
"crops": The raw pixel crops,
"features": The feature vectors for each crop outputed by the CNN encoder,
"pred_track_ids": The predicted trajectory labels from the tracker,
"asso_output": the association matrix preprocessing,
"matches": the true positives from the model,
"traj_score": the association matrix post processing,
}
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.
"""
video = self.labels[label_idx]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module containing Lightning module wrapper around all other datasets."""

from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from dreem.datasets.cell_tracking_dataset import CellTrackingDataset
from dreem.datasets.microscopy_dataset import MicroscopyDataset
from dreem.datasets.sleap_dataset import SleapDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from typing import Union
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ckpt_path: "../training/models/example/example_train/epoch=0-best-val_sw_cnt=31.06133270263672.ckpt"
out_dir: "./results"
tracker:
overlap_thresh: 0.01
decay_time: 0.9
Expand Down
10 changes: 5 additions & 5 deletions biogtr/inference/metrics.py → dreem/inference/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import torch
from typing import Union, Iterable

# from biogtr.inference.post_processing import _pairwise_iou
# from biogtr.inference.boxes import Boxes
# from dreem.inference.post_processing import _pairwise_iou
# from dreem.inference.boxes import Boxes


def get_matches(frames: list["biogtr.io.Frame"]) -> tuple[dict, list, int]:
def get_matches(frames: list["dreem.io.Frame"]) -> tuple[dict, list, int]:
"""Get comparison between predicted and gt trajectory labels.
Args:
Expand Down Expand Up @@ -100,11 +100,11 @@ def get_switch_count(switches: dict) -> int:
return sw_cnt


def to_track_eval(frames: list["biogtr.io.Frame"]) -> dict:
def to_track_eval(frames: list["dreem.io.Frame"]) -> dict:
"""Reformats frames the output from `sliding_inference` to be used by `TrackEval`.
Args:
frames: A list of Frames. `See biogtr.io.data_structures for more info`.
frames: A list of Frames. `See dreem.io.data_structures for more info`.
Returns:
data: A dictionary. Example provided below.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Helper functions for post-processing association matrix pre-tracking."""

import torch
from biogtr.inference.boxes import Boxes
from dreem.inference.boxes import Boxes


def weight_decay_time(
Expand Down
8 changes: 4 additions & 4 deletions biogtr/inference/track.py → dreem/inference/track.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Script to run inference and get out tracks."""

from biogtr.io import Config
from biogtr.models import GTRRunner
from dreem.io import Config
from dreem.models import GTRRunner
from omegaconf import DictConfig
from pathlib import Path
from pprint import pprint
Expand All @@ -14,7 +14,7 @@
import sleap_io as sio


def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None):
def export_trajectories(frames_pred: list["dreem.io.Frame"], save_path: str = None):
"""Convert trajectories to data frame and save as .csv.
Args:
Expand Down Expand Up @@ -132,7 +132,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
for i, pred in preds.items():
outpath = os.path.join(
outdir,
f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp",
f"{Path(dataloader.dataset.label_files[i]).stem}.dreem_inference.v{run_num}.slp",
)
if os.path.exists(outpath):
run_num += 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module handling sliding window tracking."""

import warnings
from biogtr.io import Frame
from dreem.io import Frame
from collections import deque
import numpy as np

Expand All @@ -14,7 +14,9 @@ class TrackQueue:
and will be compared against later frames for assignment.
"""

def __init__(self, window_size: int, max_gap: int = np.inf, verbose: bool = False):
def __init__(
self, window_size: int, max_gap: int = np.inf, verbose: bool = False
) -> None:
"""Initialize track queue.
Args:
Expand All @@ -33,15 +35,15 @@ def __init__(self, window_size: int, max_gap: int = np.inf, verbose: bool = Fals
self._curr_track = -1
self._verbose = verbose

def __len__(self):
def __len__(self) -> int:
"""Get length of the queue.
Returns:
The total number of instances in every sub-queue.
"""
return sum([len(queue) for queue in self._queues.values()])

def __repr__(self):
def __repr__(self) -> str:
"""Return the string representation of the TrackQueue.
Returns:
Expand Down
Loading

0 comments on commit 7320578

Please sign in to comment.