Skip to content

Commit

Permalink
feat(csharp/src/Client): Additional parameter support for DbCommand (a…
Browse files Browse the repository at this point in the history
…pache#2195)

Implements support for mapping DbType.Time and DbType.Decimal. Uses
System.Convert to support a larger number of source types.
  • Loading branch information
CurtHagenlocher committed Sep 27, 2024
1 parent c391168 commit b6b2377
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 62 deletions.
207 changes: 152 additions & 55 deletions csharp/src/Client/AdbcCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Types;
Expand Down Expand Up @@ -231,119 +232,203 @@ private void BindParameters()
for (int i = 0; i < fields.Length; i++)
{
AdbcParameter param = (AdbcParameter)_dbParameterCollection[i];
ArrowType type;
switch (param.DbType)
{
case DbType.Binary:
type = BinaryType.Default;
var binaryBuilder = new BinaryArray.Builder();
if (param.Value == null)
switch (param.Value)
{
binaryBuilder.AppendNull();
}
else
{
binaryBuilder.Append(((byte[])param.Value).AsSpan());
case null: binaryBuilder.AppendNull(); break;
case byte[] array: binaryBuilder.Append(array.AsSpan()); break;
default: throw new NotSupportedException($"Values of type {param.Value.GetType().Name} cannot be bound as binary");
}
parameters[i] = binaryBuilder.Build();
break;
case DbType.Boolean:
type = BooleanType.Default;
var boolBuilder = new BooleanArray.Builder();
if (param.Value == null)
switch (param.Value)
{
boolBuilder.AppendNull();
}
else
{
boolBuilder.Append((bool)param.Value);
case null: boolBuilder.AppendNull(); break;
case bool boolValue: boolBuilder.Append(boolValue); break;
default: boolBuilder.Append(ConvertValue(param.Value, Convert.ToBoolean, DbType.Boolean)); break;
}
parameters[i] = boolBuilder.Build();
break;
case DbType.Byte:
type = UInt8Type.Default;
parameters[i] = new UInt8Array.Builder().Append((byte?)param.Value).Build();
var uint8Builder = new UInt8Array.Builder();
switch (param.Value)
{
case null: uint8Builder.AppendNull(); break;
case byte byteValue: uint8Builder.Append(byteValue); break;
default: uint8Builder.Append(ConvertValue(param.Value, Convert.ToByte, DbType.Byte)); break;
}
parameters[i] = uint8Builder.Build();
break;
case DbType.Date:
type = Date32Type.Default;
var dateBuilder = new Date32Array.Builder();
if (param.Value == null)
switch (param.Value)
{
dateBuilder.AppendNull();
}
case null: dateBuilder.AppendNull(); break;
case DateTime datetime: dateBuilder.Append(datetime); break;
#if NET5_0_OR_GREATER
else if (param.Value is DateOnly)
{
dateBuilder.Append((DateOnly)param.Value);
}
case DateOnly dateonly: dateBuilder.Append(dateonly); break;
#endif
else
{
dateBuilder.Append((DateTime)param.Value);
default: dateBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.Date)); break;
}
parameters[i] = dateBuilder.Build();
break;
case DbType.DateTime:
type = TimestampType.Default;
var timestampBuilder = new TimestampArray.Builder();
if (param.Value == null)
switch (param.Value)
{
timestampBuilder.AppendNull();
case null: timestampBuilder.AppendNull(); break;
case DateTime datetime: timestampBuilder.Append(datetime); break;
default: timestampBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.DateTime)); break;
}
parameters[i] = timestampBuilder.Build();
break;
case DbType.Decimal:
var value = param.Value switch
{
null => (SqlDecimal?)null,
SqlDecimal sqlDecimal => sqlDecimal,
decimal d => new SqlDecimal(d),
_ => new SqlDecimal(ConvertValue(param.Value, Convert.ToDecimal, DbType.Decimal)),
};
var decimalBuilder = new Decimal128Array.Builder(new Decimal128Type(value?.Precision ?? 10, value?.Scale ?? 0));
if (value is null)
{
decimalBuilder.AppendNull();
}
else
{
timestampBuilder.Append((DateTime)param.Value);
decimalBuilder.Append(value.Value);
}
parameters[i] = decimalBuilder.Build();
break;
// TODO: case DbType.Decimal:
case DbType.Double:
type = DoubleType.Default;
parameters[i] = new DoubleArray.Builder().Append((double?)param.Value).Build();
var doubleBuilder = new DoubleArray.Builder();
switch (param.Value)
{
case null: doubleBuilder.AppendNull(); break;
case double dbl: doubleBuilder.Append(dbl); break;
default: doubleBuilder.Append(ConvertValue(param.Value, Convert.ToDouble, DbType.Double)); break;
}
parameters[i] = doubleBuilder.Build();
break;
case DbType.Int16:
type = Int16Type.Default;
parameters[i] = new Int16Array.Builder().Append((short?)param.Value).Build();
var int16Builder = new Int16Array.Builder();
switch (param.Value)
{
case null: int16Builder.AppendNull(); break;
case short shortValue: int16Builder.Append(shortValue); break;
default: int16Builder.Append(ConvertValue(param.Value, Convert.ToInt16, DbType.Int16)); break;
}
parameters[i] = int16Builder.Build();
break;
case DbType.Int32:
type = Int32Type.Default;
parameters[i] = new Int32Array.Builder().Append((int?)param.Value).Build();
var int32Builder = new Int32Array.Builder();
switch (param.Value)
{
case null: int32Builder.AppendNull(); break;
case int intValue: int32Builder.Append(intValue); break;
default: int32Builder.Append(ConvertValue(param.Value, Convert.ToInt32, DbType.Int32)); break;
}
parameters[i] = int32Builder.Build();
break;
case DbType.Int64:
type = Int64Type.Default;
parameters[i] = new Int64Array.Builder().Append((long?)param.Value).Build();
var int64Builder = new Int64Array.Builder();
switch (param.Value)
{
case null: int64Builder.AppendNull(); break;
case long longValue: int64Builder.Append(longValue); break;
default: int64Builder.Append(ConvertValue(param.Value, Convert.ToInt64, DbType.Int64)); break;
}
parameters[i] = int64Builder.Build();
break;
case DbType.SByte:
type = Int8Type.Default;
parameters[i] = new Int8Array.Builder().Append((sbyte?)param.Value).Build();
var int8Builder = new Int8Array.Builder();
switch (param.Value)
{
case null: int8Builder.AppendNull(); break;
case sbyte sbyteValue: int8Builder.Append(sbyteValue); break;
default: int8Builder.Append(ConvertValue(param.Value, Convert.ToSByte, DbType.SByte)); break;
}
parameters[i] = int8Builder.Build();
break;
case DbType.Single:
type = FloatType.Default;
parameters[i] = new FloatArray.Builder().Append((float?)param.Value).Build();
var floatBuilder = new FloatArray.Builder();
switch (param.Value)
{
case null: floatBuilder.AppendNull(); break;
case float floatValue: floatBuilder.Append(floatValue); break;
default: floatBuilder.Append(ConvertValue(param.Value, Convert.ToSingle, DbType.Single)); break;
}
parameters[i] = floatBuilder.Build();
break;
case DbType.String:
type = StringType.Default;
parameters[i] = new StringArray.Builder().Append((string)param.Value!).Build();
var stringBuilder = new StringArray.Builder();
switch (param.Value)
{
case null: stringBuilder.AppendNull(); break;
case string stringValue: stringBuilder.Append(stringValue); break;
default: stringBuilder.Append(ConvertValue(param.Value, Convert.ToString, DbType.String)); break;
}
parameters[i] = stringBuilder.Build();
break;
case DbType.Time:
var timeBuilder = new Time32Array.Builder();
switch (param.Value)
{
case null: timeBuilder.AppendNull(); break;
case DateTime datetime: timeBuilder.Append((int)(datetime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); break;
#if NET5_0_OR_GREATER
case TimeOnly timeonly: timeBuilder.Append(timeonly); break;
#endif
default:
DateTime convertedDateTime = ConvertValue(param.Value, Convert.ToDateTime, DbType.Time);
timeBuilder.Append((int)(convertedDateTime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond));
break;
}
parameters[i] = timeBuilder.Build();
break;
// TODO: case DbType.Time:
case DbType.UInt16:
type = UInt16Type.Default;
parameters[i] = new UInt16Array.Builder().Append((ushort?)param.Value).Build();
var uint16Builder = new UInt16Array.Builder();
switch (param.Value)
{
case null: uint16Builder.AppendNull(); break;
case ushort ushortValue: uint16Builder.Append(ushortValue); break;
default: uint16Builder.Append(ConvertValue(param.Value, Convert.ToUInt16, DbType.UInt16)); break;
}
parameters[i] = uint16Builder.Build();
break;
case DbType.UInt32:
type = UInt32Type.Default;
parameters[i] = new UInt32Array.Builder().Append((uint?)param.Value).Build();
var uint32Builder = new UInt32Array.Builder();
switch (param.Value)
{
case null: uint32Builder.AppendNull(); break;
case uint uintValue: uint32Builder.Append(uintValue); break;
default: uint32Builder.Append(ConvertValue(param.Value, Convert.ToUInt32, DbType.UInt32)); break;
}
parameters[i] = uint32Builder.Build();
break;
case DbType.UInt64:
type = UInt64Type.Default;
parameters[i] = new UInt64Array.Builder().Append((ulong?)param.Value).Build();
var uint64Builder = new UInt64Array.Builder();
switch (param.Value)
{
case null: uint64Builder.AppendNull(); break;
case ulong ulongValue: uint64Builder.Append(ulongValue); break;
default: uint64Builder.Append(ConvertValue(param.Value, Convert.ToUInt64, DbType.UInt64)); break;
}
parameters[i] = uint64Builder.Build();
break;
default:
throw new NotSupportedException($"Parameters of type {param.DbType} are not supported");
}

