Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/session_per_bus_connection #50

Merged
merged 15 commits into from
Sep 29, 2023
26 changes: 23 additions & 3 deletions ovos_bus_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MessageBusClient(_MessageBusClientBase):
_config_cache = None

def __init__(self, host=None, port=None, route=None, ssl=None,
emitter=None, cache=False):
emitter=None, cache=False, session=None):
config_overrides = dict(host=host, port=port, route=route, ssl=ssl)
if cache and self._config_cache:
config = self._config_cache
Expand All @@ -59,6 +59,14 @@ def __init__(self, host=None, port=None, route=None, ssl=None,
self.connected_event = Event()
self.started_running = False
self.wrapped_funcs = {}
if session:
SessionManager.update(session)
else:
session = SessionManager.default_session

self.session_id = session.session_id
self.on("ovos.session.update_default",
self.on_default_session_update)

@staticmethod
def build_url(host: str, port: int, route: str, ssl: bool) -> str:
Expand Down Expand Up @@ -88,6 +96,7 @@ def on_open(self, *args):
self.emitter.emit("open")
# Restore reconnect timer to 5 seconds on sucessful connect
self.retry = 5
self.emit(Message("ovos.session.sync")) # request default session update

def on_close(self, *args):
"""
Expand Down Expand Up @@ -140,10 +149,19 @@ def on_message(self, *args):
else:
message = args[1]
parsed_message = Message.deserialize(message)
SessionManager.update(Session.from_message(parsed_message))
sess = Session.from_message(parsed_message)
if sess.session_id != "default":
# 'default' can only be updated by core
SessionManager.update(sess)
self.emitter.emit('message', message)
self.emitter.emit(parsed_message.msg_type, parsed_message)

def on_default_session_update(self, message):
new_session = message.data["session_data"]
sess = Session.deserialize(new_session)
SessionManager.update(sess, make_default=True)
LOG.debug("synced default_session")

def emit(self, message: Message):
"""
Send a message onto the message bus.
Expand All @@ -155,9 +173,11 @@ def emit(self, message: Message):
message (Message): Message to send
"""
if "session" not in message.context:
sess = SessionManager.get(message)
sess = SessionManager.sessions.get(self.session_id) or \
SessionManager.default_session
message.context["session"] = sess.serialize()
sess.update_history(message)
sess.touch()

if not self.connected_event.wait(10):
if not self.started_running:
Expand Down
91 changes: 38 additions & 53 deletions ovos_bus_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from typing import Optional, List, Tuple, Union, Iterable
from uuid import uuid4

from ovos_bus_client.message import dig_for_message, Message
from ovos_config.config import Configuration
from ovos_config.locale import get_default_lang
from ovos_utils.log import LOG, log_deprecation

from ovos_bus_client.message import dig_for_message, Message


class UtteranceState(str, enum.Enum):
INTENT = "intent" # includes converse
Expand Down Expand Up @@ -484,47 +483,46 @@ def from_message(message: Message = None):
@return: Session object
"""
message = message or dig_for_message()
if message:
if message and "session" in message.context:
lang = message.context.get("lang") or \
message.data.get("lang")
sid = None
if "session_id" in message.context:
sid = message.context["session_id"]
if "session" in message.context:
sess = message.context["session"]
if sid and "session_id" not in sess:
sess["session_id"] = sid
if "lang" not in sess:
sess["lang"] = lang
sess = Session.deserialize(sess)
elif sid:
sess = SessionManager.sessions.get(sid) or \
Session(sid)
if lang:
sess.lang = lang
else:
sess = SessionManager.default_session
if not sess:
LOG.debug(f"Creating default session on reference")
sess = SessionManager.reset_default_session()
if sess and lang and sess.lang != lang:
sess.lang = lang
LOG.info(f"Updated default session lang to: {lang}")
sess = message.context["session"]
if "lang" not in sess:
sess["lang"] = lang
sess = Session.deserialize(sess)
else:
# new session
LOG.warning(f"No message found, using default session")
sess = SessionManager.default_session
if sess and sess.expired():
LOG.debug(f"Resolved session expired {sess.session_id}")
sess.touch()
LOG.debug(f"unexpiring session {sess.session_id}")
return sess


class SessionManager:
""" Keeps track of the current active session. """
default_session: Session = None
default_session: Session = Session("default")
__lock = Lock()
sessions = {}
sessions = {"default": default_session}
bus = None

@classmethod
def sync(cls, message=None):
if cls.bus:
message = message or Message("ovos.session.sync")
cls.bus.emit(message.reply("ovos.session.update_default",
{"session_data": cls.default_session.serialize()}))

@classmethod
def connect_to_bus(cls, bus):
cls.bus = bus
cls.bus.on("ovos.session.sync",
cls.handle_default_session_request)
cls.sync()

@classmethod
def handle_default_session_request(cls, message=None):
cls.sync(message)

@staticmethod
def prune_sessions():
Expand All @@ -545,17 +543,10 @@ def reset_default_session() -> Session:
Define and return a new default_session
"""
with SessionManager.__lock:
sess = Session()
LOG.info(f"New Default Session Start: {sess.session_id}")
if not SessionManager.default_session:
SessionManager.default_session = sess
if SessionManager.default_session.session_id in \
SessionManager.sessions:
LOG.debug(f"Removing expired default session from sessions")
SessionManager.sessions.pop(
SessionManager.default_session.session_id)
SessionManager.default_session = sess
SessionManager.sessions[sess.session_id] = sess
sess = Session("default")
LOG.info(f"Default Session reset")
SessionManager.default_session = SessionManager.sessions["default"] = sess
SessionManager.sync()
return SessionManager.default_session

@staticmethod
Expand All @@ -568,9 +559,13 @@ def update(sess: Session, make_default: bool = False):
if not sess:
raise ValueError(f"Expected Session and got None")
sess.touch()
SessionManager.sessions[sess.session_id] = sess
if make_default:
sess.session_id = "default"
LOG.debug(f"replacing default session with: {sess.serialize()}")
SessionManager.default_session = sess
else:
LOG.debug(f"session updated: {sess.session_id}")
SessionManager.sessions[sess.session_id] = sess

@staticmethod
def get(message: Optional[Message] = None) -> Session:
Expand All @@ -590,20 +585,10 @@ def get(message: Optional[Message] = None) -> Session:
SessionManager.sessions[msg_sess.session_id] = msg_sess
return msg_sess
else:
LOG.debug(f"No session from message.")
LOG.debug(f"No session from message, use default session")
else:
LOG.debug(f"No message, use default session")

# Default session, check if it needs to be (re)-created
if not sess or sess.expired():
if sess is not None and sess.session_id in SessionManager.sessions:
LOG.debug(f"Removing expired default: {sess.session_id}")
SessionManager.sessions.pop(sess.session_id)
sess = SessionManager.reset_default_session()
else:
# Existing default, make sure lang is in sync with Configuration
sess.lang = Configuration().get('lang') or sess.lang

return sess

@staticmethod
Expand Down
16 changes: 3 additions & 13 deletions test/unittests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,9 @@ def test_update(self):
# TODO
pass

@patch("ovos_bus_client.session.Configuration")
def test_get(self, config):
config.return_value = {'lang': 'en-us'}
self.assertEqual(config(), {'lang': 'en-us'})
from ovos_bus_client.session import Session
session = self.SessionManager.get()
self.assertIsInstance(session, Session)
self.assertEqual(session.lang, 'en-us')
config.return_value = {'lang': 'es-es'}

session = self.SessionManager.get()
self.assertIsInstance(session, Session)
self.assertEqual(session.lang, 'es-es')
def test_get(self):
# TODO - rewrite test, .get has no side effects now, lang update happens in ovos-core
pass

def test_touch(self):
# TODO
Expand Down
Loading