diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 60d27a175f8b0..f8f22578923f5 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -23,9 +23,8 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload -import psycopg2 -import psycopg2.extras from more_itertools import chunked +from psycopg2 import connect as ppg2_connect from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, execute_batch from airflow.providers.common.compat.sdk import ( @@ -65,8 +64,8 @@ if USE_PSYCOPG3: from psycopg.errors import Diagnostic -CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor -CursorRow: TypeAlias = dict[str, Any] | tuple[Any, ...] + CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor + CursorRow: TypeAlias = dict[str, Any] | tuple[Any, ...] class CompatConnection(Protocol): @@ -221,9 +220,9 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") cursor_types = { - "dictcursor": psycopg2.extras.DictCursor, - "realdictcursor": psycopg2.extras.RealDictCursor, - "namedtuplecursor": psycopg2.extras.NamedTupleCursor, + "dictcursor": DictCursor, + "realdictcursor": RealDictCursor, + "namedtuplecursor": NamedTupleCursor, } if _cursor in cursor_types: return cursor_types[_cursor] @@ -285,7 +284,7 @@ def get_conn(self) -> CompatConnection: if raw_cursor: conn_args["cursor_factory"] = self._get_cursor(raw_cursor) - self.conn = cast("CompatConnection", psycopg2.connect(**conn_args)) + self.conn = cast("CompatConnection", ppg2_connect(**conn_args)) return self.conn diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index a095ed3a34f5b..a1c8d757bcbff 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -94,7 +94,7 @@ def mock_connect(mocker): """Mock the connection object according to the correct psycopg version.""" if USE_PSYCOPG3: return mocker.patch("airflow.providers.postgres.hooks.postgres.psycopg.connection.Connection.connect") - return mocker.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") + return mocker.patch("airflow.providers.postgres.hooks.postgres.ppg2_connect") class TestPostgresHookConn: