diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py index 85f8323225daa..d69cce098996c 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator): Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the past till current time execution. Default value datetime.timedelta(days=1). + :param composer_dag_run_id: The Run ID of executable task. The 'execution_range' param is ignored, if both specified. :param gcp_conn_id: The connection ID to use when fetching connection info. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -91,6 +92,7 @@ def __init__( composer_dag_id: str, allowed_states: Iterable[str] | None = None, execution_range: timedelta | list[datetime] | None = None, + composer_dag_run_id: str | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), @@ -104,11 +106,17 @@ def __init__( self.composer_dag_id = composer_dag_id self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value] self.execution_range = execution_range + self.composer_dag_run_id = composer_dag_run_id self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.deferrable = deferrable self.poll_interval = poll_interval + if self.composer_dag_run_id and self.execution_range: + self.log.warning( + "The composer_dag_run_id parameter and execution_range parameter do not work together. This run will ignore execution_range parameter and count only specified composer_dag_run_id parameter." + ) + def _get_logical_dates(self, context) -> tuple[datetime, datetime]: if isinstance(self.execution_range, timedelta): if self.execution_range < timedelta(0): @@ -132,6 +140,16 @@ def poke(self, context: Context) -> bool: self.log.info("Dag runs are empty. Sensor waits for dag runs...") return False + if self.composer_dag_run_id: + self.log.info( + "Sensor waits for allowed states %s for specified RunID: %s", + self.allowed_states, + self.composer_dag_run_id, + ) + composer_dag_run_id_status = self._check_composer_dag_run_id_states( + dag_runs=dag_runs, + ) + return composer_dag_run_id_status self.log.info("Sensor waits for allowed states: %s", self.allowed_states) allowed_states_status = self._check_dag_runs_states( dag_runs=dag_runs, @@ -193,6 +211,12 @@ def _get_composer_airflow_version(self) -> int: image_version = environment_config["config"]["software_config"]["image_version"] return int(image_version.split("airflow-")[1].split(".")[0]) + def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool: + for dag_run in dag_runs: + if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states: + return True + return False + def execute(self, context: Context) -> None: self._composer_airflow_version = self._get_composer_airflow_version() if self.deferrable: @@ -204,6 +228,7 @@ def execute(self, context: Context) -> None: region=self.region, environment_id=self.environment_id, composer_dag_id=self.composer_dag_id, + composer_dag_run_id=self.composer_dag_run_id, start_date=start_date, end_date=end_date, allowed_states=self.allowed_states, diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py index 8b10d2048c5f2..4a2fbd2ab0de8 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -166,6 +166,7 @@ def __init__( start_date: datetime, end_date: datetime, allowed_states: list[str], + composer_dag_run_id: str | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: int = 10, @@ -179,6 +180,7 @@ def __init__( self.start_date = start_date self.end_date = end_date self.allowed_states = allowed_states + self.composer_dag_run_id = composer_dag_run_id self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval @@ -200,6 +202,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "start_date": self.start_date, "end_date": self.end_date, "allowed_states": self.allowed_states, + "composer_dag_run_id": self.composer_dag_run_id, "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, @@ -248,6 +251,12 @@ def _check_dag_runs_states( return False return True + def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool: + for dag_run in dag_runs: + if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states: + return True + return False + async def run(self): try: while True: @@ -260,14 +269,24 @@ async def run(self): await asyncio.sleep(self.poll_interval) continue - self.log.info("Sensor waits for allowed states: %s", self.allowed_states) - if self._check_dag_runs_states( - dag_runs=dag_runs, - start_date=self.start_date, - end_date=self.end_date, - ): - yield TriggerEvent({"status": "success"}) - return + if self.composer_dag_run_id: + self.log.info( + "Sensor waits for allowed states %s for specified RunID: %s", + self.allowed_states, + self.composer_dag_run_id, + ) + if self._check_composer_dag_run_id_states(dag_runs=dag_runs): + yield TriggerEvent({"status": "success"}) + return + else: + self.log.info("Sensor waits for allowed states: %s", self.allowed_states) + if self._check_dag_runs_states( + dag_runs=dag_runs, + start_date=self.start_date, + end_date=self.end_date, + ): + yield TriggerEvent({"status": "success"}) + return self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) except AirflowException as ex: diff --git a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py index ccd034a500839..c2da2daa00b22 100644 --- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py +++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py @@ -29,11 +29,12 @@ TEST_OPERATION_NAME = "test_operation_name" TEST_REGION = "region" TEST_ENVIRONMENT_ID = "test_env_id" +TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00" TEST_JSON_RESULT = lambda state, date_key: json.dumps( [ { "dag_id": "test_dag_id", - "run_id": "scheduled__2024-05-22T11:10:00+00:00", + "run_id": TEST_COMPOSER_DAG_RUN_ID, "state": state, date_key: "2024-05-22T11:10:00+00:00", "start_date": "2024-05-22T11:20:01.531988+00:00", @@ -110,3 +111,45 @@ def test_dag_runs_empty(self, mock_hook, to_dict_mode, composer_airflow_version) task._composer_airflow_version = composer_airflow_version assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode, composer_airflow_version): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT( + "success", "execution_date" if composer_airflow_version < 3 else "logical_date" + ) + + task = CloudComposerDAGRunSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_dag_id="test_dag_id", + composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID, + allowed_states=["success"], + ) + task._composer_airflow_version = composer_airflow_version + + assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode, composer_airflow_version): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT( + "running", "execution_date" if composer_airflow_version < 3 else "logical_date" + ) + + task = CloudComposerDAGRunSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_dag_id="test_dag_id", + composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID, + allowed_states=["success"], + ) + task._composer_airflow_version = composer_airflow_version + + assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py index 8716805fa133d..cde313785d6ec 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py @@ -39,6 +39,7 @@ "error": "test_error", } TEST_COMPOSER_DAG_ID = "test_dag_id" +TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00" TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0) TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0) TEST_STATES = ["success"] @@ -81,6 +82,7 @@ def dag_run_trigger(mock_conn): region=TEST_LOCATION, environment_id=TEST_ENVIRONMENT_ID, composer_dag_id=TEST_COMPOSER_DAG_ID, + composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID, start_date=TEST_START_DATE, end_date=TEST_END_DATE, allowed_states=TEST_STATES, @@ -136,6 +138,7 @@ def test_serialize(self, dag_run_trigger): "region": TEST_LOCATION, "environment_id": TEST_ENVIRONMENT_ID, "composer_dag_id": TEST_COMPOSER_DAG_ID, + "composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID, "start_date": TEST_START_DATE, "end_date": TEST_END_DATE, "allowed_states": TEST_STATES,