forked from antoniorv6/SMT
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval_functions.py
113 lines (92 loc) · 3.25 KB
/
eval_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from utils import levenshtein
def parse_krn_content(krn, ler_parsing=False, cer_parsing=False):
if cer_parsing:
krn = krn.replace("\n", " <b> ")
krn = krn.replace("\t", " <t> ")
tokens = krn.split(" ")
characters = []
for token in tokens:
if token in ['<b>', '<t>']:
characters.append(token)
else:
for char in token:
characters.append(char)
return characters
elif ler_parsing:
krn_lines = krn.split("\n")
for i, line in enumerate(krn_lines):
line = line.replace("\n", " <b> ")
line = line.replace("\t", " <t> ")
krn_lines[i] = line
return krn_lines
else:
krn = krn.replace("\n", " <b> ")
krn = krn.replace("\t", " <t> ")
return krn.split(" ")
def compute_metric(a1, a2):
acc_ed_dist = 0
acc_len = 0
for (h, g) in zip(a1, a2):
acc_ed_dist += levenshtein(h, g)
acc_len += len(g)
return 100.*acc_ed_dist / acc_len
def compute_poliphony_metrics(hyp_array, gt_array):
hyp_cer = []
gt_cer = []
hyp_ser = []
gt_ser = []
hyp_ler = []
gt_ler = []
for h_string, gt_string in zip(hyp_array, gt_array):
hyp_ler.append(parse_krn_content(h_string, ler_parsing=True, cer_parsing=False))
gt_ler.append(parse_krn_content(gt_string, ler_parsing=True, cer_parsing=False))
hyp_ser.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=False))
gt_ser.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=False))
hyp_cer.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=True))
gt_cer.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=True))
acc_ed_dist = 0
acc_len = 0
cer = 0
ser = 0
ler = 0
for (h, g) in zip(hyp_cer, gt_cer):
acc_ed_dist += levenshtein(h, g)
acc_len += len(g)
cer = compute_metric(hyp_cer, gt_cer)
ser = compute_metric(hyp_ser, gt_ser)
ler = compute_metric(hyp_ler, gt_ler)
return cer, ser, ler
def extract_music_text(array):
lines = array.split("\n")
lyrics = []
symbols = []
for idx, l in enumerate(lines):
if '.\t.\n' in l:
continue
if idx > 0 and len(l.rstrip().split('\t')) > 1:
symbols.append(l.rstrip().split('\t')[0])
lyrics.append(l.rstrip().split('\t')[1])
return lyrics, symbols, " ".join(lyrics)
def extract_music_textllevel(array):
lines = []
lcontent = []
completecontent = []
krn = array.split("\n")
for line in krn:
line = line.replace("\n", "<b>")
line = line.split("\t")
if len(line)>1:
lcontent.append(line[0])
completecontent.append(line[0])
lcontent.append("<t>")
completecontent.append("<t>")
for token in line[1]:
if token != '<':
lcontent.append(token)
completecontent.append(token)
else:
lcontent.append("<b>")
break
lines.append(lcontent)
lcontent = []
return lines, completecontent