Skip to content

Commit

Permalink
Merge pull request #112 from stankudrow/fix-async-for-cursor-infinite…
Browse files Browse the repository at this point in the history
…-loop

Fix the `async for row in cursor:` infinite loop error
  • Loading branch information
long2ice authored Sep 5, 2024
2 parents 74904ef + a11b9fc commit 49764c5
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### 0.2.5

- Fix infinite iteration case when a cursor object is put in the `async for` loop. By @stankudrow in #112.
- Fix pool connection management (the discussion #108 by @DFilyushin) by @stankudrow in #109:

- add the asynchronous context manager support to the `Pool` class with the pool "startup()" as `__aenter__` and "shutdown()" as `__aexit__` methods;
Expand Down
5 changes: 3 additions & 2 deletions asynch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .connection import connect # noqa:F401
from .pool import create_pool # noqa:F401
from asynch.connection import Connection, connect # noqa: F401
from asynch.cursors import Cursor, DictCursor # noqa: F401
from asynch.pool import Pool, create_async_pool, create_pool # noqa: F401
48 changes: 33 additions & 15 deletions asynch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def connected(self) -> Optional[bool]:
"""

warn(
"consider using `connection.opened` attribute",
(
"Please consider using the `connection.opened` property. "
"This property may be removed in the version 0.2.6 or a later release."
),
DeprecationWarning,
)
return self._opened
Expand Down Expand Up @@ -130,7 +133,7 @@ def status(self) -> str:
and the `conn.opened` is False.
:raise ConnectionError: unknown connection state
:return: connection status
:return: the connection status
:rtype: str (ConnectionStatuses StrEnum)
"""

Expand Down Expand Up @@ -167,6 +170,8 @@ def echo(self) -> bool:
return self._echo

async def close(self) -> None:
"""Close the connection."""

if self._opened:
await self._connection.disconnect()
self._opened = False
Expand All @@ -186,14 +191,27 @@ async def connect(self) -> None:
self._closed = False

def cursor(self, cursor: Optional[Cursor] = None, *, echo: bool = False) -> Cursor:
"""Return the cursor object for the connection.
When a parameter is interpreted as True,
it takes precedence over the corresponding default value.
If cursor is None, but echo is True, then an instance
of a default `Cursor` class will be created with echoing
set to True even if the `self.echo` property returns False.
:param cursor None | Cursor: a Cursor factory class
:param echo bool:
:return: the cursor from a connection
:rtype: Cursor
"""

cursor_cls = cursor or self._cursor_cls
return cursor_cls(self, self._echo or echo)
return cursor_cls(self, echo or self._echo)

async def ping(self) -> None:
"""Check the connection liveliness.
:raises ConnectionError: if ping() has failed
:return: None
"""

Expand All @@ -219,17 +237,17 @@ async def connect(
1. conn = Connection(...) # init a Connection instance
2. conn.connect() # connect to a ClickHouse instance
:param dsn: DSN/connection string (if None -> constructed from default dsn parts)
:param user: user string ("default" by default)
:param password: password string ("" by default)
:param host: host string ("127.0.0.1" by default)
:param port: port integer (9000 by default)
:param database: database string ("default" by default)
:param cursor_cls: Cursor class (asynch.Cursor by default)
:param echo: connection echo mode (False by default)
:param kwargs: connection settings
:return: the open connection
:param dsn str: DSN/connection string (if None -> constructed from default dsn parts)
:param user str: user string ("default" by default)
:param password str: password string ("" by default)
:param host str: host string ("127.0.0.1" by default)
:param port int: port integer (9000 by default)
:param database str: database string ("default" by default)
:param cursor_cls Cursor: Cursor class (asynch.Cursor by default)
:param echo bool: echo mode flag (False by default)
:param kwargs dict: connection settings
:return: an opened connection
:rtype: Connection
"""

Expand Down
47 changes: 34 additions & 13 deletions asynch/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __aiter__(self):
async def __anext__(self):
while True:
one = await self.fetchone()
if one is None:
if not one:
raise StopAsyncIteration
return one

Expand Down Expand Up @@ -349,23 +349,44 @@ def set_query_id(self, query_id=""):


class DictCursor(Cursor):
async def fetchone(self):
row = await super(DictCursor, self).fetchone()
async def fetchone(self) -> dict:
"""Fetch exactly one row from the last executed query.
:raises AttributeError: columns mismatch
:return: one row from the query
:rtype: dict
"""

row = await super().fetchone()
if self._columns:
return dict(zip(self._columns, row)) if row else {}
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")

async def fetchmany(self, size: int):
rows = await super(DictCursor, self).fetchmany(size)
async def fetchmany(self, size: int) -> list[dict]:
"""Fetch no more than `size` rows from the last executed query.
:raises AttributeError: columns mismatch
:return: the list of rows from the query
:rtype: list[dict]
"""

rows = await super().fetchmany(size)
if self._columns:
return [dict(zip(self._columns, item)) for item in rows] if rows else []
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")

async def fetchall(self):
rows = await super(DictCursor, self).fetchall()
async def fetchall(self) -> list[dict]:
"""Fetch all resulting rows from the last executed query.
:raises AttributeError: columns mismatch
:return: the list of all possible rows from the query
:rtype: list[dict]
"""

rows = await super().fetchall()
if self._columns:
return [dict(zip(self._columns, item)) for item in rows] if rows else []
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")
4 changes: 2 additions & 2 deletions asynch/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def __aenter__(self) -> "Pool":
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()

def __repr__(self):
def __repr__(self) -> str:
cls_name = self.__class__.__name__
status = self.status
return (
Expand All @@ -155,7 +155,7 @@ def status(self) -> str:
and the `pool.opened` is False.
:raise PoolError: unresolved pool state.
:return: pool status
:return: the pool status
:rtype: str (PoolStatuses StrEnum)
"""

Expand Down
43 changes: 35 additions & 8 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
from typing import Any

import pytest

from asynch.connection import Connection
from asynch.cursors import DictCursor
from asynch.proto import constants


@pytest.mark.asyncio
async def test_fetchone(conn):
@pytest.mark.parametrize(
("stmt", "answer"),
[
("SELECT 42", [{"42": 42}]),
("SELECT -21 WHERE 1 != 1", []),
],
)
async def test_cursor_async_for(
stmt: str,
answer: list[dict[str, Any]],
conn: Connection,
):
result: list[dict[str, Any]] = []

async with conn:
async with conn.cursor(cursor=DictCursor) as cursor:
cursor.set_stream_results(stream_results=True, max_row_buffer=1000)
await cursor.execute(stmt)
result = [row async for row in cursor]

assert result == answer


@pytest.mark.asyncio
async def test_fetchone(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchone()
Expand All @@ -17,23 +44,23 @@ async def test_fetchone(conn):


@pytest.mark.asyncio
async def test_fetchall(conn):
async def test_fetchall(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchall()
assert ret == [(1,)]


@pytest.mark.asyncio
async def test_dict_cursor(conn):
async def test_dict_cursor(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchall()
assert ret == [{"1": 1}]


@pytest.mark.asyncio
async def test_insert_dict(conn):
async def test_insert_dict(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.execute(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
Expand All @@ -56,7 +83,7 @@ async def test_insert_dict(conn):


@pytest.mark.asyncio
async def test_insert_tuple(conn):
async def test_insert_tuple(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.execute(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
Expand All @@ -79,7 +106,7 @@ async def test_insert_tuple(conn):


@pytest.mark.asyncio
async def test_executemany(conn):
async def test_executemany(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.executemany(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
Expand Down Expand Up @@ -114,7 +141,7 @@ async def test_executemany(conn):


@pytest.mark.asyncio
async def test_table_ddl(conn):
async def test_table_ddl(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("drop table if exists test.alter_table")
create_table_sql = """
Expand All @@ -137,7 +164,7 @@ async def test_table_ddl(conn):


@pytest.mark.asyncio
async def test_insert_buffer_overflow(conn):
async def test_insert_buffer_overflow(conn: Connection):
old_buffer_size = constants.BUFFER_SIZE
constants.BUFFER_SIZE = 2**6 + 1

Expand Down

0 comments on commit 49764c5

Please sign in to comment.