fields[i] = new Field(
string.IsNullOrWhiteSpace(param.ParameterName) ? Guid.NewGuid().ToString() : param.ParameterName,
type,
parameters[i].Data.DataType,
param.IsNullable || param.Value == null);
}

Expand All @@ -352,6 +437,18 @@ private void BindParameters()
}
}

private static T ConvertValue<T>(object value, Func<object, T> converter, DbType type)
{
try
{
return converter(value);
}
catch (Exception)
{
throw new NotSupportedException($"Values of type {value.GetType().Name} cannot be bound as {type}.");
}
}

#if NET5_0_OR_GREATER
public override ValueTask DisposeAsync()
{
Expand Down
8 changes: 6 additions & 2 deletions csharp/src/Client/AdbcParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ namespace Apache.Arrow.Adbc.Client
sealed public class AdbcParameter : DbParameter
{
public override DbType DbType { get; set; }
public override ParameterDirection Direction { get => ParameterDirection.Input; set => throw new NotImplementedException(); }
public override ParameterDirection Direction
{
get => ParameterDirection.Input;
set { if (value != ParameterDirection.Input) { throw new NotSupportedException(); } }
}
public override bool IsNullable { get; set; } = true;
#if NET5_0_OR_GREATER
[AllowNull]
#endif
public override string ParameterName { get; set; } = string.Empty;
public override int Size { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
public override int Size { get; set; }
#if NET5_0_OR_GREATER
[AllowNull]
#endif
Expand Down
19 changes: 14 additions & 5 deletions csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,26 @@ public void CanClientExecuteQueryWithNoResults()
public void CanClientExecuteParameterizedQuery()
{
SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
testConfiguration.Query = "SELECT * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?";
testConfiguration.Query = "SELECT ? as A, ? as B, ? as C, * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?";
testConfiguration.ExpectedResultsCount = 1;

using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration))
{
Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration, command =>
{
DbParameter parameter1 = command.CreateParameter();
parameter1.Value = 2;
parameter1.DbType = DbType.Int32;
command.Parameters.Add(parameter1);
DbParameter CreateParameter(DbType dbType, object value)
{
DbParameter result = command.CreateParameter();
result.DbType = dbType;
result.Value = value;
return result;
}
// TODO: Add tests for decimal and time once supported by the driver or gosnowflake
command.Parameters.Add(CreateParameter(DbType.Int32, 2));
command.Parameters.Add(CreateParameter(DbType.String, "text"));
command.Parameters.Add(CreateParameter(DbType.Double, 2.5));
command.Parameters.Add(CreateParameter(DbType.Int32, 2));
});
}
}
Expand Down

0 comments on commit b6b2377

Please sign in to comment.