diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 6e7b2076..cc872842 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -28,7 +28,7 @@ def __init__( clip_length: int = 10, mode: str = "train", augmentations: Optional[dict] = None, - gt_list: str = None, + gt_list: Optional[str] = None, ): """Initialize CellTrackingDataset. diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 5afd582d..b93d5de8 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -4,6 +4,7 @@ from biogtr.datasets.base_dataset import BaseDataset from torch.utils.data import Dataset from torchvision.transforms import functional as tvf +from typing import Optional import albumentations as A import numpy as np import random @@ -23,7 +24,7 @@ def __init__( chunk: bool = False, clip_length: int = 10, mode: str = "Train", - augmentations: dict = None, + augmentations: Optional[dict] = None, ): """Initialize MicroscopyDataset. diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index c395f235..cdddd009 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -1,14 +1,14 @@ """Module containing logic for loading sleap datasets.""" +from biogtr.datasets import data_utils +from biogtr.datasets.base_dataset import BaseDataset +from torchvision.transforms import functional as tvf +from typing import List, Optional import albumentations as A -import torch import imageio import numpy as np -import sleap_io as sio import random -from biogtr.datasets import data_utils -from biogtr.datasets.base_dataset import BaseDataset -from torchvision.transforms import functional as tvf -from typing import List +import sleap_io as sio +import torch class SleapDataset(BaseDataset): @@ -23,7 +23,7 @@ def __init__( chunk: bool = True, clip_length: int = 500, mode: str = "train", - augmentations: dict = None, + augmentations: Optional[dict] = None, ): """Initialize SleapDataset. @@ -137,6 +137,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict gt_track_ids, bboxes, crops, poses, shown_poses = [], [], [], [], [] i = int(i) + print(i) lf = video[i] img = vid_reader.get_data(i)