Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : added support for keypoints dataset #1477

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 300 additions & 0 deletions supervision/dataset/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from supervision.dataset.formats.yolo import (
load_yolo_annotations,
load_yolo_keypoint_annotations,
save_data_yaml,
save_yolo_annotations,
)
Expand All @@ -32,6 +33,7 @@
train_test_split,
)
from supervision.detection.core import Detections
from supervision.keypoint.core import KeyPoints
from supervision.utils.internal import deprecated, warn_deprecated
from supervision.utils.iterables import find_duplicates

Expand Down Expand Up @@ -917,3 +919,301 @@ def from_folder_structure(cls, root_directory_path: str) -> ClassificationDatase
images=image_paths,
annotations=annotations,
)


class KeyPointDataset(BaseDataset):
"""
Contains information about a keypoint dataset. Handles lazy image loading
and annotation retrieval, dataset splitting, conversion into yolo
formats.

Attributes:
classes (List[str]): List containing dataset class names.
images (List[str]): Accepts a list of image paths.
If you pass a list of paths, the dataset will
lazily load images on demand, which is much more memory-efficient.
annotations (Dict[str, Keypoints]): Dictionary mapping
image path to annotations. The dictionary keys match
match the keys in `images` or entries in the list of
image paths.
"""

def __init__(
self,
classes: List[str],
images: List[str],
annotations: Dict[str, KeyPoints],
) -> None:
self.classes = classes

if set(images) != set(annotations):
raise ValueError(
"The keys of the images and annotations dictionaries must match."
)
self.annotations = annotations

# Eliminate duplicates while preserving order
self.image_paths = list(dict.fromkeys(images))

def _get_image(self, image_path: str) -> np.ndarray:
"""Assumes that image is in dataset"""
return cv2.imread(image_path)

def __len__(self) -> int:
return len(self.image_paths)

def __getitem__(self, i: int) -> Tuple[str, np.ndarray, KeyPoints]:
"""
Returns:
Tuple[str, np.ndarray, KeyPoints]: The image path, image data,
and its corresponding annotation at index i.
"""
image_path = self.image_paths[i]
image = self._get_image(image_path)
annotation = self.annotations[image_path]
return image_path, image, annotation

def __iter__(self) -> Iterator[Tuple[str, np.ndarray, KeyPoints]]:
"""
Iterate over the images and annotations in the dataset.

Yields:
Iterator[Tuple[str, np.ndarray, KeyPoints]]:
An iterator that yields tuples containing the image path,
the image data, and its corresponding KeyPoint annotation.
"""
for i in range(len(self)):
image_path, image, annotation = self[i]
yield image_path, image, annotation

def __eq__(self, other) -> bool:
if not isinstance(other, KeyPointDataset):
return False

if set(self.classes) != set(other.classes):
return False

if self.image_paths != other.image_paths:
return False

if self.annotations != other.annotations:
return False

return True

def split(
self,
split_ratio: float = 0.8,
random_state: Optional[int] = None,
shuffle: bool = True,
) -> Tuple[KeyPointDataset, KeyPointDataset]:
"""
Splits the dataset into two parts (training and testing)
using the provided split_ratio.

Args:
split_ratio (float, optional): The ratio of the training
set to the entire dataset.
random_state (int, optional): The seed for the random number generator.
This is used for reproducibility.
shuffle (bool, optional): Whether to shuffle the data before splitting.

Returns:
Tuple[KeyPointDataset, KeyPointDataset]: A tuple containing
the training and testing datasets.

Examples:
```python
import supervision as sv

ds = sv.KeyPointDataset(...)
train_ds, test_ds = ds.split(split_ratio=0.7, random_state=42, shuffle=True)
len(train_ds), len(test_ds)
# (700, 300)
```
"""

train_paths, test_paths = train_test_split(
data=self.image_paths,
train_ratio=split_ratio,
random_state=random_state,
shuffle=shuffle,
)

train_annotations = {path: self.annotations[path] for path in train_paths}
test_annotations = {path: self.annotations[path] for path in test_paths}

train_dataset = KeyPointDataset(
classes=self.classes,
images=train_paths,
annotations=train_annotations,
)
test_dataset = KeyPointDataset(
classes=self.classes,
images=test_paths,
annotations=test_annotations,
)
return train_dataset, test_dataset

