-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
55 lines (39 loc) · 1.71 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch.utils.data import DataLoader
import argparse
from scipy.optimize import minimize_scalar
import numpy as np
from sklearn.metrics import average_precision_score, f1_score, precision_recall_fscore_support
from musicnet import MusicNet_song
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='/media/ycy/Shared/Datasets/musicnet')
parser.add_argument('--infile', type=str, default='pre-trained.pth')
if __name__ == '__main__':
args = parser.parse_args()
test_ids = [2303, 1819, 2382]
net = torch.load(args.infile)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
net.eval()
y_true = []
y_score = []
with torch.no_grad():
for id in test_ids:
print('==> Loading ID', id)
test_song = MusicNet_song(args.root, id, 44100)
test_loader = DataLoader(test_song, batch_size=10, num_workers=1)
for _, (inputs, targets) in enumerate(test_loader):
y_true += [targets.detach().numpy()]
inputs = inputs.to(device)
outputs = net(inputs)
y_score += [outputs.detach().cpu().numpy()]
y_score = np.vstack(y_score).flatten()
y_true = np.vstack(y_true).flatten()
print("average precision on testset:", average_precision_score(y_true, y_score))
def threshold(x):
y2 = y_score > x
return 1 - f1_score(y_true, y2)
res = minimize_scalar(threshold, bounds=(0, 1), method='bounded')
# need to change to mir_eval in the future
print('threshold is', res.x)
print(precision_recall_fscore_support(y_true, y_score > res.x, average='binary')[:3])