Skip to content

Commit

Permalink
create _df_to_source_queries override in databricks
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Aug 14, 2024
1 parent e4cebcd commit 4417423
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 57 deletions.
32 changes: 31 additions & 1 deletion sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
DataObject,
InsertOverwriteStrategy,
set_catalog,
SourceQuery,
)
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,6 +138,35 @@ def spark(self) -> PySparkSession:
self.set_current_catalog(catalog)
return self._spark

def _df_to_source_queries(
self,
df: DF,
columns_to_types: t.Dict[str, exp.DataType],
batch_size: int,
target_table: TableName,
) -> t.List[SourceQuery]:
if not self._use_spark_session:
return super(SparkEngineAdapter, self)._df_to_source_queries(
df, columns_to_types, batch_size, target_table
)
df = self._ensure_pyspark_df(df, columns_to_types)

def query_factory() -> Query:
temp_table = self._get_temp_table(target_table or "spark", table_only=True)
if self.use_serverless:
# Global temp views are not supported on Databricks Serverless
# This also means we can't mix Python SQL Connection and DB Connect since they wouldn't
# share the same temp objects.
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
else:
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
temp_table.set("db", "global_temp")
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)

if self._use_spark_session:
return [SourceQuery(query_factory=query_factory)]
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)

def _fetch_native_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
) -> DF:
Expand Down
16 changes: 3 additions & 13 deletions sqlmesh/core/engine_adapter/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,25 +258,15 @@ def _df_to_source_queries(
batch_size: int,
target_table: TableName,
) -> t.List[SourceQuery]:
if not self._use_spark_session:
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
df = self._ensure_pyspark_df(df, columns_to_types)

def query_factory() -> Query:
temp_table = self._get_temp_table(target_table or "spark", table_only=True)
if self.use_serverless:
# Global temp views are not supported on Databricks Serverless
# This also means we can't mix Python SQL Connection and DB Connect since they wouldn't
# share the same temp objects.
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
else:
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
temp_table.set("db", "global_temp")
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
temp_table.set("db", "global_temp")
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)

if self._use_spark_session:
return [SourceQuery(query_factory=query_factory)]
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
return [SourceQuery(query_factory=query_factory)]

def _ensure_pyspark_df(
self, generic_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
Expand Down
43 changes: 0 additions & 43 deletions tests/core/engine_adapter/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from datetime import datetime
from unittest.mock import call

import pandas as pd
import pytest
from pyspark.sql import types as spark_types
from pytest_mock.plugin import MockerFixture
Expand Down Expand Up @@ -195,48 +194,6 @@ def test_replace_query_exists(mocker: MockerFixture, make_mocked_engine_adapter:
]


def test_replace_query_pandas_not_exists(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
):
mocker.patch(
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists",
return_value=False,
)
mocker.patch("sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._use_spark_session", False)
mocker.patch(
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._ensure_fqn", side_effect=lambda x: x
)
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
adapter.replace_query(
"test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}
)

assert to_sql_calls(adapter) == [
"CREATE TABLE IF NOT EXISTS `test_table` AS SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM (SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)) AS `_subquery`",
]


def test_replace_query_pandas_exists(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
mocker.patch(
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists",
return_value=True,
)
mocker.patch("sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._use_spark_session", False)
mocker.patch(
"sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._ensure_fqn", side_effect=lambda x: x
)
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
adapter.replace_query(
"test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}
)

assert to_sql_calls(adapter) == [
"INSERT OVERWRITE TABLE `test_table` (`a`, `b`) SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)",
]


def test_replace_query_self_ref_not_exists(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
):
Expand Down

0 comments on commit 4417423

Please sign in to comment.