Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Sep 24, 2024
1 parent 1a8c79b commit c35f176
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 49 deletions.
6 changes: 3 additions & 3 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_current_kube_config_context_namespace())
context = config['provider'].get(
'context', kubernetes_utils.get_current_kube_config_context_name())
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_kube_config_context_namespace(context))
k8s = kubernetes.kubernetes
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
Expand Down
15 changes: 14 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
"""
# Bump if any fields get added/removed/changed, and add backward
# compaitibility logic in __setstate__.
_VERSION = 8
_VERSION = 9

def __init__(
self,
Expand Down Expand Up @@ -2515,6 +2515,19 @@ def __setstate__(self, state):
if version < 8:
self.cached_cluster_info = None

if version < 9:
# For backward compatibility, we should update the region of a
# SkyPilot cluster on Kubernetes to the actual context it is using.
# pylint: disable=import-outside-toplevel
from sky.provision.kubernetes import utils as kubernetes_utils
launched_resources = state['launched_resources']
if isinstance(launched_resources.cloud, clouds.Kubernetes):
yaml_config = common_utils.read_yaml(state['_cluster_yaml'])
context = kubernetes_utils.get_context_from_config(
yaml_config['provider'])
state['launched_resources'] = launched_resources.copy(
region=context)

self.__dict__.update(state)

# Because the update_cluster_ips and update_ssh_ports
Expand Down
44 changes: 25 additions & 19 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class Kubernetes(clouds.Cloud):
_DEFAULT_MEMORY_CPU_RATIO = 1
_DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks
_REPR = 'Kubernetes'
_SINGLETON_REGION = 'kubernetes'
_regions: List[clouds.Region] = [clouds.Region(_SINGLETON_REGION)]
_LEGACY_SINGLETON_REGION = 'kubernetes'
_CLOUD_UNSUPPORTED_FEATURES = {
# TODO(romilb): Stopping might be possible to implement with
# container checkpointing introduced in Kubernetes v1.25. See:
Expand Down Expand Up @@ -117,7 +116,10 @@ def regions_with_offering(cls, instance_type: Optional[str],
allowed_contexts = skypilot_config.get_nested(
('kubernetes', 'allowed_contexts'), None)
if allowed_contexts is None:
return cls._regions
return [
clouds.Region(
kubernetes_utils.get_current_kube_config_context_name())
]
regions = [clouds.Region(context) for context in allowed_contexts]
if region is not None:
regions = [r for r in regions if r.name == region]
Expand Down Expand Up @@ -227,7 +229,9 @@ def make_deploy_resources_variables(
dryrun: bool = False) -> Dict[str, Optional[str]]:
del cluster_name, zones, dryrun # Unused.
if region is None:
region = self._regions[0]
context = kubernetes_utils.get_current_kube_config_context_name()
else:
context = region.name

r = resources
acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
Expand Down Expand Up @@ -311,13 +315,10 @@ def make_deploy_resources_variables(
deploy_vars = {
'instance_type': resources.instance_type,
'custom_resources': custom_resources,
'region': region.name,
'cpus': str(cpus),
'memory': str(mem),
'accelerator_count': str(acc_count),
'timeout': str(timeout),
'k8s_namespace':
kubernetes_utils.get_current_kube_config_context_namespace(),
'k8s_port_mode': port_mode.value,
'k8s_networking_mode': network_utils.get_networking_mode().value,
'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME,
Expand All @@ -337,20 +338,19 @@ def make_deploy_resources_variables(

# Add kubecontext if it is set. It may be None if SkyPilot is running
# inside a pod with in-cluster auth.
resource_context = None
if region.name != self._SINGLETON_REGION:
resource_context = region.name
else:
resource_context = (
kubernetes_utils.get_current_kube_config_context_name())
if resource_context is not None:
deploy_vars['k8s_context'] = resource_context
if context is not None:
deploy_vars['k8s_context'] = context

namespace = kubernetes_utils.get_kube_config_context_namespace(context)
deploy_vars['k8s_namespace'] = namespace

return deploy_vars

def _get_feasible_launchable_resources(
self, resources: 'resources_lib.Resources'
) -> 'resources_utils.FeasibleResources':
# TODO(zhwu): This needs to be updated to return the correct region
# (context) that has enough resources.
fuzzy_candidate_list: List[str] = []
if resources.instance_type is not None:
assert resources.is_launchable(), resources
Expand Down Expand Up @@ -440,14 +440,20 @@ def instance_type_exists(self, instance_type: str) -> bool:
instance_type)

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
# TODO(zhwu): or not in allowed_contexts
if region == self._LEGACY_SINGLETON_REGION:
# For backward compatibility, we allow the region to be set to the
# legacy singletonton region.
# TODO: Remove this after 0.9.0.
return region, zone

all_contexts = kubernetes_utils.get_all_kube_config_context_names()
if all_contexts is None:
all_contexts = []
if region != self._SINGLETON_REGION and region not in all_contexts:
if region not in all_contexts:
raise ValueError(
'Kubernetes only supports context names as regions. '
f'Allowed contexts: {all_contexts}')
f'Context {region} not found in kubeconfig. Kubernetes only '
'supports context names as regions. Available '
f'contexts: {all_contexts}')
if zone is not None:
raise ValueError('Kubernetes support does not support setting zone.'
' Cluster used is determined by the kubeconfig.')
Expand Down
48 changes: 34 additions & 14 deletions sky/provision/kubernetes/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def _open_ports_using_ingress(
)

# Prepare service names, ports, for template rendering
service_details = [(f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
_PATH_PREFIX.format(
cluster_name_on_cloud=cluster_name_on_cloud,
port=port,
namespace=kubernetes_utils.
get_current_kube_config_context_namespace()).rstrip(
'/').lstrip('/')) for port in ports]
service_details = [
(f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
_PATH_PREFIX.format(
cluster_name_on_cloud=cluster_name_on_cloud,
port=port,
namespace=kubernetes_utils.get_kube_config_context_namespace(
context)).rstrip('/').lstrip('/')) for port in ports
]

# Generate ingress and services specs
# We batch ingress rule creation because each rule triggers a hot reload of
Expand Down Expand Up @@ -171,7 +172,8 @@ def _cleanup_ports_for_ingress(
for port in ports:
service_name = f'{cluster_name_on_cloud}--skypilot-svc--{port}'
network_utils.delete_namespaced_service(
namespace=provider_config.get('namespace', 'default'),
namespace=provider_config.get('namespace',
kubernetes_utils.DEFAULT_NAMESPACE),
service_name=service_name,
)

Expand Down Expand Up @@ -208,11 +210,13 @@ def query_ports(
return _query_ports_for_ingress(
cluster_name_on_cloud=cluster_name_on_cloud,
ports=ports,
provider_config=provider_config,
)
elif port_mode == kubernetes_enums.KubernetesPortMode.PODIP:
return _query_ports_for_podip(
cluster_name_on_cloud=cluster_name_on_cloud,
ports=ports,
provider_config=provider_config,
)
else:
return {}
Expand All @@ -231,8 +235,14 @@ def _query_ports_for_loadbalancer(
result: Dict[int, List[common.Endpoint]] = {}
service_name = _LOADBALANCER_SERVICE_NAME.format(
cluster_name_on_cloud=cluster_name_on_cloud)
context = provider_config.get(
'context', kubernetes_utils.get_current_kube_config_context_name())
namespace = provider_config.get(
'namespace',
kubernetes_utils.get_kube_config_context_namespace(context))
external_ip = network_utils.get_loadbalancer_ip(
namespace=provider_config.get('namespace', 'default'),
context=context,
namespace=namespace,
service_name=service_name,
# Timeout is set so that we can retry the query when the
# cluster is firstly created and the load balancer is not ready yet.
Expand All @@ -251,19 +261,24 @@ def _query_ports_for_loadbalancer(
def _query_ports_for_ingress(
cluster_name_on_cloud: str,
ports: List[int],
provider_config: Dict[str, Any],
) -> Dict[int, List[common.Endpoint]]:
ingress_details = network_utils.get_ingress_external_ip_and_ports()
context = provider_config.get(
'context', kubernetes_utils.get_current_kube_config_context_name())
ingress_details = network_utils.get_ingress_external_ip_and_ports(context)
external_ip, external_ports = ingress_details
if external_ip is None:
return {}

namespace = provider_config.get(
'namespace',
kubernetes_utils.get_kube_config_context_namespace(context))
result: Dict[int, List[common.Endpoint]] = {}
for port in ports:
path_prefix = _PATH_PREFIX.format(
cluster_name_on_cloud=cluster_name_on_cloud,
port=port,
namespace=kubernetes_utils.
get_current_kube_config_context_namespace())
namespace=namespace)

http_port, https_port = external_ports \
if external_ports is not None else (None, None)
Expand All @@ -282,10 +297,15 @@ def _query_ports_for_ingress(
def _query_ports_for_podip(
cluster_name_on_cloud: str,
ports: List[int],
provider_config: Dict[str, Any],
) -> Dict[int, List[common.Endpoint]]:
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
context = provider_config.get(
'context', kubernetes_utils.get_current_kube_config_context_name())
namespace = provider_config.get(
'namespace',
kubernetes_utils.get_kube_config_context_namespace(context))
pod_name = kubernetes_utils.get_head_pod_name(cluster_name_on_cloud)
pod_ip = network_utils.get_pod_ip(namespace, pod_name)
pod_ip = network_utils.get_pod_ip(context, namespace, pod_name)

result: Dict[int, List[common.Endpoint]] = {}
if pod_ip is None:
Expand Down
12 changes: 7 additions & 5 deletions sky/provision/kubernetes/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ def ingress_controller_exists(context: str,


def get_ingress_external_ip_and_ports(
context: str,
namespace: str = 'ingress-nginx'
) -> Tuple[Optional[str], Optional[Tuple[int, int]]]:
"""Returns external ip and ports for the ingress controller."""
core_api = kubernetes.core_api()
core_api = kubernetes.core_api(context)
ingress_services = [
item for item in core_api.list_namespaced_service(
namespace, _request_timeout=kubernetes.API_TIMEOUT).items
Expand Down Expand Up @@ -257,11 +258,12 @@ def get_ingress_external_ip_and_ports(
return external_ip, None


def get_loadbalancer_ip(namespace: str,
def get_loadbalancer_ip(context: str,
namespace: str,
service_name: str,
timeout: int = 0) -> Optional[str]:
"""Returns the IP address of the load balancer."""
core_api = kubernetes.core_api()
core_api = kubernetes.core_api(context)

ip = None

Expand All @@ -282,9 +284,9 @@ def get_loadbalancer_ip(namespace: str,
return ip


def get_pod_ip(namespace: str, pod_name: str) -> Optional[str]:
def get_pod_ip(context: str, namespace: str, pod_name: str) -> Optional[str]:
"""Returns the IP address of the pod."""
core_api = kubernetes.core_api()
core_api = kubernetes.core_api(context)
pod = core_api.read_namespaced_pod(pod_name,
namespace,
_request_timeout=kubernetes.API_TIMEOUT)
Expand Down
31 changes: 24 additions & 7 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Kubernetes utilities for SkyPilot."""
import dataclasses
import functools
import json
import math
import os
Expand Down Expand Up @@ -364,7 +365,7 @@ def get_kubernetes_pods() -> List[Any]:
Used for computing cluster resource usage.
"""
try:
ns = get_current_kube_config_context_namespace()
ns = get_kube_config_context_namespace()
pods = kubernetes.core_api().list_namespaced_pod(
ns, _request_timeout=kubernetes.API_TIMEOUT).items
except kubernetes.max_retry_error():
Expand All @@ -390,6 +391,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
Optional[str]: Error message if the instance does not fit.
"""

