diff --git a/sogs/events.py b/sogs/events.py new file mode 100644 index 00000000..37204707 --- /dev/null +++ b/sogs/events.py @@ -0,0 +1,160 @@ +from collections import defaultdict + +from oxenmq import AuthLevel + +from . import model +from .omq import omq +from .web import app + +from types import Iterable + +from oxenc import bt_serialize + +from .routes.subrequest import make_subrequest + +from flask import g + +# pools for event propagation +_pools = defaultdict(set) + +status_OK = 'OK' +status_ERR = 'ERROR' + +# the events we are able to subscribe to +EVENTS = ('message', 'joined', 'parted', 'banned', 'unbanned', 'deleted', 'uploaded') + + +def event_name_valid(eventname): + """ return True if this event name is something well formed """ + return eventname in EVENTS + + +def _user_from_conn(conn): + """ + make a model.User from a connection using its curve pubkey as the session id. + """ + # TODO: blinding? + return model.User(session_id='05' + conn.pubkey.hex()) + + +def _maybe_serialize(data): + """maybe bt encode data, if data is a bytes dont encode, + if data is a string turn it into bytes and dont encode, otherwise bt encode""" + if isinstance(data, bytes): + return data + if isinstance(data, str): + return data.encode() + return bt_serialize(data) + + +def _propagate_event(eventname, *args): + """ propagate an event to everyone who cares about it """ + assert event_name_valid(eventname) + global omq, _pools + sent = 0 + for conn in _pools[eventname]: + omq.send(conn, f'sogs.event.{eventname}', *(_maybe_serialize(a) for a in args)) + sent += 1 + if sent: + app.logger.info(f"sent {eventname} to {sent} subscribers") + + +_category = omq.add_category('sogs', AuthLevel.basic) + + +def api(f, *, name=None, minargs=None): + """ set up a request handler for zmq for a function with name of the endpoint """ + if name is None: + raise ValueError('api endpoint name cannot be none') + + def _handle_request(msg): + try: + if minargs and len(msg.data) < minargs: + raise ValueError(f"Not enough arguments, got {len(msg.data)} expected 2 or more") + app.logger.debug(f"zmq request: {name} for {msg.conn}") + g.user = _user_from_conn(msg.conn) + retval = f(*msg.data, conn=msg.conn) + if retval is None: + msg.reply(status_OK) + elif isinstance(retval, tuple): + msg.reply(status_OK, *retval) + else: + msg.reply(status_OK, bt_serialize(retval)) + except Exception as ex: + app.logger.error(f"{f.__name__} raised exception: {ex}") + msg.reply(status_ERR, f'{ex}') + finally: + g.user = None + + global _category + _category.add_request_command(name, _handle_request) + app.logger.info(f"register zmq api handler: sogs.{name}") + return f + + +def _collect_bytes(iterable: Iterable[bytes]): + """ collect all bytes from an iterable of bytes and put it into one big bytes instance """ + data = bytes() + for part in iterable: + data += part + return data + + +@api(name='sub', minargs=1) +def subscribe(*events, conn=None): + """ subscribe connection to many events """ + sub = set() + for ev in events: + name = ev.decode('ascii') + if not event_name_valid(name): + raise Exception(f"invalid event type: {name}") + sub += name + + global _pools + for name in sub: + _pools[name].add(conn) + app.logger.debug(f"sub {conn} to {len(sub)} events") + + +@api(name='unsub', minargs=1) +def unsubscribe(*events, conn=None): + """ unsub connection to many events """ + unsub = set() + for ev in events: + name = ev.decode('ascii') + if not event_name_valid(name): + raise Exception(f"invalid event type: {name}") + unsub += name + + global _pools + for name in unsub: + if conn in _pools[name]: + _pools[name].remove(conn) + app.logger.debug(f"unsub {conn} to {len(unsub)} events") + + +@api(name="request", minargs=2) +def request(method, path, body=None, *, conn=None): + """ make an rpc request via zmq """ + ctype = None + # guess content type + if body: + if body[0] in (b'{', b'['): + ctype = 'application/json' + else: + ctype = 'application/octet-stream' + resp = make_subrequest( + method.decode('ascii'), path.decode('ascii'), content_type=ctype, body=body + ) + return resp.status_code, _collect_bytes(resp.response) + + +class _Notify: + """ Holder type for all event notification functions """ + + +notify = _Notify() + +# set up event notifiers +for ev in EVENTS: + setattr(notify, ev, lambda *args: _propagate_event(ev, *args)) diff --git a/sogs/mule.py b/sogs/mule.py index 63cdae7b..a1280940 100644 --- a/sogs/mule.py +++ b/sogs/mule.py @@ -1,6 +1,5 @@ import traceback import oxenmq -from oxenc import bt_deserialize import time from datetime import timedelta import functools @@ -9,6 +8,7 @@ from . import cleanup from . import config from . import omq as o +from .events import notify # This is the uwsgi "mule" that handles things not related to serving HTTP requests: # - it holds the oxenmq instance (with its own interface into sogs) @@ -52,6 +52,10 @@ def setup_omq(): for addr in listen: omq.listen(addr, curve=True, allow_connection=allow_conn) app.logger.info(f"OxenMQ listening on {addr}") + if not listen: + app.logger.warn( + "OxenMQ did not listen on any curve addresses, the bot API is not accessable anywhere." + ) # Internal socket for workers to talk to us: omq.listen(config.OMQ_INTERNAL, curve=False, allow_connection=admin_conn) @@ -64,6 +68,10 @@ def setup_omq(): worker.add_command("message_posted", message_posted) worker.add_command("messages_deleted", messages_deleted) worker.add_command("message_edited", message_edited) + worker.add_command("user_joined", user_joined) + worker.add_command("user_banned", user_banned) + worker.add_command("user_unbanned", user_unbanned) + worker.add_command("file_uploaded", file_uploaded) app.logger.debug("Mule starting omq") omq.start() @@ -88,14 +96,32 @@ def wrapper(*args, **kwargs): @log_exceptions def message_posted(m: oxenmq.Message): - id = bt_deserialize(m.data()[0]) - app.logger.warning(f"FIXME: mule -- message posted stub, id={id}") + notify.message(*m.data()) @log_exceptions def messages_deleted(m: oxenmq.Message): - ids = bt_deserialize(m.data()[0]) - app.logger.warning(f"FIXME: mule -- message delete stub, deleted messages: {ids}") + notify.deleted(*m.data()) + + +@log_exceptions +def user_banned(m: oxenmq.Message): + notify.banned(*m.data()) + + +@log_exceptions +def user_unbanned(m: oxenmq.Message): + notify.unbannd(*m.data()) + + +@log_exceptions +def user_joined(m: oxenmq.Message): + notify.joined(*m.data()) + + +@log_exceptions +def file_uploaded(m: oxenmq.Message): + notify.uploaded(*m.data()) @log_exceptions