Skip to content

Commit

Permalink
Merge pull request #136 from wenh06/master
Browse files Browse the repository at this point in the history
put validation set in use
  • Loading branch information
Tianxiaomo authored Jul 3, 2020
2 parents af00822 + 100edcf commit 74347ac
Show file tree
Hide file tree
Showing 12 changed files with 1,751 additions and 23 deletions.
18 changes: 13 additions & 5 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
@Detail :
'''
import os
from easydict import EasyDict


_BASE_DIR = os.path.dirname(os.path.abspath(__file__))

Cfg = EasyDict()

Cfg.use_darknet_cfg = True
Cfg.cfgfile = 'cfg/yolov4.cfg'
Cfg.cfgfile = os.path.join(_BASE_DIR, 'cfg', 'yolov4.cfg')

Cfg.batch = 64
Cfg.subdivisions = 16
Expand Down Expand Up @@ -50,8 +54,8 @@
Cfg.gaussian = 0
Cfg.boxes = 60 # box num
Cfg.TRAIN_EPOCHS = 300
Cfg.train_label = 'data/train.txt'
Cfg.val_label = 'data/val.txt'
Cfg.train_label = os.path.join(_BASE_DIR, 'data', 'train.txt')
Cfg.val_label = os.path.join(_BASE_DIR, 'data' ,'val.txt')
Cfg.TRAIN_OPTIMIZER = 'adam'
'''
image_path1 x1,y1,x2,y2,id x1,y1,x2,y2,id x1,y1,x2,y2,id ...
Expand All @@ -66,5 +70,9 @@
elif Cfg.mosaic:
Cfg.mixup = 3

Cfg.checkpoints = 'checkpoints'
Cfg.TRAIN_TENSORBOARD_DIR = 'log'
Cfg.checkpoints = os.path.join(_BASE_DIR, 'checkpoints')
Cfg.TRAIN_TENSORBOARD_DIR = os.path.join(_BASE_DIR, 'log')

Cfg.iou_type = 'iou' # 'giou', 'diou', 'ciou'

Cfg.keep_checkpoint_max = 10
36 changes: 35 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def draw_box(img, bboxes):


class Yolo_dataset(Dataset):
def __init__(self, lable_path, cfg):
def __init__(self, lable_path, cfg, train=True):
super(Yolo_dataset, self).__init__()
if cfg.mixup == 2:
print("cutmix=1 - isn't supported for Detector")
Expand All @@ -249,6 +249,7 @@ def __init__(self, lable_path, cfg):
raise

self.cfg = cfg
self.train = train

truth = {}
f = open(lable_path, 'r', encoding='utf-8')
Expand All @@ -264,6 +265,8 @@ def __len__(self):
return len(self.truth.keys())

def __getitem__(self, index):
if not self.train:
return self._get_val_item(index)
img_path = list(self.truth.keys())[index]
bboxes = np.array(self.truth.get(img_path), dtype=np.float)
img_path = os.path.join(self.cfg.dataset_dir, img_path)
Expand Down Expand Up @@ -381,6 +384,37 @@ def __getitem__(self, index):
out_bboxes1[:min(out_bboxes.shape[0], self.cfg.boxes)] = out_bboxes[:min(out_bboxes.shape[0], self.cfg.boxes)]
return out_img, out_bboxes1

def _get_val_item(self, index):
"""
"""
img_path = self.imgs[index]
bboxes_with_cls_id = np.array(self.truth.get(img_path), dtype=np.float)
img = cv2.imread(os.path.join(self.cfg.dataset_dir, img_path))
# img_height, img_width = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (self.cfg.w, self.cfg.h))
# img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
num_objs = len(bboxes_with_cls_id)
target = {}
# boxes to coco format
boxes = bboxes_with_cls_id[...,:4]
boxes[..., 2:] = boxes[..., 2:] - boxes[..., :2] # box width, box height
target['boxes'] = torch.as_tensor(boxes, dtype=torch.float32)
target['labels'] = torch.as_tensor(bboxes_with_cls_id[...,-1].flatten(), dtype=torch.int64)
target['image_id'] = torch.tensor([get_image_id(img_path)])
target['area'] = (target['boxes'][:,3])*(target['boxes'][:,2])
target['iscrowd'] = torch.zeros((num_objs,), dtype=torch.int64)
return img, target


def get_image_id(filename:str) -> int:
"""Convert a string to a integer."""
raise NotImplementedError("Create your own 'get_image_id' function")
lv, no = os.path.splitext(os.path.basename(filename))[0].split("_")
lv = lv.replace("level", "")
no = f"{int(no):04d}"
return int(lv+no)


if __name__ == "__main__":
from cfg import Cfg
Expand Down
4 changes: 3 additions & 1 deletion tool/darknet2pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def forward(self, x):

# support route shortcut and reorg
class Darknet(nn.Module):
def __init__(self, cfgfile):
def __init__(self, cfgfile, inference=False):
super(Darknet, self).__init__()
self.inference = inference
self.training = not self.inference

self.blocks = parse_cfg(cfgfile)
self.width = int(self.blocks[0]['width'])
Expand Down
45 changes: 45 additions & 0 deletions tool/tv_reference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Object detection reference training scripts

This folder contains reference training scripts for object detection.
They serve as a log of how to train specific models, to provide baseline
training and evaluation scripts to quickly bootstrap research.

To execute the example commands below you must install the following:

```
cython
pycocotools
matplotlib
```

You must modify the following flags:

`--data-path=/path/to/coco/dataset`

`--nproc_per_node=<number_of_gpus_available>`

Except otherwise noted, all models have been trained on 8x V100 GPUs.

### Faster R-CNN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```


### Mask R-CNN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```


### Keypoint R-CNN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
--lr-steps 36 43 --aspect-ratio-group-factor 3
```

Loading

0 comments on commit 74347ac

Please sign in to comment.