-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add yolo * improve doc and name * rename * add more explicit default extensions * upadte extension list * update strings * rename annotation to label * add lena nd iter * add s
- Loading branch information
1 parent
5920672
commit 7d9351f
Showing
6 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Built-in Datasets | ||
|
||
DataGradients offer a few basic datasets which can help you load your data without needing to provide any additional code. | ||
These datasets contain only the very basic functionalities and are not recommended for training. | ||
|
||
## Object Detection | ||
|
||
|
||
### Yolo Format Dataset | ||
|
||
The Yolo format Detection Dataset supports any dataset stored in the YOLO format. | ||
|
||
#### Expected folder structure | ||
Any structure including at least one sub-directory for images and one for labels. They can be the same. | ||
|
||
Example 1: Separate directories for images and labels | ||
``` | ||
dataset_root/ | ||
├── images/ | ||
│ ├── train/ | ||
│ │ ├── 1.jpg | ||
│ │ ├── 2.jpg | ||
│ │ └── ... | ||
│ ├── test/ | ||
│ │ ├── ... | ||
│ └── validation/ | ||
│ ├── ... | ||
└── labels/ | ||
├── train/ | ||
│ ├── 1.txt | ||
│ ├── 2.txt | ||
│ └── ... | ||
├── test/ | ||
│ ├── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
|
||
Example 2: Same directory for images and labels | ||
``` | ||
dataset_root/ | ||
├── train/ | ||
│ ├── 1.jpg | ||
│ ├── 1.txt | ||
│ ├── 2.jpg | ||
│ ├── 2.txt | ||
│ └── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
|
||
#### Expected label files structure | ||
The label files must be structured such that each row represents a bounding box annotation. | ||
Each bounding box is represented by 5 elements: `class_id, cx, cy, w, h`. | ||
|
||
#### Instantiation | ||
``` | ||
dataset_root/ | ||
├── images/ | ||
│ ├── train/ | ||
│ │ ├── 1.jpg | ||
│ │ ├── 2.jpg | ||
│ │ └── ... | ||
│ ├── test/ | ||
│ │ ├── ... | ||
│ └── validation/ | ||
│ ├── ... | ||
└── labels/ | ||
├── train/ | ||
│ ├── 1.txt | ||
│ ├── 2.txt | ||
│ └── ... | ||
├── test/ | ||
│ ├── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
|
||
```python | ||
from data_gradients.datasets.detection import YoloFormatDetectionDataset | ||
|
||
train_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/train", labels_dir="labels/train") | ||
val_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/validation", labels_dir="labels/validation") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import os | ||
import logging | ||
from typing import List, Tuple, Sequence | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Supported image extensions for opencv: https://docs.opencv.org/3.4.3/d4/da8/group__imgcodecs.html#ga288b8b3da0892bd651fce07b3bbd3a56 | ||
DEFAULT_IMG_EXTENSIONS = ( | ||
"bmp", | ||
"dib", | ||
"exr", | ||
"hdr", | ||
"jp2", | ||
"jpe", | ||
"jpeg", | ||
"jpg", | ||
"pbm", | ||
"pgm", | ||
"pic", | ||
"png", | ||
"pnm", | ||
"ppm", | ||
"pxm", | ||
"ras", | ||
"sr", | ||
"tif", | ||
"tiff", | ||
"webp", | ||
) | ||
|
||
|
||
class ImageLabelFilesIterator: | ||
"""Iterate over all image and label files in the provided directories.""" | ||
|
||
def __init__( | ||
self, | ||
images_dir: str, | ||
labels_dir: str, | ||
label_extensions: Sequence[str], | ||
image_extensions: Sequence[str] = DEFAULT_IMG_EXTENSIONS, | ||
verbose: bool = True, | ||
): | ||
""" | ||
:param images_dir: The directory containing the images. | ||
:param labels_dir: The directory containing the labels. | ||
:param label_extensions: The extensions of the labels. Only the labels with these extensions will be considered. | ||
:param image_extensions: The extensions of the images. Only the images with these extensions will be considered. | ||
:param verbose: Whether to print extra messages. | ||
""" | ||
|
||
self.images_dir = images_dir | ||
self.labels_dir = labels_dir | ||
self.verbose = verbose | ||
self.image_extensions = self._normalize_extension(image_extensions or DEFAULT_IMG_EXTENSIONS) | ||
self.label_extensions = self._normalize_extension(label_extensions) | ||
self.images_with_labels_files = self.get_image_and_label_file_names(images_dir=images_dir, labels_dir=labels_dir) | ||
|
||
def _normalize_extension(self, extensions: List[str]) -> List[str]: | ||
"""Ensure that all extensions are lower case and don't include the '.'""" | ||
return [ext.replace(".", "").lower() for ext in extensions] | ||
|
||
def get_image_and_label_file_names(self, images_dir: str, labels_dir: str) -> List[Tuple[str, str]]: | ||
"""Gather all image and label files from the provided sub_dirs.""" | ||
images_with_labels_files = [] | ||
|
||
if not os.path.exists(images_dir): | ||
raise FileNotFoundError(f"The image directory `images_dir={images_dir}` does not exist.") | ||
if not os.path.exists(labels_dir): | ||
raise FileNotFoundError(f"The label directory `labels_dir={labels_dir}` does not exist.") | ||
|
||
images_files, labels_files = self._get_file_names_in_folder(images_dir, labels_dir) | ||
matched_images_with_labels_files = self._match_file_names(images_files, labels_files) | ||
|
||
images_with_labels_files.extend(matched_images_with_labels_files) | ||
|
||
return images_with_labels_files | ||
|
||
def _get_file_names_in_folder(self, images_dir: str, labels_dir: str) -> Tuple[List[str], List[str]]: | ||
"""Extracts the names of all image and label files in the provided folders.""" | ||
image_files = [os.path.abspath(os.path.join(images_dir, f)) for f in os.listdir(images_dir) if self.is_image(filename=f)] | ||
label_files = [os.path.abspath(os.path.join(labels_dir, f)) for f in os.listdir(labels_dir) if self.is_label(filename=f)] | ||
return image_files, label_files | ||
|
||
def _match_file_names(self, all_images_file_names: List[str], all_labels_file_names: List[str]) -> List[Tuple[str, str]]: | ||
"""Matches the names of image and label files.""" | ||
base_name = lambda file_name: os.path.splitext(os.path.basename(file_name))[0] | ||
|
||
image_file_base_names = {base_name(file_name): file_name for file_name in all_images_file_names} | ||
label_file_base_names = {base_name(file_name): file_name for file_name in all_labels_file_names} | ||
|
||
common_base_names = set(image_file_base_names.keys()) & set(label_file_base_names.keys()) | ||
unmatched_image_files = set(image_file_base_names.keys()) - set(label_file_base_names.keys()) | ||
unmatched_label_files = set(label_file_base_names.keys()) - set(image_file_base_names.keys()) | ||
|
||
if self.verbose: | ||
for imagefile in unmatched_image_files: | ||
logger.warning(f"Image file {imagefile} does not have a matching label file. Hide this message by setting `verbose=False`.") | ||
for label_file in unmatched_label_files: | ||
logger.warning(f"Label file {label_file} does not have a matching image file. Hide this message by setting `verbose=False`.") | ||
|
||
return [(image_file_base_names[name], label_file_base_names[name]) for name in common_base_names] | ||
|
||
def is_image(self, filename: str) -> bool: | ||
"""Check if the given file name refers to image.""" | ||
return filename.split(".")[-1].lower() in self.image_extensions | ||
|
||
def is_label(self, filename: str) -> bool: | ||
"""Check if the given file name refers to image.""" | ||
return filename.split(".")[-1].lower() in self.label_extensions | ||
|
||
def __len__(self): | ||
return len(self.images_with_labels_files) | ||
|
||
def __getitem__(self, index): | ||
return self.images_with_labels_files[index] | ||
|
||
def __iter__(self): | ||
for image_label_file in self.images_with_labels_files: | ||
yield image_label_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from data_gradients.datasets.detection import YoloFormatDetectionDataset | ||
from data_gradients.datasets.bdd_dataset import BDDDataset | ||
|
||
__all__ = ["YoloFormatDetectionDataset", "BDDDataset"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from data_gradients.datasets.detection.yolo_format_detection_dataset import YoloFormatDetectionDataset | ||
|
||
__all__ = ["YoloFormatDetectionDataset"] |
157 changes: 157 additions & 0 deletions
157
src/data_gradients/datasets/detection/yolo_format_detection_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import os | ||
import numpy as np | ||
import logging | ||
from typing import Tuple, Sequence | ||
|
||
from data_gradients.datasets.FolderProcessor import ImageLabelFilesIterator, DEFAULT_IMG_EXTENSIONS | ||
from data_gradients.datasets.utils import load_image, ImageChannelFormat | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class YoloFormatDetectionDataset: | ||
"""The Yolo format Detection Dataset supports any dataset stored in the YOLO format. | ||
#### Expected folder structure | ||
Any structure including at least one sub-directory for images and one for labels. They can be the same. | ||
Example 1: Separate directories for images and labels | ||
``` | ||
dataset_root/ | ||
├── images/ | ||
│ ├── train/ | ||
│ │ ├── 1.jpg | ||
│ │ ├── 2.jpg | ||
│ │ └── ... | ||
│ ├── test/ | ||
│ │ ├── ... | ||
│ └── validation/ | ||
│ ├── ... | ||
└── labels/ | ||
├── train/ | ||
│ ├── 1.txt | ||
│ ├── 2.txt | ||
│ └── ... | ||
├── test/ | ||
│ ├── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
Example 2: Same directory for images and labels | ||
``` | ||
dataset_root/ | ||
├── train/ | ||
│ ├── 1.jpg | ||
│ ├── 1.txt | ||
│ ├── 2.jpg | ||
│ ├── 2.txt | ||
│ └── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
#### Expected label files structure | ||
The label files must be structured such that each row represents a bounding box label. | ||
Each bounding box is represented by 5 elements: `class_id, cx, cy, w, h`. | ||
#### Instantiation | ||
``` | ||
dataset_root/ | ||
├── images/ | ||
│ ├── train/ | ||
│ │ ├── 1.jpg | ||
│ │ ├── 2.jpg | ||
│ │ └── ... | ||
│ ├── test/ | ||
│ │ ├── ... | ||
│ └── validation/ | ||
│ ├── ... | ||
└── labels/ | ||
├── train/ | ||
│ ├── 1.txt | ||
│ ├── 2.txt | ||
│ └── ... | ||
├── test/ | ||
│ ├── ... | ||
└── validation/ | ||
├── ... | ||
``` | ||
```python | ||
from data_gradients.datasets.detection import YoloFormatDetectionDataset | ||
train_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/train", labels_dir="labels/train") | ||
val_loader = YoloFormatDetectionDataset(root_dir="<path/to/dataset_root>", images_dir="images/validation", labels_dir="labels/validation") | ||
``` | ||
This class does NOT support dataset formats such as Pascal VOC or COCO. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root_dir: str, | ||
images_dir: str, | ||
labels_dir: str, | ||
ignore_invalid_labels: bool = True, | ||
verbose: bool = False, | ||
image_extensions: Sequence[str] = DEFAULT_IMG_EXTENSIONS, | ||
label_extensions: Sequence[str] = ("txt",), | ||
): | ||
""" | ||
:param root_dir: Where the data is stored. | ||
:param images_dir: Local path to directory that includes all the images. Path relative to `root_dir`. Can be the same as `labels_dir`. | ||
:param labels_dir: Local path to directory that includes all the labels. Path relative to `root_dir`. Can be the same as `images_dir`. | ||
:param ignore_invalid_labels: Whether to ignore labels that fail to be parsed. If True ignores and logs a warning, otherwise raise an error. | ||
:param verbose: Whether to show extra information during loading. | ||
""" | ||
self.image_label_tuples = ImageLabelFilesIterator( | ||
images_dir=os.path.join(root_dir, images_dir), | ||
labels_dir=os.path.join(root_dir, labels_dir), | ||
image_extensions=image_extensions, | ||
label_extensions=label_extensions, | ||
verbose=verbose, | ||
) | ||
self.ignore_invalid_labels = ignore_invalid_labels | ||
self.verbose = verbose | ||
|
||
def load_image(self, index: int) -> np.ndarray: | ||
img_file, _ = self.image_label_tuples[index] | ||
return load_image(path=img_file, channel_format=ImageChannelFormat.RGB) | ||
|
||
def load_labels(self, index: int) -> np.ndarray: | ||
_, label_path = self.image_label_tuples[index] | ||
|
||
with open(label_path, "r") as file: | ||
lines = file.readlines() | ||
|
||
labels = [] | ||
for line in filter(lambda x: x != "\n", lines): | ||
lines_elements = line.split() | ||
if len(lines_elements) == 5: | ||
try: | ||
labels.append(list(map(float, lines_elements))) | ||
except ValueError as e: | ||
raise ValueError( | ||
f"Invalid label: {line} from {label_path}.\nExpected 5 elements (class_id, cx, cy, w, h), got {len(lines_elements)}." | ||
) from e | ||
else: | ||
error = f"invalid label: {line} from {label_path}.\n Expected 5 elements (class_id, cx, cy, w, h), got {len(lines_elements)}." | ||
if self.ignore_invalid_labels: | ||
logger.warning(f"Ignoring {error}") | ||
else: | ||
raise RuntimeError(error.capitalize()) | ||
return np.array(labels) if labels else np.zeros((0, 5)) | ||
|
||
def __len__(self) -> int: | ||
return len(self.image_label_tuples) | ||
|
||
def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: | ||
image = self.load_image(index) | ||
labels = self.load_labels(index) | ||
return image, labels | ||
|
||
def __iter__(self) -> Tuple[np.ndarray, np.ndarray]: | ||
for i in range(len(self)): | ||
yield self[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import cv2 | ||
import numpy as np | ||
from data_gradients.utils.data_classes.data_samples import ImageChannelFormat | ||
|
||
|
||
def load_image(path: str, channel_format: ImageChannelFormat = ImageChannelFormat.BGR) -> np.ndarray: | ||
"""Load an image from a path in a specified format.""" | ||
bgr_image = cv2.imread(path) | ||
if channel_format == ImageChannelFormat.RGB: | ||
return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) | ||
elif channel_format == ImageChannelFormat.BGR: | ||
return bgr_image | ||
elif channel_format == ImageChannelFormat.GRAYSCALE: | ||
return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2GRAY) | ||
else: | ||
raise NotImplementedError(f"Channel format {channel_format} is not supported for loading image") |