From 9edc517c19db8904c30277183baf67d0be3b7da4 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 7 May 2024 12:58:46 -0700 Subject: [PATCH] fix errors from merge --- biogtr/data_structures.py | 2 +- biogtr/datasets/sleap_dataset.py | 41 ++------------------------------ 2 files changed, 3 insertions(+), 40 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 2f98593c..113928f5 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -635,7 +635,7 @@ def to(self, map_location: str): return self def to_slp( - self, track_lookup: dict[int : sio.Track] = {} + self, track_lookup: dict[int, sio.Track] = {} ) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]: """Convert Frame to sleap_io.LabeledFrame object. diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index a7e953c9..b6649d1c 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -24,7 +24,6 @@ def __init__( padding: int = 5, crop_size: int = 128, anchors: Union[int, list[str], str] = "", - anchors: Union[int, list[str], str] = "", chunk: bool = True, clip_length: int = 500, mode: str = "train", @@ -99,7 +98,6 @@ def __init__( ) or self.anchors == 0: raise ValueError(f"Must provide at least one anchor but got {self.anchors}") - if isinstance(anchors, int): self.anchors = anchors elif isinstance(anchors, str): @@ -296,11 +294,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict else: anchors = self.anchors - for anchor in anchors: - if anchor == "midpoint" or anchor == "centroid": - centroid = np.nanmean(np.array(list(pose.values())), axis=0) - anchors = self.anchors - for anchor in anchors: if anchor == "midpoint" or anchor == "centroid": centroid = np.nanmean(np.array(list(pose.values())), axis=0) @@ -313,6 +306,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict elif anchor not in pose and len(anchors) == 1: anchor = "midpoint" centroid = np.nanmean(np.array(list(pose.values())), axis=0) + elif anchor in pose: centroid = np.array(pose[anchor]) if np.isnan(centroid).any(): @@ -325,18 +319,9 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict else: centroid = np.array([np.nan, np.nan]) - else: - centroid = np.array([np.nan, np.nan]) - - if np.isnan(centroid).all(): - bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan]) - else: - bbox = data_utils.pad_bbox( - data_utils.get_bbox(centroid, self.crop_size), - padding=self.padding, - ) if np.isnan(centroid).all(): bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan]) + else: bbox = data_utils.pad_bbox( data_utils.get_bbox(centroid, self.crop_size), @@ -352,25 +337,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict ) else: crop = data_utils.crop_bbox(img, bbox) - if bbox.isnan().all(): - crop = torch.zeros( - c, - self.crop_size + 2 * self.padding, - self.crop_size + 2 * self.padding, - dtype=img.dtype, - ) - else: - crop = data_utils.crop_bbox(img, bbox) - - crops.append(crop) - centroids[anchor] = centroid - boxes.append(bbox) - - if len(crops) > 0: - crops = torch.concat(crops, dim=0) - - if len(boxes) > 0: - boxes = torch.stack(boxes, dim=0) crops.append(crop) centroids[anchor] = centroid @@ -388,9 +354,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict crop=crops, centroid=centroids, bbox=boxes, - crop=crops, - centroid=centroids, - bbox=boxes, skeleton=skeleton, pose=poses[j], point_scores=point_scores[j],