Skip to content

Commit

Permalink
Feat: add option in schema's find method to ensure types are DataTypes (
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored May 28, 2024
1 parent eae3c51 commit 7c323bd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
19 changes: 18 additions & 1 deletion sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,16 @@ def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.this.name]
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]

def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
) -> t.Optional[t.Any]:
"""
Returns the schema of a given table.
Args:
table: the target table.
raise_on_missing: whether to raise in case the schema is not found.
ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
Returns:
The schema of the target table.
Expand Down Expand Up @@ -239,6 +242,20 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
normalize=mapping_schema.normalize,
)

def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
) -> t.Optional[t.Any]:
schema = super().find(
table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
)
if ensure_data_types and isinstance(schema, dict) and dict_depth(schema) == 1:
schema = {
col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
for col, dtype in schema.items()
}

return schema

def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
Expand Down
7 changes: 7 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,10 @@ def test_has_column(self):
schema = MappingSchema({"x": {"c": "int"}})
self.assertTrue(schema.has_column("x", exp.column("c")))
self.assertFalse(schema.has_column("x", exp.column("k")))

def test_find(self):
schema = MappingSchema({"x": {"c": "int"}})
found = schema.find(exp.to_table("x"))
self.assertEqual(found, {"c": "int"})
found = schema.find(exp.to_table("x"), ensure_data_types=True)
self.assertEqual(found, {"c": exp.DataType.build("int")})

0 comments on commit 7c323bd

Please sign in to comment.