Skip to content

Commit

Permalink
Test | Add lock when using ClearSqlConnectionGlobalProvidersk (#1461)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnny Pham authored Jan 14, 2022
1 parent 330de76 commit d9efa74
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,24 @@ public void TestEmptyProviderName()
[Fact]
public void TestCanSetGlobalProvidersOnlyOnce()
{
Utility.ClearSqlConnectionGlobalProviders();
lock (Utility.ClearSqlConnectionGlobalProvidersLock)
{
Utility.ClearSqlConnectionGlobalProviders();

IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
{ DummyKeyStoreProvider.Name, new DummyKeyStoreProvider() }
};
SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
};
SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);

InvalidOperationException e = Assert.Throws<InvalidOperationException>(
() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
string expectedMessage = SystemDataResourceManager.Instance.TCE_CanOnlyCallOnce;
Assert.Contains(expectedMessage, e.Message);
InvalidOperationException e = Assert.Throws<InvalidOperationException>(
() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
string expectedMessage = SystemDataResourceManager.Instance.TCE_CanOnlyCallOnce;
Assert.Contains(expectedMessage, e.Message);

Utility.ClearSqlConnectionGlobalProviders();
Utility.ClearSqlConnectionGlobalProviders();
}
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void TestInvalidCipherText()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestInvalidAlgorithmVersion()
{
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_InvalidAlgorithmVersion,
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_InvalidAlgorithmVersion,
40, "01");
byte[] plainText = Encoding.Unicode.GetBytes("Hello World");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
Expand Down Expand Up @@ -112,7 +112,7 @@ public void TestInvalidAuthenticationTag()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestNullColumnEncryptionAlgorithm()
{
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_NullColumnEncryptionAlgorithm,
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_NullColumnEncryptionAlgorithm,
"'AEAD_AES_256_CBC_HMAC_SHA256'");
Object cipherMD = GetSqlCipherMetadata(0, 0, null, 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "MSSQL_CERTIFICATE_STORE", "RSA_OAEP");
Expand Down Expand Up @@ -148,32 +148,35 @@ public void TestUnknownEncryptionAlgorithmId()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestUnknownCustomKeyStoreProvider()
{
// Clear out the existing providers (to ensure test reliability)
ClearSqlConnectionGlobalProviders();

const string invalidProviderName = "Dummy_Provider";
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnrecognizedKeyStoreProviderName,
invalidProviderName, "'MSSQL_CERTIFICATE_STORE', 'MSSQL_CNG_STORE', 'MSSQL_CSP_PROVIDER'", "");
Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x03);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, invalidProviderName, "RSA_OAEP");
byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
lock (Utility.ClearSqlConnectionGlobalProvidersLock)
{
// Clear out the existing providers (to ensure test reliability)
ClearSqlConnectionGlobalProviders();

Exception decryptEx = Assert.Throws<TargetInvocationException>(() => DecryptWithKey(plainText, cipherMD));
Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
const string invalidProviderName = "Dummy_Provider";
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnrecognizedKeyStoreProviderName,
invalidProviderName, "'MSSQL_CERTIFICATE_STORE', 'MSSQL_CNG_STORE', 'MSSQL_CSP_PROVIDER'", "");
Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x03);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, invalidProviderName, "RSA_OAEP");
byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);

Exception encryptEx = Assert.Throws<TargetInvocationException>(() => EncryptWithKey(plainText, cipherMD));
Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
Exception decryptEx = Assert.Throws<TargetInvocationException>(() => DecryptWithKey(plainText, cipherMD));
Assert.Contains(expectedMessage, decryptEx.InnerException.Message);

Exception encryptEx = Assert.Throws<TargetInvocationException>(() => EncryptWithKey(plainText, cipherMD));
Assert.Contains(expectedMessage, encryptEx.InnerException.Message);

ClearSqlConnectionGlobalProviders();
ClearSqlConnectionGlobalProviders();
}
}

[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public void TestTceUnknownEncryptionAlgorithm()
{
const string unknownEncryptionAlgorithm = "Dummy";
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnknownColumnEncryptionAlgorithm,
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnknownColumnEncryptionAlgorithm,
unknownEncryptionAlgorithm, "'AEAD_AES_256_CBC_HMAC_SHA256'");
Object cipherMD = GetSqlCipherMetadata(0, 0, "Dummy", 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "MSSQL_CERTIFICATE_STORE", "RSA_OAEP");
Expand All @@ -193,7 +196,7 @@ public void TestExceptionsFromCertStore()
{
byte[] corruptedCek = GenerateInvalidEncryptedCek(CertFixture.cek, ECEKCorruption.SIGNATURE);

string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_KeyDecryptionFailedCertStore,
string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_KeyDecryptionFailedCertStore,
"MSSQL_CERTIFICATE_STORE", BitConverter.ToString(corruptedCek, corruptedCek.Length - 10, 10));

Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
Expand All @@ -209,27 +212,30 @@ public void TestExceptionsFromCertStore()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestExceptionsFromCustomKeyStore()
{
string expectedMessage = "Failed to decrypt a column encryption key";
lock (Utility.ClearSqlConnectionGlobalProvidersLock)
{
string expectedMessage = "Failed to decrypt a column encryption key";

// Clear out the existing providers (to ensure test reliability)
ClearSqlConnectionGlobalProviders();
// Clear out the existing providers (to ensure test reliability)
ClearSqlConnectionGlobalProviders();

IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders = new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>();
customProviders.Add(DummyKeyStoreProvider.Name, new DummyKeyStoreProvider());
SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders = new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>();
customProviders.Add(DummyKeyStoreProvider.Name, new DummyKeyStoreProvider());
SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);

object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "DummyProvider", "DummyAlgo");
byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "DummyProvider", "DummyAlgo");
byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);

