|
4 | 4 | import random
|
5 | 5 | import time
|
6 | 6 | from datetime import datetime, timezone
|
7 |
| -from enum import Enum |
8 |
| -from typing import Callable, TypeVar, Union |
| 7 | +from typing import Callable, TypeVar |
9 | 8 | from unittest import mock
|
10 | 9 | from unittest.mock import Mock
|
11 | 10 | from urllib.parse import urlparse
|
|
17 | 16 | from redis import Sentinel
|
18 | 17 | from redis.auth.idp import IdentityProviderInterface
|
19 | 18 | from redis.auth.token import JWToken
|
20 |
| -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig |
21 | 19 | from redis.backoff import NoBackoff
|
22 | 20 | from redis.cache import (
|
23 | 21 | CacheConfig,
|
|
30 | 28 | from redis.credentials import CredentialProvider
|
31 | 29 | from redis.exceptions import RedisClusterException
|
32 | 30 | from redis.retry import Retry
|
33 |
| -from redis_entraid.cred_provider import ( |
34 |
| - DEFAULT_DELAY_IN_MS, |
35 |
| - DEFAULT_EXPIRATION_REFRESH_RATIO, |
36 |
| - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, |
37 |
| - DEFAULT_MAX_ATTEMPTS, |
38 |
| - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, |
39 |
| - EntraIdCredentialsProvider, |
40 |
| -) |
41 |
| -from redis_entraid.identity_provider import ( |
42 |
| - ManagedIdentityIdType, |
43 |
| - ManagedIdentityProviderConfig, |
44 |
| - ManagedIdentityType, |
45 |
| - ServicePrincipalIdentityProviderConfig, |
46 |
| - _create_provider_from_managed_identity, |
47 |
| - _create_provider_from_service_principal, |
48 |
| -) |
49 | 31 | from tests.ssl_utils import get_tls_certificates
|
50 | 32 |
|
51 | 33 | REDIS_INFO = {}
|
|
61 | 43 | _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]
|
62 | 44 |
|
63 | 45 |
|
64 |
| -class AuthType(Enum): |
65 |
| - MANAGED_IDENTITY = "managed_identity" |
66 |
| - SERVICE_PRINCIPAL = "service_principal" |
67 |
| - |
68 |
| - |
69 | 46 | # Taken from python3.9
|
70 | 47 | class BooleanOptionalAction(argparse.Action):
|
71 | 48 | def __init__(
|
@@ -623,124 +600,21 @@ def mock_identity_provider() -> IdentityProviderInterface:
|
623 | 600 | return mock_provider
|
624 | 601 |
|
625 | 602 |
|
626 |
| -def identity_provider(request) -> IdentityProviderInterface: |
627 |
| - if hasattr(request, "param"): |
628 |
| - kwargs = request.param.get("idp_kwargs", {}) |
629 |
| - else: |
630 |
| - kwargs = {} |
631 |
| - |
632 |
| - if request.param.get("mock_idp", None) is not None: |
633 |
| - return mock_identity_provider() |
634 |
| - |
635 |
| - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) |
636 |
| - config = get_identity_provider_config(request=request) |
637 |
| - |
638 |
| - if auth_type == "MANAGED_IDENTITY": |
639 |
| - return _create_provider_from_managed_identity(config) |
640 |
| - |
641 |
| - return _create_provider_from_service_principal(config) |
642 |
| - |
643 |
| - |
644 |
| -def get_identity_provider_config( |
645 |
| - request, |
646 |
| -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: |
647 |
| - if hasattr(request, "param"): |
648 |
| - kwargs = request.param.get("idp_kwargs", {}) |
649 |
| - else: |
650 |
| - kwargs = {} |
651 |
| - |
652 |
| - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) |
653 |
| - |
654 |
| - if auth_type == AuthType.MANAGED_IDENTITY: |
655 |
| - return _get_managed_identity_provider_config(request) |
656 |
| - |
657 |
| - return _get_service_principal_provider_config(request) |
658 |
| - |
659 |
| - |
660 |
| -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: |
661 |
| - resource = os.getenv("AZURE_RESOURCE") |
662 |
| - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) |
663 |
| - |
664 |
| - if hasattr(request, "param"): |
665 |
| - kwargs = request.param.get("idp_kwargs", {}) |
666 |
| - else: |
667 |
| - kwargs = {} |
668 |
| - |
669 |
| - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) |
670 |
| - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) |
671 |
| - |
672 |
| - return ManagedIdentityProviderConfig( |
673 |
| - identity_type=identity_type, |
674 |
| - resource=resource, |
675 |
| - id_type=id_type, |
676 |
| - id_value=id_value, |
677 |
| - kwargs=kwargs, |
678 |
| - ) |
679 |
| - |
680 |
| - |
681 |
| -def _get_service_principal_provider_config( |
682 |
| - request, |
683 |
| -) -> ServicePrincipalIdentityProviderConfig: |
684 |
| - client_id = os.getenv("AZURE_CLIENT_ID") |
685 |
| - client_credential = os.getenv("AZURE_CLIENT_SECRET") |
686 |
| - tenant_id = os.getenv("AZURE_TENANT_ID") |
687 |
| - scopes = os.getenv("AZURE_REDIS_SCOPES", None) |
688 |
| - |
689 |
| - if hasattr(request, "param"): |
690 |
| - kwargs = request.param.get("idp_kwargs", {}) |
691 |
| - token_kwargs = request.param.get("token_kwargs", {}) |
692 |
| - timeout = request.param.get("timeout", None) |
693 |
| - else: |
694 |
| - kwargs = {} |
695 |
| - token_kwargs = {} |
696 |
| - timeout = None |
697 |
| - |
698 |
| - if isinstance(scopes, str): |
699 |
| - scopes = scopes.split(",") |
700 |
| - |
701 |
| - return ServicePrincipalIdentityProviderConfig( |
702 |
| - client_id=client_id, |
703 |
| - client_credential=client_credential, |
704 |
| - scopes=scopes, |
705 |
| - timeout=timeout, |
706 |
| - token_kwargs=token_kwargs, |
707 |
| - tenant_id=tenant_id, |
708 |
| - app_kwargs=kwargs, |
709 |
| - ) |
710 |
| - |
711 |
| - |
712 | 603 | def get_credential_provider(request) -> CredentialProvider:
|
713 | 604 | cred_provider_class = request.param.get("cred_provider_class")
|
714 | 605 | cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
|
715 | 606 |
|
716 |
| - if cred_provider_class != EntraIdCredentialsProvider: |
| 607 | + if not cred_provider_class: |
| 608 | + pytest.skip("No credential provider class specified in the test") |
| 609 | + |
| 610 | + # Since we can't import EntraIdCredentialsProvider in this module, |
| 611 | + # we'll just check the class name. |
| 612 | + if cred_provider_class.__name__ != "EntraIdCredentialsProvider": |
717 | 613 | return cred_provider_class(**cred_provider_kwargs)
|
718 | 614 |
|
719 |
| - idp = identity_provider(request) |
720 |
| - expiration_refresh_ratio = cred_provider_kwargs.get( |
721 |
| - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO |
722 |
| - ) |
723 |
| - lower_refresh_bound_millis = cred_provider_kwargs.get( |
724 |
| - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS |
725 |
| - ) |
726 |
| - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) |
727 |
| - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) |
728 |
| - |
729 |
| - token_mgr_config = TokenManagerConfig( |
730 |
| - expiration_refresh_ratio=expiration_refresh_ratio, |
731 |
| - lower_refresh_bound_millis=lower_refresh_bound_millis, |
732 |
| - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa |
733 |
| - retry_policy=RetryPolicy( |
734 |
| - max_attempts=max_attempts, |
735 |
| - delay_in_ms=delay_in_ms, |
736 |
| - ), |
737 |
| - ) |
| 615 | + from tests.entraid_utils import get_entra_id_credentials_provider |
738 | 616 |
|
739 |
| - return EntraIdCredentialsProvider( |
740 |
| - identity_provider=idp, |
741 |
| - token_manager_config=token_mgr_config, |
742 |
| - initial_delay_in_ms=delay_in_ms, |
743 |
| - ) |
| 617 | + return get_entra_id_credentials_provider(request, cred_provider_kwargs) |
744 | 618 |
|
745 | 619 |
|
746 | 620 | @pytest.fixture()
|
|
0 commit comments