Skip to content

Commit

Permalink
fixes on ami_id
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusdc committed Sep 17, 2024
1 parent 6aafcdc commit c211fa6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 106 deletions.
43 changes: 19 additions & 24 deletions src/_nebari/provider/cloud/amazon_web_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from botocore.exceptions import ClientError, EndpointConnectionError

from _nebari.constants import AWS_ENV_DOCS
from _nebari.provider.cloud.commons import (
filter_amis_by_latest_version,
filter_by_highest_supported_k8s_version,
)
from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version
from _nebari.utils import check_environment_variables
from nebari import schema

Expand Down Expand Up @@ -114,34 +111,32 @@ def kubernetes_versions(region: str) -> List[str]:


@functools.lru_cache()
def amis(region: str, k8s_version: str, ami_type: str) -> Dict[str, str]:
def amis(region: str, k8s_version: str, ami_type: str = None) -> Dict[str, str]:
# do an ssm get-parameters-by-path to get the latest AMI for the k8s version
session = aws_session(region=region)
ssm_client = session.client("ssm")
ami_ssm_format = {
"AL2_x86_64": "/aws/service/eks/optimized-ami/{}/amazon-linux-2",
"AL2_x86_64_GPU": "/aws/service/eks/optimized-ami/{}/amazon-linux-2-gpu",
}
ami_specifier = ami_ssm_format.get(ami_type).format(k8s_version)
if ami_specifier is None:
amis = {}

if ami_type and ami_type not in ami_ssm_format:
raise ValueError(f"Unsupported ami_type: {ami_type}")

ssm_client = session.client("ssm")
paginator = ssm_client.get_paginator("get_parameters_by_path")
page_iterator = paginator.paginate(
Path=ami_specifier,
)
ssm_param_name_list = []
for page in page_iterator:
for parameter in page["Parameters"]:
values = json.loads(parameter["Value"])
ssm_param_name_list.append(
{
"Name": values["image_name"],
"Value": values["image_id"],
"LastModifiedDate": parameter["LastModifiedDate"],
}
)
return filter_amis_by_latest_version(ssm_param_name_list)
for type, ssm_path_specifier in ami_ssm_format.items():
if ami_type and ami_type != type:
continue
ami_specifier = ssm_path_specifier.format(k8s_version)
paginator = ssm_client.get_paginator("get_parameters_by_path")
page_iterator = paginator.paginate(
Path=ami_specifier,
)
for page in page_iterator:
for parameter in page["Parameters"]:
values = json.loads(parameter["Value"])
amis[values["image_id"]] = values["image_name"]
return amis


@functools.lru_cache()
Expand Down
15 changes: 0 additions & 15 deletions src/_nebari/provider/cloud/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,3 @@ def filter_by_highest_supported_k8s_version(k8s_versions_list):
if version <= HIGHEST_SUPPORTED_K8S_VERSION:
filtered_k8s_versions_list.append(k8s_version)
return filtered_k8s_versions_list


def filter_amis_by_latest_version(amis_list):
print(amis_list)
latest_amis = {}
for ami in amis_list:
version = tuple(
filter(None, re.search(r"(\d+)\.(\d+)(?:\.(\d+))?", ami["Name"]).groups())
)
if version not in latest_amis:
latest_amis[version] = ami.pop("LastModifiedDate")
else:
if ami["LastModifiedDate"] > latest_amis[version]["LastModifiedDate"]:
latest_amis[version] = ami.pop("LastModifiedDate")
return list(latest_amis.values())
59 changes: 16 additions & 43 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,53 +547,26 @@ def _check_input(cls, data: Any) -> Any:
raise ValueError(
f"Amazon Web Services availability zone={zone} is not one of {available_zones}"
)

# check if instances are valid
available_instances = amazon_web_services.instances(data["region"])
# check if instances and/or ami_ids are valid
if "node_groups" in data:
# Cache for available AMIs per ami_type
available_amis_cache = {}
available_instances = set(amazon_web_services.instances(data["region"]))
# available_amis = set(
# amazon_web_services.amis(data["region"], data["kubernetes_version"])
# )

