Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqlalchemy adapter #693

Merged
merged 58 commits into from
Apr 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
c25422c
Avoid problems in str.format() with char accents
davizucon Aug 25, 2016
92b3027
Merge remote-tracking branch 'upstream/master'
davizucon Aug 25, 2016
bb368e6
Encoding fix and few changes in dialogs.
davizucon Aug 25, 2016
a661284
Fix arguments (+input_statement, statement_list) for most_frequent_re…
davizucon Aug 25, 2016
edb74f3
Fix arguments (+input_statement, statement_list) for most_frequent_re…
davizucon Aug 25, 2016
7009e31
Merge branch 'master' of https://github.com/davizucon/ChatterBot into…
davizucon Aug 25, 2016
48bb0f4
Fix test missing parameter.
davizucon Aug 29, 2016
4c99ce3
Merge branch 'master' of https://github.com/gunthercox/ChatterBot int…
davizucon Aug 29, 2016
706b081
Language fix and nltk fix to find correct library
davizucon Aug 29, 2016
94c72e1
Merge branch 'upstream'
davizucon Aug 29, 2016
a96a141
Fix arguments (+input_statement, statement_list) for most_frequent_re…
davizucon Aug 30, 2016
b3cf09c
Initial Version, not working yet
davizucon Aug 31, 2016
6b1a616
Some work done, still need to implement other methods.
davizucon Aug 31, 2016
cca2882
minor typ
navyad Sep 1, 2016
bdec3dd
assert method for None checks
navyad Sep 1, 2016
427b42a
super argument fix
navyad Sep 1, 2016
a7999cb
assertIsInstance for type check
navyad Sep 1, 2016
b0dbe22
Merge pull request #2 from navyad/navyad/tests
davizucon Sep 1, 2016
a6d00a9
Merge pull request #1 from navyad/navyad/sqlite
davizucon Sep 1, 2016
2ea136b
Some work done, still need to implement filter.
davizucon Sep 2, 2016
2f19f76
get_session is private
navyad Sep 2, 2016
df91ecd
wrapper for filter
navyad Sep 2, 2016
b5b9d80
looping via map
navyad Sep 2, 2016
d05541a
poi removed
navyad Sep 2, 2016
97b09a0
minor refactor
navyad Sep 2, 2016
3fa6ac7
informatve variable name
navyad Sep 2, 2016
488e2da
Merge pull request #3 from navyad/navyad/method-wrappers
davizucon Sep 3, 2016
d87b2d8
merge changes from masters
davizucon Sep 5, 2016
321fee5
Merge remote-tracking branch 'upstream/master' into sqlalchemy-adapter
davizucon Sep 5, 2016
4d682b0
Merge branch 'master' of https://github.com/gunthercox/ChatterBot int…
davizucon Sep 10, 2016
a4c386a
Merge branch 'master' of https://github.com/gunthercox/ChatterBot int…
davizucon Sep 13, 2016
2871467
must review this commit, long time away
davizucon Mar 3, 2017
8bfacff
merge from uptream
davizucon Mar 3, 2017
5e990e2
Merge branch 'master' into sqlalchemy-adapter
davizucon Mar 3, 2017
976486c
merge from master
davizucon Mar 29, 2017
7c70154
merge from master
davizucon Mar 29, 2017
f339968
Merge branch 'master' into sqlalchemy-adapter
davizucon Mar 29, 2017
0366722
move to new package and merge
davizucon Mar 29, 2017
6702ca2
All testes passing
davizucon Mar 29, 2017
8e27a2d
to check...
davizucon Apr 19, 2017
84e43e6
Merge remote-tracking branch 'upstream/master' into sqlalchemy-adapter
davizucon Apr 19, 2017
80ff870
Clean up code.
davizucon Apr 19, 2017
9fa94a3
ignore venv*
davizucon Apr 19, 2017
6b8a7d2
Last fix check before PR
davizucon Apr 19, 2017
b39ca44
Fix test-requirements
davizucon Apr 19, 2017
18c3d69
Renamed file, import conflicts in CI
davizucon Apr 19, 2017
e015ea6
Fix import rename.
davizucon Apr 19, 2017
44ed12d
Travis first install ChatterBot after install requirements...
davizucon Apr 19, 2017
70caa0b
Fiz import
davizucon Apr 19, 2017
5c8d7b6
Fiz imports
davizucon Apr 19, 2017
d0b68b8
Removed temp test class.
davizucon Apr 19, 2017
ba6dbb2
fix import.
davizucon Apr 19, 2017
be8d2bb
Rolback changes.
davizucon Apr 19, 2017
401ba7a
Random database name for tests
davizucon Apr 20, 2017
2e88087
Fix filter method, order of filters returning wrong values.
davizucon Apr 20, 2017
3af0b41
Fix filter method, order of filters returning wrong values.
davizucon Apr 20, 2017
1c01393
Changes by code review: Clean up, removed unnecessary code comments a…
davizucon Apr 24, 2017
6ffb847
Fixing import, not depend on SQLAlchemy.
davizucon Apr 24, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ docs/_build/

