Skip to content

Commit

Permalink
Re-enable AWS tags support (#2096)
Browse files Browse the repository at this point in the history
  • Loading branch information
iameskild authored Nov 14, 2023
1 parent b9158e4 commit f28436a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 57 deletions.
98 changes: 47 additions & 51 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import re
import sys
import tempfile
import typing
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import pydantic

Expand Down Expand Up @@ -52,9 +51,9 @@ class DigitalOceanInputVars(schema.Base):
name: str
environment: str
region: str
tags: typing.List[str]
tags: List[str]
kubernetes_version: str
node_groups: typing.Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, DigitalOceanNodeGroup]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand Down Expand Up @@ -143,6 +142,7 @@ class AWSInputVars(schema.Base):
vpc_cidr_block: str
permissions_boundary: Optional[str] = None
kubeconfig_filename: str = get_kubeconfig_filename()
tags: Dict[str, str] = {}


def _calculate_node_groups(config: schema.Main):
Expand Down Expand Up @@ -216,7 +216,7 @@ class DigitalOceanProvider(schema.Base):
region: str
kubernetes_version: str
# Digital Ocean image slugs are listed here https://slugs.do-api.dev/
node_groups: typing.Dict[str, DigitalOceanNodeGroup] = {
node_groups: Dict[str, DigitalOceanNodeGroup] = {
"general": DigitalOceanNodeGroup(
instance="g-8vcpu-32gb", min_nodes=1, max_nodes=1
),
Expand All @@ -227,7 +227,7 @@ class DigitalOceanProvider(schema.Base):
instance="g-4vcpu-16gb", min_nodes=1, max_nodes=5
),
}
tags: typing.Optional[typing.List[str]] = []
tags: Optional[List[str]] = []

@pydantic.validator("region")
def _validate_region(cls, value):
Expand Down Expand Up @@ -289,7 +289,7 @@ class GCPCIDRBlock(schema.Base):


class GCPMasterAuthorizedNetworksConfig(schema.Base):
cidr_blocks: typing.List[GCPCIDRBlock]
cidr_blocks: List[GCPCIDRBlock]


class GCPPrivateClusterConfig(schema.Base):
Expand All @@ -314,34 +314,28 @@ class GCPNodeGroup(schema.Base):
min_nodes: pydantic.conint(ge=0) = 0
max_nodes: pydantic.conint(ge=1) = 1
preemptible: bool = False
labels: typing.Dict[str, str] = {}
guest_accelerators: typing.List[GCPGuestAccelerator] = []
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


