diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 44f699deb295a..23de7b18d169f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -701,7 +701,7 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"), (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"), (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"), - (task_instance, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"), + (dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"), (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"), (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index eeca659533829..59266da758892 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -17,13 +17,12 @@ from __future__ import annotations -from datetime import datetime from unittest import mock from unittest.mock import MagicMock, patch import pytest -from airflow.sdk import get_current_context +from airflow.sdk import BaseOperator, get_current_context from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse from airflow.sdk.definitions.asset import ( Asset, @@ -117,75 +116,65 @@ def test_convert_variable_result_to_variable_with_deserialize_json(): class TestAirflowContextHelpers: - def setup_method(self): - self.dag_id = "dag_id" - self.task_id = "task_id" - self.try_number = 1 - self.logical_date = "2017-05-21T00:00:00" - self.dag_run_id = "dag_run_id" - self.owner = ["owner1", "owner2"] - self.email = ["email1@test.com"] - self.context = { - "dag_run": mock.MagicMock( - name="dag_run", - run_id=self.dag_run_id, - logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"), - ), - "task_instance": mock.MagicMock( - name="task_instance", - task_id=self.task_id, - dag_id=self.dag_id, - try_number=self.try_number, - logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"), - ), - "task": mock.MagicMock(name="task", owner=self.owner, email=self.email), - } - def test_context_to_airflow_vars_empty_context(self): assert context_to_airflow_vars({}) == {} - def test_context_to_airflow_vars_all_context(self): - assert context_to_airflow_vars(self.context) == { - "airflow.ctx.dag_id": self.dag_id, - "airflow.ctx.logical_date": self.logical_date, - "airflow.ctx.task_id": self.task_id, - "airflow.ctx.dag_run_id": self.dag_run_id, - "airflow.ctx.try_number": str(self.try_number), + def test_context_to_airflow_vars_all_context(self, create_runtime_ti): + task = BaseOperator( + task_id="test_context_vars", + owner=["owner1", "owner2"], + email="email1@test.com", + ) + + rti = create_runtime_ti( + task=task, + dag_id="dag_id", + run_id="dag_run_id", + logical_date="2017-05-21T00:00:00Z", + try_number=1, + ) + context = rti.get_template_context() + assert context_to_airflow_vars(context) == { + "airflow.ctx.dag_id": "dag_id", + "airflow.ctx.logical_date": "2017-05-21T00:00:00+00:00", + "airflow.ctx.task_id": "test_context_vars", + "airflow.ctx.dag_run_id": "dag_run_id", + "airflow.ctx.try_number": "1", "airflow.ctx.dag_owner": "owner1,owner2", "airflow.ctx.dag_email": "email1@test.com", } - assert context_to_airflow_vars(self.context, in_env_var_format=True) == { - "AIRFLOW_CTX_DAG_ID": self.dag_id, - "AIRFLOW_CTX_LOGICAL_DATE": self.logical_date, - "AIRFLOW_CTX_TASK_ID": self.task_id, - "AIRFLOW_CTX_TRY_NUMBER": str(self.try_number), - "AIRFLOW_CTX_DAG_RUN_ID": self.dag_run_id, + assert context_to_airflow_vars(context, in_env_var_format=True) == { + "AIRFLOW_CTX_DAG_ID": "dag_id", + "AIRFLOW_CTX_LOGICAL_DATE": "2017-05-21T00:00:00+00:00", + "AIRFLOW_CTX_TASK_ID": "test_context_vars", + "AIRFLOW_CTX_TRY_NUMBER": "1", + "AIRFLOW_CTX_DAG_RUN_ID": "dag_run_id", "AIRFLOW_CTX_DAG_OWNER": "owner1,owner2", "AIRFLOW_CTX_DAG_EMAIL": "email1@test.com", } - def test_context_to_airflow_vars_with_default_context_vars(self): + def test_context_to_airflow_vars_from_policy(self): with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method: airflow_cluster = "cluster-a" mock_method.return_value = {"airflow_cluster": airflow_cluster} - context_vars = context_to_airflow_vars(self.context) + context_vars = context_to_airflow_vars({}) assert context_vars["airflow.ctx.airflow_cluster"] == airflow_cluster - context_vars = context_to_airflow_vars(self.context, in_env_var_format=True) + context_vars = context_to_airflow_vars({}, in_env_var_format=True) assert context_vars["AIRFLOW_CTX_AIRFLOW_CLUSTER"] == airflow_cluster with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method: mock_method.return_value = {"airflow_cluster": [1, 2]} with pytest.raises(TypeError) as error: - context_to_airflow_vars(self.context) + context_to_airflow_vars({}) assert str(error.value) == "value of key must be string, not " with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method: mock_method.return_value = {1: "value"} with pytest.raises(TypeError) as error: - context_to_airflow_vars(self.context) + context_to_airflow_vars({}) assert str(error.value) == "key <1> must be string"