diff --git a/metaflow/extension_support/plugins.py b/metaflow/extension_support/plugins.py index 15d4c90bdcd..10c6f129ee7 100644 --- a/metaflow/extension_support/plugins.py +++ b/metaflow/extension_support/plugins.py @@ -179,6 +179,7 @@ def resolve_plugins(category): "metadata_provider": lambda x: x.TYPE, "datastore": lambda x: x.TYPE, "secrets_provider": lambda x: x.TYPE, + "azure_client_provider": lambda x: x.name, "sidecar": None, "logging_sidecar": None, "monitor_sidecar": None, diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index f36e04bf96e..a743d0a7073 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -123,6 +123,11 @@ ), ] +AZURE_CLIENT_PROVIDERS_DESC = [ + ("azure-default", ".azure.azure_credential.AzureDefaultClientProvider") +] + + process_plugins(globals()) @@ -144,6 +149,7 @@ def get_plugin_cli(): AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider") SECRETS_PROVIDERS = resolve_plugins("secrets_provider") +AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider") from .cards.card_modules import MF_EXTERNAL_CARDS diff --git a/metaflow/plugins/azure/__init__.py b/metaflow/plugins/azure/__init__.py index e69de29bb2d..422dde8e7c1 100644 --- a/metaflow/plugins/azure/__init__.py +++ b/metaflow/plugins/azure/__init__.py @@ -0,0 +1,3 @@ +from .azure_credential import ( + create_cacheable_azure_credential as create_azure_credential, +) diff --git a/metaflow/plugins/azure/azure_credential.py b/metaflow/plugins/azure/azure_credential.py new file mode 100644 index 00000000000..e0bccb6423d --- /dev/null +++ b/metaflow/plugins/azure/azure_credential.py @@ -0,0 +1,53 @@ +class AzureDefaultClientProvider(object): + name = "azure-default" + + @staticmethod + def create_cacheable_azure_credential(*args, **kwargs): + """azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary + because it does not have a content based hash and equality implementations. + + We implement a subclass CacheableDefaultAzureCredential to add them. + + We need this because credentials will be part of the cache key in _ClientCache. + """ + from azure.identity import DefaultAzureCredential + + class CacheableDefaultAzureCredential(DefaultAzureCredential): + def __init__(self, *args, **kwargs): + super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs) + # Just hashing all the kwargs works because they are all individually + # hashable as of 7/15/2022. + # + # What if Azure adds unhashable things to kwargs? + # - We will have CI to catch this (it will always install the latest Azure SDKs) + # - In Metaflow usage today we never specify any kwargs anyway. (see last line + # of the outer function. + self._hash_code = hash((args, tuple(sorted(kwargs.items())))) + + def __hash__(self): + return self._hash_code + + def __eq__(self, other): + return hash(self) == hash(other) + + return CacheableDefaultAzureCredential(*args, **kwargs) + + +cached_provider_class = None + + +def create_cacheable_azure_credential(): + global cached_provider_class + if cached_provider_class is None: + from metaflow.metaflow_config import DEFAULT_AZURE_CLIENT_PROVIDER + from metaflow.plugins import AZURE_CLIENT_PROVIDERS + + for p in AZURE_CLIENT_PROVIDERS: + if p.name == DEFAULT_AZURE_CLIENT_PROVIDER: + cached_provider_class = p + break + else: + raise ValueError( + "Cannot find Azure Client provider %s" % DEFAULT_AZURE_CLIENT_PROVIDER + ) + return cached_provider_class.create_cacheable_azure_credential() diff --git a/metaflow/plugins/azure/azure_utils.py b/metaflow/plugins/azure/azure_utils.py index 0f3f465a171..633804d1df7 100644 --- a/metaflow/plugins/azure/azure_utils.py +++ b/metaflow/plugins/azure/azure_utils.py @@ -7,6 +7,7 @@ MetaflowAzurePackageError, ) from metaflow.exception import MetaflowInternalError, MetaflowException +from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential def _check_and_init_azure_deps(): @@ -138,38 +139,6 @@ def _inner_func(*args, **kwargs): return _inner_func -@check_azure_deps -def create_cacheable_default_azure_credentials(*args, **kwargs): - """azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary - because it does not have a content based hash and equality implementations. - - We implement a subclass CacheableDefaultAzureCredential to add them. - - We need this because credentials will be part of the cache key in _ClientCache. - """ - from azure.identity import DefaultAzureCredential - - class CacheableDefaultAzureCredential(DefaultAzureCredential): - def __init__(self, *args, **kwargs): - super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs) - # Just hashing all the kwargs works because they are all individually - # hashable as of 7/15/2022. - # - # What if Azure adds unhashable things to kwargs? - # - We will have CI to catch this (it will always install the latest Azure SDKs) - # - In Metaflow usage today we never specify any kwargs anyway. (see last line - # of the outer function. - self._hash_code = hash((args, tuple(sorted(kwargs.items())))) - - def __hash__(self): - return self._hash_code - - def __eq__(self, other): - return hash(self) == hash(other) - - return CacheableDefaultAzureCredential(*args, **kwargs) - - @check_azure_deps def create_static_token_credential(token_): from azure.core.credentials import TokenCredential @@ -200,9 +169,7 @@ def __init__(self, token): def get_token(self, *_scopes, **_kwargs): if (self._cached_token.expires_on - time.time()) < 300: - from azure.identity import DefaultAzureCredential - - self._credential = DefaultAzureCredential() + self._credential = create_cacheable_azure_credential() if self._credential: return self._credential.get_token(*_scopes, **_kwargs) return self._cached_token diff --git a/metaflow/plugins/azure/blob_service_client_factory.py b/metaflow/plugins/azure/blob_service_client_factory.py index 64cd04ebe5d..4897a8cbb05 100644 --- a/metaflow/plugins/azure/blob_service_client_factory.py +++ b/metaflow/plugins/azure/blob_service_client_factory.py @@ -1,9 +1,11 @@ from metaflow.exception import MetaflowException from metaflow.metaflow_config import AZURE_STORAGE_BLOB_SERVICE_ENDPOINT from metaflow.plugins.azure.azure_utils import ( - create_cacheable_default_azure_credentials, check_azure_deps, ) +from metaflow.plugins.azure.azure_credential import ( + create_cacheable_azure_credential, +) import os import threading @@ -125,7 +127,7 @@ def get_azure_blob_service_client( blob_service_endpoint = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT if not credential: - credential = create_cacheable_default_azure_credentials() + credential = create_cacheable_azure_credential() credential_is_cacheable = True if not credential_is_cacheable: diff --git a/metaflow/plugins/datastores/azure_storage.py b/metaflow/plugins/datastores/azure_storage.py index 0b3fe9ee787..7c8c528788a 100644 --- a/metaflow/plugins/datastores/azure_storage.py +++ b/metaflow/plugins/datastores/azure_storage.py @@ -32,6 +32,8 @@ handle_executor_exceptions, ) +from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential + AZURE_STORAGE_DOWNLOAD_MAX_CONCURRENCY = 4 AZURE_STORAGE_UPLOAD_MAX_CONCURRENCY = 16 @@ -272,12 +274,10 @@ def _get_default_token(self): if not self._default_scope_token or ( self._default_scope_token.expires_on - time.time() < 300 ): - from azure.identity import DefaultAzureCredential - - with DefaultAzureCredential() as credential: - self._default_scope_token = credential.get_token( - AZURE_STORAGE_DEFAULT_SCOPE - ) + credential = create_cacheable_azure_credential() + self._default_scope_token = credential.get_token( + AZURE_STORAGE_DEFAULT_SCOPE + ) return self._default_scope_token @property