class GoogleCloudPlatformProvider(schema.Base):
region: str
project: str
kubernetes_version: str
availability_zones: typing.Optional[typing.List[str]] = []
availability_zones: Optional[List[str]] = []
release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL
node_groups: typing.Dict[str, GCPNodeGroup] = {
node_groups: Dict[str, GCPNodeGroup] = {
"general": GCPNodeGroup(instance="n1-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5),
}
tags: typing.Optional[typing.List[str]] = []
tags: Optional[List[str]] = []
networking_mode: str = "ROUTE"
network: str = "default"
subnetwork: typing.Optional[typing.Union[str, None]] = None
ip_allocation_policy: typing.Optional[
typing.Union[GCPIPAllocationPolicy, None]
] = None
master_authorized_networks_config: typing.Optional[
typing.Union[GCPCIDRBlock, None]
] = None
private_cluster_config: typing.Optional[
typing.Union[GCPPrivateClusterConfig, None]
] = None
subnetwork: Optional[Union[str, None]] = None
ip_allocation_policy: Optional[Union[GCPIPAllocationPolicy, None]] = None
master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None
private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None

@pydantic.root_validator
def validate_all(cls, values):
Expand Down Expand Up @@ -381,18 +375,18 @@ class AzureProvider(schema.Base):
kubernetes_version: str
storage_account_postfix: str
resource_group_name: str = None
node_groups: typing.Dict[str, AzureNodeGroup] = {
node_groups: Dict[str, AzureNodeGroup] = {
"general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1),
"user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
"worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
}
storage_account_postfix: str
vnet_subnet_id: typing.Optional[typing.Union[str, None]] = None
vnet_subnet_id: Optional[Union[str, None]] = None
private_cluster_enabled: bool = False
resource_group_name: typing.Optional[str] = None
tags: typing.Optional[typing.Dict[str, str]] = {}
network_profile: typing.Optional[typing.Dict[str, str]] = None
max_pods: typing.Optional[int] = None
resource_group_name: Optional[str] = None
tags: Optional[Dict[str, str]] = {}
network_profile: Optional[Dict[str, str]] = None
max_pods: Optional[int] = None

@pydantic.validator("kubernetes_version")
def _validate_kubernetes_version(cls, value):
Expand Down Expand Up @@ -440,8 +434,8 @@ class AWSNodeGroup(schema.Base):
class AmazonWebServicesProvider(schema.Base):
region: str
kubernetes_version: str
availability_zones: typing.Optional[typing.List[str]]
node_groups: typing.Dict[str, AWSNodeGroup] = {
availability_zones: Optional[List[str]]
node_groups: Dict[str, AWSNodeGroup] = {
"general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1),
"user": AWSNodeGroup(
instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False
Expand All @@ -450,10 +444,11 @@ class AmazonWebServicesProvider(schema.Base):
instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False
),
}
existing_subnet_ids: typing.List[str] = None
existing_security_group_ids: str = None
existing_subnet_ids: List[str] = None
existing_security_group_id: str = None
vpc_cidr_block: str = "10.10.0.0/16"
permissions_boundary: Optional[str] = None
tags: Optional[Dict[str, str]] = {}

@pydantic.root_validator
def validate_all(cls, values):
Expand Down Expand Up @@ -491,17 +486,17 @@ def validate_all(cls, values):


class LocalProvider(schema.Base):
kube_context: typing.Optional[str]
node_selectors: typing.Dict[str, KeyValueDict] = {
kube_context: Optional[str]
node_selectors: Dict[str, KeyValueDict] = {
"general": KeyValueDict(key="kubernetes.io/os", value="linux"),
"user": KeyValueDict(key="kubernetes.io/os", value="linux"),
"worker": KeyValueDict(key="kubernetes.io/os", value="linux"),
}


class ExistingProvider(schema.Base):
kube_context: typing.Optional[str]
node_selectors: typing.Dict[str, KeyValueDict] = {
kube_context: Optional[str]
node_selectors: Dict[str, KeyValueDict] = {
"general": KeyValueDict(key="kubernetes.io/os", value="linux"),
"user": KeyValueDict(key="kubernetes.io/os", value="linux"),
"worker": KeyValueDict(key="kubernetes.io/os", value="linux"),
Expand Down Expand Up @@ -532,12 +527,12 @@ class ExistingProvider(schema.Base):


class InputSchema(schema.Base):
local: typing.Optional[LocalProvider]
existing: typing.Optional[ExistingProvider]
google_cloud_platform: typing.Optional[GoogleCloudPlatformProvider]
amazon_web_services: typing.Optional[AmazonWebServicesProvider]
azure: typing.Optional[AzureProvider]
digital_ocean: typing.Optional[DigitalOceanProvider]
local: Optional[LocalProvider]
existing: Optional[ExistingProvider]
google_cloud_platform: Optional[GoogleCloudPlatformProvider]
amazon_web_services: Optional[AmazonWebServicesProvider]
azure: Optional[AzureProvider]
digital_ocean: Optional[DigitalOceanProvider]

@pydantic.root_validator(pre=True)
def check_provider(cls, values):
Expand Down Expand Up @@ -580,20 +575,20 @@ class NodeSelectorKeyValue(schema.Base):
class KubernetesCredentials(schema.Base):
host: str
cluster_ca_certifiate: str
token: typing.Optional[str]
username: typing.Optional[str]
password: typing.Optional[str]
client_certificate: typing.Optional[str]
client_key: typing.Optional[str]
config_path: typing.Optional[str]
config_context: typing.Optional[str]
token: Optional[str]
username: Optional[str]
password: Optional[str]
client_certificate: Optional[str]
client_key: Optional[str]
config_path: Optional[str]
config_context: Optional[str]


class OutputSchema(schema.Base):
node_selectors: Dict[str, NodeSelectorKeyValue]
kubernetes_credentials: KubernetesCredentials
kubeconfig_filename: str
nfs_endpoint: typing.Optional[str]
nfs_endpoint: Optional[str]


class KubernetesInfrastructureStage(NebariTerraformStage):
Expand Down Expand Up @@ -760,7 +755,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
name=self.config.escaped_project_name,
environment=self.config.namespace,
existing_subnet_ids=self.config.amazon_web_services.existing_subnet_ids,
existing_security_group_id=self.config.amazon_web_services.existing_security_group_ids,
existing_security_group_id=self.config.amazon_web_services.existing_security_group_id,
region=self.config.amazon_web_services.region,
kubernetes_version=self.config.amazon_web_services.kubernetes_version,
node_groups=[
Expand All @@ -779,6 +774,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
availability_zones=self.config.amazon_web_services.availability_zones,
vpc_cidr_block=self.config.amazon_web_services.vpc_cidr_block,
permissions_boundary=self.config.amazon_web_services.permissions_boundary,
tags=self.config.amazon_web_services.tags,
).dict()
else:
raise ValueError(f"Unknown provider: {self.config.provider}")
Expand Down
14 changes: 8 additions & 6 deletions src/_nebari/stages/infrastructure/template/aws/locals.tf
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
locals {
additional_tags = {
Project = var.name
Owner = "terraform"
Environment = var.environment
}

additional_tags = merge(
{
Project = var.name
Owner = "terraform"
Environment = var.environment
},
var.tags,
)
cluster_name = "${var.name}-${var.environment}"
}
6 changes: 6 additions & 0 deletions src/_nebari/stages/infrastructure/template/aws/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,9 @@ variable "permissions_boundary" {
type = string
default = null
}

variable "tags" {
description = "Additional tags to add to resources"
type = map(string)
default = {}
}

0 comments on commit f28436a

Please sign in to comment.