diff --git a/biogtr/datasets/data_utils.py b/biogtr/datasets/data_utils.py index 7008b6a..f6b6696 100644 --- a/biogtr/datasets/data_utils.py +++ b/biogtr/datasets/data_utils.py @@ -458,9 +458,15 @@ def view_training_batch( for i in range(num_frames): for j, data in enumerate(instances[i]["crops"]): try: - ax = axes[j] if num_frames == 1 else axes[i, j] + ax = ( + axes[j] + if num_frames == 1 + else (axes[i] if num_crops == 1 else axes[i, j]) + ) + ax.imshow(data.T) ax.axis("off") + except Exception as e: print(e) pass diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 0b1d631..4255ad7 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -135,12 +135,15 @@ def __getitem__(self, idx) -> list[dict]: anchors = [ video.skeletons[0].node_names.index(anchor_name) for anchor_name in self.anchor_names[label_idx] - ] # get the nodes from the skeleton + ] video_name = self.video_files[label_idx] vid_reader = imageio.get_reader(video_name, "ffmpeg") + img = vid_reader.get_data(0) + crop_shape = (img.shape[-1], *(self.crop_size + 2 * self.padding,) * 2) + instances = [] for i in frame_idx: @@ -179,10 +182,13 @@ def __getitem__(self, idx) -> list[dict]: if isinstance(transform, A.CoarseDropout): transform.fill_value = random.randint(0, 255) - augmented = self.augmentations( - image=img, - keypoints=np.vstack([list(s.values()) for s in shown_poses]), - ) + if shown_poses: + keypoints = np.vstack([list(s.values()) for s in shown_poses]) + + else: + keypoints = [] + + augmented = self.augmentations(image=img, keypoints=keypoints) img, aug_poses = augmented["image"], augmented["keypoints"] @@ -217,6 +223,10 @@ def __getitem__(self, idx) -> list[dict]: bboxes.append(bbox) crops.append(crop) + stacked_crops = ( + torch.stack(crops) if crops else torch.empty((0, *crop_shape)) + ) + instances.append( { "video_id": torch.tensor([label_idx]), @@ -224,8 +234,8 @@ def __getitem__(self, idx) -> list[dict]: "frame_id": torch.tensor([i]), "num_detected": torch.tensor([len(bboxes)]), "gt_track_ids": torch.tensor(gt_track_ids), - "bboxes": torch.stack(bboxes), - "crops": torch.stack(crops), + "bboxes": torch.stack(bboxes) if bboxes else torch.empty((0, 4)), + "crops": stacked_crops, "features": torch.tensor([]), "pred_track_ids": torch.tensor([-1 for _ in range(len(bboxes))]), "asso_output": torch.tensor([]), diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 9e78da7..ae6f6bb 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -64,7 +64,7 @@ def inference( save_dict["Frame"] = frame_ids save_dict["X"] = X save_dict["Y"] = Y - save_dict["Predicted Track ID"] = pred_track_ids + save_dict["Pred_track_id"] = pred_track_ids save_df = pd.DataFrame(save_dict) saved.append(save_df)