Skip to content

Commit

Permalink
[Lambda Cloud] Multinode support (#1718)
Browse files Browse the repository at this point in the history
* Multinode working

* Change lambda default to gpu_1x_a10

* Add internal ip to local machine

* Don't make SSHCommandRunner stream logs

* Add tag file refresh and update smoke tests

* Update cluster name limit

* Revert node timeout

* Increase timeout again and nit

* Nits

* Make Lambda smoke tests use A10

* Use get and set instead of __getitem__ and __setitem__

* Handle no ip case

* Update tests

* Format

* Update optimizer dryruns
  • Loading branch information
ewzeng authored Mar 17, 2023
1 parent 75775d3 commit bb6429b
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 197 deletions.
8 changes: 5 additions & 3 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,11 +1692,13 @@ def _query_status_lambda(
'terminated': None,
}
# TODO(ewzeng): filter by hash_filter_string to be safe
status_list = []
vms = lambda_utils.LambdaCloudClient().list_instances()
possible_names = [f'{cluster}-head', f'{cluster}-worker']
for node in vms:
if node['name'] == cluster:
return [status_map[node['status']]]
return []
if node.get('name') in possible_names:
status_list.append(status_map[node['status']])
return status_list


_QUERY_STATUS_FUNCS = {
Expand Down
11 changes: 9 additions & 2 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@

# Timeout (seconds) for provision progress: if in this duration no new nodes
# are launched, abort and failover.
_NODES_LAUNCHING_PROGRESS_TIMEOUT = 90
_NODES_LAUNCHING_PROGRESS_TIMEOUT = {
clouds.AWS: 90,
clouds.Azure: 90,
clouds.GCP: 90,
clouds.Lambda: 120,
clouds.Local: 90,
}

# Time gap between retries after failing to provision in all possible places.
# Used only if --retry-until-up is set.
Expand Down Expand Up @@ -1599,7 +1605,8 @@ def need_ray_up(
cluster_config_file,
cluster_handle.launched_nodes,
log_path=log_abs_path,
nodes_launching_progress_timeout=_NODES_LAUNCHING_PROGRESS_TIMEOUT,
nodes_launching_progress_timeout=_NODES_LAUNCHING_PROGRESS_TIMEOUT[
type(to_provision_cloud)],
is_local_cloud=isinstance(to_provision_cloud, clouds.Local))
if cluster_ready:
cluster_status = self.GangSchedulingStatus.CLUSTER_READY
Expand Down
5 changes: 2 additions & 3 deletions sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ class Lambda(clouds.Cloud):

# Lamdba has a 64 char limit for cluster name.
# Reference: https://cloud.lambdalabs.com/api/v1/docs#operation/launchInstance # pylint: disable=line-too-long
_MAX_CLUSTER_NAME_LEN_LIMIT = 64
# However, we need to account for the suffixes '-head' and '-worker'
_MAX_CLUSTER_NAME_LEN_LIMIT = 57
# Currently, none of clouds.CloudImplementationFeatures are implemented
# for Lambda Cloud.
# STOP/AUTOSTOP: The Lambda cloud provider does not support stopping VMs.
# MULTI_NODE: Multi-node is not supported by the implementation yet.
_CLOUD_UNSUPPORTED_FEATURES = {
clouds.CloudImplementationFeatures.STOP: 'Lambda cloud does not support stopping VMs.',
clouds.CloudImplementationFeatures.AUTOSTOP: 'Lambda cloud does not support stopping VMs.',
clouds.CloudImplementationFeatures.MULTI_NODE: 'Multi-node is not supported by the Lambda Cloud implementation yet.',
}

@classmethod
Expand Down
14 changes: 3 additions & 11 deletions sky/clouds/service_catalog/lambda_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

_df = common.read_catalog('lambda/vms.csv')

# Number of vCPUS for gpu_1x_a100_sxm4
# Number of vCPUS for gpu_1x_a10
_DEFAULT_NUM_VCPUS = 30
_DEFAULT_MEMORY_CPU_RATIO = 4

Expand Down Expand Up @@ -69,16 +69,8 @@ def get_default_instance_type(cpus: Optional[str] = None,
memory_gb_or_ratio = f'{_DEFAULT_MEMORY_CPU_RATIO}x'
else:
memory_gb_or_ratio = memory

# Set to gpu_1x_a100_sxm4 to be the default instance type if match vCPU
# requirement.
df = _df[_df['InstanceType'].eq('gpu_1x_a100_sxm4')]
instance = common.get_instance_type_for_cpus_mem_impl(
df, cpus, memory_gb_or_ratio)
if not instance:
instance = common.get_instance_type_for_cpus_mem_impl(
_df, cpus, memory_gb_or_ratio)
return instance
return common.get_instance_type_for_cpus_mem_impl(_df, cpus,
memory_gb_or_ratio)


def get_accelerators_from_instance_type(
Expand Down
9 changes: 5 additions & 4 deletions sky/skylet/providers/lambda_cloud/lambda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import json
import requests
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

CREDENTIALS_PATH = '~/.lambda_cloud/lambda_keys'
API_ENDPOINT = 'https://cloud.lambdalabs.com/api/v1'
Expand All @@ -24,13 +24,14 @@ def __init__(self, path_prefix: str, cluster_name: str) -> None:
# In case parent directory does not exist
os.makedirs(os.path.dirname(self.path), exist_ok=True)

def __getitem__(self, instance_id: str) -> Dict[str, Any]:
assert os.path.exists(self.path), 'Metadata file not found'
def get(self, instance_id: str) -> Optional[Dict[str, Any]]:
if not os.path.exists(self.path):
return None
with open(self.path, 'r') as f:
metadata = json.load(f)
return metadata.get(instance_id)

def __setitem__(self, instance_id: str, value: Dict[str, Any]) -> None:
def set(self, instance_id: str, value: Optional[Dict[str, Any]]) -> None:
# Read from metadata file
if os.path.exists(self.path):
with open(self.path, 'r') as f:
Expand Down
195 changes: 113 additions & 82 deletions sky/skylet/providers/lambda_cloud/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
TAG_RAY_CLUSTER_NAME,
TAG_RAY_USER_NODE_TYPE,
TAG_RAY_NODE_NAME,
TAG_RAY_LAUNCH_CONFIG,
TAG_RAY_NODE_STATUS,
STATUS_UP_TO_DATE,
TAG_RAY_NODE_KIND,
NODE_KIND_WORKER,
NODE_KIND_HEAD,
)
from ray.autoscaler._private.util import hash_launch_conf
from sky.skylet.providers.lambda_cloud import lambda_utils
from sky.utils import common_utils
from sky import authentication as auth
from sky.utils import command_runner
from sky.utils import subprocess_utils
from sky.utils import ux_utils

TAG_PATH_PREFIX = '~/.sky/generated/lambda_cloud/metadata'
REMOTE_RAY_YAML = '~/ray_bootstrap_config.yaml'
_TAG_PATH_PREFIX = '~/.sky/generated/lambda_cloud/metadata'
_REMOTE_RAY_SSH_KEY = '~/ray_bootstrap_key.pem'
_GET_INTERNAL_IP_CMD = 'ip -4 -br addr show | grep -Eo "10\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"'

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,81 +52,101 @@ def __init__(self, provider_config: Dict[str, Any],
self.lock = RLock()
self.lambda_client = lambda_utils.LambdaCloudClient()
self.cached_nodes = {}
self.metadata = lambda_utils.Metadata(TAG_PATH_PREFIX, cluster_name)
vms = self._list_instances_in_cluster()

# The tag file for autodowned clusters is not autoremoved. Hence, if
# a previous cluster was autodowned and has the same name as the
# current cluster, then self.metadata might load the old tag file.
# We prevent this by removing any old vms in the tag file.
self.metadata.refresh([node['id'] for node in vms])

# If tag file does not exist on head, create it and add basic tags.
# This is a hack to make sure that ray on head can access some
# important tags.
# TODO(ewzeng): change when Lambda Cloud adds tag support.
ray_yaml_path = os.path.expanduser(REMOTE_RAY_YAML)
if os.path.exists(ray_yaml_path) and not os.path.exists(
self.metadata.path):
config = common_utils.read_yaml(ray_yaml_path)
# Ensure correct cluster so sky launch on head node works correctly
if config['cluster_name'] != cluster_name:
return
# Compute launch hash
head_node_config = config.get('head_node', {})
head_node_type = config.get('head_node_type')
if head_node_type:
head_config = config['available_node_types'][head_node_type]
head_node_config.update(head_config["node_config"])
launch_hash = hash_launch_conf(head_node_config, config['auth'])
# Populate tags
for node in vms:
self.metadata[node['id']] = {
'tags': {
TAG_RAY_CLUSTER_NAME: cluster_name,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
TAG_RAY_USER_NODE_TYPE: 'ray_head_default',
TAG_RAY_NODE_NAME: f'ray-{cluster_name}-head',
TAG_RAY_LAUNCH_CONFIG: launch_hash,
}
}
self.metadata = lambda_utils.Metadata(_TAG_PATH_PREFIX, cluster_name)
self.ssh_key_path = os.path.expanduser(auth.PRIVATE_SSH_KEY_PATH)
remote_ssh_key = os.path.expanduser(_REMOTE_RAY_SSH_KEY)
if os.path.exists(remote_ssh_key):
self.ssh_key_path = remote_ssh_key

def _guess_and_add_missing_tags(self, vms: Dict[str, Any]) -> None:
"""Adds missing vms to local tag file and guesses their tags."""
for node in vms:
if self.metadata.get(node['id']) is not None:
pass
elif node['name'] == f'{self.cluster_name}-head':
self.metadata.set(
node['id'], {
'tags': {
TAG_RAY_CLUSTER_NAME: self.cluster_name,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
TAG_RAY_USER_NODE_TYPE: 'ray_head_default',
TAG_RAY_NODE_NAME: f'ray-{self.cluster_name}-head',
}
})
elif node['name'] == f'{self.cluster_name}-worker':
self.metadata.set(
node['id'], {
'tags': {
TAG_RAY_CLUSTER_NAME: self.cluster_name,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
TAG_RAY_USER_NODE_TYPE: 'ray_worker_default',
TAG_RAY_NODE_NAME: f'ray-{self.cluster_name}-worker',
}
})

def _list_instances_in_cluster(self) -> Dict[str, Any]:
"""List running instances in cluster."""
vms = self.lambda_client.list_instances()
return [node for node in vms if node.get('name') == self.cluster_name]
possible_names = [
f'{self.cluster_name}-head', f'{self.cluster_name}-worker'
]
return [node for node in vms if node.get('name') in possible_names]

@synchronized
def _get_filtered_nodes(self, tag_filters: Dict[str,
str]) -> Dict[str, Any]:

def match_tags(vm):
vm_info = self.metadata[vm['id']]
def _extract_metadata(vm: Dict[str, Any]) -> Dict[str, Any]:
metadata = {'id': vm['id'], 'status': vm['status'], 'tags': {}}
instance_info = self.metadata.get(vm['id'])
if instance_info is not None:
metadata['tags'] = instance_info['tags']
with ux_utils.print_exception_no_traceback():
if 'ip' not in vm:
raise lambda_utils.LambdaCloudError(
'A node ip address was not found. Either '
'(1) Lambda Cloud has internally errored, or '
'(2) the cluster is still booting. '
'You can manually terminate the cluster on the '
'Lambda Cloud console or (in case 2) wait for '
'booting to finish (~2 minutes).')
metadata['external_ip'] = vm['ip']
return metadata

def _match_tags(vm: Dict[str, Any]):
vm_info = self.metadata.get(vm['id'])
tags = {} if vm_info is None else vm_info['tags']
for k, v in tag_filters.items():
if tags.get(k) != v:
return False
return True

def _get_internal_ip(node: Dict[str, Any]):
# TODO(ewzeng): cache internal ips in metadata file to reduce
# ssh overhead.
runner = command_runner.SSHCommandRunner(node['external_ip'],
'ubuntu',
self.ssh_key_path)
rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc,
_GET_INTERNAL_IP_CMD,
'Failed get obtain private IP from node',
stderr=stdout + stderr)
node['internal_ip'] = stdout.strip()

vms = self._list_instances_in_cluster()
nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)]
self.metadata.refresh([node['id'] for node in vms])
self._guess_and_add_missing_tags(vms)
nodes = [_extract_metadata(vm) for vm in filter(_match_tags, vms)]
subprocess_utils.run_in_parallel(_get_internal_ip, nodes)
self.cached_nodes = {node['id']: node for node in nodes}
return self.cached_nodes

def _extract_metadata(self, vm: Dict[str, Any]) -> Dict[str, Any]:
metadata = {'id': vm['id'], 'status': vm['status'], 'tags': {}}
instance_info = self.metadata[vm['id']]
if instance_info is not None:
metadata['tags'] = instance_info['tags']
ip = vm['ip']
metadata['external_ip'] = ip
# TODO(ewzeng): The internal ip is hard to get, so set it to the
# external ip as a hack. This should be changed in the future.
# https://docs.lambdalabs.com/cloud/learn-private-ip-address/
metadata['internal_ip'] = ip
return metadata

def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
"""Return a list of node ids filtered by the specified tags dict.
Expand Down Expand Up @@ -164,44 +186,53 @@ def internal_ip(self, node_id: str) -> str:
def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
count: int) -> None:
"""Creates a number of nodes within the namespace."""
assert count == 1, count # Only support 1-node clusters for now

# get the tags
# Get tags
config_tags = node_config.get('tags', {}).copy()
config_tags.update(tags)
config_tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name

# create the node
ttype = node_config['InstanceType']
# Create nodes
instance_type = node_config['InstanceType']
region = self.provider_config['region']
vm_list = self.lambda_client.create_instances(instance_type=ttype,
region=region,
quantity=1,
name=self.cluster_name)
assert len(vm_list) == 1, len(vm_list)
vm_id = vm_list[0]
self.metadata[vm_id] = {'tags': config_tags}

# Wait for booting to finish
# TODO(ewzeng): For multi-node, launch all vms first and then wait.
if config_tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD:
name = f'{self.cluster_name}-head'
else:
name = f'{self.cluster_name}-worker'
# Lambda launch api only supports launching one node at a time,
# so we do a loop. Remove loop when launch api allows quantity > 1
booting_list = []
for _ in range(count):
vm_id = self.lambda_client.create_instances(
instance_type=instance_type,
region=region,
quantity=1,
name=name)[0]
self.metadata.set(vm_id, {'tags': config_tags})
booting_list.append(vm_id)
time.sleep(10) # Avoid api rate limits

# Wait for nodes to finish booting
while True:
vms = self.lambda_client.list_instances()
for vm in vms:
if vm['id'] == vm_id and vm['status'] == 'active':
return
vms = self._list_instances_in_cluster()
for vm_id in booting_list.copy():
for vm in vms:
if vm['id'] == vm_id and vm['status'] == 'active':
booting_list.remove(vm_id)
if len(booting_list) == 0:
return
time.sleep(10)

@synchronized
def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None:
"""Sets the tag values (string dict) for the specified node."""
node = self._get_node(node_id)
node['tags'].update(tags)
self.metadata[node_id] = {'tags': node['tags']}
self.metadata.set(node_id, {'tags': node['tags']})

def terminate_node(self, node_id: str) -> None:
"""Terminates the specified node."""
self.lambda_client.remove_instances(node_id)
self.metadata[node_id] = None
self.metadata.set(node_id, None)

def _get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
self._get_filtered_nodes({}) # Side effect: updates cache
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizer_dryruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_instance_type_from_cpu_memory(monkeypatch, capfd):
assert 'r6i.2xlarge' in stdout # AWS, 8 vCPUs, 64 GB memory
assert 'Standard_E8_v5' in stdout # Azure, 8 vCPUs, 64 GB memory
assert 'n2-highmem-8' in stdout # GCP, 8 vCPUs, 64 GB memory
assert 'gpu_1x_a100_sxm4' in stdout # Lambda, 30 vCPUs, 200 GB memory
assert 'gpu_1x_a6000' in stdout # Lambda, 14 vCPUs, 100 GB memory

_test_resources_launch(monkeypatch, cpus='4+', memory='4+')
stdout, _ = capfd.readouterr()
Expand All @@ -280,7 +280,7 @@ def test_instance_type_from_cpu_memory(monkeypatch, capfd):
assert 'n2-highcpu-4' in stdout # GCP, 4 vCPUs, 4 GB memory
assert 'c6i.xlarge' in stdout # AWS, 4 vCPUs, 8 GB memory
assert 'Standard_F4s_v2' in stdout # Azure, 4 vCPUs, 8 GB memory
assert 'gpu_1x_a100_sxm4' in stdout # Lambda, 30 vCPUs, 200 GB memory
assert 'gpu_1x_rtx6000' in stdout # Lambda, 14 vCPUs, 46 GB memory

_test_resources_launch(monkeypatch, accelerators='T4')
stdout, _ = capfd.readouterr()
Expand Down
Loading

0 comments on commit bb6429b

Please sign in to comment.