forked from Megvii-BaseDetection/YOLOX
-
Notifications
You must be signed in to change notification settings - Fork 25
/
yolov3.py
89 lines (75 loc) · 2.94 KB
/
yolov3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import torch
import torch.nn as nn
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
self.depth = 1.0
self.width = 1.0
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
def get_model(self, sublinear=False):
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
if "model" not in self.__dict__:
from yolox.models import YOLOX, YOLOFPN, YOLOXHead
backbone = YOLOFPN()
head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
self.model = YOLOX(backbone, head)
self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
return self.model
def get_data_loader(self, batch_size, is_distributed, no_aug=False):
from data.datasets.cocodataset import COCODataset
from data.datasets.mosaicdetection import MosaicDetection
from data.datasets.data_augment import TrainTransform
from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler
import torch.distributed as dist
dataset = COCODataset(
data_dir='data/COCO/',
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=50
),
)
dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=120
),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
)
self.dataset = dataset
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
else:
sampler = torch.utils.data.RandomSampler(self.dataset)
batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
input_dimension=self.input_size,
mosaic=not no_aug
)
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
return train_loader