Skip to content

Commit

Permalink
[DataStore] closes #7
Browse files Browse the repository at this point in the history
Make Firebase calls asynchronous by
running them in a ThreadPoolExecutor:
firebase/firebase-admin-python#104 (comment)

This touches a lot of code, because a new formatter was
introduced as well.
  • Loading branch information
noahkw committed Feb 7, 2020
1 parent 90615f6 commit b46838d
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 120 deletions.
95 changes: 59 additions & 36 deletions DataStore.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,93 @@
from abc import ABC, abstractmethod
from functools import partial
from firebase_admin import credentials, firestore

import firebase_admin
from firebase_admin import credentials, firestore
import concurrent
import asyncio


class DataStore(ABC):
def __init__(self):
pass

@abstractmethod
def set(self, collection, document, val):
async def set(self, collection, document, val):
pass

@abstractmethod
async def set_get_id(self, collection, val):
pass

@abstractmethod
def update(self, collection, document, val):
async def update(self, collection, document, val):
pass

@abstractmethod
def add(self, collection, val):
async def add(self, collection, val):
pass

@abstractmethod
def get(self, collection, document):
async def get(self, collection, document):
pass

@abstractmethod
def delete(self, collection, document):
async def delete(self, collection, document):
pass


class FirebaseDataStore(DataStore):
def __init__(self, key_file, db_name):
def __init__(self, key_file, db_name, loop):
super().__init__()
cred = credentials.Certificate(key_file)

firebase_admin.initialize_app(cred, {
'databaseURL': f'https://{db_name}.firebaseio.com'
})
firebase_admin.initialize_app(
cred, {'databaseURL': f'https://{db_name}.firebaseio.com'})

self.db = firestore.client()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
self.loop = loop

async def set(self, collection, document, val):
ref = self._get_doc_ref(collection, document)
await self.loop.run_in_executor(self.executor, partial(ref.set, val))

async def set_get_id(self, collection, val):
ref = self._get_doc_ref(collection, None)
await self.loop.run_in_executor(self.executor, partial(ref.set, val))
return ref.id

async def update(self, collection, document, val):
ref = self._get_doc_ref(collection, document)
await self.loop.run_in_executor(self.executor,
partial(ref.update, val))

async def add(self, collection, val):
ref = self._get_collection(collection)
await self.loop.run_in_executor(self.executor, partial(ref.add, val))

async def get(self, collection, document=None):
if document is None:
ref = self._get_collection(collection)
return await self.loop.run_in_executor(self.executor, ref.stream)
else:
ref = self._get_doc_ref(collection, document)
return await self.loop.run_in_executor(self.executor, ref.get)

def set(self, collection, document, val):
self._get_doc_ref(collection, document).set(val)

def set_get_id(self, collection, val):
doc = self._get_doc_ref(collection, None)
doc.set(val)
return doc.id

def update(self, collection, document, val):
self._get_doc_ref(collection, document).update(val)

def add(self, collection, val):
self._get_collection(collection).add(val)

def get(self, collection, document=None):
return self._get_collection(collection).stream() if document is None else self._get_doc_ref(collection,
document).get()

def delete(self, collection, document=None):
async def delete(self, collection, document=None):
if document is not None:
self._get_doc_ref(collection, document).delete()
ref = self._get_doc_ref(collection, document)
await self.loop.run_in_executor(self.executor, ref.delete)
else:
# implement batching later
docs = self._get_collection(collection).stream()
for doc in docs:
doc.reference.delete()
await self.loop.run_in_executor(self.executor,
doc.reference.delete)

def query(self, collection, *query):
return self._get_collection(collection).where(*query).stream()
async def query(self, collection, *query):
ref = self._get_collection(collection).where(*query)
return await self.loop.run_in_executor(self.executor, ref.stream)

def _get_doc_ref(self, collection, document):
return self._get_collection(collection).document(document)
Expand All @@ -83,6 +102,10 @@ def _get_collection(self, collection):
config = configparser.ConfigParser()
config.read('conf.ini')

