Skip to content

Commit

Permalink
support sky env vars for private docker registries
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Jun 6, 2024
1 parent b21a8eb commit c74bbed
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
5 changes: 5 additions & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@

EMPTY_DEFAULT_ENV_NAME = "_cluster_default_env"
DEFAULT_DOCKER_CONTAINER_NAME = "sky_container"
DOCKER_LOGIN_ENV_VARS = {
"SKYPILOT_DOCKER_USERNAME",
"SKYPILOT_DOCKER_PASSWORD",
"SKYPILOT_DOCKER_SERVER",
}

# Constants for the status check
DOUBLE_SPACE_UNICODE = "\u00A0\u00A0"
Expand Down
1 change: 1 addition & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,7 @@ def _run_commands_with_ssh(
stream_logs=stream_logs,
port_forward=port_forward,
ssh_mode=ssh_mode,
quiet_ssh=True,
)
return_codes.append(ret_code)
else:
Expand Down
21 changes: 16 additions & 5 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_SERVER_PORT,
DOCKER_LOGIN_ENV_VARS,
LOCAL_HOSTS,
)

Expand Down Expand Up @@ -445,6 +446,15 @@ def up(self):
use_spot=self.use_spot,
)
)
if self.image_id:
import os

docker_env_vars = {}
for env_var in DOCKER_LOGIN_ENV_VARS:
if os.getenv(env_var):
docker_env_vars[env_var] = os.getenv(env_var)
if docker_env_vars:
task.update_envs(docker_env_vars)
sky.launch(
task,
cluster_name=self.name,
Expand Down Expand Up @@ -566,10 +576,11 @@ def ssh(self, node: str = None):
ssh_private_key=str(sky_key),
docker_user=self.docker_user,
)
cmd = runner.run(
cmd="bash --rcfile <(echo '. ~/.bashrc; conda deactivate')",
ssh_mode=SshMode.INTERACTIVE,
port_forward=None,
return_cmd=True,
ssh_command = runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
subprocess.run(
" ".join(ssh_command),
shell=True,
)
subprocess.run(cmd, shell=True)

0 comments on commit c74bbed

Please sign in to comment.