77using Azure . Security . KeyVault . Keys . Cryptography ;
88using System ;
99using System . Collections . Concurrent ;
10- using System . Threading . Tasks ;
10+ using System . Threading ;
1111using static Azure . Security . KeyVault . Keys . Cryptography . SignatureAlgorithm ;
1212
1313namespace Microsoft . Data . SqlClient . AlwaysEncrypted . AzureKeyVaultProvider
1414{
15- internal class AzureSqlKeyCryptographer
15+ internal sealed class AzureSqlKeyCryptographer : IDisposable
1616 {
1717 /// <summary>
1818 /// TokenCredential to be used with the KeyClient
@@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer
2525 private readonly ConcurrentDictionary < Uri , KeyClient > _keyClientDictionary = new ( ) ;
2626
2727 /// <summary>
28- /// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI).
29- /// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
30- /// key into the key dictionary.
28+ /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
3129 /// </summary>
32- private readonly ConcurrentDictionary < string , Task < Azure . Response < KeyVaultKey > > > _keyFetchTaskDictionary = new ( ) ;
30+ private readonly ConcurrentDictionary < string , KeyVaultKey > _keyDictionary = new ( ) ;
3331
3432 /// <summary>
35- /// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI) .
33+ /// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys .
3634 /// </summary>
37- private readonly ConcurrentDictionary < string , KeyVaultKey > _keyDictionary = new ( ) ;
35+ private SemaphoreSlim _keyDictionarySemaphore = new ( 1 , 1 ) ;
3836
3937 /// <summary>
4038 /// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI).
@@ -50,20 +48,44 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential)
5048 TokenCredential = tokenCredential ;
5149 }
5250
51+ /// <summary>
52+ /// Disposes the SemaphoreSlim used for thread safety.
53+ /// </summary>
54+ public void Dispose ( )
55+ {
56+ _keyDictionarySemaphore . Dispose ( ) ;
57+ }
58+
5359 /// <summary>
5460 /// Adds the key, specified by the Key Identifier URI, to the cache.
61+ /// Validates the key type and fetches the key from Azure Key Vault if it is not already cached.
5562 /// </summary>
5663 /// <param name="keyIdentifierUri"></param>
5764 internal void AddKey ( string keyIdentifierUri )
5865 {
59- if ( TheKeyHasNotBeenCached ( keyIdentifierUri ) )
66+ // Allow only one thread to proceed to ensure thread safety
67+ // as we will need to fetch key information from Azure Key Vault if the key is not found in cache.
68+ _keyDictionarySemaphore . Wait ( ) ;
69+
70+ try
6071 {
61- ParseAKVPath ( keyIdentifierUri , out Uri vaultUri , out string keyName , out string keyVersion ) ;
62- CreateKeyClient ( vaultUri ) ;
63- FetchKey ( vaultUri , keyName , keyVersion , keyIdentifierUri ) ;
64- }
72+ if ( ! _keyDictionary . ContainsKey ( keyIdentifierUri ) )
73+ {
74+ ParseAKVPath ( keyIdentifierUri , out Uri vaultUri , out string keyName , out string keyVersion ) ;
75+
76+ // Fetch the KeyClient for the Key vault URI.
77+ KeyClient keyClient = GetOrCreateKeyClient ( vaultUri ) ;
78+
79+ // Fetch the key from Azure Key Vault.
80+ KeyVaultKey key = FetchKeyFromKeyVault ( keyClient , keyName , keyVersion ) ;
6581
66- bool TheKeyHasNotBeenCached ( string k ) => ! _keyDictionary . ContainsKey ( k ) && ! _keyFetchTaskDictionary . ContainsKey ( k ) ;
82+ _keyDictionary . AddOrUpdate ( keyIdentifierUri , key , ( k , v ) => key ) ;
83+ }
84+ }
85+ finally
86+ {
87+ _keyDictionarySemaphore . Release ( ) ;
88+ }
6789 }
6890
6991 /// <summary>
@@ -75,18 +97,12 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
7597 {
7698 if ( _keyDictionary . TryGetValue ( keyIdentifierUri , out KeyVaultKey key ) )
7799 {
78- AKVEventSource . Log . TryTraceEvent ( "Fetched master key from cache" ) ;
100+ AKVEventSource . Log . TryTraceEvent ( "Fetched key name={0} from cache" , key . Name ) ;
79101 return key ;
80102 }
81103
82- if ( _keyFetchTaskDictionary . TryGetValue ( keyIdentifierUri , out Task < Azure . Response < KeyVaultKey > > task ) )
83- {
84- AKVEventSource . Log . TryTraceEvent ( "New Master key fetched." ) ;
85- return Task . Run ( ( ) => task ) . GetAwaiter ( ) . GetResult ( ) ;
86- }
87-
88104 // Not a public exception - not likely to occur.
89- AKVEventSource . Log . TryTraceEvent ( "Master key not found." ) ;
105+ AKVEventSource . Log . TryTraceEvent ( "Key not found; URI={0}" , keyIdentifierUri ) ;
90106 throw ADP . MasterKeyNotFound ( keyIdentifierUri ) ;
91107 }
92108
@@ -95,10 +111,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
95111 /// </summary>
96112 /// <param name="keyIdentifierUri">The key vault key identifier URI</param>
97113 /// <returns></returns>
98- internal int GetKeySize ( string keyIdentifierUri )
99- {
100- return GetKey ( keyIdentifierUri ) . Key . N . Length ;
101- }
114+ internal int GetKeySize ( string keyIdentifierUri ) => GetKey ( keyIdentifierUri ) . Key . N . Length ;
102115
103116 /// <summary>
104117 /// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
@@ -142,49 +155,58 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)
142155
143156 CryptographyClient cryptographyClient = new ( GetKey ( keyIdentifierUri ) . Id , TokenCredential ) ;
144157 _cryptoClientDictionary . TryAdd ( keyIdentifierUri , cryptographyClient ) ;
145-
146158 return cryptographyClient ;
147159 }
148160
149161 /// <summary>
150- ///
162+ /// Fetches the column encryption key from the Azure Key Vault.
151163 /// </summary>
152- /// <param name="vaultUri ">The Azure Key Vault URI </param>
164+ /// <param name="keyClient ">The KeyClient instance </param>
153165 /// <param name="keyName">The name of the Azure Key Vault key</param>
154166 /// <param name="keyVersion">The version of the Azure Key Vault key</param>
155- /// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
156- private void FetchKey ( Uri vaultUri , string keyName , string keyVersion , string keyResourceUri )
167+ private KeyVaultKey FetchKeyFromKeyVault ( KeyClient keyClient , string keyName , string keyVersion )
157168 {
158- Task < Azure . Response < KeyVaultKey > > fetchKeyTask = FetchKeyFromKeyVault ( vaultUri , keyName , keyVersion ) ;
159- _keyFetchTaskDictionary . AddOrUpdate ( keyResourceUri , fetchKeyTask , ( k , v ) => fetchKeyTask ) ;
169+ AKVEventSource . Log . TryTraceEvent ( "Fetching key name={0}" , keyName ) ;
160170
161- fetchKeyTask
162- . ContinueWith ( k => ValidateRsaKey ( k . GetAwaiter ( ) . GetResult ( ) ) )
163- . ContinueWith ( k => _keyDictionary . AddOrUpdate ( keyResourceUri , k . GetAwaiter ( ) . GetResult ( ) , ( key , v ) => k . GetAwaiter ( ) . GetResult ( ) ) ) ;
171+ Azure . Response < KeyVaultKey > keyResponse = keyClient ? . GetKey ( keyName , keyVersion ) ;
164172
165- Task . Run ( ( ) => fetchKeyTask ) ;
173+ // Handle the case where the key response is null or contains an error
174+ // This can happen if the key does not exist or if there is an issue with the KeyClient.
175+ // In such cases, we log the error and throw an exception.
176+ if ( keyResponse == null || keyResponse . Value == null || keyResponse . GetRawResponse ( ) . IsError )
177+ {
178+ AKVEventSource . Log . TryTraceEvent ( "Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}" , keyName , keyVersion ) ;
179+ if ( keyResponse ? . GetRawResponse ( ) is Azure . Response response )
180+ {
181+ AKVEventSource . Log . TryTraceEvent ( "Response status {0} : {1}" , response . Status , response . ReasonPhrase ) ;
182+ }
183+ throw ADP . GetKeyFailed ( keyName ) ;
184+ }
185+
186+ KeyVaultKey key = keyResponse . Value ;
187+
188+ // Validate that the key is of type RSA
189+ key = ValidateRsaKey ( key ) ;
190+ return key ;
166191 }
167192
168193 /// <summary>
169- /// Looks up the KeyClient object by it's URI and then fetches the key by name .
194+ /// Gets or creates a KeyClient for the specified Azure Key Vault URI .
170195 /// </summary>
171- /// <param name="vaultUri">The Azure Key Vault URI</param>
172- /// <param name="keyName">Then name of the key</param>
173- /// <param name="keyVersion">Then version of the key</param>
196+ /// <param name="vaultUri">Key Identifier URL</param>
174197 /// <returns></returns>
175- private Task < Azure . Response < KeyVaultKey > > FetchKeyFromKeyVault ( Uri vaultUri , string keyName , string keyVersion )
198+ private KeyClient GetOrCreateKeyClient ( Uri vaultUri )
176199 {
177- _keyClientDictionary . TryGetValue ( vaultUri , out KeyClient keyClient ) ;
178- AKVEventSource . Log . TryTraceEvent ( "Fetching requested master key: {0}" , keyName ) ;
179- return keyClient ? . GetKeyAsync ( keyName , keyVersion ) ;
200+ return _keyClientDictionary . GetOrAdd (
201+ vaultUri , ( _ ) => new KeyClient ( vaultUri , TokenCredential ) ) ;
180202 }
181203
182204 /// <summary>
183205 /// Validates that a key is of type RSA
184206 /// </summary>
185207 /// <param name="key"></param>
186208 /// <returns></returns>
187- private KeyVaultKey ValidateRsaKey ( KeyVaultKey key )
209+ private static KeyVaultKey ValidateRsaKey ( KeyVaultKey key )
188210 {
189211 if ( key . KeyType != KeyType . Rsa && key . KeyType != KeyType . RsaHsm )
190212 {
@@ -195,26 +217,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
195217 return key ;
196218 }
197219
198- /// <summary>
199- /// Instantiates and adds a KeyClient to the KeyClient dictionary
200- /// </summary>
201- /// <param name="vaultUri">The Azure Key Vault URI</param>
202- private void CreateKeyClient ( Uri vaultUri )
203- {
204- if ( ! _keyClientDictionary . ContainsKey ( vaultUri ) )
205- {
206- _keyClientDictionary . TryAdd ( vaultUri , new KeyClient ( vaultUri , TokenCredential ) ) ;
207- }
208- }
209-
210220 /// <summary>
211221 /// Validates and parses the Azure Key Vault URI and key name.
212222 /// </summary>
213223 /// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
214224 /// <param name="vaultUri">The Azure Key Vault URI</param>
215225 /// <param name="masterKeyName">The name of the key</param>
216226 /// <param name="masterKeyVersion">The version of the key</param>
217- private void ParseAKVPath ( string masterKeyPath , out Uri vaultUri , out string masterKeyName , out string masterKeyVersion )
227+ private static void ParseAKVPath ( string masterKeyPath , out Uri vaultUri , out string masterKeyName , out string masterKeyVersion )
218228 {
219229 Uri masterKeyPathUri = new ( masterKeyPath ) ;
220230 vaultUri = new Uri ( masterKeyPathUri . GetLeftPart ( UriPartial . Authority ) ) ;
0 commit comments