Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pluggable azure credentials provider #1756

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metaflow/extension_support/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def resolve_plugins(category):
"metadata_provider": lambda x: x.TYPE,
"datastore": lambda x: x.TYPE,
"secrets_provider": lambda x: x.TYPE,
"azure_client_provider": lambda x: x.name,
"sidecar": None,
"logging_sidecar": None,
"monitor_sidecar": None,
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@
),
]

AZURE_CLIENT_PROVIDERS_DESC = [
("azure-default", ".azure.azure_credential.AzureDefaultClientProvider")
]


process_plugins(globals())


Expand All @@ -143,6 +148,7 @@ def get_plugin_cli():

AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider")
SECRETS_PROVIDERS = resolve_plugins("secrets_provider")
AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider")

from .cards.card_modules import MF_EXTERNAL_CARDS

Expand Down
3 changes: 3 additions & 0 deletions metaflow/plugins/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .azure_credential import (
create_cacheable_azure_credential as create_azure_credential,
)
53 changes: 53 additions & 0 deletions metaflow/plugins/azure/azure_credential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
class AzureDefaultClientProvider(object):
name = "azure-default"

@staticmethod
def create_cacheable_azure_credential(*args, **kwargs):
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
because it does not have a content based hash and equality implementations.

We implement a subclass CacheableDefaultAzureCredential to add them.

We need this because credentials will be part of the cache key in _ClientCache.
"""
from azure.identity import DefaultAzureCredential

class CacheableDefaultAzureCredential(DefaultAzureCredential):
def __init__(self, *args, **kwargs):
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
# Just hashing all the kwargs works because they are all individually
# hashable as of 7/15/2022.
#
# What if Azure adds unhashable things to kwargs?
# - We will have CI to catch this (it will always install the latest Azure SDKs)
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
# of the outer function.
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))

def __hash__(self):
return self._hash_code

def __eq__(self, other):
return hash(self) == hash(other)

return CacheableDefaultAzureCredential(*args, **kwargs)


cached_provider_class = None


def create_cacheable_azure_credential():
global cached_provider_class
if cached_provider_class is None:
from metaflow.metaflow_config import DEFAULT_AZURE_CLIENT_PROVIDER
from metaflow.plugins import AZURE_CLIENT_PROVIDERS

for p in AZURE_CLIENT_PROVIDERS:
if p.name == DEFAULT_AZURE_CLIENT_PROVIDER:
cached_provider_class = p
break
else:
raise ValueError(
"Cannot find Azure Client provider %s" % DEFAULT_AZURE_CLIENT_PROVIDER
)
return cached_provider_class.create_cacheable_azure_credential()
37 changes: 2 additions & 35 deletions metaflow/plugins/azure/azure_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
MetaflowAzurePackageError,
)
from metaflow.exception import MetaflowInternalError, MetaflowException
from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential


def _check_and_init_azure_deps():
Expand Down Expand Up @@ -138,38 +139,6 @@ def _inner_func(*args, **kwargs):
return _inner_func


@check_azure_deps
def create_cacheable_default_azure_credentials(*args, **kwargs):
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
because it does not have a content based hash and equality implementations.

We implement a subclass CacheableDefaultAzureCredential to add them.

We need this because credentials will be part of the cache key in _ClientCache.
"""
from azure.identity import DefaultAzureCredential

class CacheableDefaultAzureCredential(DefaultAzureCredential):
def __init__(self, *args, **kwargs):
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
# Just hashing all the kwargs works because they are all individually
# hashable as of 7/15/2022.
#
# What if Azure adds unhashable things to kwargs?
# - We will have CI to catch this (it will always install the latest Azure SDKs)
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
# of the outer function.
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))

def __hash__(self):
return self._hash_code

def __eq__(self, other):
return hash(self) == hash(other)

return CacheableDefaultAzureCredential(*args, **kwargs)


@check_azure_deps
def create_static_token_credential(token_):
from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -200,9 +169,7 @@ def __init__(self, token):
def get_token(self, *_scopes, **_kwargs):

if (self._cached_token.expires_on - time.time()) < 300:
from azure.identity import DefaultAzureCredential

self._credential = DefaultAzureCredential()
self._credential = create_cacheable_azure_credential()
if self._credential:
return self._credential.get_token(*_scopes, **_kwargs)
return self._cached_token
Expand Down
6 changes: 4 additions & 2 deletions metaflow/plugins/azure/blob_service_client_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
from metaflow.plugins.azure.azure_utils import (
create_cacheable_default_azure_credentials,
check_azure_deps,
)
from metaflow.plugins.azure.azure_credential import (
create_cacheable_azure_credential,
)

import os
import threading
Expand Down Expand Up @@ -125,7 +127,7 @@ def get_azure_blob_service_client(
blob_service_endpoint = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT

if not credential:
credential = create_cacheable_default_azure_credentials()
credential = create_cacheable_azure_credential()
credential_is_cacheable = True

if not credential_is_cacheable:
Expand Down
12 changes: 6 additions & 6 deletions metaflow/plugins/datastores/azure_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
handle_executor_exceptions,
)

from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential

AZURE_STORAGE_DOWNLOAD_MAX_CONCURRENCY = 4
AZURE_STORAGE_UPLOAD_MAX_CONCURRENCY = 16

Expand Down Expand Up @@ -266,12 +268,10 @@ def _get_default_token(self):
if not self._default_scope_token or (
self._default_scope_token.expires_on - time.time() < 300
):
from azure.identity import DefaultAzureCredential

with DefaultAzureCredential() as credential:
self._default_scope_token = credential.get_token(
AZURE_STORAGE_DEFAULT_SCOPE
)
credential = create_cacheable_azure_credential()
self._default_scope_token = credential.get_token(
AZURE_STORAGE_DEFAULT_SCOPE
)
return self._default_scope_token

@property
Expand Down
Loading