Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
past till current time execution.
Default value datetime.timedelta(days=1).
:param composer_dag_run_id: The Run ID of executable task. The 'execution_range' param is ignored, if both specified.
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
composer_dag_id: str,
allowed_states: Iterable[str] | None = None,
execution_range: timedelta | list[datetime] | None = None,
composer_dag_run_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
Expand All @@ -104,11 +106,17 @@ def __init__(
self.composer_dag_id = composer_dag_id
self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
self.execution_range = execution_range
self.composer_dag_run_id = composer_dag_run_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval

if self.composer_dag_run_id and self.execution_range:
self.log.warning(
"The composer_dag_run_id parameter and execution_range parameter do not work together. This run will ignore execution_range parameter and count only specified composer_dag_run_id parameter."
)

def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
if isinstance(self.execution_range, timedelta):
if self.execution_range < timedelta(0):
Expand All @@ -132,6 +140,16 @@ def poke(self, context: Context) -> bool:
self.log.info("Dag runs are empty. Sensor waits for dag runs...")
return False

if self.composer_dag_run_id:
self.log.info(
"Sensor waits for allowed states %s for specified RunID: %s",
self.allowed_states,
self.composer_dag_run_id,
)
composer_dag_run_id_status = self._check_composer_dag_run_id_states(
dag_runs=dag_runs,
)
return composer_dag_run_id_status
self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
allowed_states_status = self._check_dag_runs_states(
dag_runs=dag_runs,
Expand Down Expand Up @@ -193,6 +211,12 @@ def _get_composer_airflow_version(self) -> int:
image_version = environment_config["config"]["software_config"]["image_version"]
return int(image_version.split("airflow-")[1].split(".")[0])

def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
for dag_run in dag_runs:
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
return True
return False

def execute(self, context: Context) -> None:
self._composer_airflow_version = self._get_composer_airflow_version()
if self.deferrable:
Expand All @@ -204,6 +228,7 @@ def execute(self, context: Context) -> None:
region=self.region,
environment_id=self.environment_id,
composer_dag_id=self.composer_dag_id,
composer_dag_run_id=self.composer_dag_run_id,
start_date=start_date,
end_date=end_date,
allowed_states=self.allowed_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
start_date: datetime,
end_date: datetime,
allowed_states: list[str],
composer_dag_run_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
Expand All @@ -179,6 +180,7 @@ def __init__(
self.start_date = start_date
self.end_date = end_date
self.allowed_states = allowed_states
self.composer_dag_run_id = composer_dag_run_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
Expand All @@ -200,6 +202,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"start_date": self.start_date,
"end_date": self.end_date,
"allowed_states": self.allowed_states,
"composer_dag_run_id": self.composer_dag_run_id,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
Expand Down Expand Up @@ -248,6 +251,12 @@ def _check_dag_runs_states(
return False
return True

def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
for dag_run in dag_runs:
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
return True
return False

async def run(self):
try:
while True:
Expand All @@ -260,14 +269,24 @@ async def run(self):
await asyncio.sleep(self.poll_interval)
continue

self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
if self._check_dag_runs_states(
dag_runs=dag_runs,
start_date=self.start_date,
end_date=self.end_date,
):
yield TriggerEvent({"status": "success"})
return
if self.composer_dag_run_id:
self.log.info(
"Sensor waits for allowed states %s for specified RunID: %s",
self.allowed_states,
self.composer_dag_run_id,
)
if self._check_composer_dag_run_id_states(dag_runs=dag_runs):
yield TriggerEvent({"status": "success"})
return
else:
self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
if self._check_dag_runs_states(
dag_runs=dag_runs,
start_date=self.start_date,
end_date=self.end_date,
):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except AirflowException as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
TEST_ENVIRONMENT_ID = "test_env_id"
TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
TEST_JSON_RESULT = lambda state, date_key: json.dumps(
[
{
"dag_id": "test_dag_id",
"run_id": "scheduled__2024-05-22T11:10:00+00:00",
"run_id": TEST_COMPOSER_DAG_RUN_ID,
"state": state,
date_key: "2024-05-22T11:10:00+00:00",
"start_date": "2024-05-22T11:20:01.531988+00:00",
Expand Down Expand Up @@ -110,3 +111,45 @@ def test_dag_runs_empty(self, mock_hook, to_dict_mode, composer_airflow_version)
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"success", "execution_date" if composer_airflow_version < 3 else "logical_date"
)

task = CloudComposerDAGRunSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id="test_dag_id",
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"running", "execution_date" if composer_airflow_version < 3 else "logical_date"
)

task = CloudComposerDAGRunSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id="test_dag_id",
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"error": "test_error",
}
TEST_COMPOSER_DAG_ID = "test_dag_id"
TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
TEST_STATES = ["success"]
Expand Down Expand Up @@ -81,6 +82,7 @@ def dag_run_trigger(mock_conn):
region=TEST_LOCATION,
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id=TEST_COMPOSER_DAG_ID,
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
allowed_states=TEST_STATES,
Expand Down Expand Up @@ -136,6 +138,7 @@ def test_serialize(self, dag_run_trigger):
"region": TEST_LOCATION,
"environment_id": TEST_ENVIRONMENT_ID,
"composer_dag_id": TEST_COMPOSER_DAG_ID,
"composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID,
"start_date": TEST_START_DATE,
"end_date": TEST_END_DATE,
"allowed_states": TEST_STATES,
Expand Down