Skip to content

Commit

Permalink
Implement 'read_timeout' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Aliaksandr Akulchyk committed Jan 5, 2024
1 parent 83aa96e commit 4d64620
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
23 changes: 20 additions & 3 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="",
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand All @@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="",
init_command=init_command,
connect_timeout=connect_timeout,
read_default_group=read_default_group,
read_timeout=read_timeout,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="",
charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
connect_timeout=None, read_default_group=None, read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(self, host="localhost", user=None, password="",
when connecting.
:param read_default_group: Group to read from in the configuration
file.
:param read_timeout: The timeout for reading from the connection in seconds (default: None - no timeout)
:param autocommit: Autocommit mode. None means use server default.
(default: False)
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
Expand Down Expand Up @@ -257,6 +260,7 @@ def __init__(self, host="localhost", user=None, password="",

self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout

self._result = None
self._affected_rows = 0
Expand Down Expand Up @@ -654,12 +658,25 @@ async def _read_packet(self, packet_type=MysqlPacket):

async def _read_bytes(self, num_bytes):
try:
data = await self._reader.readexactly(num_bytes)
if self.read_timeout:
try:
data = await asyncio.wait_for(
self._reader.readexactly(num_bytes),
self.read_timeout
)
except asyncio.TimeoutError as e:
raise asyncio.TimeoutError("Read timeout exceeded") from e
else:
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except OSError as e:
except (OSError, asyncio.TimeoutError) as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except Exception as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
Expand Down
13 changes: 10 additions & 3 deletions tests/sa/test_sa_connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from unittest import mock

import aiomysql
from aiomysql import sa, Cursor

import pytest
from sqlalchemy import MetaData, Table, Column, Integer, String, func, select
from sqlalchemy.schema import DropTable, CreateTable
from sqlalchemy.sql.expression import bindparam

import aiomysql
from aiomysql import sa, Cursor

meta = MetaData()
tbl = Table('sa_tbl', meta,
Column('id', Integer, nullable=False,
Expand Down Expand Up @@ -35,6 +35,13 @@ async def connect(**kwargs):
return connect


@pytest.mark.run_loop
async def test_read_timeout(sa_connect):
conn = await sa_connect(read_timeout=0.01)
with pytest.raises(aiomysql.OperationalError):
await conn.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_execute_text_select(sa_connect):
conn = await sa_connect()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator):
await connection_creator(connect_timeout=0.000000000001)


@pytest.mark.run_loop
async def test_read_timeout(connection_creator):
with pytest.raises(aiomysql.OperationalError):
con = await connection_creator(read_timeout=0.01)
cur = await con.cursor()
await cur.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
tests_root = os.path.abspath(os.path.dirname(__file__))
Expand Down

0 comments on commit 4d64620

Please sign in to comment.