Skip to content

Commit

Permalink
Use env vars in default env creation (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen authored May 19, 2024
1 parent ec9e9a0 commit 2529b79
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 25 deletions.
4 changes: 3 additions & 1 deletion runhouse/resources/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def _run_setup_cmds(self, cluster: Cluster = None, setup_cmds: List = None):

for cmd in setup_cmds:
cmd = f"{self._run_cmd} {cmd}" if self._run_cmd else cmd
run_setup_command(cmd, cluster=cluster)
run_setup_command(
cmd, cluster=cluster, env_vars=_process_env_vars(self.env_vars)
)

def install(self, force: bool = False, cluster: Cluster = None):
"""Locally install packages and run setup commands."""
Expand Down
11 changes: 9 additions & 2 deletions runhouse/resources/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,12 @@ def run_with_logs(cmd: str, **kwargs):
return p.returncode


def run_setup_command(cmd: str, cluster: "Cluster" = None, stream_logs: bool = False):
def run_setup_command(
cmd: str,
cluster: "Cluster" = None,
env_vars: Dict = None,
stream_logs: bool = False,
):
"""
Helper function to run a command during possibly the cluster default env setup. If a cluster is provided,
run command on the cluster using SSH. If the cluster is not provided, run locally, as if already on the
Expand All @@ -188,7 +193,9 @@ def run_setup_command(cmd: str, cluster: "Cluster" = None, stream_logs: bool = F
"""
if not cluster:
return run_with_logs(cmd, stream_logs=stream_logs, require_outputs=True)[:2]
return cluster._run_commands_with_ssh([cmd], stream_logs=stream_logs)[0]
return cluster._run_commands_with_ssh(
[cmd], stream_logs=stream_logs, env_vars=env_vars
)[0]


def install_conda(cluster: "Cluster" = None):
Expand Down
61 changes: 39 additions & 22 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,9 +848,18 @@ def restart_server(
)
)

status_codes = self.run(
commands=[cmd], env=self._default_env, node=self.address
)
if self.on_this_cluster():
status_codes = self.run(
commands=[cmd], env=self._default_env, node=self.address
)
else:
status_codes = self._run_commands_with_ssh(
commands=[cmd],
cmd_prefix=self._default_env._run_cmd if self._default_env else "",
env_vars=self._default_env.env_vars if self._default_env else {},
node=self.address,
)

if not status_codes[0][0] == 0:
raise ValueError(f"Failed to restart server {self.name}")

Expand Down Expand Up @@ -1261,11 +1270,11 @@ def run(
# If not creating a Run then just run the commands via SSH and return
return self._run_commands_with_ssh(
commands,
cmd_prefix,
stream_logs,
node,
port_forward,
require_outputs,
cmd_prefix=cmd_prefix,
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
require_outputs=require_outputs,
_ssh_mode=_ssh_mode,
)

Expand All @@ -1275,11 +1284,11 @@ def run(
with run(name=run_name, cmds=commands, overwrite=True) as r:
return_codes = self._run_commands_with_ssh(
commands,
cmd_prefix,
stream_logs,
node,
port_forward,
require_outputs,
cmd_prefix=cmd_prefix,
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
require_outputs=require_outputs,
)

# Register the completed Run
Expand All @@ -1290,6 +1299,7 @@ def run(
def _run_commands_with_ssh(
self,
commands: list,
env_vars: Dict = {},
cmd_prefix: str = "",
stream_logs: bool = True,
node: str = None,
Expand Down Expand Up @@ -1317,10 +1327,20 @@ def _run_commands_with_ssh(
port=self.ssh_port,
)

if not pwd:
for command in commands:
command = f"{cmd_prefix} {command}" if cmd_prefix else command
logger.info(f"Running command on {self.name}: {command}")
env_var_prefix = (
" ".join(f"{key}={val}" for key, val in env_vars.items())
if env_vars
else ""
)

for command in commands:
command = f"{cmd_prefix} {command}" if cmd_prefix else command
logger.info(f"Running command on {self.name}: {command}")

# set env vars after log statement
command = f"{env_var_prefix} {command}" if env_var_prefix else command

if not pwd:
ssh_mode = (
SshMode.INTERACTIVE
if _ssh_mode == "interactive"
Expand All @@ -1340,12 +1360,9 @@ def _run_commands_with_ssh(
ssh_mode=ssh_mode,
)
return_codes.append(ret_code)
else:
import pexpect
else:
import pexpect

for command in commands:
command = f"{cmd_prefix} {command}" if cmd_prefix else command
logger.info(f"Running command on {self.name}: {command}")
# We need to quiet the SSH output here or it will print
# "Shared connection to ____ closed." at the end, which messes with the output.
ssh_command = runner.run(
Expand Down

0 comments on commit 2529b79

Please sign in to comment.