-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_result.py
71 lines (58 loc) · 2.43 KB
/
evaluate_result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from calcuate_metric import calculate_oral_metric, calculate_cherrant_metric, calculate_baseline_metric
from pandas import DataFrame
import numpy as np
import os
import argparse
import json
import pickle
import Levenshtein
from pypinyin import pinyin,Style
from transformers import AutoTokenizer
import jieba
from collections import defaultdict
def get_path(folder_path):
"""
@param folder_path : the root path of json files
@return names: the name of file under the root path
"""
file_names = os.listdir(folder_path)
sorted_file_names = sorted(file_names)
names = sorted_file_names
return names
def write2excel(results,header_row,report_path):
"""
@param results: 2D lists
@param header_row: the head of excel
@param report_path: the place to save the excel
"""
shuchus = np.array(results)
shuchu_dict = {}
for i in range(len(header_row)):
shuchu_dict[header_row[i]] = list(shuchus[:,i])
df = DataFrame(shuchu_dict)
df.to_excel(report_path, index=False)
def calculate_oral_metic2excel(args, mode = "difflib"):
"""
@param path: the path of files that will be tested
@param report_path: the path to save the excel
"""
result_path = args.base_dir + "all_result/"
names = get_path(result_path)
results = []
for name in names:
name_all = result_path + name
if mode == "difflib": result = [name] + calculate_oral_metric(name_all, args.base_dir)
elif mode == "cherrant": result = [name] + calculate_cherrant_metric(name_all, args.base_dir)
elif mode == "baseline": result = [name] + calculate_baseline_metric(name_all, args.base_dir)
results.append(result)
write2excel(results,args.header_row, args.base_dir+f'mid_result/analyse_metric_{mode}.xlsx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str,
default="")
parser.add_argument("--lora_path", type=str, default="")
parser.add_argument("--header_row", type=list, default=['name','S_D_p','S_D_r', 'S_D_f1', 'S_C_p', 'S_C_r', 'S_C_f1', 'C_D_p','C_D_r', 'C_D_f1', 'C_C_p', 'C_C_r', 'C_C_f1'])
parser.add_argument("--base_dir", type=str, default="./predicts/prediction_qwen_14b/")
parser.add_argument("--mode", type=str, default="cherrant") #cherrant / difflib
args = parser.parse_args()
calculate_oral_metic2excel(args, mode = args.mode)