diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index afe41bcc31839..9ca1ba427ee60 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -128,6 +128,32 @@ def get_operator_class(task: BaseOperator) -> type: return task.__class__ +def get_operator_provider_version(operator: BaseOperator | MappedOperator) -> str | None: + """Get the provider package version for the given operator.""" + try: + class_path = get_fully_qualified_class_name(operator) + + if not class_path.startswith("airflow.providers."): + return None + + from airflow.providers_manager import ProvidersManager + + providers_manager = ProvidersManager() + + for package_name, provider_info in providers_manager.providers.items(): + if package_name.startswith("apache-airflow-providers-"): + provider_module_path = package_name.replace( + "apache-airflow-providers-", "airflow.providers." + ).replace("-", ".") + if class_path.startswith(provider_module_path + "."): + return provider_info.version + + return None + + except Exception: + return None + + def get_job_name(task: TaskInstance | RuntimeTaskInstance) -> str: return f"{task.dag_id}.{task.task_id}" @@ -511,6 +537,7 @@ class TaskInfo(InfoJsonEncodable): ), "inlets": lambda task: [AssetInfo(i) for i in task.inlets if isinstance(i, Asset)], "outlets": lambda task: [AssetInfo(o) for o in task.outlets if isinstance(o, Asset)], + "operator_provider_version": lambda task: get_operator_provider_version(task), } diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 6357913e31665..d3aea277e7904 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -54,6 +54,7 @@ get_fully_qualified_class_name, get_job_name, get_operator_class, + get_operator_provider_version, get_user_provided_run_facets, ) from airflow.providers.standard.operators.empty import EmptyOperator @@ -1600,6 +1601,7 @@ def __init__(self, *args, **kwargs): "multiple_outputs": False, "operator_class": "CustomOperator", "operator_class_path": get_fully_qualified_class_name(task_10), + "operator_provider_version": None, # Custom operator doesn't have provider version "outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]", "owner": "airflow", "priority_weight": 1, @@ -1677,6 +1679,7 @@ def __init__(self, *args, **kwargs): "multiple_outputs": False, "operator_class": "CustomOperator", "operator_class_path": get_fully_qualified_class_name(task_10), + "operator_provider_version": None, # Custom operator doesn't have provider version "outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]", "owner": "airflow", "priority_weight": 1, @@ -1698,3 +1701,85 @@ def test_task_info_complete(): task_0 = BashOperator(task_id="task_0", bash_command="exit 0;") result = TaskInfoComplete(task_0) assert "'bash_command': 'exit 0;'" in str(result) + + +@patch("airflow.providers.openlineage.utils.utils.get_fully_qualified_class_name") +def test_get_operator_provider_version_exception_handling(mock_class_name): + mock_class_name.side_effect = Exception("Test exception") + operator = MagicMock() + assert get_operator_provider_version(operator) is None + + +def test_get_operator_provider_version_for_core_operator(): + """Test that get_operator_provider_version returns None for core operators.""" + operator = BaseOperator(task_id="test_task") + result = get_operator_provider_version(operator) + assert result is None + + +@patch("airflow.providers_manager.ProvidersManager") +def test_get_operator_provider_version_for_provider_operator(mock_providers_manager): + """Test that get_operator_provider_version returns version for provider operators.""" + # Mock ProvidersManager + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + + # Mock providers data + mock_manager_instance.providers = { + "apache-airflow-providers-standard": MagicMock(version="1.2.0"), + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), + "apache-airflow-providers-google": MagicMock(version="10.5.0"), + } + + # Test with BashOperator (standard provider) + operator = BashOperator(task_id="test_task", bash_command="echo test") + result = get_operator_provider_version(operator) + assert result == "1.2.0" + + +@patch("airflow.providers_manager.ProvidersManager") +def test_get_operator_provider_version_provider_not_found(mock_providers_manager): + """Test that get_operator_provider_version returns None when provider is not found.""" + # Mock ProvidersManager with no matching provider + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + mock_manager_instance.providers = { + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), + "apache-airflow-providers-google": MagicMock(version="10.5.0"), + } + + operator = BashOperator(task_id="test_task", bash_command="echo test") + result = get_operator_provider_version(operator) + assert result is None + + +def test_get_operator_provider_version_for_custom_operator(): + """Test that get_operator_provider_version returns None for custom operators.""" + + # Create a custom operator that doesn't belong to any provider + class CustomOperator(BaseOperator): + def execute(self, context): + pass + + operator = CustomOperator(task_id="test_task") + result = get_operator_provider_version(operator) + assert result is None + + +@patch("airflow.providers_manager.ProvidersManager") +def test_get_operator_provider_version_for_mapped_operator(mock_providers_manager): + """Test that get_operator_provider_version works with mapped operators.""" + # Mock ProvidersManager + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + + # Mock providers data + mock_manager_instance.providers = { + "apache-airflow-providers-standard": MagicMock(version="1.2.0"), + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), + } + + # Test with mapped BashOperator (standard provider) + mapped_operator = BashOperator.partial(task_id="test_task").expand(bash_command=["echo 1", "echo 2"]) + result = get_operator_provider_version(mapped_operator) + assert result == "1.2.0"