From 21e1b5d53c9b778ec53be473b612aa880005b3be Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Sun, 7 Apr 2019 20:18:43 +0100 Subject: [PATCH] special case bytes in udt parameter write --- .../src/System/Data/SqlClient/TdsParser.cs | 24 ++++++- .../SQL/UdtTest/SqlServerTypesTest.cs | 68 +++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs index 2c35b0de3a53..44b6760ee767 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.cs @@ -7489,7 +7489,29 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet if (!isNull) { - udtVal = _connHandler.Connection.GetBytes(value, out format, out maxsize); + if (value is byte[] rawBytes) + { + udtVal = rawBytes; + } + else if (value is SqlBytes sqlBytes) + { + switch (sqlBytes.Storage) + { + case StorageState.Buffer: + // use the buffer directly, the only way to create it is with the correctly sized byte array + udtVal = sqlBytes.Buffer; + break; + case StorageState.Stream: + case StorageState.UnmanagedBuffer: + // allocate a new byte array to store the data + udtVal = sqlBytes.Value; + break; + } + } + else + { + udtVal = _connHandler.Connection.GetBytes(value, out format, out maxsize); + } Debug.Assert(null != udtVal, "GetBytes returned null instance. Make sure that it always returns non-null value"); size = udtVal.Length; diff --git a/src/System.Data.SqlClient/tests/ManualTests/SQL/UdtTest/SqlServerTypesTest.cs b/src/System.Data.SqlClient/tests/ManualTests/SQL/UdtTest/SqlServerTypesTest.cs index 84b88331044b..3307bce387de 100644 --- a/src/System.Data.SqlClient/tests/ManualTests/SQL/UdtTest/SqlServerTypesTest.cs +++ b/src/System.Data.SqlClient/tests/ManualTests/SQL/UdtTest/SqlServerTypesTest.cs @@ -277,6 +277,74 @@ public static void TestUdtSchemaMetadata() } } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static void TestUdtParameterSetSqlByteValue() + { + const string ExpectedPointValue = "POINT (1 1)"; + SqlBytes geometrySqlBytes = null; + string actualtPointValue = null; + + using (SqlConnection connection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + connection.Open(); + + using (var command = connection.CreateCommand()) + { + command.CommandText = $"SELECT geometry::Parse('{ExpectedPointValue}')"; + using (var reader = command.ExecuteReader()) + { + reader.Read(); + geometrySqlBytes = reader.GetSqlBytes(0); + } + } + + using (var command = connection.CreateCommand()) + { + command.CommandText = "SELECT @geometry.STAsText()"; + var parameter = command.Parameters.AddWithValue("@geometry", geometrySqlBytes); + parameter.SqlDbType = SqlDbType.Udt; + parameter.UdtTypeName = "geometry"; + actualtPointValue = Convert.ToString(command.ExecuteScalar()); + } + + Assert.Equal(ExpectedPointValue, actualtPointValue); + } + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static void TestUdtParameterSetRawByteValue() + { + const string ExpectedPointValue = "POINT (1 1)"; + byte[] geometryBytes = null; + string actualtPointValue = null; + + using (SqlConnection connection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + connection.Open(); + + using (var command = connection.CreateCommand()) + { + command.CommandText = $"SELECT geometry::Parse('{ExpectedPointValue}')"; + using (var reader = command.ExecuteReader()) + { + reader.Read(); + geometryBytes = reader.GetSqlBytes(0).Buffer; + } + } + + using (var command = connection.CreateCommand()) + { + command.CommandText = "SELECT @geometry.STAsText()"; + var parameter = command.Parameters.AddWithValue("@geometry", geometryBytes); + parameter.SqlDbType = SqlDbType.Udt; + parameter.UdtTypeName = "geometry"; + actualtPointValue = Convert.ToString(command.ExecuteScalar()); + } + + Assert.Equal(ExpectedPointValue, actualtPointValue); + } + } + private static void AssertSqlUdtAssemblyQualifiedName(string assemblyQualifiedName, string expectedType) { List parts = assemblyQualifiedName.Split(',').Select(x => x.Trim()).ToList();