Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from __future__ import annotations

import warnings
from typing import Literal

import requests
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -55,6 +58,7 @@ def invoke_rest_api(
body: dict | None = None,
query_params: dict | None = None,
generate_local_token: bool = False,
airflow_version: Literal[2, 3] | None = None,
) -> dict:
"""
Invoke the REST API on the Airflow webserver with the specified inputs.
Expand All @@ -70,6 +74,8 @@ def invoke_rest_api(
:param generate_local_token: If True, only the local web token method is used without trying boto's
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
boto's `invoke_rest_api`
:param airflow_version: The Airflow major version the MWAA environment runs.
This parameter is only used if the local web token method is used to call Airflow API.
"""
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
body = {k: v for k, v in body.items() if v is not None} if body else {}
Expand All @@ -83,7 +89,7 @@ def invoke_rest_api(
}

if generate_local_token:
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)

try:
response = self.conn.invoke_rest_api(**api_kwargs)
Expand All @@ -100,7 +106,7 @@ def invoke_rest_api(
self.log.info(
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
)
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)
to_log = e.response
# ResponseMetadata is removed because it contains data that is either very unlikely to be
# useful in XComs and logs, or redundant given the data already included in the response
Expand All @@ -110,14 +116,35 @@ def invoke_rest_api(

def _invoke_rest_api_using_local_session_token(
self,
airflow_version: Literal[2, 3] | None = None,
**api_kwargs,
) -> dict:
if not airflow_version:
warnings.warn(
"The parameter ``airflow_version`` in ``MwaaHook.invoke_rest_api`` is not "
"specified and the local web token method is being used. "
"The default Airflow version being used is 2 but this value will change in the future. "
"To avoid any unexpected behavior, please explicitly specify the Airflow version.",
FutureWarning,
stacklevel=3,
)
airflow_version = 2

try:
session, hostname = self._get_session_conn(api_kwargs["Name"])
session, hostname, login_response = self._get_session_conn(api_kwargs["Name"], airflow_version)

headers = {}
if airflow_version == 3:
headers = {
"Authorization": f"Bearer {login_response.cookies['_token']}",
"Content-Type": "application/json",
}

api_version = "v1" if airflow_version == 2 else "v2"
response = session.request(
method=api_kwargs["Method"],
url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
url=f"https://{hostname}/api/{api_version}{api_kwargs['Path']}",
headers=headers,
params=api_kwargs["QueryParameters"],
json=api_kwargs["Body"],
timeout=10,
Expand All @@ -134,15 +161,19 @@ def _invoke_rest_api_using_local_session_token(
}

# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
def _get_session_conn(self, env_name: str) -> tuple:
def _get_session_conn(self, env_name: str, airflow_version: Literal[2, 3]) -> tuple:
create_token_response = self.conn.create_web_login_token(Name=env_name)
web_server_hostname = create_token_response["WebServerHostname"]
web_token = create_token_response["WebToken"]

login_url = f"https://{web_server_hostname}/aws_mwaa/login"
login_url = (
f"https://{web_server_hostname}/aws_mwaa/login"
if airflow_version == 2
else f"https://{web_server_hostname}/pluginsv2/aws_mwaa/login"
)
login_payload = {"token": web_token}
session = requests.Session()
login_response = session.post(login_url, data=login_payload, timeout=10)
login_response.raise_for_status()

return session, web_server_hostname
return session, web_server_hostname, login_response
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -46,12 +46,14 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
:param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated)
:param logical_date: The logical date (previously called execution date). This is the time or interval
covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a
unique key. (templated)
unique key. This field is required if your environment is running with Airflow 3. (templated)
:param data_interval_start: The beginning of the interval the DAG run covers
:param data_interval_end: The end of the interval the DAG run covers
:param conf: Additional configuration parameters. The value of this field can be set only when creating
the object. (templated)
:param note: Contains manually entered notes by the user about the DagRun. (templated)
:param airflow_version: The Airflow major version the MWAA environment runs.
This parameter is only used if the local web token method is used to call Airflow API. (templated)

:param wait_for_completion: Whether to wait for DAG run to stop. (default: False)
:param waiter_delay: Time in seconds to wait between status checks. (default: 120)
Expand Down Expand Up @@ -81,6 +83,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
"data_interval_end",
"conf",
"note",
"airflow_version",
)
template_fields_renderers = {"conf": "json"}

Expand All @@ -95,6 +98,7 @@ def __init__(
data_interval_end: str | None = None,
conf: dict | None = None,
note: str | None = None,
airflow_version: Literal[2, 3] | None = None,
wait_for_completion: bool = False,
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
Expand All @@ -110,6 +114,7 @@ def __init__(
self.data_interval_end = data_interval_end
self.conf = conf if conf else {}
self.note = note
self.airflow_version = airflow_version
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
Expand All @@ -123,7 +128,10 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
dag_run_id = validated_event["dag_run_id"]
self.log.info("DAG run %s of DAG %s completed", dag_run_id, self.trigger_dag_id)
return self.hook.invoke_rest_api(
env_name=self.env_name, path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}", method="GET"
env_name=self.env_name,
path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}",
method="GET",
airflow_version=self.airflow_version,
)

def execute(self, context: Context) -> dict:
Expand All @@ -146,6 +154,7 @@ def execute(self, context: Context) -> dict:
"conf": self.conf,
"note": self.note,
},
airflow_version=self.airflow_version,
)

