Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def generate_pod_yaml(args):
"""Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor"""
from kubernetes.client.api_client import ApiClient

from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.settings import pod_mutation_hook
Expand All @@ -399,14 +399,14 @@ def generate_pod_yaml(args):
pod = PodGenerator.construct_pod(
dag_id=args.dag_id,
task_id=ti.task_id,
pod_id=AirflowKubernetesScheduler._create_pod_id( # pylint: disable=W0212
pod_id=create_pod_id(
args.dag_id, ti.task_id),
try_number=ti.try_number,
kube_image=kube_config.kube_image,
date=ti.execution_date,
command=ti.command_as_list(),
pod_override_object=PodGenerator.from_obj(ti.executor_config),
worker_uuid="worker-config",
scheduler_job_id="worker-config",
namespace=kube_config.executor_namespace,
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file)
)
Expand Down
2 changes: 2 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BaseExecutor(LoggingMixin):
``0`` for infinity
"""

job_id: Optional[str] = None

def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
Expand Down
192 changes: 140 additions & 52 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import multiprocessing
import time
from queue import Empty, Queue # pylint: disable=unused-import
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import kubernetes
from dateutil import parser
Expand All @@ -44,7 +44,7 @@
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, PodGenerator
from airflow.kubernetes.pod_launcher import PodLauncher
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance
from airflow.models import TaskInstance
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
Expand All @@ -60,6 +60,18 @@
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]


class ResourceVersion:
"""Singleton for tracking resourceVersion from Kubernetes"""

_instance = None
resource_version = "0"

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance


class KubeConfig: # pylint: disable=too-many-instance-attributes
"""Configuration for Kubernetes"""

Expand Down Expand Up @@ -134,25 +146,25 @@ def __init__(self,
multi_namespace_mode: bool,
watcher_queue: 'Queue[KubernetesWatchType]',
resource_version: Optional[str],
worker_uuid: Optional[str],
scheduler_job_id: Optional[str],
kube_config: Configuration):
super().__init__()
self.namespace = namespace
self.multi_namespace_mode = multi_namespace_mode
self.worker_uuid = worker_uuid
self.scheduler_job_id = scheduler_job_id
self.watcher_queue = watcher_queue
self.resource_version = resource_version
self.kube_config = kube_config

def run(self) -> None:
"""Performs watching"""
kube_client: client.CoreV1Api = get_kube_client()
if not self.worker_uuid:
if not self.scheduler_job_id:
raise AirflowException(NOT_STARTED_MESSAGE)
while True:
try:
self.resource_version = self._run(kube_client, self.resource_version,
self.worker_uuid, self.kube_config)
self.scheduler_job_id, self.kube_config)
except ReadTimeoutError:
self.log.warning("There was a timeout error accessing the Kube API. "
"Retrying request.", exc_info=True)
Expand All @@ -167,15 +179,15 @@ def run(self) -> None:
def _run(self,
kube_client: client.CoreV1Api,
resource_version: Optional[str],
worker_uuid: str,
scheduler_job_id: str,
kube_config: Any) -> Optional[str]:
self.log.info(
'Event: and now my watch begins starting at resource_version: %s',
resource_version
)
watcher = watch.Watch()

kwargs = {'label_selector': 'airflow-worker={}'.format(worker_uuid)}
kwargs = {'label_selector': 'airflow-worker={}'.format(scheduler_job_id)}
if resource_version:
kwargs['resource_version'] = resource_version
if kube_config.kube_client_request_args:
Expand Down Expand Up @@ -277,7 +289,7 @@ def __init__(self,
task_queue: 'Queue[KubernetesJobType]',
result_queue: 'Queue[KubernetesResultsType]',
kube_client: client.CoreV1Api,
worker_uuid: str):
scheduler_job_id: str):
super().__init__()
self.log.debug("Creating Kubernetes executor")
self.kube_config = kube_config
Expand All @@ -289,16 +301,16 @@ def __init__(self,
self.launcher = PodLauncher(kube_client=self.kube_client)
self._manager = multiprocessing.Manager()
self.watcher_queue = self._manager.Queue()
self.worker_uuid = worker_uuid
self.scheduler_job_id = scheduler_job_id
self.kube_watcher = self._make_kube_watcher()

def _make_kube_watcher(self) -> KubernetesJobWatcher:
resource_version = KubeResourceVersion.get_current_resource_version()
resource_version = ResourceVersion().resource_version
watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue,
namespace=self.kube_config.kube_namespace,
multi_namespace_mode=self.kube_config.multi_namespace_mode,
resource_version=resource_version,
worker_uuid=self.worker_uuid,
scheduler_job_id=self.scheduler_job_id,
kube_config=self.kube_config)
watcher.start()
return watcher
Expand Down Expand Up @@ -333,8 +345,8 @@ def run_next(self, next_job: KubernetesJobType) -> None:

pod = PodGenerator.construct_pod(
namespace=self.namespace,
worker_uuid=self.worker_uuid,
pod_id=self._create_pod_id(dag_id, task_id),
scheduler_job_id=self.scheduler_job_id,
pod_id=create_pod_id(dag_id, task_id),
dag_id=dag_id,
task_id=task_id,
kube_image=self.kube_config.kube_image,
Expand Down Expand Up @@ -404,21 +416,6 @@ def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInsta

return TaskInstanceKey(dag_id, task_id, execution_date, try_number)

@staticmethod
def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
"""
Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
the pod name
However, there are special rules about how "-" and "." can be used so let's
only keep
alphanumeric chars see here for detail:
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/

