Skip to content

Commit

Permalink
fix cluster list release tests for k8
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Nov 13, 2024
1 parent 9ff5428 commit 043f692
Showing 1 changed file with 65 additions and 61 deletions.
126 changes: 65 additions & 61 deletions runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,17 @@ def kubernetes_cluster(
except subprocess.CalledProcessError as e:
logger.info(f"Error setting context: {e}")

c = OnDemandCluster(
name=name,
instance_type=instance_type,
provider="kubernetes",
launcher_type=launcher_type,
server_connection_type=server_connection_type,
**kwargs,
)
c.set_connection_defaults()
try:
c = Cluster.from_name(name=name)
except ValueError:
c = OnDemandCluster(
name=name,
instance_type=instance_type,
provider="kubernetes",
launcher_type=launcher_type,
server_connection_type=server_connection_type,
**kwargs,
)

return c

Expand Down Expand Up @@ -402,7 +404,7 @@ def ondemand_cluster(
server_connection_type = kwargs.pop("server_connection_type", None)
default_env = kwargs.pop("default_env", None)

return kubernetes_cluster(
c = kubernetes_cluster(
name=name,
instance_type=instance_type,
namespace=namespace,
Expand Down Expand Up @@ -431,72 +433,74 @@ def ondemand_cluster(
**kwargs,
)

if name:
alt_options = dict(
else:
if name:
alt_options = dict(
instance_type=instance_type,
num_nodes=num_nodes,
provider=provider,
region=region,
image_id=image_id,
memory=memory,
disk_size=disk_size,
open_ports=open_ports,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
)
# Filter out None/default values
alt_options = {k: v for k, v in alt_options.items() if v is not None}
try:
c = Cluster.from_name(
name,
load_from_den=load_from_den,
dryrun=dryrun,
_alt_options=alt_options,
)
if c:
c.set_connection_defaults()
if den_auth:
c.save()
return c
except ValueError as e:
if launcher_type == LauncherType.LOCAL:
import sky

state = sky.status(cluster_names=[name], refresh=False)
if len(state) == 0 and not alt_options:
raise e

c = OnDemandCluster(
instance_type=instance_type,
num_nodes=num_nodes,
provider=provider,
region=region,
num_nodes=num_nodes,
autostop_mins=autostop_mins,
use_spot=use_spot,
image_id=image_id,
region=region,
memory=memory,
disk_size=disk_size,
open_ports=open_ports,
sky_kwargs=sky_kwargs,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
launcher_type=launcher_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
name=name,
dryrun=dryrun,
**kwargs,
)
# Filter out None/default values
alt_options = {k: v for k, v in alt_options.items() if v is not None}
try:
c = Cluster.from_name(
name,
load_from_den=load_from_den,
dryrun=dryrun,
_alt_options=alt_options,
)
if c:
c.set_connection_defaults()
if den_auth:
c.save()
return c
except ValueError as e:
if launcher_type == LauncherType.LOCAL:
import sky

state = sky.status(cluster_names=[name], refresh=False)
if len(state) == 0 and not alt_options:
raise e

c = OnDemandCluster(
instance_type=instance_type,
provider=provider,
num_nodes=num_nodes,
autostop_mins=autostop_mins,
use_spot=use_spot,
image_id=image_id,
region=region,
memory=memory,
disk_size=disk_size,
open_ports=open_ports,
sky_kwargs=sky_kwargs,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
launcher_type=launcher_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
default_env=default_env,
name=name,
dryrun=dryrun,
**kwargs,
)

c.set_connection_defaults()

if den_auth or rns_client.autosave_resources():
Expand Down

0 comments on commit 043f692

Please sign in to comment.