diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 0d871779de4bb..2a43a2f502931 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -836,7 +836,8 @@ def get_waiter( corresponding value. If a custom waiter has such keys to be expanded, they need to be provided here. :param deferrable: If True, the waiter is going to be an async custom waiter. - + An async client must be provided in that case. + :param client: The client to use for the waiter's operations """ from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 6f4ed8342e072..753d43dce6d95 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import asyncio import time import boto3 @@ -194,6 +195,12 @@ def get_job_state(self, job_name: str, run_id: str) -> str: job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) return job_run["JobRun"]["JobRunState"] + async def async_get_job_state(self, job_name: str, run_id: str) -> str: + """The async version of get_job_state.""" + async with self.async_conn as client: + job_run = await client.get_job_run(JobName=job_name, RunId=run_id) + return job_run["JobRun"]["JobRunState"] + def print_job_logs( self, job_name: str, @@ -264,33 +271,68 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) :return: Dict of JobRunState and JobRunId """ - failed_states = ["FAILED", "TIMEOUT"] - finished_states = ["SUCCEEDED", "STOPPED"] next_log_tokens = self.LogContinuationTokens() while True: - if verbose: - self.print_job_logs( - job_name=job_name, - run_id=run_id, - continuation_tokens=next_log_tokens, - ) - job_run_state = self.get_job_state(job_name, run_id) - if job_run_state in finished_states: - self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) - return {"JobRunState": job_run_state, "JobRunId": run_id} - if job_run_state in failed_states: - job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}" - self.log.info(job_error_message) - raise AirflowException(job_error_message) + ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens) + if ret: + return ret else: - self.log.info( - "Polling for AWS Glue Job %s current run state with status %s", - job_name, - job_run_state, - ) time.sleep(self.JOB_POLL_INTERVAL) + async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]: + """ + Waits until Glue job with job_name completes or fails and return final state if finished. + Raises AirflowException when the job failed. + + :param job_name: unique job name per AWS account + :param run_id: The job-run ID of the predecessor job run + :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) + :return: Dict of JobRunState and JobRunId + """ + next_log_tokens = self.LogContinuationTokens() + while True: + job_run_state = await self.async_get_job_state(job_name, run_id) + ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens) + if ret: + return ret + else: + await asyncio.sleep(self.JOB_POLL_INTERVAL) + + def _handle_state( + self, + state: str, + job_name: str, + run_id: str, + verbose: bool, + next_log_tokens: GlueJobHook.LogContinuationTokens, + ) -> dict | None: + """Helper function to process Glue Job state while polling. Used by both sync and async methods.""" + failed_states = ["FAILED", "TIMEOUT"] + finished_states = ["SUCCEEDED", "STOPPED"] + + if verbose: + self.print_job_logs( + job_name=job_name, + run_id=run_id, + continuation_tokens=next_log_tokens, + ) + + if state in finished_states: + self.log.info("Exiting Job %s Run State: %s", run_id, state) + return {"JobRunState": state, "JobRunId": run_id} + if state in failed_states: + job_error_message = f"Exiting Job {run_id} Run State: {state}" + self.log.info(job_error_message) + raise AirflowException(job_error_message) + else: + self.log.info( + "Polling for AWS Glue Job %s current run state with status %s", + job_name, + state, + ) + return None + def has_job(self, job_name) -> bool: """ Checks if the job already exists. diff --git a/airflow/providers/amazon/aws/hooks/glue_crawler.py b/airflow/providers/amazon/aws/hooks/glue_crawler.py index 0393cadc9757a..fcd16dadeb941 100644 --- a/airflow/providers/amazon/aws/hooks/glue_crawler.py +++ b/airflow/providers/amazon/aws/hooks/glue_crawler.py @@ -18,9 +18,7 @@ from __future__ import annotations from functools import cached_property -from time import sleep -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sts import StsHook @@ -179,41 +177,20 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status :return: Crawler's status """ - failed_status = ["FAILED", "CANCELLED"] - - while True: - crawler = self.get_crawler(crawler_name) - crawler_state = crawler["State"] - if crawler_state == "READY": - self.log.info("State: %s", crawler_state) - self.log.info("crawler_config: %s", crawler) - crawler_status = crawler["LastCrawl"]["Status"] - if crawler_status in failed_status: - raise AirflowException(f"Status: {crawler_status}") - metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ - "CrawlerMetricsList" - ][0] - self.log.info("Status: %s", crawler_status) - self.log.info("Last Runtime Duration (seconds): %s", metrics["LastRuntimeSeconds"]) - self.log.info("Median Runtime Duration (seconds): %s", metrics["MedianRuntimeSeconds"]) - self.log.info("Tables Created: %s", metrics["TablesCreated"]) - self.log.info("Tables Updated: %s", metrics["TablesUpdated"]) - self.log.info("Tables Deleted: %s", metrics["TablesDeleted"]) - - return crawler_status - - else: - self.log.info("Polling for AWS Glue crawler: %s ", crawler_name) - self.log.info("State: %s", crawler_state) - - metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ - "CrawlerMetricsList" - ][0] - time_left = int(metrics["TimeLeftSeconds"]) - - if time_left > 0: - self.log.info("Estimated Time Left (seconds): %s", time_left) - else: - self.log.info("Crawler should finish soon") - - sleep(poll_interval) + self.get_waiter("crawler_ready").wait(Name=crawler_name, WaiterConfig={"Delay": poll_interval}) + + # query one extra time to log some info + crawler = self.get_crawler(crawler_name) + self.log.info("crawler_config: %s", crawler) + crawler_status = crawler["LastCrawl"]["Status"] + + metrics_response = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name]) + metrics = metrics_response["CrawlerMetricsList"][0] + self.log.info("Status: %s", crawler_status) + self.log.info("Last Runtime Duration (seconds): %s", metrics["LastRuntimeSeconds"]) + self.log.info("Median Runtime Duration (seconds): %s", metrics["MedianRuntimeSeconds"]) + self.log.info("Tables Created: %s", metrics["TablesCreated"]) + self.log.info("Tables Updated: %s", metrics["TablesUpdated"]) + self.log.info("Tables Deleted: %s", metrics["TablesDeleted"]) + + return crawler_status diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 497df84d31418..053e530c72674 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -21,10 +21,12 @@ import urllib.parse from typing import TYPE_CHECKING, Sequence +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink +from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -52,7 +54,10 @@ class GlueJobOperator(BaseOperator): :param iam_role_name: AWS IAM Role for Glue Job Execution :param create_job_kwargs: Extra arguments for Glue Job Creation :param run_job_kwargs: Extra arguments for Glue Job Run - :param wait_for_completion: Whether or not wait for job run completion. (default: True) + :param wait_for_completion: Whether to wait for job run completion. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) :param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False) :param update_config: If True, Operator will update job configuration. (default: False) """ @@ -91,6 +96,7 @@ def __init__( create_job_kwargs: dict | None = None, run_job_kwargs: dict | None = None, wait_for_completion: bool = True, + deferrable: bool = False, verbose: bool = False, update_config: bool = False, **kwargs, @@ -114,6 +120,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.verbose = verbose self.update_config = update_config + self.deferrable = deferrable def execute(self, context: Context): """ @@ -167,7 +174,18 @@ def execute(self, context: Context): job_run_id=glue_job_run["JobRunId"], ) self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url) - if self.wait_for_completion: + + if self.deferrable: + self.defer( + trigger=GlueJobCompleteTrigger( + job_name=self.job_name, + run_id=glue_job_run["JobRunId"], + verbose=self.verbose, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", @@ -178,3 +196,8 @@ def execute(self, context: Context): else: self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) return glue_job_run["JobRunId"] + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in glue job: {event}") + return diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index 426ca2f084d04..c7ac25f1f2e30 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -20,6 +20,9 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence +from airflow import AirflowException +from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger + if TYPE_CHECKING: from airflow.utils.context import Context @@ -40,7 +43,10 @@ class GlueCrawlerOperator(BaseOperator): :param config: Configurations for the AWS Glue crawler :param aws_conn_id: aws connection to use :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status - :param wait_for_completion: Whether or not wait for crawl execution completion. (default: True) + :param wait_for_completion: Whether to wait for crawl execution completion. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ("config",) @@ -53,18 +59,20 @@ def __init__( region_name: str | None = None, poll_interval: int = 5, wait_for_completion: bool = True, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval self.wait_for_completion = wait_for_completion + self.deferrable = deferrable self.region_name = region_name self.config = config @cached_property def hook(self) -> GlueCrawlerHook: - """Create and return an GlueCrawlerHook.""" + """Create and return a GlueCrawlerHook.""" return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): @@ -81,8 +89,22 @@ def execute(self, context: Context): self.log.info("Triggering AWS Glue Crawler") self.hook.start_crawler(crawler_name) - if self.wait_for_completion: + if self.deferrable: + self.defer( + trigger=GlueCrawlerCompleteTrigger( + crawler_name=crawler_name, + poll_interval=self.poll_interval, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: self.log.info("Waiting for AWS Glue Crawler") self.hook.wait_for_crawler_completion(crawler_name=crawler_name, poll_interval=self.poll_interval) return crawler_name + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in glue crawl: {event}") + return diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py new file mode 100644 index 0000000000000..42219a993ad25 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class GlueJobCompleteTrigger(BaseTrigger): + """ + Watches for a glue job, triggers when it finishes + + :param job_name: glue job name + :param run_id: the ID of the specific run to watch for that job + :param verbose: whether to print the job's logs in airflow logs or not + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + job_name: str, + run_id: str, + verbose: bool, + aws_conn_id: str, + ): + self.job_name = job_name + self.run_id = run_id + self.verbose = verbose + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "job_name": self.job_name, + "run_id": self.run_id, + "verbose": str(self.verbose), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = GlueJobHook(aws_conn_id=self.aws_conn_id) + await hook.async_job_completion(self.job_name, self.run_id, self.verbose) + yield TriggerEvent({"status": "success", "message": "Job done"}) diff --git a/airflow/providers/amazon/aws/triggers/glue_crawler.py b/airflow/providers/amazon/aws/triggers/glue_crawler.py new file mode 100644 index 0000000000000..10ab45dda73d3 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/glue_crawler.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from functools import cached_property +from typing import AsyncIterator + +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class GlueCrawlerCompleteTrigger(BaseTrigger): + """ + Watches for a glue crawl, triggers when it finishes + + :param crawler_name: name of the crawler to watch + :param poll_interval: The amount of time in seconds to wait between attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__(self, crawler_name: str, poll_interval: int, aws_conn_id: str): + super().__init__() + self.crawler_name = crawler_name + self.poll_interval = poll_interval + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict]: + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "crawler_name": self.crawler_name, + "poll_interval": self.poll_interval, + "aws_conn_id": self.aws_conn_id, + }, + ) + + @cached_property + def hook(self) -> GlueCrawlerHook: + return GlueCrawlerHook(aws_conn_id=self.aws_conn_id) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("crawler_ready", deferrable=True, client=client) + while True: + try: + await waiter.wait( + Name=self.crawler_name, + WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 1}, + ) + break # we reach this point only if the waiter met a success criteria + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Glue Crawler creation Failed: {error}"} + ) + break + self.log.info("Status of glue crawl is %s", error.last_response["Crawler"]["State"]) + await asyncio.sleep(int(self.poll_interval)) + + yield TriggerEvent({"status": "success", "message": "Crawl Complete"}) diff --git a/airflow/providers/amazon/aws/waiters/glue.json b/airflow/providers/amazon/aws/waiters/glue.json new file mode 100644 index 0000000000000..a8dd29572cde7 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/glue.json @@ -0,0 +1,30 @@ +{ + "version": 2, + "waiters": { + "crawler_ready": { + "operation": "GetCrawler", + "delay": 5, + "maxAttempts": 1000, + "acceptors": [ + { + "matcher": "path", + "argument": "Crawler.State == 'READY' && Crawler.LastCrawl.Status == 'FAILED'", + "expected": true, + "state": "failure" + }, + { + "matcher": "path", + "argument": "Crawler.State == 'READY' && Crawler.LastCrawl.Status == 'CANCELLED'", + "expected": true, + "state": "failure" + }, + { + "matcher": "path", + "argument": "Crawler.State", + "expected": "READY", + "state": "success" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 9a46abb34fdb9..ca07c7bad9f23 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import sys from unittest import mock from unittest.mock import MagicMock @@ -26,9 +27,15 @@ from botocore.exceptions import ClientError from moto import mock_glue, mock_iam +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +if sys.version_info < (3, 8): + from asynctest import mock as async_mock +else: + from unittest import mock as async_mock + class TestGlueJobHook: def setup_method(self): @@ -349,3 +356,65 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M assert tokens.output_stream_continuation is None assert tokens.error_stream_continuation is None assert client_mock().get_paginator().paginate.call_count == 2 + + @mock.patch.object(GlueJobHook, "get_job_state") + def test_job_completion_success(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + hook.job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + get_state_mock.assert_called_with("job_name", "run_id") + + @mock.patch.object(GlueJobHook, "get_job_state") + def test_job_completion_failure(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + hook.job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + + @pytest.mark.asyncio + @async_mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_async_job_completion_success(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + await hook.async_job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + get_state_mock.assert_called_with("job_name", "run_id") + + @pytest.mark.asyncio + @async_mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_async_job_completion_failure(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + await hook.async_job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 diff --git a/tests/providers/amazon/aws/hooks/test_glue_crawler.py b/tests/providers/amazon/aws/hooks/test_glue_crawler.py index 1f1961e0298a9..6f34b789ff724 100644 --- a/tests/providers/amazon/aws/hooks/test_glue_crawler.py +++ b/tests/providers/amazon/aws/hooks/test_glue_crawler.py @@ -19,6 +19,7 @@ from copy import deepcopy from unittest import mock +from unittest.mock import MagicMock from moto import mock_sts from moto.core import DEFAULT_ACCOUNT_ID @@ -198,10 +199,11 @@ def test_start_crawler(self, mock_get_conn): @mock.patch.object(GlueCrawlerHook, "get_crawler") @mock.patch.object(GlueCrawlerHook, "get_conn") - def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get_crawler): - mock_get_crawler.side_effect = [ - {"State": "READY", "LastCrawl": {"Status": "MOCK_STATUS"}}, - ] + @mock.patch.object(GlueCrawlerHook, "get_waiter") + def test_wait_for_crawler_completion_instant_ready( + self, _, mock_get_conn: MagicMock, mock_get_crawler: MagicMock + ): + mock_get_crawler.return_value = {"State": "READY", "LastCrawl": {"Status": "MOCK_STATUS"}} mock_get_conn.return_value.get_crawler_metrics.return_value = { "CrawlerMetricsList": [ { @@ -220,44 +222,4 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get mock.call().get_crawler_metrics(CrawlerNameList=[mock_crawler_name]), ] ) - mock_get_crawler.assert_has_calls( - [ - mock.call(mock_crawler_name), - ] - ) - - @mock.patch.object(GlueCrawlerHook, "get_conn") - @mock.patch.object(GlueCrawlerHook, "get_crawler") - @mock.patch("airflow.providers.amazon.aws.hooks.glue_crawler.sleep") - def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_crawler, mock_get_conn): - mock_get_crawler.side_effect = [ - {"State": "RUNNING"}, - {"State": "READY", "LastCrawl": {"Status": "MOCK_STATUS"}}, - ] - mock_get_conn.return_value.get_crawler_metrics.side_effect = [ - {"CrawlerMetricsList": [{"TimeLeftSeconds": 12}]}, - { - "CrawlerMetricsList": [ - { - "LastRuntimeSeconds": "TEST-A", - "MedianRuntimeSeconds": "TEST-B", - "TablesCreated": "TEST-C", - "TablesUpdated": "TEST-D", - "TablesDeleted": "TEST-E", - } - ] - }, - ] - assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS" - mock_get_conn.assert_has_calls( - [ - mock.call(), - mock.call().get_crawler_metrics(CrawlerNameList=[mock_crawler_name]), - ] - ) - mock_get_crawler.assert_has_calls( - [ - mock.call(mock_crawler_name), - mock.call(mock_crawler_name), - ] - ) + mock_get_crawler.assert_called_once_with(mock_crawler_name) diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index db5ff1e6c232f..03b5e154f47e4 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -21,6 +21,7 @@ import pytest from airflow.configuration import conf +from airflow.exceptions import TaskDeferred from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -98,6 +99,27 @@ def test_execute_without_failure( mock_print_job_logs.assert_not_called() assert glue.job_name == JOB_NAME + @mock.patch.object(GlueJobHook, "initialize_job") + @mock.patch.object(GlueJobHook, "get_conn") + def test_execute_deferrable(self, _, mock_initialize_job): + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + deferrable=True, + ) + mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID} + + with pytest.raises(TaskDeferred) as defer: + glue.execute(mock.MagicMock()) + + assert defer.value.trigger.job_name == JOB_NAME + assert defer.value.trigger.run_id == JOB_RUN_ID + @mock.patch.object(GlueJobHook, "print_job_logs") @mock.patch.object(GlueJobHook, "get_job_state") @mock.patch.object(GlueJobHook, "initialize_job") diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py new file mode 100644 index 0000000000000..c371052b22497 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_glue.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest +from asynctest import MagicMock, mock + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger + + +class TestGlueJobTrigger: + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_wait_job(self, get_state_mock: MagicMock): + GlueJobHook.JOB_POLL_INTERVAL = 0.1 + trigger = GlueJobCompleteTrigger( + job_name="job_name", + run_id="JobRunId", + verbose=False, + aws_conn_id="aws_conn_id", + ) + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + generator = trigger.run() + event = await generator.asend(None) + + assert get_state_mock.call_count == 3 + assert event.payload["status"] == "success" + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_wait_job_failed(self, get_state_mock: MagicMock): + GlueJobHook.JOB_POLL_INTERVAL = 0.1 + trigger = GlueJobCompleteTrigger( + job_name="job_name", + run_id="JobRunId", + verbose=False, + aws_conn_id="aws_conn_id", + ) + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + await trigger.run().asend(None) + + assert get_state_mock.call_count == 3