Skip to content

Commit

Permalink
Allow passing Sky kwargs to ondemand_cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg committed Jul 9, 2024
1 parent f53d4b6 commit a1c95c9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 29 deletions.
9 changes: 9 additions & 0 deletions runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def ondemand_cluster(
memory: Union[int, str, None] = None,
disk_size: Union[int, str, None] = None,
open_ports: Union[int, str, List[int], None] = None,
sky_kwargs: Dict = None,
server_port: int = None,
server_host: int = None,
server_connection_type: Union[ServerConnectionType, str] = None,
Expand Down Expand Up @@ -345,6 +346,14 @@ def ondemand_cluster(
disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. "100" or "100+".
open_ports (int or str or List[int], optional): Ports to open in the cluster's security group. Note
that you are responsible for ensuring that the applications listening on these ports are secure.
sky_kwargs (dict, optional): Additional keyword arguments to pass to the SkyPilot `Resource` or
`launch` APIs. Should be a dict of the form
`{"resources": {<resources_kwargs>}, "launch": {<launch_kwargs>}}`, where resources_kwargs and
launch_kwargs will be passed to the SkyPilot Resources API
(See `SkyPilot docs <https://skypilot.readthedocs.io/en/latest/reference/api.html#resources>`_)
and `launch` API (See
`SkyPilot docs <https://skypilot.readthedocs.io/en/latest/reference/api.html#sky-launch>`_), respectively.
Any arguments which duplicate those passed to the `ondemand_cluster` factory method will raise an error.
server_port (bool, optional): Port to use for the server. If not provided will use 80 for a
``server_connection_type`` of ``none``, 443 for ``tls`` and ``32300`` for all other SSH connection types.
server_host (bool, optional): Host from which the server listens for traffic (i.e. the --host argument
Expand Down
75 changes: 46 additions & 29 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
domain: str = None,
den_auth: bool = False,
region=None,
sky_kwargs: Dict = None,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
"""
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.region = region
self.memory = memory
self.disk_size = disk_size
self.sky_kwargs = sky_kwargs or {}

self.stable_internal_external_ips = kwargs.get(
"stable_internal_external_ips", None
Expand Down Expand Up @@ -156,6 +158,9 @@ def config(self, condensed=True):
"image_id",
"region",
"stable_internal_external_ips",
"memory",
"disk_size",
"sky_kwargs",
],
)
config["autostop_mins"] = self._autostop_mins
Expand Down Expand Up @@ -433,36 +438,48 @@ def up(self):
if self.provider != "cheapest"
else None
)
task.set_resources(
sky.Resources(
# TODO: confirm if passing instance type in old way (without --) works when provider is k8s
cloud=cloud_provider,
instance_type=self.get_instance_type(),
accelerators=self.accelerators(),
cpus=self.num_cpus(),
memory=self.memory,
region=self.region or configs.get("default_region"),
disk_size=self.disk_size,
ports=self.open_ports,
image_id=self.image_id,
use_spot=self.use_spot,
try:
task.set_resources(
sky.Resources(
# TODO: confirm if passing instance type in old way (without --) works when provider is k8s
cloud=cloud_provider,
instance_type=self.get_instance_type(),
accelerators=self.accelerators(),
cpus=self.num_cpus(),
memory=self.memory,
region=self.region or configs.get("default_region"),
disk_size=self.disk_size,
ports=self.open_ports,
image_id=self.image_id,
use_spot=self.use_spot,
**self.sky_kwargs.get("resources", {}),
)
)
)
if self.image_id:
import os

docker_env_vars = {}
for env_var in DOCKER_LOGIN_ENV_VARS:
if os.getenv(env_var):
docker_env_vars[env_var] = os.getenv(env_var)
if docker_env_vars:
task.update_envs(docker_env_vars)
sky.launch(
task,
cluster_name=self.name,
idle_minutes_to_autostop=self._autostop_mins,
down=True,
)
if self.image_id:
import os

docker_env_vars = {}
for env_var in DOCKER_LOGIN_ENV_VARS:
if os.getenv(env_var):
docker_env_vars[env_var] = os.getenv(env_var)
if docker_env_vars:
task.update_envs(docker_env_vars)
sky.launch(
task,
cluster_name=self.name,
idle_minutes_to_autostop=self._autostop_mins,
down=True,
**self.sky_kwargs.get("launch", {}),
)
# Make sure no args are passed both in sky_kwargs and as explicit args
except TypeError as e:
if "got multiple values for keyword argument" in str(e):
raise TypeError(
f"{str(e)}. If argument is in `sky_kwargs`, it may need to be passed directly through the "
f"ondemand_cluster constructor (see `ondemand_cluster docs "
f"<https://www.run.house/docs/api/python/cluster#runhouse.ondemand_cluster>`_)."
)
raise e

self._update_from_sky_status()

Expand Down

0 comments on commit a1c95c9

Please sign in to comment.