Skip to content

Commit

Permalink
feat/session_per_bus_connection (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Sep 29, 2023
1 parent b269b20 commit 9b0f422
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 69 deletions.
26 changes: 23 additions & 3 deletions ovos_bus_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MessageBusClient(_MessageBusClientBase):
_config_cache = None

def __init__(self, host=None, port=None, route=None, ssl=None,
emitter=None, cache=False):
emitter=None, cache=False, session=None):
config_overrides = dict(host=host, port=port, route=route, ssl=ssl)
if cache and self._config_cache:
config = self._config_cache
Expand All @@ -59,6 +59,14 @@ def __init__(self, host=None, port=None, route=None, ssl=None,
self.connected_event = Event()
self.started_running = False
self.wrapped_funcs = {}
if session:
SessionManager.update(session)
else:
session = SessionManager.default_session

self.session_id = session.session_id
self.on("ovos.session.update_default",
self.on_default_session_update)

@staticmethod
def build_url(host: str, port: int, route: str, ssl: bool) -> str:
Expand Down Expand Up @@ -88,6 +96,7 @@ def on_open(self, *args):
self.emitter.emit("open")
# Restore reconnect timer to 5 seconds on sucessful connect
self.retry = 5
self.emit(Message("ovos.session.sync")) # request default session update

def on_close(self, *args):
"""
Expand Down Expand Up @@ -140,10 +149,19 @@ def on_message(self, *args):
else:
message = args[1]
parsed_message = Message.deserialize(message)
SessionManager.update(Session.from_message(parsed_message))
sess = Session.from_message(parsed_message)
if sess.session_id != "default":
# 'default' can only be updated by core
SessionManager.update(sess)
self.emitter.emit('message', message)
self.emitter.emit(parsed_message.msg_type, parsed_message)

def on_default_session_update(self, message):
new_session = message.data["session_data"]
sess = Session.deserialize(new_session)
SessionManager.update(sess, make_default=True)
LOG.debug("synced default_session")

def emit(self, message: Message):
"""
Send a message onto the message bus.
Expand All @@ -155,9 +173,11 @@ def emit(self, message: Message):
message (Message): Message to send
"""
if "session" not in message.context:
sess = SessionManager.get(message)
sess = SessionManager.sessions.get(self.session_id) or \
SessionManager.default_session
message.context["session"] = sess.serialize()
sess.update_history(message)
sess.touch()

if not self.connected_event.wait(10):
if not self.started_running:
Expand Down
91 changes: 38 additions & 53 deletions ovos_bus_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from typing import Optional, List, Tuple, Union, Iterable
from uuid import uuid4

from ovos_bus_client.message import dig_for_message, Message
from ovos_config.config import Configuration
from ovos_config.locale import get_default_lang
from ovos_utils.log import LOG, log_deprecation

from ovos_bus_client.message import dig_for_message, Message


class UtteranceState(str, enum.Enum):
INTENT = "intent" # includes converse
Expand Down Expand Up @@ -484,47 +483,46 @@ def from_message(message: Message = None):
@return: Session object
"""
message = message or dig_for_message()
if message:
if message and "session" in message.context:
lang = message.context.get("lang") or \
message.data.get("lang")
sid = None
if "session_id" in message.context:
sid = message.context["session_id"]
if "session" in message.context:
sess = message.context["session"]
if sid and "session_id" not in sess:
sess["session_id"] = sid
if "lang" not in sess:
sess["lang"] = lang
sess = Session.deserialize(sess)
elif sid:
sess = SessionManager.sessions.get(sid) or \
Session(sid)
if lang:
sess.lang = lang
else:
sess = SessionManager.default_session
if not sess:
LOG.debug(f"Creating default session on reference")
sess = SessionManager.reset_default_session()
if sess and lang and sess.lang != lang:
sess.lang = lang
LOG.info(f"Updated default session lang to: {lang}")
sess = message.context["session"]
if "lang" not in sess:
sess["lang"] = lang
sess = Session.deserialize(sess)
else:
# new session
LOG.warning(f"No message found, using default session")
sess = SessionManager.default_session
if sess and sess.expired():
LOG.debug(f"Resolved session expired {sess.session_id}")
sess.touch()
LOG.debug(f"unexpiring session {sess.session_id}")
return sess


class SessionManager:
""" Keeps track of the current active session. """
default_session: Session = None
default_session: Session = Session("default")
__lock = Lock()
sessions = {}
sessions = {"default": default_session}
bus = None

@classmethod
def sync(cls, message=None):
if cls.bus:
message = message or Message("ovos.session.sync")
cls.bus.emit(message.reply("ovos.session.update_default",
{"session_data": cls.default_session.serialize()}))

@classmethod
def connect_to_bus(cls, bus):
cls.bus = bus
cls.bus.on("ovos.session.sync",
cls.handle_default_session_request)
cls.sync()

@classmethod
def handle_default_session_request(cls, message=None):
cls.sync(message)

@staticmethod
def prune_sessions():
Expand All @@ -545,17 +543,10 @@ def reset_default_session() -> Session:
Define and return a new default_session
"""
with SessionManager.__lock:
sess = Session()
LOG.info(f"New Default Session Start: {sess.session_id}")
if not SessionManager.default_session:
SessionManager.default_session = sess
if SessionManager.default_session.session_id in \
SessionManager.sessions:
LOG.debug(f"Removing expired default session from sessions")
SessionManager.sessions.pop(
SessionManager.default_session.session_id)
SessionManager.default_session = sess
SessionManager.sessions[sess.session_id] = sess
sess = Session("default")
LOG.info(f"Default Session reset")
SessionManager.default_session = SessionManager.sessions["default"] = sess
SessionManager.sync()
return SessionManager.default_session

@staticmethod
Expand All @@ -568,9 +559,13 @@ def update(sess: Session, make_default: bool = False):
if not sess:
raise ValueError(f"Expected Session and got None")
sess.touch()
SessionManager.sessions[sess.session_id] = sess
if make_default:
sess.session_id = "default"
LOG.debug(f"replacing default session with: {sess.serialize()}")
SessionManager.default_session = sess
else:
LOG.debug(f"session updated: {sess.session_id}")
SessionManager.sessions[sess.session_id] = sess

@staticmethod
def get(message: Optional[Message] = None) -> Session:
Expand All @@ -590,20 +585,10 @@ def get(message: Optional[Message] = None) -> Session:
SessionManager.sessions[msg_sess.session_id] = msg_sess
return msg_sess
else:
LOG.debug(f"No session from message.")
LOG.debug(f"No session from message, use default session")
else:
LOG.debug(f"No message, use default session")

# Default session, check if it needs to be (re)-created
if not sess or sess.expired():
if sess is not None and sess.session_id in SessionManager.sessions:
LOG.debug(f"Removing expired default: {sess.session_id}")
SessionManager.sessions.pop(sess.session_id)
sess = SessionManager.reset_default_session()
else:
# Existing default, make sure lang is in sync with Configuration
sess.lang = Configuration().get('lang') or sess.lang

return sess

@staticmethod
Expand Down
16 changes: 3 additions & 13 deletions test/unittests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,9 @@ def test_update(self):
# TODO
pass

@patch("ovos_bus_client.session.Configuration")
def test_get(self, config):
config.return_value = {'lang': 'en-us'}
self.assertEqual(config(), {'lang': 'en-us'})
from ovos_bus_client.session import Session
session = self.SessionManager.get()
self.assertIsInstance(session, Session)
self.assertEqual(session.lang, 'en-us')
config.return_value = {'lang': 'es-es'}

session = self.SessionManager.get()
self.assertIsInstance(session, Session)
self.assertEqual(session.lang, 'es-es')
def test_get(self):
# TODO - rewrite test, .get has no side effects now, lang update happens in ovos-core
pass

def test_touch(self):
# TODO
Expand Down

0 comments on commit 9b0f422

Please sign in to comment.