From 4bc2e8a7248da3a368bbb684c9f7e44632adcd1a Mon Sep 17 00:00:00 2001 From: Caroline Date: Thu, 14 Dec 2023 16:23:18 -0500 Subject: [PATCH] small bugfixes for secrets --- runhouse/resources/secrets/secret.py | 27 ++++++++++---------- runhouse/resources/secrets/secret_factory.py | 7 +++-- tests/test_secret.py | 3 --- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/runhouse/resources/secrets/secret.py b/runhouse/resources/secrets/secret.py index 17f908084..1954c9d8b 100644 --- a/runhouse/resources/secrets/secret.py +++ b/runhouse/resources/secrets/secret.py @@ -142,8 +142,8 @@ def extract_provider_secrets(cls, names: List[str] = None) -> Dict[str, "Secret" secrets = {} - # locally configured non-ssh provider secrets - for provider in _str_to_provider_class.keys(): + names = names or _str_to_provider_class.keys() + for provider in names: if provider == "ssh": continue try: @@ -153,17 +153,18 @@ def extract_provider_secrets(cls, names: List[str] = None) -> Dict[str, "Secret" continue # locally configured ssh secrets - default_ssh_folder = "~/.ssh" - ssh_files = os.listdir(os.path.expanduser(default_ssh_folder)) - for file in ssh_files: - if file != "sky-key" and f"{file}.pub" in ssh_files: - name = f"ssh-{file}" - secret = provider_secret( - provider="ssh", - name=name, - path=os.path.join(default_ssh_folder, file), - ) - secrets[name] = secret + if "ssh" in names: + default_ssh_folder = "~/.ssh" + ssh_files = os.listdir(os.path.expanduser(default_ssh_folder)) + for file in ssh_files: + if file != "sky-key" and f"{file}.pub" in ssh_files: + name = f"ssh-{file}" + secret = provider_secret( + provider="ssh", + name=name, + path=os.path.join(default_ssh_folder, file), + ) + secrets[name] = secret return secrets diff --git a/runhouse/resources/secrets/secret_factory.py b/runhouse/resources/secrets/secret_factory.py index 2f284a39e..081fcd535 100644 --- a/runhouse/resources/secrets/secret_factory.py +++ b/runhouse/resources/secrets/secret_factory.py @@ -78,8 +78,11 @@ def provider_secret( return Secret.from_name(name) elif not any([values, path, env_vars]): - secret_class = _get_provider_class(provider) - return secret_class(name=name, provider=provider, dryrun=dryrun) + if provider in Secret.builtin_providers(as_str=True): + secret_class = _get_provider_class(provider) + return secret_class(name=name, provider=provider, dryrun=dryrun) + else: + return ProviderSecret.from_name(name or provider) elif sum([bool(x) for x in [values, path, env_vars]]) == 1: secret_class = _get_provider_class(provider) diff --git a/tests/test_secret.py b/tests/test_secret.py index 6a387af26..09327d0e3 100644 --- a/tests/test_secret.py +++ b/tests/test_secret.py @@ -322,9 +322,6 @@ def test_convert_secret_resource(): headers=rns_client.request_headers, ) - with pytest.raises(ValueError): - load_config(name) - _convert_secrets_resource([name]) assert load_config(name)