Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, conf, timezone
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,12 +87,7 @@ class AwsBatchExecutor(BaseExecutor):
Airflow TaskInstance's executor_config.
"""

# Maximum number of retries to submit a Batch Job.
MAX_SUBMIT_JOB_ATTEMPTS = conf.get(
CONFIG_GROUP_NAME,
AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS,
fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS],
)
supports_multi_team: bool = True

# AWS only allows a maximum number of JOBs in the describe_jobs function
DESCRIBE_JOBS_BATCH_SIZE = 99
Expand All @@ -106,11 +101,29 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers = BatchJobCollection()
self.pending_jobs: deque = deque()

# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
# Airflow.
# Can be removed when minimum supported provider version is equal to the version of core airflow
# which introduces multi-team configuration.
if not hasattr(self, "conf"):
from airflow.providers.common.compat.sdk import conf

self.conf = conf

self.attempts_since_last_successful_connection = 0
self.load_batch_connection(check_connection=False)
self.IS_BOTO_CONNECTION_HEALTHY = False
self.submit_job_kwargs = self._load_submit_kwargs()

# Maximum number of retries to submit a Batch job.
self.max_submit_job_attempts = self.conf.get(
CONFIG_GROUP_NAME,
AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS,
fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS],
)

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

Expand Down Expand Up @@ -164,7 +177,7 @@ def check_health(self):

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
check_health = conf.getboolean(
check_health = self.conf.getboolean(
CONFIG_GROUP_NAME, AllBatchConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
)

Expand All @@ -180,12 +193,12 @@ def start(self):

def load_batch_connection(self, check_connection: bool = True):
self.log.info("Loading Connection information")
aws_conn_id = conf.get(
aws_conn_id = self.conf.get(
CONFIG_GROUP_NAME,
AllBatchConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.AWS_CONN_ID],
)
region_name = conf.get(CONFIG_GROUP_NAME, AllBatchConfigKeys.REGION_NAME, fallback=None)
region_name = self.conf.get(CONFIG_GROUP_NAME, AllBatchConfigKeys.REGION_NAME, fallback=None)
self.batch = BatchClientHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.attempts_since_last_successful_connection += 1
self.last_connection_reload = timezone.utcnow()
Expand Down Expand Up @@ -255,13 +268,13 @@ def _handle_failed_job(self, job):
queue = job_info.queue
exec_info = job_info.config
failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id)
if int(failure_count) < int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
if int(failure_count) < int(self.max_submit_job_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
job.status_reason,
failure_count,
self.__class__.MAX_SUBMIT_JOB_ATTEMPTS,
self.max_submit_job_attempts,
job.job_id,
)
self.active_workers.increment_failure_count(job_id=job.job_id)
Expand Down Expand Up @@ -320,7 +333,7 @@ def attempt_submit_jobs(self):
failure_reason = str(e)

if failure_reason:
if attempt_number >= int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
if attempt_number >= int(self.max_submit_job_attempts):
self.log.error(
(
"This job has been unsuccessfully attempted too many times (%s). "
Expand Down Expand Up @@ -459,11 +472,10 @@ def terminate(self):
# up and kill the scheduler process.
self.log.exception("Failed to terminate %s", self.__class__.__name__)

@staticmethod
def _load_submit_kwargs() -> dict:
def _load_submit_kwargs(self) -> dict:
from airflow.providers.amazon.aws.executors.batch.batch_executor_config import build_submit_kwargs

submit_kwargs = build_submit_kwargs()
submit_kwargs = build_submit_kwargs(self.conf)

if "containerOverrides" not in submit_kwargs or "command" not in submit_kwargs["containerOverrides"]:
raise KeyError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,25 @@
BatchSubmitJobKwargsConfigKeys,
)
from airflow.providers.amazon.aws.executors.ecs.utils import camelize_dict_keys
from airflow.providers.common.compat.sdk import conf
from airflow.utils.helpers import prune_dict


def _fetch_templated_kwargs() -> dict[str, str]:
def _fetch_templated_kwargs(conf) -> dict[str, str]:
submit_job_kwargs_value = conf.get(
CONFIG_GROUP_NAME, AllBatchConfigKeys.SUBMIT_JOB_KWARGS, fallback=dict()
)
return json.loads(str(submit_job_kwargs_value))


def _fetch_config_values() -> dict[str, str]:
def _fetch_config_values(conf) -> dict[str, str]:
return prune_dict(
{key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in BatchSubmitJobKwargsConfigKeys()}
)


def build_submit_kwargs() -> dict:
job_kwargs = _fetch_config_values()
job_kwargs.update(_fetch_templated_kwargs())
def build_submit_kwargs(conf) -> dict:
job_kwargs = _fetch_config_values(conf)
job_kwargs.update(_fetch_templated_kwargs(conf))

if "containerOverrides" not in job_kwargs:
job_kwargs["containerOverrides"] = {} # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import logging
import os
from unittest import mock
from unittest.mock import patch

import pytest
import yaml
from botocore.exceptions import ClientError, NoCredentialsError
from semver import VersionInfo

from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
Expand All @@ -48,7 +50,7 @@

from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS

airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
ARN1 = "arn1"
Expand Down Expand Up @@ -444,7 +446,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
mock_executor.sync_running_jobs()
for i in range(2):
assert (
f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 1 out of {mock_executor.MAX_SUBMIT_JOB_ATTEMPTS} occurred on {jobs[i]['jobId']}. Rescheduling."
f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 1 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling."
in caplog.messages[i]
)

Expand All @@ -453,7 +455,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
mock_executor.sync_running_jobs()
for i in range(2):
assert (
f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 2 out of {mock_executor.MAX_SUBMIT_JOB_ATTEMPTS} occurred on {jobs[i]['jobId']}. Rescheduling."
f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 2 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling."
in caplog.messages[i]
)

Expand All @@ -462,7 +464,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
mock_executor.sync_running_jobs()
for i in range(2):
assert (
f"Airflow task {airflow_keys[i]} has failed a maximum of {mock_executor.MAX_SUBMIT_JOB_ATTEMPTS} times. Marking as failed"
f"Airflow task {airflow_keys[i]} has failed a maximum of {mock_executor.max_submit_job_attempts} times. Marking as failed"
in caplog.text
)

Expand Down Expand Up @@ -708,6 +710,47 @@ def test_try_adopt_task_instances(self, mock_executor):
# The remaining one task is unable to be adopted.
assert len(not_adopted_tasks) == 1

@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Multi-team support requires Airflow 3.1+")
def test_team_config(self):
"""Test that the executor uses team-specific configuration when provided via self.conf."""
# Team name to be used throughout
team_name = "team_a"
# Patch environment to include two sets of configs for the Batch executor. One that is related to a
# team and one that is not. Then we will create two executors (one with a team and one without) and
# ensure the correct configs are used.
config_overrides = [
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_QUEUE}", "some-job-queue"),
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_DEFINITION}", "some-job-def"),
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_NAME}", "some-job-name"),
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.REGION_NAME}", "us-west-1"),
# team Config
(
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_QUEUE}",
"team_a_job_queue",
),
(
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_DEFINITION}",
"team_a_job_def",
),
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_NAME}", "team_a_job_name"),
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.REGION_NAME}", "us-west-2"),
]
with patch("os.environ", {key.upper(): value for key, value in config_overrides}):
team_executor = AwsBatchExecutor(team_name=team_name)
submit_kwargs = batch_executor_config.build_submit_kwargs(team_executor.conf)

assert submit_kwargs["jobQueue"] == "team_a_job_queue"
assert submit_kwargs["jobDefinition"] == "team_a_job_def"
assert submit_kwargs["jobName"] == "team_a_job_name"

# Now create an executor without a team and ensure the non-team configs are used.
non_team_executor = AwsBatchExecutor()
submit_kwargs = batch_executor_config.build_submit_kwargs(non_team_executor.conf)

assert submit_kwargs["jobQueue"] == "some-job-queue"
assert submit_kwargs["jobDefinition"] == "some-job-def"
assert submit_kwargs["jobName"] == "some-job-name"


class TestBatchExecutorConfig:
@staticmethod
Expand Down Expand Up @@ -751,7 +794,7 @@ def test_executor_config_exceptions(self, bad_config, mock_executor):
mock_executor.execute_async(mock_airflow_key, mock_cmd, executor_config=bad_config)

def test_config_defaults_are_applied(self):
submit_kwargs = batch_executor_config.build_submit_kwargs()
submit_kwargs = batch_executor_config.build_submit_kwargs(conf)
found_keys = {convert_camel_to_snake(key): key for key in submit_kwargs.keys()}

for expected_key, expected_value in CONFIG_DEFAULTS.items():
Expand Down Expand Up @@ -781,7 +824,7 @@ def test_verify_tags_are_used_as_provided(self):
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.SUBMIT_JOB_KWARGS}".upper()
)
os.environ[run_submit_kwargs_env_key] = json.dumps(provided_run_submit_kwargs)
submit_kwargs = batch_executor_config.build_submit_kwargs()
submit_kwargs = batch_executor_config.build_submit_kwargs(conf)

# Verify that tag names are exempt from the camel-case conversion.
assert submit_kwargs["tags"] == templated_tags
Expand Down