From eb4c0f043bff30c7f3bb58704b41c73faef9d63b Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Mon, 9 Sep 2024 11:04:55 -0700 Subject: [PATCH] Advance column index to avoid double clean. (#2825) --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 34 +++- .../ColumnDecryptErrorTests.cs | 170 ++++++++++++++++++ .../TestFixtures/SQLSetupStrategy.cs | 4 + .../Setup/ColumnDecryptErrorTestTable.cs | 40 +++++ ....Data.SqlClient.ManualTesting.Tests.csproj | 2 + 5 files changed, 243 insertions(+), 7 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ColumnDecryptErrorTests.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/ColumnDecryptErrorTestTable.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index edeecb6c9f..2ec05846cf 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -6388,16 +6388,36 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, SqlMetaDataPriv md, { if (stateObj is not null) { - // call to decrypt column keys has failed. The data wont be decrypted. - // Not setting the value to false, forces the driver to look for column value. - // Packet received from Key Vault will throws invalid token header. - if (stateObj.HasPendingData) + // Throwing an exception here circumvents the normal pending data checks and cleanup processes, + // so we need to ensure the appropriate state. Increment the _nextColumnDataToRead index because + // we already read the encrypted column data; Otherwise we'll double count and attempt to drain a + // corresponding number of bytes a second time. We don't want the rest of the pending data to + // interfere with future operations, so we must drain it. Set HasPendingData to false to indicate + // that we successfully drained the data. + + // The SqlDataReader also maintains a state called dataReady. We need to set that to false if we've + // drained the data off the connection. Otherwise, a consumer that catches the exception may + // continue to use the reader and will timeout waiting to read data that doesn't exist. + + // Order matters here. Must increment column before draining data. + // Update state objects after draining data. + + + + if (stateObj._readerState != null) { - // Drain the pending data now if setting the HasPendingData to false. - // SqlDataReader.TryCloseInternal can not drain if HasPendingData = false. - DrainData(stateObj); + stateObj._readerState._nextColumnDataToRead++; } + + DrainData(stateObj); + + if (stateObj._readerState != null) + { + stateObj._readerState._dataReady = false; + } + stateObj.HasPendingData = false; + } throw SQL.ColumnDecryptionFailed(columnName, null, e); } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ColumnDecryptErrorTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ColumnDecryptErrorTests.cs new file mode 100644 index 0000000000..954d94ee7f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ColumnDecryptErrorTests.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup; +using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted +{ + public sealed class ColumnDecryptErrorTests : IClassFixture, IDisposable + { + private SQLSetupStrategyAzureKeyVault fixture; + + private readonly string tableName; + + public ColumnDecryptErrorTests(SQLSetupStrategyAzureKeyVault context) + { + fixture = context; + tableName = fixture.ColumnDecryptErrorTestTable.Name; + } + + /* + * This test ensures that column decryption errors and connection pooling play nicely together. + * When a decryption error is encountered, we expect the connection to be drained of data and + * properly reset before being returned to the pool. If this doesn't happen, then random bytes + * may be left in the connection's state. These can interfere with the next operation that utilizes + * the connection. + * + * We test that state is properly reset by triggering the same error condition twice. Routing column key discovery + * away from AKV toward a dummy key store achieves this. Each connection pulls from a pool of max + * size one to ensure we are using the same internal connection/socket both times. We expect to + * receive the "Failed to decrypt column" exception twice. If the state were not cleaned properly, + * the second error would be different because the TDS stream would be unintelligible. + * + * Finally, we assert that restoring the connection to AKV allows a successful query. + */ + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsTargetReadyForAeWithKeyStore), nameof(DataTestUtility.IsAKVSetupAvailable))] + [ClassData(typeof(TestQueries))] + public void TestCleanConnectionAfterDecryptFail(string connString, string selectQuery, int totalColumnsInSelect, string[] types) + { + // Arrange + Assert.False(string.IsNullOrWhiteSpace(selectQuery), "FAILED: select query should not be null or empty."); + Assert.True(totalColumnsInSelect <= 3, "FAILED: totalColumnsInSelect should <= 3."); + + using (SqlConnection sqlConnection = new SqlConnection(connString)) + { + sqlConnection.Open(); + + Table.DeleteData(tableName, sqlConnection); + + Customer customer = new Customer( + 45, + "Microsoft", + "Corporation"); + + DatabaseHelper.InsertCustomerData(sqlConnection, null, tableName, customer); + } + + + // Act - Trigger a column decrypt error on the connection + Dictionary keyStoreProviders = new() + { + { "AZURE_KEY_VAULT", new DummyKeyStoreProvider() } + }; + + String poolEnabledConnString = new SqlConnectionStringBuilder(connString) { Pooling = true, MaxPoolSize = 1 }.ToString(); + + using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString)) + { + sqlConnection.Open(); + sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders); + + using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName), + sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled); + + using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader(); + + Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows."); + + while (sqlDataReader.Read()) + { + var error = Assert.Throws(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect)); + Assert.Contains("Failed to decrypt column", error.Message); + } + } + + + // Assert + using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString)) + { + sqlConnection.Open(); + sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders); + + using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName), + sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled); + using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader(); + + Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows."); + + while (sqlDataReader.Read()) + { + var error = Assert.Throws(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect)); + Assert.Contains("Failed to decrypt column", error.Message); + } + } + + using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString)) + { + sqlConnection.Open(); + + using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName), + sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled); + using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader(); + + Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows."); + + while (sqlDataReader.Read()) + { + DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect); + } + } + } + + + public void Dispose() + { + foreach (string connStrAE in DataTestUtility.AEConnStringsSetup) + { + using (SqlConnection sqlConnection = new SqlConnection(connStrAE)) + { + sqlConnection.Open(); + Table.DeleteData(fixture.ColumnDecryptErrorTestTable.Name, sqlConnection); + } + } + } + + private sealed class DummyKeyStoreProvider : SqlColumnEncryptionKeyStoreProvider + { + public override byte[] DecryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] encryptedColumnEncryptionKey) + { + // Must be 32 to match the key length expected for the 'AEAD_AES_256_CBC_HMAC_SHA256' algorithm + return new byte[32]; + } + + public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] columnEncryptionKey) + { + return new byte[32]; + } + } + } + + public class TestQueries : IEnumerable + { + public IEnumerator GetEnumerator() + { + foreach (string connStrAE in DataTestUtility.AEConnStrings) + { + yield return new object[] { connStrAE, @"select CustomerId, FirstName, LastName from [{0}] ", 3, new string[] { @"int", @"string", @"string" } }; + yield return new object[] { connStrAE, @"select CustomerId, FirstName from [{0}] ", 2, new string[] { @"int", @"string" } }; + yield return new object[] { connStrAE, @"select LastName from [{0}] ", 1, new string[] { @"string" } }; + } + } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs index 4d3c635684..2bc38d9930 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs @@ -21,6 +21,7 @@ public class SQLSetupStrategy : IDisposable public Table ApiTestTable { get; private set; } public Table BulkCopyAEErrorMessageTestTable { get; private set; } public Table BulkCopyAETestTable { get; private set; } + public Table ColumnDecryptErrorTestTable { get; private set; } public Table SqlParameterPropertiesTable { get; private set; } public Table DateOnlyTestTable { get; private set; } public Table End2EndSmokeTable { get; private set; } @@ -127,6 +128,9 @@ protected List CreateTables(IList columnEncryptionKe BulkCopyAETestTable = new BulkCopyAETestTable(GenerateUniqueName("BulkCopyAETestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]); tables.Add(BulkCopyAETestTable); + ColumnDecryptErrorTestTable = new ColumnDecryptErrorTestTable(GenerateUniqueName("ColumnDecryptErrorTestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]); + tables.Add(ColumnDecryptErrorTestTable); + SqlParameterPropertiesTable = new SqlParameterPropertiesTable(GenerateUniqueName("SqlParameterPropertiesTable")); tables.Add(SqlParameterPropertiesTable); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/ColumnDecryptErrorTestTable.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/ColumnDecryptErrorTestTable.cs new file mode 100644 index 0000000000..0a16e0c3aa --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/ColumnDecryptErrorTestTable.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup +{ + public class ColumnDecryptErrorTestTable : Table + { + private const string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256"; + public ColumnEncryptionKey columnEncryptionKey1; + public ColumnEncryptionKey columnEncryptionKey2; + private bool useDeterministicEncryption; + + public ColumnDecryptErrorTestTable(string tableName, ColumnEncryptionKey columnEncryptionKey1, ColumnEncryptionKey columnEncryptionKey2, bool useDeterministicEncryption = false) : base(tableName) + { + this.columnEncryptionKey1 = columnEncryptionKey1; + this.columnEncryptionKey2 = columnEncryptionKey2; + this.useDeterministicEncryption = useDeterministicEncryption; + } + + public override void Create(SqlConnection sqlConnection) + { + string encryptionType = useDeterministicEncryption ? "DETERMINISTIC" : DataTestUtility.EnclaveEnabled ? "RANDOMIZED" : "DETERMINISTIC"; + string sql = + $@"CREATE TABLE [dbo].[{Name}] + ( + [CustomerId] [int] ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey1.Name}], ENCRYPTION_TYPE = {encryptionType}, ALGORITHM = '{ColumnEncryptionAlgorithmName}'), + [FirstName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}'), + [LastName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}') + )"; + + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = sql; + command.ExecuteNonQuery(); + } + } + + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 0ec61b5420..6de3141afa 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -57,6 +57,7 @@ + @@ -73,6 +74,7 @@ +