Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
pxq0312 committed Jul 5, 2019
1 parent 2f6b1ac commit 38f46c3
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 8 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ Pytorch 1.1.0

`train.py` To train the model.

`eval.py` To test the model.

## Train & Test

For training, run

`python train.py --dataset="SHA" --data_path="path to dataset" --save_path="path to save checkpoint"`

For testing, run

`python eval.py --dataset="SHA" --data_path="path to dataset" --save_path="path to checkpoint"`

## Result

The network is training now, I will upload results and checkpoints soon.
ShanghaiTech part A: epoch367 MAE 60.43 MSE 98.24

![](./logs/A.png)

ShanghaiTech part B: epoch432 MAE 6.38 MSE 10.99

![](./logs/B.png)
2 changes: 1 addition & 1 deletion density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
if int(gt[i][1]) < img.shape[0] and int(gt[i][1]) >= 0 and int(gt[i][0]) < img.shape[1] and int(gt[i][0]) >= 0:
k[int(gt[i][1]), int(gt[i][0])] = 1
count += 1
print('Ignore {} wrong annotation'.format(len(gt) - count))
print('Ignore {} wrong annotation.'.format(len(gt) - count))
k = gaussian_filter(k, 5)
att = k > 0.001
att = att.astype(np.float32)
Expand Down
41 changes: 41 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from torch.utils import data
from dataset import Dataset
from models import Model
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='SHA', type=str, help='dataset')
parser.add_argument('--data_path', default=r'D:\dataset', type=str, help='path to dataset')
parser.add_argument('--save_path', default=r'D:\checkpoint\SFANet', type=str, help='path to save checkpoint')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')

args = parser.parse_args()

test_dataset = Dataset(args.data_path, args.dataset, False)
test_loader = data.DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device('cuda:' + str(args.gpu))

model = Model().to(device)

checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_best.pth'))
model.load_state_dict(checkpoint['model'])

model.eval()
with torch.no_grad():
mae, mse = 0.0, 0.0
for i, (images, gt) in enumerate(test_loader):
images = images.to(device)

predict, _ = model(images)

print('predict:{:.2f} label:{:.2f}'.format(predict.sum().item(), gt.item()))
mae += torch.abs(predict.sum() - gt).item()
mse += ((predict.sum() - gt) ** 2).item()

mae /= len(test_loader)
mse /= len(test_loader)
mse = mse ** 0.5
print('MAE:', mae, 'MSE:', mse)
Binary file added logs/A.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added logs/B.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=8, type=int, help='batch size')
parser.add_argument('--epoch', default=500, type=int, help='train epochs')
parser.add_argument('--dataset', default='SHB', type=str, help='dataset')
parser.add_argument('--dataset', default='SHA', type=str, help='dataset')
parser.add_argument('--data_path', default=r'D:\dataset', type=str, help='path to dataset')
parser.add_argument('--lr', default=1e-4, type=float, help='initial learning rate')
parser.add_argument('--load', default=False, action='store_true', help='load checkpoint')
parser.add_argument('--save_path', default=r'D:\checkpoint\SFANet', type=str, help='path to save checkpoint')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
parser.add_argument('--log', default='./logs', type=str, help='path to log')
parser.add_argument('--log_path', default='./logs', type=str, help='path to log')

args = parser.parse_args()

Expand All @@ -30,10 +30,10 @@

model = Model().to(device)

writer = SummaryWriter(args.log)
writer = SummaryWriter(args.log_path)

mseloss = nn.MSELoss(size_average=False).to(device)
bceloss = nn.BCELoss(size_average=False).to(device)
mseloss = nn.MSELoss(reduction='sum').to(device)
bceloss = nn.BCELoss(reduction='sum').to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

if args.load:
Expand All @@ -47,7 +47,7 @@
best_mae = 999999
start_epoch = 0

for epoch in range(start_epoch, args.epoch):
for epoch in range(start_epoch, start_epoch + args.epoch):
loss_avg, loss_att_avg = 0.0, 0.0

for i, (images, density, att) in enumerate(train_loader):
Expand Down

0 comments on commit 38f46c3

Please sign in to comment.