diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index db5ce959b1faf..394b33323138f 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -43,6 +43,16 @@ from airflow.utils.orm_event_handlers import setup_event_handlers from airflow.utils.sqlalchemy import is_sqlalchemy_v1 +USE_PSYCOPG3: bool +try: + from importlib.util import find_spec + + is_psycopg3 = find_spec("psycopg") is not None + + USE_PSYCOPG3 = is_psycopg3 and not is_sqlalchemy_v1() +except (ImportError, ModuleNotFoundError): + USE_PSYCOPG3 = False + if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -426,12 +436,17 @@ def clean_in_fork(): register_at_fork(after_in_child=clean_in_fork) -DEFAULT_ENGINE_ARGS = { - "postgresql": { - "executemany_mode": "values_plus_batch", - "executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000, - "executemany_batch_page_size": 2000, - }, +DEFAULT_ENGINE_ARGS: dict[str, dict[str, Any]] = { + "postgresql": ( + { + "executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000, + } + | ( + {} + if USE_PSYCOPG3 + else {"executemany_mode": "values_plus_batch", "executemany_batch_page_size": 2000} + ) + ) } diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 6dd01d9ba8e01..e8409c86618fd 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -61,6 +61,21 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.task_instance_session import get_current_task_instance_session +USE_PSYCOPG3: bool +try: + from importlib.util import find_spec + + import sqlalchemy + from packaging.version import Version + + is_psycopg3 = find_spec("psycopg") is not None + sqlalchemy_version = Version(sqlalchemy.__version__) + is_sqla2 = (sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro) >= (2, 0, 0) + + USE_PSYCOPG3 = is_psycopg3 and is_sqla2 +except (ImportError, ModuleNotFoundError): + USE_PSYCOPG3 = False + if TYPE_CHECKING: from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory @@ -1284,15 +1299,28 @@ def create_global_lock( dialect = conn.dialect try: if dialect.name == "postgresql": - conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) - conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) + if USE_PSYCOPG3: + # psycopg3 doesn't support parameters for `SET`. Use `set_config` instead. + # The timeout value must be passed as a string of milliseconds. + conn.execute( + text("SELECT set_config('lock_timeout', :timeout, false)"), + {"timeout": str(lock_timeout)}, + ) + conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) + else: + conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) + conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout}) yield finally: if dialect.name == "postgresql": - conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) + if USE_PSYCOPG3: + # Use set_config() to reset the timeout to its default (0 = off/wait forever). + conn.execute(text("SELECT set_config('lock_timeout', '0', false)")) + else: + conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() if not unlocked: raise RuntimeError("Error releasing DB lock!") diff --git a/airflow-core/tests/unit/always/test_connection.py b/airflow-core/tests/unit/always/test_connection.py index ef6fa2213c270..2ace72179e3ad 100644 --- a/airflow-core/tests/unit/always/test_connection.py +++ b/airflow-core/tests/unit/always/test_connection.py @@ -32,7 +32,7 @@ from airflow.models import Connection, crypto from airflow.sdk import BaseHook -from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4 +from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4, SQLALCHEMY_V_2_0 sqlite = pytest.importorskip("airflow.providers.sqlite.hooks.sqlite") @@ -679,10 +679,20 @@ def test_env_var_priority(self, mock_supervisor_comms): def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id="test_uri") hook = conn.get_hook() - assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database" + + ppg3_mode: bool = SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri() + if ppg3_mode: + assert ( + hook.get_uri() == "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database" + ) + else: + assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database" conn2 = BaseHook.get_connection(conn_id="test_uri_no_creds") hook2 = conn2.get_hook() - assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database" + if ppg3_mode: + assert hook2.get_uri() == "postgresql+psycopg://ec2.compute.com/the_database" + else: + assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database" @mock.patch.dict( "os.environ", @@ -695,7 +705,12 @@ def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id="test_uri") hook = conn.get_hook() engine = hook.get_sqlalchemy_engine() - expected = "postgresql://username:password@ec2.compute.com:5432/the_database" + + if SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri(): + expected = "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database" + else: + expected = "postgresql://username:password@ec2.compute.com:5432/the_database" + assert isinstance(engine, sqlalchemy.engine.Engine) if SQLALCHEMY_V_1_4: assert str(engine.url) == expected diff --git a/airflow-core/tests/unit/cli/commands/test_db_command.py b/airflow-core/tests/unit/cli/commands/test_db_command.py index eeac6aa16885d..e4cac8089aeeb 100644 --- a/airflow-core/tests/unit/cli/commands/test_db_command.py +++ b/airflow-core/tests/unit/cli/commands/test_db_command.py @@ -238,6 +238,27 @@ def test_cli_shell_postgres(self, mock_execute_interactive): "PGUSER": "postgres", } + @mock.patch("airflow.cli.commands.db_command.execute_interactive") + @mock.patch( + "airflow.cli.commands.db_command.settings.engine.url", + make_url("postgresql+psycopg://postgres:airflow@postgres:5432/airflow"), + ) + def test_cli_shell_postgres_ppg3(self, mock_execute_interactive): + pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.") + + db_command.shell(self.parser.parse_args(["db", "shell"])) + mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY) + _, kwargs = mock_execute_interactive.call_args + env = kwargs["env"] + postgres_env = {k: v for k, v in env.items() if k.startswith("PG")} + assert postgres_env == { + "PGDATABASE": "airflow", + "PGHOST": "postgres", + "PGPASSWORD": "airflow", + "PGPORT": "5432", + "PGUSER": "postgres", + } + @mock.patch("airflow.cli.commands.db_command.execute_interactive") @mock.patch( "airflow.cli.commands.db_command.settings.engine.url", @@ -257,6 +278,27 @@ def test_cli_shell_postgres_without_port(self, mock_execute_interactive): "PGUSER": "postgres", } + @mock.patch("airflow.cli.commands.db_command.execute_interactive") + @mock.patch( + "airflow.cli.commands.db_command.settings.engine.url", + make_url("postgresql+psycopg://postgres:airflow@postgres/airflow"), + ) + def test_cli_shell_postgres_without_port_ppg3(self, mock_execute_interactive): + pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.") + + db_command.shell(self.parser.parse_args(["db", "shell"])) + mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY) + _, kwargs = mock_execute_interactive.call_args + env = kwargs["env"] + postgres_env = {k: v for k, v in env.items() if k.startswith("PG")} + assert postgres_env == { + "PGDATABASE": "airflow", + "PGHOST": "postgres", + "PGPASSWORD": "airflow", + "PGPORT": "5432", + "PGUSER": "postgres", + } + @mock.patch( "airflow.cli.commands.db_command.settings.engine.url", make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"), @@ -265,6 +307,16 @@ def test_cli_shell_invalid(self): with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"): db_command.shell(self.parser.parse_args(["db", "shell"])) + @mock.patch( + "airflow.cli.commands.db_command.settings.engine.url", + make_url("invalid+psycopg://postgres:airflow@postgres/airflow"), + ) + def test_cli_shell_invalid_ppg3(self): + pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.") + + with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg"): + db_command.shell(self.parser.parse_args(["db", "shell"])) + @pytest.mark.parametrize( "args, match", [ diff --git a/devel-common/src/docs/utils/conf_constants.py b/devel-common/src/docs/utils/conf_constants.py index 89bd07b61e11f..5a5b0e7f672bb 100644 --- a/devel-common/src/docs/utils/conf_constants.py +++ b/devel-common/src/docs/utils/conf_constants.py @@ -248,6 +248,7 @@ def get_autodoc_mock_imports() -> list[str]: "pandas_gbq", "paramiko", "pinotdb", + "psycopg", "psycopg2", "pydruid", "pyhive", diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py index 21c6a165ca743..8abda6bc014ba 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py @@ -810,6 +810,9 @@ def _run_command(self, cur, sql_statement, parameters): self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if parameters: + # If we're using psycopg3, we might need to handle parameters differently + if hasattr(cur, "__module__") and "psycopg" in cur.__module__ and isinstance(parameters, list): + parameters = tuple(parameters) cur.execute(sql_statement, parameters) else: cur.execute(sql_statement) diff --git a/providers/common/sql/tests/unit/common/sql/sensors/test_sql.py b/providers/common/sql/tests/unit/common/sql/sensors/test_sql.py index b1d16ed1bef2c..93c1483ffdf7a 100644 --- a/providers/common/sql/tests/unit/common/sql/sensors/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/sensors/test_sql.py @@ -80,8 +80,8 @@ def test_sql_sensor_postgres(self): op2 = SqlSensor( task_id="sql_sensor_check_2", conn_id="postgres_default", - sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES", - parameters=["table_name"], + sql="SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = %s", + parameters=["information_schema"], dag=self.dag, ) op2.execute({}) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 72a526e8c3133..32b86cd9c0203 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -1175,9 +1175,9 @@ def cleanup_database_hook(self) -> None: raise ValueError("The db_hook should be set") if not isinstance(self.db_hook, PostgresHook): raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}") - conn = getattr(self.db_hook, "conn") - if conn and conn.notices: - for output in self.db_hook.conn.notices: + conn = getattr(self.db_hook, "conn", None) + if conn and hasattr(conn, "notices") and conn.notices: + for output in conn.notices: self.log.info(output) def reserve_free_tcp_port(self) -> None: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py index 9839552cd0dc6..d45d7937ff000 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py @@ -31,7 +31,7 @@ from slugify import slugify from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator -from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.providers.postgres.hooks.postgres import USE_PSYCOPG3, PostgresHook if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage @@ -52,9 +52,20 @@ def __init__(self, cursor): self.initialized = False def __iter__(self): + """Make the cursor iterable.""" return self def __next__(self): + """Fetch next row from the cursor.""" + if USE_PSYCOPG3: + if self.rows: + return self.rows.pop() + self.initialized = True + row = self.cursor.fetchone() + if row is None: + raise StopIteration + return row + # psycopg2 if self.rows: return self.rows.pop() self.initialized = True @@ -141,13 +152,29 @@ def db_hook(self) -> PostgresHook: return PostgresHook(postgres_conn_id=self.postgres_conn_id) def query(self): - """Query Postgres and returns a cursor to the results.""" + """Execute the query and return a cursor.""" conn = self.db_hook.get_conn() - cursor = conn.cursor(name=self._unique_name()) - cursor.execute(self.sql, self.parameters) - if self.use_server_side_cursor: - cursor.itersize = self.cursor_itersize - return _PostgresServerSideCursorDecorator(cursor) + + if USE_PSYCOPG3: + from psycopg.types.json import register_default_adapters + + # Register JSON handlers for this connection if not already done + register_default_adapters(conn) + + if self.use_server_side_cursor: + cursor_name = f"airflow_{self.task_id.replace('-', '_')}_{uuid.uuid4().hex}"[:63] + cursor = conn.cursor(name=cursor_name) + cursor.itersize = self.cursor_itersize + cursor.execute(self.sql, self.parameters) + return _PostgresServerSideCursorDecorator(cursor) + cursor = conn.cursor() + cursor.execute(self.sql, self.parameters) + else: + cursor = conn.cursor(name=self._unique_name()) + cursor.execute(self.sql, self.parameters) + if self.use_server_side_cursor: + cursor.itersize = self.cursor_itersize + return _PostgresServerSideCursorDecorator(cursor) return cursor def field_to_bigquery(self, field) -> dict[str, str]: @@ -182,8 +209,14 @@ def convert_type(self, value, schema_type, stringify_dict=True): hours=formatted_time.tm_hour, minutes=formatted_time.tm_min, seconds=formatted_time.tm_sec ) return str(time_delta) - if stringify_dict and isinstance(value, dict): - return json.dumps(value) + if stringify_dict: + if USE_PSYCOPG3: + from psycopg.types.json import Json + + if isinstance(value, (dict, Json)): + return json.dumps(value) + elif isinstance(value, dict): + return json.dumps(value) if isinstance(value, Decimal): return float(value) return value diff --git a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py index b7cd12d6ff9c8..cfd8255922c59 100644 --- a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py +++ b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py @@ -383,6 +383,17 @@ def call_get_conn(): "postgresql+psycopg2://login:password@host:1234/schema", id="sqlalchemy-scheme-with-driver", ), + pytest.param( + { + "conn_params": { + "extra": json.dumps( + {"sqlalchemy_scheme": "postgresql", "sqlalchemy_driver": "psycopg"} + ) + } + }, + "postgresql+psycopg://login:password@host:1234/schema", + id="sqlalchemy-scheme-with-driver-ppg3", + ), pytest.param( { "login": "user@domain", diff --git a/providers/postgres/docs/connections/postgres.rst b/providers/postgres/docs/connections/postgres.rst index 4daec3062490c..539620ad08cc9 100644 --- a/providers/postgres/docs/connections/postgres.rst +++ b/providers/postgres/docs/connections/postgres.rst @@ -75,6 +75,7 @@ Extra (optional) - ``namedtuplecursor``: Returns query results as named tuples using ``psycopg2.extras.NamedTupleCursor``. For more information, refer to the psycopg2 documentation on `connection and cursor subclasses `_. + If using psycopg (v3), refer to the documentation on `connection classes `_. More details on all Postgres parameters supported can be found in `Postgres documentation `_. diff --git a/providers/postgres/pyproject.toml b/providers/postgres/pyproject.toml index 56bc5c2002465..33ad216595bc1 100644 --- a/providers/postgres/pyproject.toml +++ b/providers/postgres/pyproject.toml @@ -80,6 +80,9 @@ dependencies = [ "polars" = [ "polars>=1.26.0" ] +"psycopg" = [ + "psycopg[binary]>=3.2.9", +] [dependency-groups] dev = [ @@ -92,6 +95,7 @@ dev = [ # Additional devel dependencies (do not remove this line and add extra development dependencies) "apache-airflow-providers-common-sql[pandas]", "apache-airflow-providers-common-sql[polars]", + "psycopg[binary]>=3.2.9", ] # To build docs: diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 2c3d0c61ce9a2..9c159deba7478 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -21,7 +21,7 @@ from collections.abc import Mapping from contextlib import closing from copy import deepcopy -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload import psycopg2 import psycopg2.extensions @@ -36,20 +36,66 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.dialects.postgres import PostgresDialect +USE_PSYCOPG3: bool +try: + import psycopg as psycopg # needed for patching in unit tests + import sqlalchemy + from packaging.version import Version + + sqlalchemy_version = Version(sqlalchemy.__version__) + is_sqla2 = (sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro) >= (2, 0, 0) + + USE_PSYCOPG3 = is_sqla2 # implicitly includes `and bool(psycopg)` since the import above succeeded +except (ImportError, ModuleNotFoundError): + USE_PSYCOPG3 = False + +if USE_PSYCOPG3: + from psycopg.rows import dict_row, namedtuple_row + from psycopg.types.json import register_default_adapters + if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame from polars import DataFrame as PolarsDataFrame - from psycopg2.extensions import connection from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.openlineage.sqlparser import DatabaseInfo + if USE_PSYCOPG3: + from psycopg.errors import Diagnostic + try: from airflow.sdk import Connection except ImportError: from airflow.models.connection import Connection # type: ignore[assignment] CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor +CursorRow: TypeAlias = dict[str, Any] | tuple[Any, ...] + + +class CompatConnection(Protocol): + """Protocol for type hinting psycopg2 and psycopg3 connection objects.""" + + def cursor(self, *args, **kwargs) -> Any: ... + def commit(self) -> None: ... + def close(self) -> None: ... + + # Context manager support + def __enter__(self) -> CompatConnection: ... + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... + + # Common properties + @property + def notices(self) -> list[Any]: ... + + # psycopg3 specific (optional) + @property + def adapters(self) -> Any: ... + + @property + def row_factory(self) -> Any: ... + + # Optional method for psycopg3 + def add_notice_handler(self, handler: Any) -> None: ... class PostgresHook(DbApiHook): @@ -58,8 +104,8 @@ class PostgresHook(DbApiHook): You can specify ssl parameters in the extra field of your connection as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. - Also you can choose cursor as ``{"cursor": "dictcursor"}``. Refer to the - psycopg2.extras for more details. + Also, you can choose cursor as ``{"cursor": "dictcursor"}``. Refer to the + psycopg2.extras or psycopg.rows for more details. Note: For Redshift, use keepalives_idle in the extra connection parameters and set it to less than 300 seconds. @@ -93,6 +139,8 @@ class PostgresHook(DbApiHook): conn_name_attr = "postgres_conn_id" default_conn_name = "postgres_default" + default_client_log_level = "warning" + default_connector_version: int = 2 conn_type = "postgres" hook_name = "Postgres" supports_autocommit = True @@ -113,11 +161,15 @@ def __init__( self, *args, options: str | None = None, enable_log_db_messages: bool = False, **kwargs ) -> None: super().__init__(*args, **kwargs) - self.conn: connection = None + self.conn: CompatConnection | None = None self.database: str | None = kwargs.pop("database", None) self.options = options self.enable_log_db_messages = enable_log_db_messages + @staticmethod + def __cast_nullable(value, dst_type: type) -> Any: + return dst_type(value) if value is not None else None + @property def sqlalchemy_url(self) -> URL: conn = self.connection @@ -125,12 +177,12 @@ def sqlalchemy_url(self) -> URL: if not isinstance(query, dict): raise AirflowException("The parameter 'sqlalchemy_query' must be of type dict!") return URL.create( - drivername="postgresql", - username=conn.login, - password=conn.password, - host=conn.host, - port=conn.port, - database=self.database or conn.schema, + drivername="postgresql+psycopg" if USE_PSYCOPG3 else "postgresql", + username=self.__cast_nullable(conn.login, str), + password=self.__cast_nullable(conn.password, str), + host=self.__cast_nullable(conn.host, str), + port=self.__cast_nullable(conn.port, int), + database=self.__cast_nullable(self.database, str) or self.__cast_nullable(conn.schema, str), query=query, ) @@ -142,8 +194,24 @@ def dialect_name(self) -> str: def dialect(self) -> Dialect: return PostgresDialect(self) + def _notice_handler(self, notice: Diagnostic): + """Handle notices from the database and log them.""" + self.log.info(str(notice.message_primary).strip()) + def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() + if USE_PSYCOPG3: + if _cursor == "dictcursor": + return dict_row + if _cursor == "namedtuplecursor": + return namedtuple_row + if _cursor == "realdictcursor": + raise AirflowException( + "realdictcursor is not supported with psycopg3. Use dictcursor instead." + ) + valid_cursors = "dictcursor, namedtuplecursor" + raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") + cursor_types = { "dictcursor": psycopg2.extras.DictCursor, "realdictcursor": psycopg2.extras.RealDictCursor, @@ -154,33 +222,63 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: valid_cursors = ", ".join(cursor_types.keys()) raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") - def get_conn(self) -> connection: + def _generate_cursor_name(self): + """Generate a unique name for server-side cursor.""" + import uuid + + return f"airflow_cursor_{uuid.uuid4().hex}" + + def get_conn(self) -> CompatConnection: """Establish a connection to a postgres database.""" conn = deepcopy(self.connection) - # check for authentication via AWS IAM if conn.extra_dejson.get("iam", False): - conn.login, conn.password, conn.port = self.get_iam_token(conn) + login, password, port = self.get_iam_token(conn) + conn.login = cast("Any", login) + conn.password = cast("Any", password) + conn.port = cast("Any", port) - conn_args = { + conn_args: dict[str, Any] = { "host": conn.host, "user": conn.login, "password": conn.password, "dbname": self.database or conn.schema, "port": conn.port, } - raw_cursor = conn.extra_dejson.get("cursor", False) - if raw_cursor: - conn_args["cursor_factory"] = self._get_cursor(raw_cursor) if self.options: conn_args["options"] = self.options + # Add extra connection arguments for arg_name, arg_val in conn.extra_dejson.items(): if arg_name not in self.ignored_extra_options: conn_args[arg_name] = arg_val - self.conn = psycopg2.connect(**conn_args) + if USE_PSYCOPG3: + from psycopg.connection import Connection as pgConnection + + raw_cursor = conn.extra_dejson.get("cursor") + if raw_cursor: + conn_args["row_factory"] = self._get_cursor(raw_cursor) + + # Use Any type for the connection args to avoid type conflicts + connection = pgConnection.connect(**cast("Any", conn_args)) + self.conn = cast("CompatConnection", connection) + + # Register JSON handlers for both json and jsonb types + # This ensures JSON data is properly decoded from bytes to Python objects + register_default_adapters(connection) + + # Add the notice handler AFTER the connection is established + if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"): + self.conn.add_notice_handler(self._notice_handler) + else: # psycopg2 + raw_cursor = conn.extra_dejson.get("cursor", False) + if raw_cursor: + conn_args["cursor_factory"] = self._get_cursor(raw_cursor) + + self.conn = cast("CompatConnection", psycopg2.connect(**conn_args)) + return self.conn @overload @@ -231,7 +329,9 @@ def get_df( engine = self.get_sqlalchemy_engine() with engine.connect() as conn: - return psql.read_sql(sql, con=conn, params=parameters, **kwargs) + if isinstance(sql, list): + sql = "; ".join(sql) # Or handle multiple queries differently + return cast("PandasDataFrame", psql.read_sql(sql, con=conn, params=parameters, **kwargs)) elif df_type == "polars": return self._get_polars_df(sql, parameters, **kwargs) @@ -241,7 +341,7 @@ def get_df( def copy_expert(self, sql: str, filename: str) -> None: """ - Execute SQL using psycopg2's ``copy_expert`` method. + Execute SQL using psycopg's ``copy_expert`` method. Necessary to execute COPY command without access to a superuser. @@ -252,14 +352,38 @@ def copy_expert(self, sql: str, filename: str) -> None: they have to check its existence by themselves. """ self.log.info("Running copy expert: %s, filename: %s", sql, filename) - if not os.path.isfile(filename): - with open(filename, "w"): - pass - - with open(filename, "r+") as file, closing(self.get_conn()) as conn, closing(conn.cursor()) as cur: - cur.copy_expert(sql, file) - file.truncate(file.tell()) - conn.commit() + if USE_PSYCOPG3: + if " from stdin" in sql.lower(): + # Handle COPY FROM STDIN: read from the file and write to the database. + if not os.path.isfile(filename): + with open(filename, "w"): + pass # Create an empty file to prevent errors. + + with open(filename, "rb") as file, self.get_conn() as conn, conn.cursor() as cur: + with cur.copy(sql) as copy: + while data := file.read(8192): + copy.write(data) + conn.commit() + else: + # Handle COPY TO STDOUT: read from the database and write to the file. + with open(filename, "wb") as file, self.get_conn() as conn, conn.cursor() as cur: + with cur.copy(sql) as copy: + for data in copy: + file.write(data) + conn.commit() + else: + if not os.path.isfile(filename): + with open(filename, "w"): + pass + + with ( + open(filename, "r+") as file, + closing(self.get_conn()) as conn, + closing(conn.cursor()) as cur, + ): + cur.copy_expert(sql, file) + file.truncate(file.tell()) + conn.commit() def get_uri(self) -> str: """ @@ -278,9 +402,9 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: self.copy_expert(f"COPY {table} TO STDOUT", tmp_file) @staticmethod - def _serialize_cell(cell: object, conn: connection | None = None) -> Any: + def _serialize_cell_ppg2(cell: object, conn: CompatConnection | None = None) -> Any: """ - Serialize a cell. + Serialize a cell using psycopg2. Psycopg2 adapts all arguments to the ``execute()`` method internally, hence we return the cell without any conversion. @@ -297,6 +421,24 @@ def _serialize_cell(cell: object, conn: connection | None = None) -> Any: """ return cell + @staticmethod + def _serialize_cell_ppg3(cell: object, conn: CompatConnection | None = None) -> Any: + """Serialize a cell using psycopg3.""" + if isinstance(cell, (dict, list)): + try: + from psycopg.types.json import Json + + return Json(cell) + except ImportError: + return cell + return cell + + @staticmethod + def _serialize_cell(cell: object, conn: Any | None = None) -> Any: + if USE_PSYCOPG3: + return PostgresHook._serialize_cell_ppg3(cell, conn) + return PostgresHook._serialize_cell_ppg2(cell, conn) + def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: """ Get the IAM token. @@ -420,11 +562,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: } def get_db_log_messages(self, conn) -> None: - """ - Log all database messages sent to the client during the session. + """Log database messages.""" + if not self.enable_log_db_messages: + return - :param conn: Connection object - """ - if self.enable_log_db_messages: - for output in conn.notices: - self.log.info(output) + if USE_PSYCOPG3: + self.log.debug( + "With psycopg3, database notices are logged upon creation (via self._notice_handler)." + ) + return + + for output in conn.notices: + self.log.info(output) diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index be332a7bda507..5a842666a5a2c 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -18,13 +18,12 @@ from __future__ import annotations import json -import logging import os +from types import SimpleNamespace from unittest import mock import pandas as pd import polars as pl -import psycopg2.extras import pytest import sqlalchemy @@ -40,7 +39,68 @@ INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, description, host, {}, login, password, port, is_encrypted, is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)" +USE_PSYCOPG3: bool +try: + from importlib.util import find_spec + + import sqlalchemy + from packaging.version import Version + + is_psycopg3 = find_spec("psycopg") is not None + sqlalchemy_version = Version(sqlalchemy.__version__) + is_sqla2 = (sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro) >= (2, 0, 0) + + USE_PSYCOPG3 = is_psycopg3 and is_sqla2 +except (ImportError, ModuleNotFoundError): + USE_PSYCOPG3 = False + +if USE_PSYCOPG3: + import psycopg.rows +else: + import psycopg2.extras + + +@pytest.fixture +def postgres_hook_setup(): + """Set up mock PostgresHook for testing.""" + table = "test_postgres_hook_table" + cur = mock.MagicMock(rowcount=0) + conn = mock.MagicMock() + conn.cursor.return_value = cur + + class UnitTestPostgresHook(PostgresHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + db_hook = UnitTestPostgresHook() + + # Return a namespace with all the objects + setup = SimpleNamespace(table=table, cur=cur, conn=conn, db_hook=db_hook) + + yield setup + + # Teardown - only for real database tests + try: + with PostgresHook().get_conn() as real_conn: + with real_conn.cursor() as real_cur: + real_cur.execute(f"DROP TABLE IF EXISTS {table}") + except Exception: + pass # Ignore cleanup errors for unit tests + + +@pytest.fixture +def mock_connect(mocker): + """Mock the connection object according to the correct psycopg version.""" + if USE_PSYCOPG3: + return mocker.patch("airflow.providers.postgres.hooks.postgres.psycopg.connection.Connection.connect") + return mocker.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") + + class TestPostgresHookConn: + """PostgresHookConn tests that are common to psycopg2 and psycopg3.""" + def setup_method(self): self.connection = Connection(login="login", password="password", host="host", schema="database") @@ -51,55 +111,6 @@ class UnitTestPostgresHook(PostgresHook): self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - def test_get_conn_non_default_id(self, mock_connect): - self.db_hook.test_conn_id = "non_default" - self.db_hook.get_conn() - mock_connect.assert_called_once_with( - user="login", password="password", host="host", dbname="database", port=None - ) - self.db_hook.get_connection.assert_called_once_with("non_default") - - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - def test_get_conn(self, mock_connect): - self.db_hook.get_conn() - mock_connect.assert_called_once_with( - user="login", password="password", host="host", dbname="database", port=None - ) - - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - def test_get_uri(self, mock_connect): - self.connection.conn_type = "postgres" - self.connection.port = 5432 - self.db_hook.get_conn() - assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" - - def test_sqlalchemy_url(self): - conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") - hook = PostgresHook(connection=conn) - expected = "postgresql://login-conn:password-conn@host/database" - if SQLALCHEMY_V_1_4: - assert str(hook.sqlalchemy_url) == expected - else: - assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected - - def test_sqlalchemy_url_with_sqlalchemy_query(self): - conn = Connection( - login="login-conn", - password="password-conn", - host="host", - schema="database", - extra=dict(sqlalchemy_query={"gssencmode": "disable"}), - ) - hook = PostgresHook(connection=conn) - - expected = "postgresql://login-conn:password-conn@host/database?gssencmode=disable" - if SQLALCHEMY_V_1_4: - assert str(hook.sqlalchemy_url) == expected - else: - assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected - def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self): conn = Connection( login="login-conn", @@ -113,12 +124,71 @@ def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self): with pytest.raises(AirflowException): hook.sqlalchemy_url - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - def test_get_conn_cursor(self, mock_connect): - self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": {"gssencmode": "disable"}}' + @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) + @pytest.mark.parametrize("port", [5432, 5439, None]) + @pytest.mark.parametrize( + "host,conn_cluster_identifier,expected_host", + [ + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + NOTSET, + "cluster-identifier.us-east-1", + ), + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + "different-identifier", + "different-identifier.us-east-1", + ), + ], + ) + def test_openlineage_methods_with_redshift( + self, + mocker, + aws_conn_id, + port, + host, + conn_cluster_identifier, + expected_host, + ): + mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + + mock_conn_extra = { + "iam": True, + "redshift": True, + } + if aws_conn_id is not NOTSET: + mock_conn_extra["aws_conn_id"] = aws_conn_id + if conn_cluster_identifier is not NOTSET: + mock_conn_extra["cluster-identifier"] = conn_cluster_identifier + + self.connection.extra = json.dumps(mock_conn_extra) + self.connection.host = host + self.connection.port = port + + # Mock AWS Connection + mock_aws_hook_instance = mock_aws_hook_class.return_value + mock_aws_hook_instance.region_name = "us-east-1" + + assert ( + self.db_hook._get_openlineage_redshift_authority_part(self.connection) + == f"{expected_host}:{port or 5439}" + ) + + def test_get_conn_non_default_id(self, mock_connect): + self.db_hook.test_conn_id = "non_default" + self.db_hook.get_conn() + mock_connect.assert_called_once_with( + user="login", + password="password", + host="host", + dbname="database", + port=None, + ) + self.db_hook.get_connection.assert_called_once_with("non_default") + + def test_get_conn(self, mock_connect): self.db_hook.get_conn() mock_connect.assert_called_once_with( - cursor_factory=psycopg2.extras.DictCursor, user="login", password="password", host="host", @@ -126,31 +196,46 @@ def test_get_conn_cursor(self, mock_connect): port=None, ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - def test_get_conn_with_invalid_cursor(self, mock_connect): + def test_get_uri(self, mock_connect): + self.connection.conn_type = "postgres" + self.connection.port = 5432 + self.db_hook.get_conn() + assert mock_connect.call_count == 1 + assert ( + self.db_hook.get_uri() + == f"postgresql{'+psycopg' if USE_PSYCOPG3 else ''}://login:password@host:5432/database" + ) + + @pytest.mark.usefixtures("mock_connect") + def test_get_conn_with_invalid_cursor(self): self.connection.extra = '{"cursor": "mycursor"}' with pytest.raises(ValueError): self.db_hook.get_conn() - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_from_connection(self, mock_connect): conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") hook = PostgresHook(connection=conn) hook.get_conn() mock_connect.assert_called_once_with( - user="login-conn", password="password-conn", host="host", dbname="database", port=None + user="login-conn", + password="password-conn", + host="host", + dbname="database", + port=None, ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_from_connection_with_database(self, mock_connect): conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") hook = PostgresHook(connection=conn, database="database-override") hook.get_conn() mock_connect.assert_called_once_with( - user="login-conn", password="password-conn", host="host", dbname="database-override", port=None + user="login-conn", + password="password-conn", + host="host", + dbname="database-override", + port=None, ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_from_connection_with_options(self, mock_connect): conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") hook = PostgresHook(connection=conn, options="-c statement_timeout=3000ms") @@ -164,11 +249,11 @@ def test_get_conn_from_connection_with_options(self, mock_connect): options="-c statement_timeout=3000ms", ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) @pytest.mark.parametrize("port", [65432, 5432, None]) - def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_conn_id, port): + def test_get_conn_rds_iam_postgres(self, mocker, mock_connect, aws_conn_id, port): + mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + mock_conn_extra = {"iam": True} if aws_conn_id is not NOTSET: mock_conn_extra["aws_conn_id"] = aws_conn_id @@ -178,9 +263,9 @@ def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_ # Mock AWS Connection mock_aws_hook_instance = mock_aws_hook_class.return_value - mock_client = mock.MagicMock() + mock_client = mocker.MagicMock() mock_client.generate_db_auth_token.return_value = mock_db_token - type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client) + type(mock_aws_hook_instance).conn = mocker.PropertyMock(return_value=mock_client) self.db_hook.get_conn() # Check AwsHook initialization @@ -202,7 +287,6 @@ def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_ port=(port or 5432), ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_extra(self, mock_connect): self.connection.extra = '{"connect_timeout": 3}' self.db_hook.get_conn() @@ -210,8 +294,6 @@ def test_get_conn_extra(self, mock_connect): user="login", password="password", host="host", dbname="database", port=None, connect_timeout=3 ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) @pytest.mark.parametrize("port", [5432, 5439, None]) @pytest.mark.parametrize( @@ -231,7 +313,7 @@ def test_get_conn_extra(self, mock_connect): ) def test_get_conn_rds_iam_redshift( self, - mock_aws_hook_class, + mocker, mock_connect, aws_conn_id, port, @@ -239,6 +321,8 @@ def test_get_conn_rds_iam_redshift( conn_cluster_identifier, expected_cluster_identifier, ): + mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + mock_conn_extra = { "iam": True, "redshift": True, @@ -256,12 +340,12 @@ def test_get_conn_rds_iam_redshift( # Mock AWS Connection mock_aws_hook_instance = mock_aws_hook_class.return_value - mock_client = mock.MagicMock() + mock_client = mocker.MagicMock() mock_client.get_cluster_credentials.return_value = { "DbPassword": mock_db_pass, "DbUser": mock_db_user, } - type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client) + type(mock_aws_hook_instance).conn = mocker.PropertyMock(return_value=mock_client) self.db_hook.get_conn() # Check AwsHook initialization @@ -286,8 +370,6 @@ def test_get_conn_rds_iam_redshift( port=(port or 5439), ) - @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) @pytest.mark.parametrize("port", [5432, 5439, None]) @pytest.mark.parametrize( @@ -307,7 +389,7 @@ def test_get_conn_rds_iam_redshift( ) def test_get_conn_rds_iam_redshift_serverless( self, - mock_aws_hook_class, + mocker, mock_connect, aws_conn_id, port, @@ -315,6 +397,8 @@ def test_get_conn_rds_iam_redshift_serverless( conn_workgroup_name, expected_workgroup_name, ): + mock_aws_hook_class = mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + mock_conn_extra = { "iam": True, "redshift-serverless": True, @@ -332,12 +416,12 @@ def test_get_conn_rds_iam_redshift_serverless( # Mock AWS Connection mock_aws_hook_instance = mock_aws_hook_class.return_value - mock_client = mock.MagicMock() + mock_client = mocker.MagicMock() mock_client.get_credentials.return_value = { "dbPassword": mock_db_pass, "dbUser": mock_db_user, } - type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client) + type(mock_aws_hook_instance).conn = mocker.PropertyMock(return_value=mock_client) self.db_hook.get_conn() # Check AwsHook initialization @@ -360,8 +444,9 @@ def test_get_conn_rds_iam_redshift_serverless( port=(port or 5439), ) - def test_get_uri_from_connection_without_database_override(self): - self.db_hook.get_connection = mock.MagicMock( + def test_get_uri_from_connection_without_database_override(self, mocker): + expected: str = f"postgresql{'+psycopg' if USE_PSYCOPG3 else ''}://login:password@host:1/database" + self.db_hook.get_connection = mocker.MagicMock( return_value=Connection( conn_type="postgres", host="host", @@ -371,11 +456,14 @@ def test_get_uri_from_connection_without_database_override(self): port=1, ) ) - assert self.db_hook.get_uri() == "postgresql://login:password@host:1/database" + assert self.db_hook.get_uri() == expected - def test_get_uri_from_connection_with_database_override(self): + def test_get_uri_from_connection_with_database_override(self, mocker): + expected: str = ( + f"postgresql{'+psycopg' if USE_PSYCOPG3 else ''}://login:password@host:1/database-override" + ) hook = PostgresHook(database="database-override") - hook.get_connection = mock.MagicMock( + hook.get_connection = mocker.MagicMock( return_value=Connection( conn_type="postgres", host="host", @@ -385,109 +473,151 @@ def test_get_uri_from_connection_with_database_override(self): port=1, ) ) - assert hook.get_uri() == "postgresql://login:password@host:1/database-override" - - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") - @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) - @pytest.mark.parametrize("port", [5432, 5439, None]) - @pytest.mark.parametrize( - "host,conn_cluster_identifier,expected_host", - [ - ( - "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", - NOTSET, - "cluster-identifier.us-east-1", - ), - ( - "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", - "different-identifier", - "different-identifier.us-east-1", - ), - ], - ) - def test_openlineage_methods_with_redshift( - self, - mock_aws_hook_class, - aws_conn_id, - port, - host, - conn_cluster_identifier, - expected_host, - ): - mock_conn_extra = { - "iam": True, - "redshift": True, - } - if aws_conn_id is not NOTSET: - mock_conn_extra["aws_conn_id"] = aws_conn_id - if conn_cluster_identifier is not NOTSET: - mock_conn_extra["cluster-identifier"] = conn_cluster_identifier - - self.connection.extra = json.dumps(mock_conn_extra) - self.connection.host = host - self.connection.port = port - - # Mock AWS Connection - mock_aws_hook_instance = mock_aws_hook_class.return_value - mock_aws_hook_instance.region_name = "us-east-1" - - assert ( - self.db_hook._get_openlineage_redshift_authority_part(self.connection) - == f"{expected_host}:{port or 5439}" - ) + assert hook.get_uri() == expected -@pytest.mark.backend("postgres") -class TestPostgresHook: - table = "test_postgres_hook_table" +@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") +class TestPostgresHookConnPPG2: + """PostgresHookConn tests that are specific to psycopg2.""" def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = conn = mock.MagicMock() - self.conn.cursor.return_value = self.cur + self.connection = Connection(login="login", password="password", host="host", schema="database") class UnitTestPostgresHook(PostgresHook): conn_name_attr = "test_conn_id" - def get_conn(self): - return conn - self.db_hook = UnitTestPostgresHook() + self.db_hook.get_connection = mock.Mock() + self.db_hook.get_connection.return_value = self.connection - def teardown_method(self): - with PostgresHook().get_conn() as conn: - with conn.cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table}") + def test_sqlalchemy_url(self): + conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") + hook = PostgresHook(connection=conn) + expected = "postgresql://login-conn:password-conn@host/database" + if SQLALCHEMY_V_1_4: + assert str(hook.sqlalchemy_url) == expected + else: + assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected - def test_copy_expert(self): - open_mock = mock.mock_open(read_data='{"some": "json"}') - with mock.patch("airflow.providers.postgres.hooks.postgres.open", open_mock): - statement = "SQL" - filename = "filename" + def test_sqlalchemy_url_with_sqlalchemy_query(self): + conn = Connection( + login="login-conn", + password="password-conn", + host="host", + schema="database", + extra=dict(sqlalchemy_query={"gssencmode": "disable"}), + ) + hook = PostgresHook(connection=conn) - self.cur.fetchall.return_value = None + expected = "postgresql://login-conn:password-conn@host/database?gssencmode=disable" + if SQLALCHEMY_V_1_4: + assert str(hook.sqlalchemy_url) == expected + else: + assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected - assert self.db_hook.copy_expert(statement, filename) is None + def test_get_conn_cursor(self, mock_connect): + self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": {"gssencmode": "disable"}}' + self.db_hook.get_conn() + mock_connect.assert_called_once_with( + cursor_factory=psycopg2.extras.DictCursor, + user="login", + password="password", + host="host", + dbname="database", + port=None, + ) - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - assert self.conn.commit.call_count == 1 - self.cur.copy_expert.assert_called_once_with(statement, open_mock.return_value) - assert open_mock.call_args.args == (filename, "r+") - def test_bulk_load(self, tmp_path): - hook = PostgresHook() - input_data = ["foo", "bar", "baz"] +@pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 not available") +class TestPostgresHookConnPPG3: + """PostgresHookConn tests that are specific to psycopg3.""" - with hook.get_conn() as conn, conn.cursor() as cur: - cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)") - conn.commit() + def setup_method(self): + self.connection = Connection(login="login", password="password", host="host", schema="database") - path = tmp_path / "testfile" - path.write_text("\n".join(input_data)) - hook.bulk_load(self.table, os.fspath(path)) + class UnitTestPostgresHook(PostgresHook): + conn_name_attr = "test_conn_id" - cur.execute(f"SELECT * FROM {self.table}") + self.db_hook = UnitTestPostgresHook() + self.db_hook.get_connection = mock.Mock() + self.db_hook.get_connection.return_value = self.connection + + def test_sqlalchemy_url(self): + conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") + hook = PostgresHook(connection=conn) + expected = "postgresql+psycopg://login-conn:password-conn@host/database" + if SQLALCHEMY_V_1_4: + assert str(hook.sqlalchemy_url) == expected + else: + assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected + + def test_sqlalchemy_url_with_sqlalchemy_query(self): + conn = Connection( + login="login-conn", + password="password-conn", + host="host", + schema="database", + extra=dict(sqlalchemy_query={"gssencmode": "disable"}), + ) + hook = PostgresHook(connection=conn) + + expected = "postgresql+psycopg://login-conn:password-conn@host/database?gssencmode=disable" + if SQLALCHEMY_V_1_4: + assert str(hook.sqlalchemy_url) == expected + else: + assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected + + def test_get_conn_cursor(self, mocker): + mock_connect = mocker.patch("psycopg.connection.Connection.connect") + self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": {"gssencmode": "disable"}}' + self.db_hook.get_conn() + mock_connect.assert_called_once_with( + row_factory=psycopg.rows.dict_row, + user="login", + password="password", + host="host", + dbname="database", + port=None, + ) + + +@pytest.mark.backend("postgres") +class TestPostgresHook: + """Tests that are identical between psycopg2 and psycopg3.""" + + table = "test_postgres_hook_table" + + def setup_method(self): + self.cur = mock.MagicMock(rowcount=0) + self.conn = conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + + class UnitTestPostgresHook(PostgresHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + self.db_hook = UnitTestPostgresHook() + + def teardown_method(self): + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table}") + + def test_bulk_load(self, tmp_path): + hook = PostgresHook() + input_data = ["foo", "bar", "baz"] + + with hook.get_conn() as conn, conn.cursor() as cur: + cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)") + conn.commit() + + path = tmp_path / "testfile" + path.write_text("\n".join(input_data)) + hook.bulk_load(self.table, os.fspath(path)) + + cur.execute(f"SELECT * FROM {self.table}") results = [row[0] for row in cur.fetchall()] assert sorted(input_data) == sorted(results) @@ -515,19 +645,20 @@ def test_bulk_dump(self, tmp_path): ("polars", pl.DataFrame), ], ) - @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook._get_polars_df") - @mock.patch("pandas.io.sql.read_sql") - @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.get_sqlalchemy_engine") - def test_get_df_with_df_type( - self, mock_get_engine, mock_read_sql, mock_polars_df, df_type, expected_type - ): + def test_get_df_with_df_type(self, mocker, df_type, expected_type): + mock_polars_df = mocker.patch("airflow.providers.postgres.hooks.postgres.PostgresHook._get_polars_df") + mock_read_sql = mocker.patch("pandas.io.sql.read_sql") + mock_get_engine = mocker.patch( + "airflow.providers.postgres.hooks.postgres.PostgresHook.get_sqlalchemy_engine" + ) + hook = mock_db_hook(PostgresHook) mock_read_sql.return_value = pd.DataFrame() mock_polars_df.return_value = pl.DataFrame() sql = "SELECT * FROM table" if df_type == "pandas": - mock_conn = mock.MagicMock() - mock_engine = mock.MagicMock() + mock_conn = mocker.MagicMock() + mock_engine = mocker.MagicMock() mock_engine.connect.return_value.__enter__.return_value = mock_conn mock_get_engine.return_value = mock_engine df = hook.get_df(sql, df_type="pandas") @@ -538,6 +669,402 @@ def test_get_df_with_df_type( mock_polars_df.assert_called_once_with(sql, None) assert isinstance(df, expected_type) + def test_rowcount(self): + hook = PostgresHook() + input_data = ["foo", "bar", "baz"] + + with hook.get_conn() as conn: + with conn.cursor() as cur: + cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)") + values = ",".join(f"('{data}')" for data in input_data) + cur.execute(f"INSERT INTO {self.table} VALUES {values}") + conn.commit() + assert cur.rowcount == len(input_data) + + def test_reserved_words(self): + hook = PostgresHook() + assert hook.reserved_words == sqlalchemy.dialects.postgresql.base.RESERVED_WORDS + + def test_generate_insert_sql_without_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + "schema", + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = PostgresHook() + assert hook._generate_insert_sql( + table="connection", values=values, target_fields=target_fields + ) == INSERT_SQL_STATEMENT.format("schema") + + def test_generate_insert_sql_with_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + '"schema"', + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = PostgresHook() + assert hook._generate_insert_sql( + table="connection", values=values, target_fields=target_fields + ) == INSERT_SQL_STATEMENT.format('"schema"') + + +@pytest.mark.backend("postgres") +@pytest.mark.skipif(USE_PSYCOPG3, reason="psycopg v3 is available") +class TestPostgresHookPPG2: + """PostgresHook tests that are specific to psycopg2.""" + + table = "test_postgres_hook_table" + + def setup_method(self): + self.cur = mock.MagicMock(rowcount=0) + self.conn = conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + + class UnitTestPostgresHook(PostgresHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + self.db_hook = UnitTestPostgresHook() + + def teardown_method(self): + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table}") + + def test_copy_expert(self, mocker): + open_mock = mocker.mock_open(read_data='{"some": "json"}') + mocker.patch("airflow.providers.postgres.hooks.postgres.open", open_mock) + + statement = "SQL" + filename = "filename" + + self.cur.fetchall.return_value = None + + assert self.db_hook.copy_expert(statement, filename) is None + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 1 + self.cur.copy_expert.assert_called_once_with(statement, open_mock.return_value) + assert open_mock.call_args.args == (filename, "r+") + + def test_insert_rows(self, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [("hello",), ("world",)] + + setup.db_hook.insert_rows(table, rows) + + assert setup.conn.close.call_count == 1 + assert setup.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + assert commit_count == setup.conn.commit.call_count + + sql = f"INSERT INTO {table} VALUES (%s)" + setup.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace(self, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [ + ( + 1, + "hello", + ), + ( + 2, + "world", + ), + ] + fields = ("id", "value") + + setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) + + assert setup.conn.close.call_count == 1 + assert setup.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + assert commit_count == setup.conn.commit.call_count + + sql = ( + f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) " + f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = excluded.{fields[1]}" + ) + setup.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace_missing_target_field_arg(self, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [ + ( + 1, + "hello", + ), + ( + 2, + "world", + ), + ] + fields = ("id", "value") + with pytest.raises(ValueError) as ctx: + setup.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) + + assert str(ctx.value) == "PostgreSQL ON CONFLICT upsert syntax requires column names" + + def test_insert_rows_replace_missing_replace_index_arg(self, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [ + ( + 1, + "hello", + ), + ( + 2, + "world", + ), + ] + fields = ("id", "value") + with pytest.raises(ValueError) as ctx: + setup.db_hook.insert_rows(table, rows, fields, replace=True) + + assert str(ctx.value) == "PostgreSQL ON CONFLICT upsert syntax requires an unique index" + + def test_insert_rows_replace_all_index(self, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [ + ( + 1, + "hello", + ), + ( + 2, + "world", + ), + ] + fields = ("id", "value") + + setup.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields) + + assert setup.conn.close.call_count == 1 + assert setup.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + assert commit_count == setup.conn.commit.call_count + + sql = ( + f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) " + f"ON CONFLICT ({', '.join(fields)}) DO NOTHING" + ) + setup.cur.executemany.assert_any_call(sql, rows) + + @pytest.mark.usefixtures("reset_logging_config") + def test_get_all_db_log_messages(self, mocker): + messages = ["a", "b", "c"] + + class FakeLogger: + notices = messages + + # Mock the logger + mock_logger = mocker.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.log") + + hook = PostgresHook(enable_log_db_messages=True) + hook.get_db_log_messages(FakeLogger) + + # Verify that info was called for each message + for msg in messages: + assert mocker.call(msg) in mock_logger.info.mock_calls + + @pytest.mark.usefixtures("reset_logging_config") + def test_log_db_messages_by_db_proc(self, mocker): + proc_name = "raise_notice" + notice_proc = f""" + CREATE PROCEDURE {proc_name} (s text) LANGUAGE PLPGSQL AS + $$ + BEGIN + raise notice 'Message from db: %', s; + END; + $$; + """ + + # Mock the logger + mock_logger = mocker.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.log") + + hook = PostgresHook(enable_log_db_messages=True) + try: + hook.run(sql=notice_proc) + hook.run(sql=f"call {proc_name}('42')") + + # Check if the notice message was logged + assert mocker.call("NOTICE: Message from db: 42\n") in mock_logger.info.mock_calls + finally: + hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") + + def test_dialect_name(self, postgres_hook_setup): + setup = postgres_hook_setup + assert setup.db_hook.dialect_name == "postgresql" + + def test_dialect(self, postgres_hook_setup): + setup = postgres_hook_setup + assert isinstance(setup.db_hook.dialect, PostgresDialect) + + +@pytest.mark.backend("postgres") +@pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 are not available") +class TestPostgresHookPPG3: + """PostgresHook tests that are specific to psycopg3.""" + + table = "test_postgres_hook_table" + + def setup_method(self): + self.cur = mock.MagicMock(rowcount=0) + self.conn = conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + + class UnitTestPostgresHook(PostgresHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + self.db_hook = UnitTestPostgresHook() + + def teardown_method(self): + with PostgresHook().get_conn() as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table}") + + def test_copy_expert_from(self, mocker): + """Tests copy_expert with a 'COPY FROM STDIN' operation.""" + statement = "COPY test_table FROM STDIN" + filename = "filename" + + m_open = mocker.mock_open() + mocker.patch("airflow.providers.postgres.hooks.postgres.open", m_open) + mocker.patch("os.path.isfile", return_value=True) # Mock file exists + + # Configure the file handle for reading + m_open.return_value.read.side_effect = [b'{"some": "json"}', b""] + + # Set up the context manager chain properly + # self.conn needs to support context manager + self.conn.__enter__ = mocker.Mock(return_value=self.conn) + self.conn.__exit__ = mocker.Mock(return_value=None) + + # cursor() returns something that also supports context manager + mock_cursor = mocker.MagicMock() + mock_cursor.__enter__ = mocker.Mock(return_value=mock_cursor) + mock_cursor.__exit__ = mocker.Mock(return_value=None) + self.conn.cursor.return_value = mock_cursor + + # cursor.copy() returns a context manager + mock_copy = mocker.MagicMock() + mock_copy.__enter__ = mocker.Mock(return_value=mock_copy) + mock_copy.__exit__ = mocker.Mock(return_value=None) + mock_cursor.copy.return_value = mock_copy + + # Call the method under test + self.db_hook.copy_expert(statement, filename) + + # Assert that the file was opened for reading in binary mode + m_open.assert_any_call(filename, "rb") + + # Assert write was called with the data + mock_copy.write.assert_called_once_with(b'{"some": "json"}') + self.conn.commit.assert_called_once() + + def test_copy_expert_to(self, mocker): + """Tests copy_expert with a 'COPY TO STDOUT' operation.""" + statement = "COPY test_table TO STDOUT" + filename = "filename" + + m_open = mocker.mock_open() + mocker.patch("airflow.providers.postgres.hooks.postgres.open", m_open) + + # Set up the context manager chain properly + # self.conn needs to support context manager + self.conn.__enter__ = mocker.Mock(return_value=self.conn) + self.conn.__exit__ = mocker.Mock(return_value=None) + + # cursor() returns something that also supports context manager + mock_cursor = mocker.MagicMock() + mock_cursor.__enter__ = mocker.Mock(return_value=mock_cursor) + mock_cursor.__exit__ = mocker.Mock(return_value=None) + self.conn.cursor.return_value = mock_cursor + + # cursor.copy() returns a context manager that is iterable + mock_copy = mocker.MagicMock() + mock_copy.__enter__ = mocker.Mock(return_value=mock_copy) + mock_copy.__exit__ = mocker.Mock(return_value=None) + mock_copy.__iter__ = mocker.Mock(return_value=iter([b"db_data_1", b"db_data_2"])) + mock_cursor.copy.return_value = mock_copy + + # Call the method under test + self.db_hook.copy_expert(statement, filename) + + # Assert that the file was opened for writing in binary mode + m_open.assert_any_call(filename, "wb") + + # Assert that the data from the DB was written to the file + handle = m_open.return_value + handle.write.assert_has_calls( + [ + mocker.call(b"db_data_1"), + mocker.call(b"db_data_2"), + ] + ) + self.conn.commit.assert_called_once() + def test_insert_rows(self): table = "table" rows = [("hello",), ("world",)] @@ -645,33 +1172,12 @@ def test_insert_rows_replace_all_index(self): ) self.cur.executemany.assert_any_call(sql, rows) - def test_rowcount(self): - hook = PostgresHook() - input_data = ["foo", "bar", "baz"] - - with hook.get_conn() as conn: - with conn.cursor() as cur: - cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)") - values = ",".join(f"('{data}')" for data in input_data) - cur.execute(f"INSERT INTO {self.table} VALUES {values}") - conn.commit() - assert cur.rowcount == len(input_data) - - @pytest.mark.usefixtures("reset_logging_config") - def test_get_all_db_log_messages(self, caplog): - messages = ["a", "b", "c"] - - class FakeLogger: - notices = messages - - with caplog.at_level(logging.INFO): - hook = PostgresHook(enable_log_db_messages=True) - hook.get_db_log_messages(FakeLogger) - for msg in messages: - assert msg in caplog.text + @pytest.mark.skip(reason="Notice handling is callback-based in psycopg3 and cannot be tested this way.") + def test_get_all_db_log_messages(self, mocker): + pass @pytest.mark.usefixtures("reset_logging_config") - def test_log_db_messages_by_db_proc(self, caplog): + def test_log_db_messages_by_db_proc(self, mocker): proc_name = "raise_notice" notice_proc = f""" CREATE PROCEDURE {proc_name} (s text) LANGUAGE PLPGSQL AS @@ -681,89 +1187,22 @@ def test_log_db_messages_by_db_proc(self, caplog): END; $$; """ - with caplog.at_level(logging.INFO): - hook = PostgresHook(enable_log_db_messages=True) - try: - hook.run(sql=notice_proc) - hook.run(sql=f"call {proc_name}('42')") - assert "NOTICE: Message from db: 42" in caplog.text - finally: - hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") + + # Mock the logger + mock_logger = mocker.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.log") + + hook = PostgresHook(enable_log_db_messages=True) + try: + hook.run(sql=notice_proc) + hook.run(sql=f"call {proc_name}('42')") + + # Check if the notice message was logged + mock_logger.info.assert_any_call("Message from db: 42") + finally: + hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") def test_dialect_name(self): assert self.db_hook.dialect_name == "postgresql" def test_dialect(self): assert isinstance(self.db_hook.dialect, PostgresDialect) - - def test_reserved_words(self): - hook = PostgresHook() - assert hook.reserved_words == sqlalchemy.dialects.postgresql.base.RESERVED_WORDS - - def test_generate_insert_sql_without_already_escaped_column_name(self): - values = [ - "1", - "mssql_conn", - "mssql", - "MSSQL connection", - "localhost", - "airflow", - "admin", - "admin", - 1433, - False, - False, - {}, - ] - target_fields = [ - "id", - "conn_id", - "conn_type", - "description", - "host", - "schema", - "login", - "password", - "port", - "is_encrypted", - "is_extra_encrypted", - "extra", - ] - hook = PostgresHook() - assert hook._generate_insert_sql( - table="connection", values=values, target_fields=target_fields - ) == INSERT_SQL_STATEMENT.format("schema") - - def test_generate_insert_sql_with_already_escaped_column_name(self): - values = [ - "1", - "mssql_conn", - "mssql", - "MSSQL connection", - "localhost", - "airflow", - "admin", - "admin", - 1433, - False, - False, - {}, - ] - target_fields = [ - "id", - "conn_id", - "conn_type", - "description", - "host", - '"schema"', - "login", - "password", - "port", - "is_encrypted", - "is_extra_encrypted", - "extra", - ] - hook = PostgresHook() - assert hook._generate_insert_sql( - table="connection", values=values, target_fields=target_fields - ) == INSERT_SQL_STATEMENT.format('"schema"')