diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 496629ebda..6f8707a15c 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -26,6 +26,7 @@ import ( "encoding/base64" "encoding/pem" "fmt" + "math" "os" "runtime" "strconv" @@ -38,6 +39,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/validation" "github.com/apache/arrow/go/v14/arrow" "github.com/apache/arrow/go/v14/arrow/array" + "github.com/apache/arrow/go/v14/arrow/decimal128" "github.com/apache/arrow/go/v14/arrow/memory" "github.com/google/uuid" "github.com/snowflakedb/gosnowflake" @@ -679,6 +681,98 @@ func (suite *SnowflakeTests) TestUseHighPrecision() { suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1)) } +func (suite *SnowflakeTests) TestDecimalHighPrecision() { + for sign := 0; sign <= 1; sign++ { + for scale := 0; scale <= 2; scale++ { + for precision := 3; precision <= 38; precision++ { + numberString := strings.Repeat("9", precision-scale) + "." + strings.Repeat("9", scale) + if sign == 1 { + numberString = "-" + numberString + } + query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale) + number, err := decimal128.FromString(numberString, int32(precision), int32(scale)) + suite.NoError(err) + + suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled)) + suite.Require().NoError(suite.stmt.SetSqlQuery(query)) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(1, n) + suite.Truef(arrow.TypeEqual(&arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)}, rdr.Schema().Field(0).Type), "expected decimal(%d, %d), got %s", precision, scale, rdr.Schema().Field(0).Type) + suite.True(rdr.Next()) + rec := rdr.Record() + + suite.Equal(number, rec.Column(0).(*array.Decimal128).Value(0)) + } + } + } +} + +func (suite *SnowflakeTests) TestNonIntDecimalLowPrecision() { + for sign := 0; sign <= 1; sign++ { + for precision := 3; precision <= 38; precision++ { + scale := 2 + numberString := strings.Repeat("9", precision-scale) + ".99" + if sign == 1 { + numberString = "-" + numberString + } + query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale) + decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale)) + suite.NoError(err) + number := decimalNumber.ToFloat64(int32(scale)) + + suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled)) + suite.Require().NoError(suite.stmt.SetSqlQuery(query)) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(1, n) + suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Float64, rdr.Schema().Field(0).Type), "expected float64, got %s", rdr.Schema().Field(0).Type) + suite.True(rdr.Next()) + rec := rdr.Record() + + value := rec.Column(0).(*array.Float64).Value(0) + difference := math.Abs(number - value) + suite.Truef(difference < 1e-13, "expected %f, got %f", number, value) + } + } +} + +func (suite *SnowflakeTests) TestIntDecimalLowPrecision() { + for sign := 0; sign <= 1; sign++ { + for precision := 3; precision <= 38; precision++ { + scale := 0 + numberString := strings.Repeat("9", precision-scale) + if sign == 1 { + numberString = "-" + numberString + } + query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale) + decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale)) + suite.NoError(err) + // The current behavior of the driver for decimal128 values too large to fit into 64 bits is to simply + // return the low 64 bits of the value. + number := int64(decimalNumber.LowBits()) + + suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled)) + suite.Require().NoError(suite.stmt.SetSqlQuery(query)) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(1, n) + suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Int64, rdr.Schema().Field(0).Type), "expected int64, got %s", rdr.Schema().Field(0).Type) + suite.True(rdr.Next()) + rec := rdr.Record() + + value := rec.Column(0).(*array.Int64).Value(0) + suite.Equal(number, value) + } + } +} + func (suite *SnowflakeTests) TestDescribeOnly() { suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled)) suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT CAST('9999.99' AS NUMBER(6, 2)) AS RESULT")) diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index 9bd5193537..b2213a2242 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -101,20 +101,36 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP } f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { - return compute.CastArray(ctx, a, compute.SafeCastOptions(dt)) + return integerToDecimal128(ctx, a, dt) } } else { if srcMeta.Scale != 0 { f.Type = arrow.PrimitiveTypes.Float64 - transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { - result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, - &compute.ArrayDatum{Value: a.Data()}, - compute.NewDatum(math.Pow10(int(srcMeta.Scale)))) - if err != nil { - return nil, err + // For precisions of 16, 17 and 18, a conversion from int64 to float64 fails with an error + // So for these precisions, we instead convert first to a decimal128 and then to a float64. + if srcMeta.Precision > 15 && srcMeta.Precision < 19 { + transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { + result, err := integerToDecimal128(ctx, a, &arrow.Decimal128Type{ + Precision: int32(srcMeta.Precision), + Scale: int32(srcMeta.Scale), + }) + if err != nil { + return nil, err + } + return compute.CastArray(ctx, result, compute.UnsafeCastOptions(f.Type)) + } + } else { + // For precisions less than 16, we can simply scale the integer value appropriately + transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { + result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, + &compute.ArrayDatum{Value: a.Data()}, + compute.NewDatum(math.Pow10(int(srcMeta.Scale)))) + if err != nil { + return nil, err + } + defer result.Release() + return result.(*compute.ArrayDatum).MakeArray(), nil } - defer result.Release() - return result.(*compute.ArrayDatum).MakeArray(), nil } } else { f.Type = arrow.PrimitiveTypes.Int64 @@ -266,6 +282,27 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP return out, getRecTransformer(out, transformers) } +func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal128Type) (arrow.Array, error) { + // We can't do a cast directly into the destination type because the numbers we get from Snowflake + // are scaled integers. So not only would the cast produce the wrong value, it also risks producing + // an error of precisions which e.g. can't hold every int64. To work around these problems, we instead + // cast into a decimal type of a precision and scale which we know will hold all values and won't + // require scaling, We then substitute the type on this array with the actual return type. + + dt0 := &arrow.Decimal128Type{ + Precision: int32(20), + Scale: int32(0), + } + result, err := compute.CastArray(ctx, a, compute.SafeCastOptions(dt0)) + if err != nil { + return nil, err + } + + data := result.Data() + result.Data().Reset(dt, data.Len(), data.Buffers(), data.Children(), data.NullN(), data.Offset()) + return result, err +} + func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) { var loc *time.Location