diff --git a/src/_nebari/deploy.py b/src/_nebari/deploy.py index 394fe0cce2..bc12856868 100644 --- a/src/_nebari/deploy.py +++ b/src/_nebari/deploy.py @@ -13,11 +13,8 @@ def deploy_configuration( config: schema.Main, stages: List[hookspecs.NebariStage], - dns_provider, - dns_auto_provision, disable_prompt: bool = False, disable_checks: bool = False, - skip_remote_state_provision: bool = False, ): if config.prevent_deploy: raise ValueError( @@ -53,7 +50,7 @@ def deploy_configuration( with contextlib.ExitStack() as stack: for stage in stages: s = stage(output_directory=pathlib.Path.cwd(), config=config) - stack.enter_context(s.deploy(stage_outputs)) + stack.enter_context(s.deploy(stage_outputs, disable_prompt)) if not disable_checks: s.check(stage_outputs) diff --git a/src/_nebari/provider/cloud/amazon_web_services.py b/src/_nebari/provider/cloud/amazon_web_services.py index ad5d0b3a39..5719f1699b 100644 --- a/src/_nebari/provider/cloud/amazon_web_services.py +++ b/src/_nebari/provider/cloud/amazon_web_services.py @@ -1,6 +1,8 @@ import functools import os +import re import time +from typing import List import boto3 from botocore.exceptions import ClientError @@ -8,6 +10,9 @@ from _nebari import constants from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version +MAX_RETRIES = 5 +DELAY = 5 + def check_credentials(): for variable in { @@ -22,16 +27,40 @@ def check_credentials(): ) +def aws_session(digitalocean_region: str = None): + if digitalocean_region: + aws_access_key_id = os.environ["SPACES_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["SPACES_SECRET_ACCESS_KEY"] + region = digitalocean_region + aws_session_token = None + else: + check_credentials() + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + region = os.environ["AWS_DEFAULT_REGION"] + + return boto3.Session( + region_name=region, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + + @functools.lru_cache() def regions(): - client = boto3.client("ec2") - response = client.describe_regions() - return {_["RegionName"]: _["RegionName"] for _ in response["Regions"]} + session = aws_session() + ec2_client = session.client("ec2") + regions = ec2_client.describe_regions()["Regions"] + return {_["RegionName"]: _["RegionName"] for _ in regions} @functools.lru_cache() def zones(): - client = boto3.client("ec2") + session = aws_session() + client = session.client("ec2") + response = client.describe_availability_zones() return {_["ZoneName"]: _["ZoneName"] for _ in response["AvailabilityZones"]} @@ -40,8 +69,9 @@ def zones(): def kubernetes_versions(): """Return list of available kubernetes supported by cloud provider. Sorted from oldest to latest.""" # AWS SDK (boto3) currently doesn't offer an intuitive way to list available kubernetes version. This implementation grabs kubernetes versions for specific EKS addons. It will therefore always be (at the very least) a subset of all kubernetes versions still supported by AWS. + session = aws_session() + client = session.client("eks") - client = boto3.client("eks") supported_kubernetes_versions = list() available_addons = client.describe_addon_versions() for addon in available_addons.get("addons", None): @@ -57,7 +87,8 @@ def kubernetes_versions(): @functools.lru_cache() def instances(): - client = boto3.client("ec2") + session = aws_session() + client = session.client("ec2") paginator = client.get_paginator("describe_instance_types") instance_types = sorted( [j["InstanceType"] for i in paginator.paginate() for j in i["InstanceTypes"]] @@ -65,57 +96,798 @@ def instances(): return {t: t for t in instance_types} -def aws_session(region: str, digitalocean: bool = False): - if digitalocean: - aws_access_key_id = os.environ["SPACES_ACCESS_KEY_ID"] - aws_secret_access_key = os.environ["SPACES_SECRET_ACCESS_KEY"] - else: - aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] - aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] +def aws_get_vpc_id(name: str, namespace: str) -> str: + cluster_name = f"{name}-{namespace}" + session = aws_session() + client = session.client("ec2") + response = client.describe_vpcs() - return boto3.session.Session( - region_name=region, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - ) + for vpc in response["Vpcs"]: + tags = vpc.get("Tags", []) + for tag in tags: + if tag["Key"] == "Name" and tag["Value"] == cluster_name: + return vpc["VpcId"] + + +def aws_get_subnet_ids(name: str, namespace: str) -> List[str]: + session = aws_session() + client = session.client("ec2") + response = client.describe_subnets() + + subnet_ids = [] + required_tags = 0 + for subnet in response["Subnets"]: + tags = subnet.get("Tags", []) + for tag in tags: + if ( + tag["Key"] == "Project" + and tag["Value"] == name + or tag["Key"] == "Environment" + and tag["Value"] == namespace + ): + required_tags += 1 + if required_tags == 2: + subnet_ids.append(subnet["SubnetId"]) + required_tags = 0 + + return subnet_ids + + +def aws_get_route_table_ids(name: str, namespace: str) -> List[str]: + cluster_name = f"{name}-{namespace}" + session = aws_session() + client = session.client("ec2") + response = client.describe_route_tables() + + routing_table_ids = [] + for routing_table in response["RouteTables"]: + tags = routing_table.get("Tags", []) + for tag in tags: + if tag["Key"] == "Name" and tag["Value"] == cluster_name: + routing_table_ids.append(routing_table["RouteTableId"]) + + return routing_table_ids + + +def aws_get_internet_gateway_ids(name: str, namespace: str) -> List[str]: + cluster_name = f"{name}-{namespace}" + session = aws_session() + client = session.client("ec2") + response = client.describe_internet_gateways() + + internet_gateways = [] + for internet_gateway in response["InternetGateways"]: + tags = internet_gateway.get("Tags", []) + for tag in tags: + if tag["Key"] == "Name" and tag["Value"] == cluster_name: + internet_gateways.append(internet_gateway["InternetGatewayId"]) + + return internet_gateways + + +def aws_get_security_group_ids(name: str, namespace: str) -> List[str]: + cluster_name = f"{name}-{namespace}" + session = aws_session() + client = session.client("ec2") + response = client.describe_security_groups() + + security_group_ids = [] + for security_group in response["SecurityGroups"]: + tags = security_group.get("Tags", []) + for tag in tags: + if tag["Key"] == "Name" and tag["Value"] == cluster_name: + security_group_ids.append(security_group["GroupId"]) + + return security_group_ids + + +def aws_get_load_balancer_name(vpc_id: str) -> str: + if not vpc_id: + print("No VPC ID provided. Exiting...") + return + + session = aws_session() + client = session.client("elb") + response = client.describe_load_balancers()["LoadBalancerDescriptions"] + + for load_balancer in response: + if load_balancer["VPCId"] == vpc_id: + return load_balancer["LoadBalancerName"] + + +def aws_get_efs_ids(name: str, namespace: str) -> List[str]: + session = aws_session() + client = session.client("efs") + response = client.describe_file_systems() + + efs_ids = [] + required_tags = 0 + for efs in response["FileSystems"]: + tags = efs.get("Tags", []) + for tag in tags: + if ( + tag["Key"] == "Project" + and tag["Value"] == name + or tag["Key"] == "Environment" + and tag["Value"] == namespace + ): + required_tags += 1 + if required_tags == 2: + efs_ids.append(efs["FileSystemId"]) + required_tags = 0 + + return efs_ids + + +def aws_get_efs_mount_target_ids(efs_id: str) -> List[str]: + if not efs_id: + print("No EFS ID provided. Exiting...") + return + + session = aws_session() + client = session.client("efs") + response = client.describe_mount_targets(FileSystemId=efs_id) + + mount_target_ids = [] + for mount_target in response["MountTargets"]: + mount_target_ids.append(mount_target["MountTargetId"]) + + return mount_target_ids + + +def aws_get_ec2_volume_ids(name: str, namespace: str) -> List[str]: + cluster_name = f"{name}-{namespace}" + session = aws_session() + client = session.client("ec2") + response = client.describe_volumes() + + volume_ids = [] + for volume in response["Volumes"]: + tags = volume.get("Tags", []) + for tag in tags: + if tag["Key"] == "KubernetesCluster" and tag["Value"] == cluster_name: + volume_ids.append(volume["VolumeId"]) + + return volume_ids + + +def aws_get_iam_policy(name: str = None, pattern: str = None) -> str: + session = aws_session() + client = session.client("iam") + response = client.list_policies(Scope="Local") + + for policy in response["Policies"]: + if (name and policy["PolicyName"] == name) or ( + pattern and re.match(pattern, policy["PolicyName"]) + ): + return policy["Arn"] + + +def aws_delete_load_balancer(name: str, namespace: str): + vpc_id = aws_get_vpc_id(name, namespace) + if not vpc_id: + print("No VPC ID provided. Exiting...") + return + + load_balancer_name = aws_get_load_balancer_name(vpc_id) + + session = aws_session() + client = session.client("elb") + + try: + print("here") + client.delete_load_balancer(LoadBalancerName=load_balancer_name) + print(f"Initiated deletion for load balancer {load_balancer_name}") + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"Load balancer {load_balancer_name} not found. Exiting...") + return + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + client.describe_load_balancers(LoadBalancerNames=load_balancer_name) + print(f"Waiting for load balancer {load_balancer_name} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"Load balancer {load_balancer_name} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_efs_mount_targets(efs_id: str): + if not efs_id: + print("No EFS provided. Exiting...") + return + + session = aws_session() + client = session.client("efs") + + mount_target_ids = aws_get_efs_mount_target_ids(efs_id) + for mount_target_id in mount_target_ids: + try: + client.delete_mount_target(MountTargetId=mount_target_id) + print(f"Initiated deletion for mount target {mount_target_id}") + except ClientError as e: + if "MountTargetNotFound" in str(e): + print(f"Mount target {mount_target_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + mount_target_ids = aws_get_efs_mount_target_ids(efs_id) + if len(mount_target_ids) == 0: + print(f"All mount targets for EFS {efs_id} deleted successfully") + return + else: + print(f"Waiting for mount targets for EFS {efs_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_efs_file_system(efs_id: str): + if not efs_id: + print("No EFS provided. Exiting...") + return + + session = aws_session() + client = session.client("efs") + + try: + client.delete_file_system(FileSystemId=efs_id) + print(f"Initiated deletion for EFS {efs_id}") + except ClientError as e: + if "FileSystemNotFound" in str(e): + print(f"EFS {efs_id} not found. Exiting...") + return + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + client.describe_file_systems(FileSystemId=efs_id) + print(f"Waiting for EFS {efs_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if "FileSystemNotFound" in str(e): + print(f"EFS {efs_id} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_efs(name: str, namespace: str): + efs_ids = aws_get_efs_ids(name, namespace) + for efs_id in efs_ids: + aws_delete_efs_mount_targets(efs_id) + aws_delete_efs_file_system(efs_id) + + +def aws_delete_subnets(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + vpc_id = aws_get_vpc_id(name, namespace) + subnet_ids = aws_get_subnet_ids(name, namespace) + for subnet_id in subnet_ids: + try: + client.delete_subnet(SubnetId=subnet_id) + print(f"Initiated deletion for subnet {subnet_id}") + except ClientError as e: + if "InvalidSubnetID.NotFound" in str(e): + print(f"Subnet {subnet_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + subnet_ids = aws_get_subnet_ids(name, namespace) + if len(subnet_ids) == 0: + print(f"All subnets for VPC {vpc_id} deleted successfully") + return + else: + print(f"Waiting for subnets for VPC {vpc_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_route_tables(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + vpc_id = aws_get_vpc_id(name, namespace) + route_table_ids = aws_get_route_table_ids(name, namespace) + for route_table_id in route_table_ids: + try: + client.delete_route_table(RouteTableId=route_table_id) + print(f"Initiated deletion for route table {route_table_id}") + except ClientError as e: + if "InvalidRouteTableID.NotFound" in str(e): + print(f"Route table {route_table_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + route_table_ids = aws_get_route_table_ids(name, namespace) + if len(route_table_ids) == 0: + print(f"All route tables for VPC {vpc_id} deleted successfully") + return + else: + print(f"Waiting for route tables for VPC {vpc_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_internet_gateways(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + vpc_id = aws_get_vpc_id(name, namespace) + internet_gateway_ids = aws_get_internet_gateway_ids(name, namespace) + for internet_gateway_id in internet_gateway_ids: + try: + client.detach_internet_gateway( + InternetGatewayId=internet_gateway_id, VpcId=vpc_id + ) + client.delete_internet_gateway(InternetGatewayId=internet_gateway_id) + print( + f"Initiated deletion for internet gateway {internet_gateway_id} from VPC {vpc_id}" + ) + except ClientError as e: + if "InvalidInternetGatewayID.NotFound" in str(e): + print(f"Internet gateway {internet_gateway_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + internet_gateway_ids = aws_get_internet_gateway_ids(name, namespace) + if len(internet_gateway_ids) == 0: + print(f"All internet gateways for VPC {vpc_id} deleted successfully") + return + else: + print(f"Waiting for internet gateways for VPC {vpc_id} to be detached...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_security_groups(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + vpc_id = aws_get_vpc_id(name, namespace) + security_group_ids = aws_get_security_group_ids(name, namespace) + for security_group_id in security_group_ids: + try: + client.delete_security_group(GroupId=security_group_id) + print(f"Initiated deletion for security group {security_group_id}") + except ClientError as e: + if "InvalidGroupID.NotFound" in str(e): + print(f"Security group {security_group_id} not found. Exiting...") + else: + raise e + retries = 0 + while retries < MAX_RETRIES: + security_group_ids = aws_get_security_group_ids(name, namespace) + if len(security_group_ids) == 0: + print(f"All security groups for VPC {vpc_id} deleted successfully") + return + else: + print(f"Waiting for security groups for VPC {vpc_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_vpc(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + vpc_id = aws_get_vpc_id(name, namespace) + if vpc_id is None: + print(f"No VPC {vpc_id} provided. Exiting...") + return + + try: + client.delete_vpc(VpcId=vpc_id) + print(f"Initiated deletion for VPC {vpc_id}") + except ClientError as e: + if "InvalidVpcID.NotFound" in str(e): + print(f"VPC {vpc_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + vpc_id = aws_get_vpc_id(name, namespace) + if vpc_id is None: + print(f"VPC {vpc_id} deleted successfully") + return + else: + print(f"Waiting for VPC {vpc_id} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 + + +def aws_delete_dynamodb_table(name: str): + session = aws_session() + client = session.client("dynamodb") + + try: + client.delete_table(TableName=name) + print(f"Initiated deletion for DynamoDB table {name}") + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"DynamoDB table {name} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + client.describe_table(TableName=name) + print(f"Waiting for DynamoDB table {name} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"DynamoDB table {name} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_ec2_volumes(name: str, namespace: str): + session = aws_session() + client = session.client("ec2") + + volume_ids = aws_get_ec2_volume_ids(name, namespace) + for volume_id in volume_ids: + try: + client.delete_volume(VolumeId=volume_id) + print(f"Initiated deletion for volume {volume_id}") + except ClientError as e: + if "InvalidVolume.NotFound" in str(e): + print(f"Volume {volume_id} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + volume_ids = aws_get_ec2_volume_ids(name, namespace) + if len(volume_ids) == 0: + print("All volumes deleted successfully") + return + else: + print("Waiting for volumes to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + retries += 1 -def delete_aws_s3_bucket( + +def aws_delete_s3_objects( bucket_name: str, - region: str, endpoint: str = None, - digitalocean: bool = False, + digitalocean_region: str = None, ): - MAX_RETRIES = 5 - DELAY = 5 + session = aws_session(digitalocean_region=digitalocean_region) + s3 = session.client("s3", endpoint_url=endpoint) + + try: + s3_objects = s3.list_objects(Bucket=bucket_name) + s3_objects = s3_objects.get("Contents") + if s3_objects: + for obj in s3_objects: + s3.delete_object(Bucket=bucket_name, Key=obj["Key"]) + + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchBucket": + print(f"Bucket {bucket_name} not found. Exiting...") + else: + raise e + + try: + versioned_objects = s3.list_object_versions(Bucket=bucket_name) + for version in versioned_objects.get("DeleteMarkers", []): + print(version) + s3.delete_object( + Bucket=bucket_name, Key=version["Key"], VersionId=version["VersionId"] + ) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchBucket": + print(f"Bucket {bucket_name} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + objs = s3.list_objects(Bucket=bucket_name)["ResponseMetadata"].get( + "Contents" + ) + if objs is None: + print("Bucket objects all deleted successfully") + return + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchBucket": + print(f"Bucket {bucket_name} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_s3_bucket( + bucket_name: str, + endpoint: str = None, + digitalocean_region: str = None, +): + aws_delete_s3_objects(bucket_name, endpoint, digitalocean_region) + + session = aws_session(digitalocean_region=digitalocean_region) + s3 = session.client("s3", endpoint_url=endpoint) + + try: + s3.delete_bucket(Bucket=bucket_name) + print(f"Initiated deletion for bucket {bucket_name}") + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchBucket": + print(f"Bucket {bucket_name} not found. Exiting...") + return + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + s3.head_bucket(Bucket=bucket_name) + print(f"Waiting for bucket {bucket_name} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if (e.response["Error"]["Code"] == "NoSuchBucket") or ( + e.response["Error"]["Code"] == "NotFound" + ): + print(f"Bucket {bucket_name} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_iam_role_policies(role_name: str): + session = aws_session() + iam = session.client("iam") - session = aws_session(region=region, digitalocean=digitalocean) - s3 = session.resource("s3", endpoint_url=endpoint) try: - bucket = s3.Bucket(bucket_name) + response = iam.list_attached_role_policies(RoleName=role_name) + for policy in response["AttachedPolicies"]: + iam.delete_role_policy(RoleName=role_name, PolicyName=policy["PolicyName"]) + print(f"Delete IAM policy {policy['PolicyName']} from IAM role {role_name}") + except ClientError as e: + if "NoSuchEntity" in str(e): + print(f"IAM role {role_name} not found. Exiting...") + else: + raise e - for obj in bucket.objects.all(): - obj.delete() - for obj_version in bucket.object_versions.all(): - obj_version.delete() +def aws_delete_iam_policy(name: str): + session = aws_session() + iam = session.client("iam") + try: + iam.delete_policy(PolicyArn=name) + print(f"Initiated deletion for IAM policy {name}") except ClientError as e: - if "NoSuchBucket" in str(e): - print(f"Bucket {bucket_name} does not exist. Skipping...") + if "NoSuchEntity" in str(e): + print(f"IAM policy {name} not found. Exiting...") + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + iam.get_policy(PolicyArn=name) + print(f"Waiting for IAM policy {name} to be deleted...") + sleep_time = DELAY * (2**retries) + time.sleep(sleep_time) + except ClientError as e: + if "NoSuchEntity" in str(e): + print(f"IAM policy {name} deleted successfully") + return + else: + raise e + retries += 1 + + +def aws_delete_iam_role(role_name: str): + session = aws_session() + iam = session.client("iam") + + try: + attached_policies = iam.list_attached_role_policies(RoleName=role_name) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchEntity": + print(f"IAM role {role_name} not found. Exiting...") return else: raise e + for policy in attached_policies["AttachedPolicies"]: + iam.detach_role_policy(RoleName=role_name, PolicyArn=policy["PolicyArn"]) + print(f"Detached policy {policy['PolicyName']} from role {role_name}") + + if policy["PolicyArn"].startswith("arn:aws:iam::aws:policy"): + continue + + policy_versions = iam.list_policy_versions(PolicyArn=policy["PolicyArn"]) + + for version in policy_versions["Versions"]: + if not version["IsDefaultVersion"]: + iam.delete_policy_version( + PolicyArn=policy["PolicyArn"], VersionId=version["VersionId"] + ) + print( + f"Deleted version {version['VersionId']} of policy {policy['PolicyName']}" + ) + + iam.delete_policy(PolicyArn=policy["PolicyArn"]) + print(f"Deleted policy {policy['PolicyName']}") + + iam.delete_role(RoleName=role_name) + print(f"Deleted role {role_name}") + - for i in range(MAX_RETRIES): +def aws_delete_node_groups(name: str, namespace: str): + cluster_name = f"{name}-{namespace}" + session = aws_session() + eks = session.client("eks") + try: + response = eks.list_nodegroups(clusterName=cluster_name) + node_groups = response.get("nodegroups", []) + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"Cluster {cluster_name} not found. Exiting...") + return + else: + raise e + + for node_group in node_groups: try: - bucket.delete() - print(f"Successfully deleted bucket {bucket_name}") + eks.delete_nodegroup(clusterName=cluster_name, nodegroupName=node_group) + print( + f"Initiated deletion for node group {node_group} belonging to cluster {cluster_name}" + ) + except ClientError as e: + if "ResourceNotFoundException" not in str(e): + raise e + + retries = 0 + while retries < MAX_RETRIES: + pending_deletion = [] + + for node_group in node_groups: + try: + response = eks.describe_nodegroup( + clusterName=cluster_name, nodegroupName=node_group + ) + if response["nodegroup"]["status"] == "DELETING": + pending_deletion.append(node_group) + except ClientError as e: + if "ResourceNotFoundException" in str(e): + pass + else: + raise e + + if not pending_deletion: + print("All node groups have been deleted successfully.") + return + + if retries < MAX_RETRIES - 1: + sleep_time = DELAY * (2**retries) + print( + f"{len(pending_deletion)} node groups still pending deletion. Retrying in {sleep_time} seconds..." + ) + time.sleep(sleep_time) + + retries += 1 + pending_deletion.clear() + + print(f"Failed to confirm deletion of all node groups after {MAX_RETRIES} retries.") + + +def aws_delete_cluster(name: str, namespace: str): + cluster_name = f"{name}-{namespace}" + session = aws_session() + eks = session.client("eks") + + try: + eks.delete_cluster(name=cluster_name) + print(f"Initiated deletion for cluster {cluster_name}") + except ClientError as e: + if "ResourceNotFoundException" in str(e): + print(f"Cluster {cluster_name} not found. Exiting...") return + else: + raise e + + retries = 0 + while retries < MAX_RETRIES: + try: + response = eks.describe_cluster(name=cluster_name) + if response["cluster"]["status"] == "DELETING": + sleep_time = DELAY * (2**retries) + print( + f"Cluster {cluster_name} still pending deletion. Retrying in {sleep_time} seconds..." + ) + time.sleep(sleep_time) + else: + raise ClientError( + f"Unexpected status for cluster {cluster_name}: {response['cluster']['status']}" + ) except ClientError as e: - if "BucketNotEmpty" in str(e): - print(f"Bucket is not yet empty. Retrying in {DELAY} seconds...") - time.sleep(DELAY) + if "ResourceNotFoundException" in str(e): + print(f"Cluster {cluster_name} has been deleted successfully.") + return else: raise e - print(f"Failed to delete bucket {bucket_name} after {MAX_RETRIES} retries.") + + retries += 1 + + print( + f"Failed to confirm deletion of cluster {cluster_name} after {MAX_RETRIES} retries." + ) + + +def aws_cleanup(name: str, namespace: str): + aws_delete_node_groups(name, namespace) + aws_delete_cluster(name, namespace) + + aws_delete_load_balancer(name, namespace) + + aws_delete_efs(name, namespace) + + aws_delete_subnets(name, namespace) + aws_delete_route_tables(name, namespace) + aws_delete_internet_gateways(name, namespace) + aws_delete_security_groups(name, namespace) + aws_delete_vpc(name, namespace) + + aws_delete_ec2_volumes(name, namespace) + + dynamodb_table_name = f"{name}-{namespace}-terraform-state-lock" + aws_delete_dynamodb_table(dynamodb_table_name) + + s3_bucket_name = f"{name}-{namespace}-terraform-state" + aws_delete_s3_bucket(s3_bucket_name) + + iam_role_name = f"{name}-{namespace}-eks-cluster-role" + iam_role_node_group_name = f"{name}-{namespace}-eks-node-group-role" + iam_policy_name_regex = "^eks-worker-autoscaling-{name}-{namespace}(\\d+)$".format( + name=name, namespace=namespace + ) + iam_policy = aws_get_iam_policy(pattern=iam_policy_name_regex) + if iam_policy: + aws_delete_iam_role_policies(iam_role_node_group_name) + aws_delete_iam_policy(iam_policy) + aws_delete_iam_role(iam_role_name) + aws_delete_iam_role(iam_role_node_group_name) diff --git a/src/_nebari/provider/cloud/digital_ocean.py b/src/_nebari/provider/cloud/digital_ocean.py index 1cfcef1d81..0da8d8daff 100644 --- a/src/_nebari/provider/cloud/digital_ocean.py +++ b/src/_nebari/provider/cloud/digital_ocean.py @@ -7,7 +7,7 @@ from kubernetes import client, config from _nebari import constants -from _nebari.provider.cloud.amazon_web_services import delete_aws_s3_bucket +from _nebari.provider.cloud.amazon_web_services import aws_delete_s3_bucket from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version from _nebari.utils import set_do_environment @@ -107,7 +107,7 @@ def digital_ocean_delete_kubernetes_cluster(cluster_name: str): digital_ocean_request(f"kubernetes/clusters/{cluster_id}", method="DELETE") -def digital_ocean_cleanup(name: str, namespace: str, region: str): +def digital_ocean_cleanup(name: str, namespace: str): cluster_name = f"{name}-{namespace}" tf_state_bucket = f"{cluster_name}-terraform-state" do_spaces_endpoint = "https://nyc3.digitaloceanspaces.com" @@ -123,7 +123,7 @@ def digital_ocean_cleanup(name: str, namespace: str, region: str): ) set_do_environment() - delete_aws_s3_bucket( - tf_state_bucket, region=region, digitalocean=True, endpoint=do_spaces_endpoint + aws_delete_s3_bucket( + tf_state_bucket, digitalocean=True, endpoint=do_spaces_endpoint ) digital_ocean_delete_kubernetes_cluster(cluster_name) diff --git a/src/_nebari/stages/base.py b/src/_nebari/stages/base.py index d15e67d218..60a4821a24 100644 --- a/src/_nebari/stages/base.py +++ b/src/_nebari/stages/base.py @@ -57,7 +57,9 @@ def set_outputs( stage_outputs[stage_key].update(outputs) @contextlib.contextmanager - def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): + def deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): deploy_config = dict( directory=str(self.output_directory / self.stage_prefix), input_vars=self.input_vars(stage_outputs), @@ -68,13 +70,17 @@ def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): deploy_config["state_imports"] = state_imports self.set_outputs(stage_outputs, terraform.deploy(**deploy_config)) - self.post_deploy(stage_outputs) + self.post_deploy(stage_outputs, disable_prompt) yield - def post_deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): + def post_deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): pass - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): pass @contextlib.contextmanager diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 81e3bf86f6..0507f7a9ca 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -712,7 +712,9 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): else: raise ValueError(f"Unknown provider: {self.config.provider}") - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -746,8 +748,10 @@ def set_outputs( super().set_outputs(stage_outputs, outputs) @contextlib.contextmanager - def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): - with super().deploy(stage_outputs): + def deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): + with super().deploy(stage_outputs, disable_prompt): with kubernetes_provider_context( stage_outputs["stages/" + self.name]["kubernetes_credentials"]["value"] ): diff --git a/src/_nebari/stages/kubernetes_ingress/__init__.py b/src/_nebari/stages/kubernetes_ingress/__init__.py index 28e5679c64..ea1d27c5e3 100644 --- a/src/_nebari/stages/kubernetes_ingress/__init__.py +++ b/src/_nebari/stages/kubernetes_ingress/__init__.py @@ -36,12 +36,11 @@ def add_clearml_dns(zone_name, record_name, record_type, ip_or_hostname): def provision_ingress_dns( - stage_outputs, - config, + stage_outputs: Dict[str, Dict[str, Any]], + config: schema.Main, dns_provider: str, dns_auto_provision: bool, disable_prompt: bool = True, - disable_checks: bool = False, ): directory = "stages/04-kubernetes-ingress" @@ -78,11 +77,8 @@ def provision_ingress_dns( f'"{config.domain}" [Press Enter when Complete]' ) - if not disable_checks: - check_ingress_dns(stage_outputs, config, disable_prompt) - -def check_ingress_dns(stage_outputs, config, disable_prompt): +def check_ingress_dns(stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool): directory = "stages/04-kubernetes-ingress" ip_or_name = stage_outputs[directory]["load_balancer_address"]["value"] @@ -155,6 +151,7 @@ class Certificate(schema.Base): class DnsProvider(schema.Base): provider: typing.Optional[str] + auto_provision: typing.Optional[bool] = False class Ingress(schema.Base): @@ -233,18 +230,21 @@ def set_outputs( super().set_outputs(stage_outputs, outputs) - def post_deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): + def post_deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): if self.config.dns and self.config.dns.provider: provision_ingress_dns( stage_outputs, self.config, dns_provider=self.config.dns.provider, - dns_auto_provision=True, - disable_prompt=True, - disable_checks=False, + dns_auto_provision=self.config.dns.auto_provision, + disable_prompt=disable_prompt, ) - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): def _attempt_tcp_connect( host, port, num_attempts=NUM_ATTEMPTS, timeout=TIMEOUT ): @@ -295,7 +295,7 @@ def _attempt_tcp_connect( f"After stage={self.name} kubernetes ingress available on tcp ports={tcp_ports}" ) - check_ingress_dns(stage_outputs, self.config, disable_prompt=False) + check_ingress_dns(stage_outputs, disable_prompt=disable_prompt) @hookimpl diff --git a/src/_nebari/stages/kubernetes_initialize/__init__.py b/src/_nebari/stages/kubernetes_initialize/__init__.py index 02f8df6f9c..918e996a08 100644 --- a/src/_nebari/stages/kubernetes_initialize/__init__.py +++ b/src/_nebari/stages/kubernetes_initialize/__init__.py @@ -99,7 +99,9 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): return input_vars.dict() - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): from kubernetes import client, config from kubernetes.client.rest import ApiException diff --git a/src/_nebari/stages/kubernetes_keycloak/__init__.py b/src/_nebari/stages/kubernetes_keycloak/__init__.py index ac8882df23..eb9e5a062d 100644 --- a/src/_nebari/stages/kubernetes_keycloak/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak/__init__.py @@ -183,7 +183,9 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ], ).dict() - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_check: bool = False + ): from keycloak import KeycloakAdmin from keycloak.exceptions import KeycloakError @@ -242,8 +244,10 @@ def _attempt_keycloak_connection( print("Keycloak service successfully started") @contextlib.contextmanager - def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): - with super().deploy(stage_outputs): + def deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): + with super().deploy(stage_outputs, disable_prompt): with keycloak_provider_context( stage_outputs["stages/" + self.name]["keycloak_credentials"]["value"] ): diff --git a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py index 39ca07a599..39f7b8ae8e 100644 --- a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py @@ -41,7 +41,9 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): return input_vars.dict() - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): directory = "stages/05-kubernetes-keycloak" from keycloak import KeycloakAdmin diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index 087bac4642..206ab07a31 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -574,7 +574,9 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): **clearml_vars.dict(by_alias=True), } - def check(self, stage_outputs: Dict[str, Dict[str, Any]]): + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): directory = "stages/07-kubernetes-services" import requests diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index ed01f6eb56..b8c11ec803 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -169,8 +169,10 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ValueError(f"Unknown provider: {self.config.provider}") @contextlib.contextmanager - def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): - with super().deploy(stage_outputs): + def deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): + with super().deploy(stage_outputs, disable_prompt): env_mapping = {} # DigitalOcean terraform remote state using Spaces Bucket # assumes aws credentials thus we set them to match spaces credentials diff --git a/src/_nebari/subcommands/deploy.py b/src/_nebari/subcommands/deploy.py index 6c0564a502..ab7b3b6551 100644 --- a/src/_nebari/subcommands/deploy.py +++ b/src/_nebari/subcommands/deploy.py @@ -7,6 +7,8 @@ from _nebari.render import render_template from nebari.hookspecs import hookimpl +TERRAFORM_STATE_STAGE_NAME = "01-terraform-state" + @hookimpl def nebari_subcommand(cli: typer.Typer): @@ -28,12 +30,12 @@ def deploy( dns_provider: str = typer.Option( False, "--dns-provider", - help="dns provider to use for registering domain name mapping", + help="dns provider to use for registering domain name mapping ⚠️ moved to `dns.provider` in nebari-config.yaml", ), dns_auto_provision: bool = typer.Option( False, "--dns-auto-provision", - help="Attempt to automatically provision DNS, currently only available for `cloudflare`", + help="Attempt to automatically provision DNS, currently only available for `cloudflare` ⚠️ moved to `dns.auto_provision` in nebari-config.yaml", ), disable_prompt: bool = typer.Option( False, @@ -69,12 +71,22 @@ def deploy( if not disable_render: render_template(output_directory, config, stages) + if skip_remote_state_provision: + for stage in stages: + if stage.name == TERRAFORM_STATE_STAGE_NAME: + stages.remove(stage) + print("Skipping remote state provision") + + if dns_provider and dns_auto_provision: + # TODO: Add deprecation warning and update docs on how to configure DNS via nebari-config.yaml + print( + "Please add a `dns.provider` and `dns.auto_privision` to your nebari-config.yaml file to enable DNS auto-provisioning." + ) + exit(1) + deploy_configuration( config, stages, - dns_provider=dns_provider, - dns_auto_provision=dns_auto_provision, disable_prompt=disable_prompt, disable_checks=disable_checks, - skip_remote_state_provision=skip_remote_state_provision, ) diff --git a/src/nebari/hookspecs.py b/src/nebari/hookspecs.py index 789dfe2d78..7dc6c8e3a4 100644 --- a/src/nebari/hookspecs.py +++ b/src/nebari/hookspecs.py @@ -27,10 +27,14 @@ def render(self) -> Dict[str, str]: return {} @contextlib.contextmanager - def deploy(self, stage_outputs: Dict[str, Dict[str, Any]]): + def deploy( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ): yield - def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool: + def check( + self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False + ) -> bool: pass @contextlib.contextmanager diff --git a/tests/common/config_mod_utils.py b/tests/common/config_mod_utils.py index dce0b4fdeb..67a897c3b9 100644 --- a/tests/common/config_mod_utils.py +++ b/tests/common/config_mod_utils.py @@ -1,6 +1,9 @@ import dataclasses import typing +from _nebari.stages.infrastructure import AWSNodeGroup, GCPNodeGroup +from _nebari.stages.kubernetes_services import JupyterLabProfile, KubeSpawner + PREEMPTIBLE_NODE_GROUP_NAME = "preemptible-node-group" @@ -70,48 +73,71 @@ def _create_gpu_environment(): def add_gpu_config(config, cloud="aws"): + # TODO: do we still need GPU_CONFIG here? gpu_config = GPU_CONFIG.get(cloud) if not gpu_config: raise ValueError(f"GPU not supported/tested on {cloud}") - gpu_node = gpu_config.node() - gpu_docker_image = gpu_config.docker_image - jupyterlab_profile = { - "display_name": "GPU Instance", - "description": "4 CPU / 16GB RAM / 1 NVIDIA T4 GPU (16 GB GPU RAM)", - "groups": ["gpu-access"], - "kubespawner_override": { - "image": gpu_docker_image, - "cpu_limit": 4, - "cpu_guarantee": 3, - "mem_limit": "16G", - "mem_guarantee": "10G", - "extra_resource_limits": {"nvidia.com/gpu": 1}, - "node_selector": { + if cloud == "aws": + gpu_node_group = AWSNodeGroup( + instance=gpu_config.gpu_name, + min_nodes=gpu_config.min_nodes, + max_nodes=gpu_config.max_nodes, + single_subnet=gpu_config.extra_config["single_subnet"], + gpu=gpu_config.extra_config["gpu"], + ) + kubespawner_overrides = KubeSpawner( + image=gpu_config.docker_image, + cpu_limit=4, + cpu_guarantee=3, + mem_limit="16G", + mem_guarantee="10G", + extra_resource_limits={"nvidia.com/gpu": 1}, + node_selector={ gpu_config.node_selector: gpu_config.node_selector_val, }, - }, - } - config[gpu_config.cloud]["node_groups"][gpu_config.node_group_name] = gpu_node - config["profiles"]["jupyterlab"].append(jupyterlab_profile) - config["environments"]["environment-gpu.yaml"] = _create_gpu_environment() + ) + + jupyterlab_profile = JupyterLabProfile( + display_name="GPU Instance", + description="4 CPU / 16GB RAM / 1 NVIDIA T4 GPU (16 GB GPU RAM)", + access="yaml", + groups=["gpu-access"], + kubespawner_override=kubespawner_overrides, + ) + + cloud_section = getattr(config, gpu_config.cloud, None) + cloud_section.node_groups[gpu_config.node_group_name] = gpu_node_group + config.profiles.jupyterlab.append(jupyterlab_profile) + config.environments["environment-gpu.yaml"] = _create_gpu_environment() + return config def add_preemptible_node_group(config, cloud="aws"): + node_group = None if cloud == "aws": cloud_name = "amazon_web_services" - instance_name = "m5.xlarge" + # TODO: how to make preemptible? + node_group = AWSNodeGroup( + instance="m5.xlarge", + min_nodes=1, + max_nodes=5, + single_subnet=False, + ) elif cloud == "gcp": cloud_name = "google_cloud_platform" - instance_name = "n1-standard-8" + node_group = GCPNodeGroup( + instance="n1-standard-8", + min_nodes=1, + max_nodes=5, + preemptible=True, + ) else: raise ValueError("Invalid cloud for preemptible config") - config[cloud_name]["node_groups"][PREEMPTIBLE_NODE_GROUP_NAME] = { - "instance": instance_name, - "min_nodes": 1, - "max_nodes": 5, - "single_subnet": False, - "preemptible": True, - } + + cloud_section = getattr(config, cloud_name, None) + if node_group: + cloud_section.node_groups[PREEMPTIBLE_NODE_GROUP_NAME] = node_group + return config diff --git a/tests/tests_integration/README.md b/tests/tests_integration/README.md index 6735ae4eaa..2d82593881 100644 --- a/tests/tests_integration/README.md +++ b/tests/tests_integration/README.md @@ -1,9 +1,10 @@ # Integration Testing via Pytest These tests are designed to test things on Nebari deployed -on cloud. At the moment it only deploys on DigitalOcean. +on cloud. -You need the following environment variables to run these. + +## Digital Ocean ```bash DIGITALOCEAN_TOKEN @@ -13,14 +14,29 @@ SPACES_SECRET_ACCESS_KEY CLOUDFLARE_TOKEN ``` -For instructions on how to get these variables check the documentation -for DigitalOcean deployment. +Once those are set, you can run: + +```bash +pytest tests_integration -vvv -s -m do +``` + +This will deploy on Nebari on Amazon Web Services, run tests on the deployment +and then teardown the cluster. + +## Amazon Web Services + +```bash +AWS_ACCESS_KEY_ID +AWS_SECRET_ACCESS_KEY +AWS_DEFAULT_REGION +CLOUDFLARE_TOKEN +``` -Running Tests: +Once those are set, you can run: ```bash -pytest tests_integration -vvv -s +pytest tests_integration -vvv -s -m aws ``` -This would deploy on digitalocean, run tests on the deployment +This will deploy on Nebari on Amazon Web Services, run tests on the deployment and then teardown the cluster. diff --git a/tests/tests_integration/deployment_fixtures.py b/tests/tests_integration/deployment_fixtures.py index 1d2f30bdf3..7dc3d1487e 100644 --- a/tests/tests_integration/deployment_fixtures.py +++ b/tests/tests_integration/deployment_fixtures.py @@ -12,9 +12,10 @@ from _nebari.config import read_configuration, write_configuration from _nebari.deploy import deploy_configuration from _nebari.destroy import destroy_configuration +from _nebari.provider.cloud.amazon_web_services import aws_cleanup from _nebari.provider.cloud.digital_ocean import digital_ocean_cleanup from _nebari.render import render_template -from _nebari.utils import set_do_environment, yaml +from _nebari.utils import set_do_environment from tests.common.config_mod_utils import add_gpu_config, add_preemptible_node_group from tests.tests_unit.utils import render_config_partial @@ -80,6 +81,8 @@ def _create_nebari_user(config): def deploy(request): """Deploy Nebari on the given cloud, currently only DigitalOcean""" ignore_warnings() + + # initialize cloud = request.param logger.info(f"Deploying: {cloud}") if cloud == "do": @@ -93,29 +96,12 @@ def deploy(request): ci_provider="github-actions", auth_provider="password", ) - # Generate certificate as well - config["certificate"] = { - "type": "lets-encrypt", - "acme_email": "internal-devops@quansight.com", - "acme_server": "https://acme-v02.api.letsencrypt.org/directory", - } - if cloud in ["aws", "gcp"]: - config = add_gpu_config(config, cloud=cloud) - config = add_preemptible_node_group(config, cloud=cloud) deployment_dir_abs = deployment_dir.absolute() os.chdir(deployment_dir) logger.info(f"Temporary directory: {deployment_dir}") config_path = Path("nebari-config.yaml") - if config_path.exists(): - with open(config_path) as f: - current_config = yaml.load(f) - - config["security"]["keycloak"]["initial_root_password"] = current_config[ - "security" - ]["keycloak"]["initial_root_password"] - write_configuration(config_path, config) from nebari.plugins import nebari_plugin_manager @@ -124,8 +110,23 @@ def deploy(request): config_schema = nebari_plugin_manager.config_schema config = read_configuration(config_path, config_schema) + + # Modify config + config.certificate.type = "lets-encrypt" + config.certificate.acme_email = "internal-devops@quansight.com" + config.certificate.acme_server = "https://acme-v02.api.letsencrypt.org/directory" + config.dns.provider = "cloudflare" + config.dns.auto_provision = True + + if cloud in ["aws", "gcp"]: + config = add_gpu_config(config, cloud=cloud) + config = add_preemptible_node_group(config, cloud=cloud) + + # render render_template(deployment_dir_abs, config, stages) + print(config) + failed = False # deploy @@ -136,11 +137,8 @@ def deploy(request): deploy_config = deploy_configuration( config=config, stages=stages, - dns_provider="cloudflare", - dns_auto_provision=True, disable_prompt=True, disable_checks=False, - skip_remote_state_provision=False, ) _create_nebari_user(config) _set_nebari_creds_in_environment(config) @@ -184,9 +182,17 @@ def _create_pytest_param(cloud): def _cleanup_nebari(config): - cloud_provider = config["provider"] + cloud_provider = config.provider + project_name = config.name + namespace = config.namespace if cloud_provider == "do": digital_ocean_cleanup( - name=config["name"], namespace=config["namespace"], region=config["region"] + name=project_name, + namespace=namespace, + ) + elif cloud_provider == "aws": + aws_cleanup( + name=project_name, + namespace=namespace, )