Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool
(task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
(task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
(task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
(task_instance, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"),
(dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"),
(task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
(dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
]
Expand Down
77 changes: 33 additions & 44 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

from __future__ import annotations

from datetime import datetime
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest

from airflow.sdk import get_current_context
from airflow.sdk import BaseOperator, get_current_context
from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse
from airflow.sdk.definitions.asset import (
Asset,
Expand Down Expand Up @@ -117,75 +116,65 @@ def test_convert_variable_result_to_variable_with_deserialize_json():


class TestAirflowContextHelpers:
def setup_method(self):
self.dag_id = "dag_id"
self.task_id = "task_id"
self.try_number = 1
self.logical_date = "2017-05-21T00:00:00"
self.dag_run_id = "dag_run_id"
self.owner = ["owner1", "owner2"]
self.email = ["email1@test.com"]
self.context = {
"dag_run": mock.MagicMock(
name="dag_run",
run_id=self.dag_run_id,
logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"),
),
"task_instance": mock.MagicMock(
name="task_instance",
task_id=self.task_id,
dag_id=self.dag_id,
try_number=self.try_number,
logical_date=datetime.strptime(self.logical_date, "%Y-%m-%dT%H:%M:%S"),
),
"task": mock.MagicMock(name="task", owner=self.owner, email=self.email),
}

def test_context_to_airflow_vars_empty_context(self):
assert context_to_airflow_vars({}) == {}

def test_context_to_airflow_vars_all_context(self):
assert context_to_airflow_vars(self.context) == {
"airflow.ctx.dag_id": self.dag_id,
"airflow.ctx.logical_date": self.logical_date,
"airflow.ctx.task_id": self.task_id,
"airflow.ctx.dag_run_id": self.dag_run_id,
"airflow.ctx.try_number": str(self.try_number),
def test_context_to_airflow_vars_all_context(self, create_runtime_ti):
task = BaseOperator(
task_id="test_context_vars",
owner=["owner1", "owner2"],
email="email1@test.com",
)

rti = create_runtime_ti(
task=task,
dag_id="dag_id",
run_id="dag_run_id",
logical_date="2017-05-21T00:00:00Z",
try_number=1,
)
context = rti.get_template_context()
assert context_to_airflow_vars(context) == {
"airflow.ctx.dag_id": "dag_id",
"airflow.ctx.logical_date": "2017-05-21T00:00:00+00:00",
"airflow.ctx.task_id": "test_context_vars",
"airflow.ctx.dag_run_id": "dag_run_id",
"airflow.ctx.try_number": "1",
"airflow.ctx.dag_owner": "owner1,owner2",
"airflow.ctx.dag_email": "email1@test.com",
}

assert context_to_airflow_vars(self.context, in_env_var_format=True) == {
"AIRFLOW_CTX_DAG_ID": self.dag_id,
"AIRFLOW_CTX_LOGICAL_DATE": self.logical_date,
"AIRFLOW_CTX_TASK_ID": self.task_id,
"AIRFLOW_CTX_TRY_NUMBER": str(self.try_number),
"AIRFLOW_CTX_DAG_RUN_ID": self.dag_run_id,
assert context_to_airflow_vars(context, in_env_var_format=True) == {
"AIRFLOW_CTX_DAG_ID": "dag_id",
"AIRFLOW_CTX_LOGICAL_DATE": "2017-05-21T00:00:00+00:00",
"AIRFLOW_CTX_TASK_ID": "test_context_vars",
"AIRFLOW_CTX_TRY_NUMBER": "1",
"AIRFLOW_CTX_DAG_RUN_ID": "dag_run_id",
"AIRFLOW_CTX_DAG_OWNER": "owner1,owner2",
"AIRFLOW_CTX_DAG_EMAIL": "email1@test.com",
}

def test_context_to_airflow_vars_with_default_context_vars(self):
def test_context_to_airflow_vars_from_policy(self):
with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
airflow_cluster = "cluster-a"
mock_method.return_value = {"airflow_cluster": airflow_cluster}

context_vars = context_to_airflow_vars(self.context)
context_vars = context_to_airflow_vars({})
assert context_vars["airflow.ctx.airflow_cluster"] == airflow_cluster

context_vars = context_to_airflow_vars(self.context, in_env_var_format=True)
context_vars = context_to_airflow_vars({}, in_env_var_format=True)
assert context_vars["AIRFLOW_CTX_AIRFLOW_CLUSTER"] == airflow_cluster

with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
mock_method.return_value = {"airflow_cluster": [1, 2]}
with pytest.raises(TypeError) as error:
context_to_airflow_vars(self.context)
context_to_airflow_vars({})
assert str(error.value) == "value of key <airflow_cluster> must be string, not <class 'list'>"

with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
mock_method.return_value = {1: "value"}
with pytest.raises(TypeError) as error:
context_to_airflow_vars(self.context)
context_to_airflow_vars({})
assert str(error.value) == "key <1> must be string"


Expand Down