Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions airflow-core/src/airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@
from airflow.utils.orm_event_handlers import setup_event_handlers
from airflow.utils.sqlalchemy import is_sqlalchemy_v1

USE_PSYCOPG3: bool
try:
from importlib.util import find_spec

is_psycopg3 = find_spec("psycopg") is not None

USE_PSYCOPG3 = is_psycopg3 and not is_sqlalchemy_v1()
except (ImportError, ModuleNotFoundError):
USE_PSYCOPG3 = False

if TYPE_CHECKING:
from sqlalchemy.engine import Engine

Expand Down Expand Up @@ -426,12 +436,17 @@ def clean_in_fork():
register_at_fork(after_in_child=clean_in_fork)


DEFAULT_ENGINE_ARGS = {
"postgresql": {
"executemany_mode": "values_plus_batch",
"executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000,
"executemany_batch_page_size": 2000,
},
DEFAULT_ENGINE_ARGS: dict[str, dict[str, Any]] = {
"postgresql": (
{
"executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000,
}
| (
{}
if USE_PSYCOPG3
else {"executemany_mode": "values_plus_batch", "executemany_batch_page_size": 2000}
)
)
}


Expand Down
34 changes: 31 additions & 3 deletions airflow-core/src/airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.task_instance_session import get_current_task_instance_session

USE_PSYCOPG3: bool
try:
from importlib.util import find_spec

import sqlalchemy
from packaging.version import Version

is_psycopg3 = find_spec("psycopg") is not None
sqlalchemy_version = Version(sqlalchemy.__version__)
is_sqla2 = (sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro) >= (2, 0, 0)

USE_PSYCOPG3 = is_psycopg3 and is_sqla2
except (ImportError, ModuleNotFoundError):
USE_PSYCOPG3 = False

if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
Expand Down Expand Up @@ -1284,15 +1299,28 @@ def create_global_lock(
dialect = conn.dialect
try:
if dialect.name == "postgresql":
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
if USE_PSYCOPG3:
# psycopg3 doesn't support parameters for `SET`. Use `set_config` instead.
# The timeout value must be passed as a string of milliseconds.
conn.execute(
text("SELECT set_config('lock_timeout', :timeout, false)"),
{"timeout": str(lock_timeout)},
)
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
else:
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})

yield
finally:
if dialect.name == "postgresql":
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
if USE_PSYCOPG3:
# Use set_config() to reset the timeout to its default (0 = off/wait forever).
conn.execute(text("SELECT set_config('lock_timeout', '0', false)"))
else:
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
if not unlocked:
raise RuntimeError("Error releasing DB lock!")
Expand Down
23 changes: 19 additions & 4 deletions airflow-core/tests/unit/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.models import Connection, crypto
from airflow.sdk import BaseHook

from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4, SQLALCHEMY_V_2_0

sqlite = pytest.importorskip("airflow.providers.sqlite.hooks.sqlite")

Expand Down Expand Up @@ -679,10 +679,20 @@ def test_env_var_priority(self, mock_supervisor_comms):
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id="test_uri")
hook = conn.get_hook()
assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database"

ppg3_mode: bool = SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri()
if ppg3_mode:
assert (
hook.get_uri() == "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database"
)
else:
assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database"
conn2 = BaseHook.get_connection(conn_id="test_uri_no_creds")
hook2 = conn2.get_hook()
assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database"
if ppg3_mode:
assert hook2.get_uri() == "postgresql+psycopg://ec2.compute.com/the_database"
else:
assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database"

@mock.patch.dict(
"os.environ",
Expand All @@ -695,7 +705,12 @@ def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id="test_uri")
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
expected = "postgresql://username:password@ec2.compute.com:5432/the_database"

if SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri():
expected = "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database"
else:
expected = "postgresql://username:password@ec2.compute.com:5432/the_database"

