Skip to content

Commit

Permalink
Abstract dataset logic to override base dataset class (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
sheridana committed Jun 26, 2023
1 parent d2aae82 commit e30a6b5
Show file tree
Hide file tree
Showing 30 changed files with 628 additions and 105 deletions.
12 changes: 10 additions & 2 deletions biogtr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Data structures for handling config parsing."""
from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset
from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
from biogtr.models.gtr_runner import GTRRunner
from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger
Expand Down Expand Up @@ -106,7 +107,9 @@ def get_gtr_runner(self):
**gtr_runner_params,
)

def get_dataset(self, mode: str) -> Union[SleapDataset, MicroscopyDataset]:
def get_dataset(
self, mode: str
) -> Union[SleapDataset, MicroscopyDataset, CellTrackingDataset]:
"""Getter for datasets.
Args:
Expand All @@ -127,18 +130,23 @@ def get_dataset(self, mode: str) -> Union[SleapDataset, MicroscopyDataset]:
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)

# todo: handle this better
if "slp_files" in dataset_params:
return SleapDataset(**dataset_params)
elif "tracks" in dataset_params or "source" in dataset_params:
return MicroscopyDataset(**dataset_params)
elif "raw_images" in dataset_params:
return CellTrackingDataset(**dataset_params)
else:
raise ValueError(
"Could not resolve dataset type from Config! Please include \
either `slp_files` or `tracks`/`source`"
)

def get_dataloader(
self, dataset: Union[SleapDataset, MicroscopyDataset], mode: str
self,
dataset: Union[SleapDataset, MicroscopyDataset, CellTrackingDataset],
mode: str,
) -> torch.utils.data.DataLoader:
"""Getter for dataloader.
Expand Down
139 changes: 139 additions & 0 deletions biogtr/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Module containing logic for loading datasets."""
from biogtr.datasets import data_utils
from torch.utils.data import Dataset
from typing import List
import torch


class BaseDataset(Dataset):
"""Base Dataset for microscopy and sleap datasets to override."""

def __init__(
self,
files: list[str],
padding: int,
crop_size: int,
chunk: bool,
clip_length: int,
mode: str,
augmentations: dict = None,
gt_list: str = None,
):
"""Initialize Dataset.
Args:
files: a list of files, file types are combined in subclasses
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters.
See subclasses for details.
gt_list: An optional path to .txt file containing ground truth for
cell tracking challenge datasets.
"""
self.files = files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
self.clip_length = clip_length
self.mode = mode

self.augmentations = (
data_utils.build_augmentations(augmentations) if augmentations else None
)

# Initialize in subclasses
self.frame_idx = None
self.labels = None
self.gt_list = None
self.chunks = None

def create_chunks(self):
"""Get indexing for data.
Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
if self.chunk:
self.chunks = [
[i * self.clip_length for i in range(len(label) // self.clip_length)]
for label in self.labels
]

self.chunked_frame_idx, self.label_idx = [], []
for i, (split, frame_idx) in enumerate(zip(self.chunks, self.frame_idx)):
frame_idx_split = torch.split(frame_idx, self.clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])
else:
self.chunked_frame_idx = self.frame_idx
self.label_idx = [i for i in range(len(self.labels))]

def __len__(self):
"""Get the size of the dataset.
Returns:
the size or the number of chunks in the dataset
"""
return len(self.chunked_frame_idx)

def no_batching_fn(self, batch):
"""Collate function used to overwrite dataloader batching function.
Args:
batch: the chunk of frames to be returned
Returns:
The batch
"""
return batch

def __getitem__(self, idx: int) -> List[dict]:
"""Get an element of the dataset.
Args:
idx: the index of the batch. Note this is not the index of the video
or the frame.
Returns:
A list of dicts where each dict corresponds a frame in the chunk and
each value is a `torch.Tensor`. Dict elements can be seen in
subclasses
"""
label_idx, frame_idx = self.get_indices(idx)

return self.get_instances(label_idx, frame_idx)

def get_indices(self, idx: int):
"""Retrieves label and frame indices given batch index.
This method should be implemented in any subclass of the BaseDataset.
Args:
idx: the index of the batch.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")

def get_instances(self, label_idx: List[int], frame_idx: List[int]):
"""Builds instances dict given label and frame indices.
This method should be implemented in any subclass of the BaseDataset.
Args:
label_idx: The index of the labels.
frame_idx: The index of the frames.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")
213 changes: 213 additions & 0 deletions biogtr/datasets/cell_tracking_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Module containing cell tracking challenge dataset."""
from PIL import Image
from biogtr.datasets import data_utils
from biogtr.datasets.base_dataset import BaseDataset
from scipy.ndimage import measurements
from torch.utils.data import Dataset
from torchvision.transforms import functional as tvf
from typing import List, Optional
import albumentations as A
import glob
import numpy as np
import os
import pandas as pd
import random
import torch


