Skip to content

Commit

Permalink
Address issue with AWS instance type schema (#2787)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
viniciusdc and pre-commit-ci[bot] authored Oct 29, 2024
1 parent 55094c3 commit c384b06
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class AzureInputVars(schema.Base):
workload_identity_enabled: bool = False


class AWSAmiTypes(enum.Enum):
class AWSAmiTypes(str, enum.Enum):
AL2_x86_64 = "AL2_x86_64"
AL2_x86_64_GPU = "AL2_x86_64_GPU"
CUSTOM = "CUSTOM"
Expand All @@ -151,25 +151,17 @@ class AWSNodeGroupInputVars(schema.Base):
ami_type: Optional[AWSAmiTypes] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None

@field_validator("ami_type", mode="before")
@classmethod
def _infer_and_validate_ami_type(cls, value, values) -> str:
gpu_enabled = values.get("gpu", False)

# Auto-set ami_type if not provided
if not value:
if values.get("launch_template") and values["launch_template"].ami_id:
return "CUSTOM"
if gpu_enabled:
return "AL2_x86_64_GPU"
return "AL2_x86_64"

# Explicit validation
if value == "AL2_x86_64" and gpu_enabled:
raise ValueError(
"ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)."
)
return value

def construct_aws_ami_type(gpu_enabled: bool, launch_template: AWSNodeLaunchTemplate):
"""Construct the AWS AMI type based on the provided parameters."""

if launch_template and launch_template.ami_id:
return "CUSTOM"

if gpu_enabled:
return "AL2_x86_64_GPU"

return "AL2_x86_64"


class AWSInputVars(schema.Base):
Expand Down Expand Up @@ -858,6 +850,10 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
single_subnet=node_group.single_subnet,
permissions_boundary=node_group.permissions_boundary,
launch_template=node_group.launch_template,
ami_type=construct_aws_ami_type(
gpu_enabled=node_group.gpu,
launch_template=node_group.launch_template,
),
)
for name, node_group in self.config.amazon_web_services.node_groups.items()
],
Expand Down

0 comments on commit c384b06

Please sign in to comment.