diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 4546d7d1ab..e8035d5990 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -806,6 +806,7 @@ + True True diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs new file mode 100644 index 0000000000..dd2cce4869 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient +{ + /// + /// Serves to convert generic to out type by casting to object first. Relies on JIT to optimize out unneccessary casts and prevent double boxing. + /// + internal static class GenericCastExtensions + { + public static TOut GenericCast(this TIn value) + { + return (TOut)(object)value; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index b0dce3a6f9..843967781a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -9,6 +9,7 @@ using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -120,8 +121,8 @@ private enum ValueSourceType DbDataReader } - // Enum for specifying SqlDataReader.Get method used - private enum ValueMethod : byte + // Enum for specifying SqlDataReader.Get / IDataReader method used + private enum ValueMethod { GetValue, SqlTypeSqlDecimal, @@ -129,7 +130,19 @@ private enum ValueMethod : byte SqlTypeSqlSingle, DataFeedStream, DataFeedText, - DataFeedXml + DataFeedXml, + GetInt32, + GetString, + GetDouble, + GetDecimal, + GetInt16, + GetInt64, + GetChar, + GetByte, + GetBoolean, + GetDateTime, + GetGuid, + GetFloat } // Used to hold column metadata for SqlDataReader case @@ -873,11 +886,19 @@ private void Dispose(bool disposing) } // Unified method to read a value from the current row - private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull) + // Reads a cell and then writes it. + // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. + // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. + // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. + private Task ReadWriteColumnValueAsync(int destRowIndex) { _SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata; int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal; + bool isSqlType = false; + bool isDataFeed = false; + bool isNull = false; + switch (_rowSourceType) { case ValueSourceType.IDataReader: @@ -887,34 +908,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_dbDataReaderRowSource.IsDBNull(sourceOrdinal)) { - isSqlType = false; - isDataFeed = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - isSqlType = false; isDataFeed = true; - isNull = false; + + object feedColumnValue; + switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.DataFeedStream: - return new StreamDataFeed(_dbDataReaderRowSource.GetStream(sourceOrdinal)); + feedColumnValue = new StreamDataFeed(_dbDataReaderRowSource.GetStream(sourceOrdinal)); + + break; case ValueMethod.DataFeedText: - return new TextDataFeed(_dbDataReaderRowSource.GetTextReader(sourceOrdinal)); + feedColumnValue = new TextDataFeed(_dbDataReaderRowSource.GetTextReader(sourceOrdinal)); + break; case ValueMethod.DataFeedXml: // Only SqlDataReader supports an XmlReader // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field Debug.Assert(_sqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader"); - return new XmlDataFeed(_sqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + feedColumnValue = new XmlDataFeed(_sqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + break; default: Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); isDataFeed = false; - object columnValue = _dbDataReaderRowSource.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + feedColumnValue = _dbDataReaderRowSource.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType); + break; } + + //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario + return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull); } } // SqlDataReader-specific logic @@ -922,36 +949,30 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_currentRowMetadata[destRowIndex].IsSqlType) { - INullable value; isSqlType = true; - isDataFeed = false; switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.SqlTypeSqlDecimal: - value = _sqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); - break; + var value = _sqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull); case ValueMethod.SqlTypeSqlDouble: // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_sqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); - break; + var dblValue = (SqlDecimal)_sqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); + return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull); case ValueMethod.SqlTypeSqlSingle: - // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_sqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); - break; + // use cast to handle value.IsNull correctly because no public constructor allows it + var singleValue = (SqlDecimal)_sqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); + return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull); default: Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - value = (INullable)_sqlDataReaderRowSource.GetSqlValue(sourceOrdinal); - break; + var sqlValue = (INullable)_sqlDataReaderRowSource.GetSqlValue(sourceOrdinal); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull); } - isNull = value.IsNull; - return value; + } else { - isSqlType = false; - isDataFeed = false; - object value = _sqlDataReaderRowSource.GetValue(sourceOrdinal); isNull = ((value == null) || (value == DBNull.Value)); if ((!isNull) && (metadata.type == SqlDbType.Udt)) @@ -965,30 +986,58 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable"); } #endif - return value; + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull); } } else { - isDataFeed = false; - IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource; - - // Only use IsDbNull when streaming is enabled and only for non-SqlDataReader - if ((_enableStreaming) && (_sqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal))) + // previously, IsDbNull was only invoked in a non-streaming scenario with a non-SqlDataReader. + // based on the else if above, the non-SqlDataReader check was superfluous + // the new logic to not rely only on IDataReader.GetValue needs DbNull + // this could potentially be a breaking change to custom IDataReader implementations that incorrectly return IsDbNull + if (rowSourceAsIDataReader.IsDBNull(sourceOrdinal)) { - isSqlType = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + switch (_currentRowMetadata[destRowIndex].Method) + { + case ValueMethod.GetInt32: + return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false); + case ValueMethod.GetString: + var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal); + isNull = strValue == null; + return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDouble: + return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDecimal: + return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt16: + return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt64: + return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetChar: + return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetByte: + return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetBoolean: + return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDateTime: + return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetGuid: + return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetFloat: + return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + default: + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } } - case ValueSourceType.DataTable: case ValueSourceType.RowArray: { @@ -996,6 +1045,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!"); isDataFeed = false; + // unfortunately this has to be boxed due to DataRow's API. object currentRowValue = _currentRow[sourceOrdinal]; ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType); @@ -1008,7 +1058,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (isSqlType) { - return new SqlDecimal(((SqlSingle)currentRowValue).Value); + var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value); + return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1016,16 +1067,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!float.IsNaN(f)) { isSqlType = true; - return new SqlDecimal(f); + return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDouble: { if (isSqlType) { - return new SqlDecimal(((SqlDouble)currentRowValue).Value); + var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1033,33 +1088,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!double.IsNaN(d)) { isSqlType = true; - return new SqlDecimal(d); + return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDecimal: { if (isSqlType) { - return (SqlDecimal)currentRowValue; + var sqlValue = (SqlDecimal)currentRowValue; + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { isSqlType = true; - return new SqlDecimal((decimal)currentRowValue); + var sqlValue = new SqlDecimal((decimal)currentRowValue); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } } default: { Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - break; + // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } } } - - // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) - return currentRowValue; + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } default: { @@ -1260,6 +1322,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) method = ValueMethod.GetValue; } } + else if (_rowSourceType == ValueSourceType.IDataReader) + { + isSqlType = false; + isDataFeed = false; + + Type t = ((IDataReader)_rowSource).GetFieldType(ordinal); + + if (t == typeof(bool)) + { + method = ValueMethod.GetBoolean; + } + else if (t == typeof(byte)) + { + method = ValueMethod.GetByte; + } + else if (t == typeof(char)) + { + method = ValueMethod.GetChar; + } + else if (t == typeof(DateTime)) + { + method = ValueMethod.GetDateTime; + } + else if (t == typeof(decimal)) + { + method = ValueMethod.GetDecimal; + } + else if (t == typeof(double)) + { + method = ValueMethod.GetDouble; + } + else if (t == typeof(float)) + { + method = ValueMethod.GetFloat; + } + else if (t == typeof(Guid)) + { + method = ValueMethod.GetGuid; + } + else if (t == typeof(short)) + { + method = ValueMethod.GetInt16; + } + else if (t == typeof(int)) + { + method = ValueMethod.GetInt32; + } + else if (t == typeof(long)) + { + method = ValueMethod.GetInt64; + } + else if (t == typeof(string)) + { + method = ValueMethod.GetString; + } + else + { + method = ValueMethod.GetValue; + } + } else { isSqlType = false; @@ -1394,8 +1516,10 @@ private string UnquotedName(string name) return name; } - private object ValidateBulkCopyVariant(object value) + private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue) { + variantValue = null; + // From the spec: // "The only acceptable types are ..." // GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8, @@ -1423,32 +1547,24 @@ private object ValidateBulkCopyVariant(object value) case TdsEnums.SQLDATETIMEOFFSET: if (value is INullable) { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types. - return MetaType.GetComValueFromSqlVariant(value); + variantValue = MetaType.GetComValueFromSqlVariant(value); + return true; } else { - return value; + return false; } default: throw SQL.BulkLoadInvalidVariantValue(); } } - private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed) + private bool ConvertValueIfNeeded(T value, _SqlMetaData metadata, ref bool isSqlType, out bool coercedToDataFeed, out object convertedValue) { - coercedToDataFeed = false; - - if (isNull) - { - if (!metadata.IsNullable) - { - throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); - } - return value; - } - MetaType type = metadata.metaType; bool typeChanged = false; + coercedToDataFeed = false; + convertedValue = null; // If the column is encrypted then we are going to transparently encrypt this column // (based on connection string setting)- Use the metaType for the underlying @@ -1473,24 +1589,26 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { case TdsEnums.SQLNUMERICN: case TdsEnums.SQLDECIMALN: - mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); - - // Convert Source Decimal Precision and Scale to Destination Precision and Scale - // Sql decimal data could get corrupted on insert if the scale of - // the source and destination weren't the same. The BCP protocol, specifies the - // scale of the incoming data in the insert statement, we just tell the server we - // are inserting the same scale back. SqlDecimal sqlValue; - if ((isSqlType) && (!typeChanged)) + if (typeof(T) == typeof(decimal)) + { + sqlValue = new SqlDecimal(value.GenericCast()); + } + else if (typeof(T) == typeof(SqlDecimal)) { - sqlValue = (SqlDecimal)value; + sqlValue = value.GenericCast(); } else { - sqlValue = new SqlDecimal((decimal)value); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + sqlValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false)); } + // Convert Source Decimal Precision and Scale to Destination Precision and Scale + // Sql decimal data could get corrupted on insert if the scale of + // the source and destination weren't the same. The BCP protocol, specifies the + // scale of the incoming data in the insert statement, we just tell the server we + // are inserting the same scale back. if (sqlValue.Scale != scale) { sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale); @@ -1504,15 +1622,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } catch (SqlTruncateException) { + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue)); } } // Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing - value = sqlValue; isSqlType = true; typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type - break; + + // returning here to avoid unnecessary decValue initialization for all types + return typeChanged; case TdsEnums.SQLINTN: case TdsEnums.SQLFLTN: @@ -1536,16 +1656,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out convertedValue, out coercedToDataFeed); break; case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out convertedValue, out coercedToDataFeed, false); if (!coercedToDataFeed) { // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX) - string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value); + string str = typeChanged + ? (string)convertedValue + : isSqlType + ? value.GenericCast().Value + : value.GenericCast() + ; + int maxStringLength = length / 2; if (str.Length > maxStringLength) { @@ -1564,8 +1690,7 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } break; case TdsEnums.SQLVARIANT: - value = ValidateBulkCopyVariant(value); - typeChanged = true; + typeChanged = ValidateBulkCopyVariantIfNeeded(value, out convertedValue); break; case TdsEnums.SQLUDT: // UDTs are sent as varbinary so we need to get the raw bytes @@ -1576,16 +1701,16 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re // in byte[] form. if (!(value is byte[])) { - value = _connection.GetBytes(value); + convertedValue = _connection.GetBytes(value); typeChanged = true; } break; case TdsEnums.SQLXMLTYPE: // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype"); - if (value is XmlReader) + if (value is XmlReader xmlReader) { - value = new XmlDataFeed((XmlReader)value); + convertedValue = new XmlDataFeed(xmlReader); typeChanged = true; coercedToDataFeed = true; } @@ -1595,14 +1720,6 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null)); throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null); } - - if (typeChanged) - { - // All type changes change to CLR types - isSqlType = false; - } - - return value; } catch (Exception e) { @@ -1612,6 +1729,8 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } + + return typeChanged; } /// @@ -2135,33 +2254,51 @@ private bool FireRowsCopiedEvent(long rowsCopied) return eventArgs.Abort; } - // Reads a cell and then writes it. - // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. - // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. - // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. - private Task ReadWriteColumnValueAsync(int col) + private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull) { - bool isSqlType; - bool isDataFeed; - bool isNull; - object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask - _SqlMetaData metadata = _sortedColumnMappings[col]._metadata; - if (!isDataFeed) + object convertedValue = null; + bool isTypeChanged = false; + if (isDataFeed) { - value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed); + //nothing to convert, skip straight to write + } + else if (isNull) + { + if (!metadata.IsNullable) + { + throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); + } + + // don't need to convert nulls + } + else + { + isTypeChanged = ConvertValueIfNeeded(value, metadata, ref isSqlType, out isDataFeed, out convertedValue); // If column encryption is requested via connection string option, perform encryption here - if (!isNull && // if value is not NULL - metadata.isEncrypted) - { // If we are transparently encrypting + if (metadata.isEncrypted) // If we are transparently encrypting + { Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); - value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType); + + convertedValue = isTypeChanged + ? _parser.EncryptColumnValue(convertedValue, metadata, metadata.column, _stateObj, isDataFeed, isSqlType) + : _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType) + ; + + isTypeChanged = true; // we should use converted value from here on. isSqlType = false; // Its not a sql type anymore } } - //write part + return isTypeChanged + ? DoWriteValueAsync(convertedValue, col, isSqlType, isDataFeed, isNull, metadata) + : DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata) + ; + } + + private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata) + { Task writeTask = null; if (metadata.type != SqlDbType.Variant) { @@ -2180,11 +2317,11 @@ private Task ReadWriteColumnValueAsync(int col) if (variantInternalType == SqlBuffer.StorageType.DateTime2) { - _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDateTime2(value.GenericCast(), _stateObj); } else if (variantInternalType == SqlBuffer.StorageType.Date) { - _parser.WriteSqlVariantDate(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDate(value.GenericCast(), _stateObj); } else { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs index 2bab7e4dbd..d75c27e9b7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -2084,14 +2084,24 @@ private int ValueSizeCore(object value) // Coerced Value is also used in SqlBulkCopy.ConvertValue(object value, _SqlMetaData metadata) internal static object CoerceValue(object value, MetaType destinationType, out bool coercedToDataFeed, out bool typeChanged, bool allowStreaming = true) + { + typeChanged = CoerceValueIfNeeded(value, destinationType, out var objValue, out coercedToDataFeed, allowStreaming); + + return typeChanged ? objValue : value; + } + + internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, out object objValue, out bool coercedToDataFeed, bool allowStreaming = true) { Debug.Assert(!(value is DataFeed), "Value provided should not already be a data feed"); Debug.Assert(!ADP.IsNull(value), "Value provided should not be null"); Debug.Assert(null != destinationType, "null destinationType"); + objValue = null; coercedToDataFeed = false; - typeChanged = false; - Type currentType = value.GetType(); + var typeChanged = false; + Type currentType = typeof(T) == typeof(object) + ? value.GetType() + : typeof(T); if ( (destinationType.ClassType != typeof(object)) && @@ -2111,45 +2121,45 @@ internal static object CoerceValue(object value, MetaType destinationType, out b // For Xml data, destination Type is always string if (currentType == typeof(SqlXml)) { - value = MetaType.GetStringFromXml((XmlReader)(((SqlXml)value).CreateReader())); + objValue = MetaType.GetStringFromXml(value.GenericCast().CreateReader()); } else if (currentType == typeof(SqlString)) { typeChanged = false; // Do nothing } - else if (typeof(XmlReader).IsAssignableFrom(currentType)) + else if (value is XmlReader xmlReader) { if (allowStreaming) { coercedToDataFeed = true; - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); } else { - value = MetaType.GetStringFromXml((XmlReader)value); + objValue = MetaType.GetStringFromXml(xmlReader); } } else if (currentType == typeof(char[])) { - value = new string((char[])value); + objValue = new string(value.GenericCast()); } else if (currentType == typeof(SqlChars)) { - value = new string(((SqlChars)value).Value); + objValue = new string(value.GenericCast().Value); } else if (value is TextReader textReader && allowStreaming) { coercedToDataFeed = true; - value = new TextDataFeed(textReader); + objValue = new TextDataFeed(textReader); } else { - value = Convert.ChangeType(value, destinationType.ClassType, null); + objValue = Convert.ChangeType(value, destinationType.ClassType, null); } } else if ((destinationType.DbType == DbType.Currency) && (currentType == typeof(string))) { - value = decimal.Parse((string)value, NumberStyles.Currency, null); + objValue = decimal.Parse(value.GenericCast(), NumberStyles.Currency, null); } else if ((currentType == typeof(SqlBytes)) && (destinationType.ClassType == typeof(byte[]))) { @@ -2157,15 +2167,15 @@ internal static object CoerceValue(object value, MetaType destinationType, out b } else if ((currentType == typeof(string)) && (destinationType.SqlDbType == SqlDbType.Time)) { - value = TimeSpan.Parse((string)value); + objValue = TimeSpan.Parse(value.GenericCast()); } else if ((currentType == typeof(string)) && (destinationType.SqlDbType == SqlDbType.DateTimeOffset)) { - value = DateTimeOffset.Parse((string)value, (IFormatProvider)null); + objValue = DateTimeOffset.Parse(value.GenericCast(), (IFormatProvider)null); } else if ((currentType == typeof(DateTime)) && (destinationType.SqlDbType == SqlDbType.DateTimeOffset)) { - value = new DateTimeOffset((DateTime)value); + objValue = new DateTimeOffset(value.GenericCast()); } else if ( TdsEnums.SQLTABLE == destinationType.TDSType && @@ -2182,11 +2192,11 @@ value is IEnumerable else if (destinationType.ClassType == typeof(byte[]) && allowStreaming && value is Stream stream) { coercedToDataFeed = true; - value = new StreamDataFeed(stream); + objValue = new StreamDataFeed(stream); } else { - value = Convert.ChangeType(value, destinationType.ClassType, null); + objValue = Convert.ChangeType(value, destinationType.ClassType, null); } } catch (Exception e) @@ -2201,8 +2211,8 @@ value is IEnumerable } Debug.Assert(allowStreaming || !coercedToDataFeed, "Streaming is not allowed, but type was coerced into a data feed"); - Debug.Assert(value.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); - return value; + Debug.Assert(objValue == null || objValue.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); + return typeChanged; } private static int StringSize(object value, bool isSqlType) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index e0ebfdf669..78da1a5a03 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -6642,10 +6642,10 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars // Therefore the sql_variant value must not include the MaxLength. This is the major difference // between this method and WriteSqlVariantValue above. // - internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject stateObj, bool canAccumulate = true) + internal Task WriteSqlVariantDataRowValue(T value, TdsParserStateObject stateObj, bool canAccumulate = true) { // handle null values - if ((null == value) || (DBNull.Value == value)) + if (null == value || value is DBNull) { WriteInt(TdsEnums.FIXEDNULL, stateObj); return null; @@ -6656,44 +6656,44 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta if (metatype.IsAnsiType) { - length = GetEncodingCharLength((string)value, length, 0, _defaultEncoding); + length = GetEncodingCharLength(value.GenericCast(), length, 0, _defaultEncoding); } switch (metatype.TDSType) { case TdsEnums.SQLFLT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteFloat((float)value, stateObj); + WriteFloat(value.GenericCast(), stateObj); break; case TdsEnums.SQLFLT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteDouble((double)value, stateObj); + WriteDouble(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteLong((long)value, stateObj); + WriteLong(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteInt((int)value, stateObj); + WriteInt(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT2: WriteSqlVariantHeader(4, metatype.TDSType, metatype.PropBytes, stateObj); - WriteShort((short)value, stateObj); + WriteShort(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT1: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - stateObj.WriteByte((byte)value); + stateObj.WriteByte(value.GenericCast()); break; case TdsEnums.SQLBIT: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - if ((bool)value == true) + if (value.GenericCast()) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -6702,7 +6702,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARBINARY: { - byte[] b = (byte[])value; + byte[] b = value.GenericCast(); length = b.Length; WriteSqlVariantHeader(4 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6712,7 +6712,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARCHAR: { - string s = (string)value; + string s = value.GenericCast(); length = s.Length; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6724,7 +6724,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLUNIQUEID: { - System.Guid guid = (System.Guid)value; + Guid guid = value.GenericCast(); Span b = stackalloc byte[16]; FillGuidBytes(guid, b); @@ -6737,7 +6737,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLNVARCHAR: { - string s = (string)value; + string s = value.GenericCast(); length = s.Length * 2; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6752,7 +6752,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLDATETIME: { - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, 8); + TdsDateTime dt = MetaType.FromDateTime(value.GenericCast(), 8); WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); WriteInt(dt.days, stateObj); @@ -6763,7 +6763,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLMONEY: { WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteCurrency((decimal)value, 8, stateObj); + WriteCurrency(value.GenericCast(), 8, stateObj); break; } @@ -6771,21 +6771,22 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta { WriteSqlVariantHeader(21, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Precision); //propbytes: precision - stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale - WriteDecimal((decimal)value, stateObj); + var decValue = value.GenericCast(); + stateObj.WriteByte((byte)((decimal.GetBits(decValue)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale + WriteDecimal(decValue, stateObj); break; } case TdsEnums.SQLTIME: WriteSqlVariantHeader(8, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteTime((TimeSpan)value, metatype.Scale, 5, stateObj); + WriteTime(value.GenericCast(), metatype.Scale, 5, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: WriteSqlVariantHeader(13, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteDateTimeOffset((DateTimeOffset)value, metatype.Scale, 10, stateObj); + WriteDateTimeOffset(value.GenericCast(), metatype.Scale, 10, stateObj); break; default: @@ -10397,7 +10398,7 @@ internal bool ShouldEncryptValuesForBulkCopy() /// Encrypts a column value (for SqlBulkCopy) /// /// - internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) + internal byte[] EncryptColumnValue(T value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) { Debug.Assert(IsColumnEncryptionSupported, "Server doesn't support encryption, yet we received encryption metadata"); Debug.Assert(ShouldEncryptValuesForBulkCopy(), "Encryption attempted when not requested"); @@ -10422,7 +10423,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin // when we normalize and serialize the data buffers. The serialization routine expects us // to report the size of data to be copied out (for serialization). If we underreport the // size, truncation will happen for us! - actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + actualLengthInBytes = (isSqlType) + ? value.GenericCast().Length + : value.GenericCast().Length; + if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) { @@ -10442,7 +10446,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = (isSqlType) ? ((SqlString)value).Value : (string)value; + string stringValue = (isSqlType) + ? value.GenericCast().Value + : value.GenericCast(); + actualLengthInBytes = _defaultEncoding.GetByteCount(stringValue); // If the string length is > max length, then use the max length (see comments above) @@ -10456,7 +10463,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - actualLengthInBytes = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + actualLengthInBytes = (isSqlType + ? value.GenericCast().Value.Length + : value.GenericCast().Length) + * 2; if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) @@ -10501,7 +10511,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin _connHandler.Connection); } - internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) + internal Task WriteBulkCopyValue(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) { Debug.Assert(!isSqlType || value is INullable, "isSqlType is true, but value can not be type cast to an INullable"); Debug.Assert(!isDataFeed ^ value is DataFeed, "Incorrect value for isDataFeed"); @@ -10558,6 +10568,8 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars return resultTask; } + string stringValue = null; + if (!isDataFeed) { switch (metatype.NullableType) @@ -10566,7 +10578,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: - ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + ccb = (isSqlType) + ? value.GenericCast().Length + : value.GenericCast().Length; break; case TdsEnums.SQLUNIQUEID: ccb = GUID_SIZE; @@ -10579,15 +10593,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = null; - if (isSqlType) - { - stringValue = ((SqlString)value).Value; - } - else - { - stringValue = (string)value; - } + stringValue = isSqlType + ? value.GenericCast().Value + : value.GenericCast(); ccb = stringValue.Length; ccbStringBytes = _defaultEncoding.GetByteCount(stringValue); @@ -10595,15 +10603,29 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + stringValue = stringValue = isSqlType + ? value.GenericCast().Value + : value.GenericCast(); + + ccb = stringValue.Length * 2; break; case TdsEnums.SQLXMLTYPE: // Value here could be string or XmlReader - if (value is XmlReader) + + if (value is XmlReader xr) { - value = MetaType.GetStringFromXml((XmlReader)value); + stringValue = MetaType.GetStringFromXml(xr); } - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + else if (isSqlType) + { + stringValue = value.GenericCast().Value; + } + else + { + stringValue = value.GenericCast(); + } + + ccb = stringValue.Length * 2; break; default: @@ -10653,19 +10675,18 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars { internalWriteTask = WriteSqlValue(value, metatype, ccb, ccbStringBytes, 0, stateObj); } + else if (stringValue != null) + { + internalWriteTask = WriteValueWithWait(stringValue, metadata, stateObj, isDataFeed, metatype, ccb, ccbStringBytes); + } else if (metatype.SqlDbType != SqlDbType.Udt || metatype.IsLong) { - internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); - if ((internalWriteTask == null) && (_asyncWrite)) - { - internalWriteTask = stateObj.WaitForAccumulatedWrites(); - } - Debug.Assert(_asyncWrite || stateObj.WaitForAccumulatedWrites() == null, "Should not have accumulated writes when writing sync"); + internalWriteTask = WriteValueWithWait(value, metadata, stateObj, isDataFeed, metatype, ccb, ccbStringBytes); } else { WriteShort(ccb, stateObj); - internalWriteTask = stateObj.WriteByteArray((byte[])value, ccb, 0); + internalWriteTask = stateObj.WriteByteArray(value.GenericCast(), ccb, 0); } #if DEBUG @@ -10693,6 +10714,17 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars return resultTask; } + private Task WriteValueWithWait(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isDataFeed, MetaType metatype, int ccb, int ccbStringBytes) + { + Task internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); + if ((internalWriteTask == null) && (_asyncWrite)) + { + internalWriteTask = stateObj.WaitForAccumulatedWrites(); + } + Debug.Assert(_asyncWrite || stateObj.WaitForAccumulatedWrites() == null, "Should not have accumulated writes when writing sync"); + return internalWriteTask; + } + // This is in its own method to avoid always allocating the lambda in WriteBulkCopyValue private Task WriteBulkCopyValueSetupContinuation(Task internalWriteTask, Encoding saveEncoding, SqlCollation saveCollation, int saveCodePage, int saveLCID) { @@ -10974,7 +11006,7 @@ private bool IsBOMNeeded(MetaType type, object value) return false; } - private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) + private Task GetTerminationTask(Task unterminatedWriteTask, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) { if (type.IsPlp && ((actualLength > 0) || isDataFeed)) { @@ -10995,16 +11027,16 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } - private Task WriteSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { return GetTerminationTask( WriteUnterminatedSqlValue(value, type, actualLength, codePageByteSize, offset, stateObj), - value, type, actualLength, stateObj, false); + type, actualLength, stateObj, false); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteUnterminatedSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -11015,11 +11047,11 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat(((SqlSingle)value).Value, stateObj); + WriteFloat(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble(((SqlDouble)value).Value, stateObj); + WriteDouble(value.GenericCast().Value, stateObj); } break; @@ -11035,12 +11067,12 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlBinary) { - return stateObj.WriteByteArray(((SqlBinary)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast().Value, actualLength, offset, canAccumulate: false); } else { Debug.Assert(value is SqlBytes); - return stateObj.WriteByteArray(((SqlBytes)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast().Value, actualLength, offset, canAccumulate: false); } } @@ -11048,7 +11080,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - SqlGuid sqlGuid = (SqlGuid)value; + SqlGuid sqlGuid = value.GenericCast(); if (sqlGuid.IsNull) { @@ -11066,7 +11098,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if (((SqlBoolean)value).Value == true) + if (value.GenericCast().Value == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11076,17 +11108,17 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte(((SqlByte)value).Value); + stateObj.WriteByte(value.GenericCast().Value); else if (type.FixedLength == 2) - WriteShort(((SqlInt16)value).Value, stateObj); + WriteShort(value.GenericCast().Value, stateObj); else if (type.FixedLength == 4) - WriteInt(((SqlInt32)value).Value, stateObj); + WriteInt(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong(((SqlInt64)value).Value, stateObj); + WriteLong(value.GenericCast().Value, stateObj); } break; @@ -11100,14 +11132,14 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe } if (value is SqlChars) { - string sch = new string(((SqlChars)value).Value); + string sch = new string(value.GenericCast().Value); return WriteEncodingChar(sch, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(value.GenericCast().Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } @@ -11136,21 +11168,21 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlChars) { - return WriteCharArray(((SqlChars)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteCharArray(value.GenericCast().Value, actualLength, offset, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteString(((SqlString)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(value.GenericCast().Value, actualLength, offset, stateObj, canAccumulate: false); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteSqlDecimal((SqlDecimal)value, stateObj); + WriteSqlDecimal(value.GenericCast(), stateObj); break; case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = value.GenericCast(); if (type.FixedLength == 4) { @@ -11170,7 +11202,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLMONEYN: { - WriteSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + WriteSqlMoney(value.GenericCast(), type.FixedLength, stateObj); break; } @@ -11640,28 +11672,28 @@ private Task NullIfCompletedWriteTask(Task task) } } - private Task WriteValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { return GetTerminationTask(WriteUnterminatedValue(value, type, scale, actualLength, encodingByteSize, offset, stateObj, paramSize, isDataFeed), - value, type, actualLength, stateObj, isDataFeed); + type, actualLength, stateObj, isDataFeed); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteUnterminatedValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { - Debug.Assert((null != value) && (DBNull.Value != value), "unexpected missing or empty object"); + Debug.Assert((null != value) && !(value is DBNull), "unexpected missing or empty object"); // parameters are always sent over as BIG or N types switch (type.NullableType) { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat((float)value, stateObj); + WriteFloat(value.GenericCast(), stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble((double)value, stateObj); + WriteDouble(value.GenericCast(), stateObj); } break; @@ -11678,7 +11710,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int if (isDataFeed) { Debug.Assert(type.IsPlp, "Stream assigned to non-PLP was not converted!"); - return NullIfCompletedWriteTask(WriteStreamFeed((StreamDataFeed)value, stateObj, paramSize)); + return NullIfCompletedWriteTask(WriteStreamFeed(value.GenericCast(), stateObj, paramSize)); } else { @@ -11686,7 +11718,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { WriteInt(actualLength, stateObj); // chunk length } - return stateObj.WriteByteArray((byte[])value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, offset, canAccumulate: false); } } @@ -11694,7 +11726,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - FillGuidBytes((System.Guid)value, b); + FillGuidBytes(value.GenericCast(), b); stateObj.WriteByteSpan(b); break; } @@ -11702,7 +11734,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if ((bool)value == true) + if (value.GenericCast() == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11712,15 +11744,15 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte((byte)value); + stateObj.WriteByte(value.GenericCast()); else if (type.FixedLength == 2) - WriteShort((short)value, stateObj); + WriteShort(value.GenericCast(), stateObj); else if (type.FixedLength == 4) - WriteInt((int)value, stateObj); + WriteInt(value.GenericCast(), stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong((long)value, stateObj); + WriteLong(value.GenericCast(), stateObj); } break; @@ -11738,7 +11770,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(value.GenericCast(), stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); } else { @@ -11753,11 +11785,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, 0, canAccumulate: false); } else { - return WriteEncodingChar((string)value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(value.GenericCast(), actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } } } @@ -11775,7 +11807,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(value.GenericCast(), stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); } else { @@ -11798,25 +11830,25 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, 0, canAccumulate: false); } else { // convert to cchars instead of cbytes actualLength >>= 1; - return WriteString((string)value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(value.GenericCast(), actualLength, offset, stateObj, canAccumulate: false); } } } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteDecimal((decimal)value, stateObj); + WriteDecimal(value.GenericCast(), stateObj); break; case TdsEnums.SQLDATETIMN: Debug.Assert(type.FixedLength <= 0xff, "Invalid Fixed Length"); - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, (byte)type.FixedLength); + TdsDateTime dt = MetaType.FromDateTime(value.GenericCast(), (byte)type.FixedLength); if (type.FixedLength == 4) { @@ -11836,13 +11868,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLMONEYN: { - WriteCurrency((decimal)value, type.FixedLength, stateObj); + WriteCurrency(value.GenericCast(), type.FixedLength, stateObj); break; } case TdsEnums.SQLDATE: { - WriteDate((DateTime)value, stateObj); + WriteDate(value.GenericCast(), stateObj); break; } @@ -11851,7 +11883,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteTime((TimeSpan)value, scale, actualLength, stateObj); + WriteTime(value.GenericCast(), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIME2: @@ -11859,11 +11891,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteDateTime2((DateTime)value, scale, actualLength, stateObj); + WriteDateTime2(value.GenericCast(), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: - WriteDateTimeOffset((DateTimeOffset)value, scale, actualLength, stateObj); + WriteDateTimeOffset(value.GenericCast(), scale, actualLength, stateObj); break; default: @@ -12107,7 +12139,7 @@ private byte[] SerializeUnencryptedValue(object value, MetaType type, byte scale // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) + private byte[] SerializeUnencryptedSqlValue(T value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -12123,11 +12155,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - return SerializeFloat(((SqlSingle)value).Value); + { + return SerializeFloat(value.GenericCast().Value); + } else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - return SerializeDouble(((SqlDouble)value).Value); + return SerializeDouble(value.GenericCast().Value); } case TdsEnums.SQLBIGBINARY: @@ -12138,19 +12172,19 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlBinary) { - Buffer.BlockCopy(((SqlBinary)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(value.GenericCast().Value, offset, b, 0, actualLength); } else { Debug.Assert(value is SqlBytes); - Buffer.BlockCopy(((SqlBytes)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(value.GenericCast().Value, offset, b, 0, actualLength); } return b; } case TdsEnums.SQLUNIQUEID: { - byte[] b = ((SqlGuid)value).ToByteArray(); + byte[] b = value.GenericCast().ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); return b; @@ -12161,23 +12195,23 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); // We normalize to allow conversion across data types. BIT is serialized into a BIGINT. - return SerializeLong(((SqlBoolean)value).Value == true ? 1 : 0, stateObj); + return SerializeLong(value.GenericCast().Value == true ? 1 : 0, stateObj); } case TdsEnums.SQLINTN: // We normalize to allow conversion across data types. All data types below are serialized into a BIGINT. if (type.FixedLength == 1) - return SerializeLong(((SqlByte)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); if (type.FixedLength == 2) - return SerializeLong(((SqlInt16)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); if (type.FixedLength == 4) - return SerializeLong(((SqlInt32)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - return SerializeLong(((SqlInt64)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); } case TdsEnums.SQLBIGCHAR: @@ -12185,13 +12219,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLTEXT: if (value is SqlChars) { - String sch = new String(((SqlChars)value).Value); + String sch = new String(value.GenericCast().Value); return SerializeEncodingChar(sch, actualLength, offset, _defaultEncoding); } else { Debug.Assert(value is SqlString); - return SerializeEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding); + return SerializeEncodingChar(value.GenericCast().Value, actualLength, offset, _defaultEncoding); } @@ -12206,20 +12240,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlChars) { - return SerializeCharArray(((SqlChars)value).Value, actualLength, offset); + return SerializeCharArray(value.GenericCast().Value, actualLength, offset); } else { Debug.Assert(value is SqlString); - return SerializeString(((SqlString)value).Value, actualLength, offset); + return SerializeString(value.GenericCast().Value, actualLength, offset); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - return SerializeSqlDecimal((SqlDecimal)value, stateObj); + return SerializeSqlDecimal(value.GenericCast(), stateObj); case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = value.GenericCast(); if (type.FixedLength == 4) { @@ -12265,7 +12299,7 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLMONEYN: { - return SerializeSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + return SerializeSqlMoney(value.GenericCast(), type.FixedLength, stateObj); } default: diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs index 4c3d594ad1..caf5dea633 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs @@ -162,7 +162,7 @@ private bool StringToIntTest(SqlConnection cnn, string targetTable, SourceType s string expectedErrorMsg = string.Format(pattern, args); - Assert.True(ex.Message.Contains(expectedErrorMsg), "Unexpected error message: " + ex.Message); + Assert.True(ex.Message.Contains(expectedErrorMsg), $"Unexpected error message: {ex}"); hitException = true; } return hitException;