From 1e2a1ae53a34541fceff622a9820393e5326eb22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 1 Jun 2023 14:08:50 -0700 Subject: [PATCH 1/4] introducing a default method for deferrable continuation aims: - reduce the amount of code we need to write when making operators deferrable - reduce code sensibility to copy-paste mistakes, both: - in the text of the messages logged (I made that mistake already) - in the structure of the trigger events, by making them constants/built in one place - a unified way to pass events -> makes the code less surprising to read, i.e. easier to understand it doesn't prevent from using a custom way of doing things if needed, but abstracts away the boilerplate code when nothing interesting happens. I converted the glue operators to this as a demo. If it's pushed, I can do a round on existing deferrable implems to convert them to this. --- airflow/models/baseoperator.py | 13 +++++++++++-- airflow/providers/amazon/aws/operators/glue.py | 7 ------- .../providers/amazon/aws/operators/glue_crawler.py | 7 ------- airflow/providers/amazon/aws/triggers/glue.py | 4 ++-- .../providers/amazon/aws/triggers/glue_crawler.py | 7 ++----- airflow/triggers/base.py | 11 +++++++++++ 6 files changed, 26 insertions(+), 23 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 6392ef9697647..5610ffacb128b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -81,7 +81,7 @@ from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep -from airflow.triggers.base import BaseTrigger +from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.context import Context from airflow.utils.decorators import fixup_decorator_warning_stack @@ -1574,7 +1574,7 @@ def defer( self, *, trigger: BaseTrigger, - method_name: str, + method_name: str = "execute_complete", kwargs: dict[str, Any] | None = None, timeout: timedelta | None = None, ): @@ -1587,6 +1587,15 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + def execute_complete(self, context, event=None): + """The default method for handling the event returned after the deferred operation completes.""" + op_name = type(self).__name__ + if event is None or event["status"] != TriggerEvent.STATUS_SUCCESS: + raise AirflowException(f"{op_name}'s deferred operation was not completed successfully: {event}") + else: + self.log.info("% completed successfully", op_name) + return event.get("value") + def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """Get the "normal" operator from the current operator. diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 053e530c72674..b9844519ed9bb 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -21,7 +21,6 @@ import urllib.parse from typing import TYPE_CHECKING, Sequence -from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -183,7 +182,6 @@ def execute(self, context: Context): verbose=self.verbose, aws_conn_id=self.aws_conn_id, ), - method_name="execute_complete", ) elif self.wait_for_completion: glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) @@ -196,8 +194,3 @@ def execute(self, context: Context): else: self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) return glue_job_run["JobRunId"] - - def execute_complete(self, context, event=None): - if event["status"] != "success": - raise AirflowException(f"Error in glue job: {event}") - return diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index c7ac25f1f2e30..b208737d8e70f 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -20,7 +20,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence -from airflow import AirflowException from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger if TYPE_CHECKING: @@ -96,15 +95,9 @@ def execute(self, context: Context): poll_interval=self.poll_interval, aws_conn_id=self.aws_conn_id, ), - method_name="execute_complete", ) elif self.wait_for_completion: self.log.info("Waiting for AWS Glue Crawler") self.hook.wait_for_crawler_completion(crawler_name=crawler_name, poll_interval=self.poll_interval) return crawler_name - - def execute_complete(self, context, event=None): - if event["status"] != "success": - raise AirflowException(f"Error in glue crawl: {event}") - return diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py index 42219a993ad25..6004bc1afde22 100644 --- a/airflow/providers/amazon/aws/triggers/glue.py +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -59,5 +59,5 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: hook = GlueJobHook(aws_conn_id=self.aws_conn_id) - await hook.async_job_completion(self.job_name, self.run_id, self.verbose) - yield TriggerEvent({"status": "success", "message": "Job done"}) + glue_job_run = await hook.async_job_completion(self.job_name, self.run_id, self.verbose) + yield TriggerEvent.success(glue_job_run["JobRunId"]) diff --git a/airflow/providers/amazon/aws/triggers/glue_crawler.py b/airflow/providers/amazon/aws/triggers/glue_crawler.py index 10ab45dda73d3..02d686b5b963a 100644 --- a/airflow/providers/amazon/aws/triggers/glue_crawler.py +++ b/airflow/providers/amazon/aws/triggers/glue_crawler.py @@ -68,11 +68,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: break # we reach this point only if the waiter met a success criteria except WaiterError as error: if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Glue Crawler creation Failed: {error}"} - ) - break + raise self.log.info("Status of glue crawl is %s", error.last_response["Crawler"]["State"]) await asyncio.sleep(int(self.poll_interval)) - yield TriggerEvent({"status": "success", "message": "Crawl Complete"}) + yield TriggerEvent.success(self.crawler_name) diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 314d97b0ee919..8ef81506c2366 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -109,9 +109,20 @@ class TriggerEvent: events. """ + STATUS_SUCCESS = "success" + def __init__(self, payload: Any): self.payload = payload + @classmethod + def success(cls, value: Any = None) -> TriggerEvent: + """ + Creates a TriggerEvent to be returned by a deferred operation that completed successfully + + :param value: the value to be returned by the operator on completion + """ + return TriggerEvent({"status": cls.STATUS_SUCCESS, "value": value}) + def __repr__(self) -> str: return f"TriggerEvent<{self.payload!r}>" From 55f4c2aae58d9603209914a81f3f8e90b8510927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 2 Jun 2023 13:09:48 -0700 Subject: [PATCH 2/4] add type annotations --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5610ffacb128b..765fbe33a617e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1587,10 +1587,10 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any]): """The default method for handling the event returned after the deferred operation completes.""" op_name = type(self).__name__ - if event is None or event["status"] != TriggerEvent.STATUS_SUCCESS: + if event["status"] != TriggerEvent.STATUS_SUCCESS: raise AirflowException(f"{op_name}'s deferred operation was not completed successfully: {event}") else: self.log.info("% completed successfully", op_name) From 43ac35165c217bac9890f0924fe3c7c918b556d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 2 Jun 2023 15:06:46 -0700 Subject: [PATCH 3/4] remove provider changes --- airflow/providers/amazon/aws/operators/glue.py | 7 +++++++ airflow/providers/amazon/aws/operators/glue_crawler.py | 7 +++++++ airflow/providers/amazon/aws/triggers/glue.py | 4 ++-- airflow/providers/amazon/aws/triggers/glue_crawler.py | 7 +++++-- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index b9844519ed9bb..053e530c72674 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -21,6 +21,7 @@ import urllib.parse from typing import TYPE_CHECKING, Sequence +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -182,6 +183,7 @@ def execute(self, context: Context): verbose=self.verbose, aws_conn_id=self.aws_conn_id, ), + method_name="execute_complete", ) elif self.wait_for_completion: glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) @@ -194,3 +196,8 @@ def execute(self, context: Context): else: self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) return glue_job_run["JobRunId"] + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in glue job: {event}") + return diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index b208737d8e70f..c7ac25f1f2e30 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -20,6 +20,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence +from airflow import AirflowException from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger if TYPE_CHECKING: @@ -95,9 +96,15 @@ def execute(self, context: Context): poll_interval=self.poll_interval, aws_conn_id=self.aws_conn_id, ), + method_name="execute_complete", ) elif self.wait_for_completion: self.log.info("Waiting for AWS Glue Crawler") self.hook.wait_for_crawler_completion(crawler_name=crawler_name, poll_interval=self.poll_interval) return crawler_name + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in glue crawl: {event}") + return diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py index 6004bc1afde22..42219a993ad25 100644 --- a/airflow/providers/amazon/aws/triggers/glue.py +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -59,5 +59,5 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: hook = GlueJobHook(aws_conn_id=self.aws_conn_id) - glue_job_run = await hook.async_job_completion(self.job_name, self.run_id, self.verbose) - yield TriggerEvent.success(glue_job_run["JobRunId"]) + await hook.async_job_completion(self.job_name, self.run_id, self.verbose) + yield TriggerEvent({"status": "success", "message": "Job done"}) diff --git a/airflow/providers/amazon/aws/triggers/glue_crawler.py b/airflow/providers/amazon/aws/triggers/glue_crawler.py index 02d686b5b963a..10ab45dda73d3 100644 --- a/airflow/providers/amazon/aws/triggers/glue_crawler.py +++ b/airflow/providers/amazon/aws/triggers/glue_crawler.py @@ -68,8 +68,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: break # we reach this point only if the waiter met a success criteria except WaiterError as error: if "terminal failure" in str(error): - raise + yield TriggerEvent( + {"status": "failure", "message": f"Glue Crawler creation Failed: {error}"} + ) + break self.log.info("Status of glue crawl is %s", error.last_response["Crawler"]["State"]) await asyncio.sleep(int(self.poll_interval)) - yield TriggerEvent.success(self.crawler_name) + yield TriggerEvent({"status": "success", "message": "Crawl Complete"}) From 6c5ec82fe7b865118574249e9862b1d64a7869d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 2 Jun 2023 15:13:38 -0700 Subject: [PATCH 4/4] use a different method name to prevent issues with typing/arguments --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 765fbe33a617e..b855d0e69c7ff 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1574,7 +1574,7 @@ def defer( self, *, trigger: BaseTrigger, - method_name: str = "execute_complete", + method_name: str = "execute_complete_default", kwargs: dict[str, Any] | None = None, timeout: timedelta | None = None, ): @@ -1587,7 +1587,7 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def execute_complete(self, context: Context, event: dict[str, Any]): + def execute_complete_default(self, context: Context, event: dict[str, Any]): """The default method for handling the event returned after the deferred operation completes.""" op_name = type(self).__name__ if event["status"] != TriggerEvent.STATUS_SUCCESS: