diff --git a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py index 46ea343..e965880 100644 --- a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py +++ b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py @@ -47,6 +47,7 @@ ) from task_processing.plugins.kubernetes.utils import get_pod_volumes from task_processing.plugins.kubernetes.utils import get_sanitised_kubernetes_name +from task_processing.plugins.kubernetes.utils import get_topology_spread_constraints logger = logging.getLogger(__name__) @@ -561,6 +562,9 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]: affinity=V1Affinity( node_affinity=get_node_affinity(task_config.node_affinities), ), + topology_spread_constraints=get_topology_spread_constraints( + task_config.topology_spread_constraints + ), # we're hardcoding this as Default as this is what we generally use # internally - until we have a usecase for something that needs one # of the other DNS policies, we can probably punt on plumbing all the diff --git a/task_processing/plugins/kubernetes/task_config.py b/task_processing/plugins/kubernetes/task_config.py index 4eceada..118bccc 100644 --- a/task_processing/plugins/kubernetes/task_config.py +++ b/task_processing/plugins/kubernetes/task_config.py @@ -8,6 +8,7 @@ from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING +from typing import Union from pyrsistent import field from pyrsistent import m @@ -25,6 +26,7 @@ from task_processing.plugins.kubernetes.types import ProjectedSAVolume from task_processing.plugins.kubernetes.types import SecretVolume from task_processing.plugins.kubernetes.types import SecretVolumeItem +from task_processing.plugins.kubernetes.types import TopologySpreadContraint from task_processing.plugins.kubernetes.utils import ( DEFAULT_PROJECTED_SA_TOKEN_EXPIRATION_SECONDS, ) @@ -473,6 +475,11 @@ def __invariant__(self) -> Tuple[Tuple[bool, str], ...]: factory=pvector, invariant=_valid_node_affinities, ) + topology_spread_constraints = field( + type=PVector if not TYPE_CHECKING else PVector["TopologySpreadContraint"], + initial=v(), + factory=pvector, + ) labels = field( type=PMap if not TYPE_CHECKING else PMap[str, str], initial=m(), diff --git a/task_processing/plugins/kubernetes/types.py b/task_processing/plugins/kubernetes/types.py index 8437a7e..8b724e4 100644 --- a/task_processing/plugins/kubernetes/types.py +++ b/task_processing/plugins/kubernetes/types.py @@ -88,3 +88,10 @@ class PodEvent(TypedDict): object: V1Pod # this is just the dict-ified version of object - but it's too big to type here raw_object: Dict[str, Any] + + +class TopologySpreadContraint(TypedDict): + max_skew: int + topology_key: str + when_unsatisfiable: str + label_selector: Dict[str, str] diff --git a/task_processing/plugins/kubernetes/utils.py b/task_processing/plugins/kubernetes/utils.py index 89a6117..c54e039 100644 --- a/task_processing/plugins/kubernetes/utils.py +++ b/task_processing/plugins/kubernetes/utils.py @@ -11,6 +11,7 @@ from kubernetes.client import V1EnvVarSource from kubernetes.client import V1HostPathVolumeSource from kubernetes.client import V1KeyToPath +from kubernetes.client import V1LabelSelector from kubernetes.client import V1NodeAffinity from kubernetes.client import V1NodeSelector from kubernetes.client import V1NodeSelectorRequirement @@ -20,6 +21,7 @@ from kubernetes.client import V1SecretKeySelector from kubernetes.client import V1SecretVolumeSource from kubernetes.client import V1ServiceAccountTokenProjection +from kubernetes.client import V1TopologySpreadConstraint from kubernetes.client import V1Volume from kubernetes.client import V1VolumeMount from kubernetes.client import V1VolumeProjection @@ -27,6 +29,7 @@ from pyrsistent.typing import PVector from task_processing.plugins.kubernetes.types import NodeAffinityOperator +from task_processing.plugins.kubernetes.types import TopologySpreadContraint if TYPE_CHECKING: from task_processing.plugins.kubernetes.types import EmptyVolume @@ -417,3 +420,22 @@ def get_kubernetes_service_account_token_volume_mounts( ) for volume in sa_volumes ] + + +def get_topology_spread_constraints( + constraints: PVector[TopologySpreadContraint], +) -> List[V1TopologySpreadConstraint]: + """Build toplogy spread constraints for pod + + :param PVector["TopologySpreadContraint"] constraints: list of topology spread constraint configs + :return: list of kubernetes topology spread constraint objects + """ + return [ + V1TopologySpreadConstraint( + label_selector=V1LabelSelector(match_labels=constraint["label_selector"]), + max_skew=constraint["max_skew"], + topology_key=constraint["topology_key"], + when_unsatisfiable=constraint["when_unsatisfiable"], + ) + for constraint in constraints + ] diff --git a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py index 96356c5..e3352dc 100644 --- a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py +++ b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py @@ -8,6 +8,7 @@ from kubernetes.client import V1Container from kubernetes.client import V1ContainerPort from kubernetes.client import V1HostPathVolumeSource +from kubernetes.client import V1LabelSelector from kubernetes.client import V1ObjectMeta from kubernetes.client import V1Pod from kubernetes.client import V1PodSecurityContext @@ -16,6 +17,7 @@ from kubernetes.client import V1ResourceRequirements from kubernetes.client import V1SecurityContext from kubernetes.client import V1ServiceAccountTokenProjection +from kubernetes.client import V1TopologySpreadConstraint from kubernetes.client import V1Volume from kubernetes.client import V1VolumeMount from kubernetes.client import V1VolumeProjection @@ -220,6 +222,7 @@ def test_run_single_request_memory(mock_get_node_affinity, k8s_executor): ), node_selector={"hello": "world"}, affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[], dns_policy="Default", service_account_name=task_config.service_account_name, ), @@ -321,6 +324,7 @@ def test_run_single_request_cpu(mock_get_node_affinity, k8s_executor): ), node_selector={"hello": "world"}, affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[], dns_policy="Default", service_account_name=task_config.service_account_name, ), @@ -426,6 +430,7 @@ def test_run_both_requests(mock_get_node_affinity, k8s_executor): ), node_selector={"hello": "world"}, affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[], dns_policy="Default", service_account_name=task_config.service_account_name, ), @@ -526,6 +531,7 @@ def test_run_no_requests(mock_get_node_affinity, k8s_executor): ), node_selector={"hello": "world"}, affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[], dns_policy="Default", service_account_name=task_config.service_account_name, ), @@ -677,6 +683,7 @@ def test_run_authentication_token(mock_get_node_affinity, k8s_executor): ), node_selector={"hello": "world"}, affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[], dns_policy="Default", service_account_name=task_config.service_account_name, ), @@ -692,6 +699,111 @@ def test_run_authentication_token(mock_get_node_affinity, k8s_executor): ] +@mock.patch( + "task_processing.plugins.kubernetes.kubernetes_pod_executor.get_node_affinity", + autospec=True, +) +def test_run_topology_spread_constraint(mock_get_node_affinity, k8s_executor): + task_config = KubernetesTaskConfig( + name="fake_task_name", + uuid="fake_id", + image="fake_docker_image", + command="fake_command", + cpus=1, + cpus_request=0.5, + memory=1024, + disk=1024, + volumes=[], + projected_sa_volumes=[], + node_selectors={"hello": "world"}, + node_affinities=[dict(key="a_label", operator="In", value=[])], + topology_spread_constraints=[ + { + "max_skew": 1, + "topology_key": "topology.kubernetes.io/zone", + "when_unsatisfiable": "ScheduleAnyway", + "label_selector": { + "app.kubernetes.io/managed-by": "task_proc", + }, + }, + ], + labels={ + "some_label": "some_label_value", + }, + annotations={ + "paasta.yelp.com/some_annotation": "some_value", + }, + service_account_name="testsa", + ports=[8888], + stdin=True, + stdin_once=True, + tty=True, + ) + expected_container = V1Container( + image=task_config.image, + name="main", + command=["/bin/sh", "-c"], + args=[task_config.command], + security_context=V1SecurityContext( + capabilities=V1Capabilities(drop=list(task_config.cap_drop)), + ), + resources=V1ResourceRequirements( + limits={ + "cpu": 1.0, + "memory": "1024.0Mi", + "ephemeral-storage": "1024.0Mi", + }, + requests={"cpu": 0.5}, + ), + env=[], + volume_mounts=[], + ports=[V1ContainerPort(container_port=8888)], + stdin=True, + stdin_once=True, + tty=True, + ) + expected_pod = V1Pod( + metadata=V1ObjectMeta( + name=task_config.pod_name, + namespace="task_processing_tests", + labels={ + "some_label": "some_label_value", + }, + annotations={ + "paasta.yelp.com/some_annotation": "some_value", + }, + ), + spec=V1PodSpec( + restart_policy=task_config.restart_policy, + containers=[expected_container], + volumes=[], + share_process_namespace=True, + security_context=V1PodSecurityContext( + fs_group=task_config.fs_group, + ), + node_selector={"hello": "world"}, + affinity=V1Affinity(node_affinity=mock_get_node_affinity.return_value), + topology_spread_constraints=[ + V1TopologySpreadConstraint( + max_skew=1, + topology_key="topology.kubernetes.io/zone", + when_unsatisfiable="ScheduleAnyway", + label_selector=V1LabelSelector( + match_labels={"app.kubernetes.io/managed-by": "task_proc"} + ), + ), + ], + dns_policy="Default", + service_account_name=task_config.service_account_name, + ), + ) + + assert k8s_executor.run(task_config) == task_config.pod_name + assert k8s_executor.kube_client.core.create_namespaced_pod.call_args_list == [ + mock.call(body=expected_pod, namespace="task_processing_tests") + ] + + def test_process_event_enqueues_task_processing_events_pending_to_running(k8s_executor): mock_pod = mock.Mock(spec=V1Pod) mock_pod.metadata.name = "test.1234"