Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DOTA2 dataset #1999

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,63 @@ train_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/train20
valid_set = COCOKeypointsDataset(data_dir='.../coco', images_dir='images/val2017', json_file='annotations/instances_val2017.json', ...)
```
</details>



### Oriented Box Detection Datasets



<details>
<summary>DOTA 2.0</summary>

1. Download DOTA dataset: https://captain-whu.github.io/DOTA/dataset.html

2. Unzip and organize it as below:
```
dota
└── train
├── images
│ ├─ 000000000001.jpg
│ └─ ...
└── ann
└─ 000000000001.txt
└── val
├── images
│ ├─ 000000000002.jpg
│ └─ ...
└── ann
└─ 000000000002.txt
```


3. Run script to slice the dataset into tiles:

```bash
python src/super_gradients/examples/dota_prepare_dataset/dota_prepare_dataset.py --data_dir <path-to>/dota --output_dir <path-to>/dota_tiles
```

4. Specify path to the sliced dataset in the dataset (CLI):
```bash
python -m super_gradients.train_from_recipe --config-name yolo_nas_r_s_dota dataset_params.data_dir=<path-to>/dota_tiles
```

4. Specify path to the sliced dataset in the dataset (YAML):
```yaml
dataset_params:
train_dataset_params:
data_dir: <path-to>/dota_tiles/train
val_dataset_params:
data_dir: <path-to>/dota_tiles/train
```

4. Specify path to the sliced dataset in the dataset (CODE):

```python

from super_gradients.training.datasets import DOTAOBBDataset

