Skip to content

Commit

Permalink
Benchmarks (#165)
Browse files Browse the repository at this point in the history
* add benchmarks

* fix readme

* fix readme

* bolt font for arguments
  • Loading branch information
RodionovDenis authored Oct 9, 2023
1 parent 2e634ae commit 11febf7
Show file tree
Hide file tree
Showing 15 changed files with 977 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ dmypy.json
# Pyre type checker
.pyre/

# vs code
.vscode

# datasets
benchmarks/data/datasets
35 changes: 35 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Reproduction of results

Install modules:

pip install -U -r requirements.txt

Downloading datasets:

python data/loader.py

Running the experiment:

python runner.py --dataset {dataset name} --method {method name} --max-iter {number of iterations} --dir {directory for results} --trials {number of trials} --n_jobs {the number of worker processes to use}

`runner.py` script parameters:

1. --dataset – one or more from the list:

(`balance`, `bank-marketing`, `banknote`, `breast-cancer`, `car-evaluation`, `cnae9`, `credit-approval`,
`digits`, `ecoli`, `parkinsons`, `semeion`, `statlog-segmentation`, `wilt`, `zoo`)

2. **--method** – either `svc`, or `xgb`, or `mlp`
3. **--max-iter** – number of iterations
4. **--dir** – directory in which tables with results will be saved (by default this will be the `result` folder)
5. **--trials** – the number of trials in non-deterministic algorithms (`hyperopt`, `optuna`)
6. **--n_jobs** – the number of worker processes to use


## Launch example

We run the `svc` method with the `breast-cancer` and `zoo` datasets, the maximum number of iterations is `200`, trials with non-deterministic algorithms are `10`, the number of worker processes to use is `12`.

python runner.py --dataset breast-cancer zoo --method svc --max-iter 200 --trials 10, --n-jobs 12

Once completed, the script will create two tables with the resulting metrics (`result/metrics.csv`) and times (`result/times.csv`). If the algorithm is non-deterministic, the table contains the mean with standard deviation.
122 changes: 122 additions & 0 deletions benchmarks/argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import data

from argparse import ArgumentParser
from dataclasses import dataclass, field

from hyperparams import Hyperparameter, Numerical, Categorial
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.neural_network import MLPClassifier
from functools import partial


METHOD_TO_HYPERPARAMS = {
SVC: {
'gamma': Numerical('float', 1e-9, 1e-6, is_log_scale=True),
'C': Numerical('int', 1, 1e10, is_log_scale=True),
'kernel': Categorial('poly', 'rbf', 'sigmoid')
},

XGBClassifier: {
'n_estimators': Numerical('int', 10, 200),
'max_depth': Numerical('int', 5, 20),
'min_child_weight': Numerical('int', 1, 10),
'gamma': Numerical('float', 0.01, 0.6),
'subsample': Numerical('float', 0.05, 0.95),
'colsample_bytree': Numerical('float', 0.05, 0.95),
'learning_rate': Numerical('float', 0.001, 0.1, is_log_scale=True)
},

MLPClassifier: {
'hidden_layer_sizes': Numerical('int', 2, 150),
'activation': Categorial('identity', 'logistic', 'tanh', 'relu'),
'solver': Categorial('lbfgs', 'sgd', 'adam'),
'alpha': Numerical('float', 1e-9, 1e-1, is_log_scale=True)
}
}


NAME_TO_DATASET = {
'balance': data.Balance,
'bank-marketing': data.BankMarketing,
'banknote': data.Banknote,
'breast-cancer': data.BreastCancer,
'car-evaluation': data.CarEvaluation,
'cnae9': data.CNAE9,
'credit-approval': data.CreditApproval,
'digits': data.Digits,
'ecoli': data.Ecoli,
'parkinsons': data.Parkinsons,
'semeion': data.Semeion,
'statlog-segmentation': data.StatlogSegmentation,
'wilt': data.Wilt,
'zoo': data.Zoo
}


@dataclass
class ConsoleArgument:
max_iter: int
estimator: SVC | XGBClassifier | MLPClassifier
dataset: data.Dataset
hyperparams: Hyperparameter = field(init=False)
dir: str
trials: int
n_jobs: int

def __post_init__(self):
estimator = self.estimator
if isinstance(estimator, partial):
estimator = estimator.func
self.hyperparams = METHOD_TO_HYPERPARAMS[estimator]


def get_estimator(name: str) -> SVC | XGBClassifier | MLPClassifier:
if name == 'svc':
return partial(SVC, max_iter=1000)
elif name == 'xgb':
return partial(XGBClassifier, n_jobs=1)
elif name == 'mlp':
return MLPClassifier
raise ValueError(f'Estimator "{name}" do not support')


def get_datasets(names: str) -> data.Dataset:
try:
result = []
for x in names:
result.append(NAME_TO_DATASET[x])
return result
except KeyError:
raise ValueError(f' Dataset "{x}" do not support')


def parse_arguments():
"""
--max-iter:
int, positive
--dataset:
names of dataset, see all names in NAME_TO_DATASET dict
--method:
must be or svc, or xgb, or mlp
--dir:
name of the dir to save the results (result by default)
"""
parser = ArgumentParser()
parser.add_argument('--max-iter', type=int)
parser.add_argument('--dataset', nargs='*')
parser.add_argument('--method')
parser.add_argument('--dir', default='result')
parser.add_argument('--trials', type=int, default=1)
parser.add_argument('--n-jobs', type=int, default=1)

args = parser.parse_args()
assert args.max_iter > 0, 'Max iter must be positive'
assert args.trials > 0, 'Trials must be positive'
assert args.n_jobs > 0, 'n_jobs must be positive'

return ConsoleArgument(args.max_iter,
get_estimator(args.method),
get_datasets(args.dataset),
args.dir, args.trials, args.n_jobs)
18 changes: 18 additions & 0 deletions benchmarks/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .loader import (Dataset,
BreastCancer,
Digits,
BankMarketing,
CNAE9,
StatlogSegmentation,
Semeion,
Ecoli,
CreditApproval,
Balance,
Parkinsons,
Zoo,
Banknote,
CarEvaluation,
Wilt)

__all__ = [Dataset, BreastCancer, Digits, BankMarketing, CNAE9, StatlogSegmentation, Semeion, Ecoli,
CreditApproval, Balance, Parkinsons, Zoo, Banknote, CarEvaluation, Wilt]
Loading

0 comments on commit 11febf7

Please sign in to comment.