From 742af7128586021b97cf216216ed977f80f6ad88 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 23 May 2023 15:42:52 -0400 Subject: [PATCH] feat(go/adbc): implement 1.1.0 features - ADBC_INFO_DRIVER_ADBC_VERSION - StatementExecuteSchema (#318) - ADBC_CONNECTION_OPTION_CURRENT_{CATALOG, DB_SCHEMA} (#319) - Get/SetOption - error_details (#755) --- .github/workflows/native-unix.yml | 1 + go/adbc/adbc.go | 32 +- go/adbc/driver/flightsql/flightsql_adbc.go | 381 +++++++++++++++--- .../flightsql/flightsql_adbc_server_test.go | 260 ++++++++++++ .../driver/flightsql/flightsql_adbc_test.go | 14 +- .../driver/flightsql/flightsql_statement.go | 181 +++++++-- go/adbc/driver/flightsql/record_reader.go | 6 +- go/adbc/driver/flightsql/utils.go | 6 +- go/adbc/driver/snowflake/connection.go | 115 +++++- go/adbc/driver/snowflake/driver.go | 128 ++++++ go/adbc/driver/snowflake/driver_test.go | 23 +- go/adbc/driver/snowflake/statement.go | 63 ++- go/adbc/validation/validation.go | 155 ++++++- 13 files changed, 1234 insertions(+), 131 deletions(-) diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 32d24a07ef..09adb7c321 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -315,6 +315,7 @@ jobs: popd - name: Go Test env: + SNOWFLAKE_DATABASE: ADBC_TESTING SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} run: | ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 99a4f81b75..c936fe5334 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -58,13 +58,15 @@ type Error struct { // SqlState is a SQLSTATE error code, if provided, as defined // by the SQL:2003 standard. If not set, it will be "\0\0\0\0\0" SqlState [5]byte - // Details is an array of additional driver-specific binary error details. + // Details is an array of additional driver-specific error details. // // This allows drivers to return custom, structured error information (for // example, JSON or Protocol Buffers) that can be optionally parsed by // clients, beyond the standard Error fields, without having to encode it in - // the error message. The encoding of the data is driver-defined. - Details [][]byte + // the error message. The encoding of the data is driver-defined. It is + // suggested to use proto.Message for Protocol Buffers and error for wrapped + // errors. + Details []interface{} } func (e Error) Error() string { @@ -621,23 +623,6 @@ type Statement interface { ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error) } -// Cancellable is a Connection or Statement that also supports Cancel. -// -// Since ADBC API revision 1.1.0. -type Cancellable interface { - // Cancel stops execution of an in-progress query. - // - // This can be called during ExecuteQuery, GetObjects, or other - // methods that produce result sets, or while consuming a - // RecordReader returned from such. Calling this function should - // make the other functions return an error with a StatusCancelled - // code. - // - // This must always be thread-safe (other operations are not - // necessarily thread-safe). - Cancel() error -} - // ConnectionGetStatistics is a Connection that supports getting // statistics on data in the database. // @@ -719,7 +704,10 @@ type StatementExecuteSchema interface { ExecuteSchema(context.Context) (*arrow.Schema, error) } -// GetSetOptions is a PostInitOptions that also supports getting and setting property values of different types. +// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types. +// +// GetOption functions should return an error with StatusNotFound for unsupported options. +// SetOption functions should return an error with StatusNotImplemented for unsupported options. // // Since ADBC API revision 1.1.0. type GetSetOptions interface { @@ -728,7 +716,7 @@ type GetSetOptions interface { SetOptionBytes(key string, value []byte) error SetOptionInt(key string, value int64) error SetOptionDouble(key string, value float64) error - GetOption(key, value string) (string, error) + GetOption(key string) (string, error) GetOptionBytes(key string) ([]byte, error) GetOptionInt(key string) (int64, error) GetOptionDouble(key string) (float64, error) diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index e038354cc5..00d123b322 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -36,7 +36,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "math" @@ -119,20 +118,13 @@ func init() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, } } -func getTimeoutOptionValue(v string) (time.Duration, error) { - timeout, err := strconv.ParseFloat(v, 64) - if math.IsNaN(timeout) || math.IsInf(timeout, 0) || timeout < 0 { - return 0, errors.New("timeout must be positive and finite") - } - return time.Duration(timeout * float64(time.Second)), err -} - type Driver struct { Alloc memory.Allocator } @@ -164,6 +156,8 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) { db.dialOpts.block = false db.dialOpts.maxMsgSize = 16 * 1024 * 1024 + db.options = make(map[string]string) + return db, db.SetOptions(opts) } @@ -192,6 +186,7 @@ type database struct { timeout timeoutOption dialOpts dbDialOpts enableCookies bool + options map[string]string alloc memory.Allocator } @@ -199,6 +194,10 @@ type database struct { func (d *database) SetOptions(cnOptions map[string]string) error { var tlsConfig tls.Config + for k, v := range cnOptions { + d.options[k] = v + } + mtlsCert := cnOptions[OptionMTLSCertChain] mtlsKey := cnOptions[OptionMTLSPrivateKey] switch { @@ -287,33 +286,24 @@ func (d *database) SetOptions(cnOptions map[string]string) error { var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { - if d.timeout.fetchTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutFetch) } if tv, ok := cnOptions[OptionTimeoutQuery]; ok { - if d.timeout.queryTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutQuery, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutQuery) } if tv, ok := cnOptions[OptionTimeoutUpdate]; ok { - if d.timeout.updateTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutUpdate, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutUpdate) } if val, ok := cnOptions[OptionWithBlock]; ok { @@ -369,7 +359,7 @@ func (d *database) SetOptions(cnOptions map[string]string) error { continue } return adbc.Error{ - Msg: fmt.Sprintf("Unknown database option '%s'", key), + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), Code: adbc.StatusInvalidArgument, } } @@ -377,6 +367,118 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) GetOption(key string) (string, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.String(), nil + } + if val, ok := d.options[key]; ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := d.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) SetOption(key, value string) error { + // We can't change most options post-init + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeoutString(key, value) + } + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, value) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + type timeoutOption struct { grpc.EmptyCallOption @@ -388,6 +490,45 @@ type timeoutOption struct { updateTimeout time.Duration } +func (t *timeoutOption) setTimeout(key string, value float64) error { + if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite", + key, value), + Code: adbc.StatusInvalidArgument, + } + } + + timeout := time.Duration(value * float64(time.Second)) + + switch key { + case OptionTimeoutFetch: + t.fetchTimeout = timeout + case OptionTimeoutQuery: + t.queryTimeout = timeout + case OptionTimeoutUpdate: + t.updateTimeout = timeout + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), + Code: adbc.StatusNotImplemented, + } + } + return nil +} + +func (t *timeoutOption) setTimeoutString(key string, value string) error { + timeout, err := strconv.ParseFloat(value, 64) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s", + key, value, err.Error()), + Code: adbc.StatusInvalidArgument, + } + } + return t.setTimeout(key, timeout) +} + func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok { @@ -729,6 +870,96 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } +func (c *cnxn) GetOption(key string) (string, error) { + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + headers := c.hdrs.Get(name) + if len(headers) > 0 { + return headers[0], nil + } + return "", adbc.Error{ + Msg: "[Flight SQL] unknown header", + Code: adbc.StatusNotFound, + } + } + + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.String(), nil + case adbc.OptionKeyAutoCommit: + if c.txn != nil { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return "", adbc.Error{ + Msg: "[Flight SQL] current catalog not supported", + Code: adbc.StatusNotFound, + } + + case adbc.OptionKeyCurrentDbSchema: + return "", adbc.Error{ + Msg: "[Flight SQL] current schema not supported", + Code: adbc.StatusNotFound, + } + } + + return "", adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := c.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.Seconds(), nil + } + + return 0.0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) SetOption(key, value string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) @@ -742,39 +973,16 @@ func (c *cnxn) SetOption(key, value string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.updateTimeout = timeout + return c.timeouts.setTimeoutString(key, value) case adbc.OptionKeyAutoCommit: autocommit := true switch value { case adbc.OptionValueEnabled: + autocommit = true case adbc.OptionValueDisabled: autocommit = false default: @@ -827,6 +1035,45 @@ func (c *cnxn) SetOption(key, value string) error { return nil } +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, value) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + // GetInfo returns metadata about the database/driver. // // The result is an Arrow dataset with the following schema: @@ -853,6 +1100,7 @@ func (c *cnxn) SetOption(key, value string) error { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -864,7 +1112,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { @@ -886,6 +1135,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) } } @@ -1350,6 +1603,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } +func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSchema(ctx, query, opts...) + } + + return c.cl.GetExecuteSchema(ctx, query, opts...) +} + func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) @@ -1358,6 +1619,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } +func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) + } + + return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) +} + func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 9d959ac4c6..50f1d9b1f0 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -42,6 +42,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) // ---- Common Infra -------------------- @@ -95,6 +96,14 @@ func TestAuthn(t *testing.T) { suite.Run(t, &AuthnTests{}) } +func TestErrorDetails(t *testing.T) { + suite.Run(t, &ErrorDetailsTests{}) +} + +func TestExecuteSchema(t *testing.T) { + suite.Run(t, &ExecuteSchemaTests{}) +} + func TestTimeout(t *testing.T) { suite.Run(t, &TimeoutTests{}) } @@ -202,6 +211,196 @@ func (suite *AuthnTests) TestBearerTokenUpdated() { defer reader.Release() } +// ---- Error Details Tests -------------------- + +type ErrorDetailsTestServer struct { + flightsql.BaseServer +} + +func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if query.GetQuery() == "details" { + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, err + } + return nil, st.Err() + } else if query.GetQuery() == "query" { + tkt, err := flightsql.CreateStatementQueryTicket([]byte("fetch")) + if err != nil { + panic(err) + } + return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (ts *ErrorDetailsTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + sc := arrow.NewSchema([]arrow.Field{}, nil) + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: st.Err(), + } + }() + return sc, ch, nil +} + +type ErrorDetailsTests struct { + ServerBasedTests +} + +func (suite *ErrorDetailsTests) SetupSuite() { + srv := ErrorDetailsTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ErrorDetailsTests) TestGetFlightInfo() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("details")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(1, len(adbcErr.Details)) + + message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +func (ts *ErrorDetailsTests) TestDoGet() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("query")) + + reader, _, err := stmt.ExecuteQuery(context.Background()) + ts.NoError(err) + + defer reader.Release() + + for reader.Next() { + } + err = reader.Err() + + ts.Error(err) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr, "Error was: %#v", err) + + ts.Equal(1, len(adbcErr.Details)) + + message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +// ---- ExecuteSchema Tests -------------------- + +type ExecuteSchemaTestServer struct { + flightsql.BaseServer +} + +func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + if query.GetQuery() == "sample query" { + return &flight.SchemaResult{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), srv.Alloc), + }, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) { + if req.GetQuery() == "sample query" { + return flightsql.ActionCreatePreparedStatementResult{ + DatasetSchema: arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), + }, nil + } + return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented") +} + +type ExecuteSchemaTests struct { + ServerBasedTests +} + +func (suite *ExecuteSchemaTests) SetupSuite() { + srv := ExecuteSchemaTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ExecuteSchemaTests) TestNoQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + es := stmt.(adbc.StatementExecuteSchema) + _, err = es.ExecuteSchema(context.Background()) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error()) +} + +func (ts *ExecuteSchemaTests) TestPreparedQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + ts.NoError(stmt.Prepare(context.Background())) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + +func (ts *ExecuteSchemaTests) TestQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + // ---- Timeout Tests -------------------- type TimeoutTestServer struct { @@ -321,6 +520,67 @@ func (ts *TimeoutTests) TestRemoveTimeout() { } } +func (ts *TimeoutTests) TestGetSet() { + keys := []string{ + "adbc.flight.sql.rpc.timeout_seconds.fetch", + "adbc.flight.sql.rpc.timeout_seconds.query", + "adbc.flight.sql.rpc.timeout_seconds.update", + } + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + + for _, v := range []interface{}{ts.db, ts.cnxn, stmt} { + getset := v.(adbc.GetSetOptions) + + for _, k := range keys { + strval, err := getset.GetOption(k) + ts.NoError(err) + ts.Equal("0s", strval) + + intval, err := getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(0), intval) + + floatval, err := getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.0, floatval) + + err = getset.SetOptionInt(k, 1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("1s", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(1), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(1.0, floatval) + + err = getset.SetOptionDouble(k, 0.1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("100ms", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + // truncated + ts.Equal(int64(0), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.1, floatval) + } + } + +} + func (ts *TimeoutTests) TestDoActionTimeout() { ts.NoError(ts.cnxn.(adbc.PostInitOptions). SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1")) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 53dbac2412..381ad0c67c 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -229,9 +229,14 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error return err } -func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } -func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } -func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } +func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } +func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) SupportsCurrentCatalogSchema() bool { return false } + +// The driver supports it, but the server we use for testing does not. +func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return false } +func (s *FlightSQLQuirks) SupportsGetSetOptions() bool { return true } func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true } func (s *FlightSQLQuirks) SupportsTransactions() bool { return true } func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false } @@ -247,6 +252,8 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "db_name" case adbc.InfoVendorVersion: @@ -273,6 +280,7 @@ func (s *FlightSQLQuirks) SampleTableSchemaMetadata(tblName string, dt arrow.Dat } } +func (s *FlightSQLQuirks) Catalog() string { return "" } func (s *FlightSQLQuirks) DBSchema() string { return "" } func TestADBCFlightSQL(t *testing.T) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index c7f074a800..04f46498f4 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -73,6 +73,29 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { + var ( + res *flight.SchemaResult + err error + ) + if s.sqlQuery != "" { + res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...) + } else if s.substraitPlan != nil { + res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) + } else { + return nil, adbc.Error{ + Code: adbc.StatusInvalidState, + Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", + } + } + + if err != nil { + return nil, err + } + + return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) +} + func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) @@ -138,6 +161,72 @@ func (s *statement) Close() (err error) { return err } +func (s *statement) GetOption(key string) (string, error) { + switch key { + case OptionStatementSubstraitVersion: + return s.query.substraitVersion, nil + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.String(), nil + } + + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + values := s.hdrs.Get(name) + if len(values) > 0 { + return values[0], nil + } + } + + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := s.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { @@ -152,35 +241,11 @@ func (s *statement) SetOption(key string, val string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.updateTimeout = timeout + return s.timeouts.setTimeoutString(key, val) case OptionStatementQueueSize: var err error var size int @@ -189,13 +254,8 @@ func (s *statement) SetOption(key string, val string) error { Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } - } else if size <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), - Code: adbc.StatusInvalidArgument, - } } - s.queueSize = size + return s.SetOptionInt(key, int64(size)) case OptionStatementSubstraitVersion: s.query.substraitVersion = val default: @@ -207,6 +267,43 @@ func (s *statement) SetOption(key string, val string) error { return nil } +func (s *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (s *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + s.queueSize = int(value) + return nil + } + return s.SetOptionDouble(key, float64(value)) +} + +func (s *statement) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return s.timeouts.setTimeout(key, value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -422,3 +519,21 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. return sc, out, info.TotalRecords, nil } + +// ExecuteSchema gets the schema of the result set of a query without executing it. +func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { + ctx = metadata.NewOutgoingContext(ctx, s.hdrs) + + if s.prepared != nil { + schema = s.prepared.DatasetSchema() + if schema == nil { + err = adbc.Error{ + Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement", + Code: adbc.StatusNotImplemented, + } + } + return + } + + return s.query.executeSchema(ctx, s.cnxn, s.timeouts) +} diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go index 409ce58e61..297d35f8dc 100644 --- a/go/adbc/driver/flightsql/record_reader.go +++ b/go/adbc/driver/flightsql/record_reader.go @@ -104,7 +104,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rec.Retain() ch <- rec } - return rdr.Err() + return adbcFromFlightStatus(rdr.Err()) }) endpoints = endpoints[1:] @@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rdr, err := doGet(ctx, cl, endpoint, clCache, opts...) if err != nil { - return err + return adbcFromFlightStatus(err) } defer rdr.Release() @@ -150,7 +150,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. chs[endpointIndex] <- rec } - return rdr.Err() + return adbcFromFlightStatus(rdr.Err()) }) } diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index 4f1f165c6b..fb7f323150 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -29,7 +29,9 @@ func adbcFromFlightStatus(err error) error { } var adbcCode adbc.Status - switch status.Code(err) { + // If not a status.Status, will return codes.Unknown + grpcStatus := status.Convert(err) + switch grpcStatus.Code() { case codes.OK: return nil case codes.Canceled: @@ -71,5 +73,7 @@ func adbcFromFlightStatus(err error) error { return adbc.Error{ Msg: err.Error(), Code: adbcCode, + // slice of proto.Message or error + Details: grpcStatus.Details(), } } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 8f965597c5..c321e77a60 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -22,6 +22,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "io" "strconv" "strings" "time" @@ -95,6 +96,7 @@ type cnxn struct { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -106,7 +108,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) for _, code := range infoCodes { switch code { @@ -122,6 +125,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) case adbc.InfoVendorName: infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) @@ -674,6 +681,85 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } +func (c *cnxn) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if c.activeTransaction { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return c.getStringQuery("SELECT CURRENT_DATABASE()") + case adbc.OptionKeyCurrentDbSchema: + return c.getStringQuery("SELECT CURRENT_SCHEMA()") + } + + return "", adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) getStringQuery(query string) (string, error) { + result, err := c.cn.QueryContext(context.Background(), query, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + defer result.Close() + + if len(result.Columns()) != 1 { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong number of columns: %s", result.Columns()), + Code: adbc.StatusInternal, + } + } + + dest := make([]driver.Value, 1) + err = result.Next(dest) + if err == io.EOF { + return "", adbc.Error{ + Msg: "[Snowflake] Internal query returned no rows", + Code: adbc.StatusInternal, + } + } else if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + value, ok := dest[0].(string) + if !ok { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong type of value: %s", dest[0]), + Code: adbc.StatusInternal, + } + } + + return value, nil +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + return 0.0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { @@ -840,6 +926,12 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusInvalidArgument, } } + case adbc.OptionKeyCurrentCatalog: + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err + case adbc.OptionKeyCurrentDbSchema: + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err default: return adbc.Error{ Msg: "[Snowflake] unknown connection option " + key + ": " + value, @@ -847,3 +939,24 @@ func (c *cnxn) SetOption(key, value string) error { } } } + +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index c02b58ddec..a00513817b 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -209,6 +209,105 @@ type database struct { alloc memory.Allocator } +func (d *database) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyUsername: + return d.cfg.User, nil + case adbc.OptionKeyPassword: + return d.cfg.Password, nil + case OptionDatabase: + return d.cfg.Database, nil + case OptionSchema: + return d.cfg.Schema, nil + case OptionWarehouse: + return d.cfg.Warehouse, nil + case OptionRole: + return d.cfg.Role, nil + case OptionRegion: + return d.cfg.Region, nil + case OptionAccount: + return d.cfg.Account, nil + case OptionProtocol: + return d.cfg.Protocol, nil + case OptionHost: + return d.cfg.Host, nil + case OptionPort: + return strconv.Itoa(d.cfg.Port), nil + case OptionAuthType: + return d.cfg.Authenticator.String(), nil + case OptionLoginTimeout: + return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil + case OptionRequestTimeout: + return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil + case OptionJwtExpireTimeout: + return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil + case OptionClientTimeout: + return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil + case OptionApplicationName: + return d.cfg.Application, nil + case OptionSSLSkipVerify: + if d.cfg.InsecureMode { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionOCSPFailOpenMode: + return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil + case OptionAuthToken: + return d.cfg.Token, nil + case OptionAuthOktaUrl: + return d.cfg.OktaURL.String(), nil + case OptionKeepSessionAlive: + if d.cfg.KeepSessionAlive { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionDisableTelemetry: + if d.cfg.DisableTelemetry { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientRequestMFAToken: + if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientStoreTempCred: + if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionLogTracing: + return d.cfg.Tracing, nil + default: + val, ok := d.cfg.Params[key] + if ok { + return *val, nil + } + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} + func (d *database) SetOptions(cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { @@ -421,6 +520,35 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) SetOption(key string, val string) error { + // Can't set options after init + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + func (d *database) Open(ctx context.Context) (adbc.Connection, error) { connector := gosnowflake.NewConnector(drv, *d.cfg) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 7ac1f27f84..aa72772936 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -38,10 +38,11 @@ import ( ) type SnowflakeQuirks struct { - dsn string - mem *memory.CheckedAllocator - connector gosnowflake.Connector - schemaName string + dsn string + mem *memory.CheckedAllocator + connector gosnowflake.Connector + catalogName string + schemaName string } func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver { @@ -181,11 +182,15 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error func (s *SnowflakeQuirks) Alloc() memory.Allocator { return s.mem } func (s *SnowflakeQuirks) BindParameter(_ int) string { return "?" } func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return true } +func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true } +func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return false } +func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true } func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false } func (s *SnowflakeQuirks) SupportsTransactions() bool { return true } func (s *SnowflakeQuirks) SupportsGetParameterSchema() bool { return false } func (s *SnowflakeQuirks) SupportsDynamicParameterBinding() bool { return false } func (s *SnowflakeQuirks) SupportsBulkIngest() bool { return true } +func (s *SnowflakeQuirks) Catalog() string { return s.catalogName } func (s *SnowflakeQuirks) DBSchema() string { return s.schemaName } func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { @@ -197,6 +202,8 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "Snowflake" } @@ -225,7 +232,7 @@ func createTempSchema(uri string) string { } defer db.Close() - schemaName := "ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_") + schemaName := strings.ToUpper("ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_")) _, err = db.Exec(`CREATE SCHEMA ADBC_TESTING.` + schemaName) if err != nil { panic(err) @@ -249,14 +256,16 @@ func dropTempSchema(uri, schema string) { func TestADBCSnowflake(t *testing.T) { uri := os.Getenv("SNOWFLAKE_URI") - + database := os.Getenv("SNOWFLAKE_DATABASE") if uri == "" { t.Skip("no SNOWFLAKE_URI defined, skip snowflake driver tests") + } else if database == "" { + t.Skip("no SNOWFLAKE_DATABASE defined, skip snowflake driver tests") } // avoid multiple runs clashing by operating in a fresh schema and then // dropping that schema when we're done. - q := &SnowflakeQuirks{dsn: uri, schemaName: createTempSchema(uri)} + q := &SnowflakeQuirks{dsn: uri, catalogName: database, schemaName: createTempSchema(uri)} defer dropTempSchema(uri, q.schemaName) suite.Run(t, &validation.DatabaseTests{Quirks: q}) suite.Run(t, &validation.ConnectionTests{Quirks: q}) diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 481e7f7cec..ddb81e5000 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -71,6 +71,35 @@ func (st *statement) Close() error { return nil } +func (st *statement) GetOption(key string) (string, error) { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionStatementQueueSize: + return int64(st.queueSize), nil + } + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (st *statement) SetOption(key string, val string) error { switch key { @@ -97,7 +126,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - st.queueSize = sz + return st.SetOptionInt(key, int64(sz)) default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -107,6 +136,38 @@ func (st *statement) SetOption(key string, val string) error { return nil } +func (st *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + st.queueSize = int(value) + return nil + } + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index ffc9e93dc8..8fc08c4c69 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -46,6 +46,12 @@ type DriverQuirks interface { BindParameter(index int) string // Whether two statements can be used at the same time on a single connection SupportsConcurrentStatements() bool + // Whether current catalog/schema are supported + SupportsCurrentCatalogSchema() bool + // Whether GetSetOptions is supported + SupportsGetSetOptions() bool + // Whether AdbcStatementExecuteSchema should work + SupportsExecuteSchema() bool // Whether AdbcStatementExecutePartitions should work SupportsPartitionedData() bool // Whether transactions are supported (Commit/Rollback on connection) @@ -65,6 +71,7 @@ type DriverQuirks interface { // have the driver drop a table with the correct SQL syntax DropTable(adbc.Connection, string) error + Catalog() string DBSchema() string Alloc() memory.Allocator @@ -115,6 +122,30 @@ func (c *ConnectionTests) TearDownTest() { c.DB = nil } +func (c *ConnectionTests) TestGetSetOptions() { + cnxn, err := c.DB.Open(context.Background()) + c.NoError(err) + c.NotNil(cnxn) + + stmt, err := cnxn.NewStatement() + c.NoError(err) + c.NotNil(stmt) + + expected := c.Quirks.SupportsGetSetOptions() + + _, ok := c.DB.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = cnxn.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = stmt.(adbc.GetSetOptions) + c.Equal(expected, ok) + + c.NoError(stmt.Close()) + c.NoError(cnxn.Close()) +} + func (c *ConnectionTests) TestNewConn() { cnxn, err := c.DB.Open(context.Background()) c.NoError(err) @@ -152,6 +183,12 @@ func (c *ConnectionTests) TestAutocommitDefault() { cnxn, _ := c.DB.Open(ctx) defer cnxn.Close() + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueEnabled, value) + } + expectedCode := adbc.StatusInvalidState var adbcError adbc.Error err := cnxn.Commit(ctx) @@ -188,8 +225,60 @@ func (c *ConnectionTests) TestAutocommitToggle() { c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)) c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } + // it is ok to disable autocommit when it isn't enabled c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } +} + +func (c *ConnectionTests) TestMetadataCurrentCatalog() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentCatalog) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.Catalog(), value) + } else { + c.Error(err) + } +} + +func (c *ConnectionTests) TestMetadataCurrentDbSchema() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentDbSchema) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.DBSchema(), value) + } else { + c.Error(err) + } } func (c *ConnectionTests) TestMetadataGetInfo() { @@ -201,6 +290,7 @@ func (c *ConnectionTests) TestMetadataGetInfo() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, @@ -219,14 +309,28 @@ func (c *ConnectionTests) TestMetadataGetInfo() { valUnion := rec.Column(1).(*array.DenseUnion) for i := 0; i < int(rec.NumRows()); i++ { code := codeCol.Value(i) - child := valUnion.Field(valUnion.ChildID(i)) - if child.IsNull(i) { + offset := int(valUnion.ValueOffset(i)) + valUnion.GetOneForMarshal(i) + if child.IsNull(offset) { exp := c.Quirks.GetMetadata(adbc.InfoCode(code)) c.Nilf(exp, "got nil for info %s, expected: %s", adbc.InfoCode(code), exp) } else { - // currently we only define utf8 values for metadata - c.Equal(c.Quirks.GetMetadata(adbc.InfoCode(code)), child.(*array.String).Value(i), adbc.InfoCode(code).String()) + expected := c.Quirks.GetMetadata(adbc.InfoCode(code)) + var actual interface{} + + switch valUnion.ChildID(i) { + case 0: + // String + actual = child.(*array.String).Value(offset) + case 2: + // int64 + actual = child.(*array.Int64).Value(offset) + default: + c.FailNow("Unknown union type code", valUnion.ChildID(i)) + } + + c.Equal(expected, actual, adbc.InfoCode(code).String()) } } } @@ -407,6 +511,49 @@ func (s *StatementTests) TestNewStatement() { s.Equal(adbc.StatusInvalidState, adbcError.Code) } +func (s *StatementTests) TestSqlExecuteSchema() { + if !s.Quirks.SupportsExecuteSchema() { + s.T().SkipNow() + } + + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + es, ok := stmt.(adbc.StatementExecuteSchema) + s.Require().True(ok, "%#v does not support ExecuteSchema", es) + + s.Run("no query", func() { + var adbcErr adbc.Error + + schema, err := es.ExecuteSchema(s.ctx) + s.ErrorAs(err, &adbcErr) + s.Equal(adbc.StatusInvalidState, adbcErr.Code) + s.Nil(schema) + }) + + s.Run("query", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) + + s.Run("prepared", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + s.NoError(stmt.Prepare(s.ctx)) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) +} + func (s *StatementTests) TestSqlPartitionedInts() { stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err)