Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taint to user nodes #2605

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
82 changes: 61 additions & 21 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,33 @@ class ExistingInputVars(schema.Base):
kube_context: str


class DigitalOceanNodeGroup(schema.Base):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate class

class NodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = []

@field_validator("taints", mode="before")
def validate_taint_strings(cls, value: List[str | schema.Taint]):
TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
parsed_taints = []
for taint in value:
if not isinstance(taint, (str, schema.Taint)):
raise ValueError(
f"Unable to parse type: {type(taint)} as taint. Must be a string or Taint object."
)

if isinstance(taint, schema.Taint):
parsed_taint = taint
elif isinstance(taint, str):
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, value, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=value, effect=effect)
parsed_taints.append(parsed_taint)

return parsed_taints


class DigitalOceanInputVars(schema.Base):
Expand All @@ -55,7 +78,7 @@ class DigitalOceanInputVars(schema.Base):
region: str
tags: List[str]
kubernetes_version: str
node_groups: Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, "DigitalOceanNodeGroup"]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand All @@ -64,10 +87,26 @@ class GCPNodeGroupInputVars(schema.Base):
instance_type: str
min_size: int
max_size: int
node_taints: List[dict]
labels: Dict[str, str]
preemptible: bool
guest_accelerators: List["GCPGuestAccelerator"]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class GCPPrivateClusterConfig(schema.Base):
enable_private_nodes: bool
Expand Down Expand Up @@ -225,16 +264,14 @@ class KeyValueDict(schema.Base):
value: str


class DigitalOceanNodeGroup(schema.Base):
class DigitalOceanNodeGroup(NodeGroup):
"""Representation of a node group with Digital Ocean

- Kubernetes limits: https://docs.digitalocean.com/products/kubernetes/details/limits/
- Available instance types: https://slugs.do-api.dev/
"""

instance: str
min_nodes: Annotated[int, Field(ge=1)] = 1
max_nodes: Annotated[int, Field(ge=1)] = 1


DEFAULT_DO_NODE_GROUPS = {
Expand Down Expand Up @@ -319,19 +356,26 @@ class GCPGuestAccelerator(schema.Base):
count: Annotated[int, Field(ge=1)] = 1


class GCPNodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
class GCPNodeGroup(NodeGroup):
preemptible: bool = False
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


DEFAULT_GCP_NODE_GROUPS = {
"general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="user", effect="NoSchedule")],
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="worker", effect="NoSchedule")],
),
}


Expand Down Expand Up @@ -369,10 +413,8 @@ def _check_input(cls, data: Any) -> Any:
return data


class AzureNodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
class AzureNodeGroup(NodeGroup):
pass


DEFAULT_AZURE_NODE_GROUPS = {
Expand Down Expand Up @@ -440,10 +482,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]:
return value if value is None else azure_cloud.validate_tags(value)


class AWSNodeGroup(schema.Base):
instance: str
min_nodes: int = 0
max_nodes: int
class AWSNodeGroup(NodeGroup):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
Expand Down Expand Up @@ -752,6 +791,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance_type=node_group.instance,
min_size=node_group.min_nodes,
max_size=node_group.max_nodes,
node_taints=node_group.taints,
preemptible=node_group.preemptible,
guest_accelerators=node_group.guest_accelerators,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ resource "google_container_node_pool" "main" {

oauth_scopes = local.node_group_oauth_scopes

dynamic "taint" {
for_each = local.merged_node_groups[count.index].node_taints
content {
key = taint.value.key
value = taint.value.value
effect = taint.value.effect
}
}

metadata = {
disable-legacy-endpoints = "true"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ variable "node_groups" {
min_size = 1
max_size = 1
labels = {}
node_taints = []
},
{
name = "user"
instance_type = "n1-standard-2"
min_size = 0
max_size = 2
labels = {}
node_taints = [] # TODO: Do this for other cloud providers
},
{
name = "worker"
instance_type = "n1-standard-2"
min_size = 0
max_size = 5
labels = {}
node_taints = []
}
]
}
Expand Down
22 changes: 22 additions & 0 deletions src/_nebari/stages/kubernetes_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,19 @@ def handle_units(cls, value: Optional[str]) -> float:
return byte_unit_conversion(value, "GiB")


class TolerationOperatorEnum(str, enum.Enum):
Equal = "Equal"
Exists = "Exists"

@classmethod
def to_yaml(cls, representer, node):
return representer.represent_str(node.value)


class Toleration(schema.Taint):
operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal


class JupyterhubInputVars(schema.Base):
jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme")
jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image")
Expand All @@ -467,6 +480,9 @@ class JupyterhubInputVars(schema.Base):
cloud_provider: str = Field(alias="cloud-provider")
jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir")
shared_fs_type: SharedFsEnum
node_taint_tolerations: Optional[List[Toleration]] = Field(
alias="node-taint-tolerations"
)

