-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
133 lines (117 loc) · 4.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from typing import Dict, Any, Optional, Sequence, Union, List
import fire
import pandas as pd
from dataset import Dataset, DATASET_NAMES, download_dataset
from benchmark import BENCHMARKS
from helpers import sensitivity_gen
def main():
fire.Fire(
{
"benchmark": benchmark,
"sensitivity": sensitivity,
"download": download,
"list": ls,
}
)
def benchmark(
parser: str,
datasets: Optional[Union[str, Sequence[str]]] = None,
results_path: str = "results.csv",
):
"""
Benchmark a parser on a given dataset. If the dataset is not passed, the parser will be benchmarked on all
available datasets. If any dataset does not exist in the data directory, the dataset will be automatically
downloaded.
:param datasets: The name of the dataset or none for all dataset
:param parser: The parser algorithm name
:param results_path: The csv file path where the benchmark results will be stored
:return:
"""
assert parser in BENCHMARKS
assert results_path
if isinstance(datasets, str):
assert datasets in DATASET_NAMES
elif isinstance(datasets, list):
assert all(ds in DATASET_NAMES for ds in datasets)
if datasets is None:
datasets = DATASET_NAMES
elif isinstance(datasets, str):
datasets = [datasets]
assert datasets is not None
print(f"Benchmarking {parser} on {datasets}")
args = [(ds, parser) for ds in datasets]
results = [__benchmark(*a) for a in args]
df = pd.DataFrame.from_records(results)
df.to_csv(results_path, index=False)
print(f"Results stored in {results_path}")
def sensitivity(
parser: str,
datasets: Optional[Union[str, Sequence[str]]] = None,
step: float = 0.05,
results_path: str = "results_sensitivity.csv",
):
"""
TODO
:param datasets: The name of the dataset or none for all dataset
:param parser: The parser algorithm name
:param results_path: The csv file path where the benchmark results will be stored
:return:
"""
assert parser in BENCHMARKS
assert results_path
if isinstance(datasets, str):
assert datasets in DATASET_NAMES
elif isinstance(datasets, list):
assert all(ds in DATASET_NAMES for ds in datasets)
if datasets is None:
datasets = DATASET_NAMES
elif isinstance(datasets, str):
datasets = [datasets]
assert datasets is not None
print(f"Benchmarking {parser} on {datasets}")
args = [(ds, parser, sen) for ds in datasets for sen in sensitivity_gen(step)]
results = [__benchmark(*a) for a in args]
df = pd.DataFrame.from_records(results)
df.to_csv(results_path, index=False)
print(f"Results stored in {results_path}")
def download(dataset_name: str):
"""
Download the benchmark csv file from Zenodo.org
:param dataset_name: the name of the benchmark dataset all in lower case letters.
:return:
"""
assert dataset_name in DATASET_NAMES
download_dataset(dataset_name)
def ls():
"""
List available parsers and datasets
:return:
"""
print("\nParsers:")
for p in BENCHMARKS:
print(f"\t{p}")
print("\nDatasets:")
for d in DATASET_NAMES:
print(f"\t{d}")
def __benchmark(
ds_name: str, parser: str, sensitivity: Optional[float] = None
) -> Dict[str, Any]:
if sensitivity is None:
print(f"Process {os.getpid()} is benchmarking {parser} on {ds_name}")
else:
print(
f"Process {os.getpid()} is benchmarking {parser} on {ds_name} with sensitivity {sensitivity:.2f}"
)
ds = Dataset(ds_name)
b = BENCHMARKS[parser]() if sensitivity is None else BENCHMARKS[parser](sensitivity)
res = b.benchmark(ds)
print(
f"({parser}, {ds_name}) -> {', '.join(f'{k}: {v:.2f}' for k, v in res.items())}"
)
if sensitivity is None:
return {"parser": parser, "dataset": ds_name} | res
else:
return {"parser": parser, "dataset": ds_name, "sensitivity": sensitivity} | res
if __name__ == "__main__":
main()