Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,7 @@ def __init__(
in_cluster=self.in_cluster,
namespace=self.namespace,
name=self.pod_name,
trigger_kwargs={"eks_cluster_name": cluster_name},
**kwargs,
)
# There is no need to manage the kube_config file, as it will be generated automatically.
Expand All @@ -1072,3 +1073,15 @@ def execute(self, context: Context):
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
) as self.config_file:
return super().execute(context)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
eks_cluster_name = event["eks_cluster_name"]
pod_namespace = event["namespace"]
with eks_hook.generate_config_file(
eks_cluster_name=eks_cluster_name, pod_namespace=pod_namespace
) as self.config_file:
return super().trigger_reentry(context, event)
24 changes: 24 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,27 @@ def test_template_fields(self):
)

validate_template_fields(op)

@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry")
@mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file")
def test_trigger_reentry(self, mock_generate_config_file, mock_k8s_pod_operator_trigger_reentry):
ti_context = mock.MagicMock(name="ti_context")
event = {"eks_cluster_name": "eks_cluster_name", "namespace": "namespace"}

op = EksPodOperator(
task_id="run_pod",
pod_name="run_pod",
cluster_name=CLUSTER_NAME,
image="amazon/aws-cli:latest",
cmds=["sh", "-c", "ls"],
labels={"demo": "hello_world"},
get_logs=True,
# Delete the pod when it reaches its final state, or the execution is interrupted.
on_finish_action="delete_pod",
)
op.trigger_reentry(ti_context, event)
mock_k8s_pod_operator_trigger_reentry.assert_called_once_with(ti_context, event)
mock_generate_config_file.assert_called_once_with(
eks_cluster_name="eks_cluster_name", pod_namespace="namespace"
)
assert mock_generate_config_file.return_value.__enter__.return_value == op.config_file
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class KubernetesPodOperator(BaseOperator):
:param logging_interval: max time in seconds that task should be in deferred state before
resuming to fetch the latest logs. If ``None``, then the task will remain in deferred state until pod
is done, and no logs will be visible until that time.
:param trigger_kwargs: additional keyword parameters passed to the trigger
"""

# !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
Expand Down Expand Up @@ -266,6 +267,7 @@ class KubernetesPodOperator(BaseOperator):
"node_selector",
"kubernetes_conn_id",
"base_container_name",
"trigger_kwargs",
)
template_fields_renderers = {"env_vars": "py"}

Expand Down Expand Up @@ -339,6 +341,7 @@ def __init__(
) = None,
progress_callback: Callable[[str], None] | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -428,6 +431,7 @@ def __init__(
self.termination_message_policy = termination_message_policy
self.active_deadline_seconds = active_deadline_seconds
self.logging_interval = logging_interval
self.trigger_kwargs = trigger_kwargs

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
Expand Down Expand Up @@ -812,6 +816,7 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None:
on_finish_action=self.on_finish_action.value,
last_log_time=last_log_time,
logging_interval=self.logging_interval,
trigger_kwargs=self.trigger_kwargs,
),
method_name="trigger_reentry",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class KubernetesPodTrigger(BaseTrigger):
:param logging_interval: number of seconds to wait before kicking it back to
the operator to print latest logs. If ``None`` will wait until container done.
:param last_log_time: where to resume logs from
:param trigger_kwargs: additional keyword parameters to send in the event
"""

def __init__(
Expand All @@ -94,6 +95,7 @@ def __init__(
on_finish_action: str = "delete_pod",
last_log_time: DateTime | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -111,6 +113,7 @@ def __init__(
self.last_log_time = last_log_time
self.logging_interval = logging_interval
self.on_finish_action = OnFinishAction(on_finish_action)
self.trigger_kwargs = trigger_kwargs or {}

self._since_time = None

Expand All @@ -134,6 +137,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"on_finish_action": self.on_finish_action.value,
"last_log_time": self.last_log_time,
"logging_interval": self.logging_interval,
"trigger_kwargs": self.trigger_kwargs,
},
)

Expand All @@ -149,6 +153,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": "All containers inside pod have started successfully.",
**self.trigger_kwargs,
}
)
elif state == ContainerState.FAILED:
Expand All @@ -158,6 +163,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": "pod failed",
**self.trigger_kwargs,
}
)
else:
Expand All @@ -172,6 +178,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"namespace": self.pod_namespace,
"status": "timeout",
"message": message,
**self.trigger_kwargs,
}
)
return
Expand All @@ -183,6 +190,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"status": "error",
"message": str(e),
"stack_trace": traceback.format_exc(),
**self.trigger_kwargs,
}
)
return
Expand Down Expand Up @@ -234,6 +242,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
"namespace": self.pod_namespace,
"name": self.pod_name,
"last_log_time": self.last_log_time,
**self.trigger_kwargs,
}
)
if container_state == ContainerState.FAILED:
Expand All @@ -244,6 +253,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
"name": self.pod_name,
"message": "Container state failed",
"last_log_time": self.last_log_time,
**self.trigger_kwargs,
}
)
self.log.debug("Container is not completed and still working.")
Expand All @@ -254,6 +264,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
"last_log_time": self.last_log_time,
"namespace": self.pod_namespace,
"name": self.pod_name,
**self.trigger_kwargs,
}
)
self.log.debug("Sleeping for %s seconds.", self.poll_interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_serialize(self, trigger):
"on_finish_action": ON_FINISH_ACTION,
"last_log_time": None,
"logging_interval": None,
"trigger_kwargs": {},
}

@pytest.mark.asyncio
Expand Down
Loading