Skip to content

Commit

Permalink
Correctly handle missing poses (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
sheridana authored Jun 15, 2023
1 parent da79d10 commit d2aae82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
8 changes: 7 additions & 1 deletion biogtr/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,15 @@ def view_training_batch(
for i in range(num_frames):
for j, data in enumerate(instances[i]["crops"]):
try:
ax = axes[j] if num_frames == 1 else axes[i, j]
ax = (
axes[j]
if num_frames == 1
else (axes[i] if num_crops == 1 else axes[i, j])
)

ax.imshow(data.T)
ax.axis("off")

except Exception as e:
print(e)
pass
Expand Down
24 changes: 17 additions & 7 deletions biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,15 @@ def __getitem__(self, idx) -> list[dict]:
anchors = [
video.skeletons[0].node_names.index(anchor_name)
for anchor_name in self.anchor_names[label_idx]
] # get the nodes from the skeleton
]

video_name = self.video_files[label_idx]

vid_reader = imageio.get_reader(video_name, "ffmpeg")

img = vid_reader.get_data(0)
crop_shape = (img.shape[-1], *(self.crop_size + 2 * self.padding,) * 2)

instances = []

for i in frame_idx:
Expand Down Expand Up @@ -179,10 +182,13 @@ def __getitem__(self, idx) -> list[dict]:
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)

augmented = self.augmentations(
image=img,
keypoints=np.vstack([list(s.values()) for s in shown_poses]),
)
if shown_poses:
keypoints = np.vstack([list(s.values()) for s in shown_poses])

else:
keypoints = []

augmented = self.augmentations(image=img, keypoints=keypoints)

img, aug_poses = augmented["image"], augmented["keypoints"]

Expand Down Expand Up @@ -217,15 +223,19 @@ def __getitem__(self, idx) -> list[dict]:
bboxes.append(bbox)
crops.append(crop)

stacked_crops = (
torch.stack(crops) if crops else torch.empty((0, *crop_shape))
)

instances.append(
{
"video_id": torch.tensor([label_idx]),
"img_shape": torch.tensor([img.shape]),
"frame_id": torch.tensor([i]),
"num_detected": torch.tensor([len(bboxes)]),
"gt_track_ids": torch.tensor(gt_track_ids),
"bboxes": torch.stack(bboxes),
"crops": torch.stack(crops),
"bboxes": torch.stack(bboxes) if bboxes else torch.empty((0, 4)),
"crops": stacked_crops,
"features": torch.tensor([]),
"pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]),
"asso_output": torch.tensor([]),
Expand Down
2 changes: 1 addition & 1 deletion biogtr/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def inference(
save_dict["Frame"] = frame_ids
save_dict["X"] = X
save_dict["Y"] = Y
save_dict["Predicted Track ID"] = pred_track_ids
save_dict["Pred_track_id"] = pred_track_ids
save_df = pd.DataFrame(save_dict)
saved.append(save_df)

Expand Down

0 comments on commit d2aae82

Please sign in to comment.