Skip to content

Commit

Permalink
Add custom dataset of grounding dino (#11012)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored Oct 9, 2023
1 parent d84ea9b commit f14353d
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
96 changes: 96 additions & 0 deletions configs/grounding_dino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,99 @@ Note:
1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being.
2. Finetune refers to fine-tuning on the COCO 2017 dataset. The R50 model is trained using 8 NVIDIA GeForce 3090 GPUs, while the remaining models are trained using 16 NVIDIA GeForce 3090 GPUs. The GPU memory usage is approximately 8.5GB.
3. Our performance is higher than the official model due to two reasons: we modified the initialization strategy and introduced a log scaler.

## Custom Dataset

To facilitate fine-tuning on custom datasets, we use a simple cat dataset as an example, as shown in the following steps.

### 1. Dataset Preparation

```shell
cd mmdetection
wget https://download.openmmlab.com/mmyolo/data/cat_dataset.zip
unzip cat_dataset.zip -d data/cat/
```

cat dataset is a single-category dataset with 144 images, which has been converted to coco format.

<div align=center>
<img src="https://user-images.githubusercontent.com/25873202/205423220-c4b8f2fd-22ba-4937-8e47-1b3f6a8facd8.png" alt="cat dataset"/>
</div>

### 2. Config Preparation

Due to the simplicity and small number of cat datasets, we use 8 cards to train 20 epochs, scale the learning rate accordingly, and do not train the language model, only the visual model.

The Details of the configuration can be found in [grounding_dino_swin-t_finetune_8xb2_20e_cat](grounding_dino_swin-t_finetune_8xb2_20e_cat.py)

### 3. Visualization and Evaluation

Due to the Grounding DINO is an open detection model, so it can be detected and evaluated even if it is not trained on the cat dataset.

The single image visualization is as follows:

```shell
cd mmdetection
python demo/image_demo.py data/cat/images/IMG_20211205_120756.jpg configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py --weights https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth --texts cat.
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/89173261-16f1-4fd9-ac63-8dc2dcda6616" alt="cat dataset"/>
</div>

The test dataset evaluation on single card is as follows:

```shell
python tools/test.py configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth
```

```text
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.867
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 1.000
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.931
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.867
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.903
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.907
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.907
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.907
```

### 4. Model Training and Visualization

```shell
./tools/dist_train.sh configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py 8 --work-dir cat_work_dir
```

The model will be saved based on the best performance on the test set. The performance of the best model (at epoch 16) is as follows:

```text
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.905
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 1.000
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.923
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.905
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.927
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.937
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.937
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.937
```

We can find that after fine-tuning training, the training of the cat dataset is increased from 86.7 to 90.5.

If we do single image inference visualization again, the result is as follows:

```shell
cd mmdetection
python demo/image_demo.py data/cat/images/IMG_20211205_120756.jpg configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py --weights cat_work_dir/best_coco_bbox_mAP_epoch_16.pth --texts cat.
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/5a027b00-8adb-4283-a47b-2f7a0a2c96d4" alt="cat dataset"/>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = 'grounding_dino_swin-t_finetune_16xb2_1x_coco.py'

data_root = 'data/cat/'
class_name = ('cat', )
num_classes = len(class_name)
metainfo = dict(classes=class_name, palette=[(220, 20, 60)])

model = dict(bbox_head=dict(num_classes=num_classes))

train_dataloader = dict(
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='annotations/trainval.json',
data_prefix=dict(img='images/')))

val_dataloader = dict(
dataset=dict(
metainfo=metainfo,
data_root=data_root,
ann_file='annotations/test.json',
data_prefix=dict(img='images/')))

test_dataloader = val_dataloader

val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
test_evaluator = val_evaluator

max_epoch = 20

default_hooks = dict(
checkpoint=dict(interval=1, max_keep_ckpts=1, save_best='auto'),
logger=dict(type='LoggerHook', interval=5))
train_cfg = dict(max_epochs=max_epoch, val_interval=1)

param_scheduler = [
dict(type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=30),
dict(
type='MultiStepLR',
begin=0,
end=max_epoch,
by_epoch=True,
milestones=[15],
gamma=0.1)
]

optim_wrapper = dict(
optimizer=dict(lr=0.00005),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'backbone': dict(lr_mult=0.1),
'language_model': dict(lr_mult=0),
}))

auto_scale_lr = dict(base_batch_size=16)

0 comments on commit f14353d

Please sign in to comment.