-
Notifications
You must be signed in to change notification settings - Fork 186
/
train_word2vec.py
69 lines (55 loc) · 2.05 KB
/
train_word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
""" pretrain a word2vec on the corpus"""
import argparse
import json
import logging
import os
from os.path import join, exists
from time import time
from datetime import timedelta
from cytoolz import concatv
import gensim
from utils import count_data
try:
DATA_DIR = os.environ['DATA']
except KeyError:
print('please use environment variable to specify data directories')
class Sentences(object):
""" needed for gensim word2vec training"""
def __init__(self):
self._path = join(DATA_DIR, 'train')
self._n_data = count_data(self._path)
def __iter__(self):
for i in range(self._n_data):
with open(join(self._path, '{}.json'.format(i))) as f:
data = json.loads(f.read())
for s in concatv(data['article'], data['abstract']):
yield ['<s>'] + s.lower().split() + [r'<\s>']
def main(args):
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
level=logging.INFO)
start = time()
save_dir = args.path
if not exists(save_dir):
os.makedirs(save_dir)
sentences = Sentences()
model = gensim.models.Word2Vec(
size=args.dim, min_count=5, workers=16, sg=1)
model.build_vocab(sentences)
print('vocab built in {}'.format(timedelta(seconds=time()-start)))
model.train(sentences,
total_examples=model.corpus_count, epochs=model.iter)
model.save(join(save_dir, 'word2vec.{}d.{}k.bin'.format(
args.dim, len(model.wv.vocab)//1000)))
model.wv.save_word2vec_format(join(
save_dir,
'word2vec.{}d.{}k.w2v'.format(args.dim, len(model.wv.vocab)//1000)
))
print('word2vec trained in {}'.format(timedelta(seconds=time()-start)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='train word2vec embedding used for model initialization'
)
parser.add_argument('--path', required=True, help='root of the model')
parser.add_argument('--dim', action='store', type=int, default=128)
args = parser.parse_args()
main(args)