Skip to content

Commit

Permalink
add train and test script
Browse files Browse the repository at this point in the history
  • Loading branch information
wangke committed Dec 7, 2019
1 parent 9841c1b commit 4d80355
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 36 deletions.
4 changes: 3 additions & 1 deletion tools/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.pth
*.log
*.log
*.json
imagesWithBox/
118 changes: 99 additions & 19 deletions tools/getTrainDataSet.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,108 @@
import cv2
import tkinter as tk
import json
import json, glob
import os

window = tk.Tk()
person = {0: 'wk', 1: 'dln', 2: 'kml', 3: 'zsy', 4: 'hh', 5: 'wly', 6: 'lch', 7: 'wzh', 8: 'ys'}

window.title('Make DataSet')
window.geometry('900x300')

person = {0: 'wk', 1: 'dln', 2: 'kml', 3: 'zsy', 4: 'hh', 5: 'wly', 6: 'lch', 7: 'wzh', 8: 'ys'}
def main():
window = tk.Tk()

window.title('Make DataSet')
window.geometry('900x300')

imagesWithBox = './imagesWithBox'
imagesList = os.listdir(imagesWithBox)

labelsJson = './labels.json'

if not os.path.exists(labelsJson):
labels = {}
else:
f = open(labelsJson, 'r')
labels = json.load(f)

needHandleList = []

for path in imagesList:
name = path.replace('.png', '')
if name not in labels.keys():
needHandleList.append(path)

if len(needHandleList) == 0:
exit(0)

index = 0

l = [0] * 12

img = cv2.imread(imagesWithBox + '/' + needHandleList[index])
cv2.imshow('Image', img)

content = tk.StringVar()
content.set("There are %d images need to handle" % len(needHandleList))
w = tk.Label(window, textvariable=content)
w.pack()


def onClick(data):
global index, l
if data == 'save':
if index >= len(needHandleList):
content.set(" " * 100)
content.set("Finished!")
return
if l[0] == 0 and l[1] == 0 and l[2] == 0:
return
mark = False
for i in l[3:]:
if i == 1:
mark = True
if not mark:
return
labels[needHandleList[index].replace('.png', '')] = l
with open(labelsJson, 'w') as file:
json.dump(labels, file)

index += 1
if index >= len(needHandleList):
return
l = [0] * 12
img = cv2.imread(imagesWithBox + '/' + needHandleList[index])
cv2.imshow('Image', img)
elif data == 'come in':
l[0] = 1
l[1] = 0
l[2] = 0
elif data == 'come out':
l[0] = 0
l[1] = 1
l[2] = 0
elif data == 'stay':
l[0] = 0
l[1] = 0
l[2] = 1
else:
l[3 + int(data)] = 1 if l[3 + int(data)] == 0 else 0
if index < len(needHandleList):
content.set(
needHandleList[index].replace('.png', '') + " " + str(l) + '%d left' % (len(needHandleList) - index))

def onClick(data):
print(data)

b = tk.Button(window, text='come in', font=('Arial', 12), width=10, height=2, command=lambda: onClick('come in'))
b.pack()
b = tk.Button(window, text='come out', font=('Arial', 12), width=10, height=2, command=lambda: onClick('come out'))
b.pack()
b = tk.Button(window, text='stay', font=('Arial', 12), width=10, height=2, command=lambda: onClick('stay'))
b.pack()
b = tk.Button(window, text='save', font=('Arial', 12), width=10, height=2, command=lambda: onClick('save'))
b.pack()
for p in person:
b = tk.Button(window, text=person[p], font=('Arial', 12), width=10, height=2, command=lambda p=p: onClick(p))
b.pack(side='left')

