forked from ouening/OD_dataset_conversion_scripts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvoc2yolo.py
325 lines (273 loc) · 11.7 KB
/
voc2yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
'''
PASCAL VOC格式数据集转YOLO格式数据集
适合项目地址:
1. https://github.com/eriklindernoren/PyTorch-YOLOv3
2. https://github.com/ultralytics/yolov3/
3. https://github.com/AlexeyAB
4. https://github.com/ultralytics/yolov5/
该项目对自定义的数据集格式要求图片要有对应的txt格式标注文件,要求图片存放在images文件夹,标签存放在labels文件夹,例如:
data/custom/images/train.jpg
data/custom/labels/train.txt
yolo_classes.names
yolo_classes_ssd.names
trainval.txt
train.txt
val.txt
当然,images文件夹和labels这两个文件夹名称可以更改,但相应的也要在代码中做修改(PyTorch-YOLOV3项目):
```utils/datasets.py: line 65
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
## ^^^^^^ and ^^^^^^ 修改这两处的值
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.augment = augment
...
```
labels/train.txt的标注信息格式为:
label_idx x_center y_center width height(归一化数值)
label_idx x_center y_center width height(归一化数值)
...
trainval.txt,val.txt,test.txt文件每一行记录了图像数据所在的全路径,这几个文件和yolo_classes.names
会在U版和A版的YOLOv3/v4系列的*.data配置文件中使用。在U版的yolov5模型中,数据配置文件保存在data/*.yaml文件中,其示例内容如下:
```
# train and val data as
# 1) directory: path/images/,
# 2) file: path/images.txt, or
# 3) list: [path1/images/, path2/images/]
train: /data/custom_yolo/trainval.txt
val: /data/custom_yolo/test.txt
# number of classes
nc: 2
# class names
names: ['person', 'bicycle']
```
'''
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import pandas as pd
import numpy as np
from collections import Counter
import argparse
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import sys
import shutil
from pathlib import Path
def counting_labels(anno_root,yolo_root):
'''
获取pascal voc格式数据集中的所有标签名
anno_root: pascal标注文件路径,一般为Annotations
'''
all_classes = []
for xml_file in os.listdir(anno_root):
xml_file = os.path.join(anno_root, xml_file)
# print(xml_file)
xml = open(xml_file,encoding='utf-8')
tree=ET.parse(xml)
root = tree.getroot()
for obj in root.iter('object'):
class_ = obj.find('name').text.strip()
all_classes.append(class_)
print(Counter(all_classes))
labels = list(set(all_classes))
print('标签数据:', labels)
print('标签长度:', len(labels))
print('写入标签信息...{}'.format(os.path.join(yolo_root,'yolo_classes.names')))
with open( os.path.join(yolo_root,'yolo_classes.names') , 'w') as f:
for k in labels:
f.write(k)
f.write('\n')
with open( os.path.join(yolo_root,'yolo_classes_ssd.names') , 'w') as f:
for k in labels:
f.write("\'"+k+"\'"+',')
f.write('\n')
return labels
def convert(size, box):
dw = 1./(size[0]) # 宽度缩放比例, size[0]为图像宽度width
dh = 1./(size[1])
x = (box[0] + box[1])/2.0 - 1
y = (box[2] + box[3])/2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h) # <x_center> <y_center> <width> <height>
def convert_annotation(anno_root:str, image_id, classes, dest_yolo_dir='YOLOLabels'):
'''
anno_root:pascal格式标注文件路径,一般为Annotations
image_id:文件名(图片名和对应的pascal voc格式标注文件名是一致的)
dest_yolo_dir:yolo格式标注信息目标保存路径,默认为opt.yolo_dir
'''
in_file = open( os.path.join(anno_root, image_id+'.xml'), encoding='utf-8')
out_file = open(os.path.join(dest_yolo_dir, image_id+'.txt'), 'w')
tree=ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
def gen_image_ids(jpeg_root):
'''
jpeg_root: JPEGImages文件夹路径
'''
img_ids = []
for k in os.listdir(jpeg_root):
img_ids.append(k) # 图片名,含后缀
return img_ids
def create_dir(ROOT:str):
if not os.path.exists(ROOT):
os.mkdir(ROOT)
else:
shutil.rmtree(ROOT) # 先删除,再创建
os.mkdir(ROOT)
def check_files(ann_root, img_root):
'''检测图像名称和xml标准文件名称是否一致,检查图像后缀'''
if os.path.exists(ann_root):
ann = Path(ann_root)
else:
raise Exception("标注文件路径错误")
if os.path.exists(img_root):
img = Path(img_root)
else:
raise Exception("图像文件路径错误")
ann_files = []
img_files = []
img_exts = []
for an, im in zip(ann.iterdir(),img.iterdir()):
ann_files.append(an.stem)
img_files.append(im.stem)
img_exts.append(im.suffix)
print('图像后缀列表:', np.unique(img_exts))
if len(np.unique(img_exts)) > 1:
# print('数据集包含多种格式图像,请检查!', np.unique(img_exts))
raise Exception('数据集包含多种格式图像,请检查!', np.unique(img_exts))
if set(ann_files)==set(img_files):
print('标注文件和图像文件匹配')
else:
print('标注文件和图像文件不匹配')
return np.unique(img_exts)[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--voc-root', type=str, required=True,
help='VOC格式数据集根目录,该目录下必须包含存储图像和标注文件的两个文件夹')
parser.add_argument('--img_dir', type=str, required=False,
help='VOC格式数据集图像存储路径,如果不指定,默认为JPEGImages')
parser.add_argument('--anno_dir', type=str, required=False,
help='VOC格式数据集标注文件存储路径,如果不指定,默认为Annotations')
parser.add_argument('--yolo-dir',type=str, default='YOLODataset',
help='yolo格式数据集保存路径,默认为VOC数据集相同路径下新建文件夹YOLODataset')
parser.add_argument('--valid-ratio',type=float, default=0.3,
help='验证集比例,默认为0.3')
opt = parser.parse_args()
voc_root = opt.voc_root
print('Pascal VOC格式数据集路径:', voc_root)
if opt.img_dir is None:
img_dir = 'JPEGImages'
else:
img_dir = opt.img_dir
jpeg_root = os.path.join(voc_root, img_dir)
if not os.path.exists(jpeg_root):
raise Exception(f'数据集图像路径{jpeg_root}不存在!')
if opt.anno_dir is None:
anno_dir = 'Annotations'
else:
anno_dir = opt.anno_dir
anno_root = os.path.join(voc_root,anno_dir)
if not os.path.exists(anno_root):
raise Exception(f'数据集图像路径{anno_root}不存在!')
# 确定图像后缀
ext = check_files(anno_root, jpeg_root)
assert ext is not None, "请检查图像后缀是否正确!"
# YOLO数据集存储路径
dest_yolo_dir = os.path.join(str(Path(voc_root).parent), opt.yolo_dir)
#
image_ids = gen_image_ids(jpeg_root)
print('数据集长度:', len(image_ids))
if not os.path.exists(dest_yolo_dir):
os.makedirs(dest_yolo_dir) # 创建labels文件夹,存储yolo格式标注文件
yolo_labels = os.path.join(dest_yolo_dir,'labels')
create_dir(yolo_labels)
yolo_images = os.path.join(dest_yolo_dir,'images')
create_dir(yolo_images)
classes = counting_labels(anno_root,dest_yolo_dir)
images_path = [] # 图片的绝对路径
length = len(image_ids)
for idx, img in enumerate(image_ids):
sys.stdout.write('\r>> Converting image %d/%d' % (
idx + 1, length))
sys.stdout.flush()
# print('图片名称:', os.path.join(pwd, 'JPEGImages', img)) #
images_path.append(os.path.join(voc_root, 'JPEGImages', img))
image_id = img[:-4] # 图像名称
# print('图像名称:', image_id)
# 转换标签
convert_annotation(anno_root, image_id, classes, dest_yolo_dir=yolo_labels)
shutil.copy(os.path.join(voc_root, 'JPEGImages', img), yolo_images)
## 生成用于config/custom.data指定的训练训练集和验证集文件yolo_train.txt和yolo_valid.txt
# 该文件的内容就是每行为图片数据在文件系统中的绝对路径
ratio = opt.valid_ratio # 验证集比例
def write_txt(txt_path, data):
'''写入txt文件'''
with open(txt_path,'w') as f:
for d in data:
f.write(str(d))
f.write('\n')
if os.path.exists(os.path.join(voc_root, 'ImageSets/Main/trainval.txt')):
print('\n使用ImageSet信息分割数据集')
trainval_file = os.path.join(voc_root, 'ImageSets/Main/trainval.txt')
trainval_name = [i.strip() for i in open(trainval_file,'r').readlines()]
trainval = [os.path.join(yolo_images,name+ext) for name in trainval_name]
train_file = os.path.join(voc_root, 'ImageSets/Main/train.txt')
train_name = [i.strip() for i in open(train_file,'r').readlines()]
train = [os.path.join(yolo_images,name+ext) for name in train_name]
val_file = os.path.join(voc_root, 'ImageSets/Main/val.txt')
val_name = [i.strip() for i in open(val_file,'r').readlines()]
val = [os.path.join(yolo_images,name+ext) for name in val_name]
test_file = os.path.join(voc_root, 'ImageSets/Main/test.txt')
test_name = [i.strip() for i in open(test_file,'r').readlines()]
test = [os.path.join(yolo_images,name+ext) for name in test_name]
print('训练集数量: ',len(train_name))
print('训练集验证集数量: ',len(trainval_name))
print('验证集数量: ',len(val_name))
print('测试集数量: ',len(test_name))
else:
print('\n使用YOLO格式图像信息分割数据集')
p = Path(yolo_images)
files = []
for file in p.iterdir():
name,sufix = file.name.split('.')
files.append(str(file))
trainval, test = train_test_split(files, test_size=ratio)
train, val = train_test_split(trainval,test_size=0.2)
print('训练集数量: ',len(train))
print('验证集数量: ',len(val))
print('测试集数量: ',len(test))
# 写入各个txt文件
trainval_txt = os.path.join(dest_yolo_dir,'trainval.txt')
write_txt(trainval_txt, trainval)
train_txt = os.path.join(dest_yolo_dir,'train.txt')
write_txt(train_txt, train)
val_txt = os.path.join(dest_yolo_dir,'val.txt')
write_txt(val_txt, val)
test_txt = os.path.join(dest_yolo_dir,'test.txt')
write_txt(test_txt, test)