class CellTrackingDataset(BaseDataset):
"""Dataset for loading cell tracking challenge data."""

def __init__(
self,
raw_images: list[str],
gt_images: list[str],
padding: int = 5,
crop_size: int = 20,
chunk: bool = False,
clip_length: int = 10,
mode: str = "train",
augmentations: Optional[dict] = None,
gt_list: str = None,
):
"""Initialize CellTrackingDataset.
Args:
raw_images: paths to raw microscopy images
gt_images: paths to gt label images
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augs = {
'Rotate': {'limit': [-90, 90]},
'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0},
'RandomContrast': {'limit': 0.2}
}
gt_list: An optional path to .txt file containing gt ids stored in cell
tracking challenge format: "track_id", "start_frame",
"end_frame", "parent_id"
"""
super().__init__(
raw_images + gt_images,
padding,
crop_size,
chunk,
clip_length,
mode,
augmentations,
gt_list,
)

self.videos = raw_images
self.labels = gt_images
self.chunk = chunk
self.clip_length = clip_length
self.crop_size = crop_size
self.padding = padding
self.mode = mode

self.augmentations = (
data_utils.build_augmentations(augmentations) if augmentations else None
)

if gt_list is not None:
self.gt_list = pd.read_csv(
gt_list,
delimiter=" ",
header=None,
names=["track_id", "start_frame", "end_frame", "parent_id"],
)
else:
self.gt_list = None

self.frame_idx = [torch.arange(len(image)) for image in self.labels]

# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()

def get_indices(self, idx):
"""Retrieves label and frame indices given batch index.
Args:
idx: the index of the batch.
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
a list of dicts where each dict corresponds a frame in the chunk
and each value is a `torch.Tensor`.
Dict Elements:
{
"video_id": The video being passed through the transformer,
"img_shape": the shape of each frame,
"frame_id": the specific frame in the entire video being used,
"num_detected": The number of objects in the frame,
"gt_track_ids": The ground truth labels,
"bboxes": The bounding boxes of each object,
"crops": The raw pixel crops,
"features": The feature vectors for each crop outputed by the
CNN encoder,
"pred_track_ids": The predicted trajectory labels from the
tracker,
"asso_output": the association matrix preprocessing,
"matches": the true positives from the model,
"traj_score": the association matrix post processing,
}
"""
image = self.videos[label_idx]
gt = self.labels[label_idx]

instances = []

for i in frame_idx:
gt_track_ids, centroids, bboxes, crops = [], [], [], []

i = int(i)

img = image[i]
gt_sec = gt[i]

img = np.array(Image.open(img))
gt_sec = np.array(Image.open(gt_sec))

if img.dtype == np.uint16:
img = ((img - img.min()) * (1 / (img.max() - img.min()) * 255)).astype(
np.uint8
)

if self.gt_list is None:
unique_instances = np.unique(gt_sec)
else:
unique_instances = self.gt_list["track_id"].unique()

for instance in unique_instances:
# not all instances are in the frame, and they also label the
# background instance as zero
if instance in gt_sec and instance != 0:
mask = gt_sec == instance
center_of_mass = measurements.center_of_mass(mask)

# scipy returns yx
x, y = center_of_mass[::-1]

bbox = data_utils.pad_bbox(
data_utils.get_bbox([int(x), int(y)], self.crop_size),
padding=self.padding,
)

gt_track_ids.append(int(instance))
centroids.append([x, y])
bboxes.append(bbox)

# albumentations wants (spatial, channels), ensure correct dims
if self.augmentations is not None:
for transform in self.augmentations:
# for occlusion simulation, can remove if we don't want
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)

augmented = self.augmentations(
image=img,
keypoints=np.vstack(centroids),
)

img, centroids = augmented["image"], augmented["keypoints"]

img = torch.Tensor(img).unsqueeze(0)

for bbox in bboxes:
crop = data_utils.crop_bbox(img, bbox)
crops.append(crop)

instances.append(
{
"video_id": torch.tensor([label_idx]),
"img_shape": torch.tensor([img.shape]),
"frame_id": torch.tensor([i]),
"num_detected": torch.tensor([len(bboxes)]),
"gt_track_ids": torch.tensor(gt_track_ids).type(torch.int64),
"bboxes": torch.stack(bboxes),
"crops": torch.stack(crops),
"features": torch.tensor([]),
"pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]),
"asso_output": torch.tensor([]),
"matches": torch.tensor([]),
"traj_score": torch.tensor([]),
}
)

return instances
Loading

0 comments on commit e30a6b5

Please sign in to comment.