Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pysogs zmq bot api #7

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions sogs/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from collections import defaultdict

from oxenmq import AuthLevel

from . import model
from .omq import omq
from .web import app

from types import Iterable

from oxenc import bt_serialize

from .routes.subrequest import make_subrequest

from flask import g

# pools for event propagation
_pools = defaultdict(set)

status_OK = 'OK'
status_ERR = 'ERROR'

# the events we are able to subscribe to
EVENTS = ('message', 'joined', 'parted', 'banned', 'unbanned', 'deleted', 'uploaded')


def event_name_valid(eventname):
""" return True if this event name is something well formed """
return eventname in EVENTS


def _user_from_conn(conn):
"""
make a model.User from a connection using its curve pubkey as the session id.
"""
# TODO: blinding?
return model.User(session_id='05' + conn.pubkey.hex())


def _maybe_serialize(data):
"""maybe bt encode data, if data is a bytes dont encode,
if data is a string turn it into bytes and dont encode, otherwise bt encode"""
if isinstance(data, bytes):
return data
if isinstance(data, str):
return data.encode()
return bt_serialize(data)


def _propagate_event(eventname, *args):
""" propagate an event to everyone who cares about it """
assert event_name_valid(eventname)
global omq, _pools
sent = 0
for conn in _pools[eventname]:
omq.send(conn, f'sogs.event.{eventname}', *(_maybe_serialize(a) for a in args))
sent += 1
if sent:
app.logger.info(f"sent {eventname} to {sent} subscribers")


_category = omq.add_category('sogs', AuthLevel.basic)


def api(f, *, name=None, minargs=None):
""" set up a request handler for zmq for a function with name of the endpoint """
if name is None:
raise ValueError('api endpoint name cannot be none')

def _handle_request(msg):
try:
if minargs and len(msg.data) < minargs:
raise ValueError(f"Not enough arguments, got {len(msg.data)} expected 2 or more")
app.logger.debug(f"zmq request: {name} for {msg.conn}")
g.user = _user_from_conn(msg.conn)
retval = f(*msg.data, conn=msg.conn)
if retval is None:
msg.reply(status_OK)
elif isinstance(retval, tuple):
msg.reply(status_OK, *retval)
else:
msg.reply(status_OK, bt_serialize(retval))
except Exception as ex:
app.logger.error(f"{f.__name__} raised exception: {ex}")
msg.reply(status_ERR, f'{ex}')
finally:
g.user = None

global _category
_category.add_request_command(name, _handle_request)
app.logger.info(f"register zmq api handler: sogs.{name}")
return f


def _collect_bytes(iterable: Iterable[bytes]):
""" collect all bytes from an iterable of bytes and put it into one big bytes instance """
data = bytes()
for part in iterable:
data += part
return data


@api(name='sub', minargs=1)
def subscribe(*events, conn=None):
""" subscribe connection to many events """
sub = set()
for ev in events:
name = ev.decode('ascii')
if not event_name_valid(name):
raise Exception(f"invalid event type: {name}")
sub += name

global _pools
for name in sub:
_pools[name].add(conn)
app.logger.debug(f"sub {conn} to {len(sub)} events")


@api(name='unsub', minargs=1)
def unsubscribe(*events, conn=None):
""" unsub connection to many events """
unsub = set()
for ev in events:
name = ev.decode('ascii')
if not event_name_valid(name):
raise Exception(f"invalid event type: {name}")
unsub += name

global _pools
for name in unsub:
if conn in _pools[name]:
_pools[name].remove(conn)
app.logger.debug(f"unsub {conn} to {len(unsub)} events")


@api(name="request", minargs=2)
def request(method, path, body=None, *, conn=None):
""" make an rpc request via zmq """
ctype = None
# guess content type
if body:
if body[0] in (b'{', b'['):
ctype = 'application/json'
else:
ctype = 'application/octet-stream'
resp = make_subrequest(
method.decode('ascii'), path.decode('ascii'), content_type=ctype, body=body
)
return resp.status_code, _collect_bytes(resp.response)


class _Notify:
""" Holder type for all event notification functions """


notify = _Notify()

# set up event notifiers
for ev in EVENTS:
setattr(notify, ev, lambda *args: _propagate_event(ev, *args))
36 changes: 31 additions & 5 deletions sogs/mule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import traceback
import oxenmq
from oxenc import bt_deserialize
import time
from datetime import timedelta
import functools
Expand All @@ -9,6 +8,7 @@
from . import cleanup
from . import config
from . import omq as o
from .events import notify

# This is the uwsgi "mule" that handles things not related to serving HTTP requests:
# - it holds the oxenmq instance (with its own interface into sogs)
Expand Down Expand Up @@ -52,6 +52,10 @@ def setup_omq():
for addr in listen:
omq.listen(addr, curve=True, allow_connection=allow_conn)
app.logger.info(f"OxenMQ listening on {addr}")
if not listen:
app.logger.warn(
"OxenMQ did not listen on any curve addresses, the bot API is not accessable anywhere."
)

# Internal socket for workers to talk to us:
omq.listen(config.OMQ_INTERNAL, curve=False, allow_connection=admin_conn)
Expand All @@ -64,6 +68,10 @@ def setup_omq():
worker.add_command("message_posted", message_posted)
worker.add_command("messages_deleted", messages_deleted)
worker.add_command("message_edited", message_edited)
worker.add_command("user_joined", user_joined)
worker.add_command("user_banned", user_banned)
worker.add_command("user_unbanned", user_unbanned)
worker.add_command("file_uploaded", file_uploaded)

app.logger.debug("Mule starting omq")
omq.start()
Expand All @@ -88,14 +96,32 @@ def wrapper(*args, **kwargs):

@log_exceptions
def message_posted(m: oxenmq.Message):
id = bt_deserialize(m.data()[0])
app.logger.warning(f"FIXME: mule -- message posted stub, id={id}")
notify.message(*m.data())


@log_exceptions
def messages_deleted(m: oxenmq.Message):
ids = bt_deserialize(m.data()[0])
app.logger.warning(f"FIXME: mule -- message delete stub, deleted messages: {ids}")
notify.deleted(*m.data())


@log_exceptions
def user_banned(m: oxenmq.Message):
notify.banned(*m.data())


@log_exceptions
def user_unbanned(m: oxenmq.Message):
notify.unbannd(*m.data())


@log_exceptions
def user_joined(m: oxenmq.Message):
notify.joined(*m.data())


@log_exceptions
def file_uploaded(m: oxenmq.Message):
notify.uploaded(*m.data())


@log_exceptions
Expand Down