diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py index a6c42d8a83bf7..3c03aa3f5010d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py @@ -1056,6 +1056,7 @@ def __init__( in_cluster=self.in_cluster, namespace=self.namespace, name=self.pod_name, + trigger_kwargs={"eks_cluster_name": cluster_name}, **kwargs, ) # There is no need to manage the kube_config file, as it will be generated automatically. @@ -1072,3 +1073,15 @@ def execute(self, context: Context): eks_cluster_name=self.cluster_name, pod_namespace=self.namespace ) as self.config_file: return super().execute(context) + + def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: + eks_hook = EksHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region, + ) + eks_cluster_name = event["eks_cluster_name"] + pod_namespace = event["namespace"] + with eks_hook.generate_config_file( + eks_cluster_name=eks_cluster_name, pod_namespace=pod_namespace + ) as self.config_file: + return super().trigger_reentry(context, event) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py index 826ed9b99718b..66cda6292ccef 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py @@ -823,3 +823,27 @@ def test_template_fields(self): ) validate_template_fields(op) + + @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file") + def test_trigger_reentry(self, mock_generate_config_file, mock_k8s_pod_operator_trigger_reentry): + ti_context = mock.MagicMock(name="ti_context") + event = {"eks_cluster_name": "eks_cluster_name", "namespace": "namespace"} + + op = EksPodOperator( + task_id="run_pod", + pod_name="run_pod", + cluster_name=CLUSTER_NAME, + image="amazon/aws-cli:latest", + cmds=["sh", "-c", "ls"], + labels={"demo": "hello_world"}, + get_logs=True, + # Delete the pod when it reaches its final state, or the execution is interrupted. + on_finish_action="delete_pod", + ) + op.trigger_reentry(ti_context, event) + mock_k8s_pod_operator_trigger_reentry.assert_called_once_with(ti_context, event) + mock_generate_config_file.assert_called_once_with( + eks_cluster_name="eks_cluster_name", pod_namespace="namespace" + ) + assert mock_generate_config_file.return_value.__enter__.return_value == op.config_file diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index 8706e444c4bab..a8ea2fcd4f10e 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -233,6 +233,7 @@ class KubernetesPodOperator(BaseOperator): :param logging_interval: max time in seconds that task should be in deferred state before resuming to fetch the latest logs. If ``None``, then the task will remain in deferred state until pod is done, and no logs will be visible until that time. + :param trigger_kwargs: additional keyword parameters passed to the trigger """ # !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!! @@ -266,6 +267,7 @@ class KubernetesPodOperator(BaseOperator): "node_selector", "kubernetes_conn_id", "base_container_name", + "trigger_kwargs", ) template_fields_renderers = {"env_vars": "py"} @@ -339,6 +341,7 @@ def __init__( ) = None, progress_callback: Callable[[str], None] | None = None, logging_interval: int | None = None, + trigger_kwargs: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -428,6 +431,7 @@ def __init__( self.termination_message_policy = termination_message_policy self.active_deadline_seconds = active_deadline_seconds self.logging_interval = logging_interval + self.trigger_kwargs = trigger_kwargs self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback @@ -812,6 +816,7 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None: on_finish_action=self.on_finish_action.value, last_log_time=last_log_time, logging_interval=self.logging_interval, + trigger_kwargs=self.trigger_kwargs, ), method_name="trigger_reentry", ) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py index 6ec0c99932b22..719e1f6f34441 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -75,6 +75,7 @@ class KubernetesPodTrigger(BaseTrigger): :param logging_interval: number of seconds to wait before kicking it back to the operator to print latest logs. If ``None`` will wait until container done. :param last_log_time: where to resume logs from + :param trigger_kwargs: additional keyword parameters to send in the event """ def __init__( @@ -94,6 +95,7 @@ def __init__( on_finish_action: str = "delete_pod", last_log_time: DateTime | None = None, logging_interval: int | None = None, + trigger_kwargs: dict | None = None, ): super().__init__() self.pod_name = pod_name @@ -111,6 +113,7 @@ def __init__( self.last_log_time = last_log_time self.logging_interval = logging_interval self.on_finish_action = OnFinishAction(on_finish_action) + self.trigger_kwargs = trigger_kwargs or {} self._since_time = None @@ -134,6 +137,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "on_finish_action": self.on_finish_action.value, "last_log_time": self.last_log_time, "logging_interval": self.logging_interval, + "trigger_kwargs": self.trigger_kwargs, }, ) @@ -149,6 +153,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "namespace": self.pod_namespace, "name": self.pod_name, "message": "All containers inside pod have started successfully.", + **self.trigger_kwargs, } ) elif state == ContainerState.FAILED: @@ -158,6 +163,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "namespace": self.pod_namespace, "name": self.pod_name, "message": "pod failed", + **self.trigger_kwargs, } ) else: @@ -172,6 +178,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "namespace": self.pod_namespace, "status": "timeout", "message": message, + **self.trigger_kwargs, } ) return @@ -183,6 +190,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "status": "error", "message": str(e), "stack_trace": traceback.format_exc(), + **self.trigger_kwargs, } ) return @@ -234,6 +242,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent: "namespace": self.pod_namespace, "name": self.pod_name, "last_log_time": self.last_log_time, + **self.trigger_kwargs, } ) if container_state == ContainerState.FAILED: @@ -244,6 +253,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent: "name": self.pod_name, "message": "Container state failed", "last_log_time": self.last_log_time, + **self.trigger_kwargs, } ) self.log.debug("Container is not completed and still working.") @@ -254,6 +264,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent: "last_log_time": self.last_log_time, "namespace": self.pod_namespace, "name": self.pod_name, + **self.trigger_kwargs, } ) self.log.debug("Sleeping for %s seconds.", self.poll_interval) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py index 2afb34aa3dfff..66fae2524d618 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py @@ -111,6 +111,7 @@ def test_serialize(self, trigger): "on_finish_action": ON_FINISH_ACTION, "last_log_time": None, "logging_interval": None, + "trigger_kwargs": {}, } @pytest.mark.asyncio