Skip to content

Commit

Permalink
add detr model (#266)
Browse files Browse the repository at this point in the history
* add detr on GPU

* refine detr on gpu

* modify detr code and upload test data on gpu

* update the format of test data and add detr test case

* update detr test metric

* add gpu 1x1 log for detr

* update 1x1 log

* add detr in test_conf.py

---------

Co-authored-by: wangdongyu04 <wangdongyu04@baidu.com>
  • Loading branch information
TWANG07 and wangdongyu04 authored Oct 13, 2023
1 parent d85f7cd commit 4cbf870
Show file tree
Hide file tree
Showing 31 changed files with 3,262 additions and 0 deletions.
38 changes: 38 additions & 0 deletions training/benchmarks/detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
### 模型信息

#### 模型介绍
Unlike traditional computer vision techniques, DETR approaches object detection as a direct set prediction problem. It consists of a set-based global loss, which forces unique predictions via bipartite matching, and a Transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions in parallel. Due to this parallel nature, DETR is very fast and efficient.
#### 模型代码来源
| repo | commmit_id | date |
| ---- | ---- |---- |
| [detr](https://github.com/facebookresearch/detr) |3af9fa878e73b6894ce3596450a8d9b89d918ca9 |2023-2-7 05:12:31|

This repository includes code is licensed under the Apache License, Version 2.0.

Some of the files in this directory were modified by BAAI in 2023 to support FlagPerf.

### 数据集
#### 数据集下载地址
COCO2017数据集
COCO官网地址:https://cocodataset.org/

#### 预处理

这里以下载coco2017数据集为例,主要下载三个文件:
- 2017 Train images [118K/18GB]:训练过程中使用到的所有图像文件
- 2017 Val images [5K/1GB]:验证过程中使用到的所有图像文件
- 2017 Train/Val annotations [241MB]:对应训练集和验证集的标注json文件
都解压到coco2017文件夹下,可得到如下文件夹结构:

```bash
├── coco2017: # 数据集根目录
├── train2017: # 所有训练图像文件夹(118287张)
├── val2017: # 所有验证图像文件夹(5000张)
└── annotations: # 对应标注文件夹
├── instances_train2017.json: # 对应目标检测、分割任务的训练集标注文件
├── instances_val2017.json: # 对应目标检测、分割任务的验证集标注文件
├── captions_train2017.json: # 对应图像描述的训练集标注文件
├── captions_val2017.json: # 对应图像描述的验证集标注文件
├── person_keypoints_train2017.json: # 对应人体关键点检测的训练集标注文件
└── person_keypoints_val2017.json: # 对应人体关键点检测的验证集标注文件夹
```
2 changes: 2 additions & 0 deletions training/benchmarks/detr/pytorch/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._base import *
from .mutable_params import mutable_params
109 changes: 109 additions & 0 deletions training/benchmarks/detr/pytorch/config/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# case info
vendor: str = None
data_dir: str = None
name: str = "detr"
cudnn_benchmark: bool = False
cudnn_deterministic: bool = True

# =========================================================
# data
# =========================================================
train_data: str = "train"
eval_data: str = "val"


# =========================================================
# Model
# =========================================================
lr = 1e-4
lr_backbone=1e-5
lr_drop = 200
weight_decay = 1e-4
clip_max_norm = 0.1
model_name = 'transformer'
backbone = 'resnet50'
dilation = False
position_embedding = 'sine'
enc_layers = 6
dec_layers = 6
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.1
nheads = 8
num_queries = 100
pre_norm = False
masks = False


# -----------------------------------------------------------------------------
# Loss
# -----------------------------------------------------------------------------
#Disables auxiliary decoding losses (loss at each layer)
aux_loss = True
# Class coefficient in the matching cost
set_cost_class = 1
# L1 box coefficient in the matching cost
set_cost_bbox = 5
# giou box coefficient in the matching cost
set_cost_giou = 2
mask_loss_coef = 1
dice_loss_coef = 1
bbox_loss_coef = 5
giou_loss_coef = 2
# Relative classification weight of the no-object class
eos_coef = 0.1


# =========================================================
# train && evaluate
# =========================================================
train_batch_size: int = 128
eval_batch_size: int = 128

target_mAP: float = 0.42

start_epoch = 0
epochs = 300

do_train = True
fp16 = False
distributed: bool = True
warmup = 0.1


# =========================================================
# utils
# =========================================================
dataset_file = 'coco'
output_dir = ''
data_dir = ''
seed: int = 42
dist_backend: str = 'nccl'
num_workers: int = 1
device : str = None
resume =''

# =========================================================
# for driver
# =========================================================
local_rank: int = -1
use_env: bool = True
log_freq: int = 100
print_freq: int = 100
n_device: int = 1
amp: bool = False
sync_bn: bool = False
gradient_accumulation_steps: int = 1


# -----------------------------------------------------------------------------
# distributed training parameters
# -----------------------------------------------------------------------------
local_rank: int = -1
use_env: bool = True
log_freq: int = 100
print_freq: int = 100
n_device: int = 1
amp: bool = False
sync_bn: bool = False
gradient_accumulation_steps: int = 1
8 changes: 8 additions & 0 deletions training/benchmarks/detr/pytorch/config/mutable_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mutable_params = [
'vendor', 'data_dir', 'train_data', 'eval_data', 'lr', 'weight_decay', 'train_batch_size', 'eval_batch_size',
'do_train', 'fp16', 'distributed', 'warmup', 'dist_backend', 'num_workers',
'device',
'cudnn_benchmark',
'cudnn_deterministic',
'local_rank'
]
24 changes: 24 additions & 0 deletions training/benchmarks/detr/pytorch/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.utils.data
import torchvision

from .coco import build as build_coco


def get_coco_api_from_dataset(dataset):
for _ in range(10):
# if isinstance(dataset, torchvision.datasets.CocoDetection):
# break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco


def build_dataset(image_set, args):
if args.dataset_file == 'coco':
return build_coco(image_set, args)
if args.dataset_file == 'coco_panoptic':
# add segmentation task in the future
pass
raise ValueError(f'dataset {args.dataset_file} not supported')
158 changes: 158 additions & 0 deletions training/benchmarks/detr/pytorch/dataloaders/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
from pathlib import Path

import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask

import dataloaders.transforms as T


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)

def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks

def __call__(self, image, target):
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)

classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)

if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)

keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)

keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]

target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints

# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]

target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])

return image, target


def make_coco_transforms(image_set):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

if image_set == 'train':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=1333),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=1333),
])
),
normalize,
])

if image_set == 'val':
return T.Compose([
T.RandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')


def build(image_set, args):
root = Path(args.data_dir)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = 'instances'
PATHS = {
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
}

img_folder, ann_file = PATHS[image_set]
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
return dataset
Loading

0 comments on commit 4cbf870

Please sign in to comment.