Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] |
if schema is None:
table, schema = self.extract_schema_from_table(table)
pk_columns = [
row["column_name"]
row[0]
for row in self.get_records(
"""
select kcu.column_name as column_name
select kcu.column_name
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on kcu.constraint_name = tco.constraint_name
Expand All @@ -61,29 +61,45 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] |
]
return pk_columns or None

@staticmethod
def _to_row(row):
return {
"name": row[0],
"type": row[1],
"nullable": row[2].casefold() == "yes",
"default": row[3],
"autoincrement": row[4].casefold() == "always",
"identity": row[5].casefold() == "yes",
}

@lru_cache(maxsize=None)
def get_column_names(
self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True
) -> list[str] | None:
if schema is None:
table, schema = self.extract_schema_from_table(table)

column_names = list(
row["column_name"]
row["name"]
for row in filter(
predicate,
self.get_records(
"""
map(
self._to_row,
self.get_records(
"""
select column_name,
data_type,
is_nullable,
column_default,
ordinal_position
is_generated,
is_identity
from information_schema.columns
where table_schema = %s
and table_name = %s
order by ordinal_position
""",
(self.unescape_word(schema), self.unescape_word(table)),
(self.unescape_word(schema), self.unescape_word(table)),
),
),
)
)
Expand Down
10 changes: 5 additions & 5 deletions providers/postgres/tests/unit/postgres/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def get_records(sql, parameters):
assert "hollywood" in parameters, "Missing 'schema' in parameters"
assert "actors" in parameters, "Missing 'table' in parameters"
if "kcu." in sql:
return [{"column_name": "id"}]
return [("id",)]
return [
{"column_name": "id", "identity": True},
{"column_name": "name"},
{"column_name": "firstname"},
{"column_name": "age"},
("id", None, "NO", None, "ALWAYS", "YES"),
("name", None, "YES", None, "NEVER", "NO"),
("firstname", None, "YES", None, "NEVER", "NO"),
("age", None, "YES", None, "NEVER", "NO"),
]

self.test_db_hook = MagicMock(placeholder="?", spec=DbApiHook)
Expand Down
Loading