Skip to content

Commit

Permalink
Re-configure ORM in spawned OpenLineage process in scheduler. (#39735)
Browse files Browse the repository at this point in the history
Log exceptions that occur within ProcessPoolExecutor.

Signed-off-by: Jakub Dardzinski <kuba0221@gmail.com>
  • Loading branch information
JDarDagran committed May 21, 2024
1 parent a81504e commit b7671ef
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 41 deletions.
95 changes: 57 additions & 38 deletions airflow/providers/openlineage/plugins/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import traceback
import uuid
from contextlib import ExitStack
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -299,48 +300,66 @@ def dag_started(
nominal_start_time: str,
nominal_end_time: str,
):
event = RunEvent(
eventType=RunState.START,
eventTime=dag_run.start_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=self._build_run(
run_id=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id),
job_name=dag_run.dag_id,
nominal_start_time=nominal_start_time,
nominal_end_time=nominal_end_time,
),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
try:
event = RunEvent(
eventType=RunState.START,
eventTime=dag_run.start_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=self._build_run(
run_id=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id),
job_name=dag_run.dag_id,
nominal_start_time=nominal_start_time,
nominal_end_time=nominal_end_time,
),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
except BaseException:
# Catch all exceptions to prevent ProcessPoolExecutor from silently swallowing them.
# This ensures that any unexpected exceptions are logged for debugging purposes.
# This part cannot be wrapped to deduplicate code, otherwise the method cannot be pickled in multiprocessing.
self.log.warning("Failed to emit DAG started event: \n %s", traceback.format_exc())

def dag_success(self, dag_run: DagRun, msg: str):
event = RunEvent(
eventType=RunState.COMPLETE,
eventTime=dag_run.end_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=Run(runId=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id)),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
try:
event = RunEvent(
eventType=RunState.COMPLETE,
eventTime=dag_run.end_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=Run(runId=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id)),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
except BaseException:
# Catch all exceptions to prevent ProcessPoolExecutor from silently swallowing them.
# This ensures that any unexpected exceptions are logged for debugging purposes.
# This part cannot be wrapped to deduplicate code, otherwise the method cannot be pickled in multiprocessing.
self.log.warning("Failed to emit DAG success event: \n %s", traceback.format_exc())

def dag_failed(self, dag_run: DagRun, msg: str):
event = RunEvent(
eventType=RunState.FAIL,
eventTime=dag_run.end_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=Run(
runId=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id),
facets={"errorMessage": ErrorMessageRunFacet(message=msg, programmingLanguage="python")},
),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
try:
event = RunEvent(
eventType=RunState.FAIL,
eventTime=dag_run.end_date.isoformat(),
job=self._build_job(job_name=dag_run.dag_id, job_type=_JOB_TYPE_DAG),
run=Run(
runId=self.build_dag_run_id(dag_run.dag_id, dag_run.run_id),
facets={"errorMessage": ErrorMessageRunFacet(message=msg, programmingLanguage="python")},
),
inputs=[],
outputs=[],
producer=_PRODUCER,
)
self.emit(event)
except BaseException:
# Catch all exceptions to prevent ProcessPoolExecutor from silently swallowing them.
# This ensures that any unexpected exceptions are logged for debugging purposes.
# This part cannot be wrapped to deduplicate code, otherwise the method cannot be pickled in multiprocessing.
self.log.warning("Failed to emit DAG failed event: \n %s", traceback.format_exc())

@staticmethod
def _build_run(
Expand Down
12 changes: 10 additions & 2 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from openlineage.client.serde import Serde

from airflow import __version__ as airflow_version
from airflow import __version__ as airflow_version, settings
from airflow.listeners import hookimpl
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager
Expand Down Expand Up @@ -281,8 +281,16 @@ def on_failure():

@property
def executor(self):
def initializer():
# Re-configure the ORM engine as there are issues with multiple processes
# if process calls Airflow DB.
settings.configure_orm()

if not self._executor:
self._executor = ProcessPoolExecutor(max_workers=conf.dag_state_change_process_pool_size())
self._executor = ProcessPoolExecutor(
max_workers=conf.dag_state_change_process_pool_size(),
initializer=initializer,
)
return self._executor

@hookimpl
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/openlineage/plugins/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_exec
try:
with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}):
listener.on_dag_run_running(mock.MagicMock(), None)
mock_executor.assert_called_once_with(max_workers=expected)
mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY)
mock_executor.return_value.submit.assert_called_once()
finally:
conf.dag_state_change_process_pool_size.cache_clear()
Expand Down

0 comments on commit b7671ef

Please sign in to comment.