From b33dc0d795e6a56ae0707cae0fb197b99b1836af Mon Sep 17 00:00:00 2001 From: Rusty Bower Date: Fri, 6 Dec 2019 02:06:06 -0600 Subject: [PATCH] tests: migrating db tests to sqlalchemy --- test/test_db.py | 138 +++++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 60 deletions(-) diff --git a/test/test_db.py b/test/test_db.py index 78bc8bfd1c..1311331c6f 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -8,13 +8,12 @@ import json import os -import sqlite3 import sys import tempfile import pytest -from sopel.db import SopelDB +from sopel.db import ChannelValues, PluginValues, Nicknames, NickValues, SopelDB from sopel.test_tools import MockConfig from sopel.tools import Identifier @@ -35,6 +34,7 @@ def db(): config = MockConfig() config.core.db_filename = db_filename + config.core.db_type = 'sqlite' db = SopelDB(config) # TODO add tests to ensure db creation works properly, too. return db @@ -45,7 +45,7 @@ def teardown_function(function): def test_get_nick_id(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() tests = [ [None, 'embolalia', Identifier('Embolalia')], # Ensures case conversion is handled properly @@ -57,13 +57,11 @@ def test_get_nick_id(db): for test in tests: test[0] = db.get_nick_id(test[2]) nick_id, slug, nick = test - with conn: - cursor = conn.cursor() - registered = cursor.execute( - 'SELECT nick_id, slug, canonical FROM nicknames WHERE canonical IS ?', [nick] - ).fetchall() - assert len(registered) == 1 - assert registered[0][1] == slug and registered[0][2] == nick + registered = session.query(Nicknames) \ + .filter(Nicknames.canonical == nick) \ + .all() + assert len(registered) == 1 + assert registered[0].slug == slug and registered[0].canonical == nick # Check that each nick ended up with a different id assert len(set(test[0] for test in tests)) == len(tests) @@ -79,6 +77,7 @@ def test_get_nick_id(db): nick_id = test[0] new_id = db.get_nick_id(Identifier(test[2].upper())) assert nick_id == new_id + session.close() def test_alias_nick(db): @@ -102,8 +101,7 @@ def test_alias_nick(db): def test_set_nick_value(db): - conn = sqlite3.connect(db_filename) - cursor = conn.cursor() + session = db.ssession() nick = 'Embolalia' nick_id = db.get_nick_id(nick) data = { @@ -117,10 +115,10 @@ def check(): db.set_nick_value(nick, key, value) for key, value in iteritems(data): - found_value = cursor.execute( - 'SELECT value FROM nick_values WHERE nick_id = ? AND key = ?', - [nick_id, key] - ).fetchone()[0] + found_value = session.query(NickValues.value) \ + .filter(NickValues.nick_id == nick_id) \ + .filter(NickValues.key == key) \ + .scalar() assert json.loads(unicode(found_value)) == value check() @@ -128,11 +126,11 @@ def check(): data['number_key'] = 'not a number anymore!' data['unicode'] = 'This is different toö!' check() + session.close() def test_get_nick_value(db): - conn = sqlite3.connect(db_filename) - cursor = conn.cursor() + session = db.ssession() nick = 'Embolalia' nick_id = db.get_nick_id(nick) data = { @@ -142,13 +140,14 @@ def test_get_nick_value(db): } for key, value in iteritems(data): - cursor.execute('INSERT INTO nick_values VALUES (?, ?, ?)', - [nick_id, key, json.dumps(value, ensure_ascii=False)]) - conn.commit() + nv = NickValues(nick_id=nick_id, key=key, value=json.dumps(value, ensure_ascii=False)) + session.add(nv) + session.commit() for key, value in iteritems(data): found_value = db.get_nick_value(nick, key) assert found_value == value + session.close() def test_get_nick_value_default(db): @@ -165,35 +164,39 @@ def test_delete_nick_value(db): def test_unalias_nick(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() nick = 'Embolalia' nick_id = 42 - conn.execute('INSERT INTO nicknames VALUES (?, ?, ?)', - [nick_id, Identifier(nick).lower(), nick]) + + nn = Nicknames(nick_id=nick_id, slug=Identifier(nick).lower(), canonical=nick) + session.add(nn) + session.commit() + aliases = ['EmbölaliÅ', 'Embo`work', 'Embo'] for alias in aliases: - conn.execute('INSERT INTO nicknames VALUES (?, ?, ?)', - [nick_id, Identifier(alias).lower(), alias]) - conn.commit() + nn = Nicknames(nick_id=nick_id, slug=Identifier(alias).lower(), canonical=alias) + session.add(nn) + session.commit() for alias in aliases: db.unalias_nick(alias) for alias in aliases: - found = conn.execute( - 'SELECT * FROM nicknames WHERE nick_id = ?', - [nick_id]).fetchall() + found = session.query(Nicknames) \ + .filter(Nicknames.nick_id == nick_id) \ + .all() assert len(found) == 1 + session.close() def test_delete_nick_group(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() aliases = ['Embolalia', 'Embo'] nick_id = 42 for alias in aliases: - conn.execute('INSERT INTO nicknames VALUES (?, ?, ?)', - [nick_id, Identifier(alias).lower(), alias]) - conn.commit() + nn = Nicknames(nick_id=nick_id, slug=Identifier(alias).lower(), canonical=alias) + session.add(nn) + session.commit() db.set_nick_value(aliases[0], 'foo', 'bar') db.set_nick_value(aliases[1], 'spam', 'eggs') @@ -201,19 +204,20 @@ def test_delete_nick_group(db): db.delete_nick_group(aliases[0]) # Nothing else has created values, so we know the tables are empty - nicks = conn.execute('SELECT * FROM nicknames').fetchall() + nicks = session.query(Nicknames).all() assert len(nicks) == 0 - data = conn.execute('SELECT * FROM nick_values').fetchone() + data = session.query(NickValues).first() assert data is None + session.close() def test_merge_nick_groups(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() aliases = ['Embolalia', 'Embo'] for nick_id, alias in enumerate(aliases): - conn.execute('INSERT INTO nicknames VALUES (?, ?, ?)', - [nick_id, Identifier(alias).lower(), alias]) - conn.commit() + nn = Nicknames(nick_id=nick_id, slug=Identifier(alias).lower(), canonical=alias) + session.add(nn) + session.commit() finals = (('foo', 'bar'), ('bar', 'blue'), ('spam', 'eggs')) @@ -224,25 +228,29 @@ def test_merge_nick_groups(db): db.merge_nick_groups(aliases[0], aliases[1]) - nick_ids = conn.execute('SELECT nick_id FROM nicknames') - nick_id = nick_ids.fetchone()[0] - alias_id = nick_ids.fetchone()[0] + nick_ids = session.query(Nicknames.nick_id).all() + nick_id = nick_ids[0][0] + alias_id = nick_ids[1][0] assert nick_id == alias_id for key, value in finals: - found = conn.execute( - 'SELECT value FROM nick_values WHERE nick_id = ? AND key = ?', - [nick_id, key]).fetchone()[0] + found = session.query(NickValues.value) \ + .filter(NickValues.nick_id == nick_id) \ + .filter(NickValues.key == key) \ + .scalar() assert json.loads(unicode(found)) == value + session.close() def test_set_channel_value(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() db.set_channel_value('#asdf', 'qwer', 'zxcv') - result = conn.execute( - 'SELECT value FROM channel_values WHERE channel = ? and key = ?', - ['#asdf', 'qwer']).fetchone()[0] + result = session.query(ChannelValues.value) \ + .filter(ChannelValues.channel == '#asdf') \ + .filter(ChannelValues.key == 'qwer') \ + .scalar() assert result == '"zxcv"' + session.close() def test_delete_channel_value(db): @@ -253,11 +261,15 @@ def test_delete_channel_value(db): def test_get_channel_value(db): - conn = sqlite3.connect(db_filename) - conn.execute("INSERT INTO channel_values VALUES ('#asdf', 'qwer', '\"zxcv\"')") - conn.commit() + session = db.ssession() + + cv = ChannelValues(channel='#asdf', key='qwer', value='\"zxcv\"') + session.add(cv) + session.commit() + result = db.get_channel_value('#asdf', 'qwer') assert result == 'zxcv' + session.close() def test_get_channel_value_default(db): @@ -287,20 +299,26 @@ def test_get_preferred_value(db): def test_set_plugin_value(db): - conn = sqlite3.connect(db_filename) + session = db.ssession() db.set_plugin_value('plugname', 'qwer', 'zxcv') - result = conn.execute( - 'SELECT value FROM plugin_values WHERE plugin = ? and key = ?', - ['plugname', 'qwer']).fetchone()[0] + result = session.query(PluginValues.value) \ + .filter(PluginValues.plugin == 'plugname') \ + .filter(PluginValues.key == 'qwer') \ + .scalar() assert result == '"zxcv"' + session.close() def test_get_plugin_value(db): - conn = sqlite3.connect(db_filename) - conn.execute("INSERT INTO plugin_values VALUES ('plugname', 'qwer', '\"zxcv\"')") - conn.commit() + session = db.ssession() + + pv = PluginValues(plugin='plugname', key='qwer', value='\"zxcv\"') + session.add(pv) + session.commit() + result = db.get_plugin_value('plugname', 'qwer') assert result == 'zxcv' + session.close() def test_get_plugin_value_default(db):