:param string: The requested Pod name
:return: ``str`` Pod name stripped of any unsafe characters
"""
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())

@staticmethod
def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> str:
r"""
Expand All @@ -437,14 +434,6 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st

return safe_pod_id

@staticmethod
def _create_pod_id(dag_id: str, task_id: str) -> str:
safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
dag_id)
safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
task_id)
return safe_dag_id + safe_task_id

def _flush_watcher_queue(self) -> None:
self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize())
while True:
Expand All @@ -470,6 +459,36 @@ def terminate(self) -> None:
self._manager.shutdown()


def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
"""
Kubernetes only supports lowercase alphanumeric characters, "-" and "." in
the pod name.
However, there are special rules about how "-" and "." can be used so let's
only keep
alphanumeric chars see here for detail:
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/

:param string: The requested Pod name
:return: ``str`` Pod name stripped of any unsafe characters
"""
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())


def create_pod_id(dag_id: str, task_id: str) -> str:
"""
Generates the kubernetes safe pod_id. Note that this is
NOT the full ID that will be launched to k8s. We will add a uuid
to ensure uniqueness.

:param dag_id: DAG ID
:param task_id: Task ID
:@return: The non-unique pod_id for this task/DAG pairing1
"""
safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id)
safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id)
return safe_dag_id + safe_task_id


class KubernetesExecutor(BaseExecutor, LoggingMixin):
"""Executor for Kubernetes"""

Expand All @@ -480,7 +499,7 @@ def __init__(self):
self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue()
self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None
self.kube_client: Optional[client.CoreV1Api] = None
self.worker_uuid: Optional[str] = None
self.scheduler_job_id: Optional[str] = None
super().__init__(parallelism=self.kube_config.parallelism)

