@@ -19,46 +19,102 @@ internal class ManagedIdentityClient
1919 {
2020 private const string WindowsHimdsFilePath = "%Programfiles%\\ AzureConnectedMachineAgent\\ himds.exe" ;
2121 private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds" ;
22- private static RequestContext _requestContext ;
23- private static AbstractManagedIdentity _identitySource ;
22+ private AbstractManagedIdentity s_identitySource ;
23+ public static ManagedIdentitySource s_sourceName = ManagedIdentitySource . None ;
2424
25- public ManagedIdentityClient ( RequestContext requestContext )
25+ internal async Task < ManagedIdentityResponse > SendTokenRequestForManagedIdentityAsync (
26+ RequestContext requestContext ,
27+ AcquireTokenForManagedIdentityParameters parameters ,
28+ CancellationToken cancellationToken )
2629 {
27- _requestContext = requestContext ;
28- }
29-
30- internal async Task < ManagedIdentityResponse > SendTokenRequestForManagedIdentityAsync ( AcquireTokenForManagedIdentityParameters parameters , CancellationToken cancellationToken )
31- {
32- if ( _identitySource == null )
30+ if ( s_identitySource == null )
3331 {
34- using ( _requestContext . Logger . LogMethodDuration ( ) )
32+ using ( requestContext . Logger . LogMethodDuration ( ) )
3533 {
36- _identitySource = await SelectManagedIdentitySourceAsync ( ) . ConfigureAwait ( false ) ;
34+ s_identitySource = await SelectManagedIdentitySourceAsync ( requestContext ) . ConfigureAwait ( false ) ;
3735 }
3836 }
3937
40- return await _identitySource . AuthenticateAsync ( parameters , cancellationToken ) . ConfigureAwait ( false ) ;
38+ return await s_identitySource . AuthenticateAsync ( parameters , cancellationToken ) . ConfigureAwait ( false ) ;
4139 }
4240
4341 // This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
44- private static async Task < AbstractManagedIdentity > SelectManagedIdentitySourceAsync ( )
42+ private static async Task < AbstractManagedIdentity > SelectManagedIdentitySourceAsync ( RequestContext requestContext )
4543 {
46- return await GetManagedIdentitySourceAsync ( _requestContext . Logger ) . ConfigureAwait ( false ) switch
47- {
48- ManagedIdentitySource . ServiceFabric => ServiceFabricManagedIdentitySource . Create ( _requestContext ) ,
49- ManagedIdentitySource . AppService => AppServiceManagedIdentitySource . Create ( _requestContext ) ,
50- ManagedIdentitySource . MachineLearning => MachineLearningManagedIdentitySource . Create ( _requestContext ) ,
51- ManagedIdentitySource . CloudShell => CloudShellManagedIdentitySource . Create ( _requestContext ) ,
52- ManagedIdentitySource . AzureArc => AzureArcManagedIdentitySource . Create ( _requestContext ) ,
53- ManagedIdentitySource . ImdsV2 => ImdsV2ManagedIdentitySource . Create ( _requestContext ) ,
54- _ => new ImdsManagedIdentitySource ( _requestContext )
44+ var source = ( s_sourceName != ManagedIdentitySource . None ) ? s_sourceName : await GetManagedIdentitySourceAsync ( requestContext ) . ConfigureAwait ( false ) ;
45+ return source switch
46+ {
47+ ManagedIdentitySource . ServiceFabric => ServiceFabricManagedIdentitySource . Create ( requestContext ) ,
48+ ManagedIdentitySource . AppService => AppServiceManagedIdentitySource . Create ( requestContext ) ,
49+ ManagedIdentitySource . MachineLearning => MachineLearningManagedIdentitySource . Create ( requestContext ) ,
50+ ManagedIdentitySource . CloudShell => CloudShellManagedIdentitySource . Create ( requestContext ) ,
51+ ManagedIdentitySource . AzureArc => AzureArcManagedIdentitySource . Create ( requestContext ) ,
52+ ManagedIdentitySource . ImdsV2 => ImdsV2ManagedIdentitySource . Create ( requestContext ) ,
53+ _ => new ImdsManagedIdentitySource ( requestContext )
5554 } ;
5655 }
5756
57+ // Detect managed identity source based on the availability of environment variables and csr metadata probe request.
58+ // This method is perf sensitive any changes should be benchmarked.
59+ internal static async Task < ManagedIdentitySource > GetManagedIdentitySourceAsync ( RequestContext requestContext )
60+ {
61+ string identityEndpoint = EnvironmentVariables . IdentityEndpoint ;
62+ string identityHeader = EnvironmentVariables . IdentityHeader ;
63+ string identityServerThumbprint = EnvironmentVariables . IdentityServerThumbprint ;
64+ string msiSecret = EnvironmentVariables . IdentityHeader ;
65+ string msiEndpoint = EnvironmentVariables . MsiEndpoint ;
66+ string msiSecretMachineLearning = EnvironmentVariables . MsiSecret ;
67+ string imdsEndpoint = EnvironmentVariables . ImdsEndpoint ;
68+
69+ var logger = requestContext ? . ServiceBundle ? . ApplicationLogger ;
70+ logger ? . Info ( "[Managed Identity] Detecting managed identity source..." ) ;
71+
72+ if ( ! string . IsNullOrEmpty ( identityEndpoint ) && ! string . IsNullOrEmpty ( identityHeader ) )
73+ {
74+ if ( ! string . IsNullOrEmpty ( identityServerThumbprint ) )
75+ {
76+ logger ? . Info ( "[Managed Identity] Service Fabric detected." ) ;
77+ s_sourceName = ManagedIdentitySource . ServiceFabric ;
78+ }
79+ else
80+ {
81+ logger ? . Info ( "[Managed Identity] App Service detected." ) ;
82+ s_sourceName = ManagedIdentitySource . AppService ;
83+ }
84+ }
85+ else if ( ! string . IsNullOrEmpty ( msiSecretMachineLearning ) && ! string . IsNullOrEmpty ( msiEndpoint ) )
86+ {
87+ logger ? . Info ( "[Managed Identity] Machine Learning detected." ) ;
88+ s_sourceName = ManagedIdentitySource . MachineLearning ;
89+ }
90+ else if ( ! string . IsNullOrEmpty ( msiEndpoint ) )
91+ {
92+ logger ? . Info ( "[Managed Identity] Cloud Shell detected." ) ;
93+ s_sourceName = ManagedIdentitySource . CloudShell ;
94+ }
95+ else if ( ValidateAzureArcEnvironment ( identityEndpoint , imdsEndpoint , logger ) )
96+ {
97+ logger ? . Info ( "[Managed Identity] Azure Arc detected." ) ;
98+ s_sourceName = ManagedIdentitySource . AzureArc ;
99+ }
100+ else if ( await ImdsV2ManagedIdentitySource . GetCsrMetadataAsync ( requestContext ) . ConfigureAwait ( false ) )
101+ {
102+ logger ? . Info ( "[Managed Identity] ImdsV2 detected." ) ;
103+ s_sourceName = ManagedIdentitySource . ImdsV2 ;
104+ }
105+ else
106+ {
107+ s_sourceName = ManagedIdentitySource . DefaultToImds ;
108+ }
109+
110+ return s_sourceName ;
111+ }
112+
58113 // Detect managed identity source based on the availability of environment variables.
59114 // The result of this method is not cached because reading environment variables is cheap.
60115 // This method is perf sensitive any changes should be benchmarked.
61- internal static async Task < ManagedIdentitySource > GetManagedIdentitySourceAsync ( ILoggerAdapter logger = null )
116+ [ Obsolete ( "Use GetManagedIdentitySourceAsync(RequestContext) instead." ) ]
117+ internal static ManagedIdentitySource GetManagedIdentitySource ( ILoggerAdapter logger = null )
62118 {
63119 string identityEndpoint = EnvironmentVariables . IdentityEndpoint ;
64120 string identityHeader = EnvironmentVariables . IdentityHeader ;
@@ -98,11 +154,6 @@ internal static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
98154 logger ? . Info ( "[Managed Identity] Azure Arc detected." ) ;
99155 return ManagedIdentitySource . AzureArc ;
100156 }
101- else if ( await ImdsV2ManagedIdentitySource . GetCsrMetadataAsync ( _requestContext ) . ConfigureAwait ( false ) )
102- {
103- logger ? . Info ( "[Managed Identity] ImdsV2 detected." ) ;
104- return ManagedIdentitySource . ImdsV2 ;
105- }
106157 else
107158 {
108159 return ManagedIdentitySource . DefaultToImds ;
0 commit comments