@field_validator("jupyterhub_shared_storage", mode="before")
@classmethod
Expand Down Expand Up @@ -634,6 +650,12 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
jupyterlab_default_settings=self.config.jupyterlab.default_settings,
jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings,
jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir,
node_taint_tolerations=[
Toleration(**taint.model_dump())
for taint in self.config.google_cloud_platform.node_groups[
"user"
].taints
], # TODO: support other cloud providers
shared_fs_type=(
# efs is equivalent to nfs in these modules
SharedFsEnum.nfs
Expand Down
11 changes: 11 additions & 0 deletions src/_nebari/stages/kubernetes_services/template/jupyterhub.tf
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ variable "idle-culler-settings" {
type = any
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}

variable "shared_fs_type" {
type = string
description = "Use NFS or Ceph"
Expand Down Expand Up @@ -175,6 +185,7 @@ module "jupyterhub" {
conda-store-service-name = module.kubernetes-conda-store-server.service_name
conda-store-jhub-apps-token = module.kubernetes-conda-store-server.service-tokens.jhub-apps
jhub-apps-enabled = var.jhub-apps-enabled
node-taint-tolerations = var.node-taint-tolerations

extra-mounts = {
"/etc/dask" = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@ def base_profile_extra_mounts():
}


def node_taint_tolerations():
tolerations = z2jh.get_config("custom.node-taint-tolerations")

if not tolerations:
return {}

return {
"tolerations": [
{
"key": taint["key"],
"operator": taint["operator"],
"value": taint["value"],
"effect": taint["effect"],
}
for taint in tolerations
]
}


def configure_user_provisioned_repositories(username):
# Define paths and configurations
pvc_home_mount_path = f"home/{username}"
Expand Down Expand Up @@ -519,6 +538,7 @@ def render_profile(profile, username, groups, keycloak_profilenames):
configure_user(username, groups),
configure_user_provisioned_repositories(username),
profile_kubespawner_override,
node_taint_tolerations(),
],
{},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ resource "helm_release" "jupyterhub" {
conda-store-jhub-apps-token = var.conda-store-jhub-apps-token
jhub-apps-enabled = var.jhub-apps-enabled
initial-repositories = var.initial-repositories
node-taint-tolerations = var.node-taint-tolerations
skel-mount = {
name = kubernetes_config_map.etc-skel.metadata.0.name
namespace = kubernetes_config_map.etc-skel.metadata.0.namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,13 @@ variable "initial-repositories" {
type = string
default = "[]"
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ resource "helm_release" "grafana-promtail" {
values = concat([
file("${path.module}/values_promtail.yaml"),
jsonencode({
tolerations = [
{
key = "node-role.kubernetes.io/master"
operator = "Exists"
effect = "NoSchedule"
},
{
key = "node-role.kubernetes.io/control-plane"
operator = "Exists"
effect = "NoSchedule"
},
Comment on lines +100 to +109
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These top 2 are the default value for this helm chart.

{
operator = "Exists"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runs promtail on all nodes, even those with NoSchedule taints. Doesn't run on nodes with NoExecute taints. This is what the nebari-prometheus-node-exporter daemonset does so I copied it here. Promtail is what exports logs from the node so we still want it to run on the user and worker nodes.

effect = "NoSchedule"
},
]
})
], var.grafana-promtail-overrides)

Expand Down
7 changes: 7 additions & 0 deletions src/_nebari/stages/kubernetes_services/template/rook-ceph.tf
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ resource "helm_release" "rook-ceph" {
},
csi = {
enableRbdDriver = false, # necessary to provision block storage, but saves some cpu and memory if not needed
provisionerReplicas : 1, # default is 2 on different nodes
pluginTolerations = [
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runs csi-driver on all nodes, even those with NoSchedule taints. Doesn't run on nodes with NoExecute taints. This is what the nebari-prometheus-node-exporter daemonset does so I copied it here.

{
operator = "Exists"
effect = "NoSchedule"
}
],
},
})
],
Expand Down
13 changes: 13 additions & 0 deletions src/nebari/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,16 @@ def is_version_accepted(v):
for deployment with the current Nebari package.
"""
return Main.is_version_accepted(v)


# TODO: Make sure the taint is actually applied to the nodes for each provider
class TaintEffectEnum(str, enum.Enum):
NoSchedule: str = "NoSchedule"
PreferNoSchedule: str = "PreferNoSchedule"
NoExecute: str = "NoExecute"


class Taint(Base):
key: str
value: str
effect: TaintEffectEnum
Loading