Skip to content

Commit

Permalink
Add lowercase tagger and rename indexing method to be more accurate
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed May 17, 2019
1 parent 215e5f1 commit 79221d4
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 60 deletions.
4 changes: 2 additions & 2 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def get_response(self, statement=None, **kwargs):
# Make sure the input statement has its search text saved

if not input_statement.search_text:
input_statement.search_text = self.storage.tagger.get_bigram_pair_string(input_statement.text)
input_statement.search_text = self.storage.tagger.get_text_index_string(input_statement.text)

if not input_statement.search_in_response_to and input_statement.in_response_to:
input_statement.search_in_response_to = self.storage.tagger.get_bigram_pair_string(input_statement.in_response_to)
input_statement.search_in_response_to = self.storage.tagger.get_text_index_string(input_statement.in_response_to)

response = self.generate_response(input_statement, additional_response_selection_parameters)

Expand Down
1 change: 1 addition & 0 deletions chatterbot/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class JaccardSimilarity(Comparator):

def __init__(self, language):
super().__init__(language)
import spacy

self.nlp = spacy.load(self.language.ISO_639_1)

Expand Down
2 changes: 1 addition & 1 deletion chatterbot/logic/best_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def process(self, input_statement, additional_response_selection_parameters=None
}

alternate_response_selection_parameters = {
'search_in_response_to': self.chatbot.storage.tagger.get_bigram_pair_string(
'search_in_response_to': self.chatbot.storage.tagger.get_text_index_string(
input_statement.text
),
'exclude_text': recent_repeated_responses,
Expand Down
2 changes: 1 addition & 1 deletion chatterbot/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def search(self, input_statement, **additional_parameters):
'No value for search_text was available on the provided input'
)

input_search_text = self.chatbot.storage.tagger.get_bigram_pair_string(
input_search_text = self.chatbot.storage.tagger.get_text_index_string(
input_statement.text
)

Expand Down
12 changes: 6 additions & 6 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def create(self, **kwargs):
tags = kwargs.pop('tags', [])

if 'search_text' not in kwargs:
kwargs['search_text'] = self.tagger.get_bigram_pair_string(kwargs['text'])
kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])

if 'search_in_response_to' not in kwargs:
if kwargs.get('in_response_to'):
kwargs['search_in_response_to'] = self.tagger.get_bigram_pair_string(kwargs['in_response_to'])
kwargs['search_in_response_to'] = self.tagger.get_text_index_string(kwargs['in_response_to'])

statement = Statement(**kwargs)

Expand Down Expand Up @@ -137,10 +137,10 @@ def create_many(self, statements):
statement_model_object = Statement(**statement_data)

if not statement.search_text:
statement_model_object.search_text = self.tagger.get_bigram_pair_string(statement.text)
statement_model_object.search_text = self.tagger.get_text_index_string(statement.text)

if not statement.search_in_response_to and statement.in_response_to:
statement_model_object.search_in_response_to = self.tagger.get_bigram_pair_string(statement.in_response_to)
statement_model_object.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)

statement_model_object.save()

Expand Down Expand Up @@ -168,10 +168,10 @@ def update(self, statement):
else:
statement = Statement.objects.create(
text=statement.text,
search_text=self.tagger.get_bigram_pair_string(statement.text),
search_text=self.tagger.get_text_index_string(statement.text),
conversation=statement.conversation,
in_response_to=statement.in_response_to,
search_in_response_to=self.tagger.get_bigram_pair_string(statement.in_response_to),
search_in_response_to=self.tagger.get_text_index_string(statement.in_response_to),
created_at=statement.created_at
)

Expand Down
12 changes: 6 additions & 6 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def create(self, **kwargs):
kwargs['tags'] = list(set(kwargs['tags']))

if 'search_text' not in kwargs:
kwargs['search_text'] = self.tagger.get_bigram_pair_string(kwargs['text'])
kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])

if 'search_in_response_to' not in kwargs:
if kwargs.get('in_response_to'):
kwargs['search_in_response_to'] = self.tagger.get_bigram_pair_string(kwargs['in_response_to'])
kwargs['search_in_response_to'] = self.tagger.get_text_index_string(kwargs['in_response_to'])

inserted = self.statements.insert_one(kwargs)

Expand All @@ -183,10 +183,10 @@ def create_many(self, statements):
statement_data['tags'] = tag_data

if not statement.search_text:
statement_data['search_text'] = self.tagger.get_bigram_pair_string(statement.text)
statement_data['search_text'] = self.tagger.get_text_index_string(statement.text)

