Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
from __future__ import annotations

from collections.abc import Callable

from methodtools import lru_cache

from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.common.sql.dialects.dialect import Dialect, T


class PostgresDialect(Dialect):
Expand All @@ -39,22 +41,55 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] |
"""
if schema is None:
table, schema = self.extract_schema_from_table(table)
sql = """
select kcu.column_name
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on kcu.constraint_name = tco.constraint_name
and kcu.constraint_schema = tco.constraint_schema
and kcu.constraint_name = tco.constraint_name
where tco.constraint_type = 'PRIMARY KEY'
and kcu.table_schema = %s
and kcu.table_name = %s
"""
pk_columns = [
row[0] for row in self.get_records(sql, (self.unescape_word(schema), self.unescape_word(table)))
row["column_name"]
for row in self.get_records(
"""
select kcu.column_name as column_name
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on kcu.constraint_name = tco.constraint_name
and kcu.constraint_schema = tco.constraint_schema
and kcu.constraint_name = tco.constraint_name
where tco.constraint_type = 'PRIMARY KEY'
and kcu.table_schema = %s
and kcu.table_name = %s
order by kcu.ordinal_position
""",
(self.unescape_word(schema), self.unescape_word(table)),
)
]
return pk_columns or None

@lru_cache(maxsize=None)
def get_column_names(
self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True
) -> list[str] | None:
if schema is None:
table, schema = self.extract_schema_from_table(table)
column_names = list(
row["column_name"]
for row in filter(
predicate,
self.get_records(
"""
select column_name,
data_type,
is_nullable,
column_default,
ordinal_position
from information_schema.columns
where table_schema = %s
and table_name = %s
order by ordinal_position
""",
(self.unescape_word(schema), self.unescape_word(table)),
),
)
)
self.log.debug("Column names for table '%s': %s", table, column_names)
return column_names

def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:
"""
Generate the REPLACE SQL statement.
Expand Down
21 changes: 9 additions & 12 deletions providers/postgres/tests/unit/postgres/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,26 @@

from unittest.mock import MagicMock

from sqlalchemy.engine import Inspector

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.postgres.dialects.postgres import PostgresDialect


class TestPostgresDialect:
def setup_method(self):
inspector = MagicMock(spc=Inspector)
inspector.get_columns.side_effect = lambda table_name, schema: [
{"name": "id", "identity": True},
{"name": "name"},
{"name": "firstname"},
{"name": "age"},
]

def get_records(sql, parameters):
assert isinstance(sql, str)
assert "hollywood" in parameters, "Missing 'schema' in parameters"
assert "actors" in parameters, "Missing 'table' in parameters"
return [("id",)]
if "kcu." in sql:
return [{"column_name": "id"}]
return [
{"column_name": "id", "identity": True},
{"column_name": "name"},
{"column_name": "firstname"},
{"column_name": "age"},
]

self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook)
self.test_db_hook = MagicMock(placeholder="?", spec=DbApiHook)
self.test_db_hook.get_records.side_effect = get_records
self.test_db_hook.insert_statement_format = "INSERT INTO {} {} VALUES ({})"
self.test_db_hook.escape_word_format = '"{}"'
Expand Down
Loading