55import time
66from datetime import datetime , timezone
77from enum import Enum
8- from typing import Callable , TypeVar
8+ from typing import Callable , TypeVar , Union
99from unittest import mock
1010from unittest .mock import Mock
1111from urllib .parse import urlparse
1717from redis import Sentinel
1818from redis .auth .idp import IdentityProviderInterface
1919from redis .auth .token import JWToken
20+ from redis .auth .token_manager import RetryPolicy , TokenManagerConfig
2021from redis .backoff import NoBackoff
2122from redis .cache import (
2223 CacheConfig ,
2930from redis .credentials import CredentialProvider
3031from redis .exceptions import RedisClusterException
3132from redis .retry import Retry
32- from redis_entraid .cred_provider import EntraIdCredentialsProvider , TokenAuthConfig
33+ from redis_entraid .cred_provider import (
34+ DEFAULT_DELAY_IN_MS ,
35+ DEFAULT_EXPIRATION_REFRESH_RATIO ,
36+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
37+ DEFAULT_MAX_ATTEMPTS ,
38+ DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
39+ EntraIdCredentialsProvider ,
40+ )
3341from redis_entraid .identity_provider import (
3442 ManagedIdentityIdType ,
43+ ManagedIdentityProviderConfig ,
3544 ManagedIdentityType ,
36- create_provider_from_managed_identity ,
37- create_provider_from_service_principal ,
45+ ServicePrincipalIdentityProviderConfig ,
46+ _create_provider_from_managed_identity ,
47+ _create_provider_from_service_principal ,
3848)
3949from tests .ssl_utils import get_tls_certificates
4050
@@ -623,41 +633,58 @@ def identity_provider(request) -> IdentityProviderInterface:
623633 return mock_identity_provider ()
624634
625635 auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
636+ config = get_identity_provider_config (request = request )
626637
627638 if auth_type == "MANAGED_IDENTITY" :
628- return _get_managed_identity_provider (request )
639+ return _create_provider_from_managed_identity (config )
640+
641+ return _create_provider_from_service_principal (config )
629642
630- return _get_service_principal_provider (request )
631643
644+ def get_identity_provider_config (
645+ request ,
646+ ) -> Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ]:
647+ if hasattr (request , "param" ):
648+ kwargs = request .param .get ("idp_kwargs" , {})
649+ else :
650+ kwargs = {}
651+
652+ auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
632653
633- def _get_managed_identity_provider (request ):
634- authority = os .getenv ("AZURE_AUTHORITY" )
654+ if auth_type == AuthType .MANAGED_IDENTITY :
655+ return _get_managed_identity_provider_config (request )
656+
657+ return _get_service_principal_provider_config (request )
658+
659+
660+ def _get_managed_identity_provider_config (request ) -> ManagedIdentityProviderConfig :
635661 resource = os .getenv ("AZURE_RESOURCE" )
636- id_value = os .getenv ("AZURE_ID_VALUE " , None )
662+ id_value = os .getenv ("AZURE_USER_ASSIGNED_MANAGED_ID " , None )
637663
638664 if hasattr (request , "param" ):
639665 kwargs = request .param .get ("idp_kwargs" , {})
640666 else :
641667 kwargs = {}
642668
643669 identity_type = kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
644- id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
670+ id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .OBJECT_ID )
645671
646- return create_provider_from_managed_identity (
672+ return ManagedIdentityProviderConfig (
647673 identity_type = identity_type ,
648674 resource = resource ,
649675 id_type = id_type ,
650676 id_value = id_value ,
651- authority = authority ,
652- ** kwargs ,
677+ kwargs = kwargs ,
653678 )
654679
655680
656- def _get_service_principal_provider (request ):
681+ def _get_service_principal_provider_config (
682+ request ,
683+ ) -> ServicePrincipalIdentityProviderConfig :
657684 client_id = os .getenv ("AZURE_CLIENT_ID" )
658685 client_credential = os .getenv ("AZURE_CLIENT_SECRET" )
659- authority = os .getenv ("AZURE_AUTHORITY " )
660- scopes = os .getenv ("AZURE_REDIS_SCOPES" , [] )
686+ tenant_id = os .getenv ("AZURE_TENANT_ID " )
687+ scopes = os .getenv ("AZURE_REDIS_SCOPES" , None )
661688
662689 if hasattr (request , "param" ):
663690 kwargs = request .param .get ("idp_kwargs" , {})
@@ -671,14 +698,14 @@ def _get_service_principal_provider(request):
671698 if isinstance (scopes , str ):
672699 scopes = scopes .split ("," )
673700
674- return create_provider_from_service_principal (
701+ return ServicePrincipalIdentityProviderConfig (
675702 client_id = client_id ,
676703 client_credential = client_credential ,
677704 scopes = scopes ,
678705 timeout = timeout ,
679706 token_kwargs = token_kwargs ,
680- authority = authority ,
681- ** kwargs ,
707+ tenant_id = tenant_id ,
708+ app_kwargs = kwargs ,
682709 )
683710
684711
@@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider:
690717 return cred_provider_class (** cred_provider_kwargs )
691718
692719 idp = identity_provider (request )
693- initial_delay_in_ms = cred_provider_kwargs .get ("initial_delay_in_ms" , 0 )
694- block_for_initial = cred_provider_kwargs .get ("block_for_initial" , False )
695720 expiration_refresh_ratio = cred_provider_kwargs .get (
696- "expiration_refresh_ratio" , TokenAuthConfig . DEFAULT_EXPIRATION_REFRESH_RATIO
721+ "expiration_refresh_ratio" , DEFAULT_EXPIRATION_REFRESH_RATIO
697722 )
698723 lower_refresh_bound_millis = cred_provider_kwargs .get (
699- "lower_refresh_bound_millis" , TokenAuthConfig .DEFAULT_LOWER_REFRESH_BOUND_MILLIS
700- )
701- max_attempts = cred_provider_kwargs .get (
702- "max_attempts" , TokenAuthConfig .DEFAULT_MAX_ATTEMPTS
724+ "lower_refresh_bound_millis" , DEFAULT_LOWER_REFRESH_BOUND_MILLIS
703725 )
704- delay_in_ms = cred_provider_kwargs .get (
705- "delay_in_ms" , TokenAuthConfig .DEFAULT_DELAY_IN_MS
726+ max_attempts = cred_provider_kwargs .get ("max_attempts" , DEFAULT_MAX_ATTEMPTS )
727+ delay_in_ms = cred_provider_kwargs .get ("delay_in_ms" , DEFAULT_DELAY_IN_MS )
728+
729+ token_mgr_config = TokenManagerConfig (
730+ expiration_refresh_ratio = expiration_refresh_ratio ,
731+ lower_refresh_bound_millis = lower_refresh_bound_millis ,
732+ token_request_execution_timeout_in_ms = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS , # noqa
733+ retry_policy = RetryPolicy (
734+ max_attempts = max_attempts ,
735+ delay_in_ms = delay_in_ms ,
736+ ),
706737 )
707738
708- auth_config = TokenAuthConfig (idp )
709- auth_config .expiration_refresh_ratio = expiration_refresh_ratio
710- auth_config .lower_refresh_bound_millis = lower_refresh_bound_millis
711- auth_config .max_attempts = max_attempts
712- auth_config .delay_in_ms = delay_in_ms
713-
714739 return EntraIdCredentialsProvider (
715- config = auth_config ,
716- initial_delay_in_ms = initial_delay_in_ms ,
717- block_for_initial = block_for_initial ,
740+ identity_provider = idp ,
741+ token_manager_config = token_mgr_config ,
742+ initial_delay_in_ms = delay_in_ms ,
718743 )
719744
720745
0 commit comments