if not statement.search_in_response_to and statement.in_response_to:
statement_data['search_in_response_to'] = self.tagger.get_bigram_pair_string(statement.in_response_to)
statement_data['search_in_response_to'] = self.tagger.get_text_index_string(statement.in_response_to)

create_statements.append(statement_data)

Expand All @@ -197,10 +197,10 @@ def update(self, statement):
data.pop('id', None)
data.pop('tags', None)

data['search_text'] = self.tagger.get_bigram_pair_string(data['text'])
data['search_text'] = self.tagger.get_text_index_string(data['text'])

if data.get('in_response_to'):
data['search_in_response_to'] = self.tagger.get_bigram_pair_string(data['in_response_to'])
data['search_in_response_to'] = self.tagger.get_text_index_string(data['in_response_to'])

update_data = {
'$set': data
Expand Down
12 changes: 6 additions & 6 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ def create(self, **kwargs):
tags = set(kwargs.pop('tags', []))

if 'search_text' not in kwargs:
kwargs['search_text'] = self.tagger.get_bigram_pair_string(kwargs['text'])
kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])

if 'search_in_response_to' not in kwargs:
in_response_to = kwargs.get('in_response_to')
if in_response_to:
kwargs['search_in_response_to'] = self.tagger.get_bigram_pair_string(in_response_to)
kwargs['search_in_response_to'] = self.tagger.get_text_index_string(in_response_to)

statement = Statement(**kwargs)

Expand Down Expand Up @@ -236,10 +236,10 @@ def create_many(self, statements):
statement_model_object = Statement(**statement_data)

if not statement.search_text:
statement_model_object.search_text = self.tagger.get_bigram_pair_string(statement.text)
statement_model_object.search_text = self.tagger.get_text_index_string(statement.text)

if not statement.search_in_response_to and statement.in_response_to:
statement_model_object.search_in_response_to = self.tagger.get_bigram_pair_string(statement.in_response_to)
statement_model_object.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)

new_tags = set(tag_data) - set(create_tags.keys())

Expand Down Expand Up @@ -299,10 +299,10 @@ def update(self, statement):

record.created_at = statement.created_at

record.search_text = self.tagger.get_bigram_pair_string(statement.text)
record.search_text = self.tagger.get_text_index_string(statement.text)

if statement.in_response_to:
record.search_in_response_to = self.tagger.get_bigram_pair_string(statement.in_response_to)
record.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)

