Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/snowflake): fix XDBC support when using high precision #1311

Merged
merged 5 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion csharp/test/Drivers/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,21 @@ public void CanGetObjectsAll()
.FirstOrDefault();

Assert.True(columns != null, "Columns cannot be null");
Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, columns.Count);

Assert.Equal(_testConfiguration.Metadata.ExpectedColumnCount, columns.Count);
if (testConfiguration.UseHighPrecision)
{
IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c => c.XdbcTypeName == "NUMBER");

if(highPrecisionColumns.Count() > 0)
{
// ensure they all are coming back as XdbcDataType_XDBC_DECIMAL because they are Decimal128
short XdbcDataType_XDBC_DECIMAL = 3;
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
int count = invalidHighPrecisionColumns.Count();
Assert.True(count == 0, $"There are {count} columns that do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
}
}
}

/// <summary>
Expand Down
10 changes: 1 addition & 9 deletions csharp/test/Drivers/Snowflake/ValueTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
using System;
using System.Collections.Generic;
using System.Data.SqlTypes;
using Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Xunit;

namespace Apache.Arrow.Adbc.Tests
namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
// TODO: When supported, use prepared statements instead of SQL string literals
// Which will better test how the driver handles values sent/received
Expand Down Expand Up @@ -239,13 +238,6 @@ private static string ConvertDoubleToString(double value)
return "'-inf'";
case double.NaN:
return "'NaN'";
#if NET472
// Standard Double.ToString() calls round up the max value, resulting in Snowflake storing infinity
case double.MaxValue:
return "1.7976931348623157E+308";
case double.MinValue:
return "-1.7976931348623157E+308";
#endif
default:
return value.ToString();
}
Expand Down
17 changes: 12 additions & 5 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,22 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,

var loc = time.Now().Location()

func toField(name string, isnullable bool, dataType string, numPrec, numPrecRadix, numScale sql.NullInt16, isIdent bool, identGen, identInc sql.NullString, charMaxLength, charOctetLength sql.NullInt32, datetimePrec sql.NullInt16, comment sql.NullString, ordinalPos int) (ret arrow.Field) {
func toField(name string, isnullable bool, dataType string, numPrec, numPrecRadix, numScale sql.NullInt16, isIdent, useHighPrecision bool, identGen, identInc sql.NullString, charMaxLength, charOctetLength sql.NullInt32, datetimePrec sql.NullInt16, comment sql.NullString, ordinalPos int) (ret arrow.Field) {
ret.Name, ret.Nullable = name, isnullable

switch dataType {
case "NUMBER":
if !numScale.Valid || numScale.Int16 == 0 {
ret.Type = arrow.PrimitiveTypes.Int64
if useHighPrecision {
ret.Type = &arrow.Decimal128Type{
Precision: int32(numPrec.Int16),
Scale: int32(numScale.Int16),
}
} else {
ret.Type = arrow.PrimitiveTypes.Float64
if !numScale.Valid || numScale.Int16 == 0 {
ret.Type = arrow.PrimitiveTypes.Int64
} else {
ret.Type = arrow.PrimitiveTypes.Float64
}
}
case "FLOAT":
fallthrough
Expand Down Expand Up @@ -639,7 +646,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
}

prevKey = key
fieldList = append(fieldList, toField(colName, isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent, identGen, identIncrement, charMaxLength, charOctetLength, datetimePrec, comment, ordinalPos))
fieldList = append(fieldList, toField(colName, isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent, c.useHighPrecision, identGen, identIncrement, charMaxLength, charOctetLength, datetimePrec, comment, ordinalPos))
}

if len(fieldList) > 0 && curTableInfo != nil {
Expand Down
Loading