Exception decryptEx = Assert.Throws<TargetInvocationException>(() => DecryptWithKey(cipherText, cipherMD));
Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
Exception decryptEx = Assert.Throws<TargetInvocationException>(() => DecryptWithKey(cipherText, cipherMD));
Assert.Contains(expectedMessage, decryptEx.InnerException.Message);

Exception encryptEx = Assert.Throws<TargetInvocationException>(() => EncryptWithKey(cipherText, cipherMD));
Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
Exception encryptEx = Assert.Throws<TargetInvocationException>(() => EncryptWithKey(cipherText, cipherMD));
Assert.Contains(expectedMessage, encryptEx.InnerException.Message);

ClearSqlConnectionGlobalProviders();
ClearSqlConnectionGlobalProviders();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ public class SqlColumnEncryptionCertificateStoreProviderWindowsShould : IClassFi
/// </summary>
private const int CipherTextStartIndex = IVStartIndex + IVLengthInBytes;

/// <summary>
/// SetCustomColumnEncryptionKeyStoreProvider can be called only once in a process. To workaround that, we use this flag.
/// </summary>
private static bool s_testCustomEncryptioKeyStoreProviderExecutedOnce = false;

[Theory]
[InvalidDecryptionParameters]
[PlatformSpecific(TestPlatforms.Windows)]
Expand Down Expand Up @@ -326,55 +321,51 @@ public void TestAeadEncryptionReversal(string dataType, object data, Utility.CCo
[PlatformSpecific(TestPlatforms.Windows)]
public void TestCustomKeyProviderListSetter()
{
// SqlConnection.RegisterColumnEncryptionKeyStoreProviders can be called only once in a process.
// This is a workaround to ensure re-runnability of the test.
if (s_testCustomEncryptioKeyStoreProviderExecutedOnce)
lock (Utility.ClearSqlConnectionGlobalProvidersLock)
{
return;
string expectedMessage1 = "Column encryption key store provider dictionary cannot be null. Expecting a non-null value.";
// Verify that we are able to set it to null.
ArgumentException e1 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(null));
Assert.Contains(expectedMessage1, e1.Message);

// A dictionary holding custom providers.
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders = new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>();
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"DummyProvider", new DummyKeyStoreProvider()));

// Verify that setting a provider in the list with null value throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"CustomProvider", null));
string expectedMessage2 = "Null reference specified for key store provider 'CustomProvider'. Expecting a non-null value.";
ArgumentNullException e2 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage2, e2.Message);
customProviders.Remove(@"CustomProvider");

// Verify that setting a provider in the list with an empty provider name throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"", new DummyKeyStoreProvider()));
string expectedMessage3 = "Invalid key store provider name specified. Key store provider names cannot be null or empty";
ArgumentNullException e3 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage3, e3.Message);

customProviders.Remove(@"");

// Verify that setting a provider in the list with name that starts with 'MSSQL_' throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"MSSQL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
string expectedMessage4 = "Invalid key store provider name 'MSSQL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
ArgumentException e4 = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage4, e4.Message);

customProviders.Remove(@"MSSQL_MyStore");

// Verify that setting a provider in the list with name that starts with 'MSSQL_' but different case throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"MsSqL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
string expectedMessage5 = "Invalid key store provider name 'MsSqL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
ArgumentException e5 = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage5, e5.Message);

customProviders.Remove(@"MsSqL_MyStore");

// Clear any providers set by other tests.
Utility.ClearSqlConnectionGlobalProviders();
}

string expectedMessage1 = "Column encryption key store provider dictionary cannot be null. Expecting a non-null value.";
// Verify that we are able to set it to null.
ArgumentException e1 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(null));
Assert.Contains(expectedMessage1, e1.Message);

// A dictionary holding custom providers.
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders = new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>();
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"DummyProvider", new DummyKeyStoreProvider()));

// Verify that setting a provider in the list with null value throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"CustomProvider", null));
string expectedMessage2 = "Null reference specified for key store provider 'CustomProvider'. Expecting a non-null value.";
ArgumentNullException e2 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage2, e2.Message);
customProviders.Remove(@"CustomProvider");

// Verify that setting a provider in the list with an empty provider name throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"", new DummyKeyStoreProvider()));
string expectedMessage3 = "Invalid key store provider name specified. Key store provider names cannot be null or empty";
ArgumentNullException e3 = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage3, e3.Message);

customProviders.Remove(@"");

// Verify that setting a provider in the list with name that starts with 'MSSQL_' throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"MSSQL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
string expectedMessage4 = "Invalid key store provider name 'MSSQL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
ArgumentException e4 = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage4, e4.Message);

customProviders.Remove(@"MSSQL_MyStore");

// Verify that setting a provider in the list with name that starts with 'MSSQL_' but different case throws an exception.
customProviders.Add(new KeyValuePair<string, SqlColumnEncryptionKeyStoreProvider>(@"MsSqL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
string expectedMessage5 = "Invalid key store provider name 'MsSqL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
ArgumentException e5 = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage5, e5.Message);

customProviders.Remove(@"MsSqL_MyStore");

// Clear any providers set by other tests.
Utility.ClearSqlConnectionGlobalProviders();
}

[Theory]
Expand Down Expand Up @@ -502,7 +493,7 @@ public class CEKEncryptionReversalParameters : DataAttribute
{
public override IEnumerable<Object[]> GetData(MethodInfo testMethod)
{
yield return new object[2] { StoreLocation.CurrentUser , CurrentUserMyPathPrefix };
yield return new object[2] { StoreLocation.CurrentUser, CurrentUserMyPathPrefix };
// use localmachine cert path only when current user is Admin.
if (CertificateFixture.IsAdmin)
{
Expand Down
Loading

0 comments on commit d9efa74

Please sign in to comment.