for tag_name in statement.get_tags():
tag = session.query(Tag).filter_by(name=tag_name).first()
Expand Down
4 changes: 3 additions & 1 deletion chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, *args, **kwargs):
"""
self.logger = kwargs.get('logger', logging.getLogger(__name__))

self.tagger = PosLemmaTagger(language=kwargs.get(
Tagger = kwargs.get('tagger', PosLemmaTagger)

self.tagger = Tagger(language=kwargs.get(
'tagger_language', languages.ENG
))

Expand Down
17 changes: 15 additions & 2 deletions chatterbot/tagging.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
import string
from chatterbot import languages
import spacy


class LowercaseTagger(object):
"""
Returns the text in lowercase.
"""

def __init__(self, language=None):
pass

def get_text_index_string(self, text):
return text.lower()


class PosLemmaTagger(object):

def __init__(self, language=None):
import spacy

self.language = language or languages.ENG

self.punctuation_table = str.maketrans(dict.fromkeys(string.punctuation))

self.nlp = spacy.load(self.language.ISO_639_1.lower())

def get_bigram_pair_string(self, text):
def get_text_index_string(self, text):
"""
Return a string of text containing part-of-speech, lemma pairs.
"""
Expand Down
6 changes: 3 additions & 3 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def train(self, conversation):
conversation_count + 1, len(conversation)
)

statement_search_text = self.chatbot.storage.tagger.get_bigram_pair_string(text)
statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)

statement = self.get_preprocessed_statement(
Statement(
Expand Down Expand Up @@ -151,7 +151,7 @@ def train(self, *corpus_paths):

for text in conversation:

statement_search_text = self.chatbot.storage.tagger.get_bigram_pair_string(text)
statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)

statement = Statement(
text=text,
Expand Down Expand Up @@ -336,7 +336,7 @@ def chunks(items, items_per_chunk):
for preprocessor in self.chatbot.preprocessors:
statement = preprocessor(statement)

statement.search_text = tagger.get_bigram_pair_string(statement.text)
statement.search_text = tagger.get_text_index_string(statement.text)
statement.search_in_response_to = previous_statement_search_text

previous_statement_text = statement.text
Expand Down
2 changes: 1 addition & 1 deletion tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_search_text_results_after_training(self):
])

results = list(self.chatbot.storage.filter(
search_text=self.chatbot.storage.tagger.get_bigram_pair_string(
search_text=self.chatbot.storage.tagger.get_text_index_string(
'Example A for search.'
)
))
Expand Down
58 changes: 29 additions & 29 deletions tests/test_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ def setUp(self):
self.tagger = tagging.PosLemmaTagger()

def test_empty_string(self):
tagged_text = self.tagger.get_bigram_pair_string(
tagged_text = self.tagger.get_text_index_string(
''
)

self.assertEqual(tagged_text, '')

def test_tagging(self):
tagged_text = self.tagger.get_bigram_pair_string(
tagged_text = self.tagger.get_text_index_string(
'Hello, how are you doing on this awesome day?'
)

Expand All @@ -27,7 +27,7 @@ def test_tagging_english(self):
language=languages.ENG
)

tagged_text = self.tagger.get_bigram_pair_string(
tagged_text = self.tagger.get_text_index_string(
'Hello, how are you doing on this awesome day?'
)

Expand All @@ -38,99 +38,99 @@ def test_tagging_german(self):
language=languages.GER
)

tagged_text = self.tagger.get_bigram_pair_string(
tagged_text = self.tagger.get_text_index_string(
'Ich spreche nicht viel Deutsch.'
)

self.assertEqual(tagged_text, 'VERB:deutsch')

def test_string_becomes_lowercase(self):
tagged_text = self.tagger.get_bigram_pair_string('THIS IS HOW IT BEGINS!')
tagged_text = self.tagger.get_text_index_string('THIS IS HOW IT BEGINS!')

self.assertEqual(tagged_text, 'DET:be VERB:how ADV:it NOUN:begin')

def test_tagging_medium_sized_words(self):
tagged_text = self.tagger.get_bigram_pair_string('Hello, my name is Gunther.')
tagged_text = self.tagger.get_text_index_string('Hello, my name is Gunther.')

self.assertEqual(tagged_text, 'INTJ:gunther')

def test_tagging_long_words(self):
tagged_text = self.tagger.get_bigram_pair_string('I play several orchestra instruments for pleasure.')
tagged_text = self.tagger.get_text_index_string('I play several orchestra instruments for pleasure.')

self.assertEqual(tagged_text, 'VERB:orchestra ADJ:instrument NOUN:pleasure')

def test_get_bigram_pair_string_punctuation_only(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_punctuation_only(self):
bigram_string = self.tagger.get_text_index_string(
'?'
)

self.assertEqual(bigram_string, '?')

def test_get_bigram_pair_string_single_character(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_single_character(self):
bigram_string = self.tagger.get_text_index_string(
'🙂'
)

self.assertEqual(bigram_string, '🙂')

def test_get_bigram_pair_string_single_character_punctuated(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_single_character_punctuated(self):
bigram_string = self.tagger.get_text_index_string(
'🤷?'
)

self.assertEqual(bigram_string, '🤷')

def test_get_bigram_pair_string_two_characters(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_two_characters(self):
bigram_string = self.tagger.get_text_index_string(
'AB'
)

self.assertEqual(bigram_string, 'ab')

def test_get_bigram_pair_string_three_characters(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_three_characters(self):
bigram_string = self.tagger.get_text_index_string(
'ABC'
)

self.assertEqual(bigram_string, 'abc')

def test_get_bigram_pair_string_four_characters(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_four_characters(self):
bigram_string = self.tagger.get_text_index_string(
'ABCD'
)

self.assertEqual(bigram_string, 'abcd')

def test_get_bigram_pair_string_five_characters(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_five_characters(self):
bigram_string = self.tagger.get_text_index_string(
'ABCDE'
)

self.assertEqual(bigram_string, 'abcde')

def test_get_bigram_pair_string_single_word(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_single_word(self):
bigram_string = self.tagger.get_text_index_string(
'Hello'
)

self.assertEqual(bigram_string, 'hello')

def test_get_bigram_pair_string_multiple_words(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_multiple_words(self):
bigram_string = self.tagger.get_text_index_string(
'Hello Dr. Salazar. How are you today?'
)

self.assertEqual(bigram_string, 'INTJ:salazar PROPN:today')

def test_get_bigram_pair_string_single_character_words(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_single_character_words(self):
bigram_string = self.tagger.get_text_index_string(
'a e i o u'
)

self.assertEqual(bigram_string, 'NOUN:o NOUN:u')

def test_get_bigram_pair_string_two_character_words(self):
bigram_string = self.tagger.get_bigram_pair_string(
def test_get_text_index_string_two_character_words(self):
bigram_string = self.tagger.get_text_index_string(
'Lo my mu it is of us'
)

Expand Down
Loading

0 comments on commit 79221d4

Please sign in to comment.