From 5f4bf34089d4f2ce57b3f007abd2d249fa21b66b Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sun, 31 Mar 2019 15:56:53 -0400 Subject: [PATCH 1/6] Switch from NLTK to Spacy for indexing --- chatterbot/__main__.py | 18 --- chatterbot/storage/storage_adapter.py | 4 +- chatterbot/tagging.py | 145 ----------------- chatterbot/tokenizers.py | 62 -------- chatterbot/trainers.py | 4 +- chatterbot/utils.py | 73 --------- docs/commands.rst | 12 -- docs/utils.rst | 5 - tests/base_case.py | 7 - tests/test_chatbot.py | 18 +-- tests/test_cli.py | 4 - tests/test_search.py | 2 +- tests/test_tagging.py | 150 ------------------ tests/test_tokenizers.py | 37 ----- tests/test_utils.py | 10 -- tests/training/test_list_training.py | 4 +- tests/training/test_ubuntu_corpus_training.py | 4 +- 17 files changed, 13 insertions(+), 546 deletions(-) delete mode 100644 chatterbot/tokenizers.py delete mode 100644 tests/test_tokenizers.py diff --git a/chatterbot/__main__.py b/chatterbot/__main__.py index 1fe45f7ef..d6c03a735 100644 --- a/chatterbot/__main__.py +++ b/chatterbot/__main__.py @@ -1,6 +1,5 @@ import importlib import sys -import os def get_chatterbot_version(): @@ -8,23 +7,6 @@ def get_chatterbot_version(): return chatterbot.__version__ -def get_nltk_data_directories(): - import nltk.data - - data_directories = [] - - # Find each data directory in the NLTK path that has content - for path in nltk.data.path: - if os.path.exists(path): - if os.listdir(path): - data_directories.append(path) - - return os.linesep.join(data_directories) - - if __name__ == '__main__': if '--version' in sys.argv: print(get_chatterbot_version()) - - if 'list_nltk_data' in sys.argv: - print(get_nltk_data_directories()) diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index cd8854b11..821760a76 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -1,6 +1,6 @@ import logging from chatterbot import languages -from chatterbot.tagging import PosHypernymTagger +from chatterbot.tagging import PosLemmaTagger class StorageAdapter(object): @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs): """ self.logger = kwargs.get('logger', logging.getLogger(__name__)) - self.tagger = PosHypernymTagger(language=kwargs.get( + self.tagger = PosLemmaTagger(language=kwargs.get( 'tagger_language', languages.ENG )) diff --git a/chatterbot/tagging.py b/chatterbot/tagging.py index cfd755ca2..3d73986af 100644 --- a/chatterbot/tagging.py +++ b/chatterbot/tagging.py @@ -1,10 +1,5 @@ import string from chatterbot import languages -from chatterbot import utils -from chatterbot.tokenizers import get_sentence_tokenizer -from nltk import pos_tag -from nltk.corpus import wordnet, stopwords -from nltk.corpus.reader.wordnet import WordNetError import spacy @@ -56,143 +51,3 @@ def get_bigram_pair_string(self, text): ] return ' '.join(bigram_pairs) - - -class PosHypernymTagger(object): - """ - For each non-stopword in a string, return a string where each word is a - hypernym preceded by the part of speech of the word before it. - """ - - def __init__(self, language=None): - self.language = language or languages.ENG - - self.sentence_tokenizer = None - - self.stopwords = None - - self.initialization_functions = [ - utils.download_nltk_stopwords, - utils.download_nltk_wordnet, - utils.download_nltk_averaged_perceptron_tagger - ] - - def get_stopwords(self): - """ - Get the list of stopwords from the NLTK corpus. - """ - if self.stopwords is None: - self.stopwords = stopwords.words(self.language.ENGLISH_NAME.lower()) - - return self.stopwords - - def tokenize_sentence(self, sentence): - """ - Tokenize the provided sentence. - """ - if self.sentence_tokenizer is None: - self.sentence_tokenizer = get_sentence_tokenizer(self.language) - - return self.sentence_tokenizer.tokenize(sentence) - - def stem_words(self, words): - """ - Return the first character of the word in place of a part-of-speech tag. - """ - return [ - (word, word.lower()[0], ) for word in words - ] - - def get_pos_tags(self, words): - try: - # pos_tag supports eng and rus - tags = pos_tag(words, lang=self.language.ISO_639) - except NotImplementedError: - tags = self.stem_words(words) - except LookupError: - tags = self.stem_words(words) - - return tags - - def get_hypernyms(self, pos_tags): - """ - Return the hypernyms for each word in a list of POS tagged words. - """ - results = [] - - for word, pos in pos_tags: - try: - synsets = wordnet.synsets(word, utils.treebank_to_wordnet(pos), lang=self.language.ISO_639) - except WordNetError: - synsets = None - except LookupError: - # Don't return any synsets if the language is not supported - synsets = None - - if synsets: - synset = synsets[0] - hypernyms = synset.hypernyms() - - if hypernyms: - results.append(hypernyms[0].name().split('.')[0]) - else: - results.append(word) - else: - results.append(word) - - return results - - def get_bigram_pair_string(self, text): - """ - For example: - What a beautiful swamp - - becomes: - - DT:beautiful JJ:wetland - """ - WORD_INDEX = 0 - POS_INDEX = 1 - - pos_tags = [] - - for sentence in self.tokenize_sentence(text.strip()): - - # Remove punctuation - if sentence and sentence[-1] in string.punctuation: - sentence_with_punctuation_removed = sentence[:-1] - - if sentence_with_punctuation_removed: - sentence = sentence_with_punctuation_removed - - words = sentence.split() - - pos_tags.extend(self.get_pos_tags(words)) - - hypernyms = self.get_hypernyms(pos_tags) - - high_quality_bigrams = [] - all_bigrams = [] - - word_count = len(pos_tags) - - if word_count == 1: - all_bigrams.append( - pos_tags[0][WORD_INDEX].lower() - ) - - for index in range(1, word_count): - word = pos_tags[index][WORD_INDEX].lower() - previous_word_pos = pos_tags[index - 1][POS_INDEX] - if word not in self.get_stopwords() and len(word) > 1: - bigram = previous_word_pos + ':' + hypernyms[index].lower() - high_quality_bigrams.append(bigram) - all_bigrams.append(bigram) - else: - bigram = previous_word_pos + ':' + word - all_bigrams.append(bigram) - - if high_quality_bigrams: - all_bigrams = high_quality_bigrams - - return ' '.join(all_bigrams) diff --git a/chatterbot/tokenizers.py b/chatterbot/tokenizers.py deleted file mode 100644 index ab1f27d0e..000000000 --- a/chatterbot/tokenizers.py +++ /dev/null @@ -1,62 +0,0 @@ -from pickle import dump, load -from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktTrainer -from nltk.tokenize import _treebank_word_tokenizer -from chatterbot.corpus import load_corpus, list_corpus_files -from chatterbot import languages - - -def get_sentence_tokenizer(language): - """ - Return the sentence tokenizer callable. - """ - - pickle_path = 'sentence_tokenizer.pickle' - - try: - input_file = open(pickle_path, 'rb') - sentence_tokenizer = load(input_file) - input_file.close() - except FileNotFoundError: - - data_file_paths = [] - - sentences = [] - - try: - # Get the paths to each file the bot will be trained with - corpus_files = list_corpus_files('chatterbot.corpus.{language}'.format( - language=language.ENGLISH_NAME.lower() - )) - except LookupError: - # Fall back to English sentence splitting rules if a language is not supported - corpus_files = list_corpus_files('chatterbot.corpus.{language}'.format( - language=languages.ENG.ENGLISH_NAME.lower() - )) - - data_file_paths.extend(corpus_files) - - for corpus, _categories, _file_path in load_corpus(*data_file_paths): - for conversation in corpus: - for text in conversation: - sentences.append(text.upper()) - sentences.append(text.lower()) - - trainer = PunktTrainer() - trainer.INCLUDE_ALL_COLLOCS = True - trainer.train('\n'.join(sentences)) - - sentence_tokenizer = PunktSentenceTokenizer(trainer.get_params()) - - # Pickle the sentence tokenizer for future use - output_file = open(pickle_path, 'wb') - dump(sentence_tokenizer, output_file, -1) - output_file.close() - - return sentence_tokenizer - - -def get_word_tokenizer(language): - """ - Return the word tokenizer callable. - """ - return _treebank_word_tokenizer diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index d5cc08a38..3d1b1841f 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -5,7 +5,7 @@ from multiprocessing import Pool, Manager from dateutil import parser as date_parser from chatterbot.conversation import Statement -from chatterbot.tagging import PosHypernymTagger +from chatterbot.tagging import PosLemmaTagger from chatterbot import utils @@ -325,7 +325,7 @@ def track_progress(members): def train(self): import glob - tagger = PosHypernymTagger(language=self.chatbot.storage.tagger.language) + tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language) # Download and extract the Ubuntu dialog corpus if needed corpus_download_path = self.download(self.data_download_url) diff --git a/chatterbot/utils.py b/chatterbot/utils.py index 95dada105..621b7faae 100644 --- a/chatterbot/utils.py +++ b/chatterbot/utils.py @@ -1,7 +1,6 @@ """ ChatterBot utility functions """ -from nltk.corpus import wordnet def import_module(dotted_path): @@ -84,56 +83,6 @@ def validate_adapter_class(validate_class, adapter_class): ) -def nltk_download_corpus(resource_path): - """ - Download the specified NLTK corpus file - unless it has already been downloaded. - - Returns True if the corpus needed to be downloaded. - """ - from nltk.data import find - from nltk import download - from os.path import split, sep - from zipfile import BadZipfile - - # Download the NLTK data only if it is not already downloaded - _, corpus_name = split(resource_path) - - if not resource_path.endswith(sep): - resource_path = resource_path + sep - - downloaded = False - - try: - find(resource_path) - except LookupError: - download(corpus_name) - downloaded = True - except BadZipfile: - raise BadZipfile( - 'The NLTK corpus file being opened is not a zipfile, ' - 'or it has been corrupted and needs to be manually deleted.' - ) - - return downloaded - - -def treebank_to_wordnet(pos): - """ - Convert Treebank part-of-speech tags to Wordnet part-of-speech tags. - * https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html - * http://www.nltk.org/_modules/nltk/corpus/reader/wordnet.html - """ - data_map = { - 'N': wordnet.NOUN, - 'J': wordnet.ADJ, - 'V': wordnet.VERB, - 'R': wordnet.ADV - } - - return data_map.get(pos[0]) - - def get_response_time(chatbot, statement='Hello'): """ Returns the amount of time taken for a given @@ -181,25 +130,3 @@ def print_progress_bar(description, iteration_counter, total_items, progress_bar sys.stdout.flush() if total_items == iteration_counter: print('\r') - - -def download_nltk_stopwords(): - """ - Download required NLTK stopwords corpus if it has not already been downloaded. - """ - nltk_download_corpus('stopwords') - - -def download_nltk_wordnet(): - """ - Download required NLTK corpora if they have not already been downloaded. - """ - nltk_download_corpus('corpora/wordnet') - - -def download_nltk_averaged_perceptron_tagger(): - """ - Download the NLTK averaged perceptron tagger that is required for this algorithm - to run only if the corpora has not already been downloaded. - """ - nltk_download_corpus('averaged_perceptron_tagger') diff --git a/docs/commands.rst b/docs/commands.rst index 9792848a6..0a2e69dea 100644 --- a/docs/commands.rst +++ b/docs/commands.rst @@ -13,15 +13,3 @@ you have then you can run the following command. .. code-block:: bash python -m chatterbot --version - -Locate NLTK data -================= - -ChatterBot uses the Natural Language Toolkit (NLTK) for various -language processing functions. ChatterBot downloads additional -data that is required by NLTK. The following command can be used -to find all NLTK data directories that contain files. - -.. code-block:: bash - - python -m chatterbot list_nltk_data diff --git a/docs/utils.rst b/docs/utils.rst index 463b7ff4a..e8bd9e965 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -18,11 +18,6 @@ Class initialization .. autofunction:: chatterbot.utils.initialize_class -NLTK corpus downloader - -.. autofunction:: chatterbot.utils.nltk_download_corpus - - ChatBot response time --------------------- diff --git a/tests/base_case.py b/tests/base_case.py index 05137f2ac..f9b7c1d24 100644 --- a/tests/base_case.py +++ b/tests/base_case.py @@ -1,17 +1,10 @@ from unittest import TestCase, SkipTest from chatterbot import ChatBot -from chatterbot import utils class ChatBotTestCase(TestCase): def setUp(self): - - # Make sure that test requirements are downloaded - utils.download_nltk_stopwords() - utils.download_nltk_wordnet() - utils.download_nltk_averaged_perceptron_tagger() - self.chatbot = ChatBot('Test Bot', **self.get_kwargs()) def tearDown(self): diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 790c5a487..0774b3ff4 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -43,10 +43,7 @@ def test_get_initialization_functions(self): """ functions = self.chatbot.get_initialization_functions() - self.assertIn('download_nltk_stopwords', str(functions)) - self.assertIn('download_nltk_wordnet', str(functions)) - self.assertIn('download_nltk_averaged_perceptron_tagger', str(functions)) - self.assertIsLength(functions, 3) + self.assertIsLength(functions, 0) def test_get_initialization_functions_spacy_similarity(self): """ @@ -57,10 +54,7 @@ def test_get_initialization_functions_spacy_similarity(self): list(self.chatbot.search_algorithms.values())[0].compare_statements = spacy_similarity functions = self.chatbot.get_initialization_functions() - self.assertIn('download_nltk_stopwords', str(functions)) - self.assertIn('download_nltk_wordnet', str(functions)) - self.assertIn('download_nltk_averaged_perceptron_tagger', str(functions)) - self.assertIsLength(functions, 3) + self.assertIsLength(functions, 0) def test_get_initialization_functions_jaccard_similarity(self): """ @@ -71,10 +65,7 @@ def test_get_initialization_functions_jaccard_similarity(self): list(self.chatbot.search_algorithms.values())[0].compare_statements = jaccard_similarity functions = self.chatbot.get_initialization_functions() - self.assertIn('download_nltk_wordnet', str(functions)) - self.assertIn('download_nltk_stopwords', str(functions)) - self.assertIn('download_nltk_averaged_perceptron_tagger', str(functions)) - self.assertIsLength(functions, 3) + self.assertIsLength(functions, 0) def test_no_statements_known(self): """ @@ -352,9 +343,8 @@ def test_search_text_results_after_training(self): ) )) + self.assertEqual(len(results), 1) self.assertEqual('Example A for search.', results[0].text) - self.assertEqual('Example B for search.', results[1].text) - self.assertIsLength(results, 2) class TestAdapterA(LogicAdapter): diff --git a/tests/test_cli.py b/tests/test_cli.py index 62e19be59..b6720be3e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,7 +11,3 @@ class CommandLineInterfaceTests(TestCase): def test_get_chatterbot_version(self): version = main.get_chatterbot_version() self.assertEqual(version, __version__) - - def test_get_nltk_data_directories(self): - directories = main.get_nltk_data_directories() - self.assertIn('/', directories) diff --git a/tests/test_search.py b/tests/test_search.py index f51d18b48..0a800e624 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -78,7 +78,7 @@ def test_get_closest_statement(self): statement = Statement(text='This is a lovely swamp.') results = list(self.search_algorithm.search(statement)) - self.assertIsLength(results, 2) + self.assertIsLength(results, 1) results_text = [result.text for result in results] diff --git a/tests/test_tagging.py b/tests/test_tagging.py index b54bbca09..864dda607 100644 --- a/tests/test_tagging.py +++ b/tests/test_tagging.py @@ -1,7 +1,6 @@ from unittest import TestCase from chatterbot import languages from chatterbot import tagging -from chatterbot import utils class PosLemmaTaggerTests(TestCase): @@ -136,152 +135,3 @@ def test_get_bigram_pair_string_two_character_words(self): ) self.assertEqual(bigram_string, 'VERB:mu') - - -class PosHypernymTaggerTests(TestCase): - - def setUp(self): - self.tagger = tagging.PosHypernymTagger() - - # Make sure the required NLTK data files are downloaded - for function in utils.get_initialization_functions(self, 'tagger'): - function() - - def test_empty_string(self): - tagged_text = self.tagger.get_bigram_pair_string( - '' - ) - - self.assertEqual(tagged_text, '') - - def test_tagging(self): - tagged_text = self.tagger.get_bigram_pair_string( - 'Hello, how are you doing on this awesome day?' - ) - - self.assertEqual(tagged_text, 'DT:awesome JJ:time_unit') - - def test_tagging_english(self): - self.tagger = tagging.PosHypernymTagger( - language=languages.ENG - ) - - tagged_text = self.tagger.get_bigram_pair_string( - 'Hello, how are you doing on this awesome day?' - ) - - self.assertEqual(tagged_text, 'DT:awesome JJ:time_unit') - - def test_tagging_french(self): - self.tagger = tagging.PosHypernymTagger( - language=languages.FRE - ) - - tagged_text = self.tagger.get_bigram_pair_string( - 'Salut comment allez-vous?' - ) - - self.assertEqual(tagged_text, 's:comment c:allez-vous') - - def test_tagging_russian(self): - self.tagger = tagging.PosHypernymTagger( - language=languages.RUS - ) - - tagged_text = self.tagger.get_bigram_pair_string( - 'Привет, как ты?' - ) - - self.assertEqual(tagged_text, 'п:как к:ты') - - def test_string_becomes_lowercase(self): - tagged_text = self.tagger.get_bigram_pair_string('THIS IS HOW IT BEGINS!') - - self.assertEqual(tagged_text, 'NNP:begins') - - def test_tagging_medium_sized_words(self): - tagged_text = self.tagger.get_bigram_pair_string('Hello, my name is Gunther.') - - self.assertEqual(tagged_text, 'PRP$:language_unit VBZ:gunther') - - def test_tagging_long_words(self): - tagged_text = self.tagger.get_bigram_pair_string('I play several orchestra instruments for pleasuer.') - - self.assertEqual(tagged_text, 'PRP:compete VBP:several JJ:orchestra JJ:device IN:pleasuer') - - def test_get_bigram_pair_string_punctuation_only(self): - bigram_string = self.tagger.get_bigram_pair_string( - '?' - ) - - self.assertEqual(bigram_string, '?') - - def test_get_bigram_pair_string_single_character(self): - bigram_string = self.tagger.get_bigram_pair_string( - '🙂' - ) - - self.assertEqual(bigram_string, '🙂') - - def test_get_bigram_pair_string_single_character_punctuated(self): - bigram_string = self.tagger.get_bigram_pair_string( - '🤷?' - ) - - self.assertEqual(bigram_string, '🤷') - - def test_get_bigram_pair_string_two_characters(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'AB' - ) - - self.assertEqual(bigram_string, 'ab') - - def test_get_bigram_pair_string_three_characters(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'ABC' - ) - - self.assertEqual(bigram_string, 'abc') - - def test_get_bigram_pair_string_four_characters(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'ABCD' - ) - - self.assertEqual(bigram_string, 'abcd') - - def test_get_bigram_pair_string_five_characters(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'ABCDE' - ) - - self.assertEqual(bigram_string, 'abcde') - - def test_get_bigram_pair_string_single_word(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'Hello' - ) - - self.assertEqual(bigram_string, 'hello') - - def test_get_bigram_pair_string_multiple_words(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'Hello Dr. Salazar. How are you today?' - ) - - self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar PRP:present') - - def test_get_bigram_pair_string_single_character_words(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'a e i o u' - ) - - self.assertEqual(bigram_string, 'DT:e NN:i NN:o VBP:u') - - def test_get_bigram_pair_string_two_character_words(self): - bigram_string = self.tagger.get_bigram_pair_string( - 'Lo my mu it is of us' - ) - - self.assertEqual(bigram_string, 'PRP$:letter IN:us') diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py deleted file mode 100644 index 329ecc7ee..000000000 --- a/tests/test_tokenizers.py +++ /dev/null @@ -1,37 +0,0 @@ -from unittest import TestCase -from chatterbot import languages -from chatterbot.tokenizers import get_sentence_tokenizer, get_word_tokenizer - - -class EnglishSentenceTokenizerTests(TestCase): - - def setUp(self): - super().setUp() - - self.tokenizer = get_sentence_tokenizer(languages.ENG) - - def test_one_sentence(self): - tokens = self.tokenizer.tokenize('Hello, how are you?') - - self.assertEqual(len(tokens), 1) - self.assertIn('Hello, how are you?', tokens) - - def test_two_sentences(self): - tokens = self.tokenizer.tokenize('It is so nice out. Don\'t you think so?') - - self.assertEqual(len(tokens), 2) - self.assertIn('It is so nice out.', tokens) - self.assertIn('Don\'t you think so?', tokens) - - -class EnglishWordTokenizerTests(TestCase): - - def setUp(self): - super().setUp() - - self.tokenizer = get_word_tokenizer(languages.ENG) - - def test_one_sentence(self): - tokens = self.tokenizer.tokenize('Hello, how are you?') - - self.assertEqual(['Hello', ',', 'how', 'are', 'you', '?'], tokens) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2c5909e78..fd127d397 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,16 +9,6 @@ def test_import_module(self): datetime = utils.import_module('datetime.datetime') self.assertTrue(hasattr(datetime, 'now')) - def test_nltk_download_corpus(self): - downloaded = utils.nltk_download_corpus('wordnet') - self.assertTrue(downloaded) - - def test_treebank_to_wordnet(self): - self.assertEqual(utils.treebank_to_wordnet('NNS'), 'n') - - def test_treebank_to_wordnet_no_match(self): - self.assertEqual(utils.treebank_to_wordnet('XXX'), None) - class UtilityChatBotTestCase(ChatBotTestCase): diff --git a/tests/training/test_list_training.py b/tests/training/test_list_training.py index aac778827..7a54f392f 100644 --- a/tests/training/test_list_training.py +++ b/tests/training/test_list_training.py @@ -84,7 +84,7 @@ def test_training_sets_search_text(self): )) self.assertIsLength(statements, 1) - self.assertEqual(statements[0].search_text, 'RB:kind PRP$:headdress') + self.assertEqual(statements[0].search_text, 'VERB:hat') def test_training_sets_search_in_response_to(self): @@ -100,7 +100,7 @@ def test_training_sets_search_in_response_to(self): )) self.assertIsLength(statements, 1) - self.assertEqual(statements[0].search_in_response_to, 'PRP:kind PRP$:headdress') + self.assertEqual(statements[0].search_in_response_to, 'VERB:hat') def test_database_has_correct_format(self): """ diff --git a/tests/training/test_ubuntu_corpus_training.py b/tests/training/test_ubuntu_corpus_training.py index d3e4af74f..b09987a07 100644 --- a/tests/training/test_ubuntu_corpus_training.py +++ b/tests/training/test_ubuntu_corpus_training.py @@ -176,7 +176,7 @@ def test_train_sets_search_text(self): results = list(self.chatbot.storage.filter(text='Is anyone there?')) self.assertEqual(len(results), 2) - self.assertEqual(results[0].search_text, 'VBZ:anyone') + self.assertEqual(results[0].search_text, 'VERB:anyone NOUN:there') def test_train_sets_search_in_response_to(self): """ @@ -190,7 +190,7 @@ def test_train_sets_search_in_response_to(self): results = list(self.chatbot.storage.filter(in_response_to='Is anyone there?')) self.assertEqual(len(results), 2) - self.assertEqual(results[0].search_in_response_to, 'VBZ:anyone') + self.assertEqual(results[0].search_in_response_to, 'VERB:anyone NOUN:there') def test_is_extracted(self): """ From 8bb1f299e0abb318c0b928d9067f0145eb18b848 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sun, 31 Mar 2019 15:57:56 -0400 Subject: [PATCH 2/6] Remove NLTK_DATA from passenv --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 1191c34c7..e6eed6dff 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ skipsdist = True [testenv] -passenv = DJANGO_SETTINGS_MODULE NLTK_DATA PYTHONPATH HOME DISPLAY +passenv = DJANGO_SETTINGS_MODULE PYTHONPATH HOME DISPLAY setenv = PYTHONDONTWRITEBYTECODE=1 deps = django111: Django>=1.11,<1.12 From 97879f918be51e60f719ff920ef811609a5e7389 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sun, 31 Mar 2019 20:45:26 -0400 Subject: [PATCH 3/6] Remove initialization methods --- chatterbot/chatterbot.py | 25 ------------------------- chatterbot/utils.py | 13 ------------- tests/test_chatbot.py | 30 ------------------------------ 3 files changed, 68 deletions(-) diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 19b33635c..054309543 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -54,31 +54,6 @@ def __init__(self, name, **kwargs): # Allow the bot to save input it receives so that it can learn self.read_only = kwargs.get('read_only', False) - if kwargs.get('initialize', True): - self.initialize() - - def get_initialization_functions(self): - initialization_functions = set() - - initialization_functions.update(utils.get_initialization_functions( - self, 'storage.tagger' - )) - - for search_algorithm in self.search_algorithms.values(): - search_algorithm_functions = utils.get_initialization_functions( - search_algorithm, 'compare_statements' - ) - initialization_functions.update(search_algorithm_functions) - - return initialization_functions - - def initialize(self): - """ - Do any work that needs to be done before the chatbot can process responses. - """ - for function in self.get_initialization_functions(): - function() - def get_response(self, statement=None, **kwargs): """ Return the bot's response based on the input. diff --git a/chatterbot/utils.py b/chatterbot/utils.py index 621b7faae..61c27a0f6 100644 --- a/chatterbot/utils.py +++ b/chatterbot/utils.py @@ -17,19 +17,6 @@ def import_module(dotted_path): return getattr(module, module_parts[-1]) -def get_initialization_functions(obj, attribute): - """ - Return all initialization methods for the comparison algorithm. - Initialization methods must start with 'initialize_' and take no parameters. - """ - attribute_parts = attribute.split('.') - outermost_attribute = getattr(obj, attribute_parts.pop(0)) - for next_attribute in attribute_parts: - outermost_attribute = getattr(outermost_attribute, next_attribute) - - return getattr(outermost_attribute, 'initialization_functions', []) - - def initialize_class(data, *args, **kwargs): """ :param data: A string or dictionary containing a import_path attribute. diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 0774b3ff4..dd990fdd6 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -37,36 +37,6 @@ def test_in_response_to_provided(self): ) self.assertIsNotNone(statement) - def test_get_initialization_functions(self): - """ - Test that the initialization functions are returned. - """ - functions = self.chatbot.get_initialization_functions() - - self.assertIsLength(functions, 0) - - def test_get_initialization_functions_spacy_similarity(self): - """ - Test that the initialization functions are returned. - """ - from chatterbot.comparisons import spacy_similarity - - list(self.chatbot.search_algorithms.values())[0].compare_statements = spacy_similarity - functions = self.chatbot.get_initialization_functions() - - self.assertIsLength(functions, 0) - - def test_get_initialization_functions_jaccard_similarity(self): - """ - Test that the initialization functions are returned. - """ - from chatterbot.comparisons import jaccard_similarity - - list(self.chatbot.search_algorithms.values())[0].compare_statements = jaccard_similarity - functions = self.chatbot.get_initialization_functions() - - self.assertIsLength(functions, 0) - def test_no_statements_known(self): """ If there is no statements in the database, then the From a9d7cb3010c623849d180999beeee39cc0ba8f62 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 6 Apr 2019 08:37:40 -0400 Subject: [PATCH 4/6] Update tests --- tests/test_search.py | 29 ++++++++++++----------------- tests_django/base_case.py | 4 ++-- tests_django/test_chatbot.py | 23 ++++++++++++++++++++++- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/tests/test_search.py b/tests/test_search.py index 0a800e624..6dcf3b2f7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -79,13 +79,8 @@ def test_get_closest_statement(self): results = list(self.search_algorithm.search(statement)) self.assertIsLength(results, 1) - - results_text = [result.text for result in results] - - self.assertIn('This is a lovely bog.', results_text) - self.assertIn('This is a beautiful swamp.', results_text) + self.assertEqual(results[0].text, 'This is a beautiful swamp.') self.assertGreater(results[0].confidence, 0) - self.assertGreater(results[1].confidence, 0) def test_different_punctuation(self): self.chatbot.storage.create_many([ @@ -97,8 +92,9 @@ def test_different_punctuation(self): statement = Statement(text='Are you good') results = list(self.search_algorithm.search(statement)) - self.assertIsLength(results, 1) - self.assertEqual(results[0].text, 'Are you good?') + self.assertEqual(len(results), 2, msg=[r.search_text for r in results]) + # Note: the last statement in the list always has the highest confidence + self.assertEqual(results[-1].text, 'Are you good?') class SearchComparisonFunctionLevenshteinDistanceComparisonTests(ChatBotTestCase): @@ -117,23 +113,22 @@ def setUp(self): def test_get_closest_statement(self): """ Note, the content of the in_response_to field for each of the - test statements is only required because the logic adapter will - filter out any statements that are not in response to a known statement. + test statements is only required because the search process will + filter out any statements that are not in response to something. """ self.chatbot.storage.create_many([ - Statement(text='Who do you love?', in_response_to='I hear you are going on a quest?'), - Statement(text='What is the meaning of life?', in_response_to='Yuck, black licorice jelly beans.'), - Statement(text='I am Iron Man.', in_response_to='What... is your quest?'), - Statement(text='What... is your quest?', in_response_to='I am Iron Man.'), - Statement(text='Yuck, black licorice jelly beans.', in_response_to='What is the meaning of life?'), - Statement(text='I hear you are going on a quest?', in_response_to='Who do you love?'), + Statement(text='What is the meaning of life?', in_response_to='...'), + Statement(text='I am Iron Man.', in_response_to='...'), + Statement(text='What... is your quest?', in_response_to='...'), + Statement(text='Yuck, black licorice jelly beans.', in_response_to='...'), + Statement(text='I hear you are going on a quest?', in_response_to='...'), ]) statement = Statement(text='What is your quest?') results = list(self.search_algorithm.search(statement)) - self.assertIsLength(results, 1) + self.assertEqual(len(results), 1, msg=[r.text for r in results]) self.assertEqual(results[0].text, 'What... is your quest?') def test_confidence_exact_match(self): diff --git a/tests_django/base_case.py b/tests_django/base_case.py index 0d239ed11..0572de82d 100644 --- a/tests_django/base_case.py +++ b/tests_django/base_case.py @@ -1,9 +1,9 @@ from chatterbot import ChatBot -from django.test import TestCase +from django.test import TransactionTestCase from tests_django import test_settings -class ChatterBotTestCase(TestCase): +class ChatterBotTestCase(TransactionTestCase): def setUp(self): super().setUp() diff --git a/tests_django/test_chatbot.py b/tests_django/test_chatbot.py index c0016f845..3397fa89d 100644 --- a/tests_django/test_chatbot.py +++ b/tests_django/test_chatbot.py @@ -283,6 +283,27 @@ def test_search_text_results_after_training(self): ) )) + self.assertEqual(len(results), 1, msg=[r.text for r in results]) + self.assertEqual('Example A for search.', results[0].text) + + def test_search_text_contains_results_after_training(self): + """ + ChatterBot should return close matches to an input + string when filtering using the search_text parameter. + """ + self.chatbot.storage.create_many([ + Statement('Example A for search.'), + Statement('Another example.'), + Statement('Example B for search.'), + Statement(text='Another statement.'), + ]) + + results = list(self.chatbot.storage.filter( + search_text_contains=self.chatbot.storage.tagger.get_bigram_pair_string( + 'Example A for search.' + ) + )) + + self.assertEqual(len(results), 2, msg=[r.text for r in results]) self.assertEqual('Example A for search.', results[0].text) self.assertEqual('Example B for search.', results[1].text) - self.assertEqual(len(results), 2) From c62f5a94466080aec3d85c7abe649a016706de2b Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 6 Apr 2019 12:02:13 -0400 Subject: [PATCH 5/6] Update docstrings --- chatterbot/storage/storage_adapter.py | 2 ++ chatterbot/trainers.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index 821760a76..74750311a 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -12,6 +12,8 @@ class StorageAdapter(object): def __init__(self, *args, **kwargs): """ Initialize common attributes shared by all storage adapters. + + :param str tagger_language: The language that the tagger uses to remove stopwords. """ self.logger = kwargs.get('logger', logging.getLogger(__name__)) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 3d1b1841f..18cdf5ff5 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -17,8 +17,6 @@ class Trainer(object): trainer. The environment variable ``CHATTERBOT_SHOW_TRAINING_PROGRESS`` can also be set to control this. ``show_training_progress`` will override the environment variable if it is set. - - :param str tagger_language: The language that the tagger uses to remove stopwords. """ def __init__(self, chatbot, **kwargs): From f99ec8d93c8294b3dc3b9db8cc264bb661546217 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 6 Apr 2019 12:06:32 -0400 Subject: [PATCH 6/6] Remove multiprocessing to prevent CI errors on Travis --- chatterbot/trainers.py | 104 ++++++++++------------------------------- 1 file changed, 25 insertions(+), 79 deletions(-) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 18cdf5ff5..48d2ab7ce 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,7 +2,6 @@ import sys import csv import time -from multiprocessing import Pool, Manager from dateutil import parser as date_parser from chatterbot.conversation import Statement from chatterbot.tagging import PosLemmaTagger @@ -174,41 +173,6 @@ def train(self, *corpus_paths): self.chatbot.storage.create_many(statements_to_create) -def read_file(files, queue, preprocessors, tagger): - - statements_from_file = [] - - for tsv_file in files: - with open(tsv_file, 'r', encoding='utf-8') as tsv: - reader = csv.reader(tsv, delimiter='\t') - - previous_statement_text = None - previous_statement_search_text = '' - - for row in reader: - if len(row) > 0: - statement = Statement( - text=row[3], - in_response_to=previous_statement_text, - conversation='training', - created_at=date_parser.parse(row[0]), - persona=row[1] - ) - - for preprocessor in preprocessors: - statement = preprocessor(statement) - - statement.search_text = tagger.get_bigram_pair_string(statement.text) - statement.search_in_response_to = previous_statement_search_text - - previous_statement_text = statement.text - previous_statement_search_text = statement.search_text - - statements_from_file.append(statement) - - queue.put(tuple(statements_from_file)) - - class UbuntuCorpusTrainer(Trainer): """ Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. @@ -337,9 +301,6 @@ def train(self): '**', '**', '*.tsv' ) - manager = Manager() - queue = manager.Queue() - def chunks(items, items_per_chunk): for start_index in range(0, len(items), items_per_chunk): end_index = start_index + items_per_chunk @@ -349,55 +310,40 @@ def chunks(items, items_per_chunk): file_groups = tuple(chunks(file_list, 10000)) - argument_groups = tuple( - ( - file_names, - queue, - self.chatbot.preprocessors, - tagger, - ) for file_names in file_groups - ) - - pool_batches = chunks(argument_groups, 9) - - total_batches = len(file_groups) - batch_number = 0 - start_time = time.time() - with Pool() as pool: - for pool_batch in pool_batches: - pool.starmap(read_file, pool_batch) + for tsv_files in file_groups: + + statements_from_file = [] - while True: + for tsv_file in tsv_files: + with open(tsv_file, 'r', encoding='utf-8') as tsv: + reader = csv.reader(tsv, delimiter='\t') - if queue.empty(): - break + previous_statement_text = None + previous_statement_search_text = '' - batch_number += 1 + for row in reader: + if len(row) > 0: + statement = Statement( + text=row[3], + in_response_to=previous_statement_text, + conversation='training', + created_at=date_parser.parse(row[0]), + persona=row[1] + ) - print('Training with batch {} with {} batches remaining...'.format( - batch_number, - total_batches - batch_number - )) + for preprocessor in self.chatbot.preprocessors: + statement = preprocessor(statement) - self.chatbot.storage.create_many(queue.get()) + statement.search_text = tagger.get_bigram_pair_string(statement.text) + statement.search_in_response_to = previous_statement_search_text - elapsed_time = time.time() - start_time - time_per_batch = elapsed_time / batch_number - remaining_time = time_per_batch * (total_batches - batch_number) + previous_statement_text = statement.text + previous_statement_search_text = statement.search_text - print('{:.0f} hours {:.0f} minutes {:.0f} seconds elapsed.'.format( - elapsed_time // 3600 % 24, - elapsed_time // 60 % 60, - elapsed_time % 60 - )) + statements_from_file.append(statement) - print('{:.0f} hours {:.0f} minutes {:.0f} seconds remaining.'.format( - remaining_time // 3600 % 24, - remaining_time // 60 % 60, - remaining_time % 60 - )) - print('---') + self.chatbot.storage.create_many(statements_from_file) print('Training took', time.time() - start_time, 'seconds.')