diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index ff1db9b5e6..5a9d974af6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -709,6 +709,7 @@ + True True diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs new file mode 100644 index 0000000000..d4298afa2d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericConverter.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient +{ + /// + /// Serves to convert generic to out type by casting to object first. Relies on JIT to optimize out unneccessary casts and prevent double boxing. + /// + internal static class GenericConverter + { + public static TOut Convert(TIn value) + { + return (TOut)(object)value; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 8b8f689d0c..4e87235780 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -136,8 +136,8 @@ private enum ValueSourceType DbDataReader } - // Enum for specifying SqlDataReader.Get method used - private enum ValueMethod : byte + // Enum for specifying SqlDataReader.Get / IDataReader Get method used + private enum ValueMethod { GetValue, SqlTypeSqlDecimal, @@ -145,7 +145,19 @@ private enum ValueMethod : byte SqlTypeSqlSingle, DataFeedStream, DataFeedText, - DataFeedXml + DataFeedXml, + GetInt32, + GetString, + GetDouble, + GetDecimal, + GetInt16, + GetInt64, + GetChar, + GetByte, + GetBoolean, + GetDateTime, + GetGuid, + GetFloat } // Used to hold column metadata for SqlDataReader case @@ -943,12 +955,19 @@ private void Dispose(bool disposing) } } - // Unified method to read a value from the current row - private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull) + // Reads a cell and then writes it. + // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. + // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. + // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. + private Task ReadWriteColumnValueAsync(int destRowIndex) { _SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata; int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal; + bool isSqlType = false; + bool isDataFeed = false; + bool isNull = false; + switch (_rowSourceType) { case ValueSourceType.IDataReader: @@ -958,34 +977,39 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_DbDataReaderRowSource.IsDBNull(sourceOrdinal)) { - isSqlType = false; - isDataFeed = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - isSqlType = false; isDataFeed = true; - isNull = false; + + object feedColumnValue; + switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.DataFeedStream: - return new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal)); + feedColumnValue = new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal)); + break; case ValueMethod.DataFeedText: - return new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal)); + feedColumnValue = new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal)); + break; case ValueMethod.DataFeedXml: // Only SqlDataReader supports an XmlReader // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field Debug.Assert(_SqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader"); - return new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + feedColumnValue = new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + break; default: Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); isDataFeed = false; - object columnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + feedColumnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType); + break; } + + //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario + return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull); } } // SqlDataReader-specific logic @@ -993,36 +1017,28 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_currentRowMetadata[destRowIndex].IsSqlType) { - INullable value; isSqlType = true; - isDataFeed = false; switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.SqlTypeSqlDecimal: - value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); - break; + var value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull); case ValueMethod.SqlTypeSqlDouble: // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); - break; + var dblValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); + return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull); case ValueMethod.SqlTypeSqlSingle: - // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); - break; + // use cast to handle value.IsNull correctly because no public constructor allows it + var singleValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); + return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull); default: Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - value = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal); - break; + var sqlValue = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull); } - - isNull = value.IsNull; - return value; } else { - isSqlType = false; - isDataFeed = false; - object value = _SqlDataReaderRowSource.GetValue(sourceOrdinal); isNull = ((value == null) || (value == DBNull.Value)); if ((!isNull) && (metadata.type == SqlDbType.Udt)) @@ -1036,27 +1052,63 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable"); } #endif - return value; + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull); } } else { - isDataFeed = false; - IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource; // Only use IsDbNull when streaming is enabled and only for non-SqlDataReader if ((_enableStreaming) && (_SqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal))) { - isSqlType = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + if (_currentRowMetadata[destRowIndex].Method == ValueMethod.GetValue || rowSourceAsIDataReader.IsDBNull(sourceOrdinal)) + { + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + switch (_currentRowMetadata[sourceOrdinal].Method) + { + case ValueMethod.GetInt32: + return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false); + case ValueMethod.GetString: + var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal); + isNull = strValue == null; + return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDouble: + return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDecimal: + return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt16: + return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt64: + return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetChar: + return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetByte: + return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetBoolean: + return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDateTime: + return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetGuid: + return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetFloat: + return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + default: + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } + } } } case ValueSourceType.DataTable: @@ -1066,6 +1118,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!"); isDataFeed = false; + // unfortunately this has to be boxed due to DataRow's API. object currentRowValue = _currentRow[sourceOrdinal]; ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType); @@ -1078,7 +1131,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (isSqlType) { - return new SqlDecimal(((SqlSingle)currentRowValue).Value); + var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value); + return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1086,16 +1140,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!float.IsNaN(f)) { isSqlType = true; - return new SqlDecimal(f); + return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDouble: { if (isSqlType) { - return new SqlDecimal(((SqlDouble)currentRowValue).Value); + var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1103,33 +1161,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!double.IsNaN(d)) { isSqlType = true; - return new SqlDecimal(d); + return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDecimal: { if (isSqlType) { - return (SqlDecimal)currentRowValue; + var sqlValue = (SqlDecimal)currentRowValue; + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { isSqlType = true; - return new SqlDecimal((decimal)currentRowValue); + var sqlValue = new SqlDecimal((decimal)currentRowValue); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } } default: { Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - break; + // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } } } - - // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) - return currentRowValue; + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } default: { @@ -1320,6 +1385,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) method = ValueMethod.GetValue; } } + else if (_rowSourceType == ValueSourceType.IDataReader) + { + isSqlType = false; + isDataFeed = false; + + Type t = ((IDataReader)_rowSource).GetFieldType(ordinal); + + if (t == typeof(bool)) + { + method = ValueMethod.GetBoolean; + } + else if (t == typeof(byte)) + { + method = ValueMethod.GetByte; + } + else if (t == typeof(char)) + { + method = ValueMethod.GetChar; + } + else if (t == typeof(DateTime)) + { + method = ValueMethod.GetDateTime; + } + else if (t == typeof(decimal)) + { + method = ValueMethod.GetDecimal; + } + else if (t == typeof(double)) + { + method = ValueMethod.GetDouble; + } + else if (t == typeof(float)) + { + method = ValueMethod.GetFloat; + } + else if (t == typeof(Guid)) + { + method = ValueMethod.GetGuid; + } + else if (t == typeof(short)) + { + method = ValueMethod.GetInt16; + } + else if (t == typeof(int)) + { + method = ValueMethod.GetInt32; + } + else if (t == typeof(long)) + { + method = ValueMethod.GetInt64; + } + else if (t == typeof(string)) + { + method = ValueMethod.GetString; + } + else + { + method = ValueMethod.GetValue; + } + } else { isSqlType = false; @@ -1454,8 +1579,10 @@ private string UnquotedName(string name) return name; } - private object ValidateBulkCopyVariant(object value) + private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue) { + variantValue = null; + // From the spec: // "The only acceptable types are ..." // GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8, @@ -1483,20 +1610,21 @@ private object ValidateBulkCopyVariant(object value) case TdsEnums.SQLDATETIMEOFFSET: if (value is INullable) { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types. - return MetaType.GetComValueFromSqlVariant(value); + variantValue = MetaType.GetComValueFromSqlVariant(value); + return true; } else { - return value; + return false; } default: throw SQL.BulkLoadInvalidVariantValue(); } } - private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed) + private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType) { - coercedToDataFeed = false; + bool coercedToDataFeed = false; if (isNull) { @@ -1504,11 +1632,13 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); } - return value; + + return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata); } MetaType type = metadata.metaType; bool typeChanged = false; + object objValue = null; // If the column is encrypted then we are going to transparently encrypt this column // (based on connection string setting)- Use the metaType for the underlying @@ -1533,46 +1663,50 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { case TdsEnums.SQLNUMERICN: case TdsEnums.SQLDECIMALN: - mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); - - // Convert Source Decimal Precision and Scale to Destination Precision and Scale - // Sql decimal data could get corrupted on insert if the scale of - // the source and destination weren't the same. The BCP protocol, specifies the - // scale of the incoming data in the insert statement, we just tell the server we - // are inserting the same scale back. - SqlDecimal sqlValue; - if ((isSqlType) && (!typeChanged)) + SqlDecimal decValue; + if (typeof(T) == typeof(decimal)) + { + decValue = new SqlDecimal(GenericConverter.Convert(value)); + } + else if (typeof(T) == typeof(SqlDecimal)) { - sqlValue = (SqlDecimal)value; + decValue = GenericConverter.Convert(value); } else { - sqlValue = new SqlDecimal((decimal)value); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + decValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false)); } - if (sqlValue.Scale != scale) + // Convert Source Decimal Precision and Scale to Destination Precision and Scale + // Sql decimal data could get corrupted on insert if the scale of + // the source and destination weren't the same. The BCP protocol, specifies the + // scale of the incoming data in the insert statement, we just tell the server we + // are inserting the same scale back. + if (decValue.Scale != scale) { - sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale); + decValue = TdsParser.AdjustSqlDecimalScale(decValue, scale); } - if (sqlValue.Precision > precision) + if (decValue.Precision > precision) { try { - sqlValue = SqlDecimal.ConvertToPrecScale(sqlValue, precision, sqlValue.Scale); + decValue = SqlDecimal.ConvertToPrecScale(decValue, precision, decValue.Scale); } catch (SqlTruncateException) { - throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue)); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(decValue)); } } // Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing - value = sqlValue; isSqlType = true; typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type - break; + + // returning here to avoid unnecessary decValue initialization for all types + return WriteConvertedValue(decValue, col, isSqlType, isNull, coercedToDataFeed, metadata); case TdsEnums.SQLINTN: case TdsEnums.SQLFLTN: @@ -1596,16 +1730,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed); break; case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false); if (!coercedToDataFeed) { // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX) - string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value); + string str = typeChanged + ? (string)objValue + : isSqlType + ? GenericConverter.Convert(value).Value + : GenericConverter.Convert(value) + ; + int maxStringLength = length / 2; if (str.Length > maxStringLength) { @@ -1622,10 +1762,10 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re throw SQL.BulkLoadStringTooLong(_destinationTableName, metadata.column, str); } } + break; case TdsEnums.SQLVARIANT: - value = ValidateBulkCopyVariant(value); - typeChanged = true; + typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue); break; case TdsEnums.SQLUDT: // UDTs are sent as varbinary so we need to get the raw bytes @@ -1636,16 +1776,16 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re // in byte[] form. if (!(value is byte[])) { - value = _connection.GetBytes(value); + objValue = _connection.GetBytes(value); typeChanged = true; } break; case TdsEnums.SQLXMLTYPE: // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype"); - if (value is XmlReader) + if (value is XmlReader xmlReader) { - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); typeChanged = true; coercedToDataFeed = true; } @@ -1655,14 +1795,6 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null)); throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null); } - - if (typeChanged) - { - // All type changes change to CLR types - isSqlType = false; - } - - return value; } catch (Exception e) { @@ -1672,6 +1804,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } + + if (typeChanged) + { + // All type changes change to CLR types + isSqlType = false; + return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata); + } + else + { + return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata); + } } /// @@ -2205,33 +2348,40 @@ private bool FireRowsCopiedEvent(long rowsCopied) return eventArgs.Abort; } - // Reads a cell and then writes it. - // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. - // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. - // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. - private Task ReadWriteColumnValueAsync(int col) + private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull) { - bool isSqlType; - bool isDataFeed; - bool isNull; - object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask - _SqlMetaData metadata = _sortedColumnMappings[col]._metadata; - if (!isDataFeed) + if (isDataFeed) + { + //nothing to convert, skip straight to write + return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata); + } + else { - value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed); + return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType); + } + } - // If column encryption is requested via connection string option, perform encryption here - if (!isNull && // if value is not NULL - metadata.isEncrypted) - { // If we are transparently encrypting - Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); - value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType); - isSqlType = false; // Its not a sql type anymore - } + private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata) + { + // If column encryption is requested via connection string option, perform encryption here + if (!isNull && // if value is not NULL + metadata.isEncrypted) + { // If we are transparently encrypting + Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); + var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType); + isSqlType = false; // Its not a sql type anymore + + return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata); } + else + { + return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata); + } + } - //write part + private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata) + { Task writeTask = null; if (metadata.type != SqlDbType.Variant) { @@ -2250,15 +2400,15 @@ private Task ReadWriteColumnValueAsync(int col) if (variantInternalType == SqlBuffer.StorageType.DateTime2) { - _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDateTime2(GenericConverter.Convert(value), _stateObj); } else if (variantInternalType == SqlBuffer.StorageType.Date) { - _parser.WriteSqlVariantDate(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDate(GenericConverter.Convert(value), _stateObj); } else { - writeTask = _parser.WriteSqlVariantDataRowValue(value, _stateObj); //returns Task/Null + writeTask = _parser.WriteSqlVariantDataRowValue(value, isNull, _stateObj); //returns Task/Null } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs index 08029b8a7a..2a2fb2a1a9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -107,7 +107,7 @@ public sealed partial class SqlParameter : DbParameter, IDbDataParameter, IClone /// /// Indicates if the parameter encryption metadata received by sp_describe_parameter_encryption. - /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate + /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate /// that no encryption is needed). /// internal bool HasReceivedMetadata { get; set; } @@ -407,7 +407,7 @@ internal SmiParameterMetaData MetaDataForSmi(out ParameterPeekAheadValue peekAhe long actualLen = GetActualSize(); long maxLen = this.Size; - // GetActualSize returns bytes length, but smi expects char length for + // GetActualSize returns bytes length, but smi expects char length for // character types, so adjust if (!mt.IsLong) { @@ -995,15 +995,27 @@ object ICloneable.Clone() } // Coerced Value is also used in SqlBulkCopy.ConvertValue(object value, _SqlMetaData metadata) + internal static object CoerceValue(object value, MetaType destinationType, out bool coercedToDataFeed, out bool typeChanged, bool allowStreaming = true) + { + typeChanged = CoerceValueIfNeeded(value, destinationType, out var objValue, out coercedToDataFeed, allowStreaming); + + return typeChanged ? objValue : value; + } + + internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, out object objValue, out bool coercedToDataFeed, bool allowStreaming = true) { Debug.Assert(!(value is DataFeed), "Value provided should not already be a data feed"); Debug.Assert(!ADP.IsNull(value), "Value provided should not be null"); Debug.Assert(null != destinationType, "null destinationType"); coercedToDataFeed = false; - typeChanged = false; - Type currentType = value.GetType(); + objValue = null; + Type currentType = typeof(T) == typeof(object) + ? value.GetType() // only call GetType if we know boxing has already occurred. + : typeof(T); + + var typeChanged = false; if ((typeof(object) != destinationType.ClassType) && (currentType != destinationType.ClassType) && @@ -1018,45 +1030,45 @@ internal static object CoerceValue(object value, MetaType destinationType, out b // For Xml data, destination Type is always string if (typeof(SqlXml) == currentType) { - value = MetaType.GetStringFromXml((XmlReader)(((SqlXml)value).CreateReader())); + objValue = MetaType.GetStringFromXml(GenericConverter.Convert(value).CreateReader()); } else if (typeof(SqlString) == currentType) { typeChanged = false; // Do nothing } - else if (typeof(XmlReader).IsAssignableFrom(currentType)) + else if (value is XmlReader xmlReader) { if (allowStreaming) { coercedToDataFeed = true; - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); } else { - value = MetaType.GetStringFromXml((XmlReader)value); + objValue = MetaType.GetStringFromXml(xmlReader); } } else if (typeof(char[]) == currentType) { - value = new string((char[])value); + objValue = new string(GenericConverter.Convert(value)); } else if (typeof(SqlChars) == currentType) { - value = new string(((SqlChars)value).Value); + objValue = new string(GenericConverter.Convert(value).Value); } - else if (value is TextReader && allowStreaming) + else if (value is TextReader tr && allowStreaming) { coercedToDataFeed = true; - value = new TextDataFeed((TextReader)value); + objValue = new TextDataFeed(tr); } else { - value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); + objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); } } else if ((DbType.Currency == destinationType.DbType) && (typeof(string) == currentType)) { - value = decimal.Parse((string)value, NumberStyles.Currency, (IFormatProvider)null); + objValue = decimal.Parse(GenericConverter.Convert(value), NumberStyles.Currency, (IFormatProvider)null); } else if ((typeof(SqlBytes) == currentType) && (typeof(byte[]) == destinationType.ClassType)) { @@ -1064,15 +1076,15 @@ internal static object CoerceValue(object value, MetaType destinationType, out b } else if ((typeof(string) == currentType) && (SqlDbType.Time == destinationType.SqlDbType)) { - value = TimeSpan.Parse((string)value); + objValue = TimeSpan.Parse(GenericConverter.Convert(value)); } else if ((typeof(string) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType)) { - value = DateTimeOffset.Parse((string)value, (IFormatProvider)null); + objValue = DateTimeOffset.Parse(GenericConverter.Convert(value), (IFormatProvider)null); } else if ((typeof(DateTime) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType)) { - value = new DateTimeOffset((DateTime)value); + objValue = new DateTimeOffset(GenericConverter.Convert(value)); } else if (TdsEnums.SQLTABLE == destinationType.TDSType && ( value is DataTable || @@ -1082,14 +1094,14 @@ value is DbDataReader || // no conversion for TVPs. typeChanged = false; } - else if (destinationType.ClassType == typeof(byte[]) && value is Stream && allowStreaming) + else if (destinationType.ClassType == typeof(byte[]) && allowStreaming && value is Stream stream) { coercedToDataFeed = true; - value = new StreamDataFeed((Stream)value); + objValue = new StreamDataFeed(stream); } else { - value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); + objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); } } catch (Exception e) @@ -1104,8 +1116,8 @@ value is DbDataReader || } Debug.Assert(allowStreaming || !coercedToDataFeed, "Streaming is not allowed, but type was coerced into a data feed"); - Debug.Assert(value.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); - return value; + Debug.Assert(objValue == null || objValue.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); + return typeChanged; } internal void FixStreamDataForNonPLP() diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index b0b25d2e09..51194adb8c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2907,7 +2907,7 @@ private bool TryProcessDone(SqlCommand cmd, SqlDataReader reader, ref RunBehavio int count; // This is added back since removing it from here introduces regressions in Managed SNI. - // It forces SqlDataReader.ReadAsync() method to run synchronously, + // It forces SqlDataReader.ReadAsync() method to run synchronously, // and will block the calling thread until data is fed from SQL Server. // TODO Investigate better solution to support non-blocking ReadAsync(). stateObj._syncOverAsync = true; @@ -6630,6 +6630,7 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars stateObj.WriteByte(mt.Precision); //propbytes: precision stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale WriteDecimal((decimal)value, stateObj); + WriteDecimal((decimal)value, stateObj); break; } @@ -6664,10 +6665,10 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars // Therefore the sql_variant value must not include the MaxLength. This is the major difference // between this method and WriteSqlVariantValue above. // - internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject stateObj, bool canAccumulate = true) + internal Task WriteSqlVariantDataRowValue(T value, bool isNull, TdsParserStateObject stateObj, bool canAccumulate = true) { // handle null values - if ((null == value) || (DBNull.Value == value)) + if (isNull) { WriteInt(TdsEnums.FIXEDNULL, stateObj); return null; @@ -6678,44 +6679,44 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta if (metatype.IsAnsiType) { - length = GetEncodingCharLength((string)value, length, 0, _defaultEncoding); + length = GetEncodingCharLength(GenericConverter.Convert(value), length, 0, _defaultEncoding); } switch (metatype.TDSType) { case TdsEnums.SQLFLT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteFloat((float)value, stateObj); + WriteFloat(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLFLT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteDouble((double)value, stateObj); + WriteDouble(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteLong((long)value, stateObj); + WriteLong(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteInt((int)value, stateObj); + WriteInt(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT2: WriteSqlVariantHeader(4, metatype.TDSType, metatype.PropBytes, stateObj); - WriteShort((short)value, stateObj); + WriteShort(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT1: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - stateObj.WriteByte((byte)value); + stateObj.WriteByte(GenericConverter.Convert(value)); break; case TdsEnums.SQLBIT: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - if ((bool)value == true) + if (GenericConverter.Convert(value)) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -6724,7 +6725,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARBINARY: { - byte[] b = (byte[])value; + byte[] b = GenericConverter.Convert(value); length = b.Length; WriteSqlVariantHeader(4 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6734,7 +6735,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARCHAR: { - string s = (string)value; + string s = GenericConverter.Convert(value); length = s.Length; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6746,7 +6747,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLUNIQUEID: { - System.Guid guid = (System.Guid)value; + Guid guid = GenericConverter.Convert(value); Span b = stackalloc byte[16]; FillGuidBytes(guid, b); @@ -6759,7 +6760,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLNVARCHAR: { - string s = (string)value; + string s = GenericConverter.Convert(value); length = s.Length * 2; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6774,7 +6775,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLDATETIME: { - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, 8); + TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), 8); WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); WriteInt(dt.days, stateObj); @@ -6785,7 +6786,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLMONEY: { WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteCurrency((decimal)value, 8, stateObj); + WriteCurrency(GenericConverter.Convert(value), 8, stateObj); break; } @@ -6793,21 +6794,22 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta { WriteSqlVariantHeader(21, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Precision); //propbytes: precision - stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale - WriteDecimal((decimal)value, stateObj); + var decValue = GenericConverter.Convert(value); + stateObj.WriteByte((byte)((decimal.GetBits(decValue)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale + WriteDecimal(decValue, stateObj); break; } case TdsEnums.SQLTIME: WriteSqlVariantHeader(8, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteTime((TimeSpan)value, metatype.Scale, 5, stateObj); + WriteTime(GenericConverter.Convert(value), metatype.Scale, 5, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: WriteSqlVariantHeader(13, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteDateTimeOffset((DateTimeOffset)value, metatype.Scale, 10, stateObj); + WriteDateTimeOffset(GenericConverter.Convert(value), metatype.Scale, 10, stateObj); break; default: @@ -10409,7 +10411,7 @@ internal bool ShouldEncryptValuesForBulkCopy() /// Encrypts a column value (for SqlBulkCopy) /// /// - internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) + internal byte[] EncryptColumnValue(T value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) { Debug.Assert(IsColumnEncryptionSupported, "Server doesn't support encryption, yet we received encryption metadata"); Debug.Assert(ShouldEncryptValuesForBulkCopy(), "Encryption attempted when not requested"); @@ -10434,7 +10436,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin // when we normalize and serialize the data buffers. The serialization routine expects us // to report the size of data to be copied out (for serialization). If we underreport the // size, truncation will happen for us! - actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + actualLengthInBytes = (isSqlType) + ? GenericConverter.Convert(value).Length + : GenericConverter.Convert(value).Length; + if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) { @@ -10454,7 +10459,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = (isSqlType) ? ((SqlString)value).Value : (string)value; + string stringValue = (isSqlType) + ? GenericConverter.Convert(value).Value + : GenericConverter.Convert(value); + actualLengthInBytes = _defaultEncoding.GetByteCount(stringValue); // If the string length is > max length, then use the max length (see comments above) @@ -10468,7 +10476,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - actualLengthInBytes = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + actualLengthInBytes = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length) + * 2; if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) @@ -10513,7 +10524,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin _connHandler.ConnectionOptions.DataSource); } - internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) + internal Task WriteBulkCopyValue(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) { Debug.Assert(!isSqlType || value is INullable, "isSqlType is true, but value can not be type cast to an INullable"); Debug.Assert(!isDataFeed ^ value is DataFeed, "Incorrect value for isDataFeed"); @@ -10578,7 +10589,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: - ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + ccb = (isSqlType) + ? GenericConverter.Convert(value).Length + : GenericConverter.Convert(value).Length; break; case TdsEnums.SQLUNIQUEID: ccb = GUID_SIZE; @@ -10594,11 +10607,11 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars string stringValue = null; if (isSqlType) { - stringValue = ((SqlString)value).Value; + stringValue = GenericConverter.Convert(value).Value; } else { - stringValue = (string)value; + stringValue = GenericConverter.Convert(value); } ccb = stringValue.Length; @@ -10607,15 +10620,22 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + ccb = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length + ) * 2; break; case TdsEnums.SQLXMLTYPE: // Value here could be string or XmlReader - if (value is XmlReader) + // the XmlReader scenario can only occur when T is object (enforced during SqlBulkCopy.ReadWriteColumnValueAsync) + if (typeof(T) == typeof(object) && value is XmlReader xr) { - value = MetaType.GetStringFromXml((XmlReader)value); + value = GenericConverter.Convert(MetaType.GetStringFromXml(xr)); } - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + ccb = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length + ) * 2; break; default: @@ -10667,7 +10687,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars } else if (metatype.SqlDbType != SqlDbType.Udt || metatype.IsLong) { + // we only have to consider a conversion from above in this case. internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); + if ((internalWriteTask == null) && (_asyncWrite)) { internalWriteTask = stateObj.WaitForAccumulatedWrites(); @@ -10677,7 +10699,7 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars else { WriteShort(ccb, stateObj); - internalWriteTask = stateObj.WriteByteArray((byte[])value, ccb, 0); + internalWriteTask = stateObj.WriteByteArray(GenericConverter.Convert(value), ccb, 0); } #if DEBUG @@ -10986,7 +11008,7 @@ private bool IsBOMNeeded(MetaType type, object value) return false; } - private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) + private Task GetTerminationTask(Task unterminatedWriteTask, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) { if (type.IsPlp && ((actualLength > 0) || isDataFeed)) { @@ -11007,16 +11029,16 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } - private Task WriteSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { return GetTerminationTask( WriteUnterminatedSqlValue(value, type, actualLength, codePageByteSize, offset, stateObj), - value, type, actualLength, stateObj, false); + type, actualLength, stateObj, false); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteUnterminatedSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -11027,11 +11049,11 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat(((SqlSingle)value).Value, stateObj); + WriteFloat(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble(((SqlDouble)value).Value, stateObj); + WriteDouble(GenericConverter.Convert(value).Value, stateObj); } break; @@ -11047,12 +11069,12 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlBinary) { - return stateObj.WriteByteArray(((SqlBinary)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false); } else { Debug.Assert(value is SqlBytes); - return stateObj.WriteByteArray(((SqlBytes)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false); } } @@ -11060,7 +11082,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - SqlGuid sqlGuid = (SqlGuid)value; + SqlGuid sqlGuid = GenericConverter.Convert(value); if (sqlGuid.IsNull) { @@ -11078,7 +11100,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if (((SqlBoolean)value).Value == true) + if (GenericConverter.Convert(value).Value == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11088,17 +11110,17 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte(((SqlByte)value).Value); + stateObj.WriteByte(GenericConverter.Convert(value).Value); else if (type.FixedLength == 2) - WriteShort(((SqlInt16)value).Value, stateObj); + WriteShort(GenericConverter.Convert(value).Value, stateObj); else if (type.FixedLength == 4) - WriteInt(((SqlInt32)value).Value, stateObj); + WriteInt(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong(((SqlInt64)value).Value, stateObj); + WriteLong(GenericConverter.Convert(value).Value, stateObj); } break; @@ -11112,14 +11134,14 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe } if (value is SqlChars) { - string sch = new string(((SqlChars)value).Value); + string sch = new string(GenericConverter.Convert(value).Value); return WriteEncodingChar(sch, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } @@ -11148,21 +11170,21 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlChars) { - return WriteCharArray(((SqlChars)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteCharArray(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteString(((SqlString)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteSqlDecimal((SqlDecimal)value, stateObj); + WriteSqlDecimal(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = GenericConverter.Convert(value); if (type.FixedLength == 4) { @@ -11182,7 +11204,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLMONEYN: { - WriteSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + WriteSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj); break; } @@ -11652,28 +11674,28 @@ private Task NullIfCompletedWriteTask(Task task) } } - private Task WriteValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { return GetTerminationTask(WriteUnterminatedValue(value, type, scale, actualLength, encodingByteSize, offset, stateObj, paramSize, isDataFeed), - value, type, actualLength, stateObj, isDataFeed); + type, actualLength, stateObj, isDataFeed); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteUnterminatedValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { - Debug.Assert((null != value) && (DBNull.Value != value), "unexpected missing or empty object"); + Debug.Assert((null != value) && !(value is DBNull), "unexpected missing or empty object"); // parameters are always sent over as BIG or N types switch (type.NullableType) { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat((float)value, stateObj); + WriteFloat(GenericConverter.Convert(value), stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble((double)value, stateObj); + WriteDouble(GenericConverter.Convert(value), stateObj); } break; @@ -11690,7 +11712,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int if (isDataFeed) { Debug.Assert(type.IsPlp, "Stream assigned to non-PLP was not converted!"); - return NullIfCompletedWriteTask(WriteStreamFeed((StreamDataFeed)value, stateObj, paramSize)); + return NullIfCompletedWriteTask(WriteStreamFeed(GenericConverter.Convert(value), stateObj, paramSize)); } else { @@ -11698,7 +11720,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { WriteInt(actualLength, stateObj); // chunk length } - return stateObj.WriteByteArray((byte[])value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, offset, canAccumulate: false); } } @@ -11706,7 +11728,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - FillGuidBytes((System.Guid)value, b); + FillGuidBytes(GenericConverter.Convert(value), b); stateObj.WriteByteSpan(b); break; } @@ -11714,7 +11736,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if ((bool)value == true) + if (GenericConverter.Convert(value) == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11724,15 +11746,15 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte((byte)value); + stateObj.WriteByte(GenericConverter.Convert(value)); else if (type.FixedLength == 2) - WriteShort((short)value, stateObj); + WriteShort(GenericConverter.Convert(value), stateObj); else if (type.FixedLength == 4) - WriteInt((int)value, stateObj); + WriteInt(GenericConverter.Convert(value), stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong((long)value, stateObj); + WriteLong(GenericConverter.Convert(value), stateObj); } break; @@ -11750,7 +11772,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); } else { @@ -11765,11 +11787,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false); } else { - return WriteEncodingChar((string)value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(GenericConverter.Convert(value), actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } } } @@ -11787,7 +11809,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); } else { @@ -11810,25 +11832,25 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false); } else { // convert to cchars instead of cbytes actualLength >>= 1; - return WriteString((string)value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(GenericConverter.Convert(value), actualLength, offset, stateObj, canAccumulate: false); } } } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteDecimal((decimal)value, stateObj); + WriteDecimal(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLDATETIMN: Debug.Assert(type.FixedLength <= 0xff, "Invalid Fixed Length"); - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, (byte)type.FixedLength); + TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), (byte)type.FixedLength); if (type.FixedLength == 4) { @@ -11848,13 +11870,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLMONEYN: { - WriteCurrency((decimal)value, type.FixedLength, stateObj); + WriteCurrency(GenericConverter.Convert(value), type.FixedLength, stateObj); break; } case TdsEnums.SQLDATE: { - WriteDate((DateTime)value, stateObj); + WriteDate(GenericConverter.Convert(value), stateObj); break; } @@ -11863,7 +11885,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteTime((TimeSpan)value, scale, actualLength, stateObj); + WriteTime(GenericConverter.Convert(value), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIME2: @@ -11871,11 +11893,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteDateTime2((DateTime)value, scale, actualLength, stateObj); + WriteDateTime2(GenericConverter.Convert(value), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: - WriteDateTimeOffset((DateTimeOffset)value, scale, actualLength, stateObj); + WriteDateTimeOffset(GenericConverter.Convert(value), scale, actualLength, stateObj); break; default: @@ -12119,7 +12141,7 @@ private byte[] SerializeUnencryptedValue(object value, MetaType type, byte scale // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) + private byte[] SerializeUnencryptedSqlValue(T value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -12135,11 +12157,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - return SerializeFloat(((SqlSingle)value).Value); + { + return SerializeFloat(GenericConverter.Convert(value).Value); + } else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - return SerializeDouble(((SqlDouble)value).Value); + return SerializeDouble(GenericConverter.Convert(value).Value); } case TdsEnums.SQLBIGBINARY: @@ -12150,19 +12174,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlBinary) { - Buffer.BlockCopy(((SqlBinary)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength); } else { Debug.Assert(value is SqlBytes); - Buffer.BlockCopy(((SqlBytes)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength); } + return b; } case TdsEnums.SQLUNIQUEID: { - byte[] b = ((SqlGuid)value).ToByteArray(); + byte[] b = GenericConverter.Convert(value).ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); return b; @@ -12173,23 +12198,23 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); // We normalize to allow conversion across data types. BIT is serialized into a BIGINT. - return SerializeLong(((SqlBoolean)value).Value == true ? 1 : 0, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value == true ? 1 : 0, stateObj); } case TdsEnums.SQLINTN: // We normalize to allow conversion across data types. All data types below are serialized into a BIGINT. if (type.FixedLength == 1) - return SerializeLong(((SqlByte)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); if (type.FixedLength == 2) - return SerializeLong(((SqlInt16)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); if (type.FixedLength == 4) - return SerializeLong(((SqlInt32)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - return SerializeLong(((SqlInt64)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); } case TdsEnums.SQLBIGCHAR: @@ -12197,13 +12222,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLTEXT: if (value is SqlChars) { - String sch = new String(((SqlChars)value).Value); + String sch = new String(GenericConverter.Convert(value).Value); return SerializeEncodingChar(sch, actualLength, offset, _defaultEncoding); } else { Debug.Assert(value is SqlString); - return SerializeEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding); + return SerializeEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding); } @@ -12218,20 +12243,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlChars) { - return SerializeCharArray(((SqlChars)value).Value, actualLength, offset); + return SerializeCharArray(GenericConverter.Convert(value).Value, actualLength, offset); } else { Debug.Assert(value is SqlString); - return SerializeString(((SqlString)value).Value, actualLength, offset); + return SerializeString(GenericConverter.Convert(value).Value, actualLength, offset); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - return SerializeSqlDecimal((SqlDecimal)value, stateObj); + return SerializeSqlDecimal(GenericConverter.Convert(value), stateObj); case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = GenericConverter.Convert(value); if (type.FixedLength == 4) { @@ -12277,7 +12302,7 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLMONEYN: { - return SerializeSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + return SerializeSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj); } default: diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index a63ed87bce..51698ddd73 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -287,6 +287,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs new file mode 100644 index 0000000000..2d6eb82fba --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/GenericConverter.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient +{ + // This leverages the same assumptions in SqlBuffer that the JIT will optimize out the boxing / unboxing when TIn == TOut + // This behavior is proven out in the NoBoxingValueTypes BulkCopy unit test that benchmarks and measures the allocations + internal static class GenericConverter + { + public static TOut Convert(TIn value) + { + return (TOut)(object)value; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index fdfbc1a71b..c4beb7d98e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -186,8 +186,8 @@ private enum ValueSourceType DbDataReader } - // Enum for specifying SqlDataReader.Get method used - private enum ValueMethod : byte + // Enum for specifying SqlDataReader.Get / IDataReader Get method used + private enum ValueMethod { GetValue, SqlTypeSqlDecimal, @@ -195,7 +195,19 @@ private enum ValueMethod : byte SqlTypeSqlSingle, DataFeedStream, DataFeedText, - DataFeedXml + DataFeedXml, + GetInt32, + GetString, + GetDouble, + GetDecimal, + GetInt16, + GetInt64, + GetChar, + GetByte, + GetBoolean, + GetDateTime, + GetGuid, + GetFloat } // Used to hold column metadata for SqlDataReader case @@ -309,7 +321,7 @@ private int RowNumber // for debug purpose only. // TODO: I will make this internal to use Reflection. #if DEBUG - internal static bool _setAlwaysTaskOnWrite = false; //when set and in DEBUG mode, TdsParser::WriteBulkCopyValue will always return a task + internal static bool _setAlwaysTaskOnWrite = false; //when set and in DEBUG mode, TdsParser::WriteBulkCopyValue will always return a task internal static bool SetAlwaysTaskOnWrite { set @@ -541,8 +553,8 @@ private bool IsCopyOption(SqlBulkCopyOptions copyOption) return (_copyOptions & copyOption) == copyOption; } - //Creates the initial query string, but does not execute it. - // + //Creates the initial query string, but does not execute it. + // private string CreateInitialQuery() { string[] parts; @@ -563,7 +575,7 @@ private string CreateInitialQuery() TDSCommand = "select @@trancount; SET FMTONLY ON select * from " + ADP.BuildMultiPartName(parts) + " SET FMTONLY OFF "; if (_connection.IsShiloh) { - // If its a temp DB then try to connect + // If its a temp DB then try to connect string TableCollationsStoredProc; if (_connection.IsKatmaiOrNewer) @@ -626,9 +638,9 @@ private string CreateInitialQuery() } // Creates and then executes initial query to get information about the targettable - // When __isAsyncBulkCopy == false (i.e. it is Sync copy): out result contains the resulset. Returns null. - // When __isAsyncBulkCopy == true (i.e. it is Async copy): This still uses the _parser.Run method synchronously and return Task. - // We need to have a _parser.RunAsync to make it real async. + // When __isAsyncBulkCopy == false (i.e. it is Sync copy): out result contains the resulset. Returns null. + // When __isAsyncBulkCopy == true (i.e. it is Async copy): This still uses the _parser.Run method synchronously and return Task. + // We need to have a _parser.RunAsync to make it real async. private Task CreateAndExecuteInitialQueryAsync(out BulkCopySimpleResultSet result) { string TDSCommand = CreateInitialQuery(); @@ -1055,13 +1067,19 @@ private void Dispose(bool disposing) // free unmanaged objects } - // unified method to read a value from the current row - // - private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull) + // Reads a cell and then writes it. + // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. + // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. + // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. + private Task ReadWriteColumnValueAsync(int destRowIndex) { _SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata; int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal; + bool isSqlType = false; + bool isDataFeed = false; + bool isNull = false; + switch (_rowSourceType) { case ValueSourceType.IDataReader: @@ -1071,34 +1089,39 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_DbDataReaderRowSource.IsDBNull(sourceOrdinal)) { - isSqlType = false; - isDataFeed = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - isSqlType = false; isDataFeed = true; - isNull = false; + + object feedColumnValue; + switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.DataFeedStream: - return new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal)); + feedColumnValue = new StreamDataFeed(_DbDataReaderRowSource.GetStream(sourceOrdinal)); + break; case ValueMethod.DataFeedText: - return new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal)); + feedColumnValue = new TextDataFeed(_DbDataReaderRowSource.GetTextReader(sourceOrdinal)); + break; case ValueMethod.DataFeedXml: // Only SqlDataReader supports an XmlReader - // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field + // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field Debug.Assert(_SqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader"); - return new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + feedColumnValue = new XmlDataFeed(_SqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + break; default: - Debug.Assert(false, string.Format("Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method)); + Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); isDataFeed = false; - object columnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + feedColumnValue = _DbDataReaderRowSource.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType); + break; } + + //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario + return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull); } } // SqlDataReader-specific logic @@ -1106,36 +1129,28 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_currentRowMetadata[destRowIndex].IsSqlType) { - INullable value; isSqlType = true; - isDataFeed = false; switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.SqlTypeSqlDecimal: - value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); - break; + var value = _SqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull); case ValueMethod.SqlTypeSqlDouble: // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); - break; + var dblValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); + return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull); case ValueMethod.SqlTypeSqlSingle: - // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); - break; + // use cast to handle value.IsNull correctly because no public constructor allows it + var singleValue = (SqlDecimal)_SqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); + return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull); default: - Debug.Assert(false, string.Format("Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method)); - value = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal); - break; + Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); + var sqlValue = (INullable)_SqlDataReaderRowSource.GetSqlValue(sourceOrdinal); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull); } - - isNull = value.IsNull; - return value; } else { - isSqlType = false; - isDataFeed = false; - object value = _SqlDataReaderRowSource.GetValue(sourceOrdinal); isNull = ((value == null) || (value == DBNull.Value)); if ((!isNull) && (metadata.type == SqlDbType.Udt)) @@ -1148,31 +1163,66 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable"); } -#endif - return value; +#endif + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull); } } else { - isDataFeed = false; - IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource; - // Back-compat with 4.0 and 4.5 - only use IsDbNull when streaming is enabled and only for non-SqlDataReader + // Only use IsDbNull when streaming is enabled and only for non-SqlDataReader if ((_enableStreaming) && (_SqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal))) { - isSqlType = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + if (_currentRowMetadata[destRowIndex].Method == ValueMethod.GetValue || rowSourceAsIDataReader.IsDBNull(sourceOrdinal)) + { + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + switch (_currentRowMetadata[sourceOrdinal].Method) + { + case ValueMethod.GetInt32: + return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false); + case ValueMethod.GetString: + var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal); + isNull = strValue == null; + return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDouble: + return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDecimal: + return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt16: + return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt64: + return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetChar: + return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetByte: + return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetBoolean: + return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDateTime: + return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetGuid: + return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetFloat: + return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + default: + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } + } } } - case ValueSourceType.DataTable: case ValueSourceType.RowArray: { @@ -1180,6 +1230,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!"); isDataFeed = false; + // unfortunately this has to be boxed due to DataRow's API. object currentRowValue = _currentRow[sourceOrdinal]; ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType); @@ -1192,7 +1243,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (isSqlType) { - return new SqlDecimal(((SqlSingle)currentRowValue).Value); + var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value); + return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1200,16 +1252,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!float.IsNaN(f)) { isSqlType = true; - return new SqlDecimal(f); + return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDouble: { if (isSqlType) { - return new SqlDecimal(((SqlDouble)currentRowValue).Value); + var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1217,37 +1273,44 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!double.IsNaN(d)) { isSqlType = true; - return new SqlDecimal(d); + return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDecimal: { if (isSqlType) { - return (SqlDecimal)currentRowValue; + var sqlValue = (SqlDecimal)currentRowValue; + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { isSqlType = true; - return new SqlDecimal((Decimal)currentRowValue); + var sqlValue = new SqlDecimal((decimal)currentRowValue); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } } default: { - Debug.Assert(false, string.Format("Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {0}", _currentRowMetadata[destRowIndex].Method)); - break; + Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); + // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } } } - - // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) - return currentRowValue; + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } default: { - Debug.Assert(false, "ValueSourcType unspecified"); + Debug.Fail("ValueSourcType unspecified"); throw ADP.NotSupported(); } } @@ -1261,7 +1324,7 @@ private Task ReadFromRowSourceAsync(CancellationToken cts) { if (_isAsyncBulkCopy && (_DbDataReaderRowSource != null)) { - //This will call ReadAsync for DbDataReader (for SqlDataReader it will be truely async read; for non-SqlDataReader it may block.) + //This will call ReadAsync for DbDataReader (for SqlDataReader it will be truely async read; for non-SqlDataReader it may block.) return _DbDataReaderRowSource.ReadAsync(cts).ContinueWith((t) => { if (t.Status == TaskStatus.RanToCompletion) @@ -1374,7 +1437,7 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) else if (typeof(SqlSingle) == t || typeof(float) == t) { isSqlType = true; - method = ValueMethod.SqlTypeSqlSingle; // Source Type SqlSingle + method = ValueMethod.SqlTypeSqlSingle; // Source Type SqlSingle } else { @@ -1439,6 +1502,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) method = ValueMethod.GetValue; } } + else if (_rowSourceType == ValueSourceType.IDataReader) + { + isSqlType = false; + isDataFeed = false; + + Type t = ((IDataReader)_rowSource).GetFieldType(ordinal); + + if (t == typeof(bool)) + { + method = ValueMethod.GetBoolean; + } + else if (t == typeof(byte)) + { + method = ValueMethod.GetByte; + } + else if (t == typeof(char)) + { + method = ValueMethod.GetChar; + } + else if (t == typeof(DateTime)) + { + method = ValueMethod.GetDateTime; + } + else if (t == typeof(decimal)) + { + method = ValueMethod.GetDecimal; + } + else if (t == typeof(double)) + { + method = ValueMethod.GetDouble; + } + else if (t == typeof(float)) + { + method = ValueMethod.GetFloat; + } + else if (t == typeof(Guid)) + { + method = ValueMethod.GetGuid; + } + else if (t == typeof(short)) + { + method = ValueMethod.GetInt16; + } + else if (t == typeof(int)) + { + method = ValueMethod.GetInt32; + } + else if (t == typeof(long)) + { + method = ValueMethod.GetInt64; + } + else if (t == typeof(string)) + { + method = ValueMethod.GetString; + } + else + { + method = ValueMethod.GetValue; + } + } else { isSqlType = false; @@ -1449,8 +1572,6 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) return new SourceColumnMetadata(method, isSqlType, isDataFeed); } - // - // private void CreateOrValidateConnection(string method) { if (null == _connection) @@ -1581,13 +1702,14 @@ private string UnquotedName(string name) return name; } - private object ValidateBulkCopyVariant(object value) + private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue) { - // from the spec: + variantValue = null; + + // From the spec: // "The only acceptable types are ..." // GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8, // MONEY4, MONEY, DECIMALN, NUMERICN, FTL4, FLT8, DATETIME4 and DATETIME - // MetaType metatype = MetaType.GetMetaTypeFromValue(value); switch (metatype.TDSType) { @@ -1610,21 +1732,22 @@ private object ValidateBulkCopyVariant(object value) case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: if (value is INullable) - { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types. - return MetaType.GetComValueFromSqlVariant(value); + { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types. + variantValue = MetaType.GetComValueFromSqlVariant(value); + return true; } else { - return value; + return false; } default: throw SQL.BulkLoadInvalidVariantValue(); } } - private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed) + private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType) { - coercedToDataFeed = false; + bool coercedToDataFeed = false; if (isNull) { @@ -1632,11 +1755,13 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); } - return value; + + return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata); } MetaType type = metadata.metaType; bool typeChanged = false; + object objValue = null; // If the column is encrypted then we are going to transparently encrypt this column // (based on connection string setting)- Use the metaType for the underlying @@ -1661,56 +1786,55 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { case TdsEnums.SQLNUMERICN: case TdsEnums.SQLDECIMALN: - mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); - - // Convert Source Decimal Precision and Scale to Destination Precision and Scale - // Fix Bug: 385971 sql decimal data could get corrupted on insert if the scale of - // the source and destination weren't the same. The BCP protocol, specifies the - // scale of the incoming data in the insert statement, we just tell the server we - // are inserting the same scale back. This then created a bug inside the BCP operation - // if the scales didn't match. The fix is to do the same thing that SQL Parameter does, - // and adjust the scale before writing. In Orcas is scale adjustment should be removed from - // SqlParameter and SqlBulkCopy and Isolated inside SqlParameter.CoerceValue, but because of - // where we are in the cycle, the changes must be kept at minimum, so I'm just bringing the - // code over to SqlBulkCopy. - - SqlDecimal sqlValue; - if ((isSqlType) && (!typeChanged)) + SqlDecimal decValue; + if (typeof(T) == typeof(decimal)) + { + decValue = new SqlDecimal(GenericConverter.Convert(value)); + } + else if (typeof(T) == typeof(SqlDecimal)) { - sqlValue = (SqlDecimal)value; + decValue = GenericConverter.Convert(value); } else { - sqlValue = new SqlDecimal((Decimal)value); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + decValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false)); } - if (sqlValue.Scale != scale) + // Convert Source Decimal Precision and Scale to Destination Precision and Scale + // Sql decimal data could get corrupted on insert if the scale of + // the source and destination weren't the same. The BCP protocol, specifies the + // scale of the incoming data in the insert statement, we just tell the server we + // are inserting the same scale back. + if (decValue.Scale != scale) { - sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale); + decValue = TdsParser.AdjustSqlDecimalScale(decValue, scale); } - if (sqlValue.Precision > precision) + if (decValue.Precision > precision) { try { - sqlValue = SqlDecimal.ConvertToPrecScale(sqlValue, precision, sqlValue.Scale); + decValue = SqlDecimal.ConvertToPrecScale(decValue, precision, decValue.Scale); } catch (SqlTruncateException) { - throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue)); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(decValue)); } catch (Exception e) { + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } } // Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing - value = sqlValue; isSqlType = true; - typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type - break; + typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type + + // returning here to avoid unnecessary decValue initialization for all types + return WriteConvertedValue(decValue, col, isSqlType, isNull, coercedToDataFeed, metadata); case TdsEnums.SQLINTN: case TdsEnums.SQLFLTN: @@ -1734,16 +1858,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed); break; case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false); if (!coercedToDataFeed) { // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX) - string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value); + string str = typeChanged + ? (string)objValue + : isSqlType + ? GenericConverter.Convert(value).Value + : GenericConverter.Convert(value) + ; + int maxStringLength = length / 2; if (str.Length > maxStringLength) { @@ -1762,8 +1892,7 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } break; case TdsEnums.SQLVARIANT: - value = ValidateBulkCopyVariant(value); - typeChanged = true; + typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue); break; case TdsEnums.SQLUDT: // UDTs are sent as varbinary so we need to get the raw bytes @@ -1774,33 +1903,25 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re // in byte[] form. if (!(value is byte[])) { - value = _connection.GetBytes(value); + objValue = _connection.GetBytes(value); typeChanged = true; } break; case TdsEnums.SQLXMLTYPE: - // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed + // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype"); - if (value is XmlReader) + if (value is XmlReader xmlReader) { - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); typeChanged = true; coercedToDataFeed = true; } break; default: - Debug.Assert(false, "Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null)); + Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null)); throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null); } - - if (typeChanged) - { - // All type changes change to CLR types - isSqlType = false; - } - - return value; } catch (Exception e) { @@ -1810,6 +1931,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } + + if (typeChanged) + { + // All type changes change to CLR types + isSqlType = false; + return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata); + } + else + { + return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata); + } } /// @@ -1842,7 +1974,7 @@ public void WriteToServer(DbDataReader reader) _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; _isAsyncBulkCopy = false; - WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; + WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally { @@ -1879,7 +2011,7 @@ public void WriteToServer(IDataReader reader) _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; _isAsyncBulkCopy = false; - WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; + WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally { @@ -1920,7 +2052,7 @@ public void WriteToServer(DataTable table, DataRowState rowState) _rowEnumerator = table.Rows.GetEnumerator(); _isAsyncBulkCopy = false; - WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; + WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally { @@ -1964,7 +2096,7 @@ public void WriteToServer(DataRow[] rows) _rowEnumerator = rows.GetEnumerator(); _isAsyncBulkCopy = false; - WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; + WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally { @@ -2024,7 +2156,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); _isAsyncBulkCopy = true; - resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; + resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; } finally { @@ -2065,7 +2197,7 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; _isAsyncBulkCopy = true; - resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; + resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; } finally { @@ -2106,7 +2238,7 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; _isAsyncBulkCopy = true; - resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; + resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; } finally { @@ -2160,7 +2292,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); _isAsyncBulkCopy = true; - resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; + resultTask = WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); //It returns Task since _isAsyncBulkCopy = true; } finally { @@ -2169,7 +2301,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat return resultTask; } - // Writes row source. + // Writes row source. // private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctoken) { @@ -2427,34 +2559,40 @@ private bool FireRowsCopiedEvent(long rowsCopied) return eventArgs.Abort; } - // Reads a cell and then writes it. - // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. - // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. - // When _isAsyncBulkCopy == false: Writes are purely sync. This method reutrn null at the end. - // - private Task ReadWriteColumnValueAsync(int col) + private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull) { - bool isSqlType; - bool isDataFeed; - bool isNull; - Object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask - _SqlMetaData metadata = _sortedColumnMappings[col]._metadata; - if (!isDataFeed) + if (isDataFeed) { - value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed); + //nothing to convert, skip straight to write + return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata); + } + else + { + return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType); + } + } - // If column encryption is requested via connection string option, perform encryption here - if (!isNull && // if value is not NULL - metadata.isEncrypted) - { // If we are transparently encrypting - Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); - value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType); - isSqlType = false; // Its not a sql type anymore - } + private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata) + { + // If column encryption is requested via connection string option, perform encryption here + if (!isNull && // if value is not NULL + metadata.isEncrypted) + { // If we are transparently encrypting + Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); + var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType); + isSqlType = false; // Its not a sql type anymore + + return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata); } + else + { + return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata); + } + } - //write part + private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata) + { Task writeTask = null; if (metadata.type != SqlDbType.Variant) { @@ -2473,15 +2611,15 @@ private Task ReadWriteColumnValueAsync(int col) if (variantInternalType == SqlBuffer.StorageType.DateTime2) { - _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDateTime2(GenericConverter.Convert(value), _stateObj); } else if (variantInternalType == SqlBuffer.StorageType.Date) { - _parser.WriteSqlVariantDate(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDate(GenericConverter.Convert(value), _stateObj); } else { - writeTask = _parser.WriteSqlVariantDataRowValue(value, _stateObj); //returns Task/Null + writeTask = _parser.WriteSqlVariantDataRowValue(value, isNull, _stateObj); //returns Task/Null } } @@ -2563,7 +2701,7 @@ private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource sour } - // The notification logic. + // The notification logic. // private void CheckAndRaiseNotification() { @@ -2636,11 +2774,11 @@ private void CheckAndRaiseNotification() Debug.Assert(writeTask == null, "Task should not pend while doing sync bulk copy"); RunParser(); AbortTransaction(); - throw exception; //this will be caught and put inside the Task's exception. + throw exception; //this will be caught and put inside the Task's exception. } } - // Checks for cancellation. If cancel requested, cancels the task and returns the cancelled task + // Checks for cancellation. If cancel requested, cancels the task and returns the cancelled task Task CheckForCancellation(CancellationToken cts, TaskCompletionSource tcs) { if (cts.IsCancellationRequested) @@ -2688,7 +2826,7 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, int i; try { - //totalRows is batchsize which is 0 by default. In that case, we keep copying till the end (until _hasMoreRowToCopy == false). + //totalRows is batchsize which is 0 by default. In that case, we keep copying till the end (until _hasMoreRowToCopy == false). for (i = rowsSoFar; (totalRows <= 0 || i < totalRows) && _hasMoreRowToCopy == true; i++) { if (_isAsyncBulkCopy == true) @@ -2705,10 +2843,10 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, task = CopyColumnsAsync(0); //copy 1 row if (task == null) - { //tsk is done. + { //tsk is done. CheckAndRaiseNotification(); //check notification logic after copying the row - //now we will read the next row. + //now we will read the next row. Task readTask = ReadFromRowSourceAsync(cts); // read the next row. Caution: more is only valid if the task returns null. Otherwise, we wait for Task.Result if (readTask != null) { @@ -3046,7 +3184,7 @@ private void CleanUpStateObject(bool isCancelRequested = true) // The continuation part of WriteToServerInternalRest. Executes when the initial query task is completed. (see, WriteToServerInternalRest). // It carries on the source which is passed from the WriteToServerInternalRest and performs SetResult when the entire copy is done. // The carried on source may be null in case of Sync copy. So no need to SetResult at that time. - // It launches the copy operation. + // It launches the copy operation. // private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet internalResults, CancellationToken cts, TaskCompletionSource source) { @@ -3081,7 +3219,7 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } AsyncHelper.ContinueTask(task, source, () => { - //Bulk copy task is completed at this moment. + //Bulk copy task is completed at this moment. //Todo: The cases may be combined for code reuse. if (task.IsCanceled) { @@ -3167,7 +3305,7 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } } - // Rest of the WriteToServerInternalAsync method. + // Rest of the WriteToServerInternalAsync method. // It carries on the source from its caller WriteToServerInternal. // source is null in case of Sync bcp. But valid in case of Async bcp. // It calls the WriteToServerInternalRestContinuedAsync as a continuation of the initial query task. @@ -3213,7 +3351,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio regReconnectCancel = cts.Register(() => cancellableReconnectTS.TrySetCanceled()); } AsyncHelper.ContinueTask(reconnectTask, cancellableReconnectTS, () => { cancellableReconnectTS.SetResult(null); }); - // no need to cancel timer since SqlBulkCopy creates specific task source for reconnection + // no need to cancel timer since SqlBulkCopy creates specific task source for reconnection AsyncHelper.SetTimeoutException(cancellableReconnectTS, BulkCopyTimeout, () => { return SQL.BulkLoadInvalidDestinationTable(_destinationTableName, SQL.CR_ReconnectTimeout()); }, CancellationToken.None); AsyncHelper.ContinueTask(cancellableReconnectTS.Task, source, @@ -3256,7 +3394,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio _connection.AddWeakReference(this, SqlReferenceCollection.BulkCopyTag); } - internalConnection.ThreadHasParserLockForClose = true; // In case of error, let the connection know that we already have the parser lock + internalConnection.ThreadHasParserLockForClose = true; // In case of error, let the connection know that we already have the parser lock try { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlParameter.cs index 12a9257fbd..78ca6f2363 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -16,10 +16,8 @@ using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Server; - namespace Microsoft.Data.SqlClient { - internal abstract class DataFeed { } @@ -109,7 +107,7 @@ internal SqlCipherMetadata CipherMetadata /// /// Indicates if the parameter encryption metadata received by sp_describe_parameter_encryption. - /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate + /// For unencrypted parameters, the encryption metadata should still be sent (and will indicate /// that no encryption is needed). /// internal bool HasReceivedMetadata { get; set; } @@ -461,7 +459,7 @@ internal SmiParameterMetaData MetaDataForSmi(out ParameterPeekAheadValue peekAhe long actualLen = GetActualSize(); long maxLen = this.Size; - // GetActualSize returns bytes length, but smi expects char length for + // GetActualSize returns bytes length, but smi expects char length for // character types, so adjust if (!mt.IsLong) { @@ -782,10 +780,10 @@ public SqlDbType SqlDbType { MetaType metatype = _metaType; // HACK!!! - // We didn't want to expose SmallVarBinary on SqlDbType so we - // stuck it at the end of SqlDbType in v1.0, except that now + // We didn't want to expose SmallVarBinary on SqlDbType so we + // stuck it at the end of SqlDbType in v1.0, except that now // we have new data types after that and it's smack dab in the - // middle of the valid range. To prevent folks from setting + // middle of the valid range. To prevent folks from setting // this invalid value we have to have this code here until we // can take the time to fix it later. if ((SqlDbType)TdsEnums.SmallVarBinary == value) @@ -1088,14 +1086,25 @@ object ICloneable.Clone() // Coerced Value is also used in SqlBulkCopy.ConvertValue(object value, _SqlMetaData metadata) internal static object CoerceValue(object value, MetaType destinationType, out bool coercedToDataFeed, out bool typeChanged, bool allowStreaming = true) + { + typeChanged = CoerceValueIfNeeded(value, destinationType, out var objValue, out coercedToDataFeed, allowStreaming); + + return typeChanged ? objValue : value; + } + + internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, out object objValue, out bool coercedToDataFeed, bool allowStreaming = true) { Debug.Assert(!(value is DataFeed), "Value provided should not already be a data feed"); Debug.Assert(!ADP.IsNull(value), "Value provided should not be null"); Debug.Assert(null != destinationType, "null destinationType"); coercedToDataFeed = false; - typeChanged = false; - Type currentType = value.GetType(); + objValue = null; + Type currentType = typeof(T) == typeof(object) + ? value.GetType() // only call GetType if we know boxing has already occurred. + : typeof(T); + + var typeChanged = false; if ((typeof(object) != destinationType.ClassType) && (currentType != destinationType.ClassType) && @@ -1110,45 +1119,45 @@ internal static object CoerceValue(object value, MetaType destinationType, out b // For Xml data, destination Type is always string if (typeof(SqlXml) == currentType) { - value = MetaType.GetStringFromXml((XmlReader)(((SqlXml)value).CreateReader())); + objValue = MetaType.GetStringFromXml(GenericConverter.Convert(value).CreateReader()); } else if (typeof(SqlString) == currentType) { typeChanged = false; // Do nothing } - else if (typeof(XmlReader).IsAssignableFrom(currentType)) + else if (value is XmlReader xmlReader) { if (allowStreaming) { coercedToDataFeed = true; - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); } else { - value = MetaType.GetStringFromXml((XmlReader)value); + objValue = MetaType.GetStringFromXml(xmlReader); } } else if (typeof(char[]) == currentType) { - value = new string((char[])value); + objValue = new string(GenericConverter.Convert(value)); } else if (typeof(SqlChars) == currentType) { - value = new string(((SqlChars)value).Value); + objValue = new string(GenericConverter.Convert(value).Value); } - else if (value is TextReader && allowStreaming) + else if (value is TextReader tr && allowStreaming) { coercedToDataFeed = true; - value = new TextDataFeed((TextReader)value); + objValue = new TextDataFeed(tr); } else { - value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); + objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); } } else if ((DbType.Currency == destinationType.DbType) && (typeof(string) == currentType)) { - value = Decimal.Parse((string)value, NumberStyles.Currency, (IFormatProvider)null); // WebData 99376 + objValue = decimal.Parse(GenericConverter.Convert(value), NumberStyles.Currency, (IFormatProvider)null); } else if ((typeof(SqlBytes) == currentType) && (typeof(byte[]) == destinationType.ClassType)) { @@ -1156,32 +1165,32 @@ internal static object CoerceValue(object value, MetaType destinationType, out b } else if ((typeof(string) == currentType) && (SqlDbType.Time == destinationType.SqlDbType)) { - value = TimeSpan.Parse((string)value); + objValue = TimeSpan.Parse(GenericConverter.Convert(value)); } else if ((typeof(string) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType)) { - value = DateTimeOffset.Parse((string)value, (IFormatProvider)null); + objValue = DateTimeOffset.Parse(GenericConverter.Convert(value), (IFormatProvider)null); } else if ((typeof(DateTime) == currentType) && (SqlDbType.DateTimeOffset == destinationType.SqlDbType)) { - value = new DateTimeOffset((DateTime)value); + objValue = new DateTimeOffset(GenericConverter.Convert(value)); } - else if (TdsEnums.SQLTABLE == destinationType.TDSType && - (value is DataTable || + else if (TdsEnums.SQLTABLE == destinationType.TDSType && ( + value is DataTable || value is DbDataReader || value is System.Collections.Generic.IEnumerable)) { // no conversion for TVPs. typeChanged = false; } - else if (destinationType.ClassType == typeof(byte[]) && value is Stream && allowStreaming) + else if (destinationType.ClassType == typeof(byte[]) && allowStreaming && value is Stream stream) { coercedToDataFeed = true; - value = new StreamDataFeed((Stream)value); + objValue = new StreamDataFeed(stream); } else { - value = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); + objValue = Convert.ChangeType(value, destinationType.ClassType, (IFormatProvider)null); } } catch (Exception e) @@ -1197,8 +1206,8 @@ value is DbDataReader || } Debug.Assert(allowStreaming || !coercedToDataFeed, "Streaming is not allowed, but type was coerced into a data feed"); - Debug.Assert(value.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); - return value; + Debug.Assert(objValue == null || objValue.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); + return typeChanged; } internal void FixStreamDataForNonPLP() @@ -1791,7 +1800,7 @@ internal void Validate(int index, bool isCommandProc) MetaType metaType = GetMetaTypeOnly(); _internalMetaType = metaType; - // NOTE: (General Criteria): SqlParameter does a Size Validation check and would fail if the size is 0. + // NOTE: (General Criteria): SqlParameter does a Size Validation check and would fail if the size is 0. // This condition filters all scenarios where we view a valid size 0. if (ADP.IsDirection(this, ParameterDirection.Output) && !ADP.IsDirection(this, ParameterDirection.ReturnValue) && // SQL BU DT 372370 @@ -1864,15 +1873,15 @@ internal MetaType ValidateTypeLengths(bool yukonOrNewer) // Bug: VSTFDevDiv #636867 // Notes: - // 'actualSizeInBytes' is the size of value passed; + // 'actualSizeInBytes' is the size of value passed; // 'sizeInCharacters' is the parameter size; - // 'actualSizeInBytes' is in bytes; - // 'this.Size' is in charaters; - // 'sizeInCharacters' is in characters; + // 'actualSizeInBytes' is in bytes; + // 'this.Size' is in charaters; + // 'sizeInCharacters' is in characters; // 'TdsEnums.TYPE_SIZE_LIMIT' is in bytes; // For Non-NCharType and for non-Yukon or greater variables, size should be maintained; // Reverting changes from bug VSTFDevDiv # 479739 as it caused an regression; - // Modifed variable names from 'size' to 'sizeInCharacters', 'actualSize' to 'actualSizeInBytes', and + // Modifed variable names from 'size' to 'sizeInCharacters', 'actualSize' to 'actualSizeInBytes', and // 'maxSize' to 'maxSizeInBytes' // The idea is to // 1) revert the regression from bug 479739 diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 9c7bdee4c4..4b404fd00f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -695,7 +695,7 @@ internal void Connect(ServerInfo serverInfo, } // Retrieve the IP and port number from native SNI for TCP protocol. The IP information is stored temporarily in the - // pendingSQLDNSObject but not in the DNS Cache at this point. We only add items to the DNS Cache after we receive the + // pendingSQLDNSObject but not in the DNS Cache at this point. We only add items to the DNS Cache after we receive the // IsSupported flag as true in the feature ext ack from server. internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) { @@ -7422,12 +7422,12 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars // Therefore the sql_variant value must not include the MaxLength. This is the major difference // between this method and WriteSqlVariantValue above. // - internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject stateObj, bool canAccumulate = true) + internal Task WriteSqlVariantDataRowValue(T value, bool isNull, TdsParserStateObject stateObj, bool canAccumulate = true) { Debug.Assert(_isShiloh == true, "Shouldn't be dealing with sql_variant in pre-SQL2000 server!"); // handle null values - if ((null == value) || (DBNull.Value == value)) + if (isNull) { WriteInt(TdsEnums.FIXEDNULL, stateObj); return null; @@ -7438,44 +7438,44 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta if (metatype.IsAnsiType) { - length = GetEncodingCharLength((string)value, length, 0, _defaultEncoding); + length = GetEncodingCharLength(GenericConverter.Convert(value), length, 0, _defaultEncoding); } switch (metatype.TDSType) { case TdsEnums.SQLFLT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteFloat((Single)value, stateObj); + WriteFloat(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLFLT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteDouble((Double)value, stateObj); + WriteDouble(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteLong((Int64)value, stateObj); + WriteLong(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteInt((Int32)value, stateObj); + WriteInt(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT2: WriteSqlVariantHeader(4, metatype.TDSType, metatype.PropBytes, stateObj); - WriteShort((Int16)value, stateObj); + WriteShort(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLINT1: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - stateObj.WriteByte((byte)value); + stateObj.WriteByte(GenericConverter.Convert(value)); break; case TdsEnums.SQLBIT: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - if ((bool)value == true) + if (GenericConverter.Convert(value)) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -7484,7 +7484,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARBINARY: { - byte[] b = (byte[])value; + byte[] b = GenericConverter.Convert(value); length = b.Length; WriteSqlVariantHeader(4 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -7494,7 +7494,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARCHAR: { - string s = (string)value; + string s = GenericConverter.Convert(value); length = s.Length; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -7506,7 +7506,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLUNIQUEID: { - System.Guid guid = (System.Guid)value; + Guid guid = GenericConverter.Convert(value); byte[] b = guid.ToByteArray(); length = b.Length; @@ -7518,7 +7518,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLNVARCHAR: { - string s = (string)value; + string s = GenericConverter.Convert(value); length = s.Length * 2; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -7533,7 +7533,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLDATETIME: { - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, 8); + TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), 8); WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); WriteInt(dt.days, stateObj); @@ -7544,7 +7544,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLMONEY: { WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteCurrency((Decimal)value, 8, stateObj); + WriteCurrency(GenericConverter.Convert(value), 8, stateObj); break; } @@ -7552,21 +7552,22 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta { WriteSqlVariantHeader(21, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Precision); //propbytes: precision - stateObj.WriteByte((byte)((Decimal.GetBits((Decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale - WriteDecimal((Decimal)value, stateObj); + var decValue = GenericConverter.Convert(value); + stateObj.WriteByte((byte)((decimal.GetBits(decValue)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale + WriteDecimal(decValue, stateObj); break; } case TdsEnums.SQLTIME: WriteSqlVariantHeader(8, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteTime((TimeSpan)value, metatype.Scale, 5, stateObj); + WriteTime(GenericConverter.Convert(value), metatype.Scale, 5, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: WriteSqlVariantHeader(13, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteDateTimeOffset((DateTimeOffset)value, metatype.Scale, 10, stateObj); + WriteDateTimeOffset(GenericConverter.Convert(value), metatype.Scale, 10, stateObj); break; default: @@ -11321,7 +11322,7 @@ internal bool ShouldEncryptValuesForBulkCopy() /// Encrypts a column value (for SqlBulkCopy) /// /// - internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) + internal byte[] EncryptColumnValue(T value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) { Debug.Assert(_serverSupportsColumnEncryption, "Server doesn't support encryption, yet we received encryption metadata"); Debug.Assert(ShouldEncryptValuesForBulkCopy(), "Encryption attempted when not requested"); @@ -11346,10 +11347,14 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin // when we normalize and serialize the data buffers. The serialization routine expects us // to report the size of data to be copied out (for serialization). If we underreport the // size, truncation will happen for us! - actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + actualLengthInBytes = (isSqlType) + ? GenericConverter.Convert(value).Length + : GenericConverter.Convert(value).Length; + if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) - { // see comments agove + { + // see comments above actualLengthInBytes = metadata.baseTI.length; } break; @@ -11365,7 +11370,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = (isSqlType) ? ((SqlString)value).Value : (string)value; + string stringValue = (isSqlType) + ? GenericConverter.Convert(value).Value + : GenericConverter.Convert(value); + actualLengthInBytes = _defaultEncoding.GetByteCount(stringValue); // If the string length is > max length, then use the max length (see comments above) @@ -11379,7 +11387,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - actualLengthInBytes = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + actualLengthInBytes = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length) + * 2; if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) @@ -11424,7 +11435,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin _connHandler.ConnectionOptions.DataSource); } - internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) + internal Task WriteBulkCopyValue(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) { Debug.Assert(!isSqlType || value is INullable, "isSqlType is true, but value can not be type cast to an INullable"); Debug.Assert(!isDataFeed ^ value is DataFeed, "Incorrect value for isDataFeed"); @@ -11489,10 +11500,12 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: - ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + ccb = (isSqlType) + ? GenericConverter.Convert(value).Length + : GenericConverter.Convert(value).Length; break; case TdsEnums.SQLUNIQUEID: - ccb = GUID_SIZE; // that's a constant for guid + ccb = GUID_SIZE; break; case TdsEnums.SQLBIGCHAR: case TdsEnums.SQLBIGVARCHAR: @@ -11505,11 +11518,11 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars string stringValue = null; if (isSqlType) { - stringValue = ((SqlString)value).Value; + stringValue = GenericConverter.Convert(value).Value; } else { - stringValue = (string)value; + stringValue = GenericConverter.Convert(value); } ccb = stringValue.Length; @@ -11518,15 +11531,22 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + ccb = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length + ) * 2; break; case TdsEnums.SQLXMLTYPE: // Value here could be string or XmlReader - if (value is XmlReader) + // the XmlReader scenario can only occur when T is object (enforced during SqlBulkCopy.ReadWriteColumnValueAsync) + if (typeof(T) == typeof(object) && value is XmlReader xr) { - value = MetaType.GetStringFromXml((XmlReader)value); + value = GenericConverter.Convert(MetaType.GetStringFromXml(xr)); } - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + ccb = (isSqlType + ? GenericConverter.Convert(value).Value.Length + : GenericConverter.Convert(value).Length + ) * 2; break; default: @@ -11578,7 +11598,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars } else if (metatype.SqlDbType != SqlDbType.Udt || metatype.IsLong) { + // we only have to consider a conversion from above in this case. internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); + if ((internalWriteTask == null) && (_asyncWrite)) { internalWriteTask = stateObj.WaitForAccumulatedWrites(); @@ -11588,7 +11610,7 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars else { WriteShort(ccb, stateObj); - internalWriteTask = stateObj.WriteByteArray((byte[])value, ccb, 0); + internalWriteTask = stateObj.WriteByteArray(GenericConverter.Convert(value), ccb, 0); } #if DEBUG @@ -11924,7 +11946,7 @@ private bool IsBOMNeeded(MetaType type, object value) return false; } - private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) + private Task GetTerminationTask(Task unterminatedWriteTask, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) { if (type.IsPlp && ((actualLength > 0) || isDataFeed)) { @@ -11947,16 +11969,16 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } - private Task WriteSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { return GetTerminationTask( WriteUnterminatedSqlValue(value, type, actualLength, codePageByteSize, offset, stateObj), - value, type, actualLength, stateObj, false); + type, actualLength, stateObj, false); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteUnterminatedSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -11967,11 +11989,11 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat(((SqlSingle)value).Value, stateObj); + WriteFloat(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble(((SqlDouble)value).Value, stateObj); + WriteDouble(GenericConverter.Convert(value).Value, stateObj); } break; @@ -11987,18 +12009,18 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlBinary) { - return stateObj.WriteByteArray(((SqlBinary)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false); } else { Debug.Assert(value is SqlBytes); - return stateObj.WriteByteArray(((SqlBytes)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value).Value, actualLength, offset, canAccumulate: false); } } case TdsEnums.SQLUNIQUEID: { - byte[] b = ((SqlGuid)value).ToByteArray(); + byte[] b = GenericConverter.Convert(value).ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); stateObj.WriteByteArray(b, actualLength, 0); @@ -12008,7 +12030,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if (((SqlBoolean)value).Value == true) + if (GenericConverter.Convert(value).Value == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -12018,17 +12040,17 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte(((SqlByte)value).Value); + stateObj.WriteByte(GenericConverter.Convert(value).Value); else if (type.FixedLength == 2) - WriteShort(((SqlInt16)value).Value, stateObj); + WriteShort(GenericConverter.Convert(value).Value, stateObj); else if (type.FixedLength == 4) - WriteInt(((SqlInt32)value).Value, stateObj); + WriteInt(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong(((SqlInt64)value).Value, stateObj); + WriteLong(GenericConverter.Convert(value).Value, stateObj); } break; @@ -12040,16 +12062,16 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { WriteInt(codePageByteSize, stateObj); // chunk length } - if (value is System.Data.SqlTypes.SqlChars) + if (value is SqlChars) { - String sch = new String(((System.Data.SqlTypes.SqlChars)value).Value); + string sch = new string(GenericConverter.Convert(value).Value); return WriteEncodingChar(sch, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } @@ -12076,27 +12098,27 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (actualLength != 0) actualLength >>= 1; - if (value is System.Data.SqlTypes.SqlChars) + if (value is SqlChars) { - return WriteCharArray(((System.Data.SqlTypes.SqlChars)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteCharArray(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteString(((SqlString)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(GenericConverter.Convert(value).Value, actualLength, offset, stateObj, canAccumulate: false); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteSqlDecimal((SqlDecimal)value, stateObj); + WriteSqlDecimal(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = GenericConverter.Convert(value); if (type.FixedLength == 4) { - if (0 > dt.DayTicks || dt.DayTicks > UInt16.MaxValue) + if (0 > dt.DayTicks || dt.DayTicks > ushort.MaxValue) throw SQL.SmallDateTimeOverflow(dt.ToString()); WriteShort(dt.DayTicks, stateObj); @@ -12112,12 +12134,12 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLMONEYN: { - WriteSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + WriteSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj); break; } case TdsEnums.SQLUDT: - Debug.Assert(false, "Called WriteSqlValue on UDT param.Should have already been handled"); + Debug.Fail("Called WriteSqlValue on UDT param.Should have already been handled"); throw SQL.UDTUnexpectedResult(value.GetType().AssemblyQualifiedName); default: @@ -12622,28 +12644,28 @@ private Task NullIfCompletedWriteTask(Task task) } } - private Task WriteValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { return GetTerminationTask(WriteUnterminatedValue(value, type, scale, actualLength, encodingByteSize, offset, stateObj, paramSize, isDataFeed), - value, type, actualLength, stateObj, isDataFeed); + type, actualLength, stateObj, isDataFeed); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteUnterminatedValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { - Debug.Assert((null != value) && (DBNull.Value != value), "unexpected missing or empty object"); + Debug.Assert((null != value) && !(value is DBNull), "unexpected missing or empty object"); // parameters are always sent over as BIG or N types switch (type.NullableType) { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat((Single)value, stateObj); + WriteFloat(GenericConverter.Convert(value), stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble((Double)value, stateObj); + WriteDouble(GenericConverter.Convert(value), stateObj); } break; @@ -12660,7 +12682,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int if (isDataFeed) { Debug.Assert(type.IsPlp, "Stream assigned to non-PLP was not converted!"); - return NullIfCompletedWriteTask(WriteStreamFeed((StreamDataFeed)value, stateObj, paramSize)); + return NullIfCompletedWriteTask(WriteStreamFeed(GenericConverter.Convert(value), stateObj, paramSize)); } else { @@ -12668,14 +12690,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { WriteInt(actualLength, stateObj); // chunk length } - - return stateObj.WriteByteArray((byte[])value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, offset, canAccumulate: false); } } case TdsEnums.SQLUNIQUEID: { - System.Guid guid = (System.Guid)value; + System.Guid guid = GenericConverter.Convert(value); byte[] b = guid.ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); @@ -12686,7 +12707,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if ((bool)value == true) + if (GenericConverter.Convert(value) == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -12696,15 +12717,15 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte((byte)value); + stateObj.WriteByte(GenericConverter.Convert(value)); else if (type.FixedLength == 2) - WriteShort((Int16)value, stateObj); + WriteShort(GenericConverter.Convert(value), stateObj); else if (type.FixedLength == 4) - WriteInt((Int32)value, stateObj); + WriteInt(GenericConverter.Convert(value), stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong((Int64)value, stateObj); + WriteLong(GenericConverter.Convert(value), stateObj); } break; @@ -12722,7 +12743,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); } else { @@ -12737,11 +12758,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false); } else { - return WriteEncodingChar((string)value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(GenericConverter.Convert(value), actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } } } @@ -12759,7 +12780,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(GenericConverter.Convert(value), stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); } else { @@ -12782,29 +12803,29 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(GenericConverter.Convert(value), actualLength, 0, canAccumulate: false); } else { // convert to cchars instead of cbytes actualLength >>= 1; - return WriteString((string)value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(GenericConverter.Convert(value), actualLength, offset, stateObj, canAccumulate: false); } } } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteDecimal((Decimal)value, stateObj); + WriteDecimal(GenericConverter.Convert(value), stateObj); break; case TdsEnums.SQLDATETIMN: Debug.Assert(type.FixedLength <= 0xff, "Invalid Fixed Length"); - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, (byte)type.FixedLength); + TdsDateTime dt = MetaType.FromDateTime(GenericConverter.Convert(value), (byte)type.FixedLength); if (type.FixedLength == 4) { - if (0 > dt.days || dt.days > UInt16.MaxValue) + if (0 > dt.days || dt.days > ushort.MaxValue) throw SQL.SmallDateTimeOverflow(MetaType.ToDateTime(dt.days, dt.time, 4).ToString(CultureInfo.InvariantCulture)); WriteShort(dt.days, stateObj); @@ -12820,13 +12841,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLMONEYN: { - WriteCurrency((Decimal)value, type.FixedLength, stateObj); + WriteCurrency(GenericConverter.Convert(value), type.FixedLength, stateObj); break; } case TdsEnums.SQLDATE: { - WriteDate((DateTime)value, stateObj); + WriteDate(GenericConverter.Convert(value), stateObj); break; } @@ -12835,7 +12856,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteTime((TimeSpan)value, scale, actualLength, stateObj); + WriteTime(GenericConverter.Convert(value), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIME2: @@ -12843,11 +12864,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteDateTime2((DateTime)value, scale, actualLength, stateObj); + WriteDateTime2(GenericConverter.Convert(value), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: - WriteDateTimeOffset((DateTimeOffset)value, scale, actualLength, stateObj); + WriteDateTimeOffset(GenericConverter.Convert(value), scale, actualLength, stateObj); break; default: @@ -13092,7 +13113,7 @@ private byte[] SerializeUnencryptedValue(object value, MetaType type, byte scale // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) + private byte[] SerializeUnencryptedSqlValue(T value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -13108,11 +13129,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - return SerializeFloat(((SqlSingle)value).Value); + { + return SerializeFloat(GenericConverter.Convert(value).Value); + } else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - return SerializeDouble(((SqlDouble)value).Value); + return SerializeDouble(GenericConverter.Convert(value).Value); } case TdsEnums.SQLBIGBINARY: @@ -13123,19 +13146,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlBinary) { - Buffer.BlockCopy(((SqlBinary)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength); } else { Debug.Assert(value is SqlBytes); - Buffer.BlockCopy(((SqlBytes)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(GenericConverter.Convert(value).Value, offset, b, 0, actualLength); } + return b; } case TdsEnums.SQLUNIQUEID: { - byte[] b = ((SqlGuid)value).ToByteArray(); + byte[] b = GenericConverter.Convert(value).ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); return b; @@ -13146,37 +13170,37 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); // We normalize to allow conversion across data types. BIT is serialized into a BIGINT. - return SerializeLong(((SqlBoolean)value).Value == true ? 1 : 0, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value == true ? 1 : 0, stateObj); } case TdsEnums.SQLINTN: // We normalize to allow conversion across data types. All data types below are serialized into a BIGINT. if (type.FixedLength == 1) - return SerializeLong(((SqlByte)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); if (type.FixedLength == 2) - return SerializeLong(((SqlInt16)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); if (type.FixedLength == 4) - return SerializeLong(((SqlInt32)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - return SerializeLong(((SqlInt64)value).Value, stateObj); + return SerializeLong(GenericConverter.Convert(value).Value, stateObj); } case TdsEnums.SQLBIGCHAR: case TdsEnums.SQLBIGVARCHAR: case TdsEnums.SQLTEXT: - if (value is System.Data.SqlTypes.SqlChars) + if (value is SqlChars) { - String sch = new String(((System.Data.SqlTypes.SqlChars)value).Value); + String sch = new String(GenericConverter.Convert(value).Value); return SerializeEncodingChar(sch, actualLength, offset, _defaultEncoding); } else { Debug.Assert(value is SqlString); - return SerializeEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding); + return SerializeEncodingChar(GenericConverter.Convert(value).Value, actualLength, offset, _defaultEncoding); } @@ -13189,22 +13213,22 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (actualLength != 0) actualLength >>= 1; - if (value is System.Data.SqlTypes.SqlChars) + if (value is SqlChars) { - return SerializeCharArray(((System.Data.SqlTypes.SqlChars)value).Value, actualLength, offset); + return SerializeCharArray(GenericConverter.Convert(value).Value, actualLength, offset); } else { Debug.Assert(value is SqlString); - return SerializeString(((SqlString)value).Value, actualLength, offset); + return SerializeString(GenericConverter.Convert(value).Value, actualLength, offset); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - return SerializeSqlDecimal((SqlDecimal)value, stateObj); + return SerializeSqlDecimal(GenericConverter.Convert(value), stateObj); case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = GenericConverter.Convert(value); if (type.FixedLength == 4) { @@ -13250,7 +13274,7 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLMONEYN: { - return SerializeSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + return SerializeSqlMoney(GenericConverter.Convert(value), type.FixedLength, stateObj); } default: diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 877791abe3..d9618adf94 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -54,7 +54,7 @@ - + Common\System\Collections\DictionaryExtensions.cs @@ -66,6 +66,7 @@ + @@ -274,6 +275,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs index 9b56183ccf..1bcc1fa83d 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs @@ -159,7 +159,7 @@ private bool StringToIntTest(SqlConnection cnn, string targetTable, SourceType s string expectedErrorMsg = string.Format(pattern, args); - Assert.True(ex.Message.Contains(expectedErrorMsg), "Unexpected error message: " + ex.Message); + Assert.True(ex.Message.Contains(expectedErrorMsg), $"Unexpected error message: {ex}"); // write out stack trace for unexpected messages hitException = true; } return hitException; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs new file mode 100644 index 0000000000..8b1f619f15 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValueTypes.cs @@ -0,0 +1,429 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Configs; +using BenchmarkDotNet.Diagnosers; +using BenchmarkDotNet.Jobs; +using BenchmarkDotNet.Running; +using BenchmarkDotNet.Validators; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class NoBoxingValueTypes : IDisposable + { + private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); + private const int _count = 5000; + private static readonly ItemToCopy _item; + private static readonly IEnumerable _items; + private static readonly IDataReader _reader; + + private static readonly string _connString = DataTestUtility.TCPConnectionString; + + private class ItemToCopy + { + // keeping this data static so the performance of the benchmark is not varied by the data size & shape + public int IntColumn { get; } = 123456; + public bool BoolColumn { get; } = true; + } + + static NoBoxingValueTypes() + { + _item = new ItemToCopy(); + + _items = Enumerable.Range(0, _count).Select(x => _item).ToArray(); + + _reader = new EnumerableDataReaderFactoryBuilder(_table) + .Add("IntColumn", i => i.IntColumn) + .Add("BoolColumn", i => i.BoolColumn) + .BuildFactory() + .CreateReader(_items) + ; + } + + public NoBoxingValueTypes() + { + using (var conn = new SqlConnection(_connString)) + using (var cmd = conn.CreateCommand()) + { + conn.Open(); + Helpers.TryExecute(cmd, $@" + CREATE TABLE {_table} ( + IntColumn INT NOT NULL, + BoolColumn BIT NOT NULL + ) + "); + } + } + + private class RunOnceConfig : ManualConfig + { + public RunOnceConfig() + { + Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0)); + Add(MemoryDiagnoser.Default); + + Add(JitOptimizationsValidator.DontFailOnError); + } + } + + + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))] + public void Should_Not_Box() + { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail +#if DEBUG + return; +#else + //cannot figure out an easy way to get this to work on all platforms + + var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461 + + var summary = BenchmarkRunner.Run(config); + + var numValueTypeColumns = 2; + var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns; + + var report = summary.Reports.First(); + + Assert.Equal(1, report.AllMeasurements.Count); + Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed); +#endif + } + + public class NoBoxingValueTypesBenchmark + { + [Benchmark] + public void BulkCopy() + { + _reader.Close(); // this resets the reader + + using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock)) + { + bc.BatchSize = _count; + bc.DestinationTableName = _table; + bc.BulkCopyTimeout = 60; + + bc.WriteToServer(_reader); + } + } + } + + public void Dispose() + { + using (var conn = new SqlConnection(_connString)) + using (var cmd = conn.CreateCommand()) + { + conn.Open(); + Helpers.TryExecute(cmd, $@" + DROP TABLE IF EXISTS {_table} + "); + } + } + + //all code here and below is a custom data reader implementation to support the benchmark + private class EnumerableDataReaderFactoryBuilder + { + private readonly List _expressions = new List(); + private readonly List> _objExpressions = new List>(); + private readonly DataTable _schemaTable; + + public EnumerableDataReaderFactoryBuilder(string tableName) + { + Name = tableName; + _schemaTable = new DataTable(); + } + + private static readonly HashSet _validTypes = new HashSet + { + typeof(decimal), + typeof(decimal?), + typeof(string), + typeof(int), + typeof(int?), + typeof(double), + typeof(bool), + typeof(bool?), + typeof(Guid), + typeof(DateTime), + }; + + public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) + { + var t = typeof(TColumn); + + var func = expression.Compile(); + + // don't do any optimizations for boxing bools here to detect boxing occurring properly. + Expression> objExpression= o => func(o); + + _objExpressions.Add(objExpression.Compile()); + + if (_validTypes.Contains(t)) + { + t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable. + _schemaTable.Columns.Add(column, t); + _expressions.Add(expression); + } + else + { + Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}"); + _schemaTable.Columns.Add(column); //add w/o type to force using GetValue + + _expressions.Add(objExpression); + } + + return this; + } + + public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); + + public string Name { get; } + } + + public class EnumerableDataReaderFactory + { + public DataTable SchemaTable { get; } + public Func[] ObjectGetters { get; } + public Func[] DecimalGetters { get; } + public Func[] NullableDecimalGetters { get; } + public Func[] StringGetters { get; } + public Func[] DoubleGetters { get; } + public Func[] IntGetters { get; } + public Func[] NullableIntGetters { get; } + public Func[] BoolGetters { get; } + + public Func[] NullableBoolGetters { get; } + + public Func[] GuidGetters { get; } + public Func[] DateTimeGetters { get; } + public bool[] NullableIndexes { get; } + + public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters) + { + SchemaTable = schemaTable; + DecimalGetters = new Func[expressions.Count]; + NullableDecimalGetters = new Func[expressions.Count]; + StringGetters = new Func[expressions.Count]; + DoubleGetters = new Func[expressions.Count]; + IntGetters = new Func[expressions.Count]; + NullableIntGetters = new Func[expressions.Count]; + BoolGetters = new Func[expressions.Count]; + NullableBoolGetters = new Func[expressions.Count]; + GuidGetters = new Func[expressions.Count]; + DateTimeGetters = new Func[expressions.Count]; + NullableIndexes = new bool[expressions.Count]; + + ObjectGetters = objectGetters.ToArray(); + + for (int i = 0; i < expressions.Count; i++) + { + var expression = expressions[i]; + + NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null; + + switch (expression) + { + case Expression> e: + break; // do nothing + case Expression> e: + DecimalGetters[i] = e.Compile(); + break; + case Expression> e: + NullableDecimalGetters[i] = e.Compile(); + break; + case Expression> e: + StringGetters[i] = e.Compile(); + break; + case Expression> e: + DoubleGetters[i] = e.Compile(); + break; + case Expression> e: + IntGetters[i] = e.Compile(); + break; + case Expression> e: + NullableIntGetters[i] = e.Compile(); + break; + case Expression> e: + BoolGetters[i] = e.Compile(); + break; + case Expression> e: + NullableBoolGetters[i] = e.Compile(); + break; + case Expression> e: + GuidGetters[i] = e.Compile(); + break; + case Expression> e: + DateTimeGetters[i] = e.Compile(); + break; + default: + throw new Exception($"Type missing: {expression.GetType().FullName}"); + } + } + } + + public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); + } + + public class EnumerableDataReader : IDataReader + { + private readonly IEnumerator _source; + private readonly EnumerableDataReaderFactory _context; + + public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) + { + _source = source; + _context = context; + } + + public object GetValue(int i) + { + var v = _context.ObjectGetters[i](_source.Current); + return v; + } + + public int FieldCount => _context.ObjectGetters.Length; + + public bool Read() => _source.MoveNext(); + + public void Close() => _source.Reset(); + + public void Dispose() => this.Close(); + + public bool NextResult() => throw new NotImplementedException(); + + public int Depth => 0; + + public bool IsClosed => false; + + public int RecordsAffected => -1; + + public DataTable GetSchemaTable() => _context.SchemaTable; + + public object this[string name] => throw new NotImplementedException(); + + public object this[int i] => GetValue(i); + + public bool GetBoolean(int i) + { + var g = _context.BoolGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableBoolGetters[i](_source.Current).Value; + } + + public byte GetByte(int i) => throw new NotImplementedException(); + + public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException(); + + public char GetChar(int i) => throw new NotImplementedException(); + public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1; + + public IDataReader GetData(int i) => throw new NotImplementedException(); + + public string GetDataTypeName(int i) => throw new NotImplementedException(); + + public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current); + + public decimal GetDecimal(int i) + { + var g = _context.DecimalGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableDecimalGetters[i](_source.Current).Value; + } + + public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current); + + public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType; + + public float GetFloat(int i) => throw new NotImplementedException(); + + public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current); + + public short GetInt16(int i) => throw new NotImplementedException(); + + public int GetInt32(int i) + { + var g = _context.IntGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableIntGetters[i](_source.Current).Value; + } + + public long GetInt64(int i) => throw new NotImplementedException(); + + public string GetName(int i) + { + if (_context.SchemaTable.Columns.Count > i) + { + return _context.SchemaTable.Columns[i].ColumnName; + } + throw new IndexOutOfRangeException($"No column for index {i}"); + } + + public int GetOrdinal(string name) + { + if (_context.SchemaTable.Columns.Count == 0) + { + throw new Exception("Schema table is empty"); + } + return _context.SchemaTable.Columns.IndexOf(name); + } + + public string GetString(int i) => _context.StringGetters[i](_source.Current); + + public int GetValues(object[] values) => throw new NotImplementedException(); + + public bool IsDBNull(int i) + { + // short circuit for non-nullable types + if (!_context.NullableIndexes[i]) + { + return false; + } + + // otherwise find the first one -- starting w/ most occurring to least + + var ig = _context.NullableIntGetters[i]; + if (ig != null) + { + return ig(_source.Current) == null; + } + + var sg = _context.StringGetters[i]; + if (sg != null) + { + return sg(_source.Current) == null; + } + + var bg = _context.NullableBoolGetters[i]; + if (bg != null) + { + return bg(_source.Current) == null; + } + + var dg = _context.NullableDecimalGetters[i]; + if (dg != null) + { + return dg(_source.Current) == null; + } + + return false; + } + } + } +} diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 044d79f9d2..e9f4b9c0b9 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -12,7 +12,7 @@ - 2.1.0 + 2.1.0 4.3.1 4.3.0 @@ -26,6 +26,7 @@ + 0.11.3 4.7.0 2.1.0 4.7.0