diff --git a/docs/source/data_functional.rst b/docs/source/data_functional.rst index 347f3b433..553bb004d 100644 --- a/docs/source/data_functional.rst +++ b/docs/source/data_functional.rst @@ -41,3 +41,9 @@ torchtext.data.functional ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: numericalize_tokens_from_iterator + + +:hidden:`filter_wikipedia_xml` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: filter_wikipedia_xml diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index adead6c4a..4f45d217a 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -27,9 +27,6 @@ The following datasets are available: Text Classification ^^^^^^^^^^^^^^^^^^^ -TextClassificationDataset -~~~~~~~~~~~~~~~~~~~~~~~~~ - AG_NEWS ~~~~~~~ @@ -126,6 +123,7 @@ CoNLL2000Chunking .. autofunction:: CoNLL2000Chunking + Question Answer ^^^^^^^^^^^^^^^ @@ -139,3 +137,13 @@ SQuAD 2.0 ~~~~~~~~~ .. autofunction:: SQuAD2 + + +Unsupervised Learning +^^^^^^^^^^^^^^^^^^^^^ + +EnWik9 +~~~~~~ + +.. autofunction:: EnWik9 + diff --git a/torchtext/data/functional.py b/torchtext/data/functional.py index 995025e92..48d4d8562 100644 --- a/torchtext/data/functional.py +++ b/torchtext/data/functional.py @@ -180,3 +180,72 @@ def numericalize_tokens_from_iterator(vocab, iterator, removed_tokens=None): else: yield iter(map(lambda x: vocab[x], filter(lambda x: x not in removed_tokens, tokens))) + + +_patterns = [(r'<.*>', ''), + (r'&', '&'), + (r'<', '<'), + (r'>', '>'), + (r'', ''), + (r'<[^>]*>', ''), + (r'\[http:[^] ]*', '['), + (r'\|thumb', ''), + (r'\|left', ''), + (r'\|right', ''), + (r'\|\d+px', ''), + (r'\[\[image:[^\[\]]*\|', ''), + (r'\[\[category:([^|\]]*)[^]]*\]\]', '[[$1]]'), + (r'\[\[[a-z\-]*:[^\]]*\]\]', ''), + (r'\[\[[^\|\]]*\|', '[['), + (r'\{\{[^\}]*\}\}', ''), + (r'\{[^\}]*\}', ''), + (r'\[', ''), + (r'\]', ''), + (r'&[^;]*;', ' '), + (r'A', 'a'), (r'B', 'b'), (r'C', 'c'), + (r'D', 'd'), (r'E', 'e'), (r'F', 'f'), + (r'G', 'g'), (r'H', 'h'), (r'I', 'i'), + (r'J', 'j'), (r'K', 'k'), (r'L', 'l'), + (r'M', 'm'), (r'N', 'n'), (r'O', 'o'), + (r'P', 'p'), (r'Q', 'q'), (r'R', 'r'), + (r'S', 's'), (r'T', 't'), (r'U', 'u'), + (r'V', 'v'), (r'W', 'w'), (r'X', 'x'), + (r'Y', 'y'), (r'Z', 'z'), + (r'0', ' zero '), (r'1', ' one '), (r'2', ' two '), + (r'3', ' three '), (r'4', ' four '), (r'5', ' five '), + (r'6', ' six '), (r'7', ' seven '), (r'8', ' eight '), + (r'9', ' nine '), + (r'[^a-z\n]+', ' '), + (r'\n ', ''), + (r'\s+', ' '), + (r'\n\s*\n', r'\n') + ] + + +def filter_wikipedia_xml(text_iterator): + r"""Filter wikipedia xml lines according to https://github.com/facebookresearch/fastText/blob/master/wikifil.pl + + args: + text_iterator: An iterator type object that yields strings. Examples include string list, text io, generators etc. + + Examples: + >>> from torchtext.data.functional import filter_wikipedia_xml + >>> from torchtext.datasets import EnWik9 + >>> data_iter = EnWik9(split='train') + >>> filter_data_iter = filter_wikipedia_xml(data_iter) + >>> file_name = '.data/EnWik9/enwik9' + >>> filter_data_iter = filter_wikipedia_xml(open(file_name,'r')) + """ + + try: + iter(text_iterator) + except: + raise TypeError("Input {} must support iterator semantics".format(text_iterator)) + + norm_transform = custom_replace(_patterns) + for line in text_iterator: + if '#redirect' in line or '#REDIRECT' in line: + continue + line = list(norm_transform([line]))[0].strip() + if line: + yield line