Skip to content

Commit

Permalink
Support extra_data passed to django api view.
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Oct 30, 2016
1 parent 9858058 commit 5bde0b7
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 34 deletions.
16 changes: 12 additions & 4 deletions chatterbot/adapters/storage/django_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chatterbot.adapters.storage import StorageAdapter
from chatterbot.conversation import Statement, Response
import json


class DjangoStorageAdapter(StorageAdapter):
Expand All @@ -17,7 +18,10 @@ def model_to_object(self, statement_model):
"""
Convert a Django model object into a ChatterBot Statement object.
"""
statement = Statement(statement_model.text)
statement = Statement(
statement_model.text,
extra_data=json.loads(statement_model.extra_data, encoding='utf8')
)

for response_object in statement_model.in_response_to.all():
statement.add_response(Response(
Expand Down Expand Up @@ -78,7 +82,8 @@ def update(self, statement, **kwargs):
# Do not alter the database unless writing is enabled
if not self.read_only:
django_statement, created = StatementModel.objects.get_or_create(
text=statement.text
text=statement.text,
extra_data=json.dumps(statement.extra_data)
)

for response in statement.in_response_to:
Expand Down Expand Up @@ -124,7 +129,10 @@ def remove(self, statement_text):

def drop(self):
"""
Remove the database.
Remove all data from the database.
"""
pass
from chatterbot.ext.django_chatterbot.models import Statement as StatementModel
from chatterbot.ext.django_chatterbot.models import Response as ResponseModel

StatementModel.objects.all().delete()
ResponseModel.objects.all().delete()
18 changes: 7 additions & 11 deletions chatterbot/conversation/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,16 @@ class Statement(object):

def __init__(self, text, **kwargs):
self.text = text
self.in_response_to = kwargs.get("in_response_to", [])

self.extra_data = {}

if "in_response_to" in kwargs:
del(kwargs["in_response_to"])
self.in_response_to = kwargs.pop('in_response_to', [])
self.extra_data = kwargs.pop('extra_data', {})

self.extra_data.update(kwargs)

def __str__(self):
return self.text

def __repr__(self):
return "<Statement text:%s>" % (self.text)
return '<Statement text:%s>' % (self.text)

def __eq__(self, other):
if not other:
Expand Down Expand Up @@ -89,12 +85,12 @@ def serialize(self):
"""
data = {}

data["text"] = self.text
data["in_response_to"] = []
data.update(self.extra_data)
data['text'] = self.text
data['in_response_to'] = []
data['extra_data'] = self.extra_data

for response in self.in_response_to:
data["in_response_to"].append(response.serialize())
data['in_response_to'].append(response.serialize())

return data

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.10.2 on 2016-10-30 12:13
from __future__ import unicode_literals

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('django_chatterbot', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='statement',
name='extra_data',
field=models.CharField(default='{}', max_length=500),
preserve_default=False,
),
]
9 changes: 7 additions & 2 deletions chatterbot/ext/django_chatterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@


class Statement(models.Model):
"""A short (<255) chat message, tweet, forum post, etc"""
"""
A short (<255) chat message, tweet, forum post, etc.
"""

text = models.CharField(
unique=True,
Expand All @@ -11,6 +13,8 @@ class Statement(models.Model):
max_length=255
)

extra_data = models.CharField(max_length=500)

def __str__(self):
if len(self.text.strip()) > 60:
return '{}...'.format(self.text[:57])
Expand All @@ -20,7 +24,8 @@ def __str__(self):


class Response(models.Model):
"""Connection between a response and the statement that triggered it
"""
Connection between a response and the statement that triggered it.
Comparble to a ManyToMany "through" table, but without the M2M indexing/relations.
Expand Down
16 changes: 12 additions & 4 deletions chatterbot/ext/django_chatterbot/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class ChatterBotViewMixin(object):

chatterbot = ChatBot(**settings.CHATTERBOT)

def validate(self, data):
from django.core.exceptions import ValidationError

if 'text' not in data:
raise ValidationError('The attribute "text" is required.')


class ChatterBotView(ChatterBotViewMixin, View):

Expand All @@ -24,13 +30,15 @@ def _serialize_recent_statements(self):
return recent_statements

def post(self, request, *args, **kwargs):

if request.is_ajax():
data = json.loads(request.read().decode('utf-8'))
input_statement = data.get('text')
input_data = json.loads(request.read().decode('utf-8'))
else:
input_statement = request.POST.get('text')
input_data = json.loads(request.body.decode('utf-8'))

self.validate(input_data)

response_data = self.chatterbot.get_response(input_statement)
response_data = self.chatterbot.get_response(input_data)

return JsonResponse(response_data, status=200)

Expand Down
27 changes: 26 additions & 1 deletion examples/django_app/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.test import TestCase
from django.core.urlresolvers import reverse
import json


class ApiTestCase(TestCase):
Expand All @@ -15,12 +16,36 @@ def test_post(self):
data = {
'text': 'How are you?'
}
response = self.client.post(self.api_url, data, format='json')
response = self.client.post(
self.api_url,
data=json.dumps(data),
content_type='application/json',
format='json'
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('in_response_to', str(response.content))

def test_post_extra_data(self):
post_data = {
'text': 'Good morning.',
'extra_data': {
'user': 'jen@example.com'
}
}
response = self.client.post(
self.api_url,
data=json.dumps(post_data),
content_type='application/json',
format='json'
)

self.assertEqual(response.status_code, 200)
self.assertIn('text', str(response.content))
self.assertIn('extra_data', str(response.content))
self.assertIn('in_response_to', str(response.content))

def test_get(self):
response = self.client.get(self.api_url)

Expand Down
3 changes: 2 additions & 1 deletion examples/django_app/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_get_recent_statements_empty(self):
def test_get_recent_statements(self):
response = self.client.post(
self.api_url,
{'text': 'How are you?'},
data=json.dumps({'text': 'How are you?'}),
content_type='application/json',
format='json'
)

Expand Down
24 changes: 24 additions & 0 deletions examples/django_app/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.test import TestCase
from django.core.exceptions import ValidationError
from chatterbot.ext.django_chatterbot.views import ChatterBotView


class ViewTestCase(TestCase):

def setUp(self):
super(ViewTestCase, self).setUp()
self.view = ChatterBotView()

def test_validate_text(self):
try:
self.view.validate({
'text': 'How are you?'
})
except ValidationError:
self.fail('Test raised ValidationError unexpectedly!')

def test_validate_invalid_text(self):
with self.assertRaises(ValidationError):
self.view.validate({
'type': 'classmethod'
})
20 changes: 9 additions & 11 deletions tests/logic_adapter_tests/test_data_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DummyMutatorLogicAdapter(LogicAdapter):
"""

def process(self, statement):
statement.add_extra_data("pos_tags", "NN")
statement.add_extra_data('pos_tags', 'NN')

self.context.storage.update(statement)

Expand All @@ -30,22 +30,20 @@ def setUp(self):
self.chatbot.set_trainer(ListTrainer)

self.chatbot.train([
"Hello",
"How are you?"
'Hello',
'How are you?'
])

def test_additional_attributes_saved(self):
"""
Test that an additional data attribute can be added to the statement
and that this attribute is saved.
"""
response = self.chatbot.get_response("Hello")
found_statement = self.chatbot.storage.find("Hello")
response = self.chatbot.get_response('Hello')
found_statement = self.chatbot.storage.find('Hello')
data = found_statement.serialize()

self.assertIsNotNone(found_statement)
self.assertIn("pos_tags", found_statement.serialize())
self.assertEqual(
"NN",
found_statement.serialize()["pos_tags"]
)

self.assertIn('extra_data', data)
self.assertIn('pos_tags', data['extra_data'])
self.assertEqual('NN', data['extra_data']['pos_tags'])

0 comments on commit 5bde0b7

Please sign in to comment.