b = tk.Button(window, text='come in', font=('Arial', 12), width=10, height=2, command=lambda: onClick('come in'))
b.pack()
b = tk.Button(window, text='come out', font=('Arial', 12), width=10, height=2, command=lambda: onClick('come out'))
b.pack()
b = tk.Button(window, text='stay', font=('Arial', 12), width=10, height=2, command=lambda: onClick('stay'))
b.pack()
b = tk.Button(window, text='save', font=('Arial', 12), width=10, height=2, command=lambda: onClick('save'))
b.pack()
for p in person:
b = tk.Button(window, text=person[p], font=('Arial', 12), width=10, height=2, command=lambda: onClick(p))
b.pack(side='left')
window.mainloop()

window.mainloop()
if __name__ == '__main__':
main()
8 changes: 7 additions & 1 deletion tools/imagesDrawBoxes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cv2
import sys
import sys, os

sys.path.append('..')
import config
Expand All @@ -16,6 +16,12 @@

imagesPath = '../images'
imagesWithBox = './imagesWithBox'

try:
os.mkdir(imagesWithBox)
except:
pass

pngList = glob.glob(imagesPath + '/*.png')
savedList = glob.glob(imagesWithBox + '/*.png')
savedListSet = set()
Expand Down
72 changes: 72 additions & 0 deletions tools/testModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import models
import time, glob
import cv2
import numpy as np
from model import ResNet50
import json
from getTrainDataSet import person

OUT_DIM = 12

resnet = models.resnet50()
resnet.load_state_dict(torch.load('./resnet50-19c8e357.pth'))

net = ResNet50(resnet, OUT_DIM)

net.load_state_dict(torch.load('models/latestModel.pth'))

trainDataPath = './imagesWithBox'
allTrainDataList = glob.glob(trainDataPath + '/*.png')

f = open('./labels.json', 'r')
labels = json.load(f)

transform = Compose([ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])


def getStr(data):
data -= 3
if data == -3:
return 'come in, '
elif data == -2:
return 'come out, '
elif data == -1:
return 'stay, '
else:
return person[data]


for path in allTrainDataList:
name = path.replace(trainDataPath, '').replace('.png', '')[1:]
# print(name)
imgO = cv2.imread(path)
img = cv2.cvtColor(imgO, cv2.COLOR_BGR2RGB)
img = transform(img).unsqueeze_(0)
t1 = time.time()
predLabel = net(img)
gtLable = labels[name]
print('use %.4f' % (time.time() - t1))
predictions = predLabel.data.cpu().numpy()[0]
idx_list = np.where(predictions > np.percentile(predictions, 90))[0]
# print(gtLable)
gtStr = 'GT '
for i in range(len(gtLable)):
if gtLable[i] == 1:
gtStr += getStr(i)
gtStr += ' '
print(gtStr)
# print(idx_list)
predStr = 'PD '
for i in idx_list:
predStr += getStr(i)
predStr += ' '
print(predStr)
imgO = cv2.putText(imgO, gtStr, (10, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0, 0, 0), 1, cv2.LINE_AA)
imgO = cv2.putText(imgO, predStr, (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0, 0, 0), 1, cv2.LINE_AA)
cv2.imshow("image", imgO)
cv2.waitKey(0)

64 changes: 49 additions & 15 deletions tools/trainModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import models
import torch.utils.data as data
import time, glob
import time, glob, math
import cv2
from sklearn.metrics import average_precision_score
import numpy as np
import os
from model import ResNet50
import sys, shutil
import sys, shutil, json

sys.path.append('..')
import logger
Expand All @@ -20,7 +22,7 @@
OUT_DIM = 12
BATCH_SIZE = 1
NUM_EPOCHS = 20
# PERCENTILE = 99.7
PERCENTILE = 0
LEARNING_RATE = 0.0001