assert isinstance(engine, sqlalchemy.engine.Engine)
if SQLALCHEMY_V_1_4:
assert str(engine.url) == expected
Expand Down
52 changes: 52 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,27 @@ def test_cli_shell_postgres(self, mock_execute_interactive):
"PGUSER": "postgres",
}

@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("postgresql+psycopg://postgres:airflow@postgres:5432/airflow"),
)
def test_cli_shell_postgres_ppg3(self, mock_execute_interactive):
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")

db_command.shell(self.parser.parse_args(["db", "shell"]))
mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY)
_, kwargs = mock_execute_interactive.call_args
env = kwargs["env"]
postgres_env = {k: v for k, v in env.items() if k.startswith("PG")}
assert postgres_env == {
"PGDATABASE": "airflow",
"PGHOST": "postgres",
"PGPASSWORD": "airflow",
"PGPORT": "5432",
"PGUSER": "postgres",
}

@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
Expand All @@ -257,6 +278,27 @@ def test_cli_shell_postgres_without_port(self, mock_execute_interactive):
"PGUSER": "postgres",
}

@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("postgresql+psycopg://postgres:airflow@postgres/airflow"),
)
def test_cli_shell_postgres_without_port_ppg3(self, mock_execute_interactive):
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")

db_command.shell(self.parser.parse_args(["db", "shell"]))
mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY)
_, kwargs = mock_execute_interactive.call_args
env = kwargs["env"]
postgres_env = {k: v for k, v in env.items() if k.startswith("PG")}
assert postgres_env == {
"PGDATABASE": "airflow",
"PGHOST": "postgres",
"PGPASSWORD": "airflow",
"PGPORT": "5432",
"PGUSER": "postgres",
}

@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"),
Expand All @@ -265,6 +307,16 @@ def test_cli_shell_invalid(self):
with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"):
db_command.shell(self.parser.parse_args(["db", "shell"]))

@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("invalid+psycopg://postgres:airflow@postgres/airflow"),
)
def test_cli_shell_invalid_ppg3(self):
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")

with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg"):
db_command.shell(self.parser.parse_args(["db", "shell"]))

@pytest.mark.parametrize(
"args, match",
[
Expand Down
1 change: 1 addition & 0 deletions devel-common/src/docs/utils/conf_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def get_autodoc_mock_imports() -> list[str]:
"pandas_gbq",
"paramiko",
"pinotdb",
"psycopg",
"psycopg2",
"pydruid",
"pyhive",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ def _run_command(self, cur, sql_statement, parameters):
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)

if parameters:
# If we're using psycopg3, we might need to handle parameters differently
if hasattr(cur, "__module__") and "psycopg" in cur.__module__ and isinstance(parameters, list):
parameters = tuple(parameters)
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def test_sql_sensor_postgres(self):
op2 = SqlSensor(
task_id="sql_sensor_check_2",
conn_id="postgres_default",
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
sql="SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = %s",
parameters=["information_schema"],
dag=self.dag,
)
op2.execute({})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,9 @@ def cleanup_database_hook(self) -> None:
raise ValueError("The db_hook should be set")
if not isinstance(self.db_hook, PostgresHook):
raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}")
conn = getattr(self.db_hook, "conn")
if conn and conn.notices:
for output in self.db_hook.conn.notices:
conn = getattr(self.db_hook, "conn", None)
if conn and hasattr(conn, "notices") and conn.notices:
for output in conn.notices:
self.log.info(output)

