Skip to content

Commit

Permalink
Restructure cluster.run to be more clear.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jun 25, 2024
1 parent f465886 commit 9060472
Showing 1 changed file with 31 additions and 43 deletions.
74 changes: 31 additions & 43 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import yaml

from runhouse.resources.envs.utils import run_with_logs

from runhouse.rns.utils.api import ResourceAccess, ResourceVisibility
from runhouse.servers.http.certs import TLSCertConfig
from runhouse.utils import locate_working_dir
Expand Down Expand Up @@ -1258,25 +1256,10 @@ def run(
if isinstance(env, Env) and not env.name:
env = self._default_env
env = env or self.default_env
env = _get_env_from(env)

if node == "all":
res_list = []
for node in self.ips:
res = self.run(
commands=commands,
env=env,
stream_logs=stream_logs,
port_forward=port_forward,
require_outputs=require_outputs,
node=node,
_ssh_mode=_ssh_mode,
)
res_list.append(res)
return res_list

# TODO [DG] suspend autostop while running

if env and not port_forward and not node:
# If node is not specified, then we just use normal logic, knowing that we are likely on the head node
if not node and not port_forward:
env_name = (
env
if isinstance(env, str)
Expand All @@ -1296,33 +1279,38 @@ def run(
return_codes.append(ret_code)
return return_codes

env = _get_env_from(env)
# Node is specified, so we do everything via ssh
else:
if node == "all":
res_list = []
for node in self.ips:
res = self.run(
commands=commands,
env=env,
stream_logs=stream_logs,
port_forward=port_forward,
require_outputs=require_outputs,
node=node,
_ssh_mode=_ssh_mode,
)
res_list.append(res)
return res_list

if self.on_this_cluster():
return_codes = []
location_str = "locally" if not self.name else f"on {self.name}"
for command in commands:
command = env._full_command(command)
logger.info(f"Running command {location_str}: {command}")
ret_code = run_with_logs(
command, stream_logs=stream_logs, require_outputs=require_outputs
)
return_codes.append(ret_code)
return return_codes
else:

full_commands = [env._full_command(cmd) for cmd in commands]
full_commands = [env._full_command(cmd) for cmd in commands]

# Create and save the Run locally
return_codes = self._run_commands_with_ssh(
full_commands,
cmd_prefix="",
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
require_outputs=require_outputs,
)
# Create and save the Run locally
return_codes = self._run_commands_with_ssh(
full_commands,
cmd_prefix="",
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
require_outputs=require_outputs,
)

return return_codes
return return_codes

def _run_commands_with_ssh(
self,
Expand Down

0 comments on commit 9060472

Please sign in to comment.