forked from vaseem-khan/URLcheck
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
82 lines (56 loc) · 2.16 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pandas
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
import numpy
from sklearn import svm
from sklearn.metrics import accuracy_score
import matplotlib.pylab as plt
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning,
module="pandas", lineno=570)
def return_nonstring_col(data_cols):
cols_to_keep=[]
train_cols=[]
for col in data_cols:
if col!='URL' and col!='host' and col!='path':
cols_to_keep.append(col)
if col!='malicious' and col!='result':
train_cols.append(col)
return [cols_to_keep,train_cols]
def svm_classifier(train,query,train_cols):
clf = svm.SVC()
# scaler = preprocessing.StandardScaler().fit(train[train_cols])
# scaler.transform(train[train_cols])
print clf.fit(train[train_cols], train['malicious'])
query['result']=clf.predict(query[train_cols])
print query[['URL','result']]
#return result
def forest_classifier_return(train,query,train_cols):
rf = RandomForestClassifier(n_estimators=150)
print rf.fit(train[train_cols], train['malicious'])
query['result']=rf.predict(query[train_cols])
print query[['URL','result']].head(2)
return query['result']
def forest_classifier(train,query,train_cols):
rf = RandomForestClassifier(n_estimators=150)
print rf.fit(train[train_cols], train['malicious'])
query['result']=rf.predict(query[train_cols])
print query[['URL','result']]
def train(db,test_db):
query_csv = pandas.read_csv(test_db)
cols_to_keep,train_cols=return_nonstring_col(query_csv.columns)
#query=query_csv[cols_to_keep]
train_csv = pandas.read_csv(db)
cols_to_keep,train_cols=return_nonstring_col(train_csv.columns)
train=train_csv[cols_to_keep]
svm_classifier(train_csv,query_csv,train_cols)
print "done svm"
forest_classifier(train_csv,query_csv,train_cols)
def train2(db,test_db):
query_csv = pandas.read_csv(test_db)
cols_to_keep,train_cols=return_nonstring_col(query_csv.columns)
#query=query_csv[cols_to_keep]
train_csv = pandas.read_csv(db)
cols_to_keep,train_cols=return_nonstring_col(train_csv.columns)
train=train_csv[cols_to_keep]
return forest_classifier_return(train_csv,query_csv,train_cols)