train_loader = DOTAOBBDataset(data_dir="<path-to>/dota_tiles/train", ...)
```

</details>
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ nav:
- Models: ./documentation/source/models.md
- Dataset:
- Data: ./documentation/source/Data.md
- Computer Vision Datasets: ./src/super_gradients/training/datasets/Dataset_Setup_Instructions.md
- Computer Vision Datasets: ./documentation/source/Dataset_Setup_Instructions.md
- Dataset Adapter: ./documentation/source/dataloader_adapter.md
- Loss functions: ./documentation/source/Losses.md
- LR Assignment: ./documentation/source/LRAssignment.md
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
This script slices the DOTA dataset into tiles of a usable size for training a model.
The tiles are saved in the output directory with the same structure as the input directory.

To use this script you should download the DOTA dataset from the official website:
https://captain-whu.github.io/DOTA/dataset.html

The dataset should be organized as follows:
dota
└── train
├── images
│ ├─ 000000000001.jpg
│ └─ ...
└── ann
└─ 000000000001.txt
└── val
├── images
│ ├─ 000000000002.jpg
│ └─ ...
└── ann
└─ 000000000002.txt

Example usage:
python dota_prepare_dataset.py --input_dir /path/to/dota --output_dir /path/to/dota-sliced

After running this script you can use /path/to/dota-sliced as the data_dir argument for training a model on DOTA dataset.
"""

import argparse
from pathlib import Path

import cv2
from super_gradients.training.datasets import DOTAOBBDataset


def main():
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
parser = argparse.ArgumentParser(description="Slice DOTA dataset into tiles of usable size for training a model")
parser.add_argument("--input_dir", help="Where the full coco dataset is stored", required=True)
parser.add_argument("--output_dir", help="Where the resulting data should be stored", required=True)
parser.add_argument("--ann_subdir_name", default="ann", help="Name of the annotations subdirectory")
parser.add_argument("--output_ann_subdir_name", default="ann-obb", help="Name of the output annotations subdirectory")
parser.add_argument("--num_workers", default=cv2.getNumberOfCPUs() // 2)
args = parser.parse_args()

cv2.setNumThreads(cv2.getNumberOfCPUs() // 4)

input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
ann_subdir_name = str(args.ann_subdir_name)
output_ann_subdir_name = str(args.output_ann_subdir_name)
DOTAOBBDataset.slice_dataset_into_tiles(
data_dir=input_dir / "train",
output_dir=output_dir / "train",
input_ann_subdir_name=ann_subdir_name,
output_ann_subdir_name=output_ann_subdir_name,
tile_size=1024,
tile_step=512,
scale_factors=(0.75, 1, 1.25),
min_visibility=0.4,
min_area=8,
num_workers=args.num_workers,
)

DOTAOBBDataset.slice_dataset_into_tiles(
data_dir=input_dir / "val",
output_dir=output_dir / "val",
input_ann_subdir_name=ann_subdir_name,
output_ann_subdir_name=output_ann_subdir_name,
tile_size=1024,
tile_step=1024,
scale_factors=(1,),
min_visibility=0.4,
min_area=8,
num_workers=args.num_workers,
)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SemanticSegmentationDecodingModule,
BinarySegmentationDecodingModule,
)
from .obb_predictions import OBBPredictions, AbstractOBBPostPredictionCallback

__all__ = [
"HasPredict",
Expand All @@ -35,4 +36,6 @@
"AbstractSegmentationDecodingModule",
"SemanticSegmentationDecodingModule",
"BinarySegmentationDecodingModule",
"OBBPredictions",
"AbstractOBBPostPredictionCallback",
]
47 changes: 47 additions & 0 deletions src/super_gradients/module_interfaces/obb_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import abc
import dataclasses
from typing import Any, List
from typing import Union

import numpy as np
from torch import Tensor

__all__ = ["OBBPredictions", "AbstractOBBPostPredictionCallback"]


@dataclasses.dataclass
class OBBPredictions:
"""
A data class that encapsulates oriented box predictions for a single image.

:param labels: Array of shape [N] with class indices
:param scores: Array of shape [N] with corresponding confidence scores.
:param rboxes_cxcywhr: Array of shape [N, 5] with rotated boxes for each pose in CXCYWHR format.
"""

scores: Union[Tensor, np.ndarray]
labels: Union[Tensor, np.ndarray]
rboxes_cxcywhr: Union[Tensor, np.ndarray]

def __init__(self, rboxes_cxcywhr, scores, labels):
if len(rboxes_cxcywhr) != len(scores) or len(rboxes_cxcywhr) != len(labels):
raise ValueError(f"rboxes_cxcywhr, scores and labels must have the same length. Got: {len(rboxes_cxcywhr)}, {len(scores)}, {len(labels)}")
if rboxes_cxcywhr.ndim != 2 or rboxes_cxcywhr.shape[1] != 5:
raise ValueError(f"rboxes_cxcywhr must have shape [N, 5]. Got: {rboxes_cxcywhr.shape}")

self.scores = scores
self.labels = labels
self.rboxes_cxcywhr = rboxes_cxcywhr

def __len__(self):
return len(self.scores)


class AbstractOBBPostPredictionCallback(abc.ABC):
"""
A protocol interface of a post-prediction callback for pose estimation models.
"""

@abc.abstractmethod
def __call__(self, predictions: Any) -> List[OBBPredictions]:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Configuration file for the dataset parameters of the DOTA2 dataset for the YOLO-NAS-R model.
# A data_dir parameter should be explicitly defined in the config file that includes this file.
# Please check documentation/source/Dataset_Setup_Instructions.md for more information on how to set up the dataset.

num_classes: 18
class_names:
- plane
- ship
- storage-tank
- baseball-diamond
- tennis-court
- basketball-court
- ground-track-field
- harbor
- bridge
- large-vehicle
- small-vehicle
- helicopter
- roundabout
- soccer-ball-field
- swimming-pool
- container-crane
- airport
- helipad

data_dir: ???

mixup_prob: 0.5

train_dataset_params:
data_dir: ${dataset_params.data_dir}/train
class_names: ${dataset_params.class_names}
ignore_empty_annotations: True
transforms:
- Albumentations:
Compose:
keypoint_params:
transforms:
- ShiftScaleRotate:
shift_limit: 0.1
scale_limit: 0.75
rotate_limit: 45
interpolation: 1
border_mode: 0
- RandomBrightnessContrast:
brightness_limit: 0.2
contrast_limit: 0.2
p: 0.5
- RandomCrop:
p: 1.0
height: 640
width: 640
- HueSaturationValue:
hue_shift_limit: 20
sat_shift_limit: 30
val_shift_limit: 20
p: 0.5
- RandomRotate90:
p: 1.0
- HorizontalFlip:
p: 0.5

- OBBRemoveSmallObjects:
min_size: 8
min_area: 64

- OBBDetectionMixup:
prob: ${dataset_params.mixup_prob}

- OBBDetectionStandardize:
max_value: 255.

train_dataloader_params:
dataset: DOTAOBBDataset
batch_size: 16
num_workers: 8
shuffle: True
drop_last: True
pin_memory: True
persistent_workers: True
collate_fn: OrientedBoxesCollate
sampler:
ClassBalancedSampler:
num_samples: 65536
oversample_threshold: 0.99
oversample_aggressiveness: 0.9945267123516118

val_dataset_params:
data_dir: ${dataset_params.data_dir}/val
class_names: ${dataset_params.class_names}
ignore_empty_annotations: True
transforms:
- OBBDetectionLongestMaxSize:
max_height: 1024
max_width: 1024
- OBBDetectionPadIfNeeded:
min_height: 1024
min_width: 1024
pad_value: 114
padding_mode: bottom_right
- OBBDetectionStandardize:
max_value: 255.


val_dataloader_params:
dataset: DOTAOBBDataset
batch_size: 16
num_workers: 8
drop_last: False
shuffle: False
pin_memory: True
persistent_workers: True
collate_fn: OrientedBoxesCollate

_convert_: all
3 changes: 2 additions & 1 deletion src/super_gradients/training/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
BaseKeypointsDataset,
COCOPoseEstimationDataset,
)

from .obb import DOTAOBBDataset

__all__ = [
"BaseKeypointsDataset",
Expand All @@ -50,6 +50,7 @@
"SuperviselyPersonsDataset",
"COCOKeypointsDataset",
"COCOPoseEstimationDataset",
"DOTAOBBDataset",
]

cv2.setNumThreads(0)
21 changes: 21 additions & 0 deletions src/super_gradients/training/datasets/datasets_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,24 @@
"motorcycle",
"bicycle",
]

DOTA2_DEFAULT_CLASSES_LIST = [
"plane",
"ship",
"storage-tank",
"baseball-diamond",
"tennis-court",
"basketball-court",
"ground-track-field",
"harbor",
"bridge",
"large-vehicle",
"small-vehicle",
"helicopter",
"roundabout",
"soccer-ball-field",
"swimming-pool",
"container-crane",
"airport",
"helipad",
]
7 changes: 7 additions & 0 deletions src/super_gradients/training/datasets/obb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .collate import OrientedBoxesCollate
from .dota import DOTAOBBDataset

__all__ = [
"DOTAOBBDataset",
"OrientedBoxesCollate",
]
Loading
Loading