Skip to content

Commit

Permalink
fix errors from merge
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed May 7, 2024
1 parent 390a564 commit 9edc517
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 40 deletions.
2 changes: 1 addition & 1 deletion biogtr/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def to(self, map_location: str):
return self

def to_slp(
self, track_lookup: dict[int : sio.Track] = {}
self, track_lookup: dict[int, sio.Track] = {}
) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]:
"""Convert Frame to sleap_io.LabeledFrame object.
Expand Down
41 changes: 2 additions & 39 deletions biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
padding: int = 5,
crop_size: int = 128,
anchors: Union[int, list[str], str] = "",
anchors: Union[int, list[str], str] = "",
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
Expand Down Expand Up @@ -99,7 +98,6 @@ def __init__(
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")


if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
Expand Down Expand Up @@ -296,11 +294,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
else:
anchors = self.anchors

for anchor in anchors:
if anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
anchors = self.anchors

for anchor in anchors:
if anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
Expand All @@ -313,6 +306,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
Expand All @@ -325,18 +319,9 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
else:
centroid = np.array([np.nan, np.nan])

else:
centroid = np.array([np.nan, np.nan])

if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])
else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
Expand All @@ -352,25 +337,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
)
else:
crop = data_utils.crop_bbox(img, bbox)
if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)

crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)

if len(crops) > 0:
crops = torch.concat(crops, dim=0)

if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)

crops.append(crop)
centroids[anchor] = centroid
Expand All @@ -388,9 +354,6 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
crop=crops,
centroid=centroids,
bbox=boxes,
crop=crops,
centroid=centroids,
bbox=boxes,
skeleton=skeleton,
pose=poses[j],
point_scores=point_scores[j],
Expand Down

0 comments on commit 9edc517

Please sign in to comment.