Skip to content

Commit ec3f003

Browse files
committed
Isolate redis-entraid dependency for tests
1 parent 4418907 commit ec3f003

File tree

5 files changed

+160
-290
lines changed

5 files changed

+160
-290
lines changed

tests/conftest.py

+9-135
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import random
55
import time
66
from datetime import datetime, timezone
7-
from enum import Enum
8-
from typing import Callable, TypeVar, Union
7+
from typing import Callable, TypeVar
98
from unittest import mock
109
from unittest.mock import Mock
1110
from urllib.parse import urlparse
@@ -17,7 +16,6 @@
1716
from redis import Sentinel
1817
from redis.auth.idp import IdentityProviderInterface
1918
from redis.auth.token import JWToken
20-
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2119
from redis.backoff import NoBackoff
2220
from redis.cache import (
2321
CacheConfig,
@@ -30,22 +28,6 @@
3028
from redis.credentials import CredentialProvider
3129
from redis.exceptions import RedisClusterException
3230
from redis.retry import Retry
33-
from redis_entraid.cred_provider import (
34-
DEFAULT_DELAY_IN_MS,
35-
DEFAULT_EXPIRATION_REFRESH_RATIO,
36-
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
37-
DEFAULT_MAX_ATTEMPTS,
38-
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
39-
EntraIdCredentialsProvider,
40-
)
41-
from redis_entraid.identity_provider import (
42-
ManagedIdentityIdType,
43-
ManagedIdentityProviderConfig,
44-
ManagedIdentityType,
45-
ServicePrincipalIdentityProviderConfig,
46-
_create_provider_from_managed_identity,
47-
_create_provider_from_service_principal,
48-
)
4931
from tests.ssl_utils import get_tls_certificates
5032

5133
REDIS_INFO = {}
@@ -61,11 +43,6 @@
6143
_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]
6244

6345

