Skip to content

Commit

Permalink
default ssh provider updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Nov 3, 2024
1 parent 7d21e90 commit 0bd96b0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
19 changes: 16 additions & 3 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,36 @@ def delete_configs(self, delete_creds: bool = False):
super().delete_configs()

def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]):
"""Setup cluster credentials from user provided ssh_creds"""
"""Setup cluster credentials from user provided ssh_creds. If no creds are provided, try using
the default SSH creds saved in Den."""
from runhouse.resources.secrets import Secret
from runhouse.resources.secrets.provider_secrets.sky_secret import SkySecret
from runhouse.resources.secrets.provider_secrets.ssh_secret import SSHSecret

if not hasattr(self, "_creds"):
self._creds = None

if not ssh_creds:
return
elif isinstance(ssh_creds, Secret):
self._creds = ssh_creds
return

elif isinstance(ssh_creds, str):
self._creds = Secret.from_name(ssh_creds)
return

if not ssh_creds:
from runhouse import ProviderSecret

try:
# Use the default ssh creds if saved in Den
ssh_secret = ProviderSecret.from_name("ssh")
self._creds = ssh_secret
return
except ValueError:
pass

return

creds = (
copy.copy(ssh_creds) if isinstance(ssh_creds, Dict) else (ssh_creds or {})
)
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/secrets/provider_secrets/ssh_secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def save(
if name:
self.name = name
elif not self.name:
self.name = f"ssh-{self.key}"
# If name not provided treat as the "default" SSH secret
self.name = self._PROVIDER
return super().save(
save_values=save_values,
headers=headers or rns_client.request_headers(),
Expand Down
12 changes: 9 additions & 3 deletions runhouse/resources/secrets/secret_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def secret(
name (str, optional): Name to assign the secret resource.
values (Dict, optional): Dictionary of secret key-value pairs.
load_from_den (bool): Whether to try loading the secret from Den. (Default: ``True``)
provider (str): Provider corresponding to the secret.
dryrun (bool, optional): Whether to create in dryrun mode. (Default: False)
Returns:
Secret: The resulting Secret object.
Expand Down Expand Up @@ -81,13 +82,18 @@ def provider_secret(
return Secret.from_name(name, load_from_den=load_from_den)

elif not any([values, path, env_vars]):
# try reloading by name or provider
try:
return ProviderSecret.from_name(
name or provider, load_from_den=load_from_den
)
except ValueError:
pass
if provider in Secret.builtin_providers(as_str=True):
secret_class = _get_provider_class(provider)
return secret_class(name=name, provider=provider, dryrun=dryrun)
else:
return ProviderSecret.from_name(
name or provider, load_from_den=load_from_den
)
raise ValueError(f"Provider {provider} not recognized.")

elif sum([bool(x) for x in [values, path, env_vars]]) == 1:
secret_class = _get_provider_class(provider)
Expand Down

0 comments on commit 0bd96b0

Please sign in to comment.