diff --git a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py index 6c1d3104d10c0..9e31d39b5e877 100644 --- a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py @@ -42,10 +42,10 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | if schema is None: table, schema = self.extract_schema_from_table(table) pk_columns = [ - row["column_name"] + row[0] for row in self.get_records( """ - select kcu.column_name as column_name + select kcu.column_name from information_schema.table_constraints tco join information_schema.key_column_usage kcu on kcu.constraint_name = tco.constraint_name @@ -61,29 +61,45 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | ] return pk_columns or None + @staticmethod + def _to_row(row): + return { + "name": row[0], + "type": row[1], + "nullable": row[2].casefold() == "yes", + "default": row[3], + "autoincrement": row[4].casefold() == "always", + "identity": row[5].casefold() == "yes", + } + @lru_cache(maxsize=None) def get_column_names( self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True ) -> list[str] | None: if schema is None: table, schema = self.extract_schema_from_table(table) + column_names = list( - row["column_name"] + row["name"] for row in filter( predicate, - self.get_records( - """ + map( + self._to_row, + self.get_records( + """ select column_name, data_type, is_nullable, column_default, - ordinal_position + is_generated, + is_identity from information_schema.columns where table_schema = %s and table_name = %s order by ordinal_position """, - (self.unescape_word(schema), self.unescape_word(table)), + (self.unescape_word(schema), self.unescape_word(table)), + ), ), ) ) diff --git a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py index 5406805420b85..999386baac12e 100644 --- a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py @@ -30,12 +30,12 @@ def get_records(sql, parameters): assert "hollywood" in parameters, "Missing 'schema' in parameters" assert "actors" in parameters, "Missing 'table' in parameters" if "kcu." in sql: - return [{"column_name": "id"}] + return [("id",)] return [ - {"column_name": "id", "identity": True}, - {"column_name": "name"}, - {"column_name": "firstname"}, - {"column_name": "age"}, + ("id", None, "NO", None, "ALWAYS", "YES"), + ("name", None, "YES", None, "NEVER", "NO"), + ("firstname", None, "YES", None, "NEVER", "NO"), + ("age", None, "YES", None, "NEVER", "NO"), ] self.test_db_hook = MagicMock(placeholder="?", spec=DbApiHook)