for _, node_group in data["node_groups"].items():
instance = (
node_group["instance"]
if hasattr(node_group, "__getitem__")
else node_group.instance
)
if instance not in available_instances:
instance = node_group.get("instance")
if instance and instance not in available_instances:
raise ValueError(
f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}"
f"Amazon Web Services instance '{instance}' is not among the available instance types for your region or account."
)

# Check if launch_template and ami_id are provided
print(available_amis_cache)
launch_template = getattr(node_group, "launch_template", None)
if (
launch_template
and getattr(node_group, "ami_type", None) != "CUSTOM"
):
if getattr(launch_template, "ami_id", None):
ami_id = launch_template.ami_id
ami_type = getattr(node_group, "ami_type", None)

# Retrieve available AMIs from cache or API
if ami_type not in available_amis_cache:
available_amis_cache[ami_type] = amazon_web_services.amis(
region=data["region"],
k8s_version=data["kubernetes_version"],
ami_type=ami_type,
)

available_amis = available_amis_cache[ami_type]

# Validate AMI ID
if ami_id not in available_amis:
raise ValueError(
f"Amazon Web Services AMI '{ami_id}' is not among the available AMIs: {available_amis} for AMI type '{ami_type}'"
)
else:
raise ValueError(
"Launch template provided without AMI ID. Please provide an AMI ID."
)
# launch_template = node_group.get("launch_template")
# if launch_template:
# ami_id = launch_template.get("ami_id")
# if ami_id and ami_id not in available_amis:
# raise ValueError(
# f"Invalid AMI ID '{ami_id}' specified in launch_template."
# )

return data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ resource "aws_launch_template" "main" {
if node_group.launch_template != null
}

name = each.value.name
image_id = each.value.launch_template.ami_id
name = each.value.name
image_id = each.value.launch_template.ami_id
instance_type = each.value.instance_type

vpc_security_group_ids = var.cluster_security_groups

Expand Down Expand Up @@ -66,6 +67,7 @@ resource "aws_launch_template" "main" {
)
}


resource "aws_eks_node_group" "main" {
count = length(var.node_groups)

Expand All @@ -74,12 +76,9 @@ resource "aws_eks_node_group" "main" {
node_role_arn = aws_iam_role.node-group.arn
subnet_ids = var.node_groups[count.index].single_subnet ? [element(var.cluster_subnets, 0)] : var.cluster_subnets

instance_types = [var.node_groups[count.index].instance_type]
# ami_type = var.node_groups[count.index].gpu == true ? "AL2_x86_64_GPU" :
# "AL2_x86_64"
ami_type = var.node_groups[count.index].ami_type
# disk_size = var.node_groups[count.index].launch_template == null ? 50 : null
disk_size = 50
instance_types = var.node_groups[count.index].launch_template == null ? [var.node_groups[count.index].instance_type] : null
ami_type = var.node_groups[count.index].ami_type
disk_size = var.node_groups[count.index].launch_template == null ? 50 : null

scaling_config {
min_size = var.node_groups[count.index].min_size
Expand All @@ -88,22 +87,13 @@ resource "aws_eks_node_group" "main" {
}

# Only set launch_template if its node_group counterpart parameter is not null
# dynamic "launch_template" {
# for_each = var.node_groups[count.index].launch_template != null ? [var.node_groups[count.index].launch_template] : []
# content {
# id = aws_launch_template.main[each.key].id
# version = aws_launch_template.main[each.key].latest_version
# }
# }
# The "each" object can be used only in "module" or "resource" blocks, and only when
# the "for_each" argument is set.
# dynamic "launch_template" {
# for_each = var.node_groups[count.index].launch_template != null ? [var.node_groups[count.index].launch_template] : []
# content {
# id = aws_launch_template.main
# version = launch_template.latest_version
# }
# }
dynamic "launch_template" {
for_each = var.node_groups[count.index].launch_template != null ? [var.node_groups[count.index].launch_template] : []
content {
id = aws_launch_template.main[var.node_groups[count.index].name].id
version = aws_launch_template.main[var.node_groups[count.index].name].latest_version
}
}

labels = {
"dedicated" = var.node_groups[count.index].name
Expand Down

0 comments on commit c211fa6

Please sign in to comment.