diff --git a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py index 0d187ebd5144b..57ebf6504c0de 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -280,7 +280,7 @@ def paginate( if top and odata_count: if len(response.get("value", [])) == top and context: results = operator.pull_xcom(context=context) - skip = sum(map(lambda result: len(result["value"]), results)) + top if results else top + skip = sum([len(result["value"]) for result in results]) + top if results else top # type: ignore query_parameters["$skip"] = skip return operator.url, query_parameters return response.get("@odata.nextLink"), operator.query_parameters diff --git a/providers/tests/microsoft/azure/base.py b/providers/tests/microsoft/azure/base.py index 98c0a59867ea4..600e4ce488e08 100644 --- a/providers/tests/microsoft/azure/base.py +++ b/providers/tests/microsoft/azure/base.py @@ -68,6 +68,7 @@ async def deferrable_operator(self, context, operator): result = None triggered_events = [] try: + operator.render_template_fields(context=context) result = operator.execute(context=context) except TaskDeferred as deferred: task = deferred diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index fe404e48e6f0a..2c9c8129d5d08 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -35,6 +35,7 @@ mock_json_response, mock_response, ) +from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: from airflow.utils.context import Context @@ -143,11 +144,40 @@ def test_execute_when_response_is_bytes(self): task_id="drive_item_content", conn_id="msgraph_api", response_type="bytes", - url=f"/drives/{drive_id}/root/content", + url="/drives/{drive_id}/root/content", + path_parameters={"drive_id": drive_id}, ) results, events = self.execute_operator(operator) + assert operator.path_parameters == {"drive_id": drive_id} + assert results == base64_encoded_content + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.bytes" + assert events[0].payload["response"] == base64_encoded_content + + @pytest.mark.db_test + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0") + def test_execute_with_lambda_parameter_when_response_is_bytes(self): + content = load_file("resources", "dummy.pdf", mode="rb", encoding=None) + base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding()) + drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22" + response = mock_response(200, content) + + with self.patch_hook_and_request_adapter(response): + operator = MSGraphAsyncOperator( + task_id="drive_item_content", + conn_id="msgraph_api", + response_type="bytes", + url="/drives/{drive_id}/root/content", + path_parameters=lambda context, jinja_env: {"drive_id": drive_id}, + ) + + results, events = self.execute_operator(operator) + + assert operator.path_parameters == {"drive_id": drive_id} assert results == base64_encoded_content assert len(events) == 1 assert isinstance(events[0], TriggerEvent) diff --git a/providers/tests/microsoft/azure/sensors/test_msgraph.py b/providers/tests/microsoft/azure/sensors/test_msgraph.py index ba5ba35478861..8b8ec793d65ce 100644 --- a/providers/tests/microsoft/azure/sensors/test_msgraph.py +++ b/providers/tests/microsoft/azure/sensors/test_msgraph.py @@ -18,11 +18,14 @@ import json +import pytest + from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor from airflow.triggers.base import TriggerEvent from providers.tests.microsoft.azure.base import Base from providers.tests.microsoft.conftest import load_json, mock_json_response +from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS class TestMSGraphSensor(Base): @@ -42,6 +45,33 @@ def test_execute(self): results, events = self.execute_operator(sensor) + assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} + assert isinstance(results, str) + assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(status) + + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0") + def test_execute_with_lambda_parameter(self): + status = load_json("resources", "status.json") + response = mock_json_response(200, status) + + with self.patch_hook_and_request_adapter(response): + sensor = MSGraphSensor( + task_id="check_workspaces_status", + conn_id="powerbi", + url="myorg/admin/workspaces/scanStatus/{scanId}", + path_parameters=lambda context, jinja_env: {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}, + result_processor=lambda context, result: result["id"], + timeout=350.0, + ) + + results, events = self.execute_operator(sensor) + + assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} assert isinstance(results, str) assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" assert len(events) == 1 diff --git a/providers/tests/microsoft/conftest.py b/providers/tests/microsoft/conftest.py index 6ac4c0a03a368..bf6a291ee9dbc 100644 --- a/providers/tests/microsoft/conftest.py +++ b/providers/tests/microsoft/conftest.py @@ -132,15 +132,15 @@ def __init__( def xcom_pull( self, - task_ids: Iterable[str] | str | None = None, + task_ids: str | Iterable[str] | None = None, dag_id: str | None = None, key: str = XCOM_RETURN_KEY, include_prior_dates: bool = False, session: Session = NEW_SESSION, - run_id: str | None = None, *, - map_indexes: Iterable[int] | int | None = None, - default: Any | None = None, + map_indexes: int | Iterable[int] | None = None, + default: Any = None, + run_id: str | None = None, ) -> Any: if map_indexes: return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")