Skip to content

Commit

Permalink
Add get model methods to each adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Sep 29, 2017
1 parent 55c206a commit 2cef1ce
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 22 deletions.
12 changes: 12 additions & 0 deletions chatterbot/storage/jsonfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def __init__(self, **kwargs):

self.adapter_supports_queries = False

def get_statement_model(self):
"""
Return the class for the statement model.
"""
from chatterbot.conversation.statement import Statement

# Create a storage-aware statement
statement = Statement
statement.storage = self

return statement

def _keys(self):
# The value has to be cast as a list for Python 3 compatibility
return list(self.database[0].keys())
Expand Down
26 changes: 25 additions & 1 deletion chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from chatterbot.storage import StorageAdapter
from chatterbot.conversation import Response


class Query(object):
Expand Down Expand Up @@ -113,6 +112,30 @@ def __init__(self, **kwargs):

self.base_query = Query()

def get_statement_model(self):
"""
Return the class for the statement model.
"""
from chatterbot.conversation.statement import Statement

# Create a storage-aware statement
statement = Statement
statement.storage = self

return statement

def get_response_model(self):
"""
Return the class for the response model.
"""
from chatterbot.conversation.response import Response

# Create a storage-aware response
response = Response
response.storage = self

return response

def count(self):
return self.statements.count()

Expand Down Expand Up @@ -140,6 +163,7 @@ def deserialize_responses(self, response_list):
the list converted to Response objects.
"""
Statement = self.get_model('statement')
Response = self.get_model('response')
proxy_statement = Statement('')

for response in response_list:
Expand Down
43 changes: 34 additions & 9 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from chatterbot.storage import StorageAdapter


Expand Down Expand Up @@ -79,11 +78,32 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
# ChatterBot's internal query builder is not yet supported for this adapter
self.adapter_supports_queries = False

def get_statement_model(self):
"""
Return the statement model.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
return Statement

def get_response_model(self):
"""
Return the response model.
"""
from chatterbot.ext.sqlalchemy_app.models import Response
return Response

def get_conversation_model(self):
"""
Return the conversation model.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation
return Conversation

def count(self):
"""
Return the number of entries in the database.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

session = self.Session()
statement_count = session.query(Statement).count()
Expand All @@ -96,7 +116,7 @@ def __statement_filter(self, session, **kwargs):
rtype: query
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

_query = session.query(Statement)
return _query.filter_by(**kwargs)
Expand Down Expand Up @@ -138,7 +158,8 @@ def filter(self, **kwargs):
all listed attributes and in which all values
match for all listed attributes will be returned.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement, Response
Statement = self.get_model('statement')
Response = self.get_model('response')

session = self.Session()

Expand Down Expand Up @@ -199,7 +220,8 @@ def update(self, statement):
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement, Response
Statement = self.get_model('statement')
Response = self.get_model('response')

if statement:
session = self.Session()
Expand Down Expand Up @@ -240,7 +262,7 @@ def create_conversation(self):
"""
Create a new conversation.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation
Conversation = self.get_model('conversation')

session = self.Session()
conversation = Conversation()
Expand All @@ -260,7 +282,8 @@ def add_to_conversation(self, conversation_id, statement, response):
"""
Add the statement and response to the conversation.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation, Statement
Statement = self.get_model('statement')
Conversation = self.get_model('conversation')

session = self.Session()
conversation = session.query(Conversation).get(conversation_id)
Expand Down Expand Up @@ -296,7 +319,7 @@ def get_latest_response(self, conversation_id):
Returns the latest response in a conversation if it exists.
Returns None if a matching conversation cannot be found.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

session = self.Session()
statement = None
Expand All @@ -318,7 +341,9 @@ def get_random(self):
"""
Returns a random statement from the database
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
import random

Statement = self.get_model('statement')

session = self.Session()
count = self.count()
Expand Down
12 changes: 0 additions & 12 deletions chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ def get_model(self, model_name):

return get_model_method()

def get_statement_model(self):
"""
Return the class for the statement model.
"""
from chatterbot.conversation.statement import Statement

# Create a storage-aware statement
statement = Statement
statement.storage = self

return statement

def generate_base_query(self, chatterbot, session_id):
"""
Create a base query for the storage adapter.
Expand Down

0 comments on commit 2cef1ce

Please sign in to comment.