Skip to content

Commit

Permalink
Merge pull request #390 from databricks/dlt
Browse files Browse the repository at this point in the history
Support for materialized views and streaming tables
  • Loading branch information
rcypher-databricks authored Jul 19, 2023
2 parents 2edb61d + d2b1e31 commit 33dca4b
Show file tree
Hide file tree
Showing 31 changed files with 817 additions and 24 deletions.
228 changes: 222 additions & 6 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import threading
import time
import requests
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -392,16 +393,31 @@ class DatabricksSQLConnectionWrapper:
_conn: DatabricksSQLConnection
_is_cluster: bool
_cursors: List[DatabricksSQLCursor]
_creds: DatabricksCredentials
_user_agent: str

def __init__(self, conn: DatabricksSQLConnection, *, is_cluster: bool):
def __init__(
self,
conn: DatabricksSQLConnection,
*,
is_cluster: bool,
creds: DatabricksCredentials,
user_agent: str,
):
self._conn = conn
self._is_cluster = is_cluster
self._cursors = []
self._creds = creds
self._user_agent = user_agent

def cursor(self) -> "DatabricksSQLCursorWrapper":
cursor = self._conn.cursor()
self._cursors.append(cursor)
return DatabricksSQLCursorWrapper(cursor)
return DatabricksSQLCursorWrapper(
cursor,
creds=self._creds,
user_agent=self._user_agent,
)

def cancel(self) -> None:
cursors: List[DatabricksSQLCursor] = self._cursors
Expand Down Expand Up @@ -452,9 +468,13 @@ class DatabricksSQLCursorWrapper:
"""Wrap a Databricks SQL cursor in a way that no-ops transactions"""

_cursor: DatabricksSQLCursor
_user_agent: str
_creds: DatabricksCredentials

def __init__(self, cursor: DatabricksSQLCursor):
def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str):
self._cursor = cursor
self._creds = creds
self._user_agent = user_agent

def cancel(self) -> None:
try:
Expand All @@ -477,12 +497,116 @@ def fetchone(self) -> Optional[Tuple]:
return self._cursor.fetchone()

def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
# print(f"execute: {sql}")
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
if bindings is not None:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, bindings)

# if the command was to refresh a materialized view we need to poll
# the pipeline until the refresh is finished.
self.pollRefreshPipeline(sql)

def pollRefreshPipeline(
self,
sql: str,
) -> None:
should_poll, model_name = _should_poll_refresh(sql)
if not should_poll:
return

# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = self._cursor.connection.thrift_backend._auth_provider._header_factory()
headers["User-Agent"] = self._user_agent

pipeline_id = _get_table_view_pipeline_id(host, headers, model_name)
pipeline = _get_pipeline_state(host, headers, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
raise dbt.exceptions.DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(
f"refreshing {model_name}, pipeline: {pipeline_id}, update: {update_id} {state}"
)

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = _get_pipeline_state(host, headers, pipeline_id)
# get the update we are currently polling
update = _find_update(pipeline, update_id)
if not update:
raise dbt.exceptions.DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(
f"refreshing {model_name}, pipeline: {pipeline_id}, update: {update_id} {state}"
)
prev_state = state

if state == "FAILED":
logger.error(f"pipeline {pipeline_id} update {update_id} failed")
msg = _get_update_error_msg(host, headers, pipeline_id, update_id)
if msg:
logger.error(msg)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = _find_update(pipeline)
if not latest_update:
raise dbt.exceptions.DbtRuntimeError(
f"No update created for pipeline: {pipeline_id}"
)

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise dbt.exceptions.DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(host, headers, pipeline_id, update_id)
raise dbt.exceptions.DbtRuntimeError(f"error refreshing model {model_name} {msg}")

if state == "CANCELED":
raise dbt.exceptions.DbtRuntimeError(f"refreshing model {model_name} cancelled")

return

@classmethod
def findUpdate(cls, updates: List, id: str) -> Optional[Dict]:
matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None

@property
def hex_query_id(self) -> str:
"""Return the hex GUID for this query
Expand Down Expand Up @@ -639,7 +763,8 @@ def add_query(

fire_event(
SQLQueryStatus(
status=str(self.get_response(cursor)), elapsed=round((time.time() - pre), 2)
status=str(self.get_response(cursor)),
elapsed=round((time.time() - pre), 2),
)
)

Expand Down Expand Up @@ -687,7 +812,8 @@ def _execute_cursor(

fire_event(
SQLQueryStatus(
status=str(self.get_response(cursor)), elapsed=round((time.time() - pre), 2)
status=str(self.get_response(cursor)),
elapsed=round((time.time() - pre), 2),
)
)

Expand Down Expand Up @@ -748,7 +874,12 @@ def connect() -> DatabricksSQLConnectionWrapper:
_user_agent_entry=user_agent_entry,
**connection_parameters,
)
return DatabricksSQLConnectionWrapper(conn, is_cluster=creds.cluster_id is not None)
return DatabricksSQLConnectionWrapper(
conn,
is_cluster=creds.cluster_id is not None,
creds=creds,
user_agent=user_agent_entry,
)
except Error as exc:
_log_dbsql_errors(exc)
raise
Expand Down Expand Up @@ -787,3 +918,88 @@ def _log_dbsql_errors(exc: Exception) -> None:
logger.debug(f"{type(exc)}: {exc}")
for key, value in sorted(exc.context.items()):
logger.debug(f"{key}: {value}")


def _should_poll_refresh(sql: str) -> Tuple[bool, str]:
# if the command was to refresh a materialized view we need to poll
# the pipeline until the refresh is finished.
name = ""
refresh_search = re.search(r"refresh\s+materialized\s+view\s+([`\w.]+)", sql)
if not refresh_search:
refresh_search = re.search(r"create\s+or\s+refresh\s+streaming\s+table\s+([`\w.]+)", sql)

if refresh_search:
name = refresh_search.group(1).replace("`", "")

return refresh_search is not None, name


def _get_table_view_pipeline_id(host: str, headers: dict, name: str) -> str:
table_url = f"https://{host}/api/2.1/unity-catalog/tables/{name}"
resp1 = requests.get(table_url, headers=headers)
if resp1.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(
f"Error getting info for materialized view/streaming table: {name}"
)

pipeline_id = resp1.json().get("pipeline_id", "")
if not pipeline_id:
raise dbt.exceptions.DbtRuntimeError(
f"Materialized view/streaming table {name} does not have a pipeline id"
)

return pipeline_id


def _get_pipeline_state(host: str, headers: dict, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

response = requests.get(pipeline_url, headers=headers)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline info: {pipeline_id}")

return response.json()


def _find_update(pipeline: dict, id: str = "") -> Optional[Dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise dbt.exceptions.DbtRuntimeError(
f"No updates for pipeline: {pipeline.get('pipeline_id', '')}"
)

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = requests.get(events_url, headers=headers)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline event info: {pipeline_id}")

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg
Loading

0 comments on commit 33dca4b

Please sign in to comment.