-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwiderface2yolo.py
144 lines (111 loc) · 4.81 KB
/
widerface2yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-
"""
@Time : 2024/6/16 20:09
@File : widerface2yolo.py
@Author : zj
@Description:
Download the WIDERFACE dataset: http://shuoyang1213.me/WIDERFACE/
Download face and keypoint annotations: https://drive.google.com/file/d/1tU_IjyOwGQfGNUvZGwWWM4SwxKp2PUQ8/view?usp=sharing
Usage - Convert the WIDERFACE dataset format to YOLO:
$ python3 widerface2yolo.py ../datasets/widerface/WIDER_train/images ../datasets/widerface/retinaface_gt_v1.1/train/label.txt ../datasets/widerface
$ python3 widerface2yolo.py ../datasets/widerface/WIDER_val/images ../datasets/widerface/retinaface_gt_v1.1/val/label.txt ../datasets/widerface
"""
import os
import cv2
import shutil
import argparse
import numpy as np
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="WiderFace2YOLO")
parser.add_argument('image', metavar='IMAGE', type=str, help='WiderFace image root.')
parser.add_argument('label', metavar='LABEL', type=str, help='WiderFace label path.')
parser.add_argument('dst', metavar='DST', type=str, help='YOLOLike data root.')
args = parser.parse_args()
print("args:", args)
return args
def load_label(file_path, is_train=True):
data = []
current_image_data = None
with open(file_path, 'r') as file:
for line in file:
if line.startswith('#'):
# 新的图像路径开始
if current_image_data is not None:
data.append(current_image_data)
image_path = line.strip()[2:]
current_image_data = {'image_path': image_path, 'annotations': []}
else:
parts = line.split(' ')
bbox = list(map(int, parts[:4]))
if is_train:
# 从第5个元素开始,直到倒数第二个元素,每2个元素形成一个关键点
keypoints = [(float(parts[i]), float(parts[i + 1])) for i in range(4, len(parts) - 1, 3)]
assert len(keypoints) == 5, keypoints
confidence = float(parts[-1])
else:
keypoints = [(-1.0, -1.0) for i in range(5)]
confidence = 0.
annotation = {
'bbox': bbox,
'keypoints': keypoints,
'confidence': confidence
}
current_image_data['annotations'].append(annotation)
# 添加最后一个图像的信息
if current_image_data is not None:
data.append(current_image_data)
return data
def main():
args = parse_args()
dst_root = args.dst
img_root = args.image
label_path = args.label
is_train = 'val' not in label_path
if is_train:
dst_img_root = os.path.join(dst_root, "images/train")
dst_label_root = os.path.join(dst_root, "labels/train")
else:
dst_img_root = os.path.join(dst_root, "images/val")
dst_label_root = os.path.join(dst_root, "labels/val")
if not os.path.exists(dst_img_root):
os.makedirs(dst_img_root)
if not os.path.exists(dst_label_root):
os.makedirs(dst_label_root)
cls_id = 0
assert os.path.exists(img_root), img_root
assert os.path.exists(label_path), label_path
print(f"Parse {label_path}")
results = load_label(label_path, is_train=is_train)
print(f"Processing {len(results)} images")
for result in tqdm(results):
image_path = os.path.join(img_root, result["image_path"])
assert os.path.isfile(image_path), image_path
image = cv2.imread(image_path)
height, width, channels = image.shape
labels = []
for anno in result['annotations']:
label = []
assert isinstance(anno, dict)
x1, y1, box_w, box_h = anno['bbox']
x_c = 1.0 * (x1 + box_w / 2) / width
y_c = 1.0 * (y1 + box_h / 2) / height
box_w = 1.0 * box_w / width
box_h = 1.0 * box_h / height
label.extend([cls_id, x_c, y_c, box_w, box_h])
for point in anno['keypoints']:
x, y = point
if x > 0:
x = x / width
if y > 0:
y = y / height
label.extend([x, y])
labels.append(label)
image_name = os.path.basename(image_path)
dst_img_path = os.path.join(dst_img_root, image_name)
shutil.copy(image_path, dst_img_path)
name = os.path.splitext(image_name)[0]
dst_label_path = os.path.join(dst_label_root, f"{name}.txt")
np.savetxt(dst_label_path, np.array(labels), delimiter=' ', fmt='%s')
if __name__ == '__main__':
main()