-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathsort_samples_by_kenlm.py
74 lines (56 loc) · 2.04 KB
/
sort_samples_by_kenlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
'''
Постобработка списков предложения для репозитория https://github.com/Koziev/NLP_Datasets
Сортируем предложения в порядке убывания вероятности, получаемой с помощью
предварительно обученной языковой модели KenLM.
'''
from __future__ import print_function
import sys
import os
import codecs
import itertools
import unicodedata
import kenlm
import nltk
import glob
def tokenize(s):
for word in nltk.word_tokenize(s):
yield word
def is_punct(word):
return len(word)>0 and unicodedata.category(word[0]) == 'Po'
def is_num(s):
if len(s)>0:
if s[0] in ('-', '+'):
return s[1:].isdigit()
return s.isdigit()
else:
return False
def prepare_word(w):
if is_num(w):
return u'_num_'
else:
return w.lower()
def prepare4lm(tokens):
s2 = unicode.join(u' ', [prepare_word(t) for t in tokens if not is_punct(t) and len(t)>0 ])
return s2
# -------------------------------------------------------------------
model_filepath = '/home/eek/polygon/kenlm/ru.text.arpa'
print('Loading the language model {}...'.format(model_filepath) )
model = kenlm.Model( model_filepath )
for filename in glob.glob('./*.txt'):
print(u'Processing {}'.format(filename))
sent_set = set()
sent_list = []
with codecs.open(filename, 'r', 'utf-8') as rdr:
for line in rdr:
line2 = prepare4lm( tokenize(line.strip()) )
if line2 not in sent_set:
sent_set.add(line2)
score = model.score(line2, bos=True, eos=True)
sent_list.append((line2,score))
new_filename = filename.replace(u'.txt', u'.sorted')
print(u'Storing {} lines to {}'.format(len(sent_list), new_filename))
with codecs.open(new_filename, 'w', 'utf-8') as wrt:
for (sent, _) in sorted(sent_list, key=lambda z:-z[1]):
wrt.write(sent)
wrt.write('\n')