Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

from typing import TYPE_CHECKING

from jinja2 import TemplateAssertionError, UndefinedError
from kubernetes.client.api_client import ApiClient

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.kube_config import KubeConfig
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -58,3 +61,17 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None:
)
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
return sanitized_pod


@provide_session
def get_rendered_k8s_spec(task_instance: TaskInstance, session=NEW_SESSION) -> dict | None:
"""Fetch rendered template fields from DB."""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(task_instance, session=session)
if not rendered_k8s_spec:
try:
rendered_k8s_spec = render_k8s_pod_yaml(task_instance)
except (TemplateAssertionError, UndefinedError) as e:
raise AirflowException(f"Unable to render a k8s spec for this taskinstance: {e}") from e
return rendered_k8s_spec
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from kubernetes.client import models as k8s
from sqlalchemy.orm import make_transient

from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.providers.cncf.kubernetes.template_rendering import render_k8s_pod_yaml
from airflow.models.renderedtifields import RenderedTaskInstanceFields, RenderedTaskInstanceFields as RTIF
from airflow.providers.cncf.kubernetes.template_rendering import get_rendered_k8s_spec, render_k8s_pod_yaml
from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.version import version
Expand Down Expand Up @@ -149,6 +150,39 @@ def test_render_k8s_pod_yaml_with_custom_pod_template_and_pod_override(
assert ti_pod_yaml["metadata"]["annotations"]["test"] == "annotation"


@pytest.mark.skipif(
AIRFLOW_V_3_0_PLUS,
reason="This test is only needed for Airflow 2 - we can remove it after "
"only Airflow 3 is supported in providers",
)
@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch.object(RenderedTaskInstanceFields, "get_k8s_pod_yaml")
@mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml")
def test_get_rendered_k8s_spec(render_k8s_pod_yaml, rtif_get_k8s_pod_yaml, create_task_instance):
# Create new TI for the same Task
ti = create_task_instance()

mock.patch.object(ti, "render_k8s_pod_yaml", autospec=True)

fake_spec = {"ermagawds": "pods"}

session = mock.Mock()

rtif_get_k8s_pod_yaml.return_value = fake_spec
assert get_rendered_k8s_spec(ti, session=session) == fake_spec

rtif_get_k8s_pod_yaml.assert_called_once_with(ti, session=session)
render_k8s_pod_yaml.assert_not_called()

# Now test that when we _dont_ find it in the DB, it calls render_k8s_pod_yaml
rtif_get_k8s_pod_yaml.return_value = None
render_k8s_pod_yaml.return_value = fake_spec

assert get_rendered_k8s_spec(session) == fake_spec

render_k8s_pod_yaml.assert_called_once()


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml")
def test_get_k8s_pod_yaml(render_k8s_pod_yaml, dag_maker, session):
Expand Down