diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs index e66d5cd393c9..a509e7174681 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs @@ -2588,10 +2588,10 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met // If its a SQL Type or Nullable UDT object rawValue = GetSqlValueFromSqlBufferInternal(data, metaData); - // Special case: User wants SqlString, but we have a SqlXml - // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion if (typeofT == s_typeofSqlString) { + // Special case: User wants SqlString, but we have a SqlXml + // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion SqlXml xmlValue = rawValue as SqlXml; if (xmlValue != null) { @@ -2610,22 +2610,58 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { - // Otherwise Its a CLR or non-Nullable UDT - try + if (typeof(XmlReader) == typeofT) { - return (T)GetValueFromSqlBufferInternal(data, metaData); + if (metaData.metaType.SqlDbType != SqlDbType.Xml) + { + throw SQL.XmlReaderNotSupportOnColumnType(metaData.column); + } + else + { + object clrValue = null; + if (!data.IsNull) + { + clrValue = GetValueFromSqlBufferInternal(data, metaData); + } + if (clrValue is null) // covers IsNull and when there is data which is present but is a clr null somehow + { + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( + new MemoryStream(Array.Empty(), writable: false), + closeInput: true + ); + } + else if (clrValue.GetType() == typeof(string)) + { + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( + new StringReader(clrValue as string), + closeInput: true + ); + } + else + { + // try the type cast to throw the invalid cast exception and inform the user what types they're trying to use and that why it is wrong + return (T)clrValue; + } + } } - catch (InvalidCastException) + else { - if (data.IsNull) + try { - // If the value was actually null, then we should throw a SqlNullValue instead - throw SQL.SqlNullValue(); + return (T)GetValueFromSqlBufferInternal(data, metaData); } - else + catch (InvalidCastException) { - // Legitimate InvalidCast, rethrow - throw; + if (data.IsNull) + { + // If the value was actually null, then we should throw a SqlNullValue instead + throw SQL.SqlNullValue(); + } + else + { + // Legitimate InvalidCast, rethrow + throw; + } } } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlTypes/SqlTypeWorkarounds.cs b/src/System.Data.SqlClient/src/System/Data/SqlTypes/SqlTypeWorkarounds.cs index 175744112b25..0e18a0d29dfd 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlTypes/SqlTypeWorkarounds.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlTypes/SqlTypeWorkarounds.cs @@ -37,6 +37,17 @@ internal static XmlReader SqlXmlCreateSqlXmlReader(Stream stream, bool closeInpu return XmlReader.Create(stream, settingsToUse); } + + internal static XmlReader SqlXmlCreateSqlXmlReader(TextReader textReader, bool closeInput = false, bool async = false) + { + Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async"); + + XmlReaderSettings settingsToUse = closeInput ? + (async ? s_defaultXmlReaderSettingsAsyncCloseInput : s_defaultXmlReaderSettingsCloseInput) : + s_defaultXmlReaderSettings; + + return XmlReader.Create(textReader, settingsToUse); + } #endregion #region Work around inability to access SqlDateTime.ToDateTime diff --git a/src/System.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs b/src/System.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs index add701ba4e6c..39052d2bf9f3 100644 --- a/src/System.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs +++ b/src/System.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs @@ -294,6 +294,10 @@ private static void GetValueOfTRead(string connectionString) rdr.GetFieldValue(15); rdr.GetFieldValue(14); rdr.GetFieldValue(15); + rdr.GetFieldValue(14); + rdr.GetFieldValue(15); + rdr.GetFieldValueAsync(14); + rdr.GetFieldValueAsync(15); rdr.Read(); Assert.True(rdr.IsDBNullAsync(11).Result, "FAILED: IsDBNull was false for a null value");