From cac63143960c4d80b10e54d478a48275e5f95595 Mon Sep 17 00:00:00 2001 From: vlejd Date: Thu, 8 Jun 2017 21:01:23 +0200 Subject: [PATCH] Fix Dictionary save_as_text method #56 + fix lint errors save_as_text now writes num_docs on the first line. load_as_text loads it in backward compatible way. --- gensim/corpora/dictionary.py | 26 ++++++---- gensim/test/test_corpora_dictionary.py | 72 +++++++++++++++++--------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index 484684c26d..644dac3c42 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -24,13 +24,13 @@ from gensim import utils -if sys.version_info[0] >= 3: - unicode = str - from six import PY3, iteritems, iterkeys, itervalues, string_types from six.moves import xrange from six.moves import zip as izip +if sys.version_info[0] >= 3: + unicode = str + logger = logging.getLogger('gensim.corpora.dictionary') @@ -180,7 +180,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N 2. more than `no_above` documents (fraction of total corpus size, *not* absolute number). 3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of - the `no_below` and `no_above` settings + the `no_below` and `no_above` settings 4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or keep all if `None`). @@ -194,9 +194,9 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N # determine which tokens to keep if keep_tokens: keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id] - good_ids = (v for v in itervalues(self.token2id) - if no_below <= self.dfs.get(v, 0) <= no_above_abs - or v in keep_ids) + good_ids = (v for v in itervalues(self.token2id) + if no_below <= self.dfs.get(v, 0) <= no_above_abs or + v in keep_ids) else: good_ids = ( v for v in itervalues(self.token2id) @@ -230,7 +230,7 @@ def filter_n_most_frequent(self, remove_n): # do the actual filtering, then rebuild dictionary to remove gaps in ids most_frequent_words = [(self[id], self.dfs.get(id, 0)) for id in most_frequent_ids] logger.info("discarding %i tokens: %s...", len(most_frequent_ids), most_frequent_words[:10]) - + self.filter_tokens(bad_ids=most_frequent_ids) logger.info("resulting dictionary: %s" % self) @@ -280,6 +280,7 @@ def compactify(self): def save_as_text(self, fname, sort_by_word=True): """ Save this Dictionary to a text file, in format: + `num_docs` `id[TAB]word_utf8[TAB]document frequency[NEWLINE]`. Sorted by word, or by decreasing word frequency. @@ -288,12 +289,14 @@ def save_as_text(self, fname, sort_by_word=True): """ logger.info("saving dictionary mapping to %s", fname) with utils.smart_open(fname, 'wb') as fout: + numdocs_line = "%d\n" % self.num_docs + fout.write(utils.to_utf8(numdocs_line)) if sort_by_word: for token, tokenid in sorted(iteritems(self.token2id)): line = "%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0)) fout.write(utils.to_utf8(line)) else: - for tokenid, freq in sorted(iteritems(self.dfs), key=lambda item: -item[1]): + for tokenid, freq in sorted(iteritems(self.dfs), key=lambda item: item[1]): line = "%i\t%s\t%i\n" % (tokenid, self[tokenid], freq) fout.write(utils.to_utf8(line)) @@ -352,6 +355,11 @@ def load_from_text(fname): with utils.smart_open(fname) as f: for lineno, line in enumerate(f): line = utils.to_unicode(line) + if lineno == 0: + if line.strip().isdigit(): + # Older versions of save_as_text may not write num_docs on first line. + result.num_docs = int(line.strip()) + continue try: wordid, word, docfreq = line[:-1].split('\t') except Exception: diff --git a/gensim/test/test_corpora_dictionary.py b/gensim/test/test_corpora_dictionary.py index 16c499b245..576525a728 100644 --- a/gensim/test/test_corpora_dictionary.py +++ b/gensim/test/test_corpora_dictionary.py @@ -120,35 +120,34 @@ def testFilter(self): d.filter_extremes(no_below=2, no_above=1.0, keep_n=4) expected = {0: 3, 1: 3, 2: 3, 3: 3} self.assertEqual(d.dfs, expected) - + def testFilterKeepTokens_keepTokens(self): # provide keep_tokens argument, keep the tokens given d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey']) expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey']) self.assertEqual(set(d.token2id.keys()), expected) - + def testFilterKeepTokens_unchangedFunctionality(self): # do not provide keep_tokens argument, filter_extremes functionality is unchanged d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0) expected = set(['graph', 'trees', 'system', 'user']) self.assertEqual(set(d.token2id.keys()), expected) - + def testFilterKeepTokens_unseenToken(self): # do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged d = Dictionary(self.texts) d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token']) expected = set(['graph', 'trees', 'system', 'user']) - self.assertEqual(set(d.token2id.keys()), expected) + self.assertEqual(set(d.token2id.keys()), expected) def testFilterMostFrequent(self): - d = Dictionary(self.texts) - d.filter_n_most_frequent(4) - expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2} - self.assertEqual(d.dfs, expected) - - + d = Dictionary(self.texts) + d.filter_n_most_frequent(4) + expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2} + self.assertEqual(d.dfs, expected) + def testFilterTokens(self): self.maxDiff = 10000 d = Dictionary(self.texts) @@ -157,8 +156,8 @@ def testFilterTokens(self): d.filter_tokens([0]) expected = {'computer': 0, 'eps': 8, 'graph': 10, 'human': 1, - 'interface': 2, 'minors': 11, 'response': 3, 'survey': 4, - 'system': 5, 'time': 6, 'trees': 9, 'user': 7} + 'interface': 2, 'minors': 11, 'response': 3, 'survey': 4, + 'system': 5, 'time': 6, 'trees': 9, 'user': 7} del expected[removed_word] self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) @@ -166,7 +165,6 @@ def testFilterTokens(self): d.add_documents([[removed_word]]) self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) - def test_doc2bow(self): d = Dictionary([["žluťoučký"], ["žluťoučký"]]) @@ -179,6 +177,32 @@ def test_doc2bow(self): # unicode must be converted to utf8 self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)]) + def test_saveAsText(self): + """`Dictionary` can be saved as textfile. """ + tmpf = get_tmpfile('save_dict_test.txt') + small_text = [["prvé", "slovo"], + ["slovo", "druhé"], + ["druhé", "slovo"]] + + d = Dictionary(small_text) + + d.save_as_text(tmpf) + serialized_lines = open(tmpf).readlines() + self.assertEqual(serialized_lines[0], "3\n") + self.assertEqual(len(serialized_lines), 4) + # We do not know, which word will have which index + self.assertEqual(serialized_lines[1][1:], "\tdruhé\t2\n") + self.assertEqual(serialized_lines[2][1:], "\tprvé\t1\n") + self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n") + + d.save_as_text(tmpf, sort_by_word=False) + serialized_lines = open(tmpf).readlines() + self.assertEqual(serialized_lines[0], "3\n") + self.assertEqual(len(serialized_lines), 4) + self.assertEqual(serialized_lines[1][1:], "\tprvé\t1\n") + self.assertEqual(serialized_lines[2][1:], "\tdruhé\t2\n") + self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n") + def test_saveAsText_and_loadFromText(self): """`Dictionary` can be saved as textfile and loaded again from textfile. """ tmpf = get_tmpfile('dict_test.txt') @@ -195,23 +219,23 @@ def test_from_corpus(self): """build `Dictionary` from an existing corpus""" documents = ["Human machine interface for lab abc computer applications", - "A survey of user opinion of computer system response time", - "The EPS user interface management system", - "System and human system engineering testing of EPS", - "Relation of user perceived response time to error measurement", - "The generation of random binary unordered trees", - "The intersection graph of paths in trees", - "Graph minors IV Widths of trees and well quasi ordering", - "Graph minors A survey"] + "A survey of user opinion of computer system response time", + "The EPS user interface management system", + "System and human system engineering testing of EPS", + "Relation of user perceived response time to error measurement", + "The generation of random binary unordered trees", + "The intersection graph of paths in trees", + "Graph minors IV Widths of trees and well quasi ordering", + "Graph minors A survey"] stoplist = set('for a of the and to in'.split()) texts = [[word for word in document.lower().split() if word not in stoplist] - for document in documents] + for document in documents] # remove words that appear only once all_tokens = sum(texts, []) tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1) texts = [[word for word in text if word not in tokens_once] - for text in texts] + for text in texts] dictionary = Dictionary(texts) corpus = [dictionary.doc2bow(text) for text in texts] @@ -260,7 +284,7 @@ def test_dict_interface(self): self.assertTrue(isinstance(d.keys(), list)) self.assertTrue(isinstance(d.values(), list)) -#endclass TestDictionary +# endclass TestDictionary if __name__ == '__main__':