Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,32 @@ def get_operator_class(task: BaseOperator) -> type:
return task.__class__


def get_operator_provider_version(operator: BaseOperator | MappedOperator) -> str | None:
"""Get the provider package version for the given operator."""
try:
class_path = get_fully_qualified_class_name(operator)

if not class_path.startswith("airflow.providers."):
return None

from airflow.providers_manager import ProvidersManager

providers_manager = ProvidersManager()

for package_name, provider_info in providers_manager.providers.items():
if package_name.startswith("apache-airflow-providers-"):
provider_module_path = package_name.replace(
"apache-airflow-providers-", "airflow.providers."
).replace("-", ".")
if class_path.startswith(provider_module_path + "."):
return provider_info.version

return None

except Exception:
return None


def get_job_name(task: TaskInstance | RuntimeTaskInstance) -> str:
return f"{task.dag_id}.{task.task_id}"

Expand Down Expand Up @@ -511,6 +537,7 @@ class TaskInfo(InfoJsonEncodable):
),
"inlets": lambda task: [AssetInfo(i) for i in task.inlets if isinstance(i, Asset)],
"outlets": lambda task: [AssetInfo(o) for o in task.outlets if isinstance(o, Asset)],
"operator_provider_version": lambda task: get_operator_provider_version(task),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
get_fully_qualified_class_name,
get_job_name,
get_operator_class,
get_operator_provider_version,
get_user_provided_run_facets,
)
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -1600,6 +1601,7 @@ def __init__(self, *args, **kwargs):
"multiple_outputs": False,
"operator_class": "CustomOperator",
"operator_class_path": get_fully_qualified_class_name(task_10),
"operator_provider_version": None, # Custom operator doesn't have provider version
"outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]",
"owner": "airflow",
"priority_weight": 1,
Expand Down Expand Up @@ -1677,6 +1679,7 @@ def __init__(self, *args, **kwargs):
"multiple_outputs": False,
"operator_class": "CustomOperator",
"operator_class_path": get_fully_qualified_class_name(task_10),
"operator_provider_version": None, # Custom operator doesn't have provider version
"outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]",
"owner": "airflow",
"priority_weight": 1,
Expand All @@ -1698,3 +1701,85 @@ def test_task_info_complete():
task_0 = BashOperator(task_id="task_0", bash_command="exit 0;")
result = TaskInfoComplete(task_0)
assert "'bash_command': 'exit 0;'" in str(result)


@patch("airflow.providers.openlineage.utils.utils.get_fully_qualified_class_name")
def test_get_operator_provider_version_exception_handling(mock_class_name):
mock_class_name.side_effect = Exception("Test exception")
operator = MagicMock()
assert get_operator_provider_version(operator) is None


def test_get_operator_provider_version_for_core_operator():
"""Test that get_operator_provider_version returns None for core operators."""
operator = BaseOperator(task_id="test_task")
result = get_operator_provider_version(operator)
assert result is None


@patch("airflow.providers_manager.ProvidersManager")
def test_get_operator_provider_version_for_provider_operator(mock_providers_manager):
"""Test that get_operator_provider_version returns version for provider operators."""
# Mock ProvidersManager
mock_manager_instance = MagicMock()
mock_providers_manager.return_value = mock_manager_instance

# Mock providers data
mock_manager_instance.providers = {
"apache-airflow-providers-standard": MagicMock(version="1.2.0"),
"apache-airflow-providers-amazon": MagicMock(version="8.12.0"),
"apache-airflow-providers-google": MagicMock(version="10.5.0"),
}

# Test with BashOperator (standard provider)
operator = BashOperator(task_id="test_task", bash_command="echo test")
result = get_operator_provider_version(operator)
assert result == "1.2.0"


@patch("airflow.providers_manager.ProvidersManager")
def test_get_operator_provider_version_provider_not_found(mock_providers_manager):
"""Test that get_operator_provider_version returns None when provider is not found."""
# Mock ProvidersManager with no matching provider
mock_manager_instance = MagicMock()
mock_providers_manager.return_value = mock_manager_instance
mock_manager_instance.providers = {
"apache-airflow-providers-amazon": MagicMock(version="8.12.0"),
"apache-airflow-providers-google": MagicMock(version="10.5.0"),
}

operator = BashOperator(task_id="test_task", bash_command="echo test")
result = get_operator_provider_version(operator)
assert result is None


def test_get_operator_provider_version_for_custom_operator():
"""Test that get_operator_provider_version returns None for custom operators."""

# Create a custom operator that doesn't belong to any provider
class CustomOperator(BaseOperator):
def execute(self, context):
pass

operator = CustomOperator(task_id="test_task")
result = get_operator_provider_version(operator)
assert result is None


@patch("airflow.providers_manager.ProvidersManager")
def test_get_operator_provider_version_for_mapped_operator(mock_providers_manager):
"""Test that get_operator_provider_version works with mapped operators."""
# Mock ProvidersManager
mock_manager_instance = MagicMock()
mock_providers_manager.return_value = mock_manager_instance

# Mock providers data
mock_manager_instance.providers = {
"apache-airflow-providers-standard": MagicMock(version="1.2.0"),
"apache-airflow-providers-amazon": MagicMock(version="8.12.0"),
}

# Test with mapped BashOperator (standard provider)
mapped_operator = BashOperator.partial(task_id="test_task").expand(bash_command=["echo 1", "echo 2"])
result = get_operator_provider_version(mapped_operator)
assert result == "1.2.0"