examples/settings.py
examples/ubuntu_dialogs*
.env
.out
venv*
1 change: 1 addition & 0 deletions chatterbot/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .django_storage import DjangoStorageAdapter
from .jsonfile import JsonFileStorageAdapter
from .mongodb import MongoDatabaseAdapter
from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter
273 changes: 273 additions & 0 deletions chatterbot/storage/sqlalchemy_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import json
import random

from chatterbot.storage import StorageAdapter
from chatterbot.conversation import Response
from chatterbot.conversation import Statement

_base = None

try:
from sqlalchemy.ext.declarative import declarative_base

_base = declarative_base()


class StatementTable(_base):
from sqlalchemy import Column, Integer, String, PickleType
from sqlalchemy.orm import relationship

__tablename__ = 'StatementTable'

def get_statement(self):
stmt = Statement(self.text, **self.extra_data)
for resp in self.in_response_to:
stmt.add_response(resp.get_response())
return stmt

def get_statement_serialized(context):
params = context.current_parameters
del (params['text_search'])
return json.dumps(params)

id = Column(Integer)
text = Column(String, primary_key=True)
extra_data = Column(PickleType)
# relationship:
in_response_to = relationship("ResponseTable", back_populates="statement_table")
text_search = Column(String, primary_key=True, default=get_statement_serialized)


class ResponseTable(_base):
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
__tablename__ = 'ResponseTable'

def get_reponse_serialized(context):
params = context.current_parameters
del (params['text_search'])
return json.dumps(params)

id = Column(Integer)
text = Column(String, primary_key=True)
occurrence = Column(Integer)
statement_text = Column(String, ForeignKey('StatementTable.text'))

statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False)
text_search = Column(String, primary_key=True, default=get_reponse_serialized)

def get_response(self):
occ = {"occurrence": self.occurrence}
return Response(text=self.text, **occ)

except ImportError:
pass


def get_statement_table(statement):
responses = list(map(get_response_table, statement.in_response_to))
return StatementTable(text=statement.text, in_response_to=responses, extra_data=statement.extra_data)


def get_response_table(response):
return ResponseTable(text=response.text, occurrence=response.occurrence)


class SQLAlchemyDatabaseAdapter(StorageAdapter):
read_only = False
drop_create = False

def __init__(self, **kwargs):
super(SQLAlchemyDatabaseAdapter, self).__init__(**kwargs)

from sqlalchemy import create_engine

self.database_name = self.kwargs.get(
"database", "chatterbot-database"
)

# if some annoying blank space wrong...
db_name = self.database_name.strip()

# default uses sqlite
self.database_uri = self.kwargs.get(
"database_uri", "sqlite:///" + db_name + ".db"
)

self.engine = create_engine(self.database_uri)

self.read_only = self.kwargs.get(
"read_only", False
)

self.drop_create = self.kwargs.get(
"drop_create", False
)

if not self.read_only and self.drop_create:
_base.metadata.drop_all(self.engine)
_base.metadata.create_all(self.engine)

def count(self):
"""
Return the number of entries in the database.
"""
session = self.__get_session()
return session.query(StatementTable).count()

def __get_session(self):
"""
:rtype: Session
"""
from sqlalchemy.orm import sessionmaker

Session = sessionmaker(bind=self.engine)
session = Session()
return session

def __statement_filter(self, session, **kwargs):
"""
Apply filter operation on StatementTable

rtype: query
"""
_query = session.query(StatementTable)
return _query.filter_by(**kwargs)

def find(self, statement_text):
"""
Returns a statement if it exists otherwise None
"""
session = self.__get_session()
query = self.__statement_filter(session, **{"text": statement_text})
record = query.first()
if record:
return record.get_statement()
return None

