Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import msal_extensions only when needed #20095

Merged
merged 4 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
logging. On Python 3.7+, credentials invoked by these classes now log debug
rather than info messages.
([#18972](https://github.com/Azure/azure-sdk-for-python/issues/18972))
- Persistent cache implementations are now loaded on demand, enabling
workarounds when importing transitive dependencies such as pywin32
fails
([#19989](https://github.com/Azure/azure-sdk-for-python/issues/19989))

## 1.7.0b2 (2021-07-08)
### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import sys
from typing import TYPE_CHECKING

import msal_extensions

if TYPE_CHECKING:
from typing import Any
import msal_extensions


class TokenCachePersistenceOptions(object):
Expand Down Expand Up @@ -49,8 +48,10 @@ def __init__(self, **kwargs):

def _load_persistent_cache(options):
# type: (TokenCachePersistenceOptions) -> msal_extensions.PersistedTokenCache
import msal_extensions

persistence = _get_persistence(
allow_unencrypted=options.allow_unencrypted_storage, account_name="MSALCache", cache_name=options.name,
allow_unencrypted=options.allow_unencrypted_storage, account_name="MSALCache", cache_name=options.name
)
return msal_extensions.PersistedTokenCache(persistence)

Expand All @@ -66,6 +67,7 @@ def _get_persistence(allow_unencrypted, account_name, cache_name):
:param bool allow_unencrypted: when True, the cache will be kept in plaintext should encryption be impossible in the
current environment
"""
import msal_extensions

if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ:
cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", cache_name)
Expand All @@ -87,8 +89,8 @@ def _get_persistence(allow_unencrypted, account_name, cache_name):
except ImportError:
if not allow_unencrypted:
raise ValueError(
"PyGObject is required to encrypt the persistent cache. Please install that library or ",
"specify 'allow_unencrypted_cache=True' to store the cache without encryption.",
"PyGObject is required to encrypt the persistent cache. Please install that library or "
+ 'specify "allow_unencrypted_cache=True" to store the cache without encryption.'
)
return msal_extensions.FilePersistence(file_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ def validate_jwt(request, client_id, pem_bytes, expect_x5c=False):
def test_token_cache(cert_path, cert_password):
"""the credential should optionally use a persistent cache, and default to an in memory cache"""

with patch("azure.identity._persistent_cache.msal_extensions") as mock_msal_extensions:
with patch("azure.identity._internal.msal_credentials._load_persistent_cache") as load_persistent_cache:
credential = CertificateCredential("tenant", "client-id", cert_path, password=cert_password)
assert not mock_msal_extensions.PersistedTokenCache.called
assert not load_persistent_cache.called
assert isinstance(credential._cache, TokenCache)

CertificateCredential(
Expand All @@ -284,7 +284,7 @@ def test_token_cache(cert_path, cert_password):
password=cert_password,
cache_persistence_options=TokenCachePersistenceOptions(),
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1
assert load_persistent_cache.call_count == 1


@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ async def mock_send(request, **kwargs):
def test_token_cache(cert_path, cert_password):
"""the credential should optionally use a persistent cache, and default to an in memory cache"""

with patch("azure.identity._persistent_cache.msal_extensions") as mock_msal_extensions:
with patch(CertificateCredential.__module__ + "._load_persistent_cache") as load_persistent_cache:
with patch(CertificateCredential.__module__ + ".msal") as mock_msal:
CertificateCredential("tenant", "client-id", cert_path, password=cert_password)
assert mock_msal.TokenCache.call_count == 1
assert not mock_msal_extensions.PersistedTokenCache.called
assert not load_persistent_cache.called

CertificateCredential(
"tenant",
Expand All @@ -205,7 +205,7 @@ def test_token_cache(cert_path, cert_password):
password=cert_password,
cache_persistence_options=TokenCachePersistenceOptions(),
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1
assert load_persistent_cache.call_count == 1


@pytest.mark.asyncio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def test_regional_authority():
def test_token_cache():
"""the credential should default to an in memory cache, and optionally use a persistent cache"""

with patch("azure.identity._persistent_cache.msal_extensions") as mock_msal_extensions:
with patch("azure.identity._internal.msal_credentials._load_persistent_cache") as load_persistent_cache:
credential = ClientSecretCredential("tenant", "client-id", "secret")
assert not mock_msal_extensions.PersistedTokenCache.called
assert not load_persistent_cache.called
assert isinstance(credential._cache, TokenCache)

ClientSecretCredential(
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions()
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1
assert load_persistent_cache.call_count == 1


def test_cache_multiple_clients():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,16 @@ async def test_cache():
def test_token_cache():
"""the credential should default to an in memory cache, and optionally use a persistent cache"""

with patch("azure.identity._persistent_cache.msal_extensions") as mock_msal_extensions:
with patch(ClientSecretCredential.__module__ + "._load_persistent_cache") as load_persistent_cache:
with patch(ClientSecretCredential.__module__ + ".msal") as mock_msal:
ClientSecretCredential("tenant", "client-id", "secret")
assert mock_msal.TokenCache.call_count == 1
assert not mock_msal_extensions.PersistedTokenCache.called
assert not load_persistent_cache.called

ClientSecretCredential(
"tenant", "client-id", "secret", cache_persistence_options=TokenCachePersistenceOptions()
)
assert mock_msal_extensions.PersistedTokenCache.call_count == 1
assert load_persistent_cache.call_count == 1


@pytest.mark.asyncio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ def __init__(self, **kwargs):
def _request_token(self, *_, **__):
pass

with patch("azure.identity._persistent_cache.msal_extensions") as mock_msal_extensions:
with patch("azure.identity._internal.msal_credentials._load_persistent_cache") as load_persistent_cache:
with patch("azure.identity._internal.msal_credentials.msal") as mock_msal:
TestCredential()
assert not mock_msal_extensions.PersistedTokenCache.called
assert not load_persistent_cache.called
assert mock_msal.TokenCache.call_count == 1

TestCredential(cache_persistence_options=TokenCachePersistenceOptions())
assert mock_msal_extensions.PersistedTokenCache.call_count == 1
assert load_persistent_cache.call_count == 1


def test_home_account_id_client_info():
Expand Down
17 changes: 10 additions & 7 deletions sdk/identity/azure-identity/tests/test_persistent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# ------------------------------------
from azure.identity import InteractiveBrowserCredential, TokenCachePersistenceOptions
import pytest
import msal_extensions

from helpers import mock


def test_token_cache_persistence_options():
with mock.patch("azure.identity._persistent_cache.msal_extensions"):
with mock.patch("azure.identity._internal.msal_credentials._load_persistent_cache"):
# [START snippet]
cache_options = TokenCachePersistenceOptions()
credential = InteractiveBrowserCredential(cache_persistence_options=cache_options)
Expand All @@ -23,25 +24,27 @@ def test_token_cache_persistence_options():


@mock.patch("azure.identity._persistent_cache.sys.platform", "linux2")
@mock.patch("azure.identity._persistent_cache.msal_extensions")
def test_persistent_cache_linux(mock_extensions):
def test_persistent_cache_linux(monkeypatch):
"""Credentials should use an unencrypted cache when encryption is unavailable and the user explicitly opts in.

This test was written when Linux was the only platform on which encryption may not be available.
"""
from azure.identity._persistent_cache import _load_persistent_cache

for cls in ("FilePersistence", "LibsecretPersistence", "PersistedTokenCache"):
monkeypatch.setattr(msal_extensions, cls, mock.Mock())

_load_persistent_cache(TokenCachePersistenceOptions())
assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence)
mock_extensions.PersistedTokenCache.reset_mock()
assert msal_extensions.PersistedTokenCache.called_with(msal_extensions.LibsecretPersistence)
msal_extensions.PersistedTokenCache.reset_mock()

# when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError
mock_extensions.LibsecretPersistence = mock.Mock(side_effect=ImportError)
msal_extensions.LibsecretPersistence = mock.Mock(side_effect=ImportError)

# encryption unavailable, no unencrypted storage not allowed
with pytest.raises(ValueError):
_load_persistent_cache(TokenCachePersistenceOptions())

# encryption unavailable, unencrypted storage allowed
_load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence)
msal_extensions.PersistedTokenCache.called_with(msal_extensions.FilePersistence)