Skip to content

Commit

Permalink
Add Yolo Format Dataset (#114)
Browse files Browse the repository at this point in the history
* add yolo

* improve doc and name

* rename

* add more explicit default extensions

* upadte extension list

* update strings

* rename annotation to label

* add lena nd iter

* add s
  • Loading branch information
Louis-Dupont authored Jun 29, 2023
1 parent 5920672 commit 7d9351f
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 0 deletions.
84 changes: 84 additions & 0 deletions documentation/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Built-in Datasets

DataGradients offer a few basic datasets which can help you load your data without needing to provide any additional code.
These datasets contain only the very basic functionalities and are not recommended for training.

## Object Detection


### Yolo Format Dataset

The Yolo format Detection Dataset supports any dataset stored in the YOLO format.

#### Expected folder structure
Any structure including at least one sub-directory for images and one for labels. They can be the same.

Example 1: Separate directories for images and labels
```
dataset_root/
├── images/
│ ├── train/
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ └── ...
│ ├── test/
│ │ ├── ...
│ └── validation/
│ ├── ...
└── labels/
├── train/
│ ├── 1.txt
│ ├── 2.txt
│ └── ...
├── test/
│ ├── ...
└── validation/
├── ...
```

Example 2: Same directory for images and labels
```
dataset_root/
├── train/
│ ├── 1.jpg
│ ├── 1.txt
│ ├── 2.jpg
│ ├── 2.txt
│ └── ...
└── validation/
├── ...
```

#### Expected label files structure
The label files must be structured such that each row represents a bounding box annotation.
Each bounding box is represented by 5 elements: `class_id, cx, cy, w, h`.

#### Instantiation
```
dataset_root/
├── images/
│ ├── train/
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ └── ...
│ ├── test/
│ │ ├── ...
│ └── validation/
│ ├── ...
└── labels/
├── train/
│ ├── 1.txt
│ ├── 2.txt
│ └── ...
├── test/
│ ├── ...
└── validation/
├── ...
```

```python
from data_gradients.datasets.detection import YoloFormatDetectionDataset

train_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/train", labels_dir="labels/train")
val_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/validation", labels_dir="labels/validation")
```
119 changes: 119 additions & 0 deletions src/data_gradients/datasets/FolderProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import logging
from typing import List, Tuple, Sequence

logger = logging.getLogger(__name__)

# Supported image extensions for opencv: https://docs.opencv.org/3.4.3/d4/da8/group__imgcodecs.html#ga288b8b3da0892bd651fce07b3bbd3a56
DEFAULT_IMG_EXTENSIONS = (
"bmp",
"dib",
"exr",
"hdr",
"jp2",
"jpe",
"jpeg",
"jpg",
"pbm",
"pgm",
"pic",
"png",
"pnm",
"ppm",
"pxm",
"ras",
"sr",
"tif",
"tiff",
"webp",
)


class ImageLabelFilesIterator:
"""Iterate over all image and label files in the provided directories."""

def __init__(
self,
images_dir: str,
labels_dir: str,
label_extensions: Sequence[str],
image_extensions: Sequence[str] = DEFAULT_IMG_EXTENSIONS,
verbose: bool = True,
):
"""
:param images_dir: The directory containing the images.
:param labels_dir: The directory containing the labels.
:param label_extensions: The extensions of the labels. Only the labels with these extensions will be considered.
:param image_extensions: The extensions of the images. Only the images with these extensions will be considered.
:param verbose: Whether to print extra messages.
"""

self.images_dir = images_dir
self.labels_dir = labels_dir
self.verbose = verbose
self.image_extensions = self._normalize_extension(image_extensions or DEFAULT_IMG_EXTENSIONS)
self.label_extensions = self._normalize_extension(label_extensions)
self.images_with_labels_files = self.get_image_and_label_file_names(images_dir=images_dir, labels_dir=labels_dir)

def _normalize_extension(self, extensions: List[str]) -> List[str]:
"""Ensure that all extensions are lower case and don't include the '.'"""
return [ext.replace(".", "").lower() for ext in extensions]

def get_image_and_label_file_names(self, images_dir: str, labels_dir: str) -> List[Tuple[str, str]]:
"""Gather all image and label files from the provided sub_dirs."""
images_with_labels_files = []

if not os.path.exists(images_dir):
raise FileNotFoundError(f"The image directory `images_dir={images_dir}` does not exist.")
if not os.path.exists(labels_dir):
raise FileNotFoundError(f"The label directory `labels_dir={labels_dir}` does not exist.")

images_files, labels_files = self._get_file_names_in_folder(images_dir, labels_dir)
matched_images_with_labels_files = self._match_file_names(images_files, labels_files)

images_with_labels_files.extend(matched_images_with_labels_files)

return images_with_labels_files

def _get_file_names_in_folder(self, images_dir: str, labels_dir: str) -> Tuple[List[str], List[str]]:
"""Extracts the names of all image and label files in the provided folders."""
image_files = [os.path.abspath(os.path.join(images_dir, f)) for f in os.listdir(images_dir) if self.is_image(filename=f)]
label_files = [os.path.abspath(os.path.join(labels_dir, f)) for f in os.listdir(labels_dir) if self.is_label(filename=f)]
return image_files, label_files

def _match_file_names(self, all_images_file_names: List[str], all_labels_file_names: List[str]) -> List[Tuple[str, str]]:
"""Matches the names of image and label files."""
base_name = lambda file_name: os.path.splitext(os.path.basename(file_name))[0]

image_file_base_names = {base_name(file_name): file_name for file_name in all_images_file_names}
label_file_base_names = {base_name(file_name): file_name for file_name in all_labels_file_names}

common_base_names = set(image_file_base_names.keys()) & set(label_file_base_names.keys())
unmatched_image_files = set(image_file_base_names.keys()) - set(label_file_base_names.keys())
unmatched_label_files = set(label_file_base_names.keys()) - set(image_file_base_names.keys())

if self.verbose:
for imagefile in unmatched_image_files:
logger.warning(f"Image file {imagefile} does not have a matching label file. Hide this message by setting `verbose=False`.")
for label_file in unmatched_label_files:
logger.warning(f"Label file {label_file} does not have a matching image file. Hide this message by setting `verbose=False`.")

return [(image_file_base_names[name], label_file_base_names[name]) for name in common_base_names]

def is_image(self, filename: str) -> bool:
"""Check if the given file name refers to image."""
return filename.split(".")[-1].lower() in self.image_extensions

def is_label(self, filename: str) -> bool:
"""Check if the given file name refers to image."""
return filename.split(".")[-1].lower() in self.label_extensions

def __len__(self):
return len(self.images_with_labels_files)

def __getitem__(self, index):
return self.images_with_labels_files[index]

def __iter__(self):
for image_label_file in self.images_with_labels_files:
yield image_label_file
4 changes: 4 additions & 0 deletions src/data_gradients/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from data_gradients.datasets.detection import YoloFormatDetectionDataset
from data_gradients.datasets.bdd_dataset import BDDDataset

__all__ = ["YoloFormatDetectionDataset", "BDDDataset"]
3 changes: 3 additions & 0 deletions src/data_gradients/datasets/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from data_gradients.datasets.detection.yolo_format_detection_dataset import YoloFormatDetectionDataset

__all__ = ["YoloFormatDetectionDataset"]
157 changes: 157 additions & 0 deletions src/data_gradients/datasets/detection/yolo_format_detection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
import numpy as np
import logging
from typing import Tuple, Sequence

from data_gradients.datasets.FolderProcessor import ImageLabelFilesIterator, DEFAULT_IMG_EXTENSIONS
from data_gradients.datasets.utils import load_image, ImageChannelFormat


logger = logging.getLogger(__name__)


class YoloFormatDetectionDataset:
"""The Yolo format Detection Dataset supports any dataset stored in the YOLO format.
#### Expected folder structure
Any structure including at least one sub-directory for images and one for labels. They can be the same.
Example 1: Separate directories for images and labels
```
dataset_root/
├── images/
│ ├── train/
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ └── ...
│ ├── test/
│ │ ├── ...
│ └── validation/
│ ├── ...
└── labels/
├── train/
│ ├── 1.txt
│ ├── 2.txt
│ └── ...
├── test/
│ ├── ...
└── validation/
├── ...
```
Example 2: Same directory for images and labels
```
dataset_root/
├── train/
│ ├── 1.jpg
│ ├── 1.txt
│ ├── 2.jpg
│ ├── 2.txt
│ └── ...
└── validation/
├── ...
```
#### Expected label files structure
The label files must be structured such that each row represents a bounding box label.
Each bounding box is represented by 5 elements: `class_id, cx, cy, w, h`.
#### Instantiation
```
dataset_root/
├── images/
│ ├── train/
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ └── ...
│ ├── test/
│ │ ├── ...
│ └── validation/
│ ├── ...
└── labels/
├── train/
│ ├── 1.txt
│ ├── 2.txt
│ └── ...
├── test/
│ ├── ...
└── validation/
├── ...
```
```python
from data_gradients.datasets.detection import YoloFormatDetectionDataset
train_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/train", labels_dir="labels/train")
val_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/validation", labels_dir="labels/validation")
```
This class does NOT support dataset formats such as Pascal VOC or COCO.
"""

def __init__(
self,
root_dir: str,
images_dir: str,
labels_dir: str,
ignore_invalid_labels: bool = True,
verbose: bool = False,
image_extensions: Sequence[str] = DEFAULT_IMG_EXTENSIONS,
label_extensions: Sequence[str] = ("txt",),
):
"""
:param root_dir: Where the data is stored.
:param images_dir: Local path to directory that includes all the images. Path relative to `root_dir`. Can be the same as `labels_dir`.
:param labels_dir: Local path to directory that includes all the labels. Path relative to `root_dir`. Can be the same as `images_dir`.
:param ignore_invalid_labels: Whether to ignore labels that fail to be parsed. If True ignores and logs a warning, otherwise raise an error.
:param verbose: Whether to show extra information during loading.
"""
self.image_label_tuples = ImageLabelFilesIterator(
images_dir=os.path.join(root_dir, images_dir),
labels_dir=os.path.join(root_dir, labels_dir),
image_extensions=image_extensions,
label_extensions=label_extensions,
verbose=verbose,
)
self.ignore_invalid_labels = ignore_invalid_labels
self.verbose = verbose

def load_image(self, index: int) -> np.ndarray:
img_file, _ = self.image_label_tuples[index]
return load_image(path=img_file, channel_format=ImageChannelFormat.RGB)

def load_labels(self, index: int) -> np.ndarray:
_, label_path = self.image_label_tuples[index]

with open(label_path, "r") as file:
lines = file.readlines()

labels = []
for line in filter(lambda x: x != "\n", lines):
lines_elements = line.split()
if len(lines_elements) == 5:
try:
labels.append(list(map(float, lines_elements)))
except ValueError as e:
raise ValueError(
f"Invalid label: {line} from {label_path}.\nExpected 5 elements (class_id, cx, cy, w, h), got {len(lines_elements)}."
) from e
else:
error = f"invalid label: {line} from {label_path}.\n Expected 5 elements (class_id, cx, cy, w, h), got {len(lines_elements)}."
if self.ignore_invalid_labels:
logger.warning(f"Ignoring {error}")
else:
raise RuntimeError(error.capitalize())
return np.array(labels) if labels else np.zeros((0, 5))

def __len__(self) -> int:
return len(self.image_label_tuples)

def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
image = self.load_image(index)
labels = self.load_labels(index)
return image, labels

def __iter__(self) -> Tuple[np.ndarray, np.ndarray]:
for i in range(len(self)):
yield self[i]
16 changes: 16 additions & 0 deletions src/data_gradients/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cv2
import numpy as np
from data_gradients.utils.data_classes.data_samples import ImageChannelFormat


def load_image(path: str, channel_format: ImageChannelFormat = ImageChannelFormat.BGR) -> np.ndarray:
"""Load an image from a path in a specified format."""
bgr_image = cv2.imread(path)
if channel_format == ImageChannelFormat.RGB:
return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
elif channel_format == ImageChannelFormat.BGR:
return bgr_image
elif channel_format == ImageChannelFormat.GRAYSCALE:
return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2GRAY)
else:
raise NotImplementedError(f"Channel format {channel_format} is not supported for loading image")

0 comments on commit 7d9351f

Please sign in to comment.