Skip to content

Commit

Permalink
Update python version to python 3.11 (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Jul 22, 2024
1 parent 417b0dc commit 6ea1060
Show file tree
Hide file tree
Showing 30 changed files with 140 additions and 150 deletions.
15 changes: 7 additions & 8 deletions dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dreem.datasets import data_utils
from dreem.io import Frame
from torch.utils.data import Dataset
from typing import List, Union
import numpy as np
import torch

Expand All @@ -20,10 +19,10 @@ def __init__(
chunk: bool,
clip_length: int,
mode: str,
augmentations: dict = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
gt_list: str = None,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
gt_list: str | None = None,
):
"""Initialize Dataset.
Expand Down Expand Up @@ -122,7 +121,7 @@ def __len__(self) -> int:
"""
return len(self.chunked_frame_idx)

def no_batching_fn(self, batch: list[Frame]) -> List[Frame]:
def no_batching_fn(self, batch: list[Frame]) -> list[Frame]:
"""Collate function used to overwrite dataloader batching function.
Args:
Expand All @@ -133,7 +132,7 @@ def no_batching_fn(self, batch: list[Frame]) -> List[Frame]:
"""
return batch

def __getitem__(self, idx: int) -> List[Frame]:
def __getitem__(self, idx: int) -> list[Frame]:
"""Get an element of the dataset.
Args:
Expand All @@ -160,7 +159,7 @@ def get_indices(self, idx: int):
"""
raise NotImplementedError("Must be implemented in subclass")

def get_instances(self, label_idx: List[int], frame_idx: List[int]):
def get_instances(self, label_idx: list[int], frame_idx: list[int]):
"""Build chunk of frames.
This method should be implemented in any subclass of the BaseDataset.
Expand Down
11 changes: 5 additions & 6 deletions dreem/datasets/cell_tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dreem.datasets import data_utils, BaseDataset
from dreem.io import Frame, Instance
from scipy.ndimage import measurements
from typing import List, Optional, Union
import albumentations as A
import numpy as np
import pandas as pd
Expand All @@ -24,10 +23,10 @@ def __init__(
chunk: bool = False,
clip_length: int = 10,
mode: str = "train",
augmentations: Optional[dict] = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
gt_list: list[str] = None,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
gt_list: list[str] | None = None,
):
"""Initialize CellTrackingDataset.
Expand Down Expand Up @@ -116,7 +115,7 @@ def get_indices(self, idx: int) -> tuple:
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Frame]:
def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
"""Get an element of the dataset.
Args:
Expand Down
7 changes: 3 additions & 4 deletions dreem/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from PIL import Image
from numpy.typing import ArrayLike
from torchvision.transforms import functional as tvf
from typing import List, Dict, Union
from xml.etree import cElementTree as et
import albumentations as A
import math
Expand Down Expand Up @@ -53,7 +52,7 @@ def crop_bbox(img: torch.Tensor, bbox: ArrayLike) -> torch.Tensor:
return crop


def get_bbox(center: ArrayLike, size: Union[int, tuple[int]]) -> torch.Tensor:
def get_bbox(center: ArrayLike, size: int | tuple[int]) -> torch.Tensor:
"""Get a square bbox around a centroid coordinates.
Args:
Expand Down Expand Up @@ -109,7 +108,7 @@ def centroid_bbox(points: ArrayLike, anchors: list, crop_size: int) -> torch.Ten
return bbox


def pose_bbox(points: np.ndarray, bbox_size: Union[tuple[int], int]) -> torch.Tensor:
def pose_bbox(points: np.ndarray, bbox_size: tuple[int] | int) -> torch.Tensor:
"""Calculate bbox around instance pose.
Args:
Expand Down Expand Up @@ -496,7 +495,7 @@ def get_max_padding(height: int, width: int) -> tuple:


def view_training_batch(
instances: List[Dict[str, List[np.ndarray]]], num_frames: int = 1, cmap=None
instances: list[dict[str, list[np.ndarray]]], num_frames: int = 1, cmap=None
) -> None:
"""Display a grid of images from a batch of training instances.
Expand Down
3 changes: 1 addition & 2 deletions dreem/datasets/eval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from torch.utils.data import Dataset
from dreem.io import Instance, Frame
from typing import List


