-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchord_rec.py
74 lines (57 loc) · 2.05 KB
/
chord_rec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
This script does chord recognition from music audio.
Author: Sivan Ding
sivan.d@nyu.edu
References:
https://github.com/bmcfee/crema
https://github.com/ejhumphrey/ace-lessons
https://github.com/bmcfee/ismir2017_chords
"""
import pickle
import crema
from tension_map import *
from utils import *
def loader_id(loader, model):
chord_gt = []
chord_est = []
color_gt = []
color_est = []
jam_gt = []
jam_est = []
for audio, sr, gt_chord in loader:
# let's first get rid of dyads...
if gt_chord.split(':')[1] in ['min2', 'maj2', 'min3', 'maj3', 'perf4', 'tritone', 'perf5', 'min6', 'maj6',
'aug6', 'maj7_2', 'octave']:
continue
gt_chord = match_chord2jam(gt_chord)
gt_color = chord2polar(gt_chord)
preds, pred_jam = chord_id(audio, sr, model)
pred_chord = preds['value'][0]
pred_color = chord2polar(pred_chord)
chord_gt.append(gt_chord)
chord_est.append(pred_chord)
color_gt.append(gt_color)
color_est.append(pred_color)
jam_gt.append(jam_label(len(audio) / sr, gt_chord))
jam_est.append(pred_jam)
return chord_gt, chord_est, color_gt, color_est, jam_gt, jam_est
if __name__ == '__main__':
# let's consider jazznet format only for now.
# custom configurations
database = "/Users/sivanding/database/jazznet/chords"
metadata = "/Users/sivanding/database/jazznet/metadata/tiny.csv"
# get model
model = crema.models.chord.ChordModel()
# load dataset
# test_loader = get_loader(database, metadata, split='test')
# with open("data/test_loader", "wb") as fp: # Pickling
# pickle.dump(test_loader, fp)
with open("data/test_loader", "rb") as fp: # Unpickling
test_loader = pickle.load(fp)
# get results
c_gt, c_est, t_gt, t_est, j_gt, j_est = loader_id(test_loader, model)
# get metrics
chord_met = chord_acc(j_gt, j_est)
tension_met = color_acc(t_gt, t_est)
print(chord_met.mean(axis=1))
print(tension_met)