64-
class AuthType(Enum):
65-
MANAGED_IDENTITY = "managed_identity"
66-
SERVICE_PRINCIPAL = "service_principal"
67-
68-
6946
# Taken from python3.9
7047
class BooleanOptionalAction(argparse.Action):
7148
def __init__(
@@ -623,124 +600,21 @@ def mock_identity_provider() -> IdentityProviderInterface:
623600
return mock_provider
624601

625602

626-
def identity_provider(request) -> IdentityProviderInterface:
627-
if hasattr(request, "param"):
628-
kwargs = request.param.get("idp_kwargs", {})
629-
else:
630-
kwargs = {}
631-
632-
if request.param.get("mock_idp", None) is not None:
633-
return mock_identity_provider()
634-
635-
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
636-
config = get_identity_provider_config(request=request)
637-
638-
if auth_type == "MANAGED_IDENTITY":
639-
return _create_provider_from_managed_identity(config)
640-
641-
return _create_provider_from_service_principal(config)
642-
643-
644-
def get_identity_provider_config(
645-
request,
646-
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
647-
if hasattr(request, "param"):
648-
kwargs = request.param.get("idp_kwargs", {})
649-
else:
650-
kwargs = {}
651-
652-
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
653-
654-
if auth_type == AuthType.MANAGED_IDENTITY:
655-
return _get_managed_identity_provider_config(request)
656-
657-
return _get_service_principal_provider_config(request)
658-
659-
660-
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
661-
resource = os.getenv("AZURE_RESOURCE")
662-
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
663-
664-
if hasattr(request, "param"):
665-
kwargs = request.param.get("idp_kwargs", {})
666-
else:
667-
kwargs = {}
668-
669-
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
670-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
671-
672-
return ManagedIdentityProviderConfig(
673-
identity_type=identity_type,
674-
resource=resource,
675-
id_type=id_type,
676-
id_value=id_value,
677-
kwargs=kwargs,
678-
)
679-
680-
681-
def _get_service_principal_provider_config(
682-
request,
683-
) -> ServicePrincipalIdentityProviderConfig:
684-
client_id = os.getenv("AZURE_CLIENT_ID")
685-
client_credential = os.getenv("AZURE_CLIENT_SECRET")
686-
tenant_id = os.getenv("AZURE_TENANT_ID")
687-
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
688-
689-
if hasattr(request, "param"):
690-
kwargs = request.param.get("idp_kwargs", {})
691-
token_kwargs = request.param.get("token_kwargs", {})
692-
timeout = request.param.get("timeout", None)
693-
else:
694-
kwargs = {}
695-
token_kwargs = {}
696-
timeout = None
697-
698-
if isinstance(scopes, str):
699-
scopes = scopes.split(",")
700-
701-
return ServicePrincipalIdentityProviderConfig(
702-
client_id=client_id,
703-
client_credential=client_credential,
704-
scopes=scopes,
705-
timeout=timeout,
706-
token_kwargs=token_kwargs,
707-
tenant_id=tenant_id,
708-
app_kwargs=kwargs,
709-
)
710-
711-
712603
def get_credential_provider(request) -> CredentialProvider:
713604
cred_provider_class = request.param.get("cred_provider_class")
714605
cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
715606

716-
if cred_provider_class != EntraIdCredentialsProvider:
607+
if not cred_provider_class:
608+
pytest.skip("No credential provider class specified in the test")
609+
610+
# Since we can't import EntraIdCredentialsProvider in this module,
611+
# we'll just check the class name.
612+
if cred_provider_class.__name__ != "EntraIdCredentialsProvider":
717613
return cred_provider_class(**cred_provider_kwargs)
718614

719-
idp = identity_provider(request)
720-
expiration_refresh_ratio = cred_provider_kwargs.get(
721-
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
722-
)
723-
lower_refresh_bound_millis = cred_provider_kwargs.get(
724-
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
725-
)
726-
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
727-
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
728-
729-
token_mgr_config = TokenManagerConfig(
730-
expiration_refresh_ratio=expiration_refresh_ratio,
731-
lower_refresh_bound_millis=lower_refresh_bound_millis,
732-
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
733-
retry_policy=RetryPolicy(
734-
max_attempts=max_attempts,
735-
delay_in_ms=delay_in_ms,
736-
),
737-
)
615+
from tests.entraid_utils import get_entra_id_credentials_provider
738616

739-
return EntraIdCredentialsProvider(
740-
identity_provider=idp,
741-
token_manager_config=token_mgr_config,
742-
initial_delay_in_ms=delay_in_ms,
743-
)
617+
return get_entra_id_credentials_provider(request, cred_provider_kwargs)
744618

745619

746620
@pytest.fixture()

tests/entraid_utils.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
from enum import Enum
3+
from typing import Union
4+
5+
from redis.auth.idp import IdentityProviderInterface
6+
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
7+
from redis_entraid.cred_provider import (
8+
DEFAULT_DELAY_IN_MS,
9+
DEFAULT_EXPIRATION_REFRESH_RATIO,
10+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
11+
DEFAULT_MAX_ATTEMPTS,
12+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
13+
EntraIdCredentialsProvider,
14+
)
15+
from redis_entraid.identity_provider import (
16+
ManagedIdentityIdType,
17+
ManagedIdentityProviderConfig,
18+
ManagedIdentityType,
19+
ServicePrincipalIdentityProviderConfig,
20+
_create_provider_from_managed_identity,
21+
_create_provider_from_service_principal,
22+
)
23+
from tests.conftest import mock_identity_provider
24+
25+
26+
class AuthType(Enum):
27+
MANAGED_IDENTITY = "managed_identity"
28+
SERVICE_PRINCIPAL = "service_principal"
29+
30+
31+
def identity_provider(request) -> IdentityProviderInterface:
32+
if hasattr(request, "param"):
33+
kwargs = request.param.get("idp_kwargs", {})
34+
else:
35+
kwargs = {}
36+
37+
if request.param.get("mock_idp", None) is not None:
38+
return mock_identity_provider()
39+
40+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
41+
config = get_identity_provider_config(request=request)
42+
43+
if auth_type == "MANAGED_IDENTITY":
44+
return _create_provider_from_managed_identity(config)
45+
46+
return _create_provider_from_service_principal(config)
47+
48+
49+
def get_identity_provider_config(
50+
request,
51+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
52+
if hasattr(request, "param"):
53+
kwargs = request.param.get("idp_kwargs", {})
54+
else:
55+
kwargs = {}
56+
57+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
58+
59+
if auth_type == AuthType.MANAGED_IDENTITY:
60+
return _get_managed_identity_provider_config(request)
61+
62+
return _get_service_principal_provider_config(request)
63+
64+
65+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
66+
resource = os.getenv("AZURE_RESOURCE")
67+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
68+
69+
if hasattr(request, "param"):
70+
kwargs = request.param.get("idp_kwargs", {})
71+
else:
72+
kwargs = {}
73+
74+
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
75+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
76+
77+
return ManagedIdentityProviderConfig(
78+
identity_type=identity_type,
79+
resource=resource,
80+
id_type=id_type,
81+
id_value=id_value,
82+
kwargs=kwargs,
83+
)
84+
85+
86+
def _get_service_principal_provider_config(
87+
request,
88+
) -> ServicePrincipalIdentityProviderConfig:
89+
client_id = os.getenv("AZURE_CLIENT_ID")
90+
client_credential = os.getenv("AZURE_CLIENT_SECRET")
91+
tenant_id = os.getenv("AZURE_TENANT_ID")
92+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
93+
94+
if hasattr(request, "param"):
95+
kwargs = request.param.get("idp_kwargs", {})
96+
token_kwargs = request.param.get("token_kwargs", {})
97+
timeout = request.param.get("timeout", None)
98+
else:
99+
kwargs = {}
100+
token_kwargs = {}
101+
timeout = None
102+
103+
if isinstance(scopes, str):
104+
scopes = scopes.split(",")
105+
106+
return ServicePrincipalIdentityProviderConfig(
107+
client_id=client_id,
108+
client_credential=client_credential,
109+
scopes=scopes,
110+
timeout=timeout,
111+
token_kwargs=token_kwargs,
112+
tenant_id=tenant_id,
113+
app_kwargs=kwargs,
114+
)
115+
116+
117+
def get_entra_id_credentials_provider(request, cred_provider_kwargs):
118+
idp = identity_provider(request)
119+
expiration_refresh_ratio = cred_provider_kwargs.get(
120+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
121+
)
122+
lower_refresh_bound_millis = cred_provider_kwargs.get(
123+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
124+
)
125+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
126+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
127+
token_mgr_config = TokenManagerConfig(
128+
expiration_refresh_ratio=expiration_refresh_ratio,
129+
lower_refresh_bound_millis=lower_refresh_bound_millis,
130+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
131+
retry_policy=RetryPolicy(
132+
max_attempts=max_attempts,
133+
delay_in_ms=delay_in_ms,
134+
),
135+
)
136+
return EntraIdCredentialsProvider(
137+
identity_provider=idp,
138+
token_manager_config=token_mgr_config,
139+
initial_delay_in_ms=delay_in_ms,
140+
)

0 commit comments

Comments
 (0)