-
Notifications
You must be signed in to change notification settings - Fork 14.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update databricks provider to use TriggerOperator #18999
Comments
Thanks for opening your first issue here! Be sure to follow the issue template! |
That would be very interesting to work on this issue, using the new Airflow functionality, i.e. implementing the polling in async. I'd be happy to contribute to that. |
@eskarimov I have implemented a prototype like this: One issue i have with the above limitation is that the task duration does not count the time the task was deferred. I'm sure many improvements could be made but this should work. # pylint: disable=abstract-method
class DatabricksHookAsync(DatabricksHook):
"""Async version of the databricks hook"""
async def get_run_state_async(
self, run_id: str, session: ClientSession
) -> RunState:
json = {"run_id": run_id}
response = await self._do_api_call_async(GET_RUN_ENDPOINT, json, session)
state = response["state"]
life_cycle_state = state["life_cycle_state"]
# result_state may not be in the state if not terminal
result_state = state.get("result_state", None)
state_message = state["state_message"]
return RunState(life_cycle_state, result_state, state_message)
async def _do_api_call_async(self, endpoint_info, json, session: ClientSession):
"""
Utility function to perform an API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: tuple[string, string]
:param json: Parameters for this API call.
:type json: dict
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
we throw an AirflowException.
:rtype: dict
"""
method, endpoint = endpoint_info
self.databricks_conn = self.get_connection(self.databricks_conn_id)
if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = {
"Authorization": "Bearer " + self.databricks_conn.extra_dejson["token"]
}
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self.databricks_conn.host
else:
raise AirflowException("DatabricksHookAsync only supports token Auth")
url = f"https://{self._parse_host(host)}/{endpoint}" # type: ignore
if method == "GET":
request_func = session.get
elif method == "POST":
request_func = session.post
elif method == "PATCH":
request_func = session.patch
else:
raise AirflowException("Unexpected HTTP Method: " + method)
attempt_num = 1
while True:
try:
response = await request_func(
url,
json=json if method in ("POST", "PATCH") else None,
params=json if method == "GET" else None,
headers=auth,
timeout=self.timeout_seconds,
)
response.raise_for_status()
return await response.json()
except ClientResponseError as err:
if err.status < 500:
# In this case, the user probably made a mistake.
# Don't retry.
# pylint: disable=raise-missing-from
raise AirflowException(
f"Response: {err.message}, Status Code: {err.status}"
)
if attempt_num == self.retry_limit:
raise AirflowException(
(
"API requests to Databricks failed {} times. " + "Giving up."
).format(self.retry_limit)
)
attempt_num += 1
await asyncio.sleep(self.retry_delay)
class DatabricksJobTrigger(BaseTrigger):
"""A trigger that checks every 15 seconds whether the databricks job is finished"""
def __init__(self, run_id: str, databricks_conn_id):
super().__init__()
self.run_id = run_id
self.databricks_conn_id = databricks_conn_id
def serialize(self) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
return (
"operators.submit_to_databricks_operator.DatabricksJobTrigger",
{
"run_id": self.run_id,
"databricks_conn_id": self.databricks_conn_id,
},
)
async def run(self):
hook = DatabricksHookAsync(self.databricks_conn_id)
async with aiohttp.ClientSession() as session:
while True:
run_state = await hook.get_run_state_async(self.run_id, session)
if run_state.is_terminal:
if run_state.is_successful:
self.log.info("Run id: %s completed successfully.", self.run_id)
else:
self.log.info("Run id: %s completed and failed.", self.run_id)
yield TriggerEvent((self.run_id, run_state.result_state))
await asyncio.sleep(15) |
@chinwobble thank you so much, it's a good base to start working on! 👍 Regarding the task duration - I haven't checked yet, but I thought it's implemented so that execution time counted over total time, so an operator might fail even while it's deferred. |
@ekarimov I've checked the task duration. |
@chinwobble you're right, I see the same behaviour... I've spotted though when I don't pass Also, we might found a a bug, because I see a code part in |
We are using localExecutor and it seems like the deferral consumes an unnecessary amount of tasks.
|
FYI just wanted to share that I'm still willing to introduce deferrable operators for Databricks, and working on the code. Regarding the timeout - I've created a separate issue #19382 to investigate and fix it, shared the findings there. |
@eskarimov Happy to review your code. I'm not the best python dev, so I think I could learn a thing or two from you. These deferrable operators are a good first step in making airflow more scalable.
|
@eskarimov nice work getting this one merged! |
Description
The databricks provider can be updated to a deferrable task introduced in airflow 2.2.0. This will significnatly reduce cpu usage, memory usage on LocalExecutor and improve reliability since state is stored in the airflow metastore.
Use case/motivation
Currently the databricks operators works by calling the databricks REST API to submit a spark job and polling it every 20 secs to check when it is done. This creates the following inefficiencies:
Related issues
No response
Are you willing to submit a PR?
Code of Conduct
The text was updated successfully, but these errors were encountered: