Skip to content

Commit

Permalink
Advance column index to avoid double clean. (#2825)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdaigle authored Sep 9, 2024
1 parent 62af3b5 commit eb4c0f0
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6388,16 +6388,36 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, SqlMetaDataPriv md,
{
if (stateObj is not null)
{
// call to decrypt column keys has failed. The data wont be decrypted.
// Not setting the value to false, forces the driver to look for column value.
// Packet received from Key Vault will throws invalid token header.
if (stateObj.HasPendingData)
// Throwing an exception here circumvents the normal pending data checks and cleanup processes,
// so we need to ensure the appropriate state. Increment the _nextColumnDataToRead index because
// we already read the encrypted column data; Otherwise we'll double count and attempt to drain a
// corresponding number of bytes a second time. We don't want the rest of the pending data to
// interfere with future operations, so we must drain it. Set HasPendingData to false to indicate
// that we successfully drained the data.

// The SqlDataReader also maintains a state called dataReady. We need to set that to false if we've
// drained the data off the connection. Otherwise, a consumer that catches the exception may
// continue to use the reader and will timeout waiting to read data that doesn't exist.

// Order matters here. Must increment column before draining data.
// Update state objects after draining data.



if (stateObj._readerState != null)
{
// Drain the pending data now if setting the HasPendingData to false.
// SqlDataReader.TryCloseInternal can not drain if HasPendingData = false.
DrainData(stateObj);
stateObj._readerState._nextColumnDataToRead++;
}

DrainData(stateObj);

if (stateObj._readerState != null)
{
stateObj._readerState._dataReady = false;
}

stateObj.HasPendingData = false;

}
throw SQL.ColumnDecryptionFailed(columnName, null, e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted
{
public sealed class ColumnDecryptErrorTests : IClassFixture<SQLSetupStrategyAzureKeyVault>, IDisposable
{
private SQLSetupStrategyAzureKeyVault fixture;

private readonly string tableName;

public ColumnDecryptErrorTests(SQLSetupStrategyAzureKeyVault context)
{
fixture = context;
tableName = fixture.ColumnDecryptErrorTestTable.Name;
}

/*
* This test ensures that column decryption errors and connection pooling play nicely together.
* When a decryption error is encountered, we expect the connection to be drained of data and
* properly reset before being returned to the pool. If this doesn't happen, then random bytes
* may be left in the connection's state. These can interfere with the next operation that utilizes
* the connection.
*
* We test that state is properly reset by triggering the same error condition twice. Routing column key discovery
* away from AKV toward a dummy key store achieves this. Each connection pulls from a pool of max
* size one to ensure we are using the same internal connection/socket both times. We expect to
* receive the "Failed to decrypt column" exception twice. If the state were not cleaned properly,
* the second error would be different because the TDS stream would be unintelligible.
*
* Finally, we assert that restoring the connection to AKV allows a successful query.
*/
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsTargetReadyForAeWithKeyStore), nameof(DataTestUtility.IsAKVSetupAvailable))]
[ClassData(typeof(TestQueries))]
public void TestCleanConnectionAfterDecryptFail(string connString, string selectQuery, int totalColumnsInSelect, string[] types)
{
// Arrange
Assert.False(string.IsNullOrWhiteSpace(selectQuery), "FAILED: select query should not be null or empty.");
Assert.True(totalColumnsInSelect <= 3, "FAILED: totalColumnsInSelect should <= 3.");

using (SqlConnection sqlConnection = new SqlConnection(connString))
{
sqlConnection.Open();

Table.DeleteData(tableName, sqlConnection);

Customer customer = new Customer(
45,
"Microsoft",
"Corporation");

DatabaseHelper.InsertCustomerData(sqlConnection, null, tableName, customer);
}


// Act - Trigger a column decrypt error on the connection
Dictionary<String, SqlColumnEncryptionKeyStoreProvider> keyStoreProviders = new()
{
{ "AZURE_KEY_VAULT", new DummyKeyStoreProvider() }
};

String poolEnabledConnString = new SqlConnectionStringBuilder(connString) { Pooling = true, MaxPoolSize = 1 }.ToString();

using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();
sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders);

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);

using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
var error = Assert.Throws<SqlException>(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect));
Assert.Contains("Failed to decrypt column", error.Message);
}
}


// Assert
using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();
sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders);

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);
using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
var error = Assert.Throws<SqlException>(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect));
Assert.Contains("Failed to decrypt column", error.Message);
}
}

using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);
using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect);
}
}
}


public void Dispose()
{
foreach (string connStrAE in DataTestUtility.AEConnStringsSetup)
{
using (SqlConnection sqlConnection = new SqlConnection(connStrAE))
{
sqlConnection.Open();
Table.DeleteData(fixture.ColumnDecryptErrorTestTable.Name, sqlConnection);
}
}
}

private sealed class DummyKeyStoreProvider : SqlColumnEncryptionKeyStoreProvider
{
public override byte[] DecryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] encryptedColumnEncryptionKey)
{
// Must be 32 to match the key length expected for the 'AEAD_AES_256_CBC_HMAC_SHA256' algorithm
return new byte[32];
}

public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] columnEncryptionKey)
{
return new byte[32];
}
}
}

public class TestQueries : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
{
foreach (string connStrAE in DataTestUtility.AEConnStrings)
{
yield return new object[] { connStrAE, @"select CustomerId, FirstName, LastName from [{0}] ", 3, new string[] { @"int", @"string", @"string" } };
yield return new object[] { connStrAE, @"select CustomerId, FirstName from [{0}] ", 2, new string[] { @"int", @"string" } };
yield return new object[] { connStrAE, @"select LastName from [{0}] ", 1, new string[] { @"string" } };
}
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class SQLSetupStrategy : IDisposable
public Table ApiTestTable { get; private set; }
public Table BulkCopyAEErrorMessageTestTable { get; private set; }
public Table BulkCopyAETestTable { get; private set; }
public Table ColumnDecryptErrorTestTable { get; private set; }
public Table SqlParameterPropertiesTable { get; private set; }
public Table DateOnlyTestTable { get; private set; }
public Table End2EndSmokeTable { get; private set; }
Expand Down Expand Up @@ -127,6 +128,9 @@ protected List<Table> CreateTables(IList<ColumnEncryptionKey> columnEncryptionKe
BulkCopyAETestTable = new BulkCopyAETestTable(GenerateUniqueName("BulkCopyAETestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]);
tables.Add(BulkCopyAETestTable);

ColumnDecryptErrorTestTable = new ColumnDecryptErrorTestTable(GenerateUniqueName("ColumnDecryptErrorTestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]);
tables.Add(ColumnDecryptErrorTestTable);

SqlParameterPropertiesTable = new SqlParameterPropertiesTable(GenerateUniqueName("SqlParameterPropertiesTable"));
tables.Add(SqlParameterPropertiesTable);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup
{
public class ColumnDecryptErrorTestTable : Table
{
private const string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256";
public ColumnEncryptionKey columnEncryptionKey1;
public ColumnEncryptionKey columnEncryptionKey2;
private bool useDeterministicEncryption;

public ColumnDecryptErrorTestTable(string tableName, ColumnEncryptionKey columnEncryptionKey1, ColumnEncryptionKey columnEncryptionKey2, bool useDeterministicEncryption = false) : base(tableName)
{
this.columnEncryptionKey1 = columnEncryptionKey1;
this.columnEncryptionKey2 = columnEncryptionKey2;
this.useDeterministicEncryption = useDeterministicEncryption;
}

public override void Create(SqlConnection sqlConnection)
{
string encryptionType = useDeterministicEncryption ? "DETERMINISTIC" : DataTestUtility.EnclaveEnabled ? "RANDOMIZED" : "DETERMINISTIC";
string sql =
$@"CREATE TABLE [dbo].[{Name}]
(
[CustomerId] [int] ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey1.Name}], ENCRYPTION_TYPE = {encryptionType}, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[FirstName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[LastName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}')
)";

using (SqlCommand command = sqlConnection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyAETestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyAEErrorMessageTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyTruncationTables.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\ColumnDecryptErrorTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\DateOnlyTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\SqlNullValuesTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\SqlParameterPropertiesTable.cs" />
Expand All @@ -73,6 +74,7 @@
</ItemGroup>
<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0')) AND ('$(TestSet)' == '' OR '$(TestSet)' == 'AE')">
<Compile Include="AlwaysEncrypted\DateOnlyReadTests.cs" />
<Compile Include="AlwaysEncrypted\ColumnDecryptErrorTests.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TestSet)' == '' OR '$(TestSet)' == '1'">
<Compile Include="SQL\AsyncTest\AsyncTimeoutTest.cs" />
Expand Down

0 comments on commit eb4c0f0

Please sign in to comment.