Skip to content

Commit

Permalink
Fix Dictionary save_as_text method #56 + fix lint errors
Browse files Browse the repository at this point in the history
save_as_text now writes num_docs on the first line.
load_as_text loads it in backward compatible way.
  • Loading branch information
vlejd committed Jun 8, 2017
1 parent f98e012 commit 2351a90
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 33 deletions.
28 changes: 19 additions & 9 deletions gensim/corpora/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

from gensim import utils

if sys.version_info[0] >= 3:
unicode = str

from six import PY3, iteritems, iterkeys, itervalues, string_types
from six.moves import xrange
from six.moves import zip as izip

if sys.version_info[0] >= 3:
unicode = str


logger = logging.getLogger('gensim.corpora.dictionary')

Expand Down Expand Up @@ -180,7 +180,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
2. more than `no_above` documents (fraction of total corpus size, *not*
absolute number).
3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of
the `no_below` and `no_above` settings
the `no_below` and `no_above` settings
4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or
keep all if `None`).
Expand All @@ -194,9 +194,9 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
# determine which tokens to keep
if keep_tokens:
keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id]
good_ids = (v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs
or v in keep_ids)
good_ids = (v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs or
v in keep_ids)
else:
good_ids = (
v for v in itervalues(self.token2id)
Expand Down Expand Up @@ -230,7 +230,7 @@ def filter_n_most_frequent(self, remove_n):
# do the actual filtering, then rebuild dictionary to remove gaps in ids
most_frequent_words = [(self[id], self.dfs.get(id, 0)) for id in most_frequent_ids]
logger.info("discarding %i tokens: %s...", len(most_frequent_ids), most_frequent_words[:10])

self.filter_tokens(bad_ids=most_frequent_ids)
logger.info("resulting dictionary: %s" % self)

Expand Down Expand Up @@ -280,6 +280,7 @@ def compactify(self):
def save_as_text(self, fname, sort_by_word=True):
"""
Save this Dictionary to a text file, in format:
`num_docs`
`id[TAB]word_utf8[TAB]document frequency[NEWLINE]`. Sorted by word,
or by decreasing word frequency.
Expand All @@ -288,12 +289,14 @@ def save_as_text(self, fname, sort_by_word=True):
"""
logger.info("saving dictionary mapping to %s", fname)
with utils.smart_open(fname, 'wb') as fout:
numdocs_line = "%d\n" % self.num_docs
fout.write(utils.to_utf8(numdocs_line))
if sort_by_word:
for token, tokenid in sorted(iteritems(self.token2id)):
line = "%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0))
fout.write(utils.to_utf8(line))
else:
for tokenid, freq in sorted(iteritems(self.dfs), key=lambda item: -item[1]):
for tokenid, freq in sorted(iteritems(self.dfs), key=lambda item: item[1]):
line = "%i\t%s\t%i\n" % (tokenid, self[tokenid], freq)
fout.write(utils.to_utf8(line))

Expand Down Expand Up @@ -352,6 +355,13 @@ def load_from_text(fname):
with utils.smart_open(fname) as f:
for lineno, line in enumerate(f):
line = utils.to_unicode(line)
if lineno == 0:
if line.strip().isdigit():
# Older versions of save_as_text may not write num_docs on first line.
result.num_docs = int(line.strip())
continue
else:
logging.warning("Text does not contain num_docs on the first line.")
try:
wordid, word, docfreq = line[:-1].split('\t')
except Exception:
Expand Down
98 changes: 74 additions & 24 deletions gensim/test/test_corpora_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,35 +120,34 @@ def testFilter(self):
d.filter_extremes(no_below=2, no_above=1.0, keep_n=4)
expected = {0: 3, 1: 3, 2: 3, 3: 3}
self.assertEqual(d.dfs, expected)

def testFilterKeepTokens_keepTokens(self):
# provide keep_tokens argument, keep the tokens given
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey'])
expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unchangedFunctionality(self):
# do not provide keep_tokens argument, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0)
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unseenToken(self):
# do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token'])
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterMostFrequent(self):
d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)


d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)

def testFilterTokens(self):
self.maxDiff = 10000
d = Dictionary(self.texts)
Expand All @@ -157,16 +156,15 @@ def testFilterTokens(self):
d.filter_tokens([0])

expected = {'computer': 0, 'eps': 8, 'graph': 10, 'human': 1,
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
del expected[removed_word]
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))

expected[removed_word] = len(expected)
d.add_documents([[removed_word]])
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))


def test_doc2bow(self):
d = Dictionary([["žluťoučký"], ["žluťoučký"]])

Expand All @@ -179,6 +177,58 @@ def test_doc2bow(self):
# unicode must be converted to utf8
self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)])

def test_saveAsText(self):
"""`Dictionary` can be saved as textfile. """
tmpf = get_tmpfile('save_dict_test.txt')
small_text = [["prvé", "slovo"],
["slovo", "druhé"],
["druhé", "slovo"]]

d = Dictionary(small_text)

d.save_as_text(tmpf)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
# We do not know, which word will have which index
self.assertEqual(serialized_lines[1][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[2][1:], "\tprvé\t1\n")
self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n")

d.save_as_text(tmpf, sort_by_word=False)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
self.assertEqual(serialized_lines[1][1:], "\tprvé\t1\n")
self.assertEqual(serialized_lines[2][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n")

def test_loadFromText(self):
tmpf = get_tmpfile('load_dict_test.txt')
no_num_docs_serialization = "1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id["prvé"], 1)
self.assertEqual(d.token2id["slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 0)

no_num_docs_serialization = "2\n1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id["prvé"], 1)
self.assertEqual(d.token2id["slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 2)

def test_saveAsText_and_loadFromText(self):
"""`Dictionary` can be saved as textfile and loaded again from textfile. """
tmpf = get_tmpfile('dict_test.txt')
Expand All @@ -195,23 +245,23 @@ def test_from_corpus(self):
"""build `Dictionary` from an existing corpus"""

documents = ["Human machine interface for lab abc computer applications",
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
stoplist = set('for a of the and to in'.split())
texts = [[word for word in document.lower().split() if word not in stoplist]
for document in documents]
for document in documents]

# remove words that appear only once
all_tokens = sum(texts, [])
tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1)
texts = [[word for word in text if word not in tokens_once]
for text in texts]
for text in texts]

dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
Expand Down Expand Up @@ -260,7 +310,7 @@ def test_dict_interface(self):
self.assertTrue(isinstance(d.keys(), list))
self.assertTrue(isinstance(d.values(), list))

#endclass TestDictionary
# endclass TestDictionary


if __name__ == '__main__':
Expand Down

0 comments on commit 2351a90

Please sign in to comment.