class EvalDataset(Dataset):
Expand All @@ -26,7 +25,7 @@ def __len__(self) -> int:
"""
return len(self.gt_dataset)

def __getitem__(self, idx: int) -> List[Frame]:
def __getitem__(self, idx: int) -> list[Frame]:
"""Get an element of the dataset.
Args:
Expand Down
7 changes: 3 additions & 4 deletions dreem/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from PIL import Image
from dreem.datasets import data_utils, BaseDataset
from dreem.io import Instance, Frame
from typing import Union
import albumentations as A
import numpy as np
import random
Expand All @@ -23,9 +22,9 @@ def __init__(
chunk: bool = False,
clip_length: int = 10,
mode: str = "Train",
augmentations: dict = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
):
"""Initialize MicroscopyDataset.
Expand Down
11 changes: 5 additions & 6 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dreem.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from torchvision.transforms import functional as tvf
from typing import List, Union


class SleapDataset(BaseDataset):
Expand All @@ -22,14 +21,14 @@ def __init__(
video_files: list[str],
padding: int = 5,
crop_size: int = 128,
anchors: Union[int, list[str], str] = "",
anchors: int | list[str] | str = "",
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
handle_missing: str = "centroid",
augmentations: dict = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
verbose: bool = False,
):
"""Initialize SleapDataset.
Expand Down Expand Up @@ -124,7 +123,7 @@ def get_indices(self, idx: int) -> tuple:
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[Frame]:
def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
"""Get an element of the dataset.
Args:
Expand Down
19 changes: 6 additions & 13 deletions dreem/datasets/tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dreem.datasets.sleap_dataset import SleapDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from typing import Union
import torch


Expand All @@ -22,18 +21,12 @@ class TrackingDataset(LightningDataModule):

def __init__(
self,
train_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
train_dl: DataLoader = None,
val_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
val_dl: DataLoader = None,
test_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
test_dl: DataLoader = None,
train_ds: SleapDataset | MicroscopyDataset | CellTrackingDataset | None = None,
train_dl: DataLoader | None = None,
val_ds: SleapDataset | MicroscopyDataset | CellTrackingDataset | None = None,
val_dl: DataLoader | None = None,
test_ds: SleapDataset | MicroscopyDataset | CellTrackingDataset | None = None,
test_dl: DataLoader | None = None,
):
"""Initialize tracking dataset.
Expand Down
14 changes: 7 additions & 7 deletions dreem/inference/boxes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module containing Boxes class."""

from typing import List, Tuple, Union
import torch
from typing import Self


class Boxes:
Expand Down Expand Up @@ -37,15 +37,15 @@ def __init__(self, tensor: torch.Tensor):

self.tensor = tensor

def clone(self) -> "Boxes":
def clone(self) -> Self:
"""Clone the Boxes.
Returns:
Boxes
"""
return Boxes(self.tensor.clone())

def to(self, device: torch.device) -> "Boxes":
def to(self, device: torch.device) -> Self:
"""Load boxes to gpu/cpu.
Args:
Expand All @@ -66,7 +66,7 @@ def area(self) -> torch.Tensor:
area = (box[:, :, 2] - box[:, :, 0]) * (box[:, :, 3] - box[:, :, 1])
return area

def clip(self, box_size: Tuple[int, int]) -> None:
def clip(self, box_size: list[int, int]) -> None:
"""Clip (in place) the boxes.
Limits x coordinates to the range [0, width]
Expand Down Expand Up @@ -102,7 +102,7 @@ def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
keep = (widths > threshold) & (heights > threshold)
return keep

def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Boxes":
def __getitem__(self, item: int | slice | torch.BoolTensor) -> "Boxes":
"""Getter for boxes.
Args:
Expand Down Expand Up @@ -146,7 +146,7 @@ def __repr__(self) -> str:
return "Boxes(" + str(self.tensor) + ")"

