From 2cdd1b11da53ba8082fac457df65fc59407983a5 Mon Sep 17 00:00:00 2001 From: tarsil Date: Tue, 28 Mar 2023 11:47:23 +0100 Subject: [PATCH] Add missing connection timeout --- .github/workflows/test-suite.yml | 8 ++++---- databases/backends/mssql.py | 8 +++++++- tests/test_integration.py | 16 +++++++++++++--- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 25fbd2d3..0b426e89 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -47,7 +47,7 @@ jobs: ACCEPT_EULA: "Y" MSSQL_PID: Express ports: - - 1433/tcp + - 1433:1433 options: >- --health-cmd "/opt/mssql-tools/bin/sqlcmd -U sa -P $MSSQL_SA_PASSWORD -Q 'select 1' -b -o /dev/null" --health-interval 60s @@ -82,7 +82,7 @@ jobs: postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:password@127.0.0.1:5432/testsuite, postgresql+asyncpg://username:password@localhost:5432/testsuite, - mssql://sa:Mssql123$#mssql@127.0.0.1:1433/master?driver=ODBC+Driver+17+for+SQL+Server, - mssql+pyodbc://sa:Mssql123$#mssql@127.0.0.1:1433/master?driver=ODBC+Driver+17+for+SQL+Server, - mssql+aioodbc://sa:Mssql123$#mssql@127.0.0.1:1433/master?driver=ODBC+Driver+17+for+SQL+Server + mssql://sa:Mssql123$#mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server, + mssql+pyodbc://sa:Mssql123$#mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server, + mssql+aioodbc://sa:Mssql123$#mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server run: "scripts/test" diff --git a/databases/backends/mssql.py b/databases/backends/mssql.py index a6ad80ca..0764e774 100644 --- a/databases/backends/mssql.py +++ b/databases/backends/mssql.py @@ -41,6 +41,7 @@ def _get_connection_kwargs(self) -> dict: pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") driver = url_options.get("driver") + timeout = url_options.get("connection_timeout", 30) trusted_connection = url_options.get("trusted_connection", "no") assert driver is not None, "The driver must be specified" @@ -56,6 +57,7 @@ def _get_connection_kwargs(self) -> dict: kwargs["trusted_connection"] = trusted_connection.lower() kwargs["driver"] = driver + kwargs["timeout"] = timeout for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. @@ -77,8 +79,12 @@ async def connect(self) -> None: port = self._database_url.port or 1433 user = self._database_url.username or getpass.getuser() password = self._database_url.password + timeout = kwargs.pop("timeout") - dsn = f"Driver={driver};Database={database};Server={hostname},{port};UID={user};PWD={password};" + if port: + dsn = f"Driver={driver};Database={database};Server={hostname},{port};UID={user};PWD={password};Connection+Timeout={timeout}" + else: + dsn = f"Driver={driver};Database={database};Server={hostname},{port};UID={user};PWD={password};Connection+Timeout={timeout}" self._pool = await aioodbc.create_pool( dsn=dsn, diff --git a/tests/test_integration.py b/tests/test_integration.py index 73c23305..25a47287 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -29,10 +29,15 @@ def create_test_database(): "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", + ]: + url = str(database_url.replace(driver=None)) + elif database_url.scheme in [ + "mssql", "mssql+pyodbc", "mssql+aioodbc", + "mssql+pymssql", ]: - url = str(database_url.replace(driver=None)) + url = str(database_url.replace(driver="pyodbc")) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -43,14 +48,19 @@ def create_test_database(): database_url = DatabaseURL(url) if database_url.scheme in ["mysql", "mysql+aiomysql", "mysql+asyncmy"]: url = str(database_url.replace(driver="pymysql")) - if database_url.scheme in ["mssql", "mssql+aioodbc"]: - url = str(database_url.replace(driver="pyodbc")) elif database_url.scheme in [ "postgresql+aiopg", "sqlite+aiosqlite", "postgresql+asyncpg", ]: url = str(database_url.replace(driver=None)) + elif database_url.scheme in [ + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + "mssql+pymssql", + ]: + url = str(database_url.replace(driver="pyodbc")) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine)