diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index c08d4e31b8..9186b527ae 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -594,6 +594,52 @@ def merge_with(self, other): import gensim.models return gensim.models.VocabTransform(old2new) + def patch_with_special_tokens(self, special_token_dict): + """Patch token2id and id2token using a dictionary of special tokens. + + + **Usecase:** when doing sequence modeling (e.g. named entity recognition), one may want to specify + special tokens that behave differently than others. + One example is the "unknown" token, and another is the padding token. + It is usual to set the padding token to have index `0`, and patching the dictionary with `{'': 0}` + would be one way to specify this. + + Parameters + ---------- + special_token_dict : dict of (str, int) + dict containing the special tokens as keys and their wanted indices as values. + + Examples + -------- + .. sourcecode:: pycon + + >>> from gensim.corpora import Dictionary + >>> + >>> corpus = [["máma", "mele", "maso"], ["ema", "má", "máma"]] + >>> dct = Dictionary(corpus) + >>> + >>> special_tokens = {'pad': 0, 'space': 1} + >>> print(dct.token2id) + {'maso': 0, 'mele': 1, 'máma': 2, 'ema': 3, 'má': 4} + >>> + >>> dct.patch_with_special_tokens(special_tokens) + >>> print(dct.token2id) + {'maso': 6, 'mele': 7, 'máma': 2, 'ema': 3, 'má': 4, 'pad': 0, 'space': 1} + + """ + possible_ids = [] + for token, idx in special_token_dict.items(): + if token in self.token2id and self.token2id[token] == idx: + continue + if token in self.token2id and self.token2id[token] != idx: + possible_ids.append(self.token2id[token]) + del self.token2id[token] + old_token = self[idx] + self.token2id[token] = idx + self.token2id[old_token] = possible_ids.pop() if \ + len(possible_ids) > 0 else len(self.token2id) - 1 + self.id2token = {} # Make sure that id2token is updated according to special tokens. + @staticmethod def load_from_text(fname): """Load a previously stored :class:`~gensim.corpora.dictionary.Dictionary` from a text file. diff --git a/gensim/test/test_corpora_dictionary.py b/gensim/test/test_corpora_dictionary.py index a7fb170253..e5ec3221fd 100644 --- a/gensim/test/test_corpora_dictionary.py +++ b/gensim/test/test_corpora_dictionary.py @@ -324,6 +324,30 @@ def test_dict_interface(self): self.assertTrue(isinstance(d.keys(), list)) self.assertTrue(isinstance(d.values(), list)) + def test_patch_with_special_tokens(self): + special_tokens = {'pad': 0, 'space': 1, 'quake': 3} + corpus = [["máma", "mele", "maso"], ["ema", "má", "máma"]] + d = Dictionary(corpus) + self.assertEqual(len(d.token2id), 5) + d.patch_with_special_tokens(special_tokens) + self.assertEqual(d.token2id['pad'], 0) + self.assertEqual(d.token2id['space'], 1) + self.assertEqual(d.token2id['quake'], 3) + self.assertEqual(len(d.token2id), 8) + self.assertNotIn((0, 1), d.doc2bow(corpus[0])) + self.assertIn((0, 1), d.doc2bow(['pad'] + corpus[0])) + corpus_with_special_tokens = [["máma", "mele", "maso"], ["ema", "má", "máma", "space"]] + d = Dictionary(corpus_with_special_tokens) + self.assertEqual(len(d.token2id), 6) + self.assertNotEqual(d.token2id['space'], 1) + d.patch_with_special_tokens(special_tokens) + self.assertEqual(len(d.token2id), 8) + self.assertEqual(max(d.token2id.values()), 7) + self.assertEqual(d.token2id['space'], 1) + self.assertNotIn((1, 1), d.doc2bow(corpus_with_special_tokens[0])) + self.assertIn((1, 1), d.doc2bow(corpus_with_special_tokens[1])) + + # endclass TestDictionary