firebase_ds = FirebaseDataStore(
config['firebase']['key_file'], config['firebase']['db_name'])
firebase_ds.add('jobs', {'func': 'somefunc', 'time': 234903284, 'args': ['arg1', 'arg2']})
firebase_ds = FirebaseDataStore(config['firebase']['key_file'],
config['firebase']['db_name'])
firebase_ds.add('jobs', {
'func': 'somefunc',
'time': 234903284,
'args': ['arg1', 'arg2']
})
29 changes: 15 additions & 14 deletions botw-bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,29 @@
config.read('conf.ini')
client = commands.Bot(command_prefix=config['discord']['command_prefix'])
client.config = config
client.database = FirebaseDataStore(
config['firebase']['key_file'], config['firebase']['db_name'])
client.db = FirebaseDataStore(config['firebase']['key_file'],
config['firebase']['db_name'], client.loop)
logger = logging.getLogger('discord')
logger.setLevel(logging.INFO)
handler = logging.FileHandler(
filename='botw-bot.log', encoding='utf-8', mode='w')
handler.setFormatter(logging.Formatter(
'%(asctime)s:%(levelname)s:%(name)s: %(message)s'))
handler = logging.FileHandler(filename='botw-bot.log',
encoding='utf-8',
mode='w')
handler.setFormatter(
logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s'))
logger.addHandler(handler)

INITIAL_EXTENSIONS = [
'cogs.BiasOfTheWeek',
'cogs.Utilities',
'cogs.Scheduler',
'cogs.EmojiUtils',
'cogs.Tags',
'jishaku'
'cogs.BiasOfTheWeek', 'cogs.Utilities', 'cogs.Scheduler',
'cogs.EmojiUtils', 'cogs.Tags', 'jishaku'
]


@client.event
async def on_ready():
await client.change_presence(activity=discord.Game('with Bini'))
logger.info(f"Logged in as {client.user}. Whitelisted servers: {config.items('whitelisted_servers')}")
logger.info(
f"Logged in as {client.user}. Whitelisted servers: {config.items('whitelisted_servers')}"
)

for ext in INITIAL_EXTENSIONS:
ext_logger = logging.getLogger(ext)
Expand All @@ -54,7 +53,9 @@ async def globally_block_dms(ctx):

@client.check
async def whitelisted_server(ctx):
server_ids = [int(server) for key, server in config.items('whitelisted_servers')]
server_ids = [
int(server) for key, server in config.items('whitelisted_servers')
]
return ctx.guild.id in server_ids


Expand Down
89 changes: 62 additions & 27 deletions cogs/BiasOfTheWeek.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from cogs.Scheduler import Job


logger = logging.getLogger(__name__)


Expand All @@ -28,13 +27,11 @@ def __str__(self):
def __eq__(self, other):
if not isinstance(other, Idol):
return NotImplemented
return str.lower(self.group) == str.lower(other.group) and str.lower(self.name) == str.lower(other.name)
return str.lower(self.group) == str.lower(other.group) and str.lower(
self.name) == str.lower(other.name)

def to_dict(self):
return {
'group': self.group,
'name': self.name
}
return {'group': self.group, 'name': self.name}

@staticmethod
def from_dict(source):
Expand All @@ -45,47 +42,65 @@ class BiasOfTheWeek(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.nominations = {}
self.nominations_collection = self.bot.config['biasoftheweek']['nominations_collection']
self.nominations_collection = self.bot.config['biasoftheweek'][
'nominations_collection']

if self.bot.loop.is_running():
asyncio.create_task(self._ainit())
else:
self.bot.loop.run_until_complete(self._ainit())

async def _ainit(self):
_nominations = await self.bot.db.get(self.nominations_collection)

_nominations = self.bot.database.get(self.nominations_collection)
for nomination in _nominations:
self.nominations[self.bot.get_user(int(nomination.id))] = Idol.from_dict(nomination.to_dict())
self.nominations[self.bot.get_user(int(
nomination.id))] = Idol.from_dict(nomination.to_dict())

logger.info(f'Initial nominations from database: {self.nominations}')
logger.info(f'Initial nominations from db: {self.nominations}')

@staticmethod
def reaction_check(reaction, user, author, prompt_msg):
return user == author and str(reaction.emoji) in [CHECK_EMOJI, CROSS_EMOJI] and \
reaction.message.id == prompt_msg.id

@commands.command()
async def nominate(self, ctx, group: commands.clean_content, name: commands.clean_content):
async def nominate(self, ctx, group: commands.clean_content,
name: commands.clean_content):
idol = Idol(group, name)

if idol in self.nominations.values():
await ctx.send(f'**{idol}** has already been nominated. Please nominate someone else.')
await ctx.send(
f'**{idol}** has already been nominated. Please nominate someone else.'
)
elif ctx.author in self.nominations.keys():
old_idol = self.nominations[ctx.author]
prompt_msg = await ctx.send(f'Your current nomination is **{old_idol}**. Do you want to override it?')
prompt_msg = await ctx.send(
f'Your current nomination is **{old_idol}**. Do you want to override it?'
)
await prompt_msg.add_reaction(CHECK_EMOJI)
await prompt_msg.add_reaction(CROSS_EMOJI)
try:
reaction, user = await self.bot.wait_for('reaction_add', timeout=60.0,
check=lambda reaction, user: self.reaction_check(reaction,
user,
ctx.author,
prompt_msg))
reaction, user = await self.bot.wait_for(
'reaction_add',
timeout=60.0,
check=lambda reaction, user: self.reaction_check(
reaction, user, ctx.author, prompt_msg))
except asyncio.TimeoutError:
pass
else:
await prompt_msg.delete()
if reaction.emoji == CHECK_EMOJI:
self.nominations[ctx.author] = idol
self.bot.database.set(self.nominations_collection, str(ctx.author.id), idol.to_dict())
await ctx.send(f'{ctx.author} nominates **{idol}** instead of **{old_idol}**.')
await self.bot.db.set(self.nominations_collection,
str(ctx.author.id), idol.to_dict())
await ctx.send(
f'{ctx.author} nominates **{idol}** instead of **{old_idol}**.'
)
else:
self.nominations[ctx.author] = idol
self.bot.database.set(self.nominations_collection, str(ctx.author.id), idol.to_dict())
await self.bot.db.set(self.nominations_collection,
str(ctx.author.id), idol.to_dict())
await ctx.send(f'{ctx.author} nominates **{idol}**.')

@nominate.error
Expand All @@ -96,7 +111,7 @@ async def nominate_error(self, ctx, error):
@commands.has_permissions(administrator=True)
async def clear_nominations(self, ctx):
self.nominations = {}
self.bot.database.delete(self.nominations_collection)
await self.bot.db.delete(self.nominations_collection)
await ctx.message.add_reaction(CHECK_EMOJI)

@commands.command()
Expand All @@ -110,23 +125,43 @@ async def nominations(self, ctx):
else:
await ctx.send('So far, no idols have been nominated.')

@commands.command()
async def db_noms(self, ctx):
embed = discord.Embed(title='Bias of the Week nominations')
nominations = {}
_nominations = await self.bot.db.get(self.nominations_collection)
for nomination in _nominations:
nominations[self.bot.get_user(int(
nomination.id))] = Idol.from_dict(nomination.to_dict())

for key, value in nominations.items():
embed.add_field(name=key, value=value)

await ctx.send(embed=embed)

@commands.command(name='pickwinner')
@commands.has_permissions(administrator=True)
async def pick_winner(self, ctx, silent: bool = False, fast_assign: bool = False):
async def pick_winner(self,
ctx,
silent: bool = False,
fast_assign: bool = False):
member, pick = random.choice(list(self.nominations.items()))

# Assign BotW winner role on next wednesday at 00:00 UTC
now = pendulum.now('Europe/London')
assign_date = now.add(seconds=120) if fast_assign else now.next(
pendulum.WEDNESDAY)
assign_date = now.add(
seconds=120) if fast_assign else now.next(pendulum.WEDNESDAY)

await ctx.send(
f"""Bias of the Week ({now.week_of_year}-{now.year}): {member if silent else member.mention}\'s pick **{pick}**.
You will be assigned the role *{self.bot.config['biasoftheweek']['winner_role_name']}* at {assign_date.to_cookie_string()}.""")
You will be assigned the role *{self.bot.config['biasoftheweek']['winner_role_name']}* at {assign_date.to_cookie_string()}."""
)

scheduler = self.bot.get_cog('Scheduler')
if scheduler is not None:
await scheduler.add_job(Job('assign_winner_role', [ctx.guild.id, member.id], assign_date.float_timestamp))
await scheduler.add_job(
Job('assign_winner_role', [ctx.guild.id, member.id],
assign_date.float_timestamp))

@pick_winner.error
async def pick_winner_error(self, ctx, error):
Expand Down
20 changes: 14 additions & 6 deletions cogs/EmojiUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class EmojiUtils(commands.Cog):

def __init__(self, bot):
self.bot = bot
self.bot.add_listener(self.on_guild_emojis_update, 'on_guild_emojis_update')
self.emoji_channel = self.bot.get_channel(int(self.bot.config['emojiutils']['emoji_channel']))
self.bot.add_listener(self.on_guild_emojis_update,
'on_guild_emojis_update')
self.emoji_channel = self.bot.get_channel(
int(self.bot.config['emojiutils']['emoji_channel']))

@commands.group(name='emoji')
@commands.has_permissions(administrator=True)
Expand All @@ -41,16 +43,22 @@ async def emoji_list_error(self, ctx, error):
async def on_guild_emojis_update(self, guild, before, after):
# delete old messages containing emoji
# need to use Message.delete to be able to delete messages older than 14 days
async for message in self.emoji_channel.history(limit=EmojiUtils.DELETE_LIMIT):
async for message in self.emoji_channel.history(
limit=EmojiUtils.DELETE_LIMIT):
await message.delete()

# get emoji that were added in the last 10 minutes
recent_emoji = [emoji for emoji in after if (
time.time() - discord.utils.snowflake_time(emoji.id).timestamp()) < EmojiUtils.NEW_EMOTE_THRESHOLD]
recent_emoji = [
emoji for emoji in after
if (time.time() -
discord.utils.snowflake_time(emoji.id).timestamp()
) < EmojiUtils.NEW_EMOTE_THRESHOLD
]

emoji_sorted = sorted(after, key=lambda e: e.name)
for emoji_chunk in chunker(emoji_sorted, EmojiUtils.SPLIT_MSG_AFTER):
await self.emoji_channel.send(''.join(str(e) for e in emoji_chunk))

if len(recent_emoji) > 0:
await self.emoji_channel.send(f"Newly added: {''.join(str(e) for e in recent_emoji)}")
await self.emoji_channel.send(
f"Newly added: {''.join(str(e) for e in recent_emoji)}")
Loading

0 comments on commit b46838d

Please sign in to comment.