Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the extra_data attribute with tags #1402

Merged
merged 2 commits into from
Sep 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def learn_response(self, statement, previous_statement):
text=statement.text,
in_response_to=previous_statement_text,
conversation=statement.conversation,
extra_data=statement.extra_data,
tags=statement.tags
)

Expand Down
32 changes: 1 addition & 31 deletions chatterbot/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,13 @@ def serialize(self):
:returns: A dictionary representation of the statement object.
:rtype: dict
"""
extra_data = self.extra_data

if not extra_data:
extra_data = '{}'

if type(extra_data) == str:
import json
extra_data = json.loads(extra_data)

return {
'id': self.id,
'text': self.text,
'created_at': self.created_at.isoformat().split('+', 1)[0],
'conversation': self.conversation,
'in_response_to': self.in_response_to,
'tags': self.get_tags(),
'extra_data': extra_data
'tags': self.get_tags()
}


Expand Down Expand Up @@ -70,8 +60,6 @@ def __init__(self, text, **kwargs):
if not isinstance(self.created_at, datetime):
self.created_at = date_parser.parse(self.created_at)

self.extra_data = kwargs.pop('extra_data', {})

# This is the confidence with which the chat bot believes
# this is an accurate response. This value is set when the
# statement is returned by the chat bot.
Expand Down Expand Up @@ -103,24 +91,6 @@ def save(self):
"""
self.storage.update(self)

def add_extra_data(self, key, value):
"""
This method allows additional data to be stored on the statement object.

Typically this data is something that pertains just to this statement.
For example, a value stored here might be the tagged parts of speech for
each word in the statement text.

- key = 'pos_tags'
- value = [('Now', 'RB'), ('for', 'IN'), ('something', 'NN'), ('different', 'JJ')]

:param key: The key to use in the dictionary of extra data.
:type key: str

:param value: The value to set for the specified key.
"""
self.extra_data[key] = value

class InvalidTypeException(Exception):

def __init__(self, value='Received an unexpected value type.'):
Expand Down
19 changes: 0 additions & 19 deletions chatterbot/ext/django_chatterbot/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ class AbstractBaseStatement(models.Model, StatementMixin):
null=True
)

extra_data = models.CharField(
max_length=500,
blank=True
)

# This is the confidence with which the chat bot believes
# this is an accurate response. This value is set when the
# statement is returned by the chat bot.
Expand All @@ -70,20 +65,6 @@ def __str__(self):
return self.text
return '<empty>'

def add_extra_data(self, key, value):
"""
Add extra data to the extra_data field.
"""
import json

if not self.extra_data:
self.extra_data = '{}'

extra_data = json.loads(self.extra_data)
extra_data[key] = value

self.extra_data = json.dumps(extra_data)

def get_tags(self):
"""
Return the list of tags for this statement.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
('django_chatterbot', '0013_change_conversations'),
]

operations = [
migrations.RemoveField(
model_name='statement',
name='extra_data',
),
]
5 changes: 1 addition & 4 deletions chatterbot/ext/sqlalchemy_app/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey, PickleType
from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declared_attr, declarative_base
Expand Down Expand Up @@ -71,8 +71,6 @@ class Statement(Base, StatementMixin):
backref='statements'
)

extra_data = Column(PickleType)

in_response_to = Column(
String(constants.STATEMENT_TEXT_MAX_LENGTH),
nullable=True
Expand All @@ -93,6 +91,5 @@ def get_statement(self):
conversation=self.conversation,
created_at=self.created_at,
tags=self.get_tags(),
extra_data=self.extra_data,
in_response_to=self.in_response_to
)
13 changes: 9 additions & 4 deletions chatterbot/input/hipchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def process_input(self, statement, conversation):
response_statement = conversation[-1] if conversation else None

if response_statement:
last_message_id = response_statement.extra_data.get(
'hipchat_message_id', None
)
tags = response_statement.get_tags()
last_message_id = None

for tag in tags:
if tag.startswith('hipchat_message_id:'):
last_message_id = tag.split('hipchat_message_id:')[-1]
break

if last_message_id:
self.recent_message_ids.add(last_message_id)

Expand All @@ -109,6 +114,6 @@ def process_input(self, statement, conversation):
text = data['message']

statement = Statement(text)
statement.add_extra_data('hipchat_message_id', data['id'])
statement.add_tags('hipchat_message_id' + data['id'])

return statement
2 changes: 1 addition & 1 deletion chatterbot/output/hipchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def process_response(self, statement, session_id=None):

# Update the output statement with the message id
self.chatbot.storage.update(
statement.add_extra_data('hipchat_message_id', data['id'])
statement.add_tags('hipchat_message_id:' + data['id'])
)

return statement
3 changes: 1 addition & 2 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def update(self, statement):
text=statement.text,
conversation=statement.conversation,
in_response_to=statement.in_response_to,
created_at=statement.created_at,
extra_data=getattr(statement, 'extra_data', '')
created_at=statement.created_at
)

for tag in statement.tags.all():
Expand Down
5 changes: 0 additions & 5 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ def update(self, statement):

record.created_at = statement.created_at

if statement.extra_data is None:
statement.extra_data = {}

record.extra_data = dict(statement.extra_data)

for _tag in statement.tags:
tag = session.query(Tag).filter_by(name=_tag).first()

Expand Down
6 changes: 3 additions & 3 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,11 @@ def train(self):
)
print(text, len(row))

statement.add_extra_data('datetime', row[0])
statement.add_extra_data('speaker', row[1])
statement.add_tags('datetime:' + row[0])
statement.add_tags('speaker:' + row[1])

if row[2].strip():
statement.add_extra_data('addressing_speaker', row[2])
statement.add_tags('addressing_speaker:', row[2])

previous_statement_text = statement.text

Expand Down
47 changes: 19 additions & 28 deletions examples/django_app/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ def setUp(self):
super(ApiTestCase, self).setUp()
self.api_url = reverse('chatterbot')

def _get_json(self, response):
from django.utils.encoding import force_text
return json.loads(force_text(response.content))

def test_invalid_text(self):
response = self.client.post(
self.api_url,
Expand All @@ -24,11 +20,9 @@ def test_invalid_text(self):
format='json'
)

content = json.loads(response.content.decode('utf-8'))

self.assertEqual(response.status_code, 400)
self.assertIn('text', content)
self.assertEqual(['The attribute "text" is required.'], content['text'])
self.assertIn('text', response.json())
self.assertEqual(['The attribute "text" is required.'], response.json()['text'])

def test_post(self):
"""
Expand All @@ -43,12 +37,10 @@ def test_post(self):
format='json'
)

