diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index ff1db9b5e6..5a9d974af6 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -709,6 +709,7 @@
+
True
True
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs
new file mode 100644
index 0000000000..d4298afa2d
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs
@@ -0,0 +1,17 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Data.SqlClient
+{
+ ///
+ /// Serves to convert generic to out type by casting to object first. Relies on JIT to optimize out unneccessary casts and prevent double boxing.
+ ///
+ internal static class GenericConverter
+ {
+ public static TOut Convert(TIn value)
+ {
+ return (TOut)(object)value;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
index 8b8f689d0c..4e87235780 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
@@ -136,8 +136,8 @@ private enum ValueSourceType
DbDataReader
}
- // Enum for specifying SqlDataReader.Get method used
- private enum ValueMethod : byte
+ // Enum for specifying SqlDataReader.Get / IDataReader Get method used
+ private enum ValueMethod
{
GetValue,
SqlTypeSqlDecimal,
@@ -145,7 +145,19 @@ private enum ValueMethod : byte
SqlTypeSqlSingle,
DataFeedStream,
DataFeedText,
- DataFeedXml
+ DataFeedXml,
+ GetInt32,
+ GetString,
+ GetDouble,
+ GetDecimal,
+ GetInt16,
+ GetInt64,
+ GetChar,
+ GetByte,
+ GetBoolean,
+ GetDateTime,
+ GetGuid,
+ GetFloat
}
// Used to hold column metadata for SqlDataReader case
@@ -943,12 +955,19 @@ private void Dispose(bool disposing)
}
}
- // Unified method to read a value from the current row
- private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull)
+ // Reads a cell and then writes it.
+ // Read may block at this moment since there is no getValueAsync or DownStream async at this moment.
+ // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance.
+ // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end.
+ private Task ReadWriteColumnValueAsync(int destRowIndex)
{
_SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata;
int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal;
+ bool isSqlType = false;
+ bool isDataFeed = false;
+ bool isNull = false;
+
switch (_rowSourceType)
{
case ValueSourceType.IDataReader:
@@ -958,34 +977,39 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (_DbDataReaderRowSource.IsDBNull(sourceOrdinal))
{
- isSqlType = false;
- isDataFeed = false;
isNull = true;
- return DBNull.Value;
+ return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
- isSqlType = false;
isDataFeed = true;
- isNull = false;
+
+ object feedColumnValue;
+
switch (_currentRowMetadata[destRowIndex].Method)
{
case ValueMethod.DataFeedStream:
- return new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal));
+ feedColumnValue = new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal));
+ break;
case ValueMethod.DataFeedText:
- return new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal));
+ feedColumnValue = new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal));
+ break;
case ValueMethod.DataFeedXml:
// Only SqlDataReader supports an XmlReader
// There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field
Debug.Assert(_SqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader");
- return new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal));
+ feedColumnValue = new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal));
+ break;
default:
Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
isDataFeed = false;
- object columnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal);
- ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
- return columnValue;
+ feedColumnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType);
+ break;
}
+
+ //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario
+ return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
// SqlDataReader-specific logic
@@ -993,36 +1017,28 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (_currentRowMetadata[destRowIndex].IsSqlType)
{
- INullable value;
isSqlType = true;
- isDataFeed = false;
switch (_currentRowMetadata[destRowIndex].Method)
{
case ValueMethod.SqlTypeSqlDecimal:
- value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal);
- break;
+ var value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal);
+ return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull);
case ValueMethod.SqlTypeSqlDouble:
// use cast to handle IsNull correctly because no public constructor allows it
- value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal);
- break;
+ var dblValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal);
+ return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull);
case ValueMethod.SqlTypeSqlSingle:
- // use cast to handle IsNull correctly because no public constructor allows it
- value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal);
- break;
+ // use cast to handle value.IsNull correctly because no public constructor allows it
+ var singleValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal);
+ return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull);
default:
Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
- value = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal);
- break;
+ var sqlValue = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull);
}
-
- isNull = value.IsNull;
- return value;
}
else
{
- isSqlType = false;
- isDataFeed = false;
-
object value = _SqlDataReaderRowSource.GetValue(sourceOrdinal);
isNull = ((value == null) || (value == DBNull.Value));
if ((!isNull) && (metadata.type == SqlDbType.Udt))
@@ -1036,27 +1052,63 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable");
}
#endif
- return value;
+ return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
else
{
- isDataFeed = false;
-
IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource;
// Only use IsDbNull when streaming is enabled and only for non-SqlDataReader
if ((_enableStreaming) && (_SqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal)))
{
- isSqlType = false;
isNull = true;
- return DBNull.Value;
+ return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
- object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
- ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
- return columnValue;
+ if (_currentRowMetadata[destRowIndex].Method == ValueMethod.GetValue || rowSourceAsIDataReader.IsDBNull(sourceOrdinal))
+ {
+ object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
+ return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ switch (_currentRowMetadata[sourceOrdinal].Method)
+ {
+ case ValueMethod.GetInt32:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false);
+ case ValueMethod.GetString:
+ var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal);
+ isNull = strValue == null;
+ return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDouble:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDecimal:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetInt16:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetInt64:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetChar:
+ return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetByte:
+ return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetBoolean:
+ return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDateTime:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetGuid:
+ return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetFloat:
+ return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ default:
+ object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
+ return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ }
}
}
case ValueSourceType.DataTable:
@@ -1066,6 +1118,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!");
isDataFeed = false;
+ // unfortunately this has to be boxed due to DataRow's API.
object currentRowValue = _currentRow[sourceOrdinal];
ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType);
@@ -1078,7 +1131,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (isSqlType)
{
- return new SqlDecimal(((SqlSingle)currentRowValue).Value);
+ var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value);
+ return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
@@ -1086,16 +1140,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
if (!float.IsNaN(f))
{
isSqlType = true;
- return new SqlDecimal(f);
+ return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
- break;
}
}
case ValueMethod.SqlTypeSqlDouble:
{
if (isSqlType)
{
- return new SqlDecimal(((SqlDouble)currentRowValue).Value);
+ var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
@@ -1103,33 +1161,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
if (!double.IsNaN(d))
{
isSqlType = true;
- return new SqlDecimal(d);
+ return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
- break;
}
}
case ValueMethod.SqlTypeSqlDecimal:
{
if (isSqlType)
{
- return (SqlDecimal)currentRowValue;
+ var sqlValue = (SqlDecimal)currentRowValue;
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
isSqlType = true;
- return new SqlDecimal((decimal)currentRowValue);
+ var sqlValue = new SqlDecimal((decimal)currentRowValue);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
default:
{
Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
- break;
+ // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN)
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
}
-
- // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN)
- return currentRowValue;
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
}
default:
{
@@ -1320,6 +1385,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal)
method = ValueMethod.GetValue;
}
}
+ else if (_rowSourceType == ValueSourceType.IDataReader)
+ {
+ isSqlType = false;
+ isDataFeed = false;
+
+ Type t = ((IDataReader)_rowSource).GetFieldType(ordinal);
+
+ if (t == typeof(bool))
+ {
+ method = ValueMethod.GetBoolean;
+ }
+ else if (t == typeof(byte))
+ {
+ method = ValueMethod.GetByte;
+ }
+ else if (t == typeof(char))
+ {
+ method = ValueMethod.GetChar;
+ }
+ else if (t == typeof(DateTime))
+ {
+ method = ValueMethod.GetDateTime;
+ }
+ else if (t == typeof(decimal))
+ {
+ method = ValueMethod.GetDecimal;
+ }
+ else if (t == typeof(double))
+ {
+ method = ValueMethod.GetDouble;
+ }
+ else if (t == typeof(float))
+ {
+ method = ValueMethod.GetFloat;
+ }
+ else if (t == typeof(Guid))
+ {
+ method = ValueMethod.GetGuid;
+ }
+ else if (t == typeof(short))
+ {
+ method = ValueMethod.GetInt16;
+ }
+ else if (t == typeof(int))
+ {
+ method = ValueMethod.GetInt32;
+ }
+ else if (t == typeof(long))
+ {
+ method = ValueMethod.GetInt64;
+ }
+ else if (t == typeof(string))
+ {
+ method = ValueMethod.GetString;
+ }
+ else
+ {
+ method = ValueMethod.GetValue;
+ }
+ }
else
{
isSqlType = false;
@@ -1454,8 +1579,10 @@ private string UnquotedName(string name)
return name;
}
- private object ValidateBulkCopyVariant(object value)
+ private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue)
{
+ variantValue = null;
+
// From the spec:
// "The only acceptable types are ..."
// GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8,
@@ -1483,20 +1610,21 @@ private object ValidateBulkCopyVariant(object value)
case TdsEnums.SQLDATETIMEOFFSET:
if (value is INullable)
{ // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types.
- return MetaType.GetComValueFromSqlVariant(value);
+ variantValue = MetaType.GetComValueFromSqlVariant(value);
+ return true;
}
else
{
- return value;
+ return false;
}
default:
throw SQL.BulkLoadInvalidVariantValue();
}
}
- private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed)
+ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType)
{
- coercedToDataFeed = false;
+ bool coercedToDataFeed = false;
if (isNull)
{
@@ -1504,11 +1632,13 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
{
throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column);
}
- return value;
+
+ return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata);
}
MetaType type = metadata.metaType;
bool typeChanged = false;
+ object objValue = null;
// If the column is encrypted then we are going to transparently encrypt this column
// (based on connection string setting)- Use the metaType for the underlying
@@ -1533,46 +1663,50 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
{
case TdsEnums.SQLNUMERICN:
case TdsEnums.SQLDECIMALN:
- mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
-
- // Convert Source Decimal Precision and Scale to Destination Precision and Scale
- // Sql decimal data could get corrupted on insert if the scale of
- // the source and destination weren't the same. The BCP protocol, specifies the
- // scale of the incoming data in the insert statement, we just tell the server we
- // are inserting the same scale back.
- SqlDecimal sqlValue;
- if ((isSqlType) && (!typeChanged))
+ SqlDecimal decValue;
+ if (typeof(T) == typeof(decimal))
+ {
+ decValue = new SqlDecimal(GenericConverter.Convert(value));
+ }
+ else if (typeof(T) == typeof(SqlDecimal))
{
- sqlValue = (SqlDecimal)value;
+ decValue = GenericConverter.Convert(value);
}
else
{
- sqlValue = new SqlDecimal((decimal)value);
+ mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
+ decValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false));
}
- if (sqlValue.Scale != scale)
+ // Convert Source Decimal Precision and Scale to Destination Precision and Scale
+ // Sql decimal data could get corrupted on insert if the scale of
+ // the source and destination weren't the same. The BCP protocol, specifies the
+ // scale of the incoming data in the insert statement, we just tell the server we
+ // are inserting the same scale back.
+ if (decValue.Scale != scale)
{
- sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale);
+ decValue = TdsParser.AdjustSqlDecimalScale(decValue, scale);
}
- if (sqlValue.Precision > precision)
+ if (decValue.Precision > precision)
{
try
{
- sqlValue = SqlDecimal.ConvertToPrecScale(sqlValue, precision, sqlValue.Scale);
+ decValue = SqlDecimal.ConvertToPrecScale(decValue, precision, decValue.Scale);
}
catch (SqlTruncateException)
{
- throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue));
+ mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
+ throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(decValue));
}
}
// Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing
- value = sqlValue;
isSqlType = true;
typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type
- break;
+
+ // returning here to avoid unnecessary decValue initialization for all types
+ return WriteConvertedValue(decValue, col, isSqlType, isNull, coercedToDataFeed, metadata);
case TdsEnums.SQLINTN:
case TdsEnums.SQLFLTN:
@@ -1596,16 +1730,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
case TdsEnums.SQLDATETIME2:
case TdsEnums.SQLDATETIMEOFFSET:
mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
+ typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed);
break;
case TdsEnums.SQLNCHAR:
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
+ typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false);
if (!coercedToDataFeed)
{ // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX)
- string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value);
+ string str = typeChanged
+ ? (string)objValue
+ : isSqlType
+ ? GenericConverter.Convert(value).Value
+ : GenericConverter.Convert(value)
+ ;
+
int maxStringLength = length / 2;
if (str.Length > maxStringLength)
{
@@ -1622,10 +1762,10 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
throw SQL.BulkLoadStringTooLong(_destinationTableName, metadata.column, str);
}
}
+
break;
case TdsEnums.SQLVARIANT:
- value = ValidateBulkCopyVariant(value);
- typeChanged = true;
+ typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue);
break;
case TdsEnums.SQLUDT:
// UDTs are sent as varbinary so we need to get the raw bytes
@@ -1636,16 +1776,16 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
// in byte[] form.
if (!(value is byte[]))
{
- value = _connection.GetBytes(value);
+ objValue = _connection.GetBytes(value);
typeChanged = true;
}
break;
case TdsEnums.SQLXMLTYPE:
// Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed
Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype");
- if (value is XmlReader)
+ if (value is XmlReader xmlReader)
{
- value = new XmlDataFeed((XmlReader)value);
+ objValue = new XmlDataFeed(xmlReader);
typeChanged = true;
coercedToDataFeed = true;
}
@@ -1655,14 +1795,6 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null));
throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null);
}
-
- if (typeChanged)
- {
- // All type changes change to CLR types
- isSqlType = false;
- }
-
- return value;
}
catch (Exception e)
{
@@ -1672,6 +1804,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
}
throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e);
}
+
+ if (typeChanged)
+ {
+ // All type changes change to CLR types
+ isSqlType = false;
+ return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata);
+ }
+ else
+ {
+ return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata);
+ }
}
///
@@ -2205,33 +2348,40 @@ private bool FireRowsCopiedEvent(long rowsCopied)
return eventArgs.Abort;
}
- // Reads a cell and then writes it.
- // Read may block at this moment since there is no getValueAsync or DownStream async at this moment.
- // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance.
- // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end.
- private Task ReadWriteColumnValueAsync(int col)
+ private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull)
{
- bool isSqlType;
- bool isDataFeed;
- bool isNull;
- object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask
-
_SqlMetaData metadata = _sortedColumnMappings[col]._metadata;
- if (!isDataFeed)
+ if (isDataFeed)
+ {
+ //nothing to convert, skip straight to write
+ return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata);
+ }
+ else
{
- value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed);
+ return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType);
+ }
+ }
- // If column encryption is requested via connection string option, perform encryption here
- if (!isNull && // if value is not NULL
- metadata.isEncrypted)
- { // If we are transparently encrypting
- Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy());
- value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType);
- isSqlType = false; // Its not a sql type anymore
- }
+ private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata)
+ {
+ // If column encryption is requested via connection string option, perform encryption here
+ if (!isNull && // if value is not NULL
+ metadata.isEncrypted)
+ { // If we are transparently encrypting
+ Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy());
+ var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType);
+ isSqlType = false; // Its not a sql type anymore
+
+ return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata);
}
+ else
+ {
+ return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata);
+ }
+ }
- //write part
+ private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata)
+ {
Task writeTask = null;
if (metadata.type != SqlDbType.Variant)
{
@@ -2250,15 +2400,15 @@ private Task ReadWriteColumnValueAsync(int col)
if (variantInternalType == SqlBuffer.StorageType.DateTime2)
{
- _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj);
+ _parser.WriteSqlVariantDateTime2(GenericConverter.Convert(value), _stateObj);
}
else if (variantInternalType == SqlBuffer.StorageType.Date)
{
- _parser.WriteSqlVariantDate(((DateTime)value), _stateObj);
+ _parser.WriteSqlVariantDate(GenericConverter.Convert(value), _stateObj);
}
else
{
- writeTask = _parser.WriteSqlVariantDataRowValue(value, _stateObj); //returns Task/Null
+ writeTask = _parser.WriteSqlVariantDataRowValue(value, isNull, _stateObj); //returns Task/Null
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs
index 08029b8a7a..2a2fb2a1a9 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs
@@ -107,7 +107,7 @@ public sealed partial class SqlParameter : DbParameter, IDbDataParameter, IClone
///
/// Indicates if the parameter encryption metadata received by sp_describe_parameter_encryption.
- /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate
+ /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate
/// that no encryption is needed).
///
internal bool HasReceivedMetadata { get; set; }
@@ -407,7 +407,7 @@ internal SmiParameterMetaData MetaDataForSmi(out ParameterPeekAheadValue peekAhe
long actualLen = GetActualSize();
long maxLen = this.Size;
- // GetActualSize returns bytes length, but smi expects char length for
+ // GetActualSize returns bytes length, but smi expects char length for
// character types, so adjust
if (!mt.IsLong)
{
@@ -995,15 +995,27 @@ object ICloneable.Clone()
}
// Coerced Value is also used in SqlBulkCopy.ConvertValue(object value, _SqlMetaData metadata)
+
internal static object CoerceValue(object value, MetaType destinationType, out bool coercedToDataFeed, out bool typeChanged, bool allowStreaming = true)
+ {
+ typeChanged = CoerceValueIfNeeded(value, destinationType, out var objValue, out coercedToDataFeed, allowStreaming);
+
+ return typeChanged ? objValue : value;
+ }
+
+ internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, out object objValue, out bool coercedToDataFeed, bool allowStreaming = true)
{
Debug.Assert(!(value is DataFeed), "Value provided should not already be a data feed");
Debug.Assert(!ADP.IsNull(value), "Value provided should not be null");
Debug.Assert(null != destinationType, "null destinationType");
coercedToDataFeed = false;
- typeChanged = false;
- Type currentType = value.GetType();
+ objValue = null;
+ Type currentType = typeof(T) == typeof(object)
+ ? value.GetType() // only call GetType if we know boxing has already occurred.
+ : typeof(T);
+
+ var typeChanged = false;
if ((typeof(object) != destinationType.ClassType) &&
(currentType != destinationType.ClassType) &&
@@ -1018,45 +1030,45 @@ internal static object CoerceValue(object value, MetaType destinationType, out b
// For Xml data, destination Type is always string
if (typeof(SqlXml) == currentType)
{
- value = MetaType.GetStringFromXml((XmlReader)(((SqlXml)value).CreateReader()));
+ objValue = MetaType.GetStringFromXml(GenericConverter.Convert(value).CreateReader());
}
else if (typeof(SqlString) == currentType)
{
typeChanged = false; // Do nothing
}
- else if (typeof(XmlReader).IsAssignableFrom(currentType))
+ else if (value is XmlReader xmlReader)
{
if (allowStreaming)
{
coercedToDataFeed = true;
- value = new XmlDataFeed((XmlReader)value);
+ objValue = new XmlDataFeed(xmlReader);
}
else
{
- value = MetaType.GetStringFromXml((XmlReader)value);
+ objValue = MetaType.GetStringFromXml(xmlReader);
}
}
else if (typeof(char[]) == currentType)
{
- value = new string((char[])value);
+ objValue = new string(GenericConverter.Convert(value));
}
else if (typeof(SqlChars) == currentType)
{
- value = new string(((SqlChars)value).Value);
+ objValue = new string(GenericConverter.Convert(value).Value);
}
- else if (value is TextReader && allowStreaming)
+ else if (value is TextReader tr && allowStreaming)
{
coercedToDataFeed = true;
- value = new TextDataFeed((TextReader)value);
+ objValue = new TextDataFeed(tr);
}
else
{
- value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null);
+ objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null);
}
}
else if ((DbType.Currency == destinationType.DbType) && (typeof(string) == currentType))
{
- value = decimal.Parse((string)value, NumberStyles.Currency, (IFormatProvider)null);
+ objValue = decimal.Parse(GenericConverter.Convert(value), NumberStyles.Currency, (IFormatProvider)null);
}
else if ((typeof(SqlBytes) == currentType) && (typeof(byte[]) == destinationType.ClassType))
{
@@ -1064,15 +1076,15 @@ internal static object CoerceValue(object value, MetaType destinationType, out b
}
else if ((typeof(string) == currentType) && (SqlDbType.Time == destinationType.SqlDbType))
{
- value = TimeSpan.Parse((string)value);
+ objValue = TimeSpan.Parse(GenericConverter.Convert(value));
}
else if ((typeof(string) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType))
{
- value = DateTimeOffset.Parse((string)value, (IFormatProvider)null);
+ objValue = DateTimeOffset.Parse(GenericConverter.Convert(value), (IFormatProvider)null);
}
else if ((typeof(DateTime) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType))
{
- value = new DateTimeOffset((DateTime)value);
+ objValue = new DateTimeOffset(GenericConverter.Convert(value));
}
else if (TdsEnums.SQLTABLE == destinationType.TDSType && (
value is DataTable ||
@@ -1082,14 +1094,14 @@ value is DbDataReader ||
// no conversion for TVPs.
typeChanged = false;
}
- else if (destinationType.ClassType == typeof(byte[]) && value is Stream && allowStreaming)
+ else if (destinationType.ClassType == typeof(byte[]) && allowStreaming && value is Stream stream)
{
coercedToDataFeed = true;
- value = new StreamDataFeed((Stream)value);
+ objValue = new StreamDataFeed(stream);
}
else
{
- value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null);
+ objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null);
}
}
catch (Exception e)
@@ -1104,8 +1116,8 @@ value is DbDataReader ||
}
Debug.Assert(allowStreaming || !coercedToDataFeed, "Streaming is not allowed, but type was coerced into a data feed");
- Debug.Assert(value.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged");
- return value;
+ Debug.Assert(objValue == null || objValue.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged");
+ return typeChanged;
}
internal void FixStreamDataForNonPLP()
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
index b0b25d2e09..51194adb8c 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -2907,7 +2907,7 @@ private bool TryProcessDone(SqlCommand cmd, SqlDataReader reader, ref RunBehavio
int count;
// This is added back since removing it from here introduces regressions in Managed SNI.
- // It forces SqlDataReader.ReadAsync() method to run synchronously,
+ // It forces SqlDataReader.ReadAsync() method to run synchronously,
// and will block the calling thread until data is fed from SQL Server.
// TODO Investigate better solution to support non-blocking ReadAsync().
stateObj._syncOverAsync = true;
@@ -6630,6 +6630,7 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars
stateObj.WriteByte(mt.Precision); //propbytes: precision
stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale
WriteDecimal((decimal)value, stateObj);
+ WriteDecimal((decimal)value, stateObj);
break;
}
@@ -6664,10 +6665,10 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars
// Therefore the sql_variant value must not include the MaxLength. This is the major difference
// between this method and WriteSqlVariantValue above.
//
- internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject stateObj, bool canAccumulate = true)
+ internal Task WriteSqlVariantDataRowValue(T value, bool isNull, TdsParserStateObject stateObj, bool canAccumulate = true)
{
// handle null values
- if ((null == value) || (DBNull.Value == value))
+ if (isNull)
{
WriteInt(TdsEnums.FIXEDNULL, stateObj);
return null;
@@ -6678,44 +6679,44 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
if (metatype.IsAnsiType)
{
- length = GetEncodingCharLength((string)value, length, 0, _defaultEncoding);
+ length = GetEncodingCharLength(GenericConverter.Convert(value), length, 0, _defaultEncoding);
}
switch (metatype.TDSType)
{
case TdsEnums.SQLFLT4:
WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteFloat((float)value, stateObj);
+ WriteFloat(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLFLT8:
WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteDouble((double)value, stateObj);
+ WriteDouble(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLINT8:
WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteLong((long)value, stateObj);
+ WriteLong(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLINT4:
WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteInt((int)value, stateObj);
+ WriteInt(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLINT2:
WriteSqlVariantHeader(4, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteShort((short)value, stateObj);
+ WriteShort(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLINT1:
WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj);
- stateObj.WriteByte((byte)value);
+ stateObj.WriteByte(GenericConverter.Convert(value));
break;
case TdsEnums.SQLBIT:
WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj);
- if ((bool)value == true)
+ if (GenericConverter.Convert(value))
stateObj.WriteByte(1);
else
stateObj.WriteByte(0);
@@ -6724,7 +6725,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLBIGVARBINARY:
{
- byte[] b = (byte[])value;
+ byte[] b = GenericConverter.Convert(value);
length = b.Length;
WriteSqlVariantHeader(4 + length, metatype.TDSType, metatype.PropBytes, stateObj);
@@ -6734,7 +6735,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLBIGVARCHAR:
{
- string s = (string)value;
+ string s = GenericConverter.Convert(value);
length = s.Length;
WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj);
@@ -6746,7 +6747,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLUNIQUEID:
{
- System.Guid guid = (System.Guid)value;
+ Guid guid = GenericConverter.Convert(value);
Span b = stackalloc byte[16];
FillGuidBytes(guid, b);
@@ -6759,7 +6760,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLNVARCHAR:
{
- string s = (string)value;
+ string s = GenericConverter.Convert(value);
length = s.Length * 2;
WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj);
@@ -6774,7 +6775,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLDATETIME:
{
- TdsDateTime dt = MetaType.FromDateTime((DateTime)value, 8);
+ TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), 8);
WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj);
WriteInt(dt.days, stateObj);
@@ -6785,7 +6786,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
case TdsEnums.SQLMONEY:
{
WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj);
- WriteCurrency((decimal)value, 8, stateObj);
+ WriteCurrency(GenericConverter.Convert(value), 8, stateObj);
break;
}
@@ -6793,21 +6794,22 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta
{
WriteSqlVariantHeader(21, metatype.TDSType, metatype.PropBytes, stateObj);
stateObj.WriteByte(metatype.Precision); //propbytes: precision
- stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale
- WriteDecimal((decimal)value, stateObj);
+ var decValue = GenericConverter.Convert(value);
+ stateObj.WriteByte((byte)((decimal.GetBits(decValue)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale
+ WriteDecimal(decValue, stateObj);
break;
}
case TdsEnums.SQLTIME:
WriteSqlVariantHeader(8, metatype.TDSType, metatype.PropBytes, stateObj);
stateObj.WriteByte(metatype.Scale); //propbytes: scale
- WriteTime((TimeSpan)value, metatype.Scale, 5, stateObj);
+ WriteTime(GenericConverter.Convert(value), metatype.Scale, 5, stateObj);
break;
case TdsEnums.SQLDATETIMEOFFSET:
WriteSqlVariantHeader(13, metatype.TDSType, metatype.PropBytes, stateObj);
stateObj.WriteByte(metatype.Scale); //propbytes: scale
- WriteDateTimeOffset((DateTimeOffset)value, metatype.Scale, 10, stateObj);
+ WriteDateTimeOffset(GenericConverter.Convert(value), metatype.Scale, 10, stateObj);
break;
default:
@@ -10409,7 +10411,7 @@ internal bool ShouldEncryptValuesForBulkCopy()
/// Encrypts a column value (for SqlBulkCopy)
///
///
- internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType)
+ internal byte[] EncryptColumnValue(T value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType)
{
Debug.Assert(IsColumnEncryptionSupported, "Server doesn't support encryption, yet we received encryption metadata");
Debug.Assert(ShouldEncryptValuesForBulkCopy(), "Encryption attempted when not requested");
@@ -10434,7 +10436,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin
// when we normalize and serialize the data buffers. The serialization routine expects us
// to report the size of data to be copied out (for serialization). If we underreport the
// size, truncation will happen for us!
- actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length;
+ actualLengthInBytes = (isSqlType)
+ ? GenericConverter.Convert(value).Length
+ : GenericConverter.Convert(value).Length;
+
if (metadata.baseTI.length > 0 &&
actualLengthInBytes > metadata.baseTI.length)
{
@@ -10454,7 +10459,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin
ThrowUnsupportedCollationEncountered(null); // stateObject only when reading
}
- string stringValue = (isSqlType) ? ((SqlString)value).Value : (string)value;
+ string stringValue = (isSqlType)
+ ? GenericConverter.Convert(value).Value
+ : GenericConverter.Convert(value);
+
actualLengthInBytes = _defaultEncoding.GetByteCount(stringValue);
// If the string length is > max length, then use the max length (see comments above)
@@ -10468,7 +10476,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin
case TdsEnums.SQLNCHAR:
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
- actualLengthInBytes = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2;
+ actualLengthInBytes = (isSqlType
+ ? GenericConverter.Convert(value).Value.Length
+ : GenericConverter.Convert(value).Length)
+ * 2;
if (metadata.baseTI.length > 0 &&
actualLengthInBytes > metadata.baseTI.length)
@@ -10513,7 +10524,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin
_connHandler.ConnectionOptions.DataSource);
}
- internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull)
+ internal Task WriteBulkCopyValue(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull)
{
Debug.Assert(!isSqlType || value is INullable, "isSqlType is true, but value can not be type cast to an INullable");
Debug.Assert(!isDataFeed ^ value is DataFeed, "Incorrect value for isDataFeed");
@@ -10578,7 +10589,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars
case TdsEnums.SQLBIGVARBINARY:
case TdsEnums.SQLIMAGE:
case TdsEnums.SQLUDT:
- ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length;
+ ccb = (isSqlType)
+ ? GenericConverter.Convert(value).Length
+ : GenericConverter.Convert(value).Length;
break;
case TdsEnums.SQLUNIQUEID:
ccb = GUID_SIZE;
@@ -10594,11 +10607,11 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars
string stringValue = null;
if (isSqlType)
{
- stringValue = ((SqlString)value).Value;
+ stringValue = GenericConverter.Convert(value).Value;
}
else
{
- stringValue = (string)value;
+ stringValue = GenericConverter.Convert(value);
}
ccb = stringValue.Length;
@@ -10607,15 +10620,22 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars
case TdsEnums.SQLNCHAR:
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
- ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2;
+ ccb = (isSqlType
+ ? GenericConverter.Convert(value).Value.Length
+ : GenericConverter.Convert(value).Length
+ ) * 2;
break;
case TdsEnums.SQLXMLTYPE:
// Value here could be string or XmlReader
- if (value is XmlReader)
+ // the XmlReader scenario can only occur when T is object (enforced during SqlBulkCopy.ReadWriteColumnValueAsync)
+ if (typeof(T) == typeof(object) && value is XmlReader xr)
{
- value = MetaType.GetStringFromXml((XmlReader)value);
+ value = GenericConverter.Convert(MetaType.GetStringFromXml(xr));
}
- ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2;
+ ccb = (isSqlType
+ ? GenericConverter.Convert(value).Value.Length
+ : GenericConverter.Convert(value).Length
+ ) * 2;
break;
default:
@@ -10667,7 +10687,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars
}
else if (metatype.SqlDbType != SqlDbType.Udt || metatype.IsLong)
{
+ // we only have to consider a conversion from above in this case.
internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed);
+
if ((internalWriteTask == null) && (_asyncWrite))
{
internalWriteTask = stateObj.WaitForAccumulatedWrites();
@@ -10677,7 +10699,7 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars
else
{
WriteShort(ccb, stateObj);
- internalWriteTask = stateObj.WriteByteArray((byte[])value, ccb, 0);
+ internalWriteTask = stateObj.WriteByteArray(GenericConverter.Convert(value), ccb, 0);
}
#if DEBUG
@@ -10986,7 +11008,7 @@ private bool IsBOMNeeded(MetaType type, object value)
return false;
}
- private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed)
+ private Task GetTerminationTask(Task unterminatedWriteTask, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed)
{
if (type.IsPlp && ((actualLength > 0) || isDataFeed))
{
@@ -11007,16 +11029,16 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy
}
- private Task WriteSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj)
+ private Task WriteSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj)
{
return GetTerminationTask(
WriteUnterminatedSqlValue(value, type, actualLength, codePageByteSize, offset, stateObj),
- value, type, actualLength, stateObj, false);
+ type, actualLength, stateObj, false);
}
// For MAX types, this method can only write everything in one big chunk. If multiple
// chunk writes needed, please use WritePlpBytes/WritePlpChars
- private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj)
+ private Task WriteUnterminatedSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj)
{
Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) ||
(value is INullable && !((INullable)value).IsNull)),
@@ -11027,11 +11049,11 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
{
case TdsEnums.SQLFLTN:
if (type.FixedLength == 4)
- WriteFloat(((SqlSingle)value).Value, stateObj);
+ WriteFloat(GenericConverter.Convert(value).Value, stateObj);
else
{
Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!");
- WriteDouble(((SqlDouble)value).Value, stateObj);
+ WriteDouble(GenericConverter.Convert(value).Value, stateObj);
}
break;
@@ -11047,12 +11069,12 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
if (value is SqlBinary)
{
- return stateObj.WriteByteArray(((SqlBinary)value).Value, actualLength, offset, canAccumulate: false);
+ return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false);
}
else
{
Debug.Assert(value is SqlBytes);
- return stateObj.WriteByteArray(((SqlBytes)value).Value, actualLength, offset, canAccumulate: false);
+ return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false);
}
}
@@ -11060,7 +11082,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
{
Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object");
Span b = stackalloc byte[16];
- SqlGuid sqlGuid = (SqlGuid)value;
+ SqlGuid sqlGuid = GenericConverter.Convert(value);
if (sqlGuid.IsNull)
{
@@ -11078,7 +11100,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
case TdsEnums.SQLBITN:
{
Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type");
- if (((SqlBoolean)value).Value == true)
+ if (GenericConverter.Convert(value).Value == true)
stateObj.WriteByte(1);
else
stateObj.WriteByte(0);
@@ -11088,17 +11110,17 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
case TdsEnums.SQLINTN:
if (type.FixedLength == 1)
- stateObj.WriteByte(((SqlByte)value).Value);
+ stateObj.WriteByte(GenericConverter.Convert(value).Value);
else
if (type.FixedLength == 2)
- WriteShort(((SqlInt16)value).Value, stateObj);
+ WriteShort(GenericConverter.Convert(value).Value, stateObj);
else
if (type.FixedLength == 4)
- WriteInt(((SqlInt32)value).Value, stateObj);
+ WriteInt(GenericConverter.Convert(value).Value, stateObj);
else
{
Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture));
- WriteLong(((SqlInt64)value).Value, stateObj);
+ WriteLong(GenericConverter.Convert(value).Value, stateObj);
}
break;
@@ -11112,14 +11134,14 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
}
if (value is SqlChars)
{
- string sch = new string(((SqlChars)value).Value);
+ string sch = new string(GenericConverter.Convert(value).Value);
return WriteEncodingChar(sch, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false);
}
else
{
Debug.Assert(value is SqlString);
- return WriteEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false);
+ return WriteEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false);
}
@@ -11148,21 +11170,21 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
if (value is SqlChars)
{
- return WriteCharArray(((SqlChars)value).Value, actualLength, offset, stateObj, canAccumulate: false);
+ return WriteCharArray(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false);
}
else
{
Debug.Assert(value is SqlString);
- return WriteString(((SqlString)value).Value, actualLength, offset, stateObj, canAccumulate: false);
+ return WriteString(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false);
}
case TdsEnums.SQLNUMERICN:
Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes");
- WriteSqlDecimal((SqlDecimal)value, stateObj);
+ WriteSqlDecimal(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLDATETIMN:
- SqlDateTime dt = (SqlDateTime)value;
+ SqlDateTime dt = GenericConverter.Convert(value);
if (type.FixedLength == 4)
{
@@ -11182,7 +11204,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
case TdsEnums.SQLMONEYN:
{
- WriteSqlMoney((SqlMoney)value, type.FixedLength, stateObj);
+ WriteSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj);
break;
}
@@ -11652,28 +11674,28 @@ private Task NullIfCompletedWriteTask(Task task)
}
}
- private Task WriteValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed)
+ private Task WriteValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed)
{
return GetTerminationTask(WriteUnterminatedValue(value, type, scale, actualLength, encodingByteSize, offset, stateObj, paramSize, isDataFeed),
- value, type, actualLength, stateObj, isDataFeed);
+ type, actualLength, stateObj, isDataFeed);
}
// For MAX types, this method can only write everything in one big chunk. If multiple
// chunk writes needed, please use WritePlpBytes/WritePlpChars
- private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed)
+ private Task WriteUnterminatedValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed)
{
- Debug.Assert((null != value) && (DBNull.Value != value), "unexpected missing or empty object");
+ Debug.Assert((null != value) && !(value is DBNull), "unexpected missing or empty object");
// parameters are always sent over as BIG or N types
switch (type.NullableType)
{
case TdsEnums.SQLFLTN:
if (type.FixedLength == 4)
- WriteFloat((float)value, stateObj);
+ WriteFloat(GenericConverter.Convert(value), stateObj);
else
{
Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!");
- WriteDouble((double)value, stateObj);
+ WriteDouble(GenericConverter.Convert(value), stateObj);
}
break;
@@ -11690,7 +11712,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
if (isDataFeed)
{
Debug.Assert(type.IsPlp, "Stream assigned to non-PLP was not converted!");
- return NullIfCompletedWriteTask(WriteStreamFeed((StreamDataFeed)value, stateObj, paramSize));
+ return NullIfCompletedWriteTask(WriteStreamFeed(GenericConverter.Convert(value), stateObj, paramSize));
}
else
{
@@ -11698,7 +11720,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
{
WriteInt(actualLength, stateObj); // chunk length
}
- return stateObj.WriteByteArray((byte[])value, actualLength, offset, canAccumulate: false);
+ return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, offset, canAccumulate: false);
}
}
@@ -11706,7 +11728,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
{
Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object");
Span b = stackalloc byte[16];
- FillGuidBytes((System.Guid)value, b);
+ FillGuidBytes(GenericConverter.Convert(value), b);
stateObj.WriteByteSpan(b);
break;
}
@@ -11714,7 +11736,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
case TdsEnums.SQLBITN:
{
Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type");
- if ((bool)value == true)
+ if (GenericConverter.Convert(value) == true)
stateObj.WriteByte(1);
else
stateObj.WriteByte(0);
@@ -11724,15 +11746,15 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
case TdsEnums.SQLINTN:
if (type.FixedLength == 1)
- stateObj.WriteByte((byte)value);
+ stateObj.WriteByte(GenericConverter.Convert(value));
else if (type.FixedLength == 2)
- WriteShort((short)value, stateObj);
+ WriteShort(GenericConverter.Convert(value), stateObj);
else if (type.FixedLength == 4)
- WriteInt((int)value, stateObj);
+ WriteInt(GenericConverter.Convert(value), stateObj);
else
{
Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture));
- WriteLong((long)value, stateObj);
+ WriteLong(GenericConverter.Convert(value), stateObj);
}
break;
@@ -11750,7 +11772,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
TextDataFeed tdf = value as TextDataFeed;
if (tdf == null)
{
- return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize));
+ return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize));
}
else
{
@@ -11765,11 +11787,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
}
if (value is byte[])
{ // If LazyMat non-filled blob, send cookie rather than value
- return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false);
+ return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false);
}
else
{
- return WriteEncodingChar((string)value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false);
+ return WriteEncodingChar(GenericConverter.Convert(value), actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false);
}
}
}
@@ -11787,7 +11809,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
TextDataFeed tdf = value as TextDataFeed;
if (tdf == null)
{
- return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize));
+ return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize));
}
else
{
@@ -11810,25 +11832,25 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
}
if (value is byte[])
{ // If LazyMat non-filled blob, send cookie rather than value
- return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false);
+ return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false);
}
else
{
// convert to cchars instead of cbytes
actualLength >>= 1;
- return WriteString((string)value, actualLength, offset, stateObj, canAccumulate: false);
+ return WriteString(GenericConverter.Convert(value), actualLength, offset, stateObj, canAccumulate: false);
}
}
}
case TdsEnums.SQLNUMERICN:
Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes");
- WriteDecimal((decimal)value, stateObj);
+ WriteDecimal(GenericConverter.Convert(value), stateObj);
break;
case TdsEnums.SQLDATETIMN:
Debug.Assert(type.FixedLength <= 0xff, "Invalid Fixed Length");
- TdsDateTime dt = MetaType.FromDateTime((DateTime)value, (byte)type.FixedLength);
+ TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), (byte)type.FixedLength);
if (type.FixedLength == 4)
{
@@ -11848,13 +11870,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
case TdsEnums.SQLMONEYN:
{
- WriteCurrency((decimal)value, type.FixedLength, stateObj);
+ WriteCurrency(GenericConverter.Convert(value), type.FixedLength, stateObj);
break;
}
case TdsEnums.SQLDATE:
{
- WriteDate((DateTime)value, stateObj);
+ WriteDate(GenericConverter.Convert(value), stateObj);
break;
}
@@ -11863,7 +11885,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
{
throw SQL.TimeScaleValueOutOfRange(scale);
}
- WriteTime((TimeSpan)value, scale, actualLength, stateObj);
+ WriteTime(GenericConverter.Convert(value), scale, actualLength, stateObj);
break;
case TdsEnums.SQLDATETIME2:
@@ -11871,11 +11893,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
{
throw SQL.TimeScaleValueOutOfRange(scale);
}
- WriteDateTime2((DateTime)value, scale, actualLength, stateObj);
+ WriteDateTime2(GenericConverter.Convert(value), scale, actualLength, stateObj);
break;
case TdsEnums.SQLDATETIMEOFFSET:
- WriteDateTimeOffset((DateTimeOffset)value, scale, actualLength, stateObj);
+ WriteDateTimeOffset(GenericConverter.Convert(value), scale, actualLength, stateObj);
break;
default:
@@ -12119,7 +12141,7 @@ private byte[] SerializeUnencryptedValue(object value, MetaType type, byte scale
// For MAX types, this method can only write everything in one big chunk. If multiple
// chunk writes needed, please use WritePlpBytes/WritePlpChars
- private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj)
+ private byte[] SerializeUnencryptedSqlValue(T value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj)
{
Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) ||
(value is INullable && !((INullable)value).IsNull)),
@@ -12135,11 +12157,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
{
case TdsEnums.SQLFLTN:
if (type.FixedLength == 4)
- return SerializeFloat(((SqlSingle)value).Value);
+ {
+ return SerializeFloat(GenericConverter.Convert(value).Value);
+ }
else
{
Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!");
- return SerializeDouble(((SqlDouble)value).Value);
+ return SerializeDouble(GenericConverter.Convert(value).Value);
}
case TdsEnums.SQLBIGBINARY:
@@ -12150,19 +12174,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
if (value is SqlBinary)
{
- Buffer.BlockCopy(((SqlBinary)value).Value, offset, b, 0, actualLength);
+ Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength);
}
else
{
Debug.Assert(value is SqlBytes);
- Buffer.BlockCopy(((SqlBytes)value).Value, offset, b, 0, actualLength);
+ Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength);
}
+
return b;
}
case TdsEnums.SQLUNIQUEID:
{
- byte[] b = ((SqlGuid)value).ToByteArray();
+ byte[] b = GenericConverter.Convert(value).ToByteArray();
Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object");
return b;
@@ -12173,23 +12198,23 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type");
// We normalize to allow conversion across data types. BIT is serialized into a BIGINT.
- return SerializeLong(((SqlBoolean)value).Value == true ? 1 : 0, stateObj);
+ return SerializeLong(GenericConverter.Convert(value).Value == true ? 1 : 0, stateObj);
}
case TdsEnums.SQLINTN:
// We normalize to allow conversion across data types. All data types below are serialized into a BIGINT.
if (type.FixedLength == 1)
- return SerializeLong(((SqlByte)value).Value, stateObj);
+ return SerializeLong(GenericConverter.Convert(value).Value, stateObj);
if (type.FixedLength == 2)
- return SerializeLong(((SqlInt16)value).Value, stateObj);
+ return SerializeLong(GenericConverter.Convert(value).Value, stateObj);
if (type.FixedLength == 4)
- return SerializeLong(((SqlInt32)value).Value, stateObj);
+ return SerializeLong(GenericConverter.Convert(value).Value, stateObj);
else
{
Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture));
- return SerializeLong(((SqlInt64)value).Value, stateObj);
+ return SerializeLong(GenericConverter.Convert(value).Value, stateObj);
}
case TdsEnums.SQLBIGCHAR:
@@ -12197,13 +12222,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
case TdsEnums.SQLTEXT:
if (value is SqlChars)
{
- String sch = new String(((SqlChars)value).Value);
+ String sch = new String(GenericConverter.Convert(value).Value);
return SerializeEncodingChar(sch, actualLength, offset, _defaultEncoding);
}
else
{
Debug.Assert(value is SqlString);
- return SerializeEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding);
+ return SerializeEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding);
}
@@ -12218,20 +12243,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
if (value is SqlChars)
{
- return SerializeCharArray(((SqlChars)value).Value, actualLength, offset);
+ return SerializeCharArray(GenericConverter.Convert(value).Value, actualLength, offset);
}
else
{
Debug.Assert(value is SqlString);
- return SerializeString(((SqlString)value).Value, actualLength, offset);
+ return SerializeString(GenericConverter.Convert(value).Value, actualLength, offset);
}
case TdsEnums.SQLNUMERICN:
Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes");
- return SerializeSqlDecimal((SqlDecimal)value, stateObj);
+ return SerializeSqlDecimal(GenericConverter.Convert(value), stateObj);
case TdsEnums.SQLDATETIMN:
- SqlDateTime dt = (SqlDateTime)value;
+ SqlDateTime dt = GenericConverter.Convert(value);
if (type.FixedLength == 4)
{
@@ -12277,7 +12302,7 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act
case TdsEnums.SQLMONEYN:
{
- return SerializeSqlMoney((SqlMoney)value, type.FixedLength, stateObj);
+ return SerializeSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj);
}
default:
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
index a63ed87bce..51698ddd73 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
@@ -287,6 +287,7 @@
+
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs
new file mode 100644
index 0000000000..2d6eb82fba
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Data.SqlClient
+{
+ // This leverages the same assumptions in SqlBuffer that the JIT will optimize out the boxing / unboxing when TIn == TOut
+ // This behavior is proven out in the NoBoxingValueTypes BulkCopy unit test that benchmarks and measures the allocations
+ internal static class GenericConverter
+ {
+ public static TOut Convert(TIn value)
+ {
+ return (TOut)(object)value;
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
index fdfbc1a71b..c4beb7d98e 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs
@@ -186,8 +186,8 @@ private enum ValueSourceType
DbDataReader
}
- // Enum for specifying SqlDataReader.Get method used
- private enum ValueMethod : byte
+ // Enum for specifying SqlDataReader.Get / IDataReader Get method used
+ private enum ValueMethod
{
GetValue,
SqlTypeSqlDecimal,
@@ -195,7 +195,19 @@ private enum ValueMethod : byte
SqlTypeSqlSingle,
DataFeedStream,
DataFeedText,
- DataFeedXml
+ DataFeedXml,
+ GetInt32,
+ GetString,
+ GetDouble,
+ GetDecimal,
+ GetInt16,
+ GetInt64,
+ GetChar,
+ GetByte,
+ GetBoolean,
+ GetDateTime,
+ GetGuid,
+ GetFloat
}
// Used to hold column metadata for SqlDataReader case
@@ -309,7 +321,7 @@ private int RowNumber
// for debug purpose only.
// TODO: I will make this internal to use Reflection.
#if DEBUG
- internal static bool _setAlwaysTaskOnWrite = false; //when set and in DEBUG mode, TdsParser::WriteBulkCopyValue will always return a task
+ internal static bool _setAlwaysTaskOnWrite = false; //when set and in DEBUG mode, TdsParser::WriteBulkCopyValue will always return a task
internal static bool SetAlwaysTaskOnWrite
{
set
@@ -541,8 +553,8 @@ private bool IsCopyOption(SqlBulkCopyOptions copyOption)
return (_copyOptions & copyOption) == copyOption;
}
- //Creates the initial query string, but does not execute it.
- //
+ //Creates the initial query string, but does not execute it.
+ //
private string CreateInitialQuery()
{
string[] parts;
@@ -563,7 +575,7 @@ private string CreateInitialQuery()
TDSCommand = "select @@trancount; SET FMTONLY ON select * from " + ADP.BuildMultiPartName(parts) + " SET FMTONLY OFF ";
if (_connection.IsShiloh)
{
- // If its a temp DB then try to connect
+ // If its a temp DB then try to connect
string TableCollationsStoredProc;
if (_connection.IsKatmaiOrNewer)
@@ -626,9 +638,9 @@ private string CreateInitialQuery()
}
// Creates and then executes initial query to get information about the targettable
- // When __isAsyncBulkCopy == false (i.e. it is Sync copy): out result contains the resulset. Returns null.
- // When __isAsyncBulkCopy == true (i.e. it is Async copy): This still uses the _parser.Run method synchronously and return Task.
- // We need to have a _parser.RunAsync to make it real async.
+ // When __isAsyncBulkCopy == false (i.e. it is Sync copy): out result contains the resulset. Returns null.
+ // When __isAsyncBulkCopy == true (i.e. it is Async copy): This still uses the _parser.Run method synchronously and return Task.
+ // We need to have a _parser.RunAsync to make it real async.
private Task CreateAndExecuteInitialQueryAsync(out BulkCopySimpleResultSet result)
{
string TDSCommand = CreateInitialQuery();
@@ -1055,13 +1067,19 @@ private void Dispose(bool disposing)
// free unmanaged objects
}
- // unified method to read a value from the current row
- //
- private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull)
+ // Reads a cell and then writes it.
+ // Read may block at this moment since there is no getValueAsync or DownStream async at this moment.
+ // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance.
+ // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end.
+ private Task ReadWriteColumnValueAsync(int destRowIndex)
{
_SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata;
int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal;
+ bool isSqlType = false;
+ bool isDataFeed = false;
+ bool isNull = false;
+
switch (_rowSourceType)
{
case ValueSourceType.IDataReader:
@@ -1071,34 +1089,39 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (_DbDataReaderRowSource.IsDBNull(sourceOrdinal))
{
- isSqlType = false;
- isDataFeed = false;
isNull = true;
- return DBNull.Value;
+ return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
- isSqlType = false;
isDataFeed = true;
- isNull = false;
+
+ object feedColumnValue;
+
switch (_currentRowMetadata[destRowIndex].Method)
{
case ValueMethod.DataFeedStream:
- return new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal));
+ feedColumnValue = new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal));
+ break;
case ValueMethod.DataFeedText:
- return new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal));
+ feedColumnValue = new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal));
+ break;
case ValueMethod.DataFeedXml:
// Only SqlDataReader supports an XmlReader
- // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field
+ // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field
Debug.Assert(_SqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader");
- return new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal));
+ feedColumnValue = new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal));
+ break;
default:
- Debug.Assert(false, string.Format("Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method));
+ Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
isDataFeed = false;
- object columnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal);
- ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
- return columnValue;
+ feedColumnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType);
+ break;
}
+
+ //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario
+ return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
// SqlDataReader-specific logic
@@ -1106,36 +1129,28 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (_currentRowMetadata[destRowIndex].IsSqlType)
{
- INullable value;
isSqlType = true;
- isDataFeed = false;
switch (_currentRowMetadata[destRowIndex].Method)
{
case ValueMethod.SqlTypeSqlDecimal:
- value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal);
- break;
+ var value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal);
+ return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull);
case ValueMethod.SqlTypeSqlDouble:
// use cast to handle IsNull correctly because no public constructor allows it
- value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal);
- break;
+ var dblValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal);
+ return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull);
case ValueMethod.SqlTypeSqlSingle:
- // use cast to handle IsNull correctly because no public constructor allows it
- value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal);
- break;
+ // use cast to handle value.IsNull correctly because no public constructor allows it
+ var singleValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal);
+ return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull);
default:
- Debug.Assert(false, string.Format("Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method));
- value = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal);
- break;
+ Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
+ var sqlValue = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull);
}
-
- isNull = value.IsNull;
- return value;
}
else
{
- isSqlType = false;
- isDataFeed = false;
-
object value = _SqlDataReaderRowSource.GetValue(sourceOrdinal);
isNull = ((value == null) || (value == DBNull.Value));
if ((!isNull) && (metadata.type == SqlDbType.Udt))
@@ -1148,31 +1163,66 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable");
}
-#endif
- return value;
+#endif
+ return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
else
{
- isDataFeed = false;
-
IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource;
- // Back-compat with 4.0 and 4.5 - only use IsDbNull when streaming is enabled and only for non-SqlDataReader
+ // Only use IsDbNull when streaming is enabled and only for non-SqlDataReader
if ((_enableStreaming) && (_SqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal)))
{
- isSqlType = false;
isNull = true;
- return DBNull.Value;
+ return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
- object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
- ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
- return columnValue;
+ if (_currentRowMetadata[destRowIndex].Method == ValueMethod.GetValue || rowSourceAsIDataReader.IsDBNull(sourceOrdinal))
+ {
+ object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
+ return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ switch (_currentRowMetadata[sourceOrdinal].Method)
+ {
+ case ValueMethod.GetInt32:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false);
+ case ValueMethod.GetString:
+ var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal);
+ isNull = strValue == null;
+ return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDouble:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDecimal:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetInt16:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetInt64:
+ return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetChar:
+ return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetByte:
+ return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetBoolean:
+ return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetDateTime:
+ return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetGuid:
+ return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ case ValueMethod.GetFloat:
+ return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull);
+ default:
+ object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal);
+ ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType);
+ return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ }
}
}
-
case ValueSourceType.DataTable:
case ValueSourceType.RowArray:
{
@@ -1180,6 +1230,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!");
isDataFeed = false;
+ // unfortunately this has to be boxed due to DataRow's API.
object currentRowValue = _currentRow[sourceOrdinal];
ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType);
@@ -1192,7 +1243,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
{
if (isSqlType)
{
- return new SqlDecimal(((SqlSingle)currentRowValue).Value);
+ var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value);
+ return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
@@ -1200,16 +1252,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
if (!float.IsNaN(f))
{
isSqlType = true;
- return new SqlDecimal(f);
+ return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
- break;
}
}
case ValueMethod.SqlTypeSqlDouble:
{
if (isSqlType)
{
- return new SqlDecimal(((SqlDouble)currentRowValue).Value);
+ var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
@@ -1217,37 +1273,44 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b
if (!double.IsNaN(d))
{
isSqlType = true;
- return new SqlDecimal(d);
+ return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull);
+ }
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
- break;
}
}
case ValueMethod.SqlTypeSqlDecimal:
{
if (isSqlType)
{
- return (SqlDecimal)currentRowValue;
+ var sqlValue = (SqlDecimal)currentRowValue;
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
else
{
isSqlType = true;
- return new SqlDecimal((Decimal)currentRowValue);
+ var sqlValue = new SqlDecimal((decimal)currentRowValue);
+ return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
default:
{
- Debug.Assert(false, string.Format("Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method));
- break;
+ Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}");
+ // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN)
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
}
}
}
-
- // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN)
- return currentRowValue;
+ else
+ {
+ return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull);
+ }
}
default:
{
- Debug.Assert(false, "ValueSourcType unspecified");
+ Debug.Fail("ValueSourcType unspecified");
throw ADP.NotSupported();
}
}
@@ -1261,7 +1324,7 @@ private Task ReadFromRowSourceAsync(CancellationToken cts)
{
if (_isAsyncBulkCopy && (_DbDataReaderRowSource != null))
{
- //This will call ReadAsync for DbDataReader (for SqlDataReader it will be truely async read; for non-SqlDataReader it may block.)
+ //This will call ReadAsync for DbDataReader (for SqlDataReader it will be truely async read; for non-SqlDataReader it may block.)
return _DbDataReaderRowSource.ReadAsync(cts).ContinueWith((t) =>
{
if (t.Status == TaskStatus.RanToCompletion)
@@ -1374,7 +1437,7 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal)
else if (typeof(SqlSingle) == t || typeof(float) == t)
{
isSqlType = true;
- method = ValueMethod.SqlTypeSqlSingle; // Source Type SqlSingle
+ method = ValueMethod.SqlTypeSqlSingle; // Source Type SqlSingle
}
else
{
@@ -1439,6 +1502,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal)
method = ValueMethod.GetValue;
}
}
+ else if (_rowSourceType == ValueSourceType.IDataReader)
+ {
+ isSqlType = false;
+ isDataFeed = false;
+
+ Type t = ((IDataReader)_rowSource).GetFieldType(ordinal);
+
+ if (t == typeof(bool))
+ {
+ method = ValueMethod.GetBoolean;
+ }
+ else if (t == typeof(byte))
+ {
+ method = ValueMethod.GetByte;
+ }
+ else if (t == typeof(char))
+ {
+ method = ValueMethod.GetChar;
+ }
+ else if (t == typeof(DateTime))
+ {
+ method = ValueMethod.GetDateTime;
+ }
+ else if (t == typeof(decimal))
+ {
+ method = ValueMethod.GetDecimal;
+ }
+ else if (t == typeof(double))
+ {
+ method = ValueMethod.GetDouble;
+ }
+ else if (t == typeof(float))
+ {
+ method = ValueMethod.GetFloat;
+ }
+ else if (t == typeof(Guid))
+ {
+ method = ValueMethod.GetGuid;
+ }
+ else if (t == typeof(short))
+ {
+ method = ValueMethod.GetInt16;
+ }
+ else if (t == typeof(int))
+ {
+ method = ValueMethod.GetInt32;
+ }
+ else if (t == typeof(long))
+ {
+ method = ValueMethod.GetInt64;
+ }
+ else if (t == typeof(string))
+ {
+ method = ValueMethod.GetString;
+ }
+ else
+ {
+ method = ValueMethod.GetValue;
+ }
+ }
else
{
isSqlType = false;
@@ -1449,8 +1572,6 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal)
return new SourceColumnMetadata(method, isSqlType, isDataFeed);
}
- //
- //
private void CreateOrValidateConnection(string method)
{
if (null == _connection)
@@ -1581,13 +1702,14 @@ private string UnquotedName(string name)
return name;
}
- private object ValidateBulkCopyVariant(object value)
+ private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue)
{
- // from the spec:
+ variantValue = null;
+
+ // From the spec:
// "The only acceptable types are ..."
// GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8,
// MONEY4, MONEY, DECIMALN, NUMERICN, FTL4, FLT8, DATETIME4 and DATETIME
- //
MetaType metatype = MetaType.GetMetaTypeFromValue(value);
switch (metatype.TDSType)
{
@@ -1610,21 +1732,22 @@ private object ValidateBulkCopyVariant(object value)
case TdsEnums.SQLDATETIME2:
case TdsEnums.SQLDATETIMEOFFSET:
if (value is INullable)
- { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types.
- return MetaType.GetComValueFromSqlVariant(value);
+ { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types.
+ variantValue = MetaType.GetComValueFromSqlVariant(value);
+ return true;
}
else
{
- return value;
+ return false;
}
default:
throw SQL.BulkLoadInvalidVariantValue();
}
}
- private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed)
+ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType)
{
- coercedToDataFeed = false;
+ bool coercedToDataFeed = false;
if (isNull)
{
@@ -1632,11 +1755,13 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
{
throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column);
}
- return value;
+
+ return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata);
}
MetaType type = metadata.metaType;
bool typeChanged = false;
+ object objValue = null;
// If the column is encrypted then we are going to transparently encrypt this column
// (based on connection string setting)- Use the metaType for the underlying
@@ -1661,56 +1786,55 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
{
case TdsEnums.SQLNUMERICN:
case TdsEnums.SQLDECIMALN:
- mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
-
- // Convert Source Decimal Precision and Scale to Destination Precision and Scale
- // Fix Bug: 385971 sql decimal data could get corrupted on insert if the scale of
- // the source and destination weren't the same. The BCP protocol, specifies the
- // scale of the incoming data in the insert statement, we just tell the server we
- // are inserting the same scale back. This then created a bug inside the BCP operation
- // if the scales didn't match. The fix is to do the same thing that SQL Parameter does,
- // and adjust the scale before writing. In Orcas is scale adjustment should be removed from
- // SqlParameter and SqlBulkCopy and Isolated inside SqlParameter.CoerceValue, but because of
- // where we are in the cycle, the changes must be kept at minimum, so I'm just bringing the
- // code over to SqlBulkCopy.
-
- SqlDecimal sqlValue;
- if ((isSqlType) && (!typeChanged))
+ SqlDecimal decValue;
+ if (typeof(T) == typeof(decimal))
+ {
+ decValue = new SqlDecimal(GenericConverter.Convert(value));
+ }
+ else if (typeof(T) == typeof(SqlDecimal))
{
- sqlValue = (SqlDecimal)value;
+ decValue = GenericConverter.Convert(value);
}
else
{
- sqlValue = new SqlDecimal((Decimal)value);
+ mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
+ decValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false));
}
- if (sqlValue.Scale != scale)
+ // Convert Source Decimal Precision and Scale to Destination Precision and Scale
+ // Sql decimal data could get corrupted on insert if the scale of
+ // the source and destination weren't the same. The BCP protocol, specifies the
+ // scale of the incoming data in the insert statement, we just tell the server we
+ // are inserting the same scale back.
+ if (decValue.Scale != scale)
{
- sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale);
+ decValue = TdsParser.AdjustSqlDecimalScale(decValue, scale);
}
- if (sqlValue.Precision > precision)
+ if (decValue.Precision > precision)
{
try
{
- sqlValue = SqlDecimal.ConvertToPrecScale(sqlValue, precision, sqlValue.Scale);
+ decValue = SqlDecimal.ConvertToPrecScale(decValue, precision, decValue.Scale);
}
catch (SqlTruncateException)
{
- throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue));
+ mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
+ throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(decValue));
}
catch (Exception e)
{
+ mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e);
}
}
// Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing
- value = sqlValue;
isSqlType = true;
- typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type
- break;
+ typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type
+
+ // returning here to avoid unnecessary decValue initialization for all types
+ return WriteConvertedValue(decValue, col, isSqlType, isNull, coercedToDataFeed, metadata);
case TdsEnums.SQLINTN:
case TdsEnums.SQLFLTN:
@@ -1734,16 +1858,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
case TdsEnums.SQLDATETIME2:
case TdsEnums.SQLDATETIMEOFFSET:
mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
+ typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed);
break;
case TdsEnums.SQLNCHAR:
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false);
- value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false);
+ typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false);
if (!coercedToDataFeed)
{ // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX)
- string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value);
+ string str = typeChanged
+ ? (string)objValue
+ : isSqlType
+ ? GenericConverter.Convert(value).Value
+ : GenericConverter.Convert(value)
+ ;
+
int maxStringLength = length / 2;
if (str.Length > maxStringLength)
{
@@ -1762,8 +1892,7 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
}
break;
case TdsEnums.SQLVARIANT:
- value = ValidateBulkCopyVariant(value);
- typeChanged = true;
+ typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue);
break;
case TdsEnums.SQLUDT:
// UDTs are sent as varbinary so we need to get the raw bytes
@@ -1774,33 +1903,25 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
// in byte[] form.
if (!(value is byte[]))
{
- value = _connection.GetBytes(value);
+ objValue = _connection.GetBytes(value);
typeChanged = true;
}
break;
case TdsEnums.SQLXMLTYPE:
- // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed
+ // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed
Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype");
- if (value is XmlReader)
+ if (value is XmlReader xmlReader)
{
- value = new XmlDataFeed((XmlReader)value);
+ objValue = new XmlDataFeed(xmlReader);
typeChanged = true;
coercedToDataFeed = true;
}
break;
default:
- Debug.Assert(false, "Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null));
+ Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null));
throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null);
}
-
- if (typeChanged)
- {
- // All type changes change to CLR types
- isSqlType = false;
- }
-
- return value;
}
catch (Exception e)
{
@@ -1810,6 +1931,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re
}
throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e);
}
+
+ if (typeChanged)
+ {
+ // All type changes change to CLR types
+ isSqlType = false;
+ return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata);
+ }
+ else
+ {
+ return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata);
+ }
}
///
@@ -1842,7 +1974,7 @@ public void WriteToServer(DbDataReader reader)
_dataTableSource = null;
_rowSourceType = ValueSourceType.DbDataReader;
_isAsyncBulkCopy = false;
- WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
+ WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
}
finally
{
@@ -1879,7 +2011,7 @@ public void WriteToServer(IDataReader reader)
_dataTableSource = null;
_rowSourceType = ValueSourceType.IDataReader;
_isAsyncBulkCopy = false;
- WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
+ WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
}
finally
{
@@ -1920,7 +2052,7 @@ public void WriteToServer(DataTable table, DataRowState rowState)
_rowEnumerator = table.Rows.GetEnumerator();
_isAsyncBulkCopy = false;
- WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
+ WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
}
finally
{
@@ -1964,7 +2096,7 @@ public void WriteToServer(DataRow[] rows)
_rowEnumerator = rows.GetEnumerator();
_isAsyncBulkCopy = false;
- WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
+ WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
}
finally
{
@@ -2024,7 +2156,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok
_rowSourceType = ValueSourceType.RowArray;
_rowEnumerator = rows.GetEnumerator();
_isAsyncBulkCopy = true;
- resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
+ resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
}
finally
{
@@ -2065,7 +2197,7 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati
_dataTableSource = null;
_rowSourceType = ValueSourceType.DbDataReader;
_isAsyncBulkCopy = true;
- resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
+ resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
}
finally
{
@@ -2106,7 +2238,7 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio
_dataTableSource = null;
_rowSourceType = ValueSourceType.IDataReader;
_isAsyncBulkCopy = true;
- resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
+ resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
}
finally
{
@@ -2160,7 +2292,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat
_rowSourceType = ValueSourceType.DataTable;
_rowEnumerator = table.Rows.GetEnumerator();
_isAsyncBulkCopy = true;
- resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
+ resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true;
}
finally
{
@@ -2169,7 +2301,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat
return resultTask;
}
- // Writes row source.
+ // Writes row source.
//
private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctoken)
{
@@ -2427,34 +2559,40 @@ private bool FireRowsCopiedEvent(long rowsCopied)
return eventArgs.Abort;
}
- // Reads a cell and then writes it.
- // Read may block at this moment since there is no getValueAsync or DownStream async at this moment.
- // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance.
- // When _isAsyncBulkCopy == false: Writes are purely sync. This method reutrn null at the end.
- //
- private Task ReadWriteColumnValueAsync(int col)
+ private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull)
{
- bool isSqlType;
- bool isDataFeed;
- bool isNull;
- Object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask
-
_SqlMetaData metadata = _sortedColumnMappings[col]._metadata;
- if (!isDataFeed)
+ if (isDataFeed)
{
- value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed);
+ //nothing to convert, skip straight to write
+ return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata);
+ }
+ else
+ {
+ return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType);
+ }
+ }
- // If column encryption is requested via connection string option, perform encryption here
- if (!isNull && // if value is not NULL
- metadata.isEncrypted)
- { // If we are transparently encrypting
- Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy());
- value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType);
- isSqlType = false; // Its not a sql type anymore
- }
+ private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata)
+ {
+ // If column encryption is requested via connection string option, perform encryption here
+ if (!isNull && // if value is not NULL
+ metadata.isEncrypted)
+ { // If we are transparently encrypting
+ Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy());
+ var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType);
+ isSqlType = false; // Its not a sql type anymore
+
+ return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata);
}
+ else
+ {
+ return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata);
+ }
+ }
- //write part
+ private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata)
+ {
Task writeTask = null;
if (metadata.type != SqlDbType.Variant)
{
@@ -2473,15 +2611,15 @@ private Task ReadWriteColumnValueAsync(int col)
if (variantInternalType == SqlBuffer.StorageType.DateTime2)
{
- _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj);
+ _parser.WriteSqlVariantDateTime2(GenericConverter.Convert(value), _stateObj);
}
else if (variantInternalType == SqlBuffer.StorageType.Date)
{
- _parser.WriteSqlVariantDate(((DateTime)value), _stateObj);
+ _parser.WriteSqlVariantDate(GenericConverter.Convert(value), _stateObj);
}
else
{
- writeTask = _parser.WriteSqlVariantDataRowValue(value, _stateObj); //returns Task/Null
+ writeTask = _parser.WriteSqlVariantDataRowValue(value, isNull, _stateObj); //returns Task/Null
}
}
@@ -2563,7 +2701,7 @@ private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource
+
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs
index 9b56183ccf..1bcc1fa83d 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs
@@ -159,7 +159,7 @@ private bool StringToIntTest(SqlConnection cnn, string targetTable, SourceType s
string expectedErrorMsg = string.Format(pattern, args);
- Assert.True(ex.Message.Contains(expectedErrorMsg), "Unexpected error message: " + ex.Message);
+ Assert.True(ex.Message.Contains(expectedErrorMsg), $"Unexpected error message: {ex}"); // write out stack trace for unexpected messages
hitException = true;
}
return hitException;
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs
new file mode 100644
index 0000000000..8b1f619f15
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs
@@ -0,0 +1,429 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Linq;
+using System.Linq.Expressions;
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Configs;
+using BenchmarkDotNet.Diagnosers;
+using BenchmarkDotNet.Jobs;
+using BenchmarkDotNet.Running;
+using BenchmarkDotNet.Validators;
+using Xunit;
+
+namespace Microsoft.Data.SqlClient.ManualTesting.Tests
+{
+ public class NoBoxingValueTypes : IDisposable
+ {
+ private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes));
+ private const int _count = 5000;
+ private static readonly ItemToCopy _item;
+ private static readonly IEnumerable _items;
+ private static readonly IDataReader _reader;
+
+ private static readonly string _connString = DataTestUtility.TCPConnectionString;
+
+ private class ItemToCopy
+ {
+ // keeping this data static so the performance of the benchmark is not varied by the data size & shape
+ public int IntColumn { get; } = 123456;
+ public bool BoolColumn { get; } = true;
+ }
+
+ static NoBoxingValueTypes()
+ {
+ _item = new ItemToCopy();
+
+ _items = Enumerable.Range(0, _count).Select(x => _item).ToArray();
+
+ _reader = new EnumerableDataReaderFactoryBuilder(_table)
+ .Add("IntColumn", i => i.IntColumn)
+ .Add("BoolColumn", i => i.BoolColumn)
+ .BuildFactory()
+ .CreateReader(_items)
+ ;
+ }
+
+ public NoBoxingValueTypes()
+ {
+ using (var conn = new SqlConnection(_connString))
+ using (var cmd = conn.CreateCommand())
+ {
+ conn.Open();
+ Helpers.TryExecute(cmd, $@"
+ CREATE TABLE {_table} (
+ IntColumn INT NOT NULL,
+ BoolColumn BIT NOT NULL
+ )
+ ");
+ }
+ }
+
+ private class RunOnceConfig : ManualConfig
+ {
+ public RunOnceConfig()
+ {
+ Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0));
+ Add(MemoryDiagnoser.Default);
+
+ Add(JitOptimizationsValidator.DontFailOnError);
+ }
+ }
+
+
+
+ [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))]
+ public void Should_Not_Box()
+ { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail
+#if DEBUG
+ return;
+#else
+ //cannot figure out an easy way to get this to work on all platforms
+
+ var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461
+
+ var summary = BenchmarkRunner.Run(config);
+
+ var numValueTypeColumns = 2;
+ var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns;
+
+ var report = summary.Reports.First();
+
+ Assert.Equal(1, report.AllMeasurements.Count);
+ Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed);
+#endif
+ }
+
+ public class NoBoxingValueTypesBenchmark
+ {
+ [Benchmark]
+ public void BulkCopy()
+ {
+ _reader.Close(); // this resets the reader
+
+ using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock))
+ {
+ bc.BatchSize = _count;
+ bc.DestinationTableName = _table;
+ bc.BulkCopyTimeout = 60;
+
+ bc.WriteToServer(_reader);
+ }
+ }
+ }
+
+ public void Dispose()
+ {
+ using (var conn = new SqlConnection(_connString))
+ using (var cmd = conn.CreateCommand())
+ {
+ conn.Open();
+ Helpers.TryExecute(cmd, $@"
+ DROP TABLE IF EXISTS {_table}
+ ");
+ }
+ }
+
+ //all code here and below is a custom data reader implementation to support the benchmark
+ private class EnumerableDataReaderFactoryBuilder
+ {
+ private readonly List _expressions = new List();
+ private readonly List> _objExpressions = new List>();
+ private readonly DataTable _schemaTable;
+
+ public EnumerableDataReaderFactoryBuilder(string tableName)
+ {
+ Name = tableName;
+ _schemaTable = new DataTable();
+ }
+
+ private static readonly HashSet _validTypes = new HashSet
+ {
+ typeof(decimal),
+ typeof(decimal?),
+ typeof(string),
+ typeof(int),
+ typeof(int?),
+ typeof(double),
+ typeof(bool),
+ typeof(bool?),
+ typeof(Guid),
+ typeof(DateTime),
+ };
+
+ public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression)
+ {
+ var t = typeof(TColumn);
+
+ var func = expression.Compile();
+
+ // don't do any optimizations for boxing bools here to detect boxing occurring properly.
+ Expression> objExpression= o => func(o);
+
+ _objExpressions.Add(objExpression.Compile());
+
+ if (_validTypes.Contains(t))
+ {
+ t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable.
+ _schemaTable.Columns.Add(column, t);
+ _expressions.Add(expression);
+ }
+ else
+ {
+ Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}");
+ _schemaTable.Columns.Add(column); //add w/o type to force using GetValue
+
+ _expressions.Add(objExpression);
+ }
+
+ return this;
+ }
+
+ public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions);
+
+ public string Name { get; }
+ }
+
+ public class EnumerableDataReaderFactory
+ {
+ public DataTable SchemaTable { get; }
+ public Func[] ObjectGetters { get; }
+ public Func[] DecimalGetters { get; }
+ public Func[] NullableDecimalGetters { get; }
+ public Func[] StringGetters { get; }
+ public Func[] DoubleGetters { get; }
+ public Func[] IntGetters { get; }
+ public Func[] NullableIntGetters { get; }
+ public Func[] BoolGetters { get; }
+
+ public Func[] NullableBoolGetters { get; }
+
+ public Func[] GuidGetters { get; }
+ public Func[] DateTimeGetters { get; }
+ public bool[] NullableIndexes { get; }
+
+ public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters)
+ {
+ SchemaTable = schemaTable;
+ DecimalGetters = new Func[expressions.Count];
+ NullableDecimalGetters = new Func[expressions.Count];
+ StringGetters = new Func[expressions.Count];
+ DoubleGetters = new Func[expressions.Count];
+ IntGetters = new Func[expressions.Count];
+ NullableIntGetters = new Func[expressions.Count];
+ BoolGetters = new Func[expressions.Count];
+ NullableBoolGetters = new Func[expressions.Count];
+ GuidGetters = new Func[expressions.Count];
+ DateTimeGetters = new Func[expressions.Count];
+ NullableIndexes = new bool[expressions.Count];
+
+ ObjectGetters = objectGetters.ToArray();
+
+ for (int i = 0; i < expressions.Count; i++)
+ {
+ var expression = expressions[i];
+
+ NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null;
+
+ switch (expression)
+ {
+ case Expression> e:
+ break; // do nothing
+ case Expression> e:
+ DecimalGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ NullableDecimalGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ StringGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ DoubleGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ IntGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ NullableIntGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ BoolGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ NullableBoolGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ GuidGetters[i] = e.Compile();
+ break;
+ case Expression> e:
+ DateTimeGetters[i] = e.Compile();
+ break;
+ default:
+ throw new Exception($"Type missing: {expression.GetType().FullName}");
+ }
+ }
+ }
+
+ public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator());
+ }
+
+ public class EnumerableDataReader : IDataReader
+ {
+ private readonly IEnumerator _source;
+ private readonly EnumerableDataReaderFactory _context;
+
+ public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source)
+ {
+ _source = source;
+ _context = context;
+ }
+
+ public object GetValue(int i)
+ {
+ var v = _context.ObjectGetters[i](_source.Current);
+ return v;
+ }
+
+ public int FieldCount => _context.ObjectGetters.Length;
+
+ public bool Read() => _source.MoveNext();
+
+ public void Close() => _source.Reset();
+
+ public void Dispose() => this.Close();
+
+ public bool NextResult() => throw new NotImplementedException();
+
+ public int Depth => 0;
+
+ public bool IsClosed => false;
+
+ public int RecordsAffected => -1;
+
+ public DataTable GetSchemaTable() => _context.SchemaTable;
+
+ public object this[string name] => throw new NotImplementedException();
+
+ public object this[int i] => GetValue(i);
+
+ public bool GetBoolean(int i)
+ {
+ var g = _context.BoolGetters[i];
+
+ if (g != null)
+ return g(_source.Current);
+
+ return _context.NullableBoolGetters[i](_source.Current).Value;
+ }
+
+ public byte GetByte(int i) => throw new NotImplementedException();
+
+ public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException();
+
+ public char GetChar(int i) => throw new NotImplementedException();
+ public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1;
+
+ public IDataReader GetData(int i) => throw new NotImplementedException();
+
+ public string GetDataTypeName(int i) => throw new NotImplementedException();
+
+ public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current);
+
+ public decimal GetDecimal(int i)
+ {
+ var g = _context.DecimalGetters[i];
+
+ if (g != null)
+ return g(_source.Current);
+
+ return _context.NullableDecimalGetters[i](_source.Current).Value;
+ }
+
+ public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current);
+
+ public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType;
+
+ public float GetFloat(int i) => throw new NotImplementedException();
+
+ public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current);
+
+ public short GetInt16(int i) => throw new NotImplementedException();
+
+ public int GetInt32(int i)
+ {
+ var g = _context.IntGetters[i];
+
+ if (g != null)
+ return g(_source.Current);
+
+ return _context.NullableIntGetters[i](_source.Current).Value;
+ }
+
+ public long GetInt64(int i) => throw new NotImplementedException();
+
+ public string GetName(int i)
+ {
+ if (_context.SchemaTable.Columns.Count > i)
+ {
+ return _context.SchemaTable.Columns[i].ColumnName;
+ }
+ throw new IndexOutOfRangeException($"No column for index {i}");
+ }
+
+ public int GetOrdinal(string name)
+ {
+ if (_context.SchemaTable.Columns.Count == 0)
+ {
+ throw new Exception("Schema table is empty");
+ }
+ return _context.SchemaTable.Columns.IndexOf(name);
+ }
+
+ public string GetString(int i) => _context.StringGetters[i](_source.Current);
+
+ public int GetValues(object[] values) => throw new NotImplementedException();
+
+ public bool IsDBNull(int i)
+ {
+ // short circuit for non-nullable types
+ if (!_context.NullableIndexes[i])
+ {
+ return false;
+ }
+
+ // otherwise find the first one -- starting w/ most occurring to least
+
+ var ig = _context.NullableIntGetters[i];
+ if (ig != null)
+ {
+ return ig(_source.Current) == null;
+ }
+
+ var sg = _context.StringGetters[i];
+ if (sg != null)
+ {
+ return sg(_source.Current) == null;
+ }
+
+ var bg = _context.NullableBoolGetters[i];
+ if (bg != null)
+ {
+ return bg(_source.Current) == null;
+ }
+
+ var dg = _context.NullableDecimalGetters[i];
+ if (dg != null)
+ {
+ return dg(_source.Current) == null;
+ }
+
+ return false;
+ }
+ }
+ }
+}
diff --git a/tools/props/Versions.props b/tools/props/Versions.props
index 044d79f9d2..e9f4b9c0b9 100644
--- a/tools/props/Versions.props
+++ b/tools/props/Versions.props
@@ -12,7 +12,7 @@
- 2.1.0
+ 2.1.0
4.3.1
4.3.0
@@ -26,6 +26,7 @@
+ 0.11.3
4.7.0
2.1.0
4.7.0