def reserve_free_tcp_port(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from slugify import slugify

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.postgres.hooks.postgres import USE_PSYCOPG3, PostgresHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
Expand All @@ -52,9 +52,20 @@ def __init__(self, cursor):
self.initialized = False

def __iter__(self):
"""Make the cursor iterable."""
return self

def __next__(self):
"""Fetch next row from the cursor."""
if USE_PSYCOPG3:
if self.rows:
return self.rows.pop()
self.initialized = True
row = self.cursor.fetchone()
if row is None:
raise StopIteration
return row
# psycopg2
if self.rows:
return self.rows.pop()
self.initialized = True
Expand Down Expand Up @@ -141,13 +152,29 @@ def db_hook(self) -> PostgresHook:
return PostgresHook(postgres_conn_id=self.postgres_conn_id)

def query(self):
"""Query Postgres and returns a cursor to the results."""
"""Execute the query and return a cursor."""
conn = self.db_hook.get_conn()
cursor = conn.cursor(name=self._unique_name())
cursor.execute(self.sql, self.parameters)
if self.use_server_side_cursor:
cursor.itersize = self.cursor_itersize
return _PostgresServerSideCursorDecorator(cursor)

if USE_PSYCOPG3:
from psycopg.types.json import register_default_adapters

# Register JSON handlers for this connection if not already done
register_default_adapters(conn)

if self.use_server_side_cursor:
cursor_name = f"airflow_{self.task_id.replace('-', '_')}_{uuid.uuid4().hex}"[:63]
cursor = conn.cursor(name=cursor_name)
cursor.itersize = self.cursor_itersize
cursor.execute(self.sql, self.parameters)
return _PostgresServerSideCursorDecorator(cursor)
cursor = conn.cursor()
cursor.execute(self.sql, self.parameters)
else:
cursor = conn.cursor(name=self._unique_name())
cursor.execute(self.sql, self.parameters)
if self.use_server_side_cursor:
cursor.itersize = self.cursor_itersize
return _PostgresServerSideCursorDecorator(cursor)
return cursor

def field_to_bigquery(self, field) -> dict[str, str]:
Expand Down Expand Up @@ -182,8 +209,14 @@ def convert_type(self, value, schema_type, stringify_dict=True):
hours=formatted_time.tm_hour, minutes=formatted_time.tm_min, seconds=formatted_time.tm_sec
)
return str(time_delta)
if stringify_dict and isinstance(value, dict):
return json.dumps(value)
if stringify_dict:
if USE_PSYCOPG3:
from psycopg.types.json import Json

if isinstance(value, (dict, Json)):
return json.dumps(value)
elif isinstance(value, dict):
return json.dumps(value)
if isinstance(value, Decimal):
return float(value)
return value
Expand Down
11 changes: 11 additions & 0 deletions providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,17 @@ def call_get_conn():
"postgresql+psycopg2://login:password@host:1234/schema",
id="sqlalchemy-scheme-with-driver",
),
pytest.param(
{
"conn_params": {
"extra": json.dumps(
{"sqlalchemy_scheme": "postgresql", "sqlalchemy_driver": "psycopg"}
)
}
},
"postgresql+psycopg://login:password@host:1234/schema",
id="sqlalchemy-scheme-with-driver-ppg3",
),
pytest.param(
{
"login": "user@domain",
Expand Down
1 change: 1 addition & 0 deletions providers/postgres/docs/connections/postgres.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Extra (optional)
- ``namedtuplecursor``: Returns query results as named tuples using ``psycopg2.extras.NamedTupleCursor``.

For more information, refer to the psycopg2 documentation on `connection and cursor subclasses <https://www.psycopg.org/docs/extras.html#connection-and-cursor-subclasses>`_.
If using psycopg (v3), refer to the documentation on `connection classes <https://www.psycopg.org/psycopg3/docs/api/connections.html>`_.

More details on all Postgres parameters supported can be found in
`Postgres documentation <https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING>`_.
Expand Down
4 changes: 4 additions & 0 deletions providers/postgres/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ dependencies = [
"polars" = [
"polars>=1.26.0"
]
"psycopg" = [
"psycopg[binary]>=3.2.9",
]

[dependency-groups]
dev = [
Expand All @@ -92,6 +95,7 @@ dev = [
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[pandas]",
"apache-airflow-providers-common-sql[polars]",
"psycopg[binary]>=3.2.9",
]

# To build docs:
Expand Down
Loading
Loading