diff --git a/chatterbot/adapters/storage/__init__.py b/chatterbot/adapters/storage/__init__.py index d417ef541..2fe7e08b9 100644 --- a/chatterbot/adapters/storage/__init__.py +++ b/chatterbot/adapters/storage/__init__.py @@ -2,4 +2,3 @@ from .django_storage import DjangoStorageAdapter from .jsonfile import JsonFileStorageAdapter from .mongodb import MongoDatabaseAdapter -from .twitter_storage import TwitterAdapter diff --git a/chatterbot/adapters/storage/twitter_storage.py b/chatterbot/adapters/storage/twitter_storage.py deleted file mode 100644 index 1b4b4e237..000000000 --- a/chatterbot/adapters/storage/twitter_storage.py +++ /dev/null @@ -1,118 +0,0 @@ -from chatterbot.adapters.storage import StorageAdapter -from chatterbot.conversation import Statement, Response -import random -import twitter - - -class TwitterAdapter(StorageAdapter): - """ - The TwitterAdapter allows ChatterBot to read tweets from twitter. - """ - - def __init__(self, **kwargs): - super(TwitterAdapter, self).__init__(**kwargs) - - self.api = twitter.Api( - consumer_key=kwargs.get('twitter_consumer_key'), - consumer_secret=kwargs.get('twitter_consumer_secret'), - access_token_key=kwargs.get('twitter_access_token_key'), - access_token_secret=kwargs.get('twitter_access_token_secret') - ) - - self.adapter_supports_queries = False - - def count(self): - return 1 - - def find(self, statement_text): - tweets = self.api.GetSearch(term=statement_text, count=1) - - if tweets: - return Statement(tweets[0].text, in_response_to=[ - Response(statement_text) - ]) - - return None - - def filter(self, **kwargs): - """ - Returns a list of statements in the database - that match the parameters specified. - """ - statement_text = kwargs.get('text') - - # if not statement_text: - # statement_text = kwargs.get('in_response_to__contains') - # data['in_reply_to_status_id_str'] - - # If no text parameter was given get a selection of recent tweets - if not statement_text: - statements = self.get_random(number=20) - return statements - - tweets = self.api.GetSearch(term=statement_text) - tweet = random.choice(tweets) - - statement = Statement(tweet.text, in_response_to=[ - Response(statement_text) - ]) - - return [statement] - - def update(self, statement): - return statement - - def choose_word(self, words): - """ - Light weight search for a valid word if one exists. - """ - for word in words: - # If the word contains only letters with a length from 4 to 9 - if word.isalpha() and len(word) > 3 and len(word) <= 9: - return word - - return None - - def get_random(self, number=1): - """ - Returns a random statement from the api. - To generate a random tweet, search twitter for recent tweets - containing the term 'random'. Then randomly select one tweet - from the current set of tweets. Randomly choose one word from - the selected random tweet, and make a second search request. - Return one random tweet selected from the search results. - """ - statements = [] - tweets = self.api.GetSearch(term="random", count=5) - - tweet = random.choice(tweets) - base_response = Response(text=tweet.text) - - words = tweet.text.split() - word = self.choose_word(words) - - # If a valid word is found, make a second search request - # TODO: What if a word is not found? - if word: - tweets = self.api.GetSearch(term=word, count=number) - if tweets: - for tweet in tweets: - # TODO: Handle non-ascii characters properly - cleaned_text = ''.join( - [i if ord(i) < 128 else ' ' for i in tweet.text] - ) - statements.append( - Statement(cleaned_text, in_response_to=[base_response]) - ) - - if number == 1: - return random.choice(statements) - - return statements - - def drop(self): - """ - Twitter is only a simulated data source in - this case so it cannot be removed. - """ - pass diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 02f29d83f..5a0e5a167 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -70,7 +70,7 @@ def __init__(self, name, **kwargs): # Use specified trainer or fall back to the default trainer = kwargs.get('trainer', 'chatterbot.trainers.Trainer') TrainerClass = import_module(trainer) - self.trainer = TrainerClass(self.storage) + self.trainer = TrainerClass(self.storage, **kwargs) self.logger = kwargs.get('logger', logging.getLogger(__name__)) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index efbfb09dd..33ee38a3a 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -1,5 +1,6 @@ from .conversation import Statement, Response from .corpus import Corpus +import logging class Trainer(object): @@ -7,6 +8,7 @@ class Trainer(object): def __init__(self, storage, **kwargs): self.storage = storage self.corpus = Corpus() + self.logger = logging.getLogger(__name__) def train(self, *args, **kwargs): raise self.TrainerInitializationException() @@ -82,3 +84,89 @@ def train(self, *corpora): for data in corpus_data: for pair in data: trainer.train(pair) + + +class TwitterTrainer(Trainer): + + def __init__(self, storage, **kwargs): + super(TwitterTrainer, self).__init__(storage, **kwargs) + from twitter import Api as TwitterApi + + self.api = TwitterApi( + consumer_key=kwargs.get('twitter_consumer_key'), + consumer_secret=kwargs.get('twitter_consumer_secret'), + access_token_key=kwargs.get('twitter_access_token_key'), + access_token_secret=kwargs.get('twitter_access_token_secret') + ) + + def random_word(self, base_word='random'): + """ + Generate a random word using the Twitter API. + + Search twitter for recent tweets containing the term 'random'. + Then randomly select one word from those tweets and do another + search with that word. Return a randomly selected word from the + new set of results. + """ + import random + random_tweets = self.api.GetSearch(term=base_word, count=5) + random_words = self.get_words_from_tweets(random_tweets) + random_word = random.choice(list(random_words)) + tweets = self.api.GetSearch(term=random_word, count=5) + words = self.get_words_from_tweets(tweets) + word = random.choice(list(words)) + return word + + def get_words_from_tweets(self, tweets): + """ + Given a list of tweets, return the set of + words from the tweets. + """ + words = set() + + for tweet in tweets: + # TODO: Handle non-ascii characters properly + cleaned_text = ''.join( + [i if ord(i) < 128 else ' ' for i in tweet.text] + ) + tweet_words = cleaned_text.split() + + for word in tweet_words: + # If the word contains only letters with a length from 4 to 9 + if word.isalpha() and len(word) > 3 and len(word) <= 9: + words.add(word) + + return words + + def get_statements(self): + """ + Returns list of random statements from the API. + """ + from twitter import TwitterError + statements = [] + + # Generate a random word + random_word = self.random_word() + + self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word)) + tweets = self.api.GetSearch(term=random_word, count=50) + for tweet in tweets: + statement = Statement(tweet.text) + + if tweet.in_reply_to_status_id: + try: + status = self.api.GetStatus(tweet.in_reply_to_status_id) + statement.add_response(Response(status.text)) + statements.append(statement) + except TwitterError as e: + self.logger.warning(str(e)) + + self.logger.info('Adding {} tweets with responses'.format(len(statements))) + + return statements + + def train(self): + for i in range(0, 10): + statements = self.get_statements() + for statement in statements: + self.storage.update(statement, force=True) diff --git a/docs/adapters/storage.rst b/docs/adapters/storage.rst index 3b127f90f..d6d025b22 100644 --- a/docs/adapters/storage.rst +++ b/docs/adapters/storage.rst @@ -66,30 +66,3 @@ can set the `database_uri` parameter to the uri of your database. .. code-block:: python database_uri='mongodb://example.com:8100/' - -Twitter Adapter -============================== - -.. autofunction:: chatterbot.adapters.storage.TwitterAdapter - -"chatterbot.adapters.storage.TwitterAdapter" - -Create an app from you twiter acccount, Once created -It will have following app credentails that are required to work with -TwitterAdapter. - -twitter_consumer_key --------------------- -Consumer key of twitter app. - -twitter_consumer_secret ------------------------ -Consumer secret of twitter app. - -twitter_access_token_key ------------------------- -Access token key of twitter app. - -twitter_access_token_secret ---------------------------- -Access token secret of twitter app. diff --git a/docs/training.rst b/docs/training.rst index 78896040c..8acf97ee1 100644 --- a/docs/training.rst +++ b/docs/training.rst @@ -15,6 +15,7 @@ The case that someone wants to create a custom training module typically comes u .. _set_trainer: + Setting the training class ========================== @@ -22,9 +23,12 @@ ChatterBot comes with training classes built in, or you can create your own if needed. To use a training class you must import it and pass it to the `set_trainer()` method before calling `train()`. + Training via list data ====================== +.. autofunction:: chatterbot.trainers.ListTrainer + For the training, process, you will need to pass in a list of statements where the order of each statement is based on it's placement in a given conversation. For example, if you were to run bot of the following training calls, then the resulting chatterbot would respond to both statements of "Hi there!" and "Greetings!" by saying "Hello". @@ -59,9 +63,12 @@ This will establish each item in the list as a possible response to it's predece "You are welcome.", ]) + Training with corpus data ========================= +.. autofunction:: chatterbot.trainers.ChatterBotCorpusTrainer + ChatterBot comes with a corpus data and utility module that makes it easy to quickly train your bot to communicate. To do so, simply specify the corpus data modules you want to use. @@ -91,6 +98,35 @@ conversations corpora then you would simply specify them. "chatterbot.corpus.english.conversations" ) + +Training with the Twitter API +============================= + +.. autofunction:: chatterbot.trainers.TwitterTrainer + +Create an new app using you twiter acccount. Once created, +it will provide you with the following credentails that are +required to work with the Twitter API. + ++-------------------------------------+-------------------------------------+ +| Parameter | Description | ++=====================================+=====================================+ +| :code:`twitter_consumer_key` | Consumer key of twitter app. | ++-------------------------------------+-------------------------------------+ +| :code:`twitter_consumer_secret` | Consumer secret of twitter app. | ++-------------------------------------+-------------------------------------+ +| :code:`twitter_access_token_key` | Access token key of twitter app. | ++-------------------------------------+-------------------------------------+ +| :code:`twitter_access_token_secret` | Access token secret of twitter app. | ++-------------------------------------+-------------------------------------+ + +Twitter training example +------------------------ + +.. literalinclude:: ../examples/twitter_training_example.py + :language: python + + Creating a new training class ============================= @@ -105,6 +141,7 @@ parameters you choose. Take a look at the existing `trainer classes on GitHub`_ for examples. + The ChatterBot Corpus ===================== @@ -122,6 +159,7 @@ To explore what languages and sets of corpora are available, check out the `chat If you are interested in contributing a new language corpus, or adding content to an existing language in the corpus, please feel free to submit a pull request on ChatterBot's GitHub page. Contributions are welcomed! + Exporting your chat bot's database as a training corpus ======================================================= diff --git a/examples/twitter_example.py b/examples/twitter_training_example.py similarity index 61% rename from examples/twitter_example.py rename to examples/twitter_training_example.py index cf8f0c0aa..3613d32dd 100644 --- a/examples/twitter_example.py +++ b/examples/twitter_training_example.py @@ -1,8 +1,12 @@ from chatterbot import ChatBot from settings import TWITTER +import logging ''' +This example demonstrates how you can train your chat bot +using data from Twitter. + To use this example, create a new file called settings.py. In settings.py define the following: @@ -14,25 +18,23 @@ } ''' -chatbot = ChatBot("ChatterBot", - storage_adapter="chatterbot.adapters.storage.TwitterAdapter", +# Comment out the following line to disable verbose logging +logging.basicConfig(level=logging.INFO) + +chatbot = ChatBot("TwitterBot", logic_adapters=[ "chatterbot.adapters.logic.ClosestMatchAdapter" ], input_adapter="chatterbot.adapters.input.TerminalAdapter", output_adapter="chatterbot.adapters.output.TerminalAdapter", - database="../database.db", + database="./twitter-database.db", twitter_consumer_key=TWITTER["CONSUMER_KEY"], twitter_consumer_secret=TWITTER["CONSUMER_SECRET"], twitter_access_token_key=TWITTER["ACCESS_TOKEN"], - twitter_access_token_secret=TWITTER["ACCESS_TOKEN_SECRET"] + twitter_access_token_secret=TWITTER["ACCESS_TOKEN_SECRET"], + trainer="chatterbot.trainers.TwitterTrainer" ) -print("Type something to begin...") - -while True: - try: - bot_input = chatbot.get_response(None) +chatbot.train() - except (KeyboardInterrupt, EOFError, SystemExit): - break +chatbot.logger.info('Trained database generated successfully!') \ No newline at end of file diff --git a/tests/storage_adapter_tests/test_twitter_adapter.py b/tests/storage_adapter_tests/test_twitter_adapter.py deleted file mode 100644 index 3d6278901..000000000 --- a/tests/storage_adapter_tests/test_twitter_adapter.py +++ /dev/null @@ -1,100 +0,0 @@ -from unittest import TestCase -from unittest import SkipTest -from mock import Mock, MagicMock -from chatterbot.adapters.storage import TwitterAdapter -import os -import json - -def side_effect(*args, **kwargs): - from twitter import Status - - # A special case for testing a response with no results - if 'term' in kwargs and kwargs.get('term') == 'Non-existant': - return [] - - current_directory = os.path.dirname(os.path.realpath(__file__)) - data_file = os.path.join( - current_directory, - 'test_data', - 'get_search.json' - ) - tweet_data = open(data_file) - data = json.loads(tweet_data.read()) - tweet_data.close() - - return [Status.NewFromJsonDict(x) for x in data.get('statuses', '')] - - -class TwitterAdapterTestCase(TestCase): - - def setUp(self): - """ - Instantiate the adapter. - """ - self.adapter = TwitterAdapter( - twitter_consumer_key='twitter_consumer_key', - twitter_consumer_secret='twitter_consumer_secret', - twitter_access_token_key='twitter_access_token_key', - twitter_access_token_secret='twitter_access_token_secret' - ) - self.adapter.api = Mock() - - self.adapter.api.GetSearch = MagicMock(side_effect=side_effect) - - def test_count(self): - """ - The count should always be 1. - """ - self.assertEqual(self.adapter.count(), 1) - - def test_count(self): - """ - The update method should return the input statement. - """ - from chatterbot.conversation import Statement - statement = Statement('Hello') - result = self.adapter.update(statement) - self.assertEqual(statement, result) - - def test_choose_word(self): - words = ['G', 'is', 'my', 'favorite', 'letter'] - word = self.adapter.choose_word(words) - self.assertEqual(word, words[3]) - - def test_choose_no_word(self): - words = ['q'] - word = self.adapter.choose_word(words) - self.assertEqual(word, None) - - def test_drop(self): - """ - This drop method should do nothing. - """ - self.adapter.drop() - - def test_filter(self): - statements = self.adapter.filter() - self.assertEqual(len(statements), 1) - - def test_statement_not_found(self): - """ - Test the case that a match is not found. - """ - statement = self.adapter.find('Non-existant') - self.assertEqual(statement, None) - - def test_statement_found(self): - found_statement = self.adapter.find('New statement') - self.assertNotEqual(found_statement, None) - self.assertTrue(len(found_statement.text)) - - def test_filter(self): - statements = self.adapter.filter( - text__contains='a few of my favorite things' - ) - self.assertGreater(len(statements), 0) - - def test_get_random(self): - statement = self.adapter.get_random() - self.assertNotEqual(statement, None) - self.assertTrue(len(statement.text)) diff --git a/tests/storage_adapter_tests/test_data/get_search.json b/tests/training_tests/test_data/get_search.json similarity index 50% rename from tests/storage_adapter_tests/test_data/get_search.json rename to tests/training_tests/test_data/get_search.json index cd5d68c6b..47039f9b2 100644 --- a/tests/storage_adapter_tests/test_data/get_search.json +++ b/tests/training_tests/test_data/get_search.json @@ -1 +1 @@ -{"statuses":[{"metadata":{"result_type":"popular","iso_language_code":"en"},"created_at":"Tue Dec 08 21:40:00 +0000 2015","id":674342688083283970,"id_str":"674342688083283970","text":"\ud83c\udfb6 C++, Java, Python & Ruby. These are a few of my favorite things \ud83c\udfb6 #HourOfCode \ud83d\udd51\ud83d\udcbb\ud83d\udc7e\ud83c\udfae https:\/\/t.co\/GSCmPh9V6j","source":"\u003ca href=\"https:\/\/vine.co\" rel=\"nofollow\"\u003eVine for Android\u003c\/a\u003e","truncated":false,"in_reply_to_status_id":null,"in_reply_to_status_id_str":null,"in_reply_to_user_id":null,"in_reply_to_user_id_str":null,"in_reply_to_screen_name":null,"user":{"id":58309829,"id_str":"58309829","name":"Nickelodeon","screen_name":"NickelodeonTV","location":"USA","description":"The Official Twitter for Nickelodeon, USA!","url":"https:\/\/t.co\/Lz9i6LdC4f","entities":{"url":{"urls":[{"url":"https:\/\/t.co\/Lz9i6LdC4f","expanded_url":"http:\/\/www.nick.com","display_url":"nick.com","indices":[0,23]}]},"description":{"urls":[]}},"protected":false,"followers_count":3914587,"friends_count":2263,"listed_count":3321,"created_at":"Sun Jul 19 22:19:02 +0000 2009","favourites_count":2757,"utc_offset":-18000,"time_zone":"Eastern Time (US & Canada)","geo_enabled":true,"verified":true,"statuses_count":33910,"lang":"en","contributors_enabled":false,"is_translator":false,"is_translation_enabled":true,"profile_background_color":"FA743E","profile_background_image_url":"http:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_image_url_https":"https:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_tile":false,"profile_image_url":"http:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_image_url_https":"https:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_banner_url":"https:\/\/pbs.twimg.com\/profile_banners\/58309829\/1448906254","profile_link_color":"D1771E","profile_sidebar_border_color":"FFFFFF","profile_sidebar_fill_color":"F0F0F0","profile_text_color":"333333","profile_use_background_image":false,"has_extended_profile":false,"default_profile":false,"default_profile_image":false,"following":false,"follow_request_sent":false,"notifications":false},"geo":null,"coordinates":null,"place":null,"contributors":null,"is_quote_status":false,"retweet_count":28,"favorite_count":126,"entities":{"hashtags":[{"text":"HourOfCode","indices":[72,83]}],"symbols":[],"user_mentions":[],"urls":[{"url":"https:\/\/t.co\/GSCmPh9V6j","expanded_url":"https:\/\/vine.co\/v\/i7QJji9Ldmr","display_url":"vine.co\/v\/i7QJji9Ldmr","indices":[89,112]}]},"favorited":false,"retweeted":false,"possibly_sensitive":false,"lang":"en"}]} +{"statuses":[{"metadata":{"result_type":"popular","iso_language_code":"en"},"created_at":"Tue Dec 08 21:40:00 +0000 2015","id":674342688083283970,"id_str":"674342688083283970","text":"\ud83c\udfb6 C++, Java, Python & Ruby. These are a few of my favorite things \ud83c\udfb6 #HourOfCode \ud83d\udd51\ud83d\udcbb\ud83d\udc7e\ud83c\udfae https:\/\/t.co\/GSCmPh9V6j","source":"\u003ca href=\"https:\/\/vine.co\" rel=\"nofollow\"\u003eVine for Android\u003c\/a\u003e","truncated":false,"in_reply_to_status_id":null,"in_reply_to_status_id_str":null,"in_reply_to_user_id":null,"in_reply_to_user_id_str":null,"in_reply_to_screen_name":null,"user":{"id":58309829,"id_str":"58309829","name":"Nickelodeon","screen_name":"NickelodeonTV","location":"USA","description":"The Official Twitter for Nickelodeon, USA!","url":"https:\/\/t.co\/Lz9i6LdC4f","entities":{"url":{"urls":[{"url":"https:\/\/t.co\/Lz9i6LdC4f","expanded_url":"http:\/\/www.nick.com","display_url":"nick.com","indices":[0,23]}]},"description":{"urls":[]}},"protected":false,"followers_count":3914587,"friends_count":2263,"listed_count":3321,"created_at":"Sun Jul 19 22:19:02 +0000 2009","favourites_count":2757,"utc_offset":-18000,"time_zone":"Eastern Time (US & Canada)","geo_enabled":true,"verified":true,"statuses_count":33910,"lang":"en","contributors_enabled":false,"is_translator":false,"is_translation_enabled":true,"profile_background_color":"FA743E","profile_background_image_url":"http:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_image_url_https":"https:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_tile":false,"profile_image_url":"http:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_image_url_https":"https:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_banner_url":"https:\/\/pbs.twimg.com\/profile_banners\/58309829\/1448906254","profile_link_color":"D1771E","profile_sidebar_border_color":"FFFFFF","profile_sidebar_fill_color":"F0F0F0","profile_text_color":"333333","profile_use_background_image":false,"has_extended_profile":false,"default_profile":false,"default_profile_image":false,"following":false,"follow_request_sent":false,"notifications":false},"geo":null,"coordinates":null,"place":null,"contributors":null,"is_quote_status":false,"retweet_count":28,"favorite_count":126,"entities":{"hashtags":[{"text":"HourOfCode","indices":[72,83]}],"symbols":[],"user_mentions":[],"urls":[{"url":"https:\/\/t.co\/GSCmPh9V6j","expanded_url":"https:\/\/vine.co\/v\/i7QJji9Ldmr","display_url":"vine.co\/v\/i7QJji9Ldmr","indices":[89,112]}]},"favorited":false,"retweeted":false,"possibly_sensitive":false,"lang":"en"},{"metadata":{"result_type":"popular","iso_language_code":"en"},"created_at":"Tue Dec 08 21:45:00 +0000 2015","id":674342688083283970,"id_str":"674342688083283970","text":"Are you sure about Ruby?","source":"\u003ca href=\"https:\/\/vine.co\" rel=\"nofollow\"\u003eVine for Android\u003c\/a\u003e","truncated":false,"in_reply_to_status_id":674342688083283970,"in_reply_to_status_id_str":"674342688083283970","in_reply_to_user_id":null,"in_reply_to_user_id_str":null,"in_reply_to_screen_name":null,"user":{"id":58309829,"id_str":"58309829","name":"Nickelodeon","screen_name":"NickelodeonTV","location":"USA","description":"The Official Twitter for Nickelodeon, USA!","url":"https:\/\/t.co\/Lz9i6LdC4f","entities":{"url":{"urls":[{"url":"https:\/\/t.co\/Lz9i6LdC4f","expanded_url":"http:\/\/www.nick.com","display_url":"nick.com","indices":[0,23]}]},"description":{"urls":[]}},"protected":false,"followers_count":3914587,"friends_count":2263,"listed_count":3321,"created_at":"Sun Jul 19 22:19:02 +0000 2009","favourites_count":2757,"utc_offset":-18000,"time_zone":"Eastern Time (US & Canada)","geo_enabled":true,"verified":true,"statuses_count":33910,"lang":"en","contributors_enabled":false,"is_translator":false,"is_translation_enabled":true,"profile_background_color":"FA743E","profile_background_image_url":"http:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_image_url_https":"https:\/\/pbs.twimg.com\/profile_background_images\/450718163508789248\/E26KBqrx.jpeg","profile_background_tile":false,"profile_image_url":"http:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_image_url_https":"https:\/\/pbs.twimg.com\/profile_images\/671387650792665088\/sJxvItMD_normal.jpg","profile_banner_url":"https:\/\/pbs.twimg.com\/profile_banners\/58309829\/1448906254","profile_link_color":"D1771E","profile_sidebar_border_color":"FFFFFF","profile_sidebar_fill_color":"F0F0F0","profile_text_color":"333333","profile_use_background_image":false,"has_extended_profile":false,"default_profile":false,"default_profile_image":false,"following":false,"follow_request_sent":false,"notifications":false},"geo":null,"coordinates":null,"place":null,"contributors":null,"is_quote_status":false,"retweet_count":28,"favorite_count":126,"entities":{"hashtags":[{"text":"HourOfCode","indices":[72,83]}],"symbols":[],"user_mentions":[],"urls":[{"url":"https:\/\/t.co\/GSCmPh9V6j","expanded_url":"https:\/\/vine.co\/v\/i7QJji9Ldmr","display_url":"vine.co\/v\/i7QJji9Ldmr","indices":[89,112]}]},"favorited":false,"retweeted":false,"possibly_sensitive":false,"lang":"en"}]} \ No newline at end of file diff --git a/tests/training_tests/test_twitter_trainer.py b/tests/training_tests/test_twitter_trainer.py new file mode 100644 index 000000000..83b5dd0e6 --- /dev/null +++ b/tests/training_tests/test_twitter_trainer.py @@ -0,0 +1,83 @@ +from tests.base_case import ChatBotTestCase +from unittest import SkipTest +from mock import Mock, MagicMock +from chatterbot.trainers import TwitterTrainer +import os +import json + + +def get_search_side_effect(*args, **kwargs): + from twitter import Status + + current_directory = os.path.dirname(os.path.realpath(__file__)) + data_file = os.path.join( + current_directory, + 'test_data', + 'get_search.json' + ) + tweet_data = open(data_file) + data = json.loads(tweet_data.read()) + tweet_data.close() + + return [Status.NewFromJsonDict(x) for x in data.get('statuses')] + + +def get_status_side_effect(*args, **kwargs): + from twitter import Status + + current_directory = os.path.dirname(os.path.realpath(__file__)) + data_file = os.path.join( + current_directory, + 'test_data', + 'get_search.json' + ) + tweet_data = open(data_file) + data = json.loads(tweet_data.read()) + tweet_data.close() + + return Status.NewFromJsonDict(data.get('statuses')[1]) + + +class TwitterTrainerTestCase(ChatBotTestCase): + + def setUp(self): + """ + Instantiate the trainer class for testing. + """ + super(TwitterTrainerTestCase, self).setUp() + + self.trainer = TwitterTrainer( + self.chatbot.storage, + twitter_consumer_key='twitter_consumer_key', + twitter_consumer_secret='twitter_consumer_secret', + twitter_access_token_key='twitter_access_token_key', + twitter_access_token_secret='twitter_access_token_secret' + ) + self.trainer.api = Mock() + + self.trainer.api.GetSearch = MagicMock(side_effect=get_search_side_effect) + self.trainer.api.GetStatus = MagicMock(side_effect=get_status_side_effect) + + def test_random_word(self): + word = self.trainer.random_word() + + self.assertTrue(len(word) > 3) + + def test_get_words_from_tweets(self): + tweets = get_search_side_effect() + words = self.trainer.get_words_from_tweets(tweets) + + self.assertIn('about', words) + self.assertIn('favorite', words) + self.assertIn('things', words) + + def test_get_statements(self): + statements = self.trainer.get_statements() + + self.assertEqual(len(statements), 1) + + def test_train(self): + self.trainer.train() + + statement_created = self.trainer.storage.filter() + self.assertTrue(len(statement_created))