Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Databricks Serverless support #3001

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions docs/integrations/engines/databricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The SQL Connector is bundled with SQLMesh and automatically installed when you i

The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models locally (the default SQLMesh approach).

The SQL Connector does not support Databricks Serverless Compute. If you require Serverless Compute then you must use the Databricks Connect library.

### Databricks Connect

If you want Databricks to process PySpark DataFrames in SQLMesh Python models, then SQLMesh must use the [Databricks Connect](https://docs.databricks.com/dev-tools/databricks-connect.html) library to connect to Databricks (instead of the Databricks SQL Connector library).
Expand Down Expand Up @@ -242,21 +244,22 @@ The only relevant SQLMesh configuration parameter is the optional `catalog` para

### Connection options

| Option | Description | Type | Required |
|--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
| `type` | Engine type name - must be `databricks` | string | Y |
| `server_hostname` | Databricks instance host name | string | N |
| `http_path` | HTTP path, either to a DBSQL endpoint (such as `/sql/1.0/endpoints/1234567890abcdef`) or to an All-Purpose cluster (such as `/sql/protocolv1/o/1234567890123456/1234-123456-slid123`) | string | N |
| `access_token` | HTTP Bearer access token, such as Databricks Personal Access Token | string | N |
| Option | Description | Type | Required |
|--------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
| `type` | Engine type name - must be `databricks` | string | Y |
| `server_hostname` | Databricks instance host name | string | N |
| `http_path` | HTTP path, either to a DBSQL endpoint (such as `/sql/1.0/endpoints/1234567890abcdef`) or to an All-Purpose cluster (such as `/sql/protocolv1/o/1234567890123456/1234-123456-slid123`) | string | N |
| `access_token` | HTTP Bearer access token, such as Databricks Personal Access Token | string | N |
| `catalog` | The name of the catalog to use for the connection. [Defaults to use Databricks cluster default](https://docs.databricks.com/en/data-governance/unity-catalog/create-catalogs.html#the-default-catalog-configuration-when-unity-catalog-is-enabled). | string | N |
| `http_headers` | SQL Connector Only: An optional dictionary of HTTP headers that will be set on every request | dict | N |
| `session_configuration` | SQL Connector Only: An optional dictionary of Spark session parameters. Execute the SQL command `SET -v` to get a full list of available commands. | dict | N |
| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N |
| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N |
| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N |
| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N |
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |
| `http_headers` | SQL Connector Only: An optional dictionary of HTTP headers that will be set on every request | dict | N |
| `session_configuration` | SQL Connector Only: An optional dictionary of Spark session parameters. Execute the SQL command `SET -v` to get a full list of available commands. | dict | N |
| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N |
| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N |
| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N |
| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect. If using serverless then SQL connector is disabled since Serverless is not supported for SQL Connector | bool | N |
| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N |
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |

## Airflow Scheduler
**Engine Name:** `databricks` / `databricks-submit` / `databricks-sql`.
Expand Down
51 changes: 41 additions & 10 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ class DatabricksConnectionConfig(ConnectionConfig):
databricks_connect_server_hostname: t.Optional[str] = None
databricks_connect_access_token: t.Optional[str] = None
databricks_connect_cluster_id: t.Optional[str] = None
databricks_connect_use_serverless: bool = False
force_databricks_connect: bool = False
disable_databricks_connect: bool = False
disable_spark_session: bool = False
Expand All @@ -550,24 +551,41 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
bool(values.get("disable_spark_session"))
):
return values
databricks_connect_use_serverless = values.get("databricks_connect_use_serverless")
server_hostname, http_path, access_token = (
values.get("server_hostname"),
values.get("http_path"),
values.get("access_token"),
)
if not server_hostname or not http_path or not access_token:
if databricks_connect_use_serverless:
values["force_databricks_connect"] = True
values["disable_databricks_connect"] = False
if (
not server_hostname or not http_path or not access_token
) and not databricks_connect_use_serverless:
raise ValueError(
"`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
)
if (
databricks_connect_use_serverless
and not server_hostname
and not values.get("databricks_connect_server_hostname")
):
raise ValueError(
"`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
)
if DatabricksEngineAdapter.can_access_databricks_connect(
bool(values.get("disable_databricks_connect"))
):
if not values.get("databricks_connect_server_hostname"):
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not values.get("databricks_connect_access_token"):
values["databricks_connect_access_token"] = access_token
if not values.get("databricks_connect_cluster_id"):
values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
if not values.get("databricks_connect_server_hostname"):
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not databricks_connect_use_serverless:
if not values.get("databricks_connect_cluster_id"):
if t.TYPE_CHECKING:
assert http_path is not None
values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
return values

@property
Expand Down Expand Up @@ -612,7 +630,7 @@ def _connection_factory(self) -> t.Callable:

return connection

from databricks import sql
from databricks import sql # type: ignore

return sql.connect

Expand All @@ -635,14 +653,27 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:

from databricks.connect import DatabricksSession

return dict(
spark=DatabricksSession.builder.remote(
if t.TYPE_CHECKING:
assert self.databricks_connect_server_hostname is not None
assert self.databricks_connect_access_token is not None

if self.databricks_connect_use_serverless:
builder = DatabricksSession.builder.remote(
host=self.databricks_connect_server_hostname,
token=self.databricks_connect_access_token,
serverless=True,
)
else:
if t.TYPE_CHECKING:
assert self.databricks_connect_cluster_id is not None
builder = DatabricksSession.builder.remote(
host=self.databricks_connect_server_hostname,
token=self.databricks_connect_access_token,
cluster_id=self.databricks_connect_cluster_id,
)
.userAgent("sqlmesh")
.getOrCreate(),

return dict(
spark=builder.userAgent("sqlmesh").getOrCreate(),
catalog=self.catalog,
)

Expand Down
66 changes: 58 additions & 8 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import typing as t

import pandas as pd
Expand All @@ -11,14 +12,15 @@
DataObject,
InsertOverwriteStrategy,
set_catalog,
SourceQuery,
)
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,13 +76,32 @@ def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool
def _use_spark_session(self) -> bool:
if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
return True
return self.can_access_databricks_connect(
bool(self._extra_config.get("disable_databricks_connect"))
) and {
"databricks_connect_server_hostname",
"databricks_connect_access_token",
"databricks_connect_cluster_id",
}.issubset(self._extra_config)
return (
self.can_access_databricks_connect(
bool(self._extra_config.get("disable_databricks_connect"))
)
and (
{
"databricks_connect_server_hostname",
"databricks_connect_access_token",
}.issubset(self._extra_config)
)
and (
"databricks_connect_cluster_id" in self._extra_config
or "databricks_connect_use_serverless" in self._extra_config
)
)

@property
def use_serverless(self) -> bool:
from sqlmesh import RuntimeEnv
from sqlmesh.utils import str_to_bool

if not self._use_spark_session:
return False
return (
RuntimeEnv.get().is_databricks and str_to_bool(os.environ["IS_SERVERLESS"])
) or bool(self._extra_config["databricks_connect_use_serverless"])

@property
def is_spark_session_cursor(self) -> bool:
Expand Down Expand Up @@ -117,6 +138,35 @@ def spark(self) -> PySparkSession:
self.set_current_catalog(catalog)
return self._spark

def _df_to_source_queries(
self,
df: DF,
columns_to_types: t.Dict[str, exp.DataType],
batch_size: int,
target_table: TableName,
) -> t.List[SourceQuery]:
if not self._use_spark_session:
return super(SparkEngineAdapter, self)._df_to_source_queries(
df, columns_to_types, batch_size, target_table
)
df = self._ensure_pyspark_df(df, columns_to_types)

def query_factory() -> Query:
temp_table = self._get_temp_table(target_table or "spark", table_only=True)
if self.use_serverless:
# Global temp views are not supported on Databricks Serverless
# This also means we can't mix Python SQL Connection and DB Connect since they wouldn't
# share the same temp objects.
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
else:
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
temp_table.set("db", "global_temp")
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)

if self._use_spark_session:
return [SourceQuery(query_factory=query_factory)]
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)

def _fetch_native_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
) -> DF:
Expand Down
Loading