diff --git a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py index e8bef1d2b5a48..6c1d3104d10c0 100644 --- a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py @@ -16,9 +16,11 @@ # under the License. from __future__ import annotations +from collections.abc import Callable + from methodtools import lru_cache -from airflow.providers.common.sql.dialects.dialect import Dialect +from airflow.providers.common.sql.dialects.dialect import Dialect, T class PostgresDialect(Dialect): @@ -39,22 +41,55 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | """ if schema is None: table, schema = self.extract_schema_from_table(table) - sql = """ - select kcu.column_name - from information_schema.table_constraints tco - join information_schema.key_column_usage kcu - on kcu.constraint_name = tco.constraint_name - and kcu.constraint_schema = tco.constraint_schema - and kcu.constraint_name = tco.constraint_name - where tco.constraint_type = 'PRIMARY KEY' - and kcu.table_schema = %s - and kcu.table_name = %s - """ pk_columns = [ - row[0] for row in self.get_records(sql, (self.unescape_word(schema), self.unescape_word(table))) + row["column_name"] + for row in self.get_records( + """ + select kcu.column_name as column_name + from information_schema.table_constraints tco + join information_schema.key_column_usage kcu + on kcu.constraint_name = tco.constraint_name + and kcu.constraint_schema = tco.constraint_schema + and kcu.constraint_name = tco.constraint_name + where tco.constraint_type = 'PRIMARY KEY' + and kcu.table_schema = %s + and kcu.table_name = %s + order by kcu.ordinal_position + """, + (self.unescape_word(schema), self.unescape_word(table)), + ) ] return pk_columns or None + @lru_cache(maxsize=None) + def get_column_names( + self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True + ) -> list[str] | None: + if schema is None: + table, schema = self.extract_schema_from_table(table) + column_names = list( + row["column_name"] + for row in filter( + predicate, + self.get_records( + """ + select column_name, + data_type, + is_nullable, + column_default, + ordinal_position + from information_schema.columns + where table_schema = %s + and table_name = %s + order by ordinal_position + """, + (self.unescape_word(schema), self.unescape_word(table)), + ), + ) + ) + self.log.debug("Column names for table '%s': %s", table, column_names) + return column_names + def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: """ Generate the REPLACE SQL statement. diff --git a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py index c7593723325ac..5406805420b85 100644 --- a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py @@ -19,29 +19,26 @@ from unittest.mock import MagicMock -from sqlalchemy.engine import Inspector - from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.dialects.postgres import PostgresDialect class TestPostgresDialect: def setup_method(self): - inspector = MagicMock(spc=Inspector) - inspector.get_columns.side_effect = lambda table_name, schema: [ - {"name": "id", "identity": True}, - {"name": "name"}, - {"name": "firstname"}, - {"name": "age"}, - ] - def get_records(sql, parameters): assert isinstance(sql, str) assert "hollywood" in parameters, "Missing 'schema' in parameters" assert "actors" in parameters, "Missing 'table' in parameters" - return [("id",)] + if "kcu." in sql: + return [{"column_name": "id"}] + return [ + {"column_name": "id", "identity": True}, + {"column_name": "name"}, + {"column_name": "firstname"}, + {"column_name": "age"}, + ] - self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + self.test_db_hook = MagicMock(placeholder="?", spec=DbApiHook) self.test_db_hook.get_records.side_effect = get_records self.test_db_hook.insert_statement_format = "INSERT INTO {} {} VALUES ({})" self.test_db_hook.escape_word_format = '"{}"'