-
Notifications
You must be signed in to change notification settings - Fork 32
/
teacher.py
82 lines (67 loc) · 2.91 KB
/
teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from utils.torch_utils import select_device
from utils.general import non_max_suppression, scale_coords, xyxy2xywh
import numpy as np
import torch
from utils.torch_utils import select_device
class TeacherModel(object):
def __init__(self, conf_thres=0.5, iou_thres=0.3, imgsz=640, training=False):
self.model = None
self.device = None
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.imgsz = imgsz
self.training = training
def init_model(self, weights, device, nc):
device = select_device(device)
t_model = torch.load(weights, map_location=torch.device('cpu'))
if t_model.get("model", None) is not None:
t_model = t_model["model"]
t_model.to(device)
t_model.float()
self.model = t_model
self.device = device
if self.training:
self.model.train()
else:
self.model.eval()
self.stride = int(self.model.stride.max())
self.nc = nc
def generate_batch_targets(self, imgs, tar_size=[640, 640]):
targets = []
with torch.no_grad():
if self.training:
preds = self.model(imgs)
else:
preds = self.model(imgs)[0]
if not self.training:
for img_id in range(imgs.shape[0]):
pred = preds[img_id:img_id+1]
pred = non_max_suppression(
pred, self.conf_thres, self.iou_thres, distill=True, agnostic=False)
for det in pred: # detections per image
gn = torch.tensor(tar_size)[[1, 0, 1, 0]]
if len(det):
# Rescale boxes from img_size to img0 size
det[:, :4] = scale_coords(
imgs[img_id].unsqueeze(0).shape[2:], det[:, :4], tar_size).round()
for value in reversed(det):
xyxy, cls_id = value[:4], value[5]
logits = value[-self.nc:].logit().tolist()
xywh = (xyxy2xywh(torch.tensor(xyxy.cpu()).view(1, 4)
) / gn).view(-1).tolist() # normalized xywh
line = [img_id, int(cls_id)]
line.extend(xywh)
line.extend(logits)
targets.append(line)
return torch.tensor(np.array(targets), dtype=torch.float32), None
else:
return [], preds
if __name__ == '__main__':
teacher = TeacherModel(conf_thres=0.0001)
teacher.init_model('weights/yolov5m-voc.pt', select_device('0'), 2, 20)
# img0 = cv2.imread('../xingren.jpg')
# img0, bboxes = teacher.predict(img0)
# cv2.imshow('winname', img0)
# cv2.waitKey(0)
imgs = torch.rand((2, 3, 640, 640)).to(teacher.device)
targets = teacher.generate_batch_targets(imgs)