-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtuning_svm.py
52 lines (44 loc) · 2.12 KB
/
tuning_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python
# Created by "Thieu" at 13:53, 06/09/2022 ----------%
# Email: nguyenthieu2102@gmail.com %
# Github: https://github.com/thieu1995 %
# --------------------------------------------------%
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn import svm
from config import Config
from utils.data_util import split_dataset
from utils.math_util import get_combinations
from utils.io_util import save_results_to_csv
from permetrics.classification import ClassificationMetric
import warnings
warnings.filterwarnings('ignore')
# Load Train and Test datasets
# Identify feature and response variable(s) and values must be numeric and numpy arrays
data = pd.read_csv('data/input_data/inflow_by_mean.csv')
list_features = ['value', 'value-1', 'value-2', 'value-3', 'value-4', 'value-5',
'value-6', 'value-7', 'value-8', 'value-9', 'value-10', 'value-11', 'value-12']
list_all_features = get_combinations(list_features)
y_output = 'label+1'
for idx, features in enumerate(list_all_features):
if len(features) <= 3:
continue
x_train, x_test, y_train, y_test, lb_encoder = split_dataset(data, features, y_output)
# defining parameter range
param_grid = {'C': [0.1, 1, 10, 100, 1000],
'gamma': [1, 0.1, 0.01, 0.001, 0.0001],
'kernel': ['rbf', 'linear']}
grid = GridSearchCV(svm.SVC(), param_grid, refit=True, verbose=3)
grid.fit(x_train, y_train)
mm1 = {
"features": features,
"best_params": grid.best_params_,
"best_estimator": grid.best_estimator_
}
predicted = grid.predict(x_test)
evaluator = ClassificationMetric(y_test, predicted)
metrics = ["AS", "PS", "RS", "F1S", "F2S", "MCC", "LS"]
paras = [{"average": "micro"}, ] * len(metrics)
mm2 = evaluator.get_metrics_by_list_names(metrics, paras)
mm = {**mm1, **mm2}
save_results_to_csv(mm, f"svm-tuning-results_mean_threshold", "data/history")