Expand All @@ -34,15 +36,22 @@ def __init__(self, imgPath, imgList):
self.imgList = imgList
self.dataLen = len(self.imgList)
self.transform = Compose([ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
f = open('./labels.json', 'r')
self.labels = json.load(f)

def __len__(self):
return self.dataLen

def __getitem__(self, index):
name = self.imgList[index].replace(self.imgPath, '').replace('.png', '')[1:]
l = self.labels[name]

img = cv2.imread(self.imgList[index])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(img)
label = torch.zeros((1, OUT_DIM))
# label = torch.zeros((1, OUT_DIM))
label = torch.from_numpy(np.array(l).astype(np.float32))

return img, label.squeeze()


Expand All @@ -60,26 +69,37 @@ def __getitem__(self, index):
except:
log.info('pre trained model isn\'t exist')

try:
os.mkdir('models')
except:
pass

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

trainDataPath = './imagesWithBox'
allTrainDataList = glob.glob(trainDataPath + '/*.png')

trainDataSet = ImageDataSet(trainDataPath, allTrainDataList)
trainDataLoader = data.DataLoader(dataset=trainDataSet, batch_size=BATCH_SIZE, shuffle=False)
imgs, labels = next(iter(trainDataLoader))
print(imgs.size(), labels.size())

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
lossFunc = torch.nn.BCEWithLogitsLoss()
lossFunc = torch.nn.MultiLabelSoftMarginLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
bestLoss = np.inf


def compute_mAP(labels, outputs):
y_true = labels.cpu().numpy()
y_pred = outputs.cpu().numpy()
AP = []
for i in range(y_pred.shape[0]):
AP.append(average_precision_score(y_true[i], y_pred[i]))
return np.mean(AP)

for epoch in range(NUM_EPOCHS):
runningLoss = 0.0
epochStartTime = time.time()
net.train()

mAP = []
log.info("set learning rate: %.6f" % optimizer.param_groups[0]['lr'])

for i, (imagesBatch, labelsBatch) in enumerate(trainDataLoader):
Expand All @@ -89,23 +109,37 @@ def __getitem__(self, index):

optimizer.zero_grad()

with torch.set_grad_enabled(True):
predBatch = net(imagesBatch)
loss = lossFunc(predBatch, labelsBatch)
# with torch.set_grad_enabled(True):
predBatch = net(imagesBatch)
loss = lossFunc(predBatch, labelsBatch)

loss.backward()
optimizer.step()

runningLoss += loss.item() * imagesBatch.size(0)
# predictions = predBatch.detach().numpy()[0]
# idx_list = np.where(predictions > np.percentile(predictions, PERCENTILE))
# print(idx_list)
m = compute_mAP(labelsBatch.data, predBatch.data)
mAP.append(m)

elapsedTime = time.time() - startTime
log.info("Epoch[{}]: {}/{} | loss:{:.8f} | Time: {:.4f}s".format(epoch + 1, i + 1, len(trainDataLoader.dataset),
runningLoss / (i + 1), elapsedTime))
log.info(
"Epoch:[{:3d}/{:3d}]: {:3d}/{:3d} | loss:{:.8f} | mAP:{:.2f} | Time: {:.4f}s".format(epoch + 1, NUM_EPOCHS,
i + 1,
math.ceil(len(
trainDataLoader.dataset) / BATCH_SIZE),
runningLoss / (i + 1),
np.mean(mAP[
0 - BATCH_SIZE:]),
elapsedTime))

modelName = 'models/{}_{}_model.pth'.format(time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime()), epoch + 1)
torch.save(net.state_dict(), modelName)
shutil.copy(modelName, 'models/latestModel.pth')
scheduler.step(loss)
epochLoss = runningLoss / len(trainDataLoader.dataset)
epochElapsedTime = time.time() - epochStartTime
log.info("Epoch: {}/{} | loss:{:.8f} | Time: {:.4f}s".format(epoch + 1, NUM_EPOCHS, epochLoss, epochElapsedTime))
log.info(
"Epoch:[{:3d}/{:3d}] | loss:{:.8f} | mAP:{:.2f} | Time: {:.4f}s".format(epoch + 1, NUM_EPOCHS, epochLoss,
np.mean(mAP), epochElapsedTime))

0 comments on commit 4d80355

Please sign in to comment.