Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taint to user and worker nodes #2605

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5000f06
save progress
Adam-D-Lewis Jun 26, 2024
7ce8555
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 16, 2024
a661514
fix node taint check
Adam-D-Lewis Aug 16, 2024
fb55fab
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 19, 2024
7f1800d
fix node taints on gcp
Adam-D-Lewis Aug 19, 2024
40940f6
add latest changes
Adam-D-Lewis Aug 19, 2024
cdac5c6
merge develop
Adam-D-Lewis Aug 21, 2024
6382c7b
allow daemonsets to run on user node group
Adam-D-Lewis Aug 21, 2024
e9d9dd9
recreate node groups when taints change
Adam-D-Lewis Aug 21, 2024
c55cd5f
quick attempt to get scheduler running on tanted worker node group
Adam-D-Lewis Aug 21, 2024
57e6e09
Merge branch 'main' into node-taint
Adam-D-Lewis Oct 25, 2024
a1370c9
add default options to options_handler
Adam-D-Lewis Oct 25, 2024
0e7e11c
add comments
Adam-D-Lewis Oct 28, 2024
adb9d74
rename variable
Adam-D-Lewis Oct 31, 2024
7944071
add comment
Adam-D-Lewis Oct 31, 2024
fa81fb9
make work for all providers
Adam-D-Lewis Oct 31, 2024
da9fd82
move var back
Adam-D-Lewis Oct 31, 2024
6a1f81d
move var back
Adam-D-Lewis Oct 31, 2024
b4c08f3
move var back
Adam-D-Lewis Oct 31, 2024
9bae2a1
move var back
Adam-D-Lewis Oct 31, 2024
b3dbeda
add reference
Adam-D-Lewis Oct 31, 2024
97858d0
refactor
Adam-D-Lewis Nov 1, 2024
4ac7b9c
various fixes for aws and azure providers
Adam-D-Lewis Nov 1, 2024
480647b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 1, 2024
f6b9a4f
add taint conversion for AWS
Adam-D-Lewis Nov 4, 2024
e752a3a
add DEFAULT_.*_TAINT vars
Adam-D-Lewis Nov 4, 2024
59daa0c
clean up fixed TODOs
Adam-D-Lewis Nov 4, 2024
e05f143
more clean up
Adam-D-Lewis Nov 4, 2024
3a4ae6b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 4, 2024
f3cb2e9
fix test
Adam-D-Lewis Nov 4, 2024
b125e8c
fix test error
Adam-D-Lewis Nov 4, 2024
8f9f846
Merge branch 'main' into node-taint
dcmcand Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,33 @@ class ExistingInputVars(schema.Base):
kube_context: str


class DigitalOceanNodeGroup(schema.Base):
Copy link
Member Author

@Adam-D-Lewis Adam-D-Lewis Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate class, so I deleted it

class NodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = []

@field_validator("taints", mode="before")
def validate_taint_strings(cls, value: List[str | schema.Taint]):
TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
parsed_taints = []
for taint in value:
if not isinstance(taint, (str, schema.Taint)):
raise ValueError(
f"Unable to parse type: {type(taint)} as taint. Must be a string or Taint object."
)

if isinstance(taint, schema.Taint):
parsed_taint = taint
elif isinstance(taint, str):
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, value, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=value, effect=effect)
parsed_taints.append(parsed_taint)

return parsed_taints


class DigitalOceanInputVars(schema.Base):
Expand All @@ -53,7 +76,7 @@ class DigitalOceanInputVars(schema.Base):
region: str
tags: List[str]
kubernetes_version: str
node_groups: Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, "DigitalOceanNodeGroup"]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand All @@ -62,10 +85,26 @@ class GCPNodeGroupInputVars(schema.Base):
instance_type: str
min_size: int
max_size: int
node_taints: List[dict]
labels: Dict[str, str]
preemptible: bool
guest_accelerators: List["GCPGuestAccelerator"]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class GCPPrivateClusterConfig(schema.Base):
enable_private_nodes: bool
Expand Down Expand Up @@ -211,16 +250,14 @@ class KeyValueDict(schema.Base):
value: str


class DigitalOceanNodeGroup(schema.Base):
class DigitalOceanNodeGroup(NodeGroup):
"""Representation of a node group with Digital Ocean

- Kubernetes limits: https://docs.digitalocean.com/products/kubernetes/details/limits/
- Available instance types: https://slugs.do-api.dev/
"""

instance: str
min_nodes: Annotated[int, Field(ge=1)] = 1
max_nodes: Annotated[int, Field(ge=1)] = 1


DEFAULT_DO_NODE_GROUPS = {
Expand Down Expand Up @@ -305,19 +342,26 @@ class GCPGuestAccelerator(schema.Base):
count: Annotated[int, Field(ge=1)] = 1


class GCPNodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
class GCPNodeGroup(NodeGroup):
preemptible: bool = False
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


DEFAULT_GCP_NODE_GROUPS = {
"general": GCPNodeGroup(instance="e2-highmem-4", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="user", effect="NoSchedule")],
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="worker", effect="NoSchedule")],
),
Adam-D-Lewis marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down Expand Up @@ -355,10 +399,8 @@ def _check_input(cls, data: Any) -> Any:
return data


class AzureNodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
class AzureNodeGroup(NodeGroup):
pass


DEFAULT_AZURE_NODE_GROUPS = {
Expand Down Expand Up @@ -426,10 +468,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]:
return value if value is None else azure_cloud.validate_tags(value)


class AWSNodeGroup(schema.Base):
instance: str
min_nodes: int = 0
max_nodes: int
class AWSNodeGroup(NodeGroup):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
Expand Down Expand Up @@ -738,6 +777,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance_type=node_group.instance,
min_size=node_group.min_nodes,
max_size=node_group.max_nodes,
node_taints=node_group.taints,
preemptible=node_group.preemptible,
guest_accelerators=node_group.guest_accelerators,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ resource "google_container_node_pool" "main" {

oauth_scopes = local.node_group_oauth_scopes

dynamic "taint" {
for_each = local.merged_node_groups[count.index].node_taints
content {
key = taint.value.key
value = taint.value.value
effect = taint.value.effect
}
}

metadata = {
disable-legacy-endpoints = "true"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ variable "node_groups" {
min_size = 1
max_size = 1
labels = {}
node_taints = []
},
{
name = "user"
instance_type = "n1-standard-2"
min_size = 0
max_size = 2
labels = {}
node_taints = [] # TODO: Do this for other cloud providers
},
{
name = "worker"
instance_type = "n1-standard-2"
min_size = 0
max_size = 5
labels = {}
node_taints = []
}
]
}
Expand Down
22 changes: 22 additions & 0 deletions src/_nebari/stages/kubernetes_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,19 @@ class CondaStoreInputVars(schema.Base):
)


class TolerationOperatorEnum(str, enum.Enum):
Equal = "Equal"
Exists = "Exists"

@classmethod
def to_yaml(cls, representer, node):
return representer.represent_str(node.value)


class Toleration(schema.Taint):
operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal


class JupyterhubInputVars(schema.Base):
jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme")
jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image")
Expand All @@ -405,6 +418,9 @@ class JupyterhubInputVars(schema.Base):
jhub_apps_enabled: bool = Field(alias="jhub-apps-enabled")
cloud_provider: str = Field(alias="cloud-provider")
jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir")
node_taint_tolerations: Optional[List[Toleration]] = Field(
alias="node-taint-tolerations"
)


class DaskGatewayInputVars(schema.Base):
Expand Down Expand Up @@ -565,6 +581,12 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
jupyterlab_default_settings=self.config.jupyterlab.default_settings,
jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings,
jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir,
node_taint_tolerations=[
Toleration(**taint.model_dump())
for taint in self.config.google_cloud_platform.node_groups[
"user"
].taints
], # TODO: support other cloud providers
)

dask_gateway_vars = DaskGatewayInputVars(
Expand Down
11 changes: 11 additions & 0 deletions src/_nebari/stages/kubernetes_services/template/jupyterhub.tf
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ variable "idle-culler-settings" {
type = any
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}

module "kubernetes-nfs-server" {
count = var.jupyterhub-shared-endpoint == null ? 1 : 0

Expand Down Expand Up @@ -137,6 +147,7 @@ module "jupyterhub" {
conda-store-service-name = module.kubernetes-conda-store-server.service_name
conda-store-jhub-apps-token = module.kubernetes-conda-store-server.service-tokens.jhub-apps
jhub-apps-enabled = var.jhub-apps-enabled
node-taint-tolerations = var.node-taint-tolerations

extra-mounts = {
"/etc/dask" = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@ def base_profile_extra_mounts():
}


def node_taint_tolerations():
tolerations = z2jh.get_config("custom.node-taint-tolerations")

if not tolerations:
return {}

return {
"tolerations": [
{
"key": taint["key"],
"operator": taint["operator"],
"value": taint["value"],
"effect": taint["effect"],
}
for taint in tolerations
]
}


def configure_user_provisioned_repositories(username):
# Define paths and configurations
pvc_home_mount_path = f"home/{username}"
Expand Down Expand Up @@ -519,6 +538,7 @@ def render_profile(profile, username, groups, keycloak_profilenames):
configure_user(username, groups),
configure_user_provisioned_repositories(username),
profile_kubespawner_override,
node_taint_tolerations(),
],
{},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ resource "helm_release" "jupyterhub" {
conda-store-jhub-apps-token = var.conda-store-jhub-apps-token
jhub-apps-enabled = var.jhub-apps-enabled
initial-repositories = var.initial-repositories
node-taint-tolerations = var.node-taint-tolerations
skel-mount = {
name = kubernetes_config_map.etc-skel.metadata.0.name
namespace = kubernetes_config_map.etc-skel.metadata.0.namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,13 @@ variable "initial-repositories" {
type = string
default = "[]"
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}
13 changes: 13 additions & 0 deletions src/nebari/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,16 @@ def is_version_accepted(v):
for deployment with the current Nebari package.
"""
return Main.is_version_accepted(v)


# TODO: Make sure the taint is actually applied to the nodes for each provider
class TaintEffectEnum(str, enum.Enum):
NoSchedule: str = "NoSchedule"
PreferNoSchedule: str = "PreferNoSchedule"
NoExecute: str = "NoExecute"


class Taint(Base):
key: str
value: str
effect: TaintEffectEnum
Loading