def remove(self, statement_text):
"""
Removes the statement that matches the input text.
Removes any responses from statements where the response text matches
the input text.
"""
session = self.__get_session()
query = self.__statement_filter(session, **{"text": statement_text})
record = query.first()
session.delete(record)

self._session_finish(session, statement_text)

def filter(self, **kwargs):
"""
Returns a list of objects from the database.
The kwargs parameter can contain any number
of attributes. Only objects which contain
all listed attributes and in which all values
match for all listed attributes will be returned.
"""

filter_parameters = kwargs.copy()

session = self.__get_session()
statements = []
# _response_query = None
_query = None
if len(filter_parameters) == 0:
_response_query = session.query(StatementTable)
statements.extend(_response_query.all())
else:
for i, fp in enumerate(filter_parameters):
_filter = filter_parameters[fp]
if fp in ['in_response_to', 'in_response_to__contains']:
_response_query = session.query(StatementTable)
if isinstance(_filter, list):
if len(_filter) == 0:
_query = _response_query.filter(
StatementTable.in_response_to == None) # Here must use == instead of is
else:
for f in _filter:
_query = _response_query.filter(
StatementTable.in_response_to.contains(get_response_table(f)))
else:
if fp == 'in_response_to__contains':
_query = _response_query.join(ResponseTable).filter(ResponseTable.text == _filter)
else:
_query = _response_query.filter(StatementTable.in_response_to == None)
else:
if _query:
_query = _query.filter(ResponseTable.text_search.like('%' + _filter + '%'))
else:
_response_query = session.query(ResponseTable)
_query = _response_query.filter(ResponseTable.text_search.like('%' + _filter + '%'))

if _query is None:
return []
if len(filter_parameters) == i + 1:
statements.extend(_query.all())

results = []

for statement in statements:
if isinstance(statement, ResponseTable):
if statement and statement.statement_table:
results.append(statement.statement_table.get_statement())
else:
if statement:
results.append(statement.get_statement())

return results

def update(self, statement):
"""
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
session = self.__get_session()
if statement:
query = self.__statement_filter(session, **{"text": statement.text})
record = query.first()

if record:
# update
if statement.text:
record.text = statement.text
if statement.extra_data:
record.extra_data = dict[statement.extra_data]
if statement.in_response_to:
record.in_response_to = list(map(get_response_table, statement.in_response_to))
session.add(record)
else:
session.add(get_statement_table(statement))

self._session_finish(session)

def get_random(self):
"""
Returns a random statement from the database
"""
count = self.count()
if count < 1:
raise self.EmptyDatabaseException()

rand = random.randrange(0, count)
session = self.__get_session()
stmt = session.query(StatementTable)[rand]

return stmt.get_statement()

def drop(self):
"""
Drop the database attached to a given adapter.
"""
_base.metadata.drop_all(self.engine)

def _session_finish(self, session, statement_text=None):
from sqlalchemy.exc import DatabaseError
try:
if not self.read_only:
session.commit()
else:
session.rollback()
except DatabaseError as e:
self.logger.error(statement_text, str(e.orig))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ jsondatabase>=0.1.7,<1.0.0
nltk>=3.2.0,<4.0.0
pymongo>=3.3.0,<4.0.0
python-twitter>=3.0.0,<4.0.0
SQLAlchemy==1.1.7
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ nose
nose-exclude>=0.5.0,<0.6.0
twython
sphinx
sphinx_rtd_theme
sphinx_rtd_theme
2 changes: 1 addition & 1 deletion tests/base_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def get_kwargs(self):
kwargs = super(ChatBotMongoTestCase, self).get_kwargs()
kwargs['database'] = self.random_string()
kwargs['storage_adapter'] = 'chatterbot.storage.MongoDatabaseAdapter'
return kwargs
return kwargs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from tests.base_case import ChatBotTestCase


class SqlAlchemyStorageIntegrationTests(ChatBotTestCase):

def test_database_is_updated(self):
"""
Test that the database is updated when read_only is set to false.
"""
input_text = 'What is the airspeed velocity of an unladen swallow?'
exists_before = self.chatbot.storage.find(input_text)

response = self.chatbot.get_response(input_text)
exists_after = self.chatbot.storage.find(input_text)

self.assertFalse(exists_before)
self.assertTrue(exists_after)
Loading