Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update databricks provider to use TriggerOperator #18999

Closed
1 of 2 tasks
chinwobble opened this issue Oct 15, 2021 · 10 comments · Fixed by #19736
Closed
1 of 2 tasks

Update databricks provider to use TriggerOperator #18999

chinwobble opened this issue Oct 15, 2021 · 10 comments · Fixed by #19736
Labels
kind:feature Feature Requests

Comments

@chinwobble
Copy link
Contributor

Description

The databricks provider can be updated to a deferrable task introduced in airflow 2.2.0. This will significnatly reduce cpu usage, memory usage on LocalExecutor and improve reliability since state is stored in the airflow metastore.

Use case/motivation

Currently the databricks operators works by calling the databricks REST API to submit a spark job and polling it every 20 secs to check when it is done. This creates the following inefficiencies:

  • If the airflow executor process crashes, a duplicate job can be created in databricks since airflow doesn't save the databricks job run id.
  • Each task instance of the databricks operators run its own process, however 95% of the time the process is idle and waiting to repoll the databricks API. If you want to 50 jobs in parallel, you will use a non trivial amount of memory / CPU just to poll a rest API.

Related issues

No response

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

Code of Conduct

@chinwobble chinwobble added the kind:feature Feature Requests label Oct 15, 2021
@boring-cyborg
Copy link

boring-cyborg bot commented Oct 15, 2021

Thanks for opening your first issue here! Be sure to follow the issue template!

@eskarimov
Copy link
Contributor

That would be very interesting to work on this issue, using the new Airflow functionality, i.e. implementing the polling in async. I'd be happy to contribute to that.

@chinwobble
Copy link
Contributor Author

chinwobble commented Oct 27, 2021

@eskarimov I have implemented a prototype like this:

One issue i have with the above limitation is that the task duration does not count the time the task was deferred.

I'm sure many improvements could be made but this should work.

# pylint: disable=abstract-method
class DatabricksHookAsync(DatabricksHook):
    """Async version of the databricks hook"""

    async def get_run_state_async(
        self, run_id: str, session: ClientSession
    ) -> RunState:
        json = {"run_id": run_id}
        response = await self._do_api_call_async(GET_RUN_ENDPOINT, json, session)
        state = response["state"]
        life_cycle_state = state["life_cycle_state"]
        # result_state may not be in the state if not terminal
        result_state = state.get("result_state", None)
        state_message = state["state_message"]
        return RunState(life_cycle_state, result_state, state_message)

    async def _do_api_call_async(self, endpoint_info, json, session: ClientSession):
        """
        Utility function to perform an API call with retries
        :param endpoint_info: Tuple of method and endpoint
        :type endpoint_info: tuple[string, string]
        :param json: Parameters for this API call.
        :type json: dict
        :return: If the api call returns a OK status code,
            this function returns the response in JSON. Otherwise,
            we throw an AirflowException.
        :rtype: dict
        """
        method, endpoint = endpoint_info

        self.databricks_conn = self.get_connection(self.databricks_conn_id)

        if "token" in self.databricks_conn.extra_dejson:
            self.log.info("Using token auth. ")
            auth = {
                "Authorization": "Bearer " + self.databricks_conn.extra_dejson["token"]
            }
            if "host" in self.databricks_conn.extra_dejson:
                host = self._parse_host(self.databricks_conn.extra_dejson["host"])
            else:
                host = self.databricks_conn.host
        else:
            raise AirflowException("DatabricksHookAsync only supports token Auth")

        url = f"https://{self._parse_host(host)}/{endpoint}"  # type: ignore

        if method == "GET":
            request_func = session.get
        elif method == "POST":
            request_func = session.post
        elif method == "PATCH":
            request_func = session.patch
        else:
            raise AirflowException("Unexpected HTTP Method: " + method)

        attempt_num = 1
        while True:
            try:
                response = await request_func(
                    url,
                    json=json if method in ("POST", "PATCH") else None,
                    params=json if method == "GET" else None,
                    headers=auth,
                    timeout=self.timeout_seconds,
                )
                response.raise_for_status()
                return await response.json()
            except ClientResponseError as err:
                if err.status < 500:
                    # In this case, the user probably made a mistake.
                    # Don't retry.
                    # pylint: disable=raise-missing-from
                    raise AirflowException(
                        f"Response: {err.message}, Status Code: {err.status}"
                    )

            if attempt_num == self.retry_limit:
                raise AirflowException(
                    (
                        "API requests to Databricks failed {} times. " + "Giving up."
                    ).format(self.retry_limit)
                )

            attempt_num += 1
            await asyncio.sleep(self.retry_delay)