content = json.loads(response.content.decode('utf-8'))

self.assertEqual(response.status_code, 200)
self.assertIn('text', content)
self.assertGreater(len(content['text']), 1)
self.assertIn('in_response_to', content)
self.assertIn('text', response.json())
self.assertGreater(len(response.json()['text']), 1)
self.assertIn('in_response_to', response.json())

def test_post_unicode(self):
"""
Expand All @@ -63,16 +55,14 @@ def test_post_unicode(self):
format='json'
)

content = json.loads(response.content.decode('utf-8'))

self.assertEqual(response.status_code, 200)
self.assertIn('text', content)
self.assertGreater(len(content['text']), 1)
self.assertIn('in_response_to', content)
self.assertIn('text', response.json())
self.assertGreater(len(response.json()['text']), 1)
self.assertIn('in_response_to', response.json())

def test_escaped_unicode_post(self):
"""
Test that unicode reponse
Test that unicode reponce
"""
response = self.client.post(
self.api_url,
Expand All @@ -84,15 +74,15 @@ def test_escaped_unicode_post(self):
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('in_response_to', str(response.content))
self.assertIn('text', response.json())
self.assertIn('in_response_to', response.json())

def test_post_extra_data(self):
def test_post_tags(self):
post_data = {
'text': 'Good morning.',
'extra_data': {
'user': 'jen@example.com'
}
'tags': [
'user:jen@example.com'
]
}
response = self.client.post(
self.api_url,
Expand All @@ -102,9 +92,10 @@ def test_post_extra_data(self):
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('extra_data', str(response.content))
self.assertIn('in_response_to', str(response.content))
self.assertIn('text', response.json())
self.assertIn('in_response_to', response.json())
self.assertIn('tags', response.json())
self.assertIn('user:jen@example.com', response.json()['tags'])

def test_get(self):
response = self.client.get(self.api_url)
Expand Down
26 changes: 11 additions & 15 deletions examples/django_app/tests/test_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from django.test import TestCase
from django.urls import reverse
from django.utils.encoding import force_text


class ViewTestCase(TestCase):
Expand Down Expand Up @@ -43,15 +42,15 @@ def test_post(self):
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('in_response_to', str(response.content))
self.assertIn('text', response.json())
self.assertIn('in_response_to', response.json())

def test_post_extra_data(self):
def test_post_tags(self):
post_data = {
'text': 'Good morning.',
'extra_data': {
'user': 'jen@example.com'
}
'tags': [
'user:jen@example.com'
]
}
response = self.client.post(
self.api_url,
Expand All @@ -61,9 +60,10 @@ def test_post_extra_data(self):
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('extra_data', str(response.content))
self.assertIn('in_response_to', str(response.content))
self.assertIn('text', response.json())
self.assertIn('in_response_to', response.json())
self.assertIn('tags', response.json())
self.assertIn('user:jen@example.com', response.json()['tags'])


class ApiIntegrationTestCase(TestCase):
Expand All @@ -76,11 +76,7 @@ def setUp(self):
super(ApiIntegrationTestCase, self).setUp()
self.api_url = reverse('chatterbot')

def _get_json(self, response):
return json.loads(force_text(response.content))

def test_get(self):
response = self.client.get(self.api_url)
data = self._get_json(response)

self.assertIn('name', data)
self.assertIn('name', response.json())
9 changes: 2 additions & 7 deletions tests/logic_adapter_tests/test_data_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DummyMutatorLogicAdapter(LogicAdapter):
"""

def process(self, statement):
statement.add_extra_data('pos_tags', 'NN')
statement.add_tags('pos_tags:NN')

self.chatbot.storage.update(statement)
statement.confidence = 1
Expand Down Expand Up @@ -45,9 +45,4 @@ def test_additional_attributes_saved(self):
)

self.assertEqual(len(results), 1)

data = results[0].serialize()

self.assertIn('extra_data', data)
self.assertIn('pos_tags', data['extra_data'])
self.assertEqual('NN', data['extra_data']['pos_tags'])
self.assertIn('pos_tags:NN', results[0].get_tags())
Loading