-
Notifications
You must be signed in to change notification settings - Fork 0
/
dior.py
112 lines (104 loc) · 4.32 KB
/
dior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*-coding:utf-8-*-
from .xml_style import XMLDataset
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .builder import DATASETS
@DATASETS.register_module()
class DIORDataset(XMLDataset):
CLASSES = ('airplane', 'airport', 'baseballfield', 'basketballcourt',
'bridge', 'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank',
'tenniscourt', 'trainstation', 'vehicle', 'windmill'
)
def __init__(self, min_size=None, ann_subdir='Annotations/Horizontal_Bounding_Boxes', **kwargs):
super(DIORDataset, self).__init__(**kwargs, ann_subdir=ann_subdir)
self.num_classes = len(self.CLASSES)
self.class_to_index = dict(zip(self.CLASSES, range(self.num_classes)))
self.min_size = 0.5
# self.img_prefix = '/disk2/hm/data/DIOR/DIOR' # if need to define img_prefix by our own
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations/Horizontal_Bounding_Boxes',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
try:
root = tree.getroot()
except:
print(1)
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.data_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations/Horizontal_Bounding_Boxes',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.class_to_index[name]
bnd_box = obj.find('bndbox')
if bnd_box is not None:
bbox = [
float(bnd_box.find('xmin').text),
float(bnd_box.find('ymin').text),
float(bnd_box.find('xmax').text),
float(bnd_box.find('ymax').text)
]
else:
bnd_box = obj.find('robndbox')
if bnd_box is not None:
bbox = [
float(bnd_box.find('xmin').text),
float(bnd_box.find('ymin').text),
float(bnd_box.find('xmax').text),
float(bnd_box.find('ymax').text)
]
else:
print('------------------------------------')
print('annotation error: {}.xml'.format(img_id))
print('------------------------------------')
break
ignore = False
if self.min_size:
# assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.min_size or h < self.min_size:
ignore = True
if ignore:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0,))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0,))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann