From 14e79f4e866064532040b9bbe533af17de0be577 Mon Sep 17 00:00:00 2001 From: Rahul Madan Date: Sun, 29 Jun 2025 22:15:06 +0530 Subject: [PATCH 1/2] added another attribute containing the provider package version of the operator being used. Signed-off-by: Rahul Madan --- .../providers/openlineage/utils/utils.py | 25 ++++++ .../unit/openlineage/utils/test_utils.py | 87 +++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index afe41bcc31839..76ba08a5bff45 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -128,6 +128,30 @@ def get_operator_class(task: BaseOperator) -> type: return task.__class__ +def get_operator_provider_version(operator: BaseOperator | MappedOperator) -> str | None: + """Get the provider package version for the given operator.""" + try: + class_path = get_fully_qualified_class_name(operator) + + if not class_path.startswith('airflow.providers.'): + return None + + from airflow.providers_manager import ProvidersManager + + providers_manager = ProvidersManager() + + for package_name, provider_info in providers_manager.providers.items(): + if package_name.startswith('apache-airflow-providers-'): + provider_module_path = package_name.replace('apache-airflow-providers-', 'airflow.providers.').replace('-', '.') + if class_path.startswith(provider_module_path + '.'): + return provider_info.version + + return None + + except Exception: + return None + + def get_job_name(task: TaskInstance | RuntimeTaskInstance) -> str: return f"{task.dag_id}.{task.task_id}" @@ -511,6 +535,7 @@ class TaskInfo(InfoJsonEncodable): ), "inlets": lambda task: [AssetInfo(i) for i in task.inlets if isinstance(i, Asset)], "outlets": lambda task: [AssetInfo(o) for o in task.outlets if isinstance(o, Asset)], + "operator_provider_version": lambda task: get_operator_provider_version(task), } diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 6357913e31665..727a76a0509b5 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -54,6 +54,7 @@ get_fully_qualified_class_name, get_job_name, get_operator_class, + get_operator_provider_version, get_user_provided_run_facets, ) from airflow.providers.standard.operators.empty import EmptyOperator @@ -1600,6 +1601,7 @@ def __init__(self, *args, **kwargs): "multiple_outputs": False, "operator_class": "CustomOperator", "operator_class_path": get_fully_qualified_class_name(task_10), + "operator_provider_version": None, # Custom operator doesn't have provider version "outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]", "owner": "airflow", "priority_weight": 1, @@ -1677,6 +1679,7 @@ def __init__(self, *args, **kwargs): "multiple_outputs": False, "operator_class": "CustomOperator", "operator_class_path": get_fully_qualified_class_name(task_10), + "operator_provider_version": None, # Custom operator doesn't have provider version "outlets": "[{'uri': 'uri2', 'extra': {'b': 2}}, {'uri': 'uri3', 'extra': {'c': 3}}]", "owner": "airflow", "priority_weight": 1, @@ -1698,3 +1701,87 @@ def test_task_info_complete(): task_0 = BashOperator(task_id="task_0", bash_command="exit 0;") result = TaskInfoComplete(task_0) assert "'bash_command': 'exit 0;'" in str(result) + + +@patch('airflow.providers.openlineage.utils.utils.get_fully_qualified_class_name') +def test_get_operator_provider_version_exception_handling(mock_class_name): + mock_class_name.side_effect = Exception("Test exception") + operator = MagicMock() + assert get_operator_provider_version(operator) is None + + +def test_get_operator_provider_version_for_core_operator(): + """Test that get_operator_provider_version returns None for core operators.""" + operator = BaseOperator(task_id="test_task") + result = get_operator_provider_version(operator) + assert result is None + + +@patch('airflow.providers_manager.ProvidersManager') +def test_get_operator_provider_version_for_provider_operator(mock_providers_manager): + """Test that get_operator_provider_version returns version for provider operators.""" + # Mock ProvidersManager + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + + # Mock providers data + mock_manager_instance.providers = { + 'apache-airflow-providers-standard': MagicMock(version='1.2.0'), + 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), + 'apache-airflow-providers-google': MagicMock(version='10.5.0'), + } + + # Test with BashOperator (standard provider) + operator = BashOperator(task_id="test_task", bash_command="echo test") + result = get_operator_provider_version(operator) + assert result == '1.2.0' + + +@patch('airflow.providers_manager.ProvidersManager') +def test_get_operator_provider_version_provider_not_found(mock_providers_manager): + """Test that get_operator_provider_version returns None when provider is not found.""" + # Mock ProvidersManager with no matching provider + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + mock_manager_instance.providers = { + 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), + 'apache-airflow-providers-google': MagicMock(version='10.5.0'), + } + + operator = BashOperator(task_id="test_task", bash_command="echo test") + result = get_operator_provider_version(operator) + assert result is None + + +def test_get_operator_provider_version_for_custom_operator(): + """Test that get_operator_provider_version returns None for custom operators.""" + # Create a custom operator that doesn't belong to any provider + class CustomOperator(BaseOperator): + def execute(self, context): + pass + + operator = CustomOperator(task_id="test_task") + result = get_operator_provider_version(operator) + assert result is None + + +@patch('airflow.providers_manager.ProvidersManager') +def test_get_operator_provider_version_for_mapped_operator(mock_providers_manager): + """Test that get_operator_provider_version works with mapped operators.""" + # Mock ProvidersManager + mock_manager_instance = MagicMock() + mock_providers_manager.return_value = mock_manager_instance + + # Mock providers data + mock_manager_instance.providers = { + 'apache-airflow-providers-standard': MagicMock(version='1.2.0'), + 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), + } + + # Test with mapped BashOperator (standard provider) + mapped_operator = BashOperator.partial(task_id="test_task").expand(bash_command=["echo 1", "echo 2"]) + result = get_operator_provider_version(mapped_operator) + assert result == '1.2.0' + + + From ceb903df0c3c3eef2af400afb806a5bc34cde6b3 Mon Sep 17 00:00:00 2001 From: Rahul Madan Date: Mon, 30 Jun 2025 00:18:52 +0530 Subject: [PATCH 2/2] precommit run Signed-off-by: Rahul Madan --- .../providers/openlineage/utils/utils.py | 10 +++--- .../unit/openlineage/utils/test_utils.py | 32 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 76ba08a5bff45..9ca1ba427ee60 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -133,7 +133,7 @@ def get_operator_provider_version(operator: BaseOperator | MappedOperator) -> st try: class_path = get_fully_qualified_class_name(operator) - if not class_path.startswith('airflow.providers.'): + if not class_path.startswith("airflow.providers."): return None from airflow.providers_manager import ProvidersManager @@ -141,9 +141,11 @@ def get_operator_provider_version(operator: BaseOperator | MappedOperator) -> st providers_manager = ProvidersManager() for package_name, provider_info in providers_manager.providers.items(): - if package_name.startswith('apache-airflow-providers-'): - provider_module_path = package_name.replace('apache-airflow-providers-', 'airflow.providers.').replace('-', '.') - if class_path.startswith(provider_module_path + '.'): + if package_name.startswith("apache-airflow-providers-"): + provider_module_path = package_name.replace( + "apache-airflow-providers-", "airflow.providers." + ).replace("-", ".") + if class_path.startswith(provider_module_path + "."): return provider_info.version return None diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 727a76a0509b5..d3aea277e7904 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -1703,7 +1703,7 @@ def test_task_info_complete(): assert "'bash_command': 'exit 0;'" in str(result) -@patch('airflow.providers.openlineage.utils.utils.get_fully_qualified_class_name') +@patch("airflow.providers.openlineage.utils.utils.get_fully_qualified_class_name") def test_get_operator_provider_version_exception_handling(mock_class_name): mock_class_name.side_effect = Exception("Test exception") operator = MagicMock() @@ -1717,7 +1717,7 @@ def test_get_operator_provider_version_for_core_operator(): assert result is None -@patch('airflow.providers_manager.ProvidersManager') +@patch("airflow.providers_manager.ProvidersManager") def test_get_operator_provider_version_for_provider_operator(mock_providers_manager): """Test that get_operator_provider_version returns version for provider operators.""" # Mock ProvidersManager @@ -1726,27 +1726,27 @@ def test_get_operator_provider_version_for_provider_operator(mock_providers_mana # Mock providers data mock_manager_instance.providers = { - 'apache-airflow-providers-standard': MagicMock(version='1.2.0'), - 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), - 'apache-airflow-providers-google': MagicMock(version='10.5.0'), + "apache-airflow-providers-standard": MagicMock(version="1.2.0"), + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), + "apache-airflow-providers-google": MagicMock(version="10.5.0"), } # Test with BashOperator (standard provider) operator = BashOperator(task_id="test_task", bash_command="echo test") result = get_operator_provider_version(operator) - assert result == '1.2.0' + assert result == "1.2.0" -@patch('airflow.providers_manager.ProvidersManager') +@patch("airflow.providers_manager.ProvidersManager") def test_get_operator_provider_version_provider_not_found(mock_providers_manager): """Test that get_operator_provider_version returns None when provider is not found.""" # Mock ProvidersManager with no matching provider mock_manager_instance = MagicMock() mock_providers_manager.return_value = mock_manager_instance mock_manager_instance.providers = { - 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), - 'apache-airflow-providers-google': MagicMock(version='10.5.0'), - } + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), + "apache-airflow-providers-google": MagicMock(version="10.5.0"), + } operator = BashOperator(task_id="test_task", bash_command="echo test") result = get_operator_provider_version(operator) @@ -1755,6 +1755,7 @@ def test_get_operator_provider_version_provider_not_found(mock_providers_manager def test_get_operator_provider_version_for_custom_operator(): """Test that get_operator_provider_version returns None for custom operators.""" + # Create a custom operator that doesn't belong to any provider class CustomOperator(BaseOperator): def execute(self, context): @@ -1765,7 +1766,7 @@ def execute(self, context): assert result is None -@patch('airflow.providers_manager.ProvidersManager') +@patch("airflow.providers_manager.ProvidersManager") def test_get_operator_provider_version_for_mapped_operator(mock_providers_manager): """Test that get_operator_provider_version works with mapped operators.""" # Mock ProvidersManager @@ -1774,14 +1775,11 @@ def test_get_operator_provider_version_for_mapped_operator(mock_providers_manage # Mock providers data mock_manager_instance.providers = { - 'apache-airflow-providers-standard': MagicMock(version='1.2.0'), - 'apache-airflow-providers-amazon': MagicMock(version='8.12.0'), + "apache-airflow-providers-standard": MagicMock(version="1.2.0"), + "apache-airflow-providers-amazon": MagicMock(version="8.12.0"), } # Test with mapped BashOperator (standard provider) mapped_operator = BashOperator.partial(task_id="test_task").expand(bash_command=["echo 1", "echo 2"]) result = get_operator_provider_version(mapped_operator) - assert result == '1.2.0' - - - + assert result == "1.2.0"