From b6b237717f51713fb08b2f35c0d2688f1e63c74d Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Fri, 27 Sep 2024 15:18:48 -0700 Subject: [PATCH] feat(csharp/src/Client): Additional parameter support for DbCommand (#2195) Implements support for mapping DbType.Time and DbType.Decimal. Uses System.Convert to support a larger number of source types. --- csharp/src/Client/AdbcCommand.cs | 207 +++++++++++++----- csharp/src/Client/AdbcParameter.cs | 8 +- .../Drivers/Interop/Snowflake/ClientTests.cs | 19 +- 3 files changed, 172 insertions(+), 62 deletions(-) diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs index f76c246ccc..2dac9a95d2 100644 --- a/csharp/src/Client/AdbcCommand.cs +++ b/csharp/src/Client/AdbcCommand.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Data.SqlTypes; using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Types; @@ -231,111 +232,195 @@ private void BindParameters() for (int i = 0; i < fields.Length; i++) { AdbcParameter param = (AdbcParameter)_dbParameterCollection[i]; - ArrowType type; switch (param.DbType) { case DbType.Binary: - type = BinaryType.Default; var binaryBuilder = new BinaryArray.Builder(); - if (param.Value == null) + switch (param.Value) { - binaryBuilder.AppendNull(); - } - else - { - binaryBuilder.Append(((byte[])param.Value).AsSpan()); + case null: binaryBuilder.AppendNull(); break; + case byte[] array: binaryBuilder.Append(array.AsSpan()); break; + default: throw new NotSupportedException($"Values of type {param.Value.GetType().Name} cannot be bound as binary"); } parameters[i] = binaryBuilder.Build(); break; case DbType.Boolean: - type = BooleanType.Default; var boolBuilder = new BooleanArray.Builder(); - if (param.Value == null) + switch (param.Value) { - boolBuilder.AppendNull(); - } - else - { - boolBuilder.Append((bool)param.Value); + case null: boolBuilder.AppendNull(); break; + case bool boolValue: boolBuilder.Append(boolValue); break; + default: boolBuilder.Append(ConvertValue(param.Value, Convert.ToBoolean, DbType.Boolean)); break; } parameters[i] = boolBuilder.Build(); break; case DbType.Byte: - type = UInt8Type.Default; - parameters[i] = new UInt8Array.Builder().Append((byte?)param.Value).Build(); + var uint8Builder = new UInt8Array.Builder(); + switch (param.Value) + { + case null: uint8Builder.AppendNull(); break; + case byte byteValue: uint8Builder.Append(byteValue); break; + default: uint8Builder.Append(ConvertValue(param.Value, Convert.ToByte, DbType.Byte)); break; + } + parameters[i] = uint8Builder.Build(); break; case DbType.Date: - type = Date32Type.Default; var dateBuilder = new Date32Array.Builder(); - if (param.Value == null) + switch (param.Value) { - dateBuilder.AppendNull(); - } + case null: dateBuilder.AppendNull(); break; + case DateTime datetime: dateBuilder.Append(datetime); break; #if NET5_0_OR_GREATER - else if (param.Value is DateOnly) - { - dateBuilder.Append((DateOnly)param.Value); - } + case DateOnly dateonly: dateBuilder.Append(dateonly); break; #endif - else - { - dateBuilder.Append((DateTime)param.Value); + default: dateBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.Date)); break; } parameters[i] = dateBuilder.Build(); break; case DbType.DateTime: - type = TimestampType.Default; var timestampBuilder = new TimestampArray.Builder(); - if (param.Value == null) + switch (param.Value) { - timestampBuilder.AppendNull(); + case null: timestampBuilder.AppendNull(); break; + case DateTime datetime: timestampBuilder.Append(datetime); break; + default: timestampBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.DateTime)); break; + } + parameters[i] = timestampBuilder.Build(); + break; + case DbType.Decimal: + var value = param.Value switch + { + null => (SqlDecimal?)null, + SqlDecimal sqlDecimal => sqlDecimal, + decimal d => new SqlDecimal(d), + _ => new SqlDecimal(ConvertValue(param.Value, Convert.ToDecimal, DbType.Decimal)), + }; + var decimalBuilder = new Decimal128Array.Builder(new Decimal128Type(value?.Precision ?? 10, value?.Scale ?? 0)); + if (value is null) + { + decimalBuilder.AppendNull(); } else { - timestampBuilder.Append((DateTime)param.Value); + decimalBuilder.Append(value.Value); } + parameters[i] = decimalBuilder.Build(); break; - // TODO: case DbType.Decimal: case DbType.Double: - type = DoubleType.Default; - parameters[i] = new DoubleArray.Builder().Append((double?)param.Value).Build(); + var doubleBuilder = new DoubleArray.Builder(); + switch (param.Value) + { + case null: doubleBuilder.AppendNull(); break; + case double dbl: doubleBuilder.Append(dbl); break; + default: doubleBuilder.Append(ConvertValue(param.Value, Convert.ToDouble, DbType.Double)); break; + } + parameters[i] = doubleBuilder.Build(); break; case DbType.Int16: - type = Int16Type.Default; - parameters[i] = new Int16Array.Builder().Append((short?)param.Value).Build(); + var int16Builder = new Int16Array.Builder(); + switch (param.Value) + { + case null: int16Builder.AppendNull(); break; + case short shortValue: int16Builder.Append(shortValue); break; + default: int16Builder.Append(ConvertValue(param.Value, Convert.ToInt16, DbType.Int16)); break; + } + parameters[i] = int16Builder.Build(); break; case DbType.Int32: - type = Int32Type.Default; - parameters[i] = new Int32Array.Builder().Append((int?)param.Value).Build(); + var int32Builder = new Int32Array.Builder(); + switch (param.Value) + { + case null: int32Builder.AppendNull(); break; + case int intValue: int32Builder.Append(intValue); break; + default: int32Builder.Append(ConvertValue(param.Value, Convert.ToInt32, DbType.Int32)); break; + } + parameters[i] = int32Builder.Build(); break; case DbType.Int64: - type = Int64Type.Default; - parameters[i] = new Int64Array.Builder().Append((long?)param.Value).Build(); + var int64Builder = new Int64Array.Builder(); + switch (param.Value) + { + case null: int64Builder.AppendNull(); break; + case long longValue: int64Builder.Append(longValue); break; + default: int64Builder.Append(ConvertValue(param.Value, Convert.ToInt64, DbType.Int64)); break; + } + parameters[i] = int64Builder.Build(); break; case DbType.SByte: - type = Int8Type.Default; - parameters[i] = new Int8Array.Builder().Append((sbyte?)param.Value).Build(); + var int8Builder = new Int8Array.Builder(); + switch (param.Value) + { + case null: int8Builder.AppendNull(); break; + case sbyte sbyteValue: int8Builder.Append(sbyteValue); break; + default: int8Builder.Append(ConvertValue(param.Value, Convert.ToSByte, DbType.SByte)); break; + } + parameters[i] = int8Builder.Build(); break; case DbType.Single: - type = FloatType.Default; - parameters[i] = new FloatArray.Builder().Append((float?)param.Value).Build(); + var floatBuilder = new FloatArray.Builder(); + switch (param.Value) + { + case null: floatBuilder.AppendNull(); break; + case float floatValue: floatBuilder.Append(floatValue); break; + default: floatBuilder.Append(ConvertValue(param.Value, Convert.ToSingle, DbType.Single)); break; + } + parameters[i] = floatBuilder.Build(); break; case DbType.String: - type = StringType.Default; - parameters[i] = new StringArray.Builder().Append((string)param.Value!).Build(); + var stringBuilder = new StringArray.Builder(); + switch (param.Value) + { + case null: stringBuilder.AppendNull(); break; + case string stringValue: stringBuilder.Append(stringValue); break; + default: stringBuilder.Append(ConvertValue(param.Value, Convert.ToString, DbType.String)); break; + } + parameters[i] = stringBuilder.Build(); + break; + case DbType.Time: + var timeBuilder = new Time32Array.Builder(); + switch (param.Value) + { + case null: timeBuilder.AppendNull(); break; + case DateTime datetime: timeBuilder.Append((int)(datetime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); break; +#if NET5_0_OR_GREATER + case TimeOnly timeonly: timeBuilder.Append(timeonly); break; +#endif + default: + DateTime convertedDateTime = ConvertValue(param.Value, Convert.ToDateTime, DbType.Time); + timeBuilder.Append((int)(convertedDateTime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); + break; + } + parameters[i] = timeBuilder.Build(); break; - // TODO: case DbType.Time: case DbType.UInt16: - type = UInt16Type.Default; - parameters[i] = new UInt16Array.Builder().Append((ushort?)param.Value).Build(); + var uint16Builder = new UInt16Array.Builder(); + switch (param.Value) + { + case null: uint16Builder.AppendNull(); break; + case ushort ushortValue: uint16Builder.Append(ushortValue); break; + default: uint16Builder.Append(ConvertValue(param.Value, Convert.ToUInt16, DbType.UInt16)); break; + } + parameters[i] = uint16Builder.Build(); break; case DbType.UInt32: - type = UInt32Type.Default; - parameters[i] = new UInt32Array.Builder().Append((uint?)param.Value).Build(); + var uint32Builder = new UInt32Array.Builder(); + switch (param.Value) + { + case null: uint32Builder.AppendNull(); break; + case uint uintValue: uint32Builder.Append(uintValue); break; + default: uint32Builder.Append(ConvertValue(param.Value, Convert.ToUInt32, DbType.UInt32)); break; + } + parameters[i] = uint32Builder.Build(); break; case DbType.UInt64: - type = UInt64Type.Default; - parameters[i] = new UInt64Array.Builder().Append((ulong?)param.Value).Build(); + var uint64Builder = new UInt64Array.Builder(); + switch (param.Value) + { + case null: uint64Builder.AppendNull(); break; + case ulong ulongValue: uint64Builder.Append(ulongValue); break; + default: uint64Builder.Append(ConvertValue(param.Value, Convert.ToUInt64, DbType.UInt64)); break; + } + parameters[i] = uint64Builder.Build(); break; default: throw new NotSupportedException($"Parameters of type {param.DbType} are not supported"); @@ -343,7 +428,7 @@ private void BindParameters() fields[i] = new Field( string.IsNullOrWhiteSpace(param.ParameterName) ? Guid.NewGuid().ToString() : param.ParameterName, - type, + parameters[i].Data.DataType, param.IsNullable || param.Value == null); } @@ -352,6 +437,18 @@ private void BindParameters() } } + private static T ConvertValue(object value, Func converter, DbType type) + { + try + { + return converter(value); + } + catch (Exception) + { + throw new NotSupportedException($"Values of type {value.GetType().Name} cannot be bound as {type}."); + } + } + #if NET5_0_OR_GREATER public override ValueTask DisposeAsync() { diff --git a/csharp/src/Client/AdbcParameter.cs b/csharp/src/Client/AdbcParameter.cs index 620b921c56..c816b1a0b8 100644 --- a/csharp/src/Client/AdbcParameter.cs +++ b/csharp/src/Client/AdbcParameter.cs @@ -25,13 +25,17 @@ namespace Apache.Arrow.Adbc.Client sealed public class AdbcParameter : DbParameter { public override DbType DbType { get; set; } - public override ParameterDirection Direction { get => ParameterDirection.Input; set => throw new NotImplementedException(); } + public override ParameterDirection Direction + { + get => ParameterDirection.Input; + set { if (value != ParameterDirection.Input) { throw new NotSupportedException(); } } + } public override bool IsNullable { get; set; } = true; #if NET5_0_OR_GREATER [AllowNull] #endif public override string ParameterName { get; set; } = string.Empty; - public override int Size { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public override int Size { get; set; } #if NET5_0_OR_GREATER [AllowNull] #endif diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs index 942a78fe75..b5f0457401 100644 --- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs +++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs @@ -149,17 +149,26 @@ public void CanClientExecuteQueryWithNoResults() public void CanClientExecuteParameterizedQuery() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - testConfiguration.Query = "SELECT * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?"; + testConfiguration.Query = "SELECT ? as A, ? as B, ? as C, * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?"; testConfiguration.ExpectedResultsCount = 1; using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration, command => { - DbParameter parameter1 = command.CreateParameter(); - parameter1.Value = 2; - parameter1.DbType = DbType.Int32; - command.Parameters.Add(parameter1); + DbParameter CreateParameter(DbType dbType, object value) + { + DbParameter result = command.CreateParameter(); + result.DbType = dbType; + result.Value = value; + return result; + } + + // TODO: Add tests for decimal and time once supported by the driver or gosnowflake + command.Parameters.Add(CreateParameter(DbType.Int32, 2)); + command.Parameters.Add(CreateParameter(DbType.String, "text")); + command.Parameters.Add(CreateParameter(DbType.Double, 2.5)); + command.Parameters.Add(CreateParameter(DbType.Int32, 2)); }); } }