class DatabricksJobTrigger(BaseTrigger):
    """A trigger that checks every 15 seconds whether the databricks job is finished"""

    def __init__(self, run_id: str, databricks_conn_id):
        super().__init__()
        self.run_id = run_id
        self.databricks_conn_id = databricks_conn_id

    def serialize(self) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
        return (
            "operators.submit_to_databricks_operator.DatabricksJobTrigger",
            {
                "run_id": self.run_id,
                "databricks_conn_id": self.databricks_conn_id,
            },
        )

    async def run(self):
        hook = DatabricksHookAsync(self.databricks_conn_id)
        async with aiohttp.ClientSession() as session:
            while True:
                run_state = await hook.get_run_state_async(self.run_id, session)
                if run_state.is_terminal:
                    if run_state.is_successful:
                        self.log.info("Run id: %s completed successfully.", self.run_id)
                    else:
                        self.log.info("Run id: %s completed and failed.", self.run_id)
                    yield TriggerEvent((self.run_id, run_state.result_state))
                await asyncio.sleep(15)

@eskarimov
Copy link
Contributor

@chinwobble thank you so much, it's a good base to start working on! 👍

Regarding the task duration - I haven't checked yet, but I thought it's implemented so that execution time counted over total time, so an operator might fail even while it's deferred.

@chinwobble
Copy link
Contributor Author

chinwobble commented Oct 31, 2021

@chinwobble thank you so much, it's a good base to start working on! 👍

Regarding the task duration - I haven't checked yet, but I thought it's implemented so that execution time counted over total time, so an operator might fail even while it's deferred.

@ekarimov I've checked the task duration.
If your task defers, that deferral time is not considered as part of the task duration.

@eskarimov
Copy link
Contributor

eskarimov commented Nov 1, 2021

@chinwobble you're right, I see the same behaviour... I've spotted though when I don't pass timeout argument to self.defer(), a deferred task will fail if its runtime more than execution_timeout set for the Operator calling it.
I.e. timeout for a deferred task is equal to execution timeout for an Operator in this case. Here is the code behind

Also, we might found a a bug, because I see a code part in TaskInstance trying to check for timeout and re-calculate it, when task executed after deferred state. The issue is that self.start_date is changing every time when a task is continued after the deferred state.

@chinwobble
Copy link
Contributor Author

chinwobble commented Nov 2, 2021

We are using localExecutor and it seems like the deferral consumes an unnecessary amount of tasks.
When implementing deferral, it looks like this is what is happening:

  1. schedulerJob decides a task is ready to be queued
  2. Queued jobs run on a local executor and start a new job consuming 1 worker slot.
  3. The execute method is started, it takes 1 sec to submit the job to databricks
  4. self.defer() is run and queues a trigger.
  5. This throws an exception and causes the task job to terminate and the worker slot is returned to the pool
  6. The triggerer is a separate process that can asynchronously process triggers and determine when to a resume a task
  7. The trigger yields a result causing the scheduler to requeue the original databricks task possibly on a different method
  8. A new executor worker process (second job start) is started (consuming 1 worker slot), when complete it marks the task as succeeded or failed
  9. The task is complete and all workers are returned to the pool.

@eskarimov
Copy link
Contributor

FYI just wanted to share that I'm still willing to introduce deferrable operators for Databricks, and working on the code.
Hope to create a PR later this week.
There're a couple of moments I'd like to discuss with someone more experienced than me, @chinwobble I'd be very happy if you review the code, will add you to the reviewers list if you don't mind.

Regarding the timeout - I've created a separate issue #19382 to investigate and fix it, shared the findings there.
@chinwobble could I add your comment above there into the issue description? It's a nice summary on how it works internally

@chinwobble
Copy link
Contributor Author

@eskarimov Happy to review your code. I'm not the best python dev, so I think I could learn a thing or two from you.

#18999 (comment)

These deferrable operators are a good first step in making airflow more scalable.
However, there are a few issues that make me feel its not ready yet:

  • task duration during deferral is not counted
  • for the databricks use case, there is an unnecessary number of processes being started / killed
  • More polling is introduced, for our pipeline that takes 2.5 hours / day, an extra 20 mins was added after implementing deferrable operators.

@chinwobble
Copy link
Contributor Author

@eskarimov nice work getting this one merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature Feature Requests
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants