diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index 49d03a503c66d..181364235cd3a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -18,6 +18,9 @@ from __future__ import annotations +import warnings +from typing import Literal + import requests from botocore.exceptions import ClientError @@ -55,6 +58,7 @@ def invoke_rest_api( body: dict | None = None, query_params: dict | None = None, generate_local_token: bool = False, + airflow_version: Literal[2, 3] | None = None, ) -> dict: """ Invoke the REST API on the Airflow webserver with the specified inputs. @@ -70,6 +74,8 @@ def invoke_rest_api( :param generate_local_token: If True, only the local web token method is used without trying boto's `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying boto's `invoke_rest_api` + :param airflow_version: The Airflow major version the MWAA environment runs. + This parameter is only used if the local web token method is used to call Airflow API. """ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise body = {k: v for k, v in body.items() if v is not None} if body else {} @@ -83,7 +89,7 @@ def invoke_rest_api( } if generate_local_token: - return self._invoke_rest_api_using_local_session_token(**api_kwargs) + return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs) try: response = self.conn.invoke_rest_api(**api_kwargs) @@ -100,7 +106,7 @@ def invoke_rest_api( self.log.info( "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..." ) - return self._invoke_rest_api_using_local_session_token(**api_kwargs) + return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs) to_log = e.response # ResponseMetadata is removed because it contains data that is either very unlikely to be # useful in XComs and logs, or redundant given the data already included in the response @@ -110,14 +116,35 @@ def invoke_rest_api( def _invoke_rest_api_using_local_session_token( self, + airflow_version: Literal[2, 3] | None = None, **api_kwargs, ) -> dict: + if not airflow_version: + warnings.warn( + "The parameter ``airflow_version`` in ``MwaaHook.invoke_rest_api`` is not " + "specified and the local web token method is being used. " + "The default Airflow version being used is 2 but this value will change in the future. " + "To avoid any unexpected behavior, please explicitly specify the Airflow version.", + FutureWarning, + stacklevel=3, + ) + airflow_version = 2 + try: - session, hostname = self._get_session_conn(api_kwargs["Name"]) + session, hostname, login_response = self._get_session_conn(api_kwargs["Name"], airflow_version) + + headers = {} + if airflow_version == 3: + headers = { + "Authorization": f"Bearer {login_response.cookies['_token']}", + "Content-Type": "application/json", + } + api_version = "v1" if airflow_version == 2 else "v2" response = session.request( method=api_kwargs["Method"], - url=f"https://{hostname}/api/v1{api_kwargs['Path']}", + url=f"https://{hostname}/api/{api_version}{api_kwargs['Path']}", + headers=headers, params=api_kwargs["QueryParameters"], json=api_kwargs["Body"], timeout=10, @@ -134,15 +161,19 @@ def _invoke_rest_api_using_local_session_token( } # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token - def _get_session_conn(self, env_name: str) -> tuple: + def _get_session_conn(self, env_name: str, airflow_version: Literal[2, 3]) -> tuple: create_token_response = self.conn.create_web_login_token(Name=env_name) web_server_hostname = create_token_response["WebServerHostname"] web_token = create_token_response["WebToken"] - login_url = f"https://{web_server_hostname}/aws_mwaa/login" + login_url = ( + f"https://{web_server_hostname}/aws_mwaa/login" + if airflow_version == 2 + else f"https://{web_server_hostname}/pluginsv2/aws_mwaa/login" + ) login_payload = {"token": web_token} session = requests.Session() login_response = session.post(login_url, data=login_payload, timeout=10) login_response.raise_for_status() - return session, web_server_hostname + return session, web_server_hostname, login_response diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py index 6b65d03e63408..92eb513c76cd6 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -19,7 +19,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -46,12 +46,14 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): :param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated) :param logical_date: The logical date (previously called execution date). This is the time or interval covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a - unique key. (templated) + unique key. This field is required if your environment is running with Airflow 3. (templated) :param data_interval_start: The beginning of the interval the DAG run covers :param data_interval_end: The end of the interval the DAG run covers :param conf: Additional configuration parameters. The value of this field can be set only when creating the object. (templated) :param note: Contains manually entered notes by the user about the DagRun. (templated) + :param airflow_version: The Airflow major version the MWAA environment runs. + This parameter is only used if the local web token method is used to call Airflow API. (templated) :param wait_for_completion: Whether to wait for DAG run to stop. (default: False) :param waiter_delay: Time in seconds to wait between status checks. (default: 120) @@ -81,6 +83,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): "data_interval_end", "conf", "note", + "airflow_version", ) template_fields_renderers = {"conf": "json"} @@ -95,6 +98,7 @@ def __init__( data_interval_end: str | None = None, conf: dict | None = None, note: str | None = None, + airflow_version: Literal[2, 3] | None = None, wait_for_completion: bool = False, waiter_delay: int = 60, waiter_max_attempts: int = 20, @@ -110,6 +114,7 @@ def __init__( self.data_interval_end = data_interval_end self.conf = conf if conf else {} self.note = note + self.airflow_version = airflow_version self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts @@ -123,7 +128,10 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None dag_run_id = validated_event["dag_run_id"] self.log.info("DAG run %s of DAG %s completed", dag_run_id, self.trigger_dag_id) return self.hook.invoke_rest_api( - env_name=self.env_name, path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}", method="GET" + env_name=self.env_name, + path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}", + method="GET", + airflow_version=self.airflow_version, ) def execute(self, context: Context) -> dict: @@ -146,6 +154,7 @@ def execute(self, context: Context) -> dict: "conf": self.conf, "note": self.note, }, + airflow_version=self.airflow_version, ) dag_run_id = response["RestApiResponse"]["dag_run_id"] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py index dcf6328a8afb7..53f4e36e5741f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Collection, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -51,6 +51,8 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]): ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated) :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated) + :param airflow_version: The Airflow major version the MWAA environment runs. + This parameter is only used if the local web token method is used to call Airflow API. (templated) :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) @@ -75,6 +77,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]): "external_dag_run_id", "success_states", "failure_states", + "airflow_version", "deferrable", "max_retries", "poke_interval", @@ -88,6 +91,7 @@ def __init__( external_dag_run_id: str, success_states: Collection[str] | None = None, failure_states: Collection[str] | None = None, + airflow_version: Literal[2, 3] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: int = 60, max_retries: int = 720, @@ -104,6 +108,7 @@ def __init__( self.external_env_name = external_env_name self.external_dag_id = external_dag_id self.external_dag_run_id = external_dag_run_id + self.airflow_version = airflow_version self.deferrable = deferrable self.poke_interval = poke_interval self.max_retries = max_retries @@ -119,6 +124,7 @@ def poke(self, context: Context) -> bool: env_name=self.external_env_name, path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}", method="GET", + airflow_version=self.airflow_version, ) # If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has @@ -179,6 +185,8 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]): ``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated) :param failure_states: Collection of task instance states that would make this task marked as failed and raise an AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated) + :param airflow_version: The Airflow major version the MWAA environment runs. + This parameter is only used if the local web token method is used to call Airflow API. (templated) :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) @@ -204,6 +212,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]): "external_task_id", "success_states", "failure_states", + "airflow_version", "deferrable", "max_retries", "poke_interval", @@ -218,6 +227,7 @@ def __init__( external_task_id: str, success_states: Collection[str] | None = None, failure_states: Collection[str] | None = None, + airflow_version: Literal[2, 3] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: int = 60, max_retries: int = 720, @@ -235,6 +245,7 @@ def __init__( self.external_dag_id = external_dag_id self.external_dag_run_id = external_dag_run_id self.external_task_id = external_task_id + self.airflow_version = airflow_version self.deferrable = deferrable self.poke_interval = poke_interval self.max_retries = max_retries @@ -252,6 +263,7 @@ def poke(self, context: Context) -> bool: env_name=self.external_env_name, path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}/taskInstances/{self.external_task_id}", method="GET", + airflow_version=self.airflow_version, ) # If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has # happened in the API and KeyError would be raised @@ -278,6 +290,7 @@ def execute(self, context: Context): env_name=self.external_env_name, path=f"/dags/{self.external_dag_id}/dagRuns", method="GET", + airflow_version=self.airflow_version, ) self.external_dag_run_id = response["RestApiResponse"]["dag_runs"][-1]["dag_run_id"] @@ -290,6 +303,7 @@ def execute(self, context: Context): external_task_id=self.external_task_id, success_states=self.success_states, failure_states=self.failure_states, + airflow_version=self.airflow_version, waiter_delay=int(self.poke_interval), waiter_max_attempts=self.max_retries, aws_conn_id=self.aws_conn_id, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py index b31f5a018b9a5..8b5538446c307 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Collection -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger @@ -121,6 +121,8 @@ class MwaaTaskCompletedTrigger(AwsBaseWaiterTrigger): ``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated) :param failure_states: Collection of task instance states that would make this task marked as failed and raise an AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated) + :param airflow_version: The Airflow major version the MWAA environment runs. + This parameter is only used if the local web token method is used to call Airflow API. (templated) :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) :param waiter_max_attempts: The maximum number of attempts to be made. (default: 720) :param aws_conn_id: The Airflow connection used for AWS credentials. @@ -135,6 +137,7 @@ def __init__( external_task_id: str, success_states: Collection[str] | None = None, failure_states: Collection[str] | None = None, + airflow_version: Literal[2, 3] | None = None, waiter_delay: int = 60, waiter_max_attempts: int = 720, **kwargs, @@ -165,6 +168,7 @@ def __init__( "Name": external_env_name, "Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}/taskInstances/{external_task_id}", "Method": "GET", + "airflow_version": airflow_version, }, failure_message=f"The task {external_task_id} of DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state", status_message="State of DAG run", diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa.py b/providers/amazon/tests/system/amazon/aws/example_mwaa.py index c5f2f980153ac..22385eea7dd04 100644 --- a/providers/amazon/tests/system/amazon/aws/example_mwaa.py +++ b/providers/amazon/tests/system/amazon/aws/example_mwaa.py @@ -28,7 +28,7 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import DAG, chain, task + from airflow.sdk import DAG, chain, task, timezone else: # Airflow 2 path from airflow.decorators import task # type: ignore[attr-defined,no-redef] @@ -95,7 +95,9 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): mwaa_hook = MwaaHook() mwaa_hook.conn = session.client("mwaa") - response = mwaa_hook.invoke_rest_api(env_name=mwaa_env_name, path="/dags", method="GET") + response = mwaa_hook.invoke_rest_api( + env_name=mwaa_env_name, path="/dags", method="GET", airflow_version=3 + ) return "dags" in response["RestApiResponse"] @@ -116,8 +118,10 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): trigger_dag_run = MwaaTriggerDagRunOperator( task_id="trigger_dag_run", env_name=env_name, + logical_date=datetime.now(timezone.utc).isoformat(), trigger_dag_id=trigger_dag_id, wait_for_completion=True, + airflow_version=3, ) # [END howto_operator_mwaa_trigger_dag_run] @@ -144,8 +148,10 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): trigger_dag_run_dont_wait = MwaaTriggerDagRunOperator( task_id="trigger_dag_run_dont_wait", env_name=env_name, + logical_date=datetime.now(timezone.utc).isoformat(), trigger_dag_id=trigger_dag_id, wait_for_completion=False, + airflow_version=3, ) wait_for_task_concurrent = MwaaTaskSensor( diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa_airflow2.py b/providers/amazon/tests/system/amazon/aws/example_mwaa_airflow2.py new file mode 100644 index 0000000000000..ccbf02536fa84 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_mwaa_airflow2.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +import boto3 + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from airflow.providers.amazon.aws.hooks.sts import StsHook +from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator +from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor, MwaaTaskSensor + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, chain, task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] + from airflow.models.baseoperator import chain # type: ignore[attr-defined,no-redef] + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + + +from system.amazon.aws.utils import SystemTestContextBuilder + +DAG_ID = "example_mwaa_airflow2" + +# Externally fetched variables: +EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME" +EXISTING_DAG_ID_KEY = "DAG_ID" +EXISTING_TASK_ID_KEY = "TASK_ID" +ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY = "ROLE_WITHOUT_INVOKE_REST_API_ARN" + +sys_test_context_task = ( + SystemTestContextBuilder() + # NOTE: Creating a functional MWAA environment is time-consuming and requires + # manually creating and configuring an S3 bucket for DAG storage and a VPC with + # private subnets which is out of scope for this demo. To simplify this demo and + # make it run in a reasonable time, an existing MWAA environment already + # containing a DAG is required. + # Here's a quick start guide to create an MWAA environment using AWS CloudFormation: + # https://docs.aws.amazon.com/mwaa/latest/userguide/quick-start.html + # If creating the environment using the AWS console, make sure to have a VPC with + # at least 1 private subnet to be able to select the VPC while going through the + # environment creation steps in the console wizard. + # Make sure to set the environment variables with appropriate values + .add_variable(EXISTING_ENVIRONMENT_NAME_KEY) + .add_variable(EXISTING_DAG_ID_KEY) + .add_variable(ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY) + .add_variable(EXISTING_TASK_ID_KEY) + .build() +) + + +@task +def unpause_dag(env_name: str, dag_id: str): + mwaa_hook = MwaaHook() + response = mwaa_hook.invoke_rest_api( + env_name=env_name, path=f"/dags/{dag_id}", method="PATCH", body={"is_paused": False} + ) + return not response["RestApiResponse"]["is_paused"] + + +# This task in the system test verifies that the MwaaHook's IAM fallback mechanism continues to work with +# the live MWAA API. This fallback depends on parsing a specific error message from the MWAA API, so we +# want to ensure we find out if the API response format ever changes. Unit tests cover this with mocked +# responses, but this system test validates against the real API. +@task +def test_iam_fallback(role_to_assume_arn, mwaa_env_name): + assumed_role = StsHook().conn.assume_role( + RoleArn=role_to_assume_arn, RoleSessionName="MwaaSysTestIamFallback" + ) + + credentials = assumed_role["Credentials"] + session = boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + mwaa_hook = MwaaHook() + mwaa_hook.conn = session.client("mwaa") + response = mwaa_hook.invoke_rest_api(env_name=mwaa_env_name, path="/dags", method="GET") + return "dags" in response["RestApiResponse"] + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY] + trigger_dag_id = test_context[EXISTING_DAG_ID_KEY] + task_id = test_context[EXISTING_TASK_ID_KEY] + restricted_role_arn = test_context[ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY] + + # [START howto_operator_mwaa_trigger_dag_run] + trigger_dag_run = MwaaTriggerDagRunOperator( + task_id="trigger_dag_run", + env_name=env_name, + trigger_dag_id=trigger_dag_id, + wait_for_completion=True, + ) + # [END howto_operator_mwaa_trigger_dag_run] + + # [START howto_sensor_mwaa_task] + wait_for_task = MwaaTaskSensor( + task_id="wait_for_task", + external_env_name=env_name, + external_dag_id=trigger_dag_id, + external_task_id=task_id, + poke_interval=5, + ) + # [END howto_sensor_mwaa_task] + + # [START howto_sensor_mwaa_dag_run] + wait_for_dag_run = MwaaDagRunSensor( + task_id="wait_for_dag_run", + external_env_name=env_name, + external_dag_id=trigger_dag_id, + external_dag_run_id="{{ task_instance.xcom_pull(task_ids='trigger_dag_run')['RestApiResponse']['dag_run_id'] }}", + poke_interval=5, + ) + # [END howto_sensor_mwaa_dag_run] + + trigger_dag_run_dont_wait = MwaaTriggerDagRunOperator( + task_id="trigger_dag_run_dont_wait", + env_name=env_name, + trigger_dag_id=trigger_dag_id, + wait_for_completion=False, + ) + + wait_for_task_concurrent = MwaaTaskSensor( + task_id="wait_for_task_concurrent", + external_env_name=env_name, + external_dag_id=trigger_dag_id, + external_task_id=task_id, + poke_interval=5, + ) + + test_context >> [ + unpause_dag(env_name, trigger_dag_id), + test_iam_fallback(restricted_role_arn, env_name), + trigger_dag_run, + trigger_dag_run_dont_wait, + ] + chain(trigger_dag_run, wait_for_task, wait_for_dag_run) + chain(trigger_dag_run_dont_wait, wait_for_task_concurrent) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py index d8046db33a847..41dd43a8650ca 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py @@ -105,7 +105,7 @@ def test_invoke_rest_api_local_token_parameter( mock_conn.invoke_rest_api.assert_called_once() @mock.patch.object(MwaaHook, "_get_session_conn") - def test_invoke_rest_api_fallback_success_when_iam_fails( + def test_invoke_rest_api_fallback_success_when_iam_fails_with_airflow2( self, mock_get_session_conn, mock_conn, example_responses ): boto_invoke_error = ClientError( @@ -117,6 +117,7 @@ def test_invoke_rest_api_fallback_success_when_iam_fails( "method": METHOD, "url": f"https://{HOSTNAME}/api/v1{PATH}", "params": QUERY_PARAMS, + "headers": {}, "json": BODY, "timeout": 10, } @@ -127,7 +128,7 @@ def test_invoke_rest_api_fallback_success_when_iam_fails( mock_session = mock.MagicMock() mock_session.request.return_value = mock_response - mock_get_session_conn.return_value = (mock_session, HOSTNAME) + mock_get_session_conn.return_value = (mock_session, HOSTNAME, None) retval = self.hook.invoke_rest_api( env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, query_params=QUERY_PARAMS @@ -137,6 +138,50 @@ def test_invoke_rest_api_fallback_success_when_iam_fails( mock_response.raise_for_status.assert_called_once() assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} + @mock.patch.object(MwaaHook, "_get_session_conn") + def test_invoke_rest_api_fallback_success_when_iam_fails_with_airflow3( + self, mock_get_session_conn, mock_conn, example_responses + ): + boto_invoke_error = ClientError( + error_response=example_responses["missingIamRole"], operation_name="invoke_rest_api" + ) + mock_conn.invoke_rest_api = mock.MagicMock(side_effect=boto_invoke_error) + + kwargs_to_assert = { + "method": METHOD, + "url": f"https://{HOSTNAME}/api/v2{PATH}", + "params": QUERY_PARAMS, + "headers": { + "Authorization": "Bearer token", + "Content-Type": "application/json", + }, + "json": BODY, + "timeout": 10, + } + + mock_response = mock.MagicMock() + mock_response.status_code = example_responses["success"]["RestApiStatusCode"] + mock_response.json.return_value = example_responses["success"]["RestApiResponse"] + mock_session = mock.MagicMock() + mock_session.request.return_value = mock_response + mock_login_response = mock.MagicMock() + mock_login_response.cookies = {"_token": "token"} + + mock_get_session_conn.return_value = (mock_session, HOSTNAME, mock_login_response) + + retval = self.hook.invoke_rest_api( + env_name=ENV_NAME, + path=PATH, + method=METHOD, + body=BODY, + query_params=QUERY_PARAMS, + airflow_version=3, + ) + + mock_session.request.assert_called_once_with(**kwargs_to_assert) + mock_response.raise_for_status.assert_called_once() + assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} + @mock.patch.object(MwaaHook, "_get_session_conn") def test_invoke_rest_api_using_local_session_token_failure( self, mock_get_session_conn, example_responses @@ -149,7 +194,7 @@ def test_invoke_rest_api_using_local_session_token_failure( mock_session = mock.MagicMock() mock_session.request.return_value = mock_response - mock_get_session_conn.return_value = (mock_session, HOSTNAME) + mock_get_session_conn.return_value = (mock_session, HOSTNAME, None) mock_error_log = mock.MagicMock() self.hook.log.error = mock_error_log @@ -161,23 +206,46 @@ def test_invoke_rest_api_using_local_session_token_failure( mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"]) @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") - def test_get_session_conn(self, mock_create_session, mock_conn): + def test_get_session_conn_airflow2(self, mock_create_session, mock_conn): token = "token" mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token} login_url = f"https://{HOSTNAME}/aws_mwaa/login" login_payload = {"token": token} mock_session = mock.MagicMock() + mock_login_response = mock.MagicMock() + mock_session.post.return_value = mock_login_response + mock_create_session.return_value = mock_session + + retval = self.hook._get_session_conn(env_name=ENV_NAME, airflow_version=2) + + mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME) + mock_create_session.assert_called_once_with() + mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10) + mock_session.post.return_value.raise_for_status.assert_called_once() + + assert retval == (mock_session, HOSTNAME, mock_login_response) + + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") + def test_get_session_conn_airflow3(self, mock_create_session, mock_conn): + token = "token" + mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token} + login_url = f"https://{HOSTNAME}/pluginsv2/aws_mwaa/login" + login_payload = {"token": token} + + mock_session = mock.MagicMock() + mock_login_response = mock.MagicMock() + mock_session.post.return_value = mock_login_response mock_create_session.return_value = mock_session - retval = self.hook._get_session_conn(env_name=ENV_NAME) + retval = self.hook._get_session_conn(env_name=ENV_NAME, airflow_version=3) mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME) mock_create_session.assert_called_once_with() mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10) mock_session.post.return_value.raise_for_status.assert_called_once() - assert retval == (mock_session, HOSTNAME) + assert retval == (mock_session, HOSTNAME, mock_login_response) @pytest.fixture def example_responses(self): diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/operators/test_mwaa.py index 9c1b9527fbe15..560f0eff91ab9 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_mwaa.py @@ -85,6 +85,7 @@ def test_execute(self, mock_hook): "conf": OP_KWARGS["conf"], "note": OP_KWARGS["note"], }, + airflow_version=None, ) assert op_ret_val == HOOK_RETURN_VALUE