diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py index 4c353256a6ec0..a5bbab86f412c 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py @@ -163,7 +163,7 @@ async def get_artifacts_for_steps(steps, artifacts): root_parent_job_name=task_instance.dag_id, root_parent_job_namespace=namespace(), ) - client = get_openlineage_listener().adapter.get_or_create_openlineage_client() + adapter = get_openlineage_listener().adapter # process each step in loop, sending generated events in the same order as steps for counter, artifacts in enumerate(step_artifacts, 1): @@ -193,7 +193,7 @@ async def get_artifacts_for_steps(steps, artifacts): log.debug("Found %s OpenLineage events for artifact no. %s.", len(events), counter) for event in events: - client.emit(event=event) + adapter.emit(event=event) log.debug("Emitted all OpenLineage events for artifact no. %s.", counter) log.info("OpenLineage has successfully finished processing information about DBT job run.") diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py index b76855f44b118..4ac961ecedaab 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py @@ -177,17 +177,15 @@ def test_generate_events( mock_task_instance.dag_id = DAG_ID mock_task_instance.dag_run.clear_number = 0 - mock_client = MagicMock() + mock_adapter = MagicMock() - mock_client.emit.side_effect = emit_event - mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value = ( - mock_client - ) + mock_adapter.emit.side_effect = emit_event + mock_get_openlineage_listener.return_value.adapter = mock_adapter mock_build_task_instance_run_id.return_value = TASK_UUID mock_build_dag_run_id.return_value = DAG_UUID generate_openlineage_events_from_dbt_cloud_run(mock_operator, task_instance=mock_task_instance) - assert mock_client.emit.call_count == 4 + assert mock_adapter.emit.call_count == 4 def test_do_not_raise_error_if_runid_not_set_on_operator(self): operator = DbtCloudRunJobOperator(task_id="dbt-job-runid-taskid", job_id=1500) diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index 87afcb11752c4..bae4fa5bb0ab1 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -280,6 +280,7 @@ def emit_openlineage_events_for_snowflake_queries( from airflow.providers.common.compat.openlineage.facet import ( ErrorMessageRunFacet, ExternalQueryRunFacet, + RunFacet, SQLJobFacet, ) from airflow.providers.openlineage.conf import namespace @@ -303,7 +304,6 @@ def emit_openlineage_events_for_snowflake_queries( # If no query metadata is provided, we use task_instance's state when checking for success default_state = task_instance.state.value if hasattr(task_instance, "state") else "" - log.debug("Generating OpenLineage facets") common_run_facets = {"parent": _get_parent_run_facet(task_instance)} common_job_facets: dict[str, JobFacet] = { "jobType": job_type_job.JobTypeJobFacet( @@ -325,12 +325,11 @@ def emit_openlineage_events_for_snowflake_queries( query_metadata if query_metadata else "not found", ) - # TODO(potiuk): likely typing here needs to be fixed - query_specific_run_facets = { # type : ignore[assignment] + query_specific_run_facets: dict[str, RunFacet] = { "externalQuery": ExternalQueryRunFacet(externalQueryId=query_id, source=query_source_namespace) } if query_metadata.get("ERROR_MESSAGE"): - query_specific_run_facets["error"] = ErrorMessageRunFacet( # type: ignore[assignment] + query_specific_run_facets["error"] = ErrorMessageRunFacet( message=f"{query_metadata.get('ERROR_CODE')} : {query_metadata['ERROR_MESSAGE']}", programmingLanguage="SQL", ) @@ -353,9 +352,9 @@ def emit_openlineage_events_for_snowflake_queries( events.extend(event_batch) log.debug("Generated %s OpenLineage events; emitting now.", len(events)) - client = get_openlineage_listener().adapter.get_or_create_openlineage_client() + adapter = get_openlineage_listener().adapter for event in events: - client.emit(event) + adapter.emit(event) log.info("OpenLineage has successfully finished processing information about Snowflake queries.") return diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py index c0a5524073ee8..2127948b3b969 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py @@ -351,10 +351,10 @@ def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_ additional_run_facets = {"custom_run": "value_run"} additional_job_facets = {"custom_job": "value_job"} - fake_client = mock.MagicMock() - fake_client.emit = mock.MagicMock() + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() fake_listener = mock.MagicMock() - fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + fake_listener.adapter = fake_adapter with ( mock.patch( @@ -376,7 +376,7 @@ def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_ ) assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. - assert fake_client.emit.call_count == 6 # Expect two events per query. + assert fake_adapter.emit.call_count == 6 # Expect two events per query. expected_common_job_facets = { "jobType": job_type_job.JobTypeJobFacet( @@ -539,7 +539,7 @@ def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_ ), ] - assert fake_client.emit.call_args_list == expected_calls + assert fake_adapter.emit.call_args_list == expected_calls @mock.patch("importlib.metadata.version", return_value="2.3.0") @@ -573,10 +573,10 @@ def test_emit_openlineage_events_for_snowflake_queries_without_hook( additional_run_facets = {"custom_run": "value_run"} additional_job_facets = {"custom_job": "value_job"} - fake_client = mock.MagicMock() - fake_client.emit = mock.MagicMock() + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() fake_listener = mock.MagicMock() - fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + fake_listener.adapter = fake_adapter with mock.patch( "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", @@ -592,7 +592,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_hook( ) assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. - assert fake_client.emit.call_count == 2 # Expect two events per query. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. expected_common_job_facets = { "jobType": job_type_job.JobTypeJobFacet( @@ -657,7 +657,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_hook( ), ] - assert fake_client.emit.call_args_list == expected_calls + assert fake_adapter.emit.call_args_list == expected_calls @mock.patch("importlib.metadata.version", return_value="2.3.0") @@ -665,10 +665,10 @@ def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(mock_ve query_ids = [] original_query_ids = copy.deepcopy(query_ids) - fake_client = mock.MagicMock() - fake_client.emit = mock.MagicMock() + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() fake_listener = mock.MagicMock() - fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + fake_listener.adapter = fake_adapter with mock.patch( "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", @@ -681,7 +681,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(mock_ve ) assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. - fake_client.emit.assert_not_called() # No events should be emitted + fake_adapter.emit.assert_not_called() # No events should be emitted # emit_openlineage_events_for_snowflake_queries requires OL provider 2.3.0 @@ -690,10 +690,10 @@ def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): query_ids = ["q1", "q2"] original_query_ids = copy.deepcopy(query_ids) - fake_client = mock.MagicMock() - fake_client.emit = mock.MagicMock() + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() fake_listener = mock.MagicMock() - fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + fake_listener.adapter = fake_adapter with mock.patch( "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", @@ -711,4 +711,4 @@ def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): task_instance=None, ) assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. - fake_client.emit.assert_not_called() # No events should be emitted + fake_adapter.emit.assert_not_called() # No events should be emitted