-
Notifications
You must be signed in to change notification settings - Fork 0
/
alert-train-classifier.py
executable file
·40 lines (31 loc) · 1.21 KB
/
alert-train-classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
import pickle
import sys
if len(sys.argv) != 4:
print("Usage:", sys.argv[0], "<csv_file> <model> <model_file>",
file = sys.stderr)
print("<model> = rf|dt|abdt|gbdt", file = sys.stderr)
sys.exit(0)
training_set = pd.read_csv(sys.argv[1])
X = training_set.drop(columns=['Timestamp', 'SignatureText', 'Label'])
y = training_set['Label']
if sys.argv[2] == "rf":
clf = RandomForestClassifier(n_estimators=100, random_state=1)
elif sys.argv[2] == "dt":
clf = DecisionTreeClassifier(random_state=1)
elif sys.argv[2] == "abdt":
clf = AdaBoostClassifier(DecisionTreeClassifier(max_depth=3),
n_estimators=100, random_state=1)
elif sys.argv[2] == "gbdt":
clf = GradientBoostingClassifier(n_estimators=300, max_depth=2,
random_state=1)
else:
print("Unknown model", sys.argv[2], file = sys.stderr)
sys.exit(1)
model = clf.fit(X, y)
with open(sys.argv[3], 'wb') as file:
pickle.dump(model, file)