@provide_session
Expand Down Expand Up @@ -519,7 +538,7 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
pod_generator.datetime_to_label_safe_datestring(
task.execution_date
),
self.worker_uuid
self.scheduler_job_id
)
)
# pylint: enable=protected-access
Expand Down Expand Up @@ -568,19 +587,14 @@ def _create_or_update_secret(secret_name, secret_path):
def start(self) -> None:
"""Starts the executor"""
self.log.info('Start Kubernetes executor')
self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid()
if not self.worker_uuid:
raise AirflowException("Could not get worker uuid")
self.log.debug('Start with worker_uuid: %s', self.worker_uuid)
# always need to reset resource version since we don't know
# when we last started, note for behavior below
# https://github.com/kubernetes-client/python/blob/master/kubernetes/docs
# /CoreV1Api.md#list_namespaced_pod
KubeResourceVersion.reset_resource_version()
if not self.job_id:
raise AirflowException("Could not get scheduler_job_id")
self.scheduler_job_id = self.job_id
self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id)
self.kube_client = get_kube_client()
self.kube_scheduler = AirflowKubernetesScheduler(
self.kube_config, self.task_queue, self.result_queue,
self.kube_client, self.worker_uuid
self.kube_client, self.scheduler_job_id
)
self._inject_secrets()
self.clear_not_launched_queued_tasks()
Expand All @@ -595,10 +609,10 @@ def execute_async(self,
'Add task %s with command %s with executor_config %s',
key, command, executor_config
)

kube_executor_config = PodGenerator.from_obj(executor_config)
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
self.task_queue.put((key, command, kube_executor_config))

def sync(self) -> None:
Expand All @@ -607,7 +621,7 @@ def sync(self) -> None:
self.log.debug('self.running: %s', self.running)
if self.queued_tasks:
self.log.debug('self.queued: %s', self.queued_tasks)
if not self.worker_uuid:
if not self.scheduler_job_id:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.kube_scheduler:
raise AirflowException(NOT_STARTED_MESSAGE)
Expand Down Expand Up @@ -640,7 +654,8 @@ def sync(self) -> None:
except Empty:
break

KubeResourceVersion.checkpoint_resource_version(last_resource_version)
resource_instance = ResourceVersion()
resource_instance.resource_version = last_resource_version or resource_instance.resource_version

# pylint: disable=too-many-nested-blocks
for _ in range(self.kube_config.worker_pods_creation_batch_size):
Expand Down Expand Up @@ -681,6 +696,79 @@ def _change_state(self,
self.log.debug('Could not find key: %s', str(key))
self.event_buffer[key] = state, None

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
tis_to_flush = [ti for ti in tis if not ti.external_executor_id]
scheduler_job_ids = [ti.external_executor_id for ti in tis]
pod_ids = {
create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti
for ti in tis if ti.external_executor_id
}
kube_client: client.CoreV1Api = self.kube_client
for scheduler_job_id in scheduler_job_ids:
kwargs = {
'label_selector': f'airflow-worker={scheduler_job_id}'
}
pod_list = kube_client.list_namespaced_pod(
namespace=self.kube_config.kube_namespace,
**kwargs
)
for pod in pod_list.items:
self.adopt_launched_task(kube_client, pod, pod_ids)
self._adopt_completed_pods(kube_client)
tis_to_flush.extend(pod_ids.values())
return tis_to_flush

def adopt_launched_task(self, kube_client, pod, pod_ids: dict):
"""
Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors

:param kube_client: kubernetes client for speaking to kube API
:param pod: V1Pod spec that we will patch with new label
:param pod_ids: pod_ids we expect to patch.
"""
self.log.info("attempting to adopt pod %s", pod.metadata.name)
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
dag_id = pod.metadata.labels['dag_id']
task_id = pod.metadata.labels['task_id']
pod_id = create_pod_id(dag_id=dag_id, task_id=task_id)
if pod_id not in pod_ids:
self.log.error("attempting to adopt task %s in dag %s"
" which was not specified by database", task_id, dag_id)
else:
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body=PodGenerator.serialize_pod(pod),
)
pod_ids.pop(pod_id)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)

def _adopt_completed_pods(self, kube_client: kubernetes.client.CoreV1Api):
"""

Patch completed pod so that the KubernetesJobWatcher can delete it.

:param kube_client: kubernetes client for speaking to kube API
"""
kwargs = {
'field_selector': "status.phase=Succeeded",
'label_selector': 'kubernetes_executor=True',
}
pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
for pod in pod_list.items:
self.log.info("Attempting to adopt pod %s", pod.metadata.name)
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body=PodGenerator.serialize_pod(pod),
)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)

def _flush_task_queue(self) -> None:
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
Expand Down
Loading