@classmethod
def merge(cls, dataset_list: List[KeyPointDataset]) -> KeyPointDataset:
"""
Merge a list of `KeyPointDataset` objects into a single
`KeyPointDataset` object.

This method takes a list of `KeyPointDataset` objects and combines
their respective fields (`classes`, `images`,
`annotations`) into a single `KeyPointDataset` object.

Args:
dataset_list (List[KeyPointDataset]): A list of `KeyPointDataset`
objects to merge.

Returns:
(KeyPointDataset): A single `KeyPointDataset` object containing
the merged data from the input list.

Examples:
```python
import supervision as sv

ds_1 = sv.KeyPointDataset(...)
len(ds_1)
# 100
ds_1.classes
# ['dog', 'person']

ds_2 = sv.KeyPointDataset(...)
len(ds_2)
# 200
ds_2.classes
# ['cat']

ds_merged = sv.KeyPointDataset.merge([ds_1, ds_2])
len(ds_merged)
# 300
ds_merged.classes
# ['cat', 'dog', 'person']
```
"""

image_paths = list(
chain.from_iterable(dataset.image_paths for dataset in dataset_list)
)
image_paths_unique = list(dict.fromkeys(image_paths))
if len(image_paths) != len(image_paths_unique):
duplicates = find_duplicates(image_paths)
raise ValueError(
f"Image paths {duplicates} are not unique across datasets."
)
image_paths = image_paths_unique

classes = merge_class_lists(
class_lists=[dataset.classes for dataset in dataset_list]
)

annotations = {}
for dataset in dataset_list:
annotations.update(dataset.annotations)
for dataset in dataset_list:
class_index_mapping = build_class_index_mapping(
source_classes=dataset.classes, target_classes=classes
)
for image_path in dataset.image_paths:
annotations[image_path] = map_detections_class_id(
source_to_target_mapping=class_index_mapping,
detections=annotations[image_path],
)

return cls(
classes=classes,
images=image_paths,
annotations=annotations,
)

@classmethod
def from_yolo(
cls,
images_directory_path: str,
annotations_directory_path: str,
data_yaml_path: str,
force_masks: bool = False,
) -> KeyPointDataset:
"""
Creates a Dataset instance from YOLO formatted data.

Args:
images_directory_path (str): The path to the
directory containing the images.
annotations_directory_path (str): The path to the directory
containing the YOLO annotation files.
data_yaml_path (str): The path to the data
YAML file containing class information.

Returns:
KeyPointDataset: A KeyPointDataset instance
containing the loaded images and annotations.

Examples:
```python
import roboflow
from roboflow import Roboflow
import supervision as sv

roboflow.login()
rf = Roboflow()

project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov8")

ds = sv.KeyPointDataset.from_yolo(
images_directory_path=f"{dataset.location}/train/images",
annotations_directory_path=f"{dataset.location}/train/labels",
data_yaml_path=f"{dataset.location}/data.yaml"
)

ds.classes
# ['dog', 'person']
```
"""
classes, image_paths, annotations = load_yolo_keypoint_annotations(
images_directory_path=images_directory_path,
annotations_directory_path=annotations_directory_path,
data_yaml_path=data_yaml_path,
)
return KeyPointDataset(
classes=classes, images=image_paths, annotations=annotations
)

def as_yolo(
self,
images_directory_path: Optional[str] = None,
annotations_directory_path: Optional[str] = None,
data_yaml_path: Optional[str] = None,
) -> None:
"""
Exports the dataset to YOLO format. This method saves the
images and their corresponding annotations in YOLO format.

Args:
images_directory_path (Optional[str]): The path to the
directory where the images should be saved.
If not provided, images will not be saved.
annotations_directory_path (Optional[str]): The path to the
directory where the annotations in
YOLO format should be saved. If not provided,
annotations will not be saved.
data_yaml_path (Optional[str]): The path where the data.yaml
file should be saved.
If not provided, the file will not be saved.
"""
if images_directory_path is not None:
save_dataset_images(
dataset=self, images_directory_path=images_directory_path
)
if annotations_directory_path is not None:
save_yolo_annotations(
dataset=self, annotations_directory_path=annotations_directory_path
)
if data_yaml_path is not None:
save_data_yaml(data_yaml_path=data_yaml_path, classes=self.classes)
Loading