-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_test.py
59 lines (48 loc) · 1.81 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# coding: utf-8
import logging
from src import settings
from src.models import Word2vec
from src.models import RandomForest
from src import DataLoader
from src.preprocessors import tokenizer_spacy_en # noqa: F401
from src.preprocessors import tokenizer_re # noqa: F401
from src.preprocessors import avg_words_vectors
logger = logging.getLogger(__name__)
logger.info("Dataset Downloading...")
DataLoader.download_from_gdrive(
url=settings.DATASET_URL,
destination=settings.DATASET_FOLDER / settings.DATASET_NAME
)
logger.info("Dataset Loading...")
dataloader = DataLoader(settings.DATASET_FOLDER / settings.DATASET_NAME)
dataloader.load_dataframe_from_csv()
logger.info("Dataset Processing...")
dataloader.dataframe['message_words'] = dataloader.dataframe['message'].progress_apply(lambda x: tokenizer_re(x.lower())) # noqa: E501
print(dataloader.dataframe['message_words'])
logger.info("train word2vec model")
word2vec = Word2vec()
word2vec.train(
dataloader.dataframe['message_words'],
vector_size=300,
window=5,
min_count=1,
workers=4,
)
logger.info("Dataset Vectorizing...")
dataloader.dataframe['avg_words_vectors'] = dataloader.dataframe['message_words'].progress_apply( # noqa: E501
lambda x: avg_words_vectors(word2vec=word2vec, words_list=x, vector_size=300) # noqa: E501
)
logger.info("Dataset Splitting...")
dataloader.split_dataframe(
features_names=['avg_words_vectors'],
target_name='label',
test_size=0.2,
seed=42,
)
x_train = list(dataloader.x_train['avg_words_vectors'])
x_test = list(dataloader.x_test['avg_words_vectors'])
logger.info("Random Forest Training & Test")
forest = RandomForest(n_estimators=100)
forest.train(x_train, dataloader.y_train)
result = forest.test(x_test, dataloader.y_test, pos_label=1, average='binary')
logger.info(result)