Skip to content

Commit

Permalink
Fix: enable fetching schema for models querying INFORMATION_SCHEMA (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Nov 4, 2024
1 parent 6669d29 commit c866f83
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 15 deletions.
40 changes: 29 additions & 11 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,35 @@ def dtype_to_sql(dtype: t.Optional[StandardSqlDataType]) -> str:
return "JSON"
return kind.name

table = self._get_table(table_name)
columns = {
field.name: exp.DataType.build(
dtype_to_sql(field.to_standard_sql().type), dialect=self.dialect
)
for field in table.schema
}
if include_pseudo_columns and table.time_partitioning and not table.time_partitioning.field:
columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP")
if table.time_partitioning.type_ == "DAY":
columns["_PARTITIONDATE"] = exp.DataType.build("DATE")
def create_mapping_schema(
schema: t.Sequence[bigquery.SchemaField],
) -> t.Dict[str, exp.DataType]:
return {
field.name: exp.DataType.build(
dtype_to_sql(field.to_standard_sql().type), dialect=self.dialect
)
for field in schema
}

table = exp.to_table(table_name)
if len(table.parts) > 3:
# The client's `get_table` method can't handle paths with >3 identifiers
self.execute(exp.select("*").from_(table).limit(1))
query_results = self._query_job._query_results
columns = create_mapping_schema(query_results.schema)
else:
bq_table = self._get_table(table)
columns = create_mapping_schema(bq_table.schema)

if (
include_pseudo_columns
and bq_table.time_partitioning
and not bq_table.time_partitioning.field
):
columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP")
if bq_table.time_partitioning.type_ == "DAY":
columns["_PARTITIONDATE"] = exp.DataType.build("DATE")

return columns

def alter_table(
Expand Down
9 changes: 7 additions & 2 deletions tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.utils.pandas import compare_dataframes

if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName
from sqlmesh.core.engine_adapter._typing import Query

TEST_SCHEMA = "test_schema"
Expand Down Expand Up @@ -212,12 +213,16 @@ def input_data(
def output_data(self, data: pd.DataFrame) -> pd.DataFrame:
return self._format_df(data)

def table(self, table_name: str, schema: str = TEST_SCHEMA) -> exp.Table:
def table(self, table_name: TableName, schema: str = TEST_SCHEMA) -> exp.Table:
schema = self.add_test_suffix(schema)
self._schemas.append(schema)

table = exp.to_table(table_name)
table.set("db", exp.parse_identifier(schema, dialect=self.dialect))

return exp.to_table(
normalize_model_name(
".".join([schema, table_name]),
table,
default_catalog=self.engine_adapter.default_catalog,
dialect=self.dialect,
)
Expand Down
40 changes: 40 additions & 0 deletions tests/core/engine_adapter/integration/test_integration_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,43 @@ def _get_data_object(table: exp.Table) -> DataObject:

metadata = _get_data_object(target_table_1)
assert not metadata.is_clustered


def test_fetch_schema_of_information_schema_tables(
ctx: TestContext, engine_adapter: BigQueryEngineAdapter
):
# We produce Table(this=Dot(this=INFORMATION_SCHEMA, expression=TABLES)) here,
# otherwise `db` or `catalog` would be set, which is not the right representation
information_schema_tables = exp.to_table("_._.INFORMATION_SCHEMA.TABLES")
information_schema_tables.set("db", None)
information_schema_tables.set("catalog", None)

source = ctx.table(information_schema_tables)

expected_columns_to_types = {
"table_catalog": exp.DataType.build("TEXT"),
"table_schema": exp.DataType.build("TEXT"),
"table_name": exp.DataType.build("TEXT"),
"table_type": exp.DataType.build("TEXT"),
"is_insertable_into": exp.DataType.build("TEXT"),
"is_typed": exp.DataType.build("TEXT"),
"creation_time": exp.DataType.build("TIMESTAMPTZ"),
"base_table_catalog": exp.DataType.build("TEXT"),
"base_table_schema": exp.DataType.build("TEXT"),
"base_table_name": exp.DataType.build("TEXT"),
"snapshot_time_ms": exp.DataType.build("TIMESTAMPTZ"),
"ddl": exp.DataType.build("TEXT"),
"default_collation_name": exp.DataType.build("TEXT"),
"upsert_stream_apply_watermark": exp.DataType.build("TIMESTAMPTZ"),
"replica_source_catalog": exp.DataType.build("TEXT"),
"replica_source_schema": exp.DataType.build("TEXT"),
"replica_source_name": exp.DataType.build("TEXT"),
"replication_status": exp.DataType.build("TEXT"),
"replication_error": exp.DataType.build("TEXT"),
"is_change_history_enabled": exp.DataType.build("TEXT"),
"sync_status": exp.DataType.build(
"STRUCT<last_completion_time TIMESTAMPTZ, error_time TIMESTAMPTZ, error STRUCT<reason TEXT, location TEXT, message TEXT>>"
),
}

assert expected_columns_to_types == engine_adapter.columns(source.sql())
4 changes: 2 additions & 2 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5945,12 +5945,12 @@ def custom_macro(evaluator, arg1, arg2):
dialect snowflake,
);
SELECT * FROM (@custom_macro(@foo, @bar)) AS q
SELECT * FROM (@custom_macro(@a, @b)) AS q
""")

config = Config(
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
variables={"foo": "foo", "bar": "boo"},
variables={"a": "a", "b": "b"},
)
context = Context(paths=tmp_path, config=config)

Expand Down

0 comments on commit c866f83

Please sign in to comment.