def inside_box(
self, box_size: Tuple[int, int], boundary_threshold: int = 0
self, box_size: tuple[int, int], boundary_threshold: int = 0
) -> torch.Tensor:
"""Check if box is inside reference box.
Expand Down Expand Up @@ -181,7 +181,7 @@ def scale(self, scale_x: float, scale_y: float) -> None:
self.tensor[:, :, 1::2] *= scale_y

@classmethod
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
def cat(cls, boxes_list: list["Boxes"]) -> "Boxes":
"""Concatenates a list of Boxes into a single Boxes.
Arguments:
Expand Down
6 changes: 3 additions & 3 deletions dreem/inference/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import motmetrics as mm
import torch
from typing import Union, Iterable
from typing import Iterable
import pandas as pd

# from dreem.inference.post_processing import _pairwise_iou
Expand Down Expand Up @@ -237,9 +237,9 @@ def get_track_evals(data: dict, metrics: dict) -> dict:

def get_pymotmetrics(
data: dict,
metrics: Union[str, tuple] = "all",
metrics: str | tuple = "all",
key: str = "tracker_ids",
save: str = None,
save: str | None = None,
) -> pd.DataFrame:
"""Given data and a key, evaluate the predictions.
Expand Down
14 changes: 7 additions & 7 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
def weight_decay_time(
asso_output: torch.Tensor,
decay_time: float = 0,
reid_features: torch.Tensor = None,
T: int = None,
k: int = None,
reid_features: torch.Tensor | None = None,
T: int | None = None,
k: int | None = None,
) -> torch.Tensor:
"""Weight association matrix by time.
Expand Down Expand Up @@ -90,7 +90,7 @@ def _pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:


def weight_iou(
asso_output: torch.Tensor, method: str = None, last_ious: torch.Tensor = None
asso_output: torch.Tensor, method: str | None = None, last_ious: torch.Tensor = None
) -> torch.Tensor:
"""Weight the association matrix by the IOU between object bboxes across frames.
Expand Down Expand Up @@ -123,9 +123,9 @@ def weight_iou(
def filter_max_center_dist(
asso_output: torch.Tensor,
max_center_dist: float = 0,
k_boxes: torch.Tensor = None,
nonk_boxes: torch.Tensor = None,
id_inds: torch.Tensor = None,
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.
Expand Down
2 changes: 1 addition & 1 deletion dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def export_trajectories(
frames_pred: list["dreem.io.Frame"], save_path: str = None
frames_pred: list["dreem.io.Frame"], save_path: str | None = None
) -> pd.DataFrame:
"""Convert trajectories to data frame and save as .csv.
Expand Down
7 changes: 5 additions & 2 deletions dreem/inference/track_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dreem.io import Frame
from collections import deque
import numpy as np
from torch import device


class TrackQueue:
Expand Down Expand Up @@ -154,7 +155,7 @@ def verbose(self, verbose: bool) -> None:
"""
self._verbose = verbose

def end_tracks(self, track_id: int = None) -> bool:
def end_tracks(self, track_id: int | None = None) -> bool:
"""Terminate tracks and removing them from the queue.
Args:
Expand Down Expand Up @@ -222,7 +223,9 @@ def add_frame(self, frame: Frame) -> None:
) # should this be done in the tracker or the queue?

def collate_tracks(
self, track_ids: list[int] = None, device: str = None
self,
track_ids: list[int] | None = None,
device: str | device | None = None,
) -> list[Frame]:
"""Merge queues into a single list of Frames containing corresponding instances.
Expand Down
6 changes: 3 additions & 3 deletions dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(
use_vis_feats: bool = True,
overlap_thresh: float = 0.01,
mult_thresh: bool = True,
decay_time: float = None,
iou: str = None,
max_center_dist: float = None,
decay_time: float | None = None,
iou: str | None = None,
max_center_dist: float | None = None,
persistent_tracking: bool = False,
max_gap: int = inf,
max_tracks: int = inf,
Expand Down
Loading

0 comments on commit 6ea1060

Please sign in to comment.