From 742af7128586021b97cf216216ed977f80f6ad88 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 23 May 2023 15:42:52 -0400
Subject: [PATCH] feat(go/adbc): implement 1.1.0 features
- ADBC_INFO_DRIVER_ADBC_VERSION
- StatementExecuteSchema (#318)
- ADBC_CONNECTION_OPTION_CURRENT_{CATALOG, DB_SCHEMA} (#319)
- Get/SetOption
- error_details (#755)
---
.github/workflows/native-unix.yml | 1 +
go/adbc/adbc.go | 32 +-
go/adbc/driver/flightsql/flightsql_adbc.go | 381 +++++++++++++++---
.../flightsql/flightsql_adbc_server_test.go | 260 ++++++++++++
.../driver/flightsql/flightsql_adbc_test.go | 14 +-
.../driver/flightsql/flightsql_statement.go | 181 +++++++--
go/adbc/driver/flightsql/record_reader.go | 6 +-
go/adbc/driver/flightsql/utils.go | 6 +-
go/adbc/driver/snowflake/connection.go | 115 +++++-
go/adbc/driver/snowflake/driver.go | 128 ++++++
go/adbc/driver/snowflake/driver_test.go | 23 +-
go/adbc/driver/snowflake/statement.go | 63 ++-
go/adbc/validation/validation.go | 155 ++++++-
13 files changed, 1234 insertions(+), 131 deletions(-)
diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml
index 32d24a07ef..09adb7c321 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -315,6 +315,7 @@ jobs:
popd
- name: Go Test
env:
+ SNOWFLAKE_DATABASE: ADBC_TESTING
SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }}
run: |
./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go
index 99a4f81b75..c936fe5334 100644
--- a/go/adbc/adbc.go
+++ b/go/adbc/adbc.go
@@ -58,13 +58,15 @@ type Error struct {
// SqlState is a SQLSTATE error code, if provided, as defined
// by the SQL:2003 standard. If not set, it will be "\0\0\0\0\0"
SqlState [5]byte
- // Details is an array of additional driver-specific binary error details.
+ // Details is an array of additional driver-specific error details.
//
// This allows drivers to return custom, structured error information (for
// example, JSON or Protocol Buffers) that can be optionally parsed by
// clients, beyond the standard Error fields, without having to encode it in
- // the error message. The encoding of the data is driver-defined.
- Details [][]byte
+ // the error message. The encoding of the data is driver-defined. It is
+ // suggested to use proto.Message for Protocol Buffers and error for wrapped
+ // errors.
+ Details []interface{}
}
func (e Error) Error() string {
@@ -621,23 +623,6 @@ type Statement interface {
ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error)
}
-// Cancellable is a Connection or Statement that also supports Cancel.
-//
-// Since ADBC API revision 1.1.0.
-type Cancellable interface {
- // Cancel stops execution of an in-progress query.
- //
- // This can be called during ExecuteQuery, GetObjects, or other
- // methods that produce result sets, or while consuming a
- // RecordReader returned from such. Calling this function should
- // make the other functions return an error with a StatusCancelled
- // code.
- //
- // This must always be thread-safe (other operations are not
- // necessarily thread-safe).
- Cancel() error
-}
-
// ConnectionGetStatistics is a Connection that supports getting
// statistics on data in the database.
//
@@ -719,7 +704,10 @@ type StatementExecuteSchema interface {
ExecuteSchema(context.Context) (*arrow.Schema, error)
}
-// GetSetOptions is a PostInitOptions that also supports getting and setting property values of different types.
+// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types.
+//
+// GetOption functions should return an error with StatusNotFound for unsupported options.
+// SetOption functions should return an error with StatusNotImplemented for unsupported options.
//
// Since ADBC API revision 1.1.0.
type GetSetOptions interface {
@@ -728,7 +716,7 @@ type GetSetOptions interface {
SetOptionBytes(key string, value []byte) error
SetOptionInt(key string, value int64) error
SetOptionDouble(key string, value float64) error
- GetOption(key, value string) (string, error)
+ GetOption(key string) (string, error)
GetOptionBytes(key string) ([]byte, error)
GetOptionInt(key string) (int64, error)
GetOptionDouble(key string) (float64, error)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go
index e038354cc5..00d123b322 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -36,7 +36,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
- "errors"
"fmt"
"io"
"math"
@@ -119,20 +118,13 @@ func init() {
adbc.InfoDriverName,
adbc.InfoDriverVersion,
adbc.InfoDriverArrowVersion,
+ adbc.InfoDriverADBCVersion,
adbc.InfoVendorName,
adbc.InfoVendorVersion,
adbc.InfoVendorArrowVersion,
}
}
-func getTimeoutOptionValue(v string) (time.Duration, error) {
- timeout, err := strconv.ParseFloat(v, 64)
- if math.IsNaN(timeout) || math.IsInf(timeout, 0) || timeout < 0 {
- return 0, errors.New("timeout must be positive and finite")
- }
- return time.Duration(timeout * float64(time.Second)), err
-}
-
type Driver struct {
Alloc memory.Allocator
}
@@ -164,6 +156,8 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
db.dialOpts.block = false
db.dialOpts.maxMsgSize = 16 * 1024 * 1024
+ db.options = make(map[string]string)
+
return db, db.SetOptions(opts)
}
@@ -192,6 +186,7 @@ type database struct {
timeout timeoutOption
dialOpts dbDialOpts
enableCookies bool
+ options map[string]string
alloc memory.Allocator
}
@@ -199,6 +194,10 @@ type database struct {
func (d *database) SetOptions(cnOptions map[string]string) error {
var tlsConfig tls.Config
+ for k, v := range cnOptions {
+ d.options[k] = v
+ }
+
mtlsCert := cnOptions[OptionMTLSCertChain]
mtlsKey := cnOptions[OptionMTLSPrivateKey]
switch {
@@ -287,33 +286,24 @@ func (d *database) SetOptions(cnOptions map[string]string) error {
var err error
if tv, ok := cnOptions[OptionTimeoutFetch]; ok {
- if d.timeout.fetchTimeout, err = getTimeoutOptionValue(tv); err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, tv, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
+ if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil {
+ return err
}
+ delete(cnOptions, OptionTimeoutFetch)
}
if tv, ok := cnOptions[OptionTimeoutQuery]; ok {
- if d.timeout.queryTimeout, err = getTimeoutOptionValue(tv); err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutQuery, tv, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
+ if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil {
+ return err
}
+ delete(cnOptions, OptionTimeoutQuery)
}
if tv, ok := cnOptions[OptionTimeoutUpdate]; ok {
- if d.timeout.updateTimeout, err = getTimeoutOptionValue(tv); err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutUpdate, tv, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
+ if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil {
+ return err
}
+ delete(cnOptions, OptionTimeoutUpdate)
}
if val, ok := cnOptions[OptionWithBlock]; ok {
@@ -369,7 +359,7 @@ func (d *database) SetOptions(cnOptions map[string]string) error {
continue
}
return adbc.Error{
- Msg: fmt.Sprintf("Unknown database option '%s'", key),
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusInvalidArgument,
}
}
@@ -377,6 +367,118 @@ func (d *database) SetOptions(cnOptions map[string]string) error {
return nil
}
+func (d *database) GetOption(key string) (string, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return d.timeout.fetchTimeout.String(), nil
+ case OptionTimeoutQuery:
+ return d.timeout.queryTimeout.String(), nil
+ case OptionTimeoutUpdate:
+ return d.timeout.updateTimeout.String(), nil
+ }
+ if val, ok := d.options[key]; ok {
+ return val, nil
+ }
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionInt(key string) (int64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ val, err := d.GetOptionDouble(key)
+ if err != nil {
+ return 0, err
+ }
+ return int64(val), nil
+ }
+
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionDouble(key string) (float64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return d.timeout.fetchTimeout.Seconds(), nil
+ case OptionTimeoutQuery:
+ return d.timeout.queryTimeout.Seconds(), nil
+ case OptionTimeoutUpdate:
+ return d.timeout.updateTimeout.Seconds(), nil
+ }
+
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) SetOption(key, value string) error {
+ // We can't change most options post-init
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return d.timeout.setTimeoutString(key, value)
+ }
+ if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
+ d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value)
+ }
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+func (d *database) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+func (d *database) SetOptionInt(key string, value int64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return d.timeout.setTimeout(key, float64(value))
+ }
+
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+func (d *database) SetOptionDouble(key string, value float64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return d.timeout.setTimeout(key, value)
+ }
+
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
type timeoutOption struct {
grpc.EmptyCallOption
@@ -388,6 +490,45 @@ type timeoutOption struct {
updateTimeout time.Duration
}
+func (t *timeoutOption) setTimeout(key string, value float64) error {
+ if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite",
+ key, value),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ timeout := time.Duration(value * float64(time.Second))
+
+ switch key {
+ case OptionTimeoutFetch:
+ t.fetchTimeout = timeout
+ case OptionTimeoutQuery:
+ t.queryTimeout = timeout
+ case OptionTimeoutUpdate:
+ t.updateTimeout = timeout
+ default:
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+ }
+ return nil
+}
+
+func (t *timeoutOption) setTimeoutString(key string, value string) error {
+ timeout, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s",
+ key, value, err.Error()),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ return t.setTimeout(key, timeout)
+}
+
func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) {
for _, opt := range callOptions {
if to, ok := opt.(timeoutOption); ok {
@@ -729,6 +870,96 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd
return nil, err
}
+func (c *cnxn) GetOption(key string) (string, error) {
+ if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
+ name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
+ headers := c.hdrs.Get(name)
+ if len(headers) > 0 {
+ return headers[0], nil
+ }
+ return "", adbc.Error{
+ Msg: "[Flight SQL] unknown header",
+ Code: adbc.StatusNotFound,
+ }
+ }
+
+ switch key {
+ case OptionTimeoutFetch:
+ return c.timeouts.fetchTimeout.String(), nil
+ case OptionTimeoutQuery:
+ return c.timeouts.queryTimeout.String(), nil
+ case OptionTimeoutUpdate:
+ return c.timeouts.updateTimeout.String(), nil
+ case adbc.OptionKeyAutoCommit:
+ if c.txn != nil {
+ // No autocommit
+ return adbc.OptionValueDisabled, nil
+ } else {
+ // Autocommit
+ return adbc.OptionValueEnabled, nil
+ }
+ case adbc.OptionKeyCurrentCatalog:
+ return "", adbc.Error{
+ Msg: "[Flight SQL] current catalog not supported",
+ Code: adbc.StatusNotFound,
+ }
+
+ case adbc.OptionKeyCurrentDbSchema:
+ return "", adbc.Error{
+ Msg: "[Flight SQL] current schema not supported",
+ Code: adbc.StatusNotFound,
+ }
+ }
+
+ return "", adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) GetOptionInt(key string) (int64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ val, err := c.GetOptionDouble(key)
+ if err != nil {
+ return 0, err
+ }
+ return int64(val), nil
+ }
+
+ return 0, adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) GetOptionDouble(key string) (float64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return c.timeouts.fetchTimeout.Seconds(), nil
+ case OptionTimeoutQuery:
+ return c.timeouts.queryTimeout.Seconds(), nil
+ case OptionTimeoutUpdate:
+ return c.timeouts.updateTimeout.Seconds(), nil
+ }
+
+ return 0.0, adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
func (c *cnxn) SetOption(key, value string) error {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
@@ -742,39 +973,16 @@ func (c *cnxn) SetOption(key, value string) error {
switch key {
case OptionTimeoutFetch:
- timeout, err := getTimeoutOptionValue(value)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, value, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- c.timeouts.fetchTimeout = timeout
+ fallthrough
case OptionTimeoutQuery:
- timeout, err := getTimeoutOptionValue(value)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, value, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- c.timeouts.queryTimeout = timeout
+ fallthrough
case OptionTimeoutUpdate:
- timeout, err := getTimeoutOptionValue(value)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, value, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- c.timeouts.updateTimeout = timeout
+ return c.timeouts.setTimeoutString(key, value)
case adbc.OptionKeyAutoCommit:
autocommit := true
switch value {
case adbc.OptionValueEnabled:
+ autocommit = true
case adbc.OptionValueDisabled:
autocommit = false
default:
@@ -827,6 +1035,45 @@ func (c *cnxn) SetOption(key, value string) error {
return nil
}
+func (c *cnxn) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (c *cnxn) SetOptionInt(key string, value int64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return c.timeouts.setTimeout(key, float64(value))
+ }
+
+ return adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (c *cnxn) SetOptionDouble(key string, value float64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return c.timeouts.setTimeout(key, value)
+ }
+
+ return adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
// GetInfo returns metadata about the database/driver.
//
// The result is an Arrow dataset with the following schema:
@@ -853,6 +1100,7 @@ func (c *cnxn) SetOption(key, value string) error {
// codes (the row will be omitted from the result).
func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) {
const strValTypeID arrow.UnionTypeCode = 0
+ const intValTypeID arrow.UnionTypeCode = 2
if len(infoCodes) == 0 {
infoCodes = infoSupportedCodes
@@ -864,7 +1112,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
- strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder)
+ strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder)
+ intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder)
translated := make([]flightsql.SqlInfo, 0, len(infoCodes))
for _, code := range infoCodes {
@@ -886,6 +1135,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverArrowVersion)
+ case adbc.InfoDriverADBCVersion:
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(intValTypeID)
+ intInfoBldr.Append(adbc.AdbcVersion1_1_0)
}
}
@@ -1350,6 +1603,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio
return c.cl.Execute(ctx, query, opts...)
}
+func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ if c.txn != nil {
+ return c.txn.GetExecuteSchema(ctx, query, opts...)
+ }
+
+ return c.cl.GetExecuteSchema(ctx, query, opts...)
+}
+
func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
if c.txn != nil {
return c.txn.ExecuteSubstrait(ctx, plan, opts...)
@@ -1358,6 +1619,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla
return c.cl.ExecuteSubstrait(ctx, plan, opts...)
}
+func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ if c.txn != nil {
+ return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...)
+ }
+
+ return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...)
+}
+
func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) {
if c.txn != nil {
return c.txn.ExecuteUpdate(ctx, query, opts...)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 9d959ac4c6..50f1d9b1f0 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
+ "google.golang.org/protobuf/types/known/wrapperspb"
)
// ---- Common Infra --------------------
@@ -95,6 +96,14 @@ func TestAuthn(t *testing.T) {
suite.Run(t, &AuthnTests{})
}
+func TestErrorDetails(t *testing.T) {
+ suite.Run(t, &ErrorDetailsTests{})
+}
+
+func TestExecuteSchema(t *testing.T) {
+ suite.Run(t, &ExecuteSchemaTests{})
+}
+
func TestTimeout(t *testing.T) {
suite.Run(t, &TimeoutTests{})
}
@@ -202,6 +211,196 @@ func (suite *AuthnTests) TestBearerTokenUpdated() {
defer reader.Release()
}
+// ---- Error Details Tests --------------------
+
+type ErrorDetailsTestServer struct {
+ flightsql.BaseServer
+}
+
+func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
+ if query.GetQuery() == "details" {
+ detail := wrapperspb.Int32Value{Value: 42}
+ st, err := status.New(codes.Unknown, "details").WithDetails(&detail)
+ if err != nil {
+ return nil, err
+ }
+ return nil, st.Err()
+ } else if query.GetQuery() == "query" {
+ tkt, err := flightsql.CreateStatementQueryTicket([]byte("fetch"))
+ if err != nil {
+ panic(err)
+ }
+ return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil
+ }
+ return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented")
+}
+
+func (ts *ErrorDetailsTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+ sc := arrow.NewSchema([]arrow.Field{}, nil)
+ detail := wrapperspb.Int32Value{Value: 42}
+ st, err := status.New(codes.Unknown, "details").WithDetails(&detail)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: nil,
+ Desc: nil,
+ Err: st.Err(),
+ }
+ }()
+ return sc, ch, nil
+}
+
+type ErrorDetailsTests struct {
+ ServerBasedTests
+}
+
+func (suite *ErrorDetailsTests) SetupSuite() {
+ srv := ErrorDetailsTestServer{}
+ srv.Alloc = memory.DefaultAllocator
+ suite.DoSetupSuite(&srv, nil, nil)
+}
+
+func (ts *ErrorDetailsTests) TestGetFlightInfo() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("details"))
+
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+
+ ts.Equal(1, len(adbcErr.Details))
+
+ message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value)
+ ts.True(ok, "Got message: %#v", message)
+ ts.Equal(int32(42), message.Value)
+}
+
+func (ts *ErrorDetailsTests) TestDoGet() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("query"))
+
+ reader, _, err := stmt.ExecuteQuery(context.Background())
+ ts.NoError(err)
+
+ defer reader.Release()
+
+ for reader.Next() {
+ }
+ err = reader.Err()
+
+ ts.Error(err)
+
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr, "Error was: %#v", err)
+
+ ts.Equal(1, len(adbcErr.Details))
+
+ message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value)
+ ts.True(ok, "Got message: %#v", message)
+ ts.Equal(int32(42), message.Value)
+}
+
+// ---- ExecuteSchema Tests --------------------
+
+type ExecuteSchemaTestServer struct {
+ flightsql.BaseServer
+}
+
+func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ if query.GetQuery() == "sample query" {
+ return &flight.SchemaResult{
+ Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{
+ {Name: "ints", Type: arrow.PrimitiveTypes.Int32},
+ }, nil), srv.Alloc),
+ }, nil
+ }
+ return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented")
+}
+
+func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) {
+ if req.GetQuery() == "sample query" {
+ return flightsql.ActionCreatePreparedStatementResult{
+ DatasetSchema: arrow.NewSchema([]arrow.Field{
+ {Name: "ints", Type: arrow.PrimitiveTypes.Int32},
+ }, nil),
+ }, nil
+ }
+ return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented")
+}
+
+type ExecuteSchemaTests struct {
+ ServerBasedTests
+}
+
+func (suite *ExecuteSchemaTests) SetupSuite() {
+ srv := ExecuteSchemaTestServer{}
+ srv.Alloc = memory.DefaultAllocator
+ suite.DoSetupSuite(&srv, nil, nil)
+}
+
+func (ts *ExecuteSchemaTests) TestNoQuery() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ es := stmt.(adbc.StatementExecuteSchema)
+ _, err = es.ExecuteSchema(context.Background())
+
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error())
+}
+
+func (ts *ExecuteSchemaTests) TestPreparedQuery() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("sample query"))
+ ts.NoError(stmt.Prepare(context.Background()))
+
+ es := stmt.(adbc.StatementExecuteSchema)
+ schema, err := es.ExecuteSchema(context.Background())
+ ts.NoError(err)
+ ts.NotNil(schema)
+
+ expectedSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "ints", Type: arrow.PrimitiveTypes.Int32},
+ }, nil)
+
+ ts.True(expectedSchema.Equal(schema), schema.String())
+}
+
+func (ts *ExecuteSchemaTests) TestQuery() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("sample query"))
+
+ es := stmt.(adbc.StatementExecuteSchema)
+ schema, err := es.ExecuteSchema(context.Background())
+ ts.NoError(err)
+ ts.NotNil(schema)
+
+ expectedSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "ints", Type: arrow.PrimitiveTypes.Int32},
+ }, nil)
+
+ ts.True(expectedSchema.Equal(schema), schema.String())
+}
+
// ---- Timeout Tests --------------------
type TimeoutTestServer struct {
@@ -321,6 +520,67 @@ func (ts *TimeoutTests) TestRemoveTimeout() {
}
}
+func (ts *TimeoutTests) TestGetSet() {
+ keys := []string{
+ "adbc.flight.sql.rpc.timeout_seconds.fetch",
+ "adbc.flight.sql.rpc.timeout_seconds.query",
+ "adbc.flight.sql.rpc.timeout_seconds.update",
+ }
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ for _, v := range []interface{}{ts.db, ts.cnxn, stmt} {
+ getset := v.(adbc.GetSetOptions)
+
+ for _, k := range keys {
+ strval, err := getset.GetOption(k)
+ ts.NoError(err)
+ ts.Equal("0s", strval)
+
+ intval, err := getset.GetOptionInt(k)
+ ts.NoError(err)
+ ts.Equal(int64(0), intval)
+
+ floatval, err := getset.GetOptionDouble(k)
+ ts.NoError(err)
+ ts.Equal(0.0, floatval)
+
+ err = getset.SetOptionInt(k, 1)
+ ts.NoError(err)
+
+ strval, err = getset.GetOption(k)
+ ts.NoError(err)
+ ts.Equal("1s", strval)
+
+ intval, err = getset.GetOptionInt(k)
+ ts.NoError(err)
+ ts.Equal(int64(1), intval)
+
+ floatval, err = getset.GetOptionDouble(k)
+ ts.NoError(err)
+ ts.Equal(1.0, floatval)
+
+ err = getset.SetOptionDouble(k, 0.1)
+ ts.NoError(err)
+
+ strval, err = getset.GetOption(k)
+ ts.NoError(err)
+ ts.Equal("100ms", strval)
+
+ intval, err = getset.GetOptionInt(k)
+ ts.NoError(err)
+ // truncated
+ ts.Equal(int64(0), intval)
+
+ floatval, err = getset.GetOptionDouble(k)
+ ts.NoError(err)
+ ts.Equal(0.1, floatval)
+ }
+ }
+
+}
+
func (ts *TimeoutTests) TestDoActionTimeout() {
ts.NoError(ts.cnxn.(adbc.PostInitOptions).
SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1"))
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 53dbac2412..381ad0c67c 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -229,9 +229,14 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error
return err
}
-func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem }
-func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" }
-func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true }
+func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem }
+func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" }
+func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true }
+func (s *FlightSQLQuirks) SupportsCurrentCatalogSchema() bool { return false }
+
+// The driver supports it, but the server we use for testing does not.
+func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return false }
+func (s *FlightSQLQuirks) SupportsGetSetOptions() bool { return true }
func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true }
func (s *FlightSQLQuirks) SupportsTransactions() bool { return true }
func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false }
@@ -247,6 +252,8 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} {
return "(unknown or development build)"
case adbc.InfoDriverArrowVersion:
return "(unknown or development build)"
+ case adbc.InfoDriverADBCVersion:
+ return adbc.AdbcVersion1_1_0
case adbc.InfoVendorName:
return "db_name"
case adbc.InfoVendorVersion:
@@ -273,6 +280,7 @@ func (s *FlightSQLQuirks) SampleTableSchemaMetadata(tblName string, dt arrow.Dat
}
}
+func (s *FlightSQLQuirks) Catalog() string { return "" }
func (s *FlightSQLQuirks) DBSchema() string { return "" }
func TestADBCFlightSQL(t *testing.T) {
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go
index c7f074a800..04f46498f4 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -73,6 +73,29 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C
}
}
+func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) {
+ var (
+ res *flight.SchemaResult
+ err error
+ )
+ if s.sqlQuery != "" {
+ res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...)
+ } else if s.substraitPlan != nil {
+ res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...)
+ } else {
+ return nil, adbc.Error{
+ Code: adbc.StatusInvalidState,
+ Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement",
+ }
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc)
+}
+
func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) {
if s.sqlQuery != "" {
return cnxn.executeUpdate(ctx, s.sqlQuery, opts...)
@@ -138,6 +161,72 @@ func (s *statement) Close() (err error) {
return err
}
+func (s *statement) GetOption(key string) (string, error) {
+ switch key {
+ case OptionStatementSubstraitVersion:
+ return s.query.substraitVersion, nil
+ case OptionTimeoutFetch:
+ return s.timeouts.fetchTimeout.String(), nil
+ case OptionTimeoutQuery:
+ return s.timeouts.queryTimeout.String(), nil
+ case OptionTimeoutUpdate:
+ return s.timeouts.updateTimeout.String(), nil
+ }
+
+ if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
+ name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
+ values := s.hdrs.Get(name)
+ if len(values) > 0 {
+ return values[0], nil
+ }
+ }
+
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (s *statement) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (s *statement) GetOptionInt(key string) (int64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ val, err := s.GetOptionDouble(key)
+ if err != nil {
+ return 0, err
+ }
+ return int64(val), nil
+ }
+
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (s *statement) GetOptionDouble(key string) (float64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return s.timeouts.fetchTimeout.Seconds(), nil
+ case OptionTimeoutQuery:
+ return s.timeouts.queryTimeout.Seconds(), nil
+ case OptionTimeoutUpdate:
+ return s.timeouts.updateTimeout.Seconds(), nil
+ }
+
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+
// SetOption sets a string option on this statement
func (s *statement) SetOption(key string, val string) error {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
@@ -152,35 +241,11 @@ func (s *statement) SetOption(key string, val string) error {
switch key {
case OptionTimeoutFetch:
- timeout, err := getTimeoutOptionValue(val)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, val, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- s.timeouts.fetchTimeout = timeout
+ fallthrough
case OptionTimeoutQuery:
- timeout, err := getTimeoutOptionValue(val)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, val, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- s.timeouts.queryTimeout = timeout
+ fallthrough
case OptionTimeoutUpdate:
- timeout, err := getTimeoutOptionValue(val)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s",
- OptionTimeoutFetch, val, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- s.timeouts.updateTimeout = timeout
+ return s.timeouts.setTimeoutString(key, val)
case OptionStatementQueueSize:
var err error
var size int
@@ -189,13 +254,8 @@ func (s *statement) SetOption(key string, val string) error {
Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val),
Code: adbc.StatusInvalidArgument,
}
- } else if size <= 0 {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val),
- Code: adbc.StatusInvalidArgument,
- }
}
- s.queueSize = size
+ return s.SetOptionInt(key, int64(size))
case OptionStatementSubstraitVersion:
s.query.substraitVersion = val
default:
@@ -207,6 +267,43 @@ func (s *statement) SetOption(key string, val string) error {
return nil
}
+func (s *statement) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (s *statement) SetOptionInt(key string, value int64) error {
+ switch key {
+ case OptionStatementQueueSize:
+ if value <= 0 {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ s.queueSize = int(value)
+ return nil
+ }
+ return s.SetOptionDouble(key, float64(value))
+}
+
+func (s *statement) SetOptionDouble(key string, value float64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return s.timeouts.setTimeout(key, value)
+ }
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
// SetSqlQuery sets the query string to be executed.
//
// The query can then be executed with any of the Execute methods.
@@ -422,3 +519,21 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
return sc, out, info.TotalRecords, nil
}
+
+// ExecuteSchema gets the schema of the result set of a query without executing it.
+func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) {
+ ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
+
+ if s.prepared != nil {
+ schema = s.prepared.DatasetSchema()
+ if schema == nil {
+ err = adbc.Error{
+ Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement",
+ Code: adbc.StatusNotImplemented,
+ }
+ }
+ return
+ }
+
+ return s.query.executeSchema(ctx, s.cnxn, s.timeouts)
+}
diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go
index 409ce58e61..297d35f8dc 100644
--- a/go/adbc/driver/flightsql/record_reader.go
+++ b/go/adbc/driver/flightsql/record_reader.go
@@ -104,7 +104,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
rec.Retain()
ch <- rec
}
- return rdr.Err()
+ return adbcFromFlightStatus(rdr.Err())
})
endpoints = endpoints[1:]
@@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
if err != nil {
- return err
+ return adbcFromFlightStatus(err)
}
defer rdr.Release()
@@ -150,7 +150,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
chs[endpointIndex] <- rec
}
- return rdr.Err()
+ return adbcFromFlightStatus(rdr.Err())
})
}
diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go
index 4f1f165c6b..fb7f323150 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -29,7 +29,9 @@ func adbcFromFlightStatus(err error) error {
}
var adbcCode adbc.Status
- switch status.Code(err) {
+ // If not a status.Status, will return codes.Unknown
+ grpcStatus := status.Convert(err)
+ switch grpcStatus.Code() {
case codes.OK:
return nil
case codes.Canceled:
@@ -71,5 +73,7 @@ func adbcFromFlightStatus(err error) error {
return adbc.Error{
Msg: err.Error(),
Code: adbcCode,
+ // slice of proto.Message or error
+ Details: grpcStatus.Details(),
}
}
diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go
index 8f965597c5..c321e77a60 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -22,6 +22,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
+ "io"
"strconv"
"strings"
"time"
@@ -95,6 +96,7 @@ type cnxn struct {
// codes (the row will be omitted from the result).
func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) {
const strValTypeID arrow.UnionTypeCode = 0
+ const intValTypeID arrow.UnionTypeCode = 2
if len(infoCodes) == 0 {
infoCodes = infoSupportedCodes
@@ -106,7 +108,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
- strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder)
+ strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder)
+ intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder)
for _, code := range infoCodes {
switch code {
@@ -122,6 +125,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverArrowVersion)
+ case adbc.InfoDriverADBCVersion:
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(intValTypeID)
+ intInfoBldr.Append(adbc.AdbcVersion1_1_0)
case adbc.InfoVendorName:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
@@ -674,6 +681,85 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie
return
}
+func (c *cnxn) GetOption(key string) (string, error) {
+ switch key {
+ case adbc.OptionKeyAutoCommit:
+ if c.activeTransaction {
+ // No autocommit
+ return adbc.OptionValueDisabled, nil
+ } else {
+ // Autocommit
+ return adbc.OptionValueEnabled, nil
+ }
+ case adbc.OptionKeyCurrentCatalog:
+ return c.getStringQuery("SELECT CURRENT_DATABASE()")
+ case adbc.OptionKeyCurrentDbSchema:
+ return c.getStringQuery("SELECT CURRENT_SCHEMA()")
+ }
+
+ return "", adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) getStringQuery(query string) (string, error) {
+ result, err := c.cn.QueryContext(context.Background(), query, nil)
+ if err != nil {
+ return "", errToAdbcErr(adbc.StatusInternal, err)
+ }
+ defer result.Close()
+
+ if len(result.Columns()) != 1 {
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong number of columns: %s", result.Columns()),
+ Code: adbc.StatusInternal,
+ }
+ }
+
+ dest := make([]driver.Value, 1)
+ err = result.Next(dest)
+ if err == io.EOF {
+ return "", adbc.Error{
+ Msg: "[Snowflake] Internal query returned no rows",
+ Code: adbc.StatusInternal,
+ }
+ } else if err != nil {
+ return "", errToAdbcErr(adbc.StatusInternal, err)
+ }
+
+ value, ok := dest[0].(string)
+ if !ok {
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong type of value: %s", dest[0]),
+ Code: adbc.StatusInternal,
+ }
+ }
+
+ return value, nil
+}
+
+func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) GetOptionInt(key string) (int64, error) {
+ return 0, adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
+func (c *cnxn) GetOptionDouble(key string) (float64, error) {
+ return 0.0, adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotFound,
+ }
+}
+
func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) {
tblParts := make([]string, 0, 3)
if catalog != nil {
@@ -840,6 +926,12 @@ func (c *cnxn) SetOption(key, value string) error {
Code: adbc.StatusInvalidArgument,
}
}
+ case adbc.OptionKeyCurrentCatalog:
+ _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}})
+ return err
+ case adbc.OptionKeyCurrentDbSchema:
+ _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}})
+ return err
default:
return adbc.Error{
Msg: "[Snowflake] unknown connection option " + key + ": " + value,
@@ -847,3 +939,24 @@ func (c *cnxn) SetOption(key, value string) error {
}
}
}
+
+func (c *cnxn) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (c *cnxn) SetOptionInt(key string, value int64) error {
+ return adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (c *cnxn) SetOptionDouble(key string, value float64) error {
+ return adbc.Error{
+ Msg: "[Snowflake] unknown connection option",
+ Code: adbc.StatusNotImplemented,
+ }
+}
diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go
index c02b58ddec..a00513817b 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -209,6 +209,105 @@ type database struct {
alloc memory.Allocator
}
+func (d *database) GetOption(key string) (string, error) {
+ switch key {
+ case adbc.OptionKeyUsername:
+ return d.cfg.User, nil
+ case adbc.OptionKeyPassword:
+ return d.cfg.Password, nil
+ case OptionDatabase:
+ return d.cfg.Database, nil
+ case OptionSchema:
+ return d.cfg.Schema, nil
+ case OptionWarehouse:
+ return d.cfg.Warehouse, nil
+ case OptionRole:
+ return d.cfg.Role, nil
+ case OptionRegion:
+ return d.cfg.Region, nil
+ case OptionAccount:
+ return d.cfg.Account, nil
+ case OptionProtocol:
+ return d.cfg.Protocol, nil
+ case OptionHost:
+ return d.cfg.Host, nil
+ case OptionPort:
+ return strconv.Itoa(d.cfg.Port), nil
+ case OptionAuthType:
+ return d.cfg.Authenticator.String(), nil
+ case OptionLoginTimeout:
+ return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil
+ case OptionRequestTimeout:
+ return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil
+ case OptionJwtExpireTimeout:
+ return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil
+ case OptionClientTimeout:
+ return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil
+ case OptionApplicationName:
+ return d.cfg.Application, nil
+ case OptionSSLSkipVerify:
+ if d.cfg.InsecureMode {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case OptionOCSPFailOpenMode:
+ return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil
+ case OptionAuthToken:
+ return d.cfg.Token, nil
+ case OptionAuthOktaUrl:
+ return d.cfg.OktaURL.String(), nil
+ case OptionKeepSessionAlive:
+ if d.cfg.KeepSessionAlive {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case OptionDisableTelemetry:
+ if d.cfg.DisableTelemetry {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case OptionClientRequestMFAToken:
+ if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case OptionClientStoreTempCred:
+ if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case OptionLogTracing:
+ return d.cfg.Tracing, nil
+ default:
+ val, ok := d.cfg.Params[key]
+ if ok {
+ return *val, nil
+ }
+ }
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionInt(key string) (int64, error) {
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (d *database) GetOptionDouble(key string) (float64, error) {
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+
func (d *database) SetOptions(cnOptions map[string]string) error {
uri, ok := cnOptions[adbc.OptionKeyURI]
if ok {
@@ -421,6 +520,35 @@ func (d *database) SetOptions(cnOptions map[string]string) error {
return nil
}
+func (d *database) SetOption(key string, val string) error {
+ // Can't set options after init
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (d *database) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (d *database) SetOptionInt(key string, value int64) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (d *database) SetOptionDouble(key string, value float64) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
connector := gosnowflake.NewConnector(drv, *d.cfg)
diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go
index 7ac1f27f84..aa72772936 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -38,10 +38,11 @@ import (
)
type SnowflakeQuirks struct {
- dsn string
- mem *memory.CheckedAllocator
- connector gosnowflake.Connector
- schemaName string
+ dsn string
+ mem *memory.CheckedAllocator
+ connector gosnowflake.Connector
+ catalogName string
+ schemaName string
}
func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver {
@@ -181,11 +182,15 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error
func (s *SnowflakeQuirks) Alloc() memory.Allocator { return s.mem }
func (s *SnowflakeQuirks) BindParameter(_ int) string { return "?" }
func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return true }
+func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true }
+func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return false }
+func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true }
func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false }
func (s *SnowflakeQuirks) SupportsTransactions() bool { return true }
func (s *SnowflakeQuirks) SupportsGetParameterSchema() bool { return false }
func (s *SnowflakeQuirks) SupportsDynamicParameterBinding() bool { return false }
func (s *SnowflakeQuirks) SupportsBulkIngest() bool { return true }
+func (s *SnowflakeQuirks) Catalog() string { return s.catalogName }
func (s *SnowflakeQuirks) DBSchema() string { return s.schemaName }
func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} {
switch code {
@@ -197,6 +202,8 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} {
return "(unknown or development build)"
case adbc.InfoDriverArrowVersion:
return "(unknown or development build)"
+ case adbc.InfoDriverADBCVersion:
+ return adbc.AdbcVersion1_1_0
case adbc.InfoVendorName:
return "Snowflake"
}
@@ -225,7 +232,7 @@ func createTempSchema(uri string) string {
}
defer db.Close()
- schemaName := "ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_")
+ schemaName := strings.ToUpper("ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_"))
_, err = db.Exec(`CREATE SCHEMA ADBC_TESTING.` + schemaName)
if err != nil {
panic(err)
@@ -249,14 +256,16 @@ func dropTempSchema(uri, schema string) {
func TestADBCSnowflake(t *testing.T) {
uri := os.Getenv("SNOWFLAKE_URI")
-
+ database := os.Getenv("SNOWFLAKE_DATABASE")
if uri == "" {
t.Skip("no SNOWFLAKE_URI defined, skip snowflake driver tests")
+ } else if database == "" {
+ t.Skip("no SNOWFLAKE_DATABASE defined, skip snowflake driver tests")
}
// avoid multiple runs clashing by operating in a fresh schema and then
// dropping that schema when we're done.
- q := &SnowflakeQuirks{dsn: uri, schemaName: createTempSchema(uri)}
+ q := &SnowflakeQuirks{dsn: uri, catalogName: database, schemaName: createTempSchema(uri)}
defer dropTempSchema(uri, q.schemaName)
suite.Run(t, &validation.DatabaseTests{Quirks: q})
suite.Run(t, &validation.ConnectionTests{Quirks: q})
diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go
index 481e7f7cec..ddb81e5000 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -71,6 +71,35 @@ func (st *statement) Close() error {
return nil
}
+func (st *statement) GetOption(key string) (string, error) {
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (st *statement) GetOptionBytes(key string) ([]byte, error) {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (st *statement) GetOptionInt(key string) (int64, error) {
+ switch key {
+ case OptionStatementQueueSize:
+ return int64(st.queueSize), nil
+ }
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+func (st *statement) GetOptionDouble(key string) (float64, error) {
+ return 0, adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+}
+
// SetOption sets a string option on this statement
func (st *statement) SetOption(key string, val string) error {
switch key {
@@ -97,7 +126,7 @@ func (st *statement) SetOption(key string, val string) error {
Code: adbc.StatusInvalidArgument,
}
}
- st.queueSize = sz
+ return st.SetOptionInt(key, int64(sz))
default:
return adbc.Error{
Msg: fmt.Sprintf("invalid statement option %s=%s", key, val),
@@ -107,6 +136,38 @@ func (st *statement) SetOption(key string, val string) error {
return nil
}
+func (st *statement) SetOptionBytes(key string, value []byte) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (st *statement) SetOptionInt(key string, value int64) error {
+ switch key {
+ case OptionStatementQueueSize:
+ if value <= 0 {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ st.queueSize = int(value)
+ return nil
+ }
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
+func (st *statement) SetOptionDouble(key string, value float64) error {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+}
+
// SetSqlQuery sets the query string to be executed.
//
// The query can then be executed with any of the Execute methods.
diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go
index ffc9e93dc8..8fc08c4c69 100644
--- a/go/adbc/validation/validation.go
+++ b/go/adbc/validation/validation.go
@@ -46,6 +46,12 @@ type DriverQuirks interface {
BindParameter(index int) string
// Whether two statements can be used at the same time on a single connection
SupportsConcurrentStatements() bool
+ // Whether current catalog/schema are supported
+ SupportsCurrentCatalogSchema() bool
+ // Whether GetSetOptions is supported
+ SupportsGetSetOptions() bool
+ // Whether AdbcStatementExecuteSchema should work
+ SupportsExecuteSchema() bool
// Whether AdbcStatementExecutePartitions should work
SupportsPartitionedData() bool
// Whether transactions are supported (Commit/Rollback on connection)
@@ -65,6 +71,7 @@ type DriverQuirks interface {
// have the driver drop a table with the correct SQL syntax
DropTable(adbc.Connection, string) error
+ Catalog() string
DBSchema() string
Alloc() memory.Allocator
@@ -115,6 +122,30 @@ func (c *ConnectionTests) TearDownTest() {
c.DB = nil
}
+func (c *ConnectionTests) TestGetSetOptions() {
+ cnxn, err := c.DB.Open(context.Background())
+ c.NoError(err)
+ c.NotNil(cnxn)
+
+ stmt, err := cnxn.NewStatement()
+ c.NoError(err)
+ c.NotNil(stmt)
+
+ expected := c.Quirks.SupportsGetSetOptions()
+
+ _, ok := c.DB.(adbc.GetSetOptions)
+ c.Equal(expected, ok)
+
+ _, ok = cnxn.(adbc.GetSetOptions)
+ c.Equal(expected, ok)
+
+ _, ok = stmt.(adbc.GetSetOptions)
+ c.Equal(expected, ok)
+
+ c.NoError(stmt.Close())
+ c.NoError(cnxn.Close())
+}
+
func (c *ConnectionTests) TestNewConn() {
cnxn, err := c.DB.Open(context.Background())
c.NoError(err)
@@ -152,6 +183,12 @@ func (c *ConnectionTests) TestAutocommitDefault() {
cnxn, _ := c.DB.Open(ctx)
defer cnxn.Close()
+ if getset, ok := cnxn.(adbc.GetSetOptions); ok {
+ value, err := getset.GetOption(adbc.OptionKeyAutoCommit)
+ c.NoError(err)
+ c.Equal(adbc.OptionValueEnabled, value)
+ }
+
expectedCode := adbc.StatusInvalidState
var adbcError adbc.Error
err := cnxn.Commit(ctx)
@@ -188,8 +225,60 @@ func (c *ConnectionTests) TestAutocommitToggle() {
c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled))
c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
+ if getset, ok := cnxn.(adbc.GetSetOptions); ok {
+ value, err := getset.GetOption(adbc.OptionKeyAutoCommit)
+ c.NoError(err)
+ c.Equal(adbc.OptionValueDisabled, value)
+ }
+
// it is ok to disable autocommit when it isn't enabled
c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
+
+ if getset, ok := cnxn.(adbc.GetSetOptions); ok {
+ value, err := getset.GetOption(adbc.OptionKeyAutoCommit)
+ c.NoError(err)
+ c.Equal(adbc.OptionValueDisabled, value)
+ }
+}
+
+func (c *ConnectionTests) TestMetadataCurrentCatalog() {
+ ctx := context.Background()
+ cnxn, _ := c.DB.Open(ctx)
+ defer cnxn.Close()
+ getset, ok := cnxn.(adbc.GetSetOptions)
+
+ if !c.Quirks.SupportsGetSetOptions() {
+ c.False(ok)
+ return
+ }
+ c.True(ok)
+ value, err := getset.GetOption(adbc.OptionKeyCurrentCatalog)
+ if c.Quirks.SupportsCurrentCatalogSchema() {
+ c.NoError(err)
+ c.Equal(c.Quirks.Catalog(), value)
+ } else {
+ c.Error(err)
+ }
+}
+
+func (c *ConnectionTests) TestMetadataCurrentDbSchema() {
+ ctx := context.Background()
+ cnxn, _ := c.DB.Open(ctx)
+ defer cnxn.Close()
+ getset, ok := cnxn.(adbc.GetSetOptions)
+
+ if !c.Quirks.SupportsGetSetOptions() {
+ c.False(ok)
+ return
+ }
+ c.True(ok)
+ value, err := getset.GetOption(adbc.OptionKeyCurrentDbSchema)
+ if c.Quirks.SupportsCurrentCatalogSchema() {
+ c.NoError(err)
+ c.Equal(c.Quirks.DBSchema(), value)
+ } else {
+ c.Error(err)
+ }
}
func (c *ConnectionTests) TestMetadataGetInfo() {
@@ -201,6 +290,7 @@ func (c *ConnectionTests) TestMetadataGetInfo() {
adbc.InfoDriverName,
adbc.InfoDriverVersion,
adbc.InfoDriverArrowVersion,
+ adbc.InfoDriverADBCVersion,
adbc.InfoVendorName,
adbc.InfoVendorVersion,
adbc.InfoVendorArrowVersion,
@@ -219,14 +309,28 @@ func (c *ConnectionTests) TestMetadataGetInfo() {
valUnion := rec.Column(1).(*array.DenseUnion)
for i := 0; i < int(rec.NumRows()); i++ {
code := codeCol.Value(i)
-
child := valUnion.Field(valUnion.ChildID(i))
- if child.IsNull(i) {
+ offset := int(valUnion.ValueOffset(i))
+ valUnion.GetOneForMarshal(i)
+ if child.IsNull(offset) {
exp := c.Quirks.GetMetadata(adbc.InfoCode(code))
c.Nilf(exp, "got nil for info %s, expected: %s", adbc.InfoCode(code), exp)
} else {
- // currently we only define utf8 values for metadata
- c.Equal(c.Quirks.GetMetadata(adbc.InfoCode(code)), child.(*array.String).Value(i), adbc.InfoCode(code).String())
+ expected := c.Quirks.GetMetadata(adbc.InfoCode(code))
+ var actual interface{}
+
+ switch valUnion.ChildID(i) {
+ case 0:
+ // String
+ actual = child.(*array.String).Value(offset)
+ case 2:
+ // int64
+ actual = child.(*array.Int64).Value(offset)
+ default:
+ c.FailNow("Unknown union type code", valUnion.ChildID(i))
+ }
+
+ c.Equal(expected, actual, adbc.InfoCode(code).String())
}
}
}
@@ -407,6 +511,49 @@ func (s *StatementTests) TestNewStatement() {
s.Equal(adbc.StatusInvalidState, adbcError.Code)
}
+func (s *StatementTests) TestSqlExecuteSchema() {
+ if !s.Quirks.SupportsExecuteSchema() {
+ s.T().SkipNow()
+ }
+
+ stmt, err := s.Cnxn.NewStatement()
+ s.Require().NoError(err)
+ defer stmt.Close()
+
+ es, ok := stmt.(adbc.StatementExecuteSchema)
+ s.Require().True(ok, "%#v does not support ExecuteSchema", es)
+
+ s.Run("no query", func() {
+ var adbcErr adbc.Error
+
+ schema, err := es.ExecuteSchema(s.ctx)
+ s.ErrorAs(err, &adbcErr)
+ s.Equal(adbc.StatusInvalidState, adbcErr.Code)
+ s.Nil(schema)
+ })
+
+ s.Run("query", func() {
+ s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'"))
+
+ schema, err := es.ExecuteSchema(s.ctx)
+ s.NoError(err)
+ s.Equal(2, len(schema.Fields()))
+ s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64)
+ s.Equal(arrow.STRING, schema.Field(1).Type.ID())
+ })
+
+ s.Run("prepared", func() {
+ s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'"))
+ s.NoError(stmt.Prepare(s.ctx))
+
+ schema, err := es.ExecuteSchema(s.ctx)
+ s.NoError(err)
+ s.Equal(2, len(schema.Fields()))
+ s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64)
+ s.Equal(arrow.STRING, schema.Field(1).Type.ID())
+ })
+}
+
func (s *StatementTests) TestSqlPartitionedInts() {
stmt, err := s.Cnxn.NewStatement()
s.Require().NoError(err)