dag_run_id = response["RestApiResponse"]["dag_run_id"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -51,6 +51,8 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
:param airflow_version: The Airflow major version the MWAA environment runs.
This parameter is only used if the local web token method is used to call Airflow API. (templated)
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
Expand All @@ -75,6 +77,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
"external_dag_run_id",
"success_states",
"failure_states",
"airflow_version",
"deferrable",
"max_retries",
"poke_interval",
Expand All @@ -88,6 +91,7 @@ def __init__(
external_dag_run_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
airflow_version: Literal[2, 3] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poke_interval: int = 60,
max_retries: int = 720,
Expand All @@ -104,6 +108,7 @@ def __init__(
self.external_env_name = external_env_name
self.external_dag_id = external_dag_id
self.external_dag_run_id = external_dag_run_id
self.airflow_version = airflow_version
self.deferrable = deferrable
self.poke_interval = poke_interval
self.max_retries = max_retries
Expand All @@ -119,6 +124,7 @@ def poke(self, context: Context) -> bool:
env_name=self.external_env_name,
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}",
method="GET",
airflow_version=self.airflow_version,
)

# If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has
Expand Down Expand Up @@ -179,6 +185,8 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated)
:param failure_states: Collection of task instance states that would make this task marked as failed and raise an
AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated)
:param airflow_version: The Airflow major version the MWAA environment runs.
This parameter is only used if the local web token method is used to call Airflow API. (templated)
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
Expand All @@ -204,6 +212,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
"external_task_id",
"success_states",
"failure_states",
"airflow_version",
"deferrable",
"max_retries",
"poke_interval",
Expand All @@ -218,6 +227,7 @@ def __init__(
external_task_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
airflow_version: Literal[2, 3] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poke_interval: int = 60,
max_retries: int = 720,
Expand All @@ -235,6 +245,7 @@ def __init__(
self.external_dag_id = external_dag_id
self.external_dag_run_id = external_dag_run_id
self.external_task_id = external_task_id
self.airflow_version = airflow_version
self.deferrable = deferrable
self.poke_interval = poke_interval
self.max_retries = max_retries
Expand All @@ -252,6 +263,7 @@ def poke(self, context: Context) -> bool:
env_name=self.external_env_name,
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}/taskInstances/{self.external_task_id}",
method="GET",
airflow_version=self.airflow_version,
)
# If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has
# happened in the API and KeyError would be raised
Expand All @@ -278,6 +290,7 @@ def execute(self, context: Context):
env_name=self.external_env_name,
path=f"/dags/{self.external_dag_id}/dagRuns",
method="GET",
airflow_version=self.airflow_version,
)
self.external_dag_run_id = response["RestApiResponse"]["dag_runs"][-1]["dag_run_id"]

Expand All @@ -290,6 +303,7 @@ def execute(self, context: Context):
external_task_id=self.external_task_id,
success_states=self.success_states,
failure_states=self.failure_states,
airflow_version=self.airflow_version,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from collections.abc import Collection
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
Expand Down Expand Up @@ -121,6 +121,8 @@ class MwaaTaskCompletedTrigger(AwsBaseWaiterTrigger):
``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated)
:param failure_states: Collection of task instance states that would make this task marked as failed and raise an
AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated)
:param airflow_version: The Airflow major version the MWAA environment runs.
This parameter is only used if the local web token method is used to call Airflow API. (templated)
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
:param aws_conn_id: The Airflow connection used for AWS credentials.
Expand All @@ -135,6 +137,7 @@ def __init__(
external_task_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
airflow_version: Literal[2, 3] | None = None,
waiter_delay: int = 60,
waiter_max_attempts: int = 720,
**kwargs,
Expand Down Expand Up @@ -165,6 +168,7 @@ def __init__(
"Name": external_env_name,
"Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}/taskInstances/{external_task_id}",
"Method": "GET",
"airflow_version": airflow_version,
},
failure_message=f"The task {external_task_id} of DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
status_message="State of DAG run",
Expand Down
10 changes: 8 additions & 2 deletions providers/amazon/tests/system/amazon/aws/example_mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import DAG, chain, task
from airflow.sdk import DAG, chain, task, timezone
else:
# Airflow 2 path
from airflow.decorators import task # type: ignore[attr-defined,no-redef]
Expand Down Expand Up @@ -95,7 +95,9 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name):

mwaa_hook = MwaaHook()
mwaa_hook.conn = session.client("mwaa")
response = mwaa_hook.invoke_rest_api(env_name=mwaa_env_name, path="/dags", method="GET")
response = mwaa_hook.invoke_rest_api(
env_name=mwaa_env_name, path="/dags", method="GET", airflow_version=3
)
return "dags" in response["RestApiResponse"]


Expand All @@ -116,8 +118,10 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name):
trigger_dag_run = MwaaTriggerDagRunOperator(
task_id="trigger_dag_run",
env_name=env_name,
logical_date=datetime.now(timezone.utc).isoformat(),
trigger_dag_id=trigger_dag_id,
wait_for_completion=True,
airflow_version=3,
)
# [END howto_operator_mwaa_trigger_dag_run]

Expand All @@ -144,8 +148,10 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name):
trigger_dag_run_dont_wait = MwaaTriggerDagRunOperator(
task_id="trigger_dag_run_dont_wait",
env_name=env_name,
logical_date=datetime.now(timezone.utc).isoformat(),
trigger_dag_id=trigger_dag_id,
wait_for_completion=False,
airflow_version=3,
)

wait_for_task_concurrent = MwaaTaskSensor(
Expand Down
Loading