diff --git a/peewee_async.py b/peewee_async.py index 6ea8da9..5998f03 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -16,7 +16,6 @@ import abc import asyncio import contextlib -import functools import logging import uuid import warnings @@ -298,15 +297,6 @@ def execute_sql(self, *args, **kwargs): (str(args), str(kwargs))) return super().execute_sql(*args, **kwargs) - async def fetch_results(self, query, cursor): - # TODO: Probably we don't need this method at all? - # We might get here if we use older `Manager` interface. - if isinstance(query, peewee.BaseModelSelect): - return await AsyncQueryWrapper.make_for_all_rows(cursor, query) - if isinstance(query, peewee.RawQuery): - return await AsyncQueryWrapper.make_for_all_rows(cursor, query) - assert False, "Unsupported type of query '%s', use AioModel instead" % type(query) - def connection(self) -> ConnectionContext: return ConnectionContext(self.aio_pool, self._task_data) @@ -324,16 +314,12 @@ async def aio_execute(self, query, fetch_results=None): :param query: peewee query instance created with ``Model.select()``, ``Model.update()`` etc. - :param fetch_results: function with cursor param. It let you get data manually and don't need to close cursor - It will be closed automatically - :return: result depends on query type, it's the same as for sync - ``query.execute()`` + :param fetch_results: function with cursor param. It let you get data manually and + don't need to close cursor It will be closed automatically. + :return: result depends on query type, it's the same as for sync `query.execute()` """ sql, params = query.sql() - if fetch_results is None: - query_fetch_results = getattr(query, 'fetch_results', None) - database_fetch_results = functools.partial(self.fetch_results, query) - fetch_results = query_fetch_results or database_fetch_results + fetch_results = fetch_results or getattr(query, 'fetch_results', None) return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) diff --git a/peewee_async_compat.py b/peewee_async_compat.py index 4fa3b64..1e6abff 100644 --- a/peewee_async_compat.py +++ b/peewee_async_compat.py @@ -49,7 +49,23 @@ def _patch_query_with_compat_methods(query, async_query_cls): - aio_get (for SELECT) - aio_scalar (for SELECT) """ - from peewee_async import AioModelSelect + from peewee_async import AioModelSelect, AioModelUpdate, AioModelDelete, AioModelInsert + + if getattr(query, 'aio_execute', None): + # No need to patch + return + + if async_query_cls is None: + if isinstance(query, peewee.RawQuery): + async_query_cls = AioModelSelect + if isinstance(query, peewee.SelectBase): + async_query_cls = AioModelSelect + elif isinstance(query, peewee.Update): + async_query_cls = AioModelUpdate + elif isinstance(query, peewee.Delete): + async_query_cls = AioModelDelete + elif isinstance(query, peewee.Insert): + async_query_cls = AioModelInsert query.aio_execute = partial(async_query_cls.aio_execute, query) query.fetch_results = partial(async_query_cls.fetch_results, query) @@ -345,6 +361,7 @@ async def create_or_get(self, model_, **kwargs): async def execute(self, query): """Execute query asyncronously.""" + _patch_query_with_compat_methods(query, None) return await self.database.aio_execute(query) async def prefetch(self, query, *subqueries, prefetch_type=peewee.PREFETCH_TYPE.JOIN): diff --git a/tests/aio_model/test_selecting.py b/tests/aio_model/test_selecting.py index 385c92a..84b8359 100644 --- a/tests/aio_model/test_selecting.py +++ b/tests/aio_model/test_selecting.py @@ -1,9 +1,11 @@ +import peewee +import pytest from tests.conftest import all_dbs -from tests.models import TestModelAlpha, TestModelBeta +from tests.models import TestModel, TestModelAlpha, TestModelBeta @all_dbs -async def test__select__w_join(manager): +async def test_select_w_join(manager): alpha = await TestModelAlpha.aio_create(text="Test 1") beta = await TestModelBeta.aio_create(alpha_id=alpha.id, text="text") @@ -13,4 +15,23 @@ async def test__select__w_join(manager): ).aio_execute())[0] assert result.id == beta.id - assert result.joined_alpha.id == alpha.id \ No newline at end of file + assert result.joined_alpha.id == alpha.id + + +# @pytest.mark.skip +@all_dbs +async def test_select_compound(manager): + obj1 = await manager.create(TestModel, text="Test 1") + obj2 = await manager.create(TestModel, text="Test 2") + query = ( + TestModel.select().where(TestModel.id == obj1.id) | + TestModel.select().where(TestModel.id == obj2.id) + ) + assert isinstance(query, peewee.ModelCompoundSelectQuery) + # NOTE: Two `AioModelSelect` when joining via `|` produce `ModelCompoundSelectQuery` + # without `aio_execute()` method, so only compat mode is available for now. + # result = await query.aio_execute() + result = await manager.execute(query) + assert len(list(result)) == 2 + assert obj1 in list(result) + assert obj2 in list(result) diff --git a/tests/db_config.py b/tests/db_config.py index 1b74bcf..7e1f0a3 100644 --- a/tests/db_config.py +++ b/tests/db_config.py @@ -1,3 +1,4 @@ +import os import peewee_async import peewee_asyncext @@ -5,19 +6,19 @@ PG_DEFAULTS = { 'database': 'postgres', 'host': '127.0.0.1', - 'port': 5432, + 'port': int(os.environ.get('POSTGRES_PORT', 5432)), 'password': 'postgres', 'user': 'postgres', - "connect_timeout": 30 + 'connect_timeout': 30 } MYSQL_DEFAULTS = { 'database': 'mysql', 'host': '127.0.0.1', - 'port': 3306, + 'port': int(os.environ.get('MYSQL_PORT', 3306)), 'user': 'root', 'password': 'mysql', - "connect_timeout": 30 + 'connect_timeout': 30 } DB_DEFAULTS = { @@ -28,6 +29,7 @@ 'mysql': MYSQL_DEFAULTS, 'mysql-pool': MYSQL_DEFAULTS } + DB_CLASSES = { 'postgres': peewee_async.PostgresqlDatabase, 'postgres-ext': peewee_asyncext.PostgresqlExtDatabase, diff --git a/tests/test_compat.py b/tests/test_compat.py index 93457a6..4290d2f 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,5 +1,6 @@ import uuid +import peewee from tests.conftest import all_dbs from tests.models import CompatTestModel @@ -9,6 +10,34 @@ async def test_create_select_compat_mode(manager): obj1 = await manager.create(CompatTestModel, text="Test 1") obj2 = await manager.create(CompatTestModel, text="Test 2") query = CompatTestModel.select().order_by(CompatTestModel.text) + assert isinstance(query, peewee.ModelSelect) + result = await manager.execute(query) + assert list(result) == [obj1, obj2] + + +@all_dbs +async def test_compound_select_compat_mode(manager): + obj1 = await manager.create(CompatTestModel, text="Test 1") + obj2 = await manager.create(CompatTestModel, text="Test 2") + query = ( + CompatTestModel.select().where(CompatTestModel.id == obj1.id) | + CompatTestModel.select().where(CompatTestModel.id == obj2.id) + ) + assert isinstance(query, peewee.ModelCompoundSelectQuery) + result = await manager.execute(query) + assert len(list(result)) == 2 + assert obj1 in list(result) + assert obj2 in list(result) + + +@all_dbs +async def test_raw_select_compat_mode(manager): + obj1 = await manager.create(CompatTestModel, text="Test 1") + obj2 = await manager.create(CompatTestModel, text="Test 2") + query = CompatTestModel.raw( + 'SELECT id, text, data FROM compattestmodel m ORDER BY m.text' + ) + assert isinstance(query, peewee.ModelRaw) result = await manager.execute(query) assert list(result) == [obj1, obj2]