# TODO(zhwu): this should check the node for specific context, instead
# of the default context to make failover fully functional.

def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType',
node_list: List[Any]) -> Tuple[bool, Optional[str]]:
"""Checks if the instance fits on the cluster based on CPU and memory.
Expand Down Expand Up @@ -629,7 +633,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
str: Error message if credentials are invalid, None otherwise
"""
try:
ns = get_current_kube_config_context_namespace()
ns = get_kube_config_context_namespace()
context = get_current_kube_config_context_name()
kubernetes.core_api(context).list_namespaced_pod(
ns, _request_timeout=timeout)
Expand Down Expand Up @@ -760,6 +764,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
return False, None


@functools.lru_cache()
def get_current_kube_config_context_name() -> Optional[str]:
"""Get the current kubernetes context from the kubeconfig file
Expand All @@ -774,6 +779,7 @@ def get_current_kube_config_context_name() -> Optional[str]:
return None


@functools.lru_cache()
def get_all_kube_config_context_names() -> Optional[List[str]]:
"""Get all kubernetes context names from the kubeconfig file
Expand All @@ -789,7 +795,9 @@ def get_all_kube_config_context_names() -> Optional[List[str]]:
return None


def get_current_kube_config_context_namespace() -> str:
@functools.lru_cache()
def get_kube_config_context_namespace(
context_name: Optional[str] = None) -> str:
"""Get the current kubernetes context namespace from the kubeconfig file
Returns:
Expand All @@ -804,9 +812,17 @@ def get_current_kube_config_context_namespace() -> str:
return f.read().strip()
# If not in-cluster, get the namespace from kubeconfig
try:
_, current_context = k8s.config.list_kube_config_contexts()
if 'namespace' in current_context['context']:
return current_context['context']['namespace']
contexts, current_context = k8s.config.list_kube_config_contexts()
if context_name is None:
context = current_context
else:
context = next((c for c in contexts if c['name'] == context_name),
None)
if context is None:
return DEFAULT_NAMESPACE

if 'namespace' in context['context']:
return context['context']['namespace']
else:
return DEFAULT_NAMESPACE
except k8s.config.config_exception.ConfigException:
Expand Down Expand Up @@ -1763,8 +1779,9 @@ def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:


def get_namespace_from_config(provider_config: Dict[str, Any]) -> str:
context = get_context_from_config(provider_config)
return provider_config.get('namespace',
get_current_kube_config_context_namespace())
get_kube_config_context_namespace(context))


def get_context_from_config(provider_config: Dict[str, Any]) -> str:
Expand Down

0 comments on commit c35f176

Please sign in to comment.