Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

308 random sample on stream #1408

Merged
merged 3 commits into from
Jun 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions gensim/corpora/textcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from __future__ import with_statement

import logging
import random

from gensim import interfaces, utils
from six import string_types
Expand Down Expand Up @@ -97,6 +98,28 @@ def get_texts(self):
else:
yield utils.tokenize(line, lowercase=True)

def sample_texts(self, n):
"""
Yield n random texts from the corpus without replacement.

Given the the number of remaingin elements in stream is remaining and we need
to choose n elements, the probability for current element to be chosen is n/remaining.
If we choose it, we just decreese the n and move to the next element.
"""
length = len(self)
if not n <= length:
raise ValueError("sample larger than population")

if not 0 <= n:
raise ValueError("negative sample size")

for i, sample in enumerate(self.get_texts()):
remaining_in_stream = length - i
chance = random.randint(1, remaining_in_stream)
if chance <= n:
n -= 1
yield sample

def __len__(self):
if not hasattr(self, 'length'):
# cache the corpus length
Expand Down
55 changes: 55 additions & 0 deletions gensim/test/test_textcorpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking the WikiCorpus
"""


import logging
import unittest

from gensim.corpora.textcorpus import TextCorpus


logger = logging.getLogger(__name__)


class TestTextCorpus(unittest.TestCase):
# TODO add tests for other methods

def test_sample_text(self):
class TestTextCorpus(TextCorpus):
def __init__(self):
self.data = [["document1"], ["document2"]]

def get_texts(self):
for document in self.data:
yield document

corpus = TestTextCorpus()

sample1 = list(corpus.sample_texts(1))
self.assertEqual(len(sample1), 1)
document1 = sample1[0] == ["document1"]
document2 = sample1[0] == ["document2"]
self.assertTrue(document1 or document2)

sample2 = list(corpus.sample_texts(2))
self.assertEqual(len(sample2), 2)
self.assertEqual(sample2[0], ["document1"])
self.assertEqual(sample2[1], ["document2"])

with self.assertRaises(ValueError):
list(corpus.sample_texts(3))

with self.assertRaises(ValueError):
list(corpus.sample_texts(-1))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()
4 changes: 3 additions & 1 deletion gensim/test/test_wikicorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from gensim.corpora.wikicorpus import WikiCorpus


module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
datapath = lambda fname: os.path.join(module_path, 'test_data', fname)
FILENAME = 'enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2'
FILENAME_U = 'bgwiki-latest-pages-articles-shortened.xml.bz2'

logger = logging.getLogger(__name__)


class TestWikiCorpus(unittest.TestCase):

# #TODO: sporadic failure to be investigated
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_unicode_element(self):
l = wc.get_texts()
self.assertTrue(u'папа' in next(l))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()