Skip to content

Commit

Permalink
[Identity] Fix issue with cache option usage
Browse files Browse the repository at this point in the history
If a user supplies `TokenCachePersistenceOptions` to
a `SharedTokenCacheCredential`, these options are not
used when the cache is loaded. This can lead to issues
when users are trying to use caches with custom names
since the default name is used instead.

This commit ensures that user-provided cache options
are propagated when the cache is loaded.

Ref: Azure#26982

Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
pvaneck committed Oct 25, 2022
1 parent cdf13d2 commit 41946bb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, authentication_record, **kwargs):
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
Expand Down Expand Up @@ -64,10 +65,13 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
try:
# This credential accepts the user's default cache regardless of whether it's encrypted. It doesn't
# create a new cache. If the default cache exists, the user must have created it earlier. If it's
# unencrypted, the user must have allowed that.
self._cache = _load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or \
TokenCachePersistenceOptions(allow_unencrypted_storage=True)
self._cache = _load_persistent_cache(options)
except Exception: # pylint:disable=broad-except
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
self._cache = kwargs.pop("_cache", None)
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
self._client = None # type: Optional[AadClientBase]
self._client_kwargs = kwargs
self._client_kwargs["tenant_id"] = "organizations"
Expand All @@ -116,10 +117,13 @@ def _initialize(self):
def _load_cache(self):
if not self._cache and self.supported():
try:
# This credential accepts the user's default cache regardless of whether it's encrypted. It doesn't
# create a new cache. If the default cache exists, the user must have created it earlier. If it's
# unencrypted, the user must have allowed that.
self._cache = _load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or \
TokenCachePersistenceOptions(allow_unencrypted_storage=True)
self._cache = _load_persistent_cache(options)
except Exception: # pylint:disable=broad-except
pass

Expand Down
13 changes: 13 additions & 0 deletions sdk/identity/azure-identity/tests/test_shared_cache_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AzureAuthorityHosts,
CredentialUnavailableError,
SharedTokenCacheCredential,
TokenCachePersistenceOptions,
)
from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID, EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
Expand Down Expand Up @@ -764,6 +765,18 @@ def test_initialization():
assert mock_cache_loader.call_count == 1


def test_initialization_with_cache_options():
"""the credential should use user-supplied persistence options"""

with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader:
options = TokenCachePersistenceOptions(name="foo.cache")
credential = SharedTokenCacheCredential(cache_persistence_options=options)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
mock_cache_loader.assert_called_once_with(options)


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError
from azure.identity import CredentialUnavailableError, TokenCachePersistenceOptions
from azure.identity.aio import SharedTokenCacheCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
Expand Down Expand Up @@ -621,6 +621,19 @@ async def test_initialization():
assert mock_cache_loader.call_count == 1


@pytest.mark.asyncio
async def test_initialization_with_cache_options():
"""the credential should use user-supplied persistence options"""

with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader:
options = TokenCachePersistenceOptions(name="foo.cache")
credential = SharedTokenCacheCredential(cache_persistence_options=options)

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")
mock_cache_loader.assert_called_once_with(options)


@pytest.mark.asyncio
async def test_multitenant_authentication():
first_token = "***"
Expand Down

0 comments on commit 41946bb

Please sign in to comment.