diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index f5514626ad..6968faacf5 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -355,6 +355,17 @@ const ( InfoDriverADBCVersion InfoCode = 103 // DriverADBCVersion ) +type InfoValueTypeCode = arrow.UnionTypeCode + +const ( + InfoValueStringType InfoValueTypeCode = 0 + InfoValueBooleanType InfoValueTypeCode = 1 + InfoValueInt64Type InfoValueTypeCode = 2 + InfoValueInt32BitmaskType InfoValueTypeCode = 3 + InfoValueStringListType InfoValueTypeCode = 4 + InfoValueInt32ToInt32ListMapType InfoValueTypeCode = 5 +) + type ObjectDepth int const ( diff --git a/go/adbc/driver/driverbase/driver.go b/go/adbc/driver/driverbase/driver.go deleted file mode 100644 index e4cfb99602..0000000000 --- a/go/adbc/driver/driverbase/driver.go +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Package driverbase provides a framework for implementing ADBC drivers in -// Go. It intends to reduce boilerplate for common functionality and managing -// state transitions. -package driverbase - -import ( - "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow/go/v16/arrow/memory" -) - -// DriverImpl is an interface that drivers implement to provide -// vendor-specific functionality. -type DriverImpl interface { - Base() *DriverImplBase - NewDatabase(opts map[string]string) (adbc.Database, error) -} - -// DriverImplBase is a struct that provides default implementations of the -// DriverImpl interface. It is meant to be used as a composite struct for a -// driver's DriverImpl implementation. -type DriverImplBase struct { - Alloc memory.Allocator - ErrorHelper ErrorHelper -} - -func NewDriverImplBase(name string, alloc memory.Allocator) DriverImplBase { - if alloc == nil { - alloc = memory.DefaultAllocator - } - return DriverImplBase{Alloc: alloc, ErrorHelper: ErrorHelper{DriverName: name}} -} - -func (base *DriverImplBase) Base() *DriverImplBase { - return base -} - -// driver is the actual implementation of adbc.Driver. -type driver struct { - impl DriverImpl -} - -// NewDriver wraps a DriverImpl to create an adbc.Driver. -func NewDriver(impl DriverImpl) adbc.Driver { - return &driver{impl} -} - -func (drv *driver) NewDatabase(opts map[string]string) (adbc.Database, error) { - return drv.impl.NewDatabase(opts) -} diff --git a/go/adbc/driver/flightsql/flightsql_connection.go b/go/adbc/driver/flightsql/flightsql_connection.go index e71ac308df..83807856ec 100644 --- a/go/adbc/driver/flightsql/flightsql_connection.go +++ b/go/adbc/driver/flightsql/flightsql_connection.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" "github.com/apache/arrow/go/v16/arrow/flight" @@ -43,7 +44,9 @@ import ( "google.golang.org/protobuf/proto" ) -type cnxn struct { +type connectionImpl struct { + driverbase.ConnectionImplBase + cl *flightsql.Client db *databaseImpl @@ -54,6 +57,82 @@ type cnxn struct { supportInfo support } +// GetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if catalog, ok := options["catalog"]; ok { + if val, ok := catalog.(string); ok { + return val, nil + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string catalog %#v", catalog) + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current catalog not supported") +} + +// GetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if schema, ok := options["schema"]; ok { + if val, ok := schema.(string); ok { + return val, nil + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string schema %#v", schema) + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current schema not supported") +} + +// SetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentCatalog(value string) error { + return c.setSessionOptions(context.Background(), "catalog", value) +} + +// SetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentDbSchema(value string) error { + return c.setSessionOptions(context.Background(), "schema", value) +} + +func (c *connectionImpl) SetAutocommit(enabled bool) error { + if enabled && c.txn == nil { + // no-op don't even error if the server didn't support transactions + return nil + } + + if !c.supportInfo.transactions { + return errNoTransactionSupport + } + + ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) + var err error + if c.txn != nil { + if err = c.txn.Commit(ctx, c.timeouts); err != nil { + return adbc.Error{ + Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), + Code: adbc.StatusIO, + } + } + } + + if enabled { + c.txn = nil + return nil + } + + if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil { + return adbc.Error{ + Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), + Code: adbc.StatusIO, + } + } + + return nil +} + var adbcToFlightSQLInfo = map[adbc.InfoCode]flightsql.SqlInfo{ adbc.InfoVendorName: flightsql.SqlInfoFlightSqlServerName, adbc.InfoVendorVersion: flightsql.SqlInfoFlightSqlServerVersion, @@ -97,7 +176,7 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } -func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { +func (c *connectionImpl) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD rawOptions, err := c.cl.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) @@ -140,7 +219,7 @@ func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, e return options, nil } -func (c *cnxn) setSessionOptions(ctx context.Context, key string, val interface{}) error { +func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val interface{}) error { req := flight.SetSessionOptionsRequest{} var err error @@ -206,7 +285,7 @@ func getSessionOption[T any](options map[string]interface{}, key string, default return value, nil } -func (c *cnxn) GetOption(key string) (string, error) { +func (c *connectionImpl) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) headers := c.hdrs.Get(name) @@ -226,51 +305,6 @@ func (c *cnxn) GetOption(key string) (string, error) { return c.timeouts.queryTimeout.String(), nil case OptionTimeoutUpdate: return c.timeouts.updateTimeout.String(), nil - case adbc.OptionKeyAutoCommit: - if c.txn != nil { - // No autocommit - return adbc.OptionValueDisabled, nil - } else { - // Autocommit - return adbc.OptionValueEnabled, nil - } - case adbc.OptionKeyCurrentCatalog: - options, err := c.getSessionOptions(context.Background()) - if err != nil { - return "", err - } - if catalog, ok := options["catalog"]; ok { - if val, ok := catalog.(string); ok { - return val, nil - } - return "", adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] Server returned non-string catalog %#v", catalog), - Code: adbc.StatusInternal, - } - } - return "", adbc.Error{ - Msg: "[FlightSQL] current catalog not supported", - Code: adbc.StatusNotFound, - } - - case adbc.OptionKeyCurrentDbSchema: - options, err := c.getSessionOptions(context.Background()) - if err != nil { - return "", err - } - if schema, ok := options["schema"]; ok { - if val, ok := schema.(string); ok { - return val, nil - } - return "", adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] Server returned non-string schema %#v", schema), - Code: adbc.StatusInternal, - } - } - return "", adbc.Error{ - Msg: "[FlightSQL] current schema not supported", - Code: adbc.StatusNotFound, - } case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) if err != nil { @@ -333,7 +367,7 @@ func (c *cnxn) GetOption(key string) (string, error) { } } -func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { +func (c *connectionImpl) GetOptionBytes(key string) ([]byte, error) { switch key { case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) @@ -356,7 +390,7 @@ func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { } } -func (c *cnxn) GetOptionInt(key string) (int64, error) { +func (c *connectionImpl) GetOptionInt(key string) (int64, error) { switch key { case OptionTimeoutFetch: fallthrough @@ -378,13 +412,10 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) { return getSessionOption(options, name, int64(0), "an integer") } - return 0, adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotFound, - } + return c.ConnectionImplBase.GetOptionInt(key) } -func (c *cnxn) GetOptionDouble(key string) (float64, error) { +func (c *connectionImpl) GetOptionDouble(key string) (float64, error) { switch key { case OptionTimeoutFetch: return c.timeouts.fetchTimeout.Seconds(), nil @@ -402,13 +433,10 @@ func (c *cnxn) GetOptionDouble(key string) (float64, error) { return getSessionOption(options, name, float64(0.0), "a floating-point") } - return 0.0, adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotFound, - } + return c.ConnectionImplBase.GetOptionDouble(key) } -func (c *cnxn) SetOption(key, value string) error { +func (c *connectionImpl) SetOption(key, value string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) if value == "" { @@ -422,56 +450,6 @@ func (c *cnxn) SetOption(key, value string) error { switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeoutString(key, value) - case adbc.OptionKeyAutoCommit: - autocommit := true - switch value { - case adbc.OptionValueEnabled: - // Do nothing - case adbc.OptionValueDisabled: - autocommit = false - default: - return adbc.Error{ - Msg: "[Flight SQL] invalid value for option " + key + ": " + value, - Code: adbc.StatusInvalidArgument, - } - } - - if autocommit && c.txn == nil { - // no-op don't even error if the server didn't support transactions - return nil - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - - ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) - var err error - if c.txn != nil { - if err = c.txn.Commit(ctx, c.timeouts); err != nil { - return adbc.Error{ - Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), - Code: adbc.StatusIO, - } - } - } - - if autocommit { - c.txn = nil - return nil - } - - if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil { - return adbc.Error{ - Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), - Code: adbc.StatusIO, - } - } - return nil - case adbc.OptionKeyCurrentCatalog: - return c.setSessionOptions(context.Background(), "catalog", value) - case adbc.OptionKeyCurrentDbSchema: - return c.setSessionOptions(context.Background(), "schema", value) } switch { @@ -506,20 +484,10 @@ func (c *cnxn) SetOption(key, value string) error { return c.setSessionOptions(context.Background(), name, nil) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOption(key, value) } -func (c *cnxn) SetOptionBytes(key string, value []byte) error { - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionInt(key string, value int64) error { +func (c *connectionImpl) SetOptionInt(key string, value int64) error { switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeout(key, float64(value)) @@ -529,13 +497,10 @@ func (c *cnxn) SetOptionInt(key string, value int64) error { return c.setSessionOptions(context.Background(), name, value) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOptionInt(key, value) } -func (c *cnxn) SetOptionDouble(key string, value float64) error { +func (c *connectionImpl) SetOptionDouble(key string, value float64) error { switch key { case OptionTimeoutFetch: fallthrough @@ -549,231 +514,117 @@ func (c *cnxn) SetOptionDouble(key string, value float64) error { return c.setSessionOptions(context.Background(), name, value) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOptionDouble(key, value) } -// GetInfo returns metadata about the database/driver. -// -// The result is an Arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|----------------------------- -// info_name | uint32 not null -// info_value | INFO_SCHEMA -// -// INFO_SCHEMA is a dense union with members: -// -// Field Name (Type Code) | Field Type -// ----------------------------|----------------------------- -// string_value (0) | utf8 -// bool_value (1) | bool -// int64_value (2) | int64 -// int32_bitmask (3) | int32 -// string_list (4) | list -// int32_to_int32_list_map (5) | map> -// -// Each metadatum is identified by an integer code. The recognized -// codes are defined as constants. Codes [0, 10_000) are reserved -// for ADBC usage. Drivers/vendors will ignore requests for unrecognized -// codes (the row will be omitted from the result). -func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { - const strValTypeID arrow.UnionTypeCode = 0 - const intValTypeID arrow.UnionTypeCode = 2 +func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { + driverInfo := c.ConnectionImplBase.DriverInfo if len(infoCodes) == 0 { - infoCodes = infoSupportedCodes + infoCodes = driverInfo.InfoSupportedCodes() } - bldr := array.NewRecordBuilder(c.cl.Alloc, adbc.GetInfoSchema) - defer bldr.Release() - bldr.Reserve(len(infoCodes)) - - infoNameBldr := bldr.Field(0).(*array.Uint32Builder) - infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) - intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) - translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { if t, ok := adbcToFlightSQLInfo[code]; ok { translated = append(translated, t) - continue } + } - switch code { - case adbc.InfoDriverName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverName) - case adbc.InfoDriverVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverVersion) - case adbc.InfoDriverArrowVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverArrowVersion) - case adbc.InfoDriverADBCVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(intValTypeID) - intInfoBldr.Append(adbc.AdbcVersion1_1_0) - } + // None of the requested info codes are available on the server, so just return the local info + if len(translated) == 0 { + return nil } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) - if err == nil { - for i, endpoint := range info.Endpoint { - var header, trailer metadata.MD - rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) - if err != nil { - return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) - } - for rdr.Next() { - rec := rdr.Record() - field := rec.Column(0).(*array.Uint32) - info := rec.Column(1).(*array.DenseUnion) - - for i := 0; i < int(rec.NumRows()); i++ { - switch flightsql.SqlInfo(field.Value(i)) { - case flightsql.SqlInfoFlightSqlServerName: - infoNameBldr.Append(uint32(adbc.InfoVendorName)) - case flightsql.SqlInfoFlightSqlServerVersion: - infoNameBldr.Append(uint32(adbc.InfoVendorVersion)) - case flightsql.SqlInfoFlightSqlServerArrowVersion: - infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion)) - default: - continue - } + // Just return local driver info if GetSqlInfo hasn't been implemented on the server + if grpcstatus.Code(err) == grpccodes.Unimplemented { + return nil + } + + if err != nil { + return adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") + } + + // No error, go get the SqlInfo from the server + for i, endpoint := range info.Endpoint { + var header, trailer metadata.MD + rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) + if err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + } - infoValueBldr.Append(info.TypeCode(i)) - // we know we're only doing string fields here right now - v := info.Field(info.ChildID(i)).(*array.String). - Value(int(info.ValueOffset(i))) - strInfoBldr.Append(v) + for rdr.Next() { + rec := rdr.Record() + field := rec.Column(0).(*array.Uint32) + info := rec.Column(1).(*array.DenseUnion) + + var adbcInfoCode adbc.InfoCode + for i := 0; i < int(rec.NumRows()); i++ { + switch flightsql.SqlInfo(field.Value(i)) { + case flightsql.SqlInfoFlightSqlServerName: + adbcInfoCode = adbc.InfoVendorName + case flightsql.SqlInfoFlightSqlServerVersion: + adbcInfoCode = adbc.InfoVendorVersion + case flightsql.SqlInfoFlightSqlServerArrowVersion: + adbcInfoCode = adbc.InfoVendorArrowVersion + default: + continue } - } - if err := checkContext(rdr.Err(), ctx); err != nil { - return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + // we know we're only doing string fields here right now + v := info.Field(info.ChildID(i)).(*array.String). + Value(int(info.ValueOffset(i))) + if err := driverInfo.RegisterInfoCode(adbcInfoCode, strings.Clone(v)); err != nil { + return err + } } } - } else if grpcstatus.Code(err) != grpccodes.Unimplemented { - return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") + + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + } } - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) + return nil } -// GetObjects gets a hierarchical view of all catalogs, database schemas, -// tables, and columns. -// -// The result is an Arrow Dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// catalog_name | utf8 -// catalog_db_schemas | list -// -// DB_SCHEMA_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// db_schema_name | utf8 -// db_schema_tables | list -// -// TABLE_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// table_name | utf8 not null -// table_type | utf8 not null -// table_columns | list -// table_constraints | list -// -// COLUMN_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type | Comments -// ----------------------------|---------------------|--------- -// column_name | utf8 not null | -// ordinal_position | int32 | (1) -// remarks | utf8 | (2) -// xdbc_data_type | int16 | (3) -// xdbc_type_name | utf8 | (3) -// xdbc_column_size | int32 | (3) -// xdbc_decimal_digits | int16 | (3) -// xdbc_num_prec_radix | int16 | (3) -// xdbc_nullable | int16 | (3) -// xdbc_column_def | utf8 | (3) -// xdbc_sql_data_type | int16 | (3) -// xdbc_datetime_sub | int16 | (3) -// xdbc_char_octet_length | int32 | (3) -// xdbc_is_nullable | utf8 | (3) -// xdbc_scope_catalog | utf8 | (3) -// xdbc_scope_schema | utf8 | (3) -// xdbc_scope_table | utf8 | (3) -// xdbc_is_autoincrement | bool | (3) -// xdbc_is_generatedcolumn | bool | (3) -// -// 1. The column's ordinal position in the table (starting from 1). -// 2. Database-specific description of the column. -// 3. Optional Value. Should be null if not supported by the driver. -// xdbc_values are meant to provide JDBC/ODBC-compatible metadata -// in an agnostic manner. -// -// CONSTRAINT_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type | Comments -// ----------------------------|---------------------|--------- -// constraint_name | utf8 | -// constraint_type | utf8 not null | (1) -// constraint_column_names | list not null | (2) -// constraint_column_usage | list | (3) -// -// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. -// 2. The columns on the current table that are constrained, in order. -// 3. For FOREIGN KEY only, the referenced table and columns. -// -// USAGE_SCHEMA is a Struct with fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// fk_catalog | utf8 -// fk_db_schema | utf8 -// fk_table | utf8 not null -// fk_column_name | utf8 not null -// -// For the parameters: If nil is passed, then that parameter will not -// be filtered by at all. If an empty string, then only objects without -// that property (ie: catalog or db schema) will be returned. -// -// tableName and columnName must be either nil (do not filter by -// table name or column name) or non-empty. -// -// All non-empty, non-nil strings should be a search pattern (as described -// earlier). -func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { - ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType} - if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas, c.getObjectsTables); err != nil { - return nil, err +// Helper function to read and validate a metadata stream +func (c *connectionImpl) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { + // use a default queueSize for the reader + rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...) + if err != nil { + return nil, adbcFromFlightStatus(err, "DoGet") } - defer g.Release() - var header, trailer metadata.MD + if !rdr.Schema().Equal(expectedSchema) { + rdr.Release() + return nil, adbc.Error{ + Msg: fmt.Sprintf("Invalid schema returned for: expected %s, got %s", expectedSchema.String(), rdr.Schema().String()), + Code: adbc.StatusInternal, + } + } + return rdr, nil +} + +func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) { + var ( + header, trailer metadata.MD + numCatalogs int64 + ) // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } + if info.TotalRecords > 0 { + numCatalogs = info.TotalRecords + } + header = metadata.MD{} trailer = metadata.MD{} rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -782,48 +633,25 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * } defer rdr.Release() - foundCatalog := false + catalogs := make([]string, 0, numCatalogs) for rdr.Next() { arr := rdr.Record().Column(0).(*array.String) for i := 0; i < arr.Len(); i++ { // XXX: force copy since accessor is unsafe catalogName := string([]byte(arr.Value(i))) - g.AppendCatalog(catalogName) - foundCatalog = true + catalogs = append(catalogs, catalogName) } } - // Implementations like Dremio report no catalogs, but still have schemas - if !foundCatalog && depth != adbc.ObjectDepthCatalogs { - g.AppendCatalog("") - } - if err := checkContext(rdr.Err(), ctx); err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - return g.Finish() -} - -// Helper function to read and validate a metadata stream -func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { - // use a default queueSize for the reader - rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...) - if err != nil { - return nil, adbcFromFlightStatus(err, "DoGet") - } - if !rdr.Schema().Equal(expectedSchema) { - rdr.Release() - return nil, adbc.Error{ - Msg: fmt.Sprintf("Invalid schema returned for: expected %s, got %s", expectedSchema.String(), rdr.Schema().String()), - Code: adbc.StatusInternal, - } - } - return rdr, nil + return catalogs, nil } // Helper function to build up a map of catalogs to DB schemas -func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { +func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { if depth == adbc.ObjectDepthCatalogs { return } @@ -864,7 +692,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, return } -func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { +func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } @@ -944,7 +772,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat return } -func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { +func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { opts := &flightsql.GetTablesOpts{ Catalog: catalog, DbSchemaFilterPattern: dbSchema, @@ -1023,7 +851,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st // Field Name | Field Type // ----------------|-------------- // table_type | utf8 not null -func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { +func (c *connectionImpl) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1040,18 +868,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. -func (c *cnxn) Commit(ctx context.Context) error { - if c.txn == nil { - return adbc.Error{ - Msg: "[Flight SQL] Cannot commit when autocommit is enabled", - Code: adbc.StatusInvalidState, - } - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - +func (c *connectionImpl) Commit(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1074,18 +891,7 @@ func (c *cnxn) Commit(ctx context.Context) error { // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. -func (c *cnxn) Rollback(ctx context.Context) error { - if c.txn == nil { - return adbc.Error{ - Msg: "[Flight SQL] Cannot rollback when autocommit is enabled", - Code: adbc.StatusInvalidState, - } - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - +func (c *connectionImpl) Rollback(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1103,7 +909,7 @@ func (c *cnxn) Rollback(ctx context.Context) error { } // NewStatement initializes a new statement object tied to this connection -func (c *cnxn) NewStatement() (adbc.Statement, error) { +func (c *connectionImpl) NewStatement() (adbc.Statement, error) { return &statement{ alloc: c.db.Alloc, clientCache: c.clientCache, @@ -1114,7 +920,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) { }, nil } -func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (c *connectionImpl) execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.Execute(ctx, query, opts...) } @@ -1122,7 +928,7 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } -func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { +func (c *connectionImpl) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSchema(ctx, query, opts...) } @@ -1130,7 +936,7 @@ func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.Cal return c.cl.GetExecuteSchema(ctx, query, opts...) } -func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (c *connectionImpl) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) } @@ -1138,7 +944,7 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } -func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { +func (c *connectionImpl) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) } @@ -1146,7 +952,7 @@ func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.Substr return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) } -func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { +func (c *connectionImpl) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) } @@ -1154,7 +960,7 @@ func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.Cal return c.cl.ExecuteUpdate(ctx, query, opts...) } -func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) { +func (c *connectionImpl) executeSubstraitUpdate(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteSubstraitUpdate(ctx, plan, opts...) } @@ -1162,7 +968,7 @@ func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan flightsql.Substr return c.cl.ExecuteSubstraitUpdate(ctx, plan, opts...) } -func (c *cnxn) poll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (c *connectionImpl) poll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecutePoll(ctx, query, retryDescriptor, opts...) } @@ -1170,7 +976,7 @@ func (c *cnxn) poll(ctx context.Context, query string, retryDescriptor *flight.F return c.cl.ExecutePoll(ctx, query, retryDescriptor, opts...) } -func (c *cnxn) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (c *connectionImpl) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } @@ -1178,7 +984,7 @@ func (c *cnxn) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, return c.cl.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } -func (c *cnxn) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (c *connectionImpl) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.Prepare(ctx, query, opts...) } @@ -1186,7 +992,7 @@ func (c *cnxn) prepare(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Prepare(ctx, query, opts...) } -func (c *cnxn) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (c *connectionImpl) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.PrepareSubstrait(ctx, plan, opts...) } @@ -1195,7 +1001,7 @@ func (c *cnxn) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPla } // Close closes this connection and releases any associated resources. -func (c *cnxn) Close() error { +func (c *connectionImpl) Close() error { if c.cl == nil { return adbc.Error{ Msg: "[Flight SQL Connection] trying to close already closed connection", @@ -1225,7 +1031,7 @@ func (c *cnxn) Close() error { // results can then be read independently using the returned RecordReader. // // A partition can be retrieved by using ExecutePartitions on a statement. -func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) { +func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) { var info flight.FlightInfo if err := proto.Unmarshal(serializedPartition, &info); err != nil { return nil, adbc.Error{ @@ -1251,5 +1057,5 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r } var ( - _ adbc.PostInitOptions = (*cnxn)(nil) + _ adbc.PostInitOptions = (*connectionImpl)(nil) ) diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index 5e5e3af978..9f0848c3f9 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -29,7 +29,7 @@ import ( "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/array" "github.com/apache/arrow/go/v16/arrow/flight" "github.com/apache/arrow/go/v16/arrow/flight/flightsql" @@ -51,7 +51,6 @@ func (d *dbDialOpts) rebuild() { d.opts = []grpc.DialOption{ grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize), grpc.MaxCallSendMsgSize(d.maxMsgSize)), - grpc.WithUserAgent("ADBC Flight SQL Driver " + infoDriverVersion), } if d.block { d.opts = append(d.opts, grpc.WithBlock()) @@ -383,7 +382,12 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl creds = insecure.NewCredentials() target = "unix:" + uri.Path } - dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds)) + + driverVersion, ok := d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion() + if !ok { + driverVersion = driverbase.UnknownVersion + } + dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion)) d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) @@ -503,9 +507,18 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { } } - return &cnxn{cl: cl, db: d, clientCache: cache, - hdrs: make(metadata.MD), timeouts: d.timeout, - supportInfo: cnxnSupport}, nil + conn := &connectionImpl{ + cl: cl, db: d, clientCache: cache, + hdrs: make(metadata.MD), timeouts: d.timeout, supportInfo: cnxnSupport, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + } + + return driverbase.NewConnectionBuilder(conn). + WithDriverInfoPreparer(conn). + WithAutocommitSetter(conn). + WithDbObjectsEnumerator(conn). + WithCurrentNamespacer(conn). + Connection(), nil } type bearerAuthMiddleware struct { diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index d437f0829b..441370a9e2 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -33,12 +33,10 @@ package flightsql import ( "net/url" - "runtime/debug" - "strings" "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/memory" "golang.org/x/exp/maps" "google.golang.org/grpc/metadata" @@ -69,56 +67,19 @@ const ( infoDriverName = "ADBC Flight SQL Driver - Go" ) -var ( - infoDriverVersion string - infoDriverArrowVersion string - infoSupportedCodes []adbc.InfoCode -) - var errNoTransactionSupport = adbc.Error{ Msg: "[Flight SQL] server does not report transaction support", Code: adbc.StatusNotImplemented, } -func init() { - if info, ok := debug.ReadBuildInfo(); ok { - for _, dep := range info.Deps { - switch { - case dep.Path == "github.com/apache/arrow-adbc/go/adbc/driver/flightsql": - infoDriverVersion = dep.Version - case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): - infoDriverArrowVersion = dep.Version - } - } - } - // XXX: Deps not populated in tests - // https://github.com/golang/go/issues/33976 - if infoDriverVersion == "" { - infoDriverVersion = "(unknown or development build)" - } - if infoDriverArrowVersion == "" { - infoDriverArrowVersion = "(unknown or development build)" - } - - infoSupportedCodes = []adbc.InfoCode{ - adbc.InfoDriverName, - adbc.InfoDriverVersion, - adbc.InfoDriverArrowVersion, - adbc.InfoDriverADBCVersion, - adbc.InfoVendorName, - adbc.InfoVendorVersion, - adbc.InfoVendorArrowVersion, - } -} - type driverImpl struct { driverbase.DriverImplBase } // NewDriver creates a new Flight SQL driver using the given Arrow allocator. func NewDriver(alloc memory.Allocator) adbc.Driver { - impl := driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Flight SQL", alloc)} - return driverbase.NewDriver(&impl) + info := driverbase.DefaultDriverInfo("Flight SQL") + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}) } func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index d78b653c81..c68eba8cdb 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -72,7 +72,7 @@ func (s *sqlOrSubstrait) setSubstraitPlan(plan []byte) { s.substraitPlan = plan } -func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if s.sqlQuery != "" { return cnxn.execute(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -85,7 +85,7 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } -func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*arrow.Schema, error) { var ( res *flight.SchemaResult err error @@ -108,7 +108,7 @@ func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ... return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) } -func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { +func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -121,7 +121,7 @@ func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ... } } -func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *cnxn, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *connectionImpl, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if s.sqlQuery != "" { return cnxn.poll(ctx, s.sqlQuery, retryDescriptor, opts...) } else if s.substraitPlan != nil { @@ -134,7 +134,7 @@ func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *cnxn, retryDescriptor * } } -func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if s.sqlQuery != "" { return cnxn.prepare(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -156,7 +156,7 @@ type incrementalState struct { type statement struct { alloc memory.Allocator - cnxn *cnxn + cnxn *connectionImpl clientCache gcache.Cache hdrs metadata.MD diff --git a/go/adbc/driver/internal/driverbase/connection.go b/go/adbc/driver/internal/driverbase/connection.go new file mode 100644 index 0000000000..68b0a9bc69 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/connection.go @@ -0,0 +1,497 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase + +import ( + "context" + "fmt" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "golang.org/x/exp/slog" +) + +const ( + ConnectionMessageOptionUnknown = "Unknown connection option" + ConnectionMessageOptionUnsupported = "Unsupported connection option" + ConnectionMessageCannotCommit = "Cannot commit when autocommit is enabled" + ConnectionMessageCannotRollback = "Cannot rollback when autocommit is enabled" +) + +// ConnectionImpl is an interface that drivers implement to provide +// vendor-specific functionality. +type ConnectionImpl interface { + adbc.Connection + adbc.GetSetOptions + Base() *ConnectionImplBase +} + +// CurrentNamespacer is an interface that drivers may implement to delegate +// stateful namespacing with DB catalogs and schemas. The appropriate (Get/Set)Options +// implementations will be provided using the results of these methods. +type CurrentNamespacer interface { + GetCurrentCatalog() (string, error) + GetCurrentDbSchema() (string, error) + SetCurrentCatalog(string) error + SetCurrentDbSchema(string) error +} + +// DriverInfoPreparer is an interface that drivers may implement to add/update +// DriverInfo values whenever adbc.Connection.GetInfo() is called. +type DriverInfoPreparer interface { + PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error +} + +// TableTypeLister is an interface that drivers may implement to simplify the +// implementation of adbc.Connection.GetTableTypes() for backends that do not natively +// send these values as arrow records. The conversion of the result to a RecordReader +// is handled automatically. +type TableTypeLister interface { + ListTableTypes(ctx context.Context) ([]string, error) +} + +// AutocommitSetter is an interface that drivers may implement to simplify the +// implementation of autocommit state management. There is no need to implement +// this for backends that do not support autocommit, as this is already the default +// behavior. SetAutocommit should only attempt to update the autocommit state in the +// backend. Local driver state is automatically updated if the result of this call +// does not produce an error. (Get/Set)Options implementations are provided automatically +// as well/ +type AutocommitSetter interface { + SetAutocommit(enabled bool) error +} + +// DbObjectsEnumerator is an interface that drivers may implement to simplify the +// implementation of adbc.Connection.GetObjects(). By independently implementing lookup +// for catalogs, dbSchemas and tables, the driverbase is able to provide the full +// GetObjects functionality for arbitrary search patterns and lookup depth. +type DbObjectsEnumerator interface { + GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) + GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []internal.Metadata) (map[string][]string, error) + GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (map[internal.CatalogAndSchema][]internal.TableInfo, error) +} + +// Connection is the interface satisfied by the result of the NewConnection constructor, +// given that an input is provided satisfying the ConnectionImpl interface. +type Connection interface { + adbc.Connection + adbc.GetSetOptions +} + +// ConnectionImplBase is a struct that provides default implementations of the +// ConnectionImpl interface. It is meant to be used as a composite struct for a +// driver's ConnectionImpl implementation. +type ConnectionImplBase struct { + Alloc memory.Allocator + ErrorHelper ErrorHelper + DriverInfo *DriverInfo + Logger *slog.Logger + + Autocommit bool + Closed bool +} + +// NewConnectionImplBase instantiates ConnectionImplBase. +// +// - database is a DatabaseImplBase containing the common resources from the parent +// database, allowing the Arrow allocator, error handler, and logger to be reused. +func NewConnectionImplBase(database *DatabaseImplBase) ConnectionImplBase { + return ConnectionImplBase{ + Alloc: database.Alloc, + ErrorHelper: database.ErrorHelper, + DriverInfo: database.DriverInfo, + Logger: database.Logger, + Autocommit: true, + Closed: false, + } +} + +func (base *ConnectionImplBase) Base() *ConnectionImplBase { + return base +} + +func (base *ConnectionImplBase) Commit(ctx context.Context) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Commit") +} + +func (base *ConnectionImplBase) Rollback(context.Context) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Rollback") +} + +func (base *ConnectionImplBase) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { + + if len(infoCodes) == 0 { + infoCodes = base.DriverInfo.InfoSupportedCodes() + } + + bldr := array.NewRecordBuilder(base.Alloc, adbc.GetInfoSchema) + defer bldr.Release() + bldr.Reserve(len(infoCodes)) + + infoNameBldr := bldr.Field(0).(*array.Uint32Builder) + infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) + strInfoBldr := infoValueBldr.Child(int(adbc.InfoValueStringType)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(adbc.InfoValueInt64Type)).(*array.Int64Builder) + + for _, code := range infoCodes { + switch code { + case adbc.InfoDriverName: + name, ok := base.DriverInfo.GetInfoDriverName() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(name) + case adbc.InfoDriverVersion: + version, ok := base.DriverInfo.GetInfoDriverVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(version) + case adbc.InfoDriverArrowVersion: + arrowVersion, ok := base.DriverInfo.GetInfoDriverArrowVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(arrowVersion) + case adbc.InfoDriverADBCVersion: + adbcVersion, ok := base.DriverInfo.GetInfoDriverADBCVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueInt64Type) + intInfoBldr.Append(adbcVersion) + case adbc.InfoVendorName: + name, ok := base.DriverInfo.GetInfoVendorName() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(name) + default: + infoNameBldr.Append(uint32(code)) + value, ok := base.DriverInfo.GetInfoForInfoCode(code) + if !ok { + infoValueBldr.AppendNull() + continue + } + + // TODO: Handle other custom info types + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(fmt.Sprint(value)) + } + } + + final := bldr.NewRecord() + defer final.Release() + return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) +} + +func (base *ConnectionImplBase) Close() error { + return nil +} + +func (base *ConnectionImplBase) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetObjects") +} + +func (base *ConnectionImplBase) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetTableSchema") +} + +func (base *ConnectionImplBase) GetTableTypes(context.Context) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetTableTypes") +} + +func (base *ConnectionImplBase) NewStatement() (adbc.Statement, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "NewStatement") +} + +func (base *ConnectionImplBase) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "ReadPartition") +} + +func (base *ConnectionImplBase) GetOption(key string) (string, error) { + return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionBytes(key string) ([]byte, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionDouble(key string) (float64, error) { + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionInt(key string) (int64, error) { + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOption(key string, val string) error { + switch key { + case adbc.OptionKeyAutoCommit: + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnsupported, key) + } + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionBytes(key string, val []byte) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionDouble(key string, val float64) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionInt(key string, val int64) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +type connection struct { + ConnectionImpl + + dbObjectsEnumerator DbObjectsEnumerator + currentNamespacer CurrentNamespacer + driverInfoPreparer DriverInfoPreparer + tableTypeLister TableTypeLister + autocommitSetter AutocommitSetter +} + +type ConnectionBuilder struct { + connection *connection +} + +func NewConnectionBuilder(impl ConnectionImpl) *ConnectionBuilder { + return &ConnectionBuilder{connection: &connection{ConnectionImpl: impl}} +} + +func (b *ConnectionBuilder) WithDbObjectsEnumerator(helper DbObjectsEnumerator) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.dbObjectsEnumerator = helper + return b +} + +func (b *ConnectionBuilder) WithCurrentNamespacer(helper CurrentNamespacer) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.currentNamespacer = helper + return b +} + +func (b *ConnectionBuilder) WithDriverInfoPreparer(helper DriverInfoPreparer) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.driverInfoPreparer = helper + return b +} + +func (b *ConnectionBuilder) WithAutocommitSetter(helper AutocommitSetter) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.autocommitSetter = helper + return b +} + +func (b *ConnectionBuilder) WithTableTypeLister(helper TableTypeLister) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.tableTypeLister = helper + return b +} + +func (b *ConnectionBuilder) Connection() Connection { + conn := b.connection + b.connection = nil + return conn +} + +// GetObjects implements Connection. +func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + helper := cnxn.dbObjectsEnumerator + + // If the dbObjectsEnumerator has not been set, then the driver implementor has elected to provide their own GetObjects implementation + if helper == nil { + return cnxn.ConnectionImpl.GetObjects(ctx, depth, catalog, dbSchema, tableName, columnName, tableType) + } + + // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. + g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType} + if err := g.Init(cnxn.Base().Alloc, helper.GetObjectsDbSchemas, helper.GetObjectsTables); err != nil { + return nil, err + } + defer g.Release() + + catalogs, err := helper.GetObjectsCatalogs(ctx, catalog) + if err != nil { + return nil, err + } + + foundCatalog := false + for _, catalog := range catalogs { + g.AppendCatalog(catalog) + foundCatalog = true + } + + // Implementations like Dremio report no catalogs, but still have schemas + if !foundCatalog && depth != adbc.ObjectDepthCatalogs { + g.AppendCatalog("") + } + return g.Finish() +} + +func (cnxn *connection) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if cnxn.Base().Autocommit { + return adbc.OptionValueEnabled, nil + } else { + return adbc.OptionValueDisabled, nil + } + case adbc.OptionKeyCurrentCatalog: + if cnxn.currentNamespacer != nil { + val, err := cnxn.currentNamespacer.GetCurrentCatalog() + if err != nil { + return "", cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current catalog: %s", err) + } + return val, nil + } + case adbc.OptionKeyCurrentDbSchema: + if cnxn.currentNamespacer != nil { + val, err := cnxn.currentNamespacer.GetCurrentDbSchema() + if err != nil { + return "", cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current db schema: %s", err) + } + return val, nil + } + } + return cnxn.ConnectionImpl.GetOption(key) +} + +func (cnxn *connection) SetOption(key string, val string) error { + switch key { + case adbc.OptionKeyAutoCommit: + if cnxn.autocommitSetter != nil { + + var autocommit bool + switch val { + case adbc.OptionValueEnabled: + autocommit = true + case adbc.OptionValueDisabled: + autocommit = false + default: + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidArgument, "cannot set value %s for key %s", val, key) + } + + err := cnxn.autocommitSetter.SetAutocommit(autocommit) + if err == nil { + // Only update the driver state if the action was successful + cnxn.Base().Autocommit = autocommit + } + + return err + } + case adbc.OptionKeyCurrentCatalog: + if cnxn.currentNamespacer != nil { + return cnxn.currentNamespacer.SetCurrentCatalog(val) + } + case adbc.OptionKeyCurrentDbSchema: + if cnxn.currentNamespacer != nil { + return cnxn.currentNamespacer.SetCurrentDbSchema(val) + } + } + return cnxn.ConnectionImpl.SetOption(key, val) +} + +func (cnxn *connection) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { + if cnxn.driverInfoPreparer != nil { + if err := cnxn.driverInfoPreparer.PrepareDriverInfo(ctx, infoCodes); err != nil { + return nil, err + } + } + + return cnxn.Base().GetInfo(ctx, infoCodes) +} + +func (cnxn *connection) GetTableTypes(ctx context.Context) (array.RecordReader, error) { + if cnxn.tableTypeLister == nil { + return cnxn.ConnectionImpl.GetTableTypes(ctx) + } + + tableTypes, err := cnxn.tableTypeLister.ListTableTypes(ctx) + if err != nil { + return nil, err + } + + bldr := array.NewRecordBuilder(cnxn.Base().Alloc, adbc.TableTypesSchema) + defer bldr.Release() + + bldr.Field(0).(*array.StringBuilder).AppendValues(tableTypes, nil) + final := bldr.NewRecord() + defer final.Release() + return array.NewRecordReader(adbc.TableTypesSchema, []arrow.Record{final}) +} + +func (cnxn *connection) Commit(ctx context.Context) error { + if cnxn.Base().Autocommit { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, ConnectionMessageCannotCommit) + } + return cnxn.ConnectionImpl.Commit(ctx) +} + +func (cnxn *connection) Rollback(ctx context.Context) error { + if cnxn.Base().Autocommit { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, ConnectionMessageCannotRollback) + } + return cnxn.ConnectionImpl.Rollback(ctx) +} + +func (cnxn *connection) Close() error { + if cnxn.Base().Closed { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, "Trying to close already closed connection") + } + + err := cnxn.ConnectionImpl.Close() + if err == nil { + cnxn.Base().Closed = true + } + + return err +} + +var _ ConnectionImpl = (*ConnectionImplBase)(nil) diff --git a/go/adbc/driver/driverbase/database.go b/go/adbc/driver/internal/driverbase/database.go similarity index 52% rename from go/adbc/driver/driverbase/database.go rename to go/adbc/driver/internal/driverbase/database.go index b08b77fcaa..9ab00967a5 100644 --- a/go/adbc/driver/driverbase/database.go +++ b/go/adbc/driver/internal/driverbase/database.go @@ -25,14 +25,24 @@ import ( "golang.org/x/exp/slog" ) +const ( + DatabaseMessageOptionUnknown = "Unknown database option" +) + // DatabaseImpl is an interface that drivers implement to provide // vendor-specific functionality. type DatabaseImpl interface { + adbc.Database adbc.GetSetOptions Base() *DatabaseImplBase - Open(context.Context) (adbc.Connection, error) - Close() error - SetOptions(map[string]string) error +} + +// Database is the interface satisfied by the result of the NewDatabase constructor, +// given an input is provided satisfying the DatabaseImpl interface. +type Database interface { + adbc.Database + adbc.GetSetOptions + adbc.DatabaseLogging } // DatabaseImplBase is a struct that provides default implementations of the @@ -41,14 +51,16 @@ type DatabaseImpl interface { type DatabaseImplBase struct { Alloc memory.Allocator ErrorHelper ErrorHelper + DriverInfo *DriverInfo Logger *slog.Logger } -// NewDatabaseImplBase instantiates DatabaseImplBase. name is the driver's -// name and is used to construct error messages. alloc is an Arrow allocator -// to use. +// NewDatabaseImplBase instantiates DatabaseImplBase. +// +// - driver is a DriverImplBase containing the common resources from the parent +// driver, allowing the Arrow allocator and error handler to be reused. func NewDatabaseImplBase(driver *DriverImplBase) DatabaseImplBase { - return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper: driver.ErrorHelper, Logger: nilLogger()} + return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper: driver.ErrorHelper, DriverInfo: driver.DriverInfo, Logger: nilLogger()} } func (base *DatabaseImplBase) Base() *DatabaseImplBase { @@ -56,97 +68,72 @@ func (base *DatabaseImplBase) Base() *DatabaseImplBase { } func (base *DatabaseImplBase) GetOption(key string) (string, error) { - return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionBytes(key string) ([]byte, error) { - return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionDouble(key string) (float64, error) { - return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionInt(key string) (int64, error) { - return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOption(key string, val string) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionBytes(key string, val []byte) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionDouble(key string, val float64) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionInt(key string, val int64) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) -} - -// database is the implementation of adbc.Database. -type database struct { - impl DatabaseImpl -} - -// NewDatabase wraps a DatabaseImpl to create an adbc.Database. -func NewDatabase(impl DatabaseImpl) adbc.Database { - return &database{ - impl: impl, - } -} - -func (db *database) GetOption(key string) (string, error) { - return db.impl.GetOption(key) -} - -func (db *database) GetOptionBytes(key string) ([]byte, error) { - return db.impl.GetOptionBytes(key) -} - -func (db *database) GetOptionDouble(key string) (float64, error) { - return db.impl.GetOptionDouble(key) -} - -func (db *database) GetOptionInt(key string) (int64, error) { - return db.impl.GetOptionInt(key) -} - -func (db *database) SetOption(key string, val string) error { - return db.impl.SetOption(key, val) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } -func (db *database) SetOptionBytes(key string, val []byte) error { - return db.impl.SetOptionBytes(key, val) +func (base *DatabaseImplBase) Close() error { + return nil } -func (db *database) SetOptionDouble(key string, val float64) error { - return db.impl.SetOptionDouble(key, val) +func (base *DatabaseImplBase) Open(ctx context.Context) (adbc.Connection, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Open") } -func (db *database) SetOptionInt(key string, val int64) error { - return db.impl.SetOptionInt(key, val) +func (base *DatabaseImplBase) SetOptions(options map[string]string) error { + for key, val := range options { + if err := base.SetOption(key, val); err != nil { + return err + } + } + return nil } -func (db *database) Open(ctx context.Context) (adbc.Connection, error) { - return db.impl.Open(ctx) +// database is the implementation of adbc.Database. +type database struct { + DatabaseImpl } -func (db *database) Close() error { - return db.impl.Close() +// NewDatabase wraps a DatabaseImpl to create an adbc.Database. +func NewDatabase(impl DatabaseImpl) Database { + return &database{ + DatabaseImpl: impl, + } } func (db *database) SetLogger(logger *slog.Logger) { if logger != nil { - db.impl.Base().Logger = logger + db.Base().Logger = logger } else { - db.impl.Base().Logger = nilLogger() + db.Base().Logger = nilLogger() } } -func (db *database) SetOptions(opts map[string]string) error { - return db.impl.SetOptions(opts) -} +var _ DatabaseImpl = (*DatabaseImplBase)(nil) diff --git a/go/adbc/driver/internal/driverbase/driver.go b/go/adbc/driver/internal/driverbase/driver.go new file mode 100644 index 0000000000..bd3e11c086 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver.go @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package driverbase provides a framework for implementing ADBC drivers in +// Go. It intends to reduce boilerplate for common functionality and managing +// state transitions. +package driverbase + +import ( + "runtime/debug" + "strings" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v16/arrow/memory" +) + +var ( + infoDriverVersion string + infoDriverArrowVersion string +) + +func init() { + if info, ok := debug.ReadBuildInfo(); ok { + for _, dep := range info.Deps { + switch { + case dep.Path == "github.com/apache/arrow-adbc/go/adbc": + infoDriverVersion = dep.Version + case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): + infoDriverArrowVersion = dep.Version + } + } + } +} + +// DriverImpl is an interface that drivers implement to provide +// vendor-specific functionality. +type DriverImpl interface { + adbc.Driver + Base() *DriverImplBase +} + +// Driver is the interface satisfied by the result of the NewDriver constructor, +// given an input is provided satisfying the DriverImpl interface. +type Driver interface { + adbc.Driver +} + +// DriverImplBase is a struct that provides default implementations of the +// DriverImpl interface. It is meant to be used as a composite struct for a +// driver's DriverImpl implementation. +type DriverImplBase struct { + Alloc memory.Allocator + ErrorHelper ErrorHelper + DriverInfo *DriverInfo +} + +func (base *DriverImplBase) NewDatabase(opts map[string]string) (adbc.Database, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "NewDatabase") +} + +// NewDriverImplBase instantiates DriverImplBase. +// +// - info contains build and vendor info, as well as the name to construct error messages. +// - alloc is an Arrow allocator to use. +func NewDriverImplBase(info *DriverInfo, alloc memory.Allocator) DriverImplBase { + if alloc == nil { + alloc = memory.DefaultAllocator + } + + if infoDriverVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoDriverVersion, infoDriverVersion); err != nil { + panic(err) + } + } + + if infoDriverArrowVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoDriverArrowVersion, infoDriverArrowVersion); err != nil { + panic(err) + } + } + + return DriverImplBase{ + Alloc: alloc, + ErrorHelper: ErrorHelper{DriverName: info.GetName()}, + DriverInfo: info, + } +} + +func (base *DriverImplBase) Base() *DriverImplBase { + return base +} + +type driver struct { + DriverImpl +} + +// NewDriver wraps a DriverImpl to create a Driver. +func NewDriver(impl DriverImpl) Driver { + return &driver{DriverImpl: impl} +} + +var _ DriverImpl = (*DriverImplBase)(nil) diff --git a/go/adbc/driver/internal/driverbase/driver_info.go b/go/adbc/driver/internal/driverbase/driver_info.go new file mode 100644 index 0000000000..e68aa16c2c --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_info.go @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase + +import ( + "fmt" + "sort" + + "github.com/apache/arrow-adbc/go/adbc" +) + +const ( + UnknownVersion = "(unknown or development build)" + DefaultInfoDriverADBCVersion = adbc.AdbcVersion1_1_0 +) + +func DefaultDriverInfo(name string) *DriverInfo { + defaultInfoVendorName := name + defaultInfoDriverName := fmt.Sprintf("ADBC %s Driver - Go", name) + + return &DriverInfo{ + name: name, + info: map[adbc.InfoCode]any{ + adbc.InfoVendorName: defaultInfoVendorName, + adbc.InfoDriverName: defaultInfoDriverName, + adbc.InfoDriverVersion: UnknownVersion, + adbc.InfoDriverArrowVersion: UnknownVersion, + adbc.InfoVendorVersion: UnknownVersion, + adbc.InfoVendorArrowVersion: UnknownVersion, + adbc.InfoDriverADBCVersion: DefaultInfoDriverADBCVersion, + }, + } +} + +type DriverInfo struct { + name string + info map[adbc.InfoCode]any +} + +func (di *DriverInfo) GetName() string { return di.name } + +func (di *DriverInfo) InfoSupportedCodes() []adbc.InfoCode { + // The keys of the info map are used to determine which info codes are supported. + // This means that any info codes the driver knows about should be set to some default + // at init, even if we don't know the value yet. + codes := make([]adbc.InfoCode, 0, len(di.info)) + for code := range di.info { + codes = append(codes, code) + } + + // Sorting info codes helps present them to the client in a consistent way. + // It also helps add some determinism to internal tests. + // The ordering is in no way part of the API contract and should not be relied upon. + sort.SliceStable(codes, func(i, j int) bool { + return codes[i] < codes[j] + }) + return codes +} + +func (di *DriverInfo) RegisterInfoCode(code adbc.InfoCode, value any) error { + switch code { + case adbc.InfoVendorName: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoVendorVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoVendorArrowVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverName: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverArrowVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverADBCVersion: + if err := ensureType[int64](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + } + + di.info[code] = value + return nil +} + +func (di *DriverInfo) GetInfoForInfoCode(code adbc.InfoCode) (any, bool) { + val, ok := di.info[code] + return val, ok +} + +func (di *DriverInfo) GetInfoVendorName() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorName) +} + +func (di *DriverInfo) GetInfoVendorVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorVersion) +} + +func (di *DriverInfo) GetInfoVendorArrowVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorArrowVersion) +} + +func (di *DriverInfo) GetInfoDriverName() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverName) +} + +func (di *DriverInfo) GetInfoDriverVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverVersion) +} + +func (di *DriverInfo) GetInfoDriverArrowVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverArrowVersion) +} + +func (di *DriverInfo) GetInfoDriverADBCVersion() (int64, bool) { + return di.getInt64InfoCode(adbc.InfoDriverADBCVersion) +} + +func (di *DriverInfo) getStringInfoCode(code adbc.InfoCode) (string, bool) { + val, ok := di.GetInfoForInfoCode(code) + if !ok { + return "", false + } + + if err := ensureType[string](val); err != nil { + panic(err) + } + + return val.(string), true +} + +func (di *DriverInfo) getInt64InfoCode(code adbc.InfoCode) (int64, bool) { + val, ok := di.GetInfoForInfoCode(code) + if !ok { + return int64(0), false + } + + if err := ensureType[int64](val); err != nil { + panic(err) + } + + return val.(int64), true +} + +func ensureType[T any](value any) error { + typedVal, ok := value.(T) + if !ok { + return fmt.Errorf("expected info_value %v to be of type %T but found %T", value, typedVal, value) + } + return nil +} diff --git a/go/adbc/driver/internal/driverbase/driver_info_test.go b/go/adbc/driver/internal/driverbase/driver_info_test.go new file mode 100644 index 0000000000..2bad25d056 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_info_test.go @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase_test + +import ( + "strings" + "testing" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" + "github.com/stretchr/testify/require" +) + +func TestDriverInfo(t *testing.T) { + driverInfo := driverbase.DefaultDriverInfo("test") + + // The provided name is used for ErrorHelper, certain info code values, etc + require.Equal(t, "test", driverInfo.GetName()) + + // These are the info codes that are set for every driver + expectedDefaultInfoCodes := []adbc.InfoCode{ + adbc.InfoVendorName, + adbc.InfoVendorVersion, + adbc.InfoVendorArrowVersion, + adbc.InfoDriverName, + adbc.InfoDriverVersion, + adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, + } + require.ElementsMatch(t, expectedDefaultInfoCodes, driverInfo.InfoSupportedCodes()) + + // We get some formatted default values out of the box + vendorName, ok := driverInfo.GetInfoVendorName() + require.True(t, ok) + require.Equal(t, "test", vendorName) + + driverName, ok := driverInfo.GetInfoDriverName() + require.True(t, ok) + require.Equal(t, "ADBC test Driver - Go", driverName) + + // We can register a string value to an info code that expects a string + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoDriverVersion, "string_value")) + + // We cannot register a non-string value to that same info code + err := driverInfo.RegisterInfoCode(adbc.InfoDriverVersion, 123) + require.Error(t, err) + require.Equal(t, "info_code 101: expected info_value 123 to be of type string but found int", err.Error()) + + // We can also set vendor-specific info codes but they won't get type checked + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001), "string_value")) + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001), 123)) + + // Retrieving known info codes is type-safe + driverVersion, ok := driverInfo.GetInfoDriverName() + require.True(t, ok) + require.NotEmpty(t, strings.Clone(driverVersion)) // do string stuff + + adbcVersion, ok := driverInfo.GetInfoDriverADBCVersion() + require.True(t, ok) + require.NotEmpty(t, adbcVersion+int64(123)) // do int64 stuff + + // We can also retrieve arbitrary info codes, but the result's type must be asserted + arrowVersion, ok := driverInfo.GetInfoForInfoCode(adbc.InfoDriverArrowVersion) + require.True(t, ok) + _, ok = arrowVersion.(string) + require.True(t, ok) + + // We can check if info codes have been set or not + _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_001)) + require.True(t, ok) + _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_002)) + require.False(t, ok) +} diff --git a/go/adbc/driver/internal/driverbase/driver_test.go b/go/adbc/driver/internal/driverbase/driver_test.go new file mode 100644 index 0000000000..f43a049bbe --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_test.go @@ -0,0 +1,595 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase_test + +import ( + "context" + "fmt" + "testing" + + "golang.org/x/exp/slog" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const ( + OptionKeyRecognized = "recognized" + OptionKeyUnrecognized = "unrecognized" +) + +// NewDriver creates a new adbc.Driver for testing. In addition to a memory.Allocator, it takes +// a slog.Handler to use for all structured logging as well as a useHelpers flag to determine whether +// the test should register helper methods or use the default driverbase implementation. +func NewDriver(alloc memory.Allocator, handler slog.Handler, useHelpers bool) adbc.Driver { + info := driverbase.DefaultDriverInfo("MockDriver") + _ = info.RegisterInfoCode(adbc.InfoCode(10_001), "my custom info") + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), handler: handler, useHelpers: useHelpers}) +} + +func TestDefaultDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + + ctx := context.TODO() + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, false) // Do not use helper implementations; only default behavior + + db, err := drv.NewDatabase(nil) + require.NoError(t, err) + defer db.Close() + + require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized: "should-pass"})) + + err = db.SetOptions(map[string]string{OptionKeyUnrecognized: "should-fail"}) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown database option 'unrecognized'", err.Error()) + + cnxn, err := db.Open(ctx) + require.NoError(t, err) + defer func() { + // Cannot close more than once + require.NoError(t, cnxn.Close()) + require.Error(t, cnxn.Close()) + }() + + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot commit when autocommit is enabled", err.Error()) + + err = cnxn.Rollback(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when autocommit is enabled", err.Error()) + + info, err := cnxn.GetInfo(ctx, nil) + require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() + + // This is what the driverbase provided GetInfo result should look like out of the box, + // with one custom setting registered at initialization + expectedGetInfoTable, err := array.TableFromJSON(alloc, adbc.GetInfoSchema, []string{`[ + { + "info_name": 0, + "info_value": [0, "MockDriver"] + }, + { + "info_name": 1, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 2, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 100, + "info_value": [0, "ADBC MockDriver Driver - Go"] + }, + { + "info_name": 101, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 102, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 103, + "info_value": [2, 1001000] + }, + { + "info_name": 10001, + "info_value": [0, "my custom info"] + } + ]`}) + require.NoError(t, err) + defer expectedGetInfoTable.Release() + + require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) + + _, err = cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] GetObjects", err.Error()) + + _, err = cnxn.GetTableTypes(ctx) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] GetTableTypes", err.Error()) + + autocommit, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit) + require.NoError(t, err) + require.Equal(t, adbc.OptionValueEnabled, autocommit) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit, "false") + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unsupported connection option 'adbc.connection.autocommit'", err.Error()) + + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog") + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error()) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "false"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } + +} + +func TestCustomizedDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + + ctx := context.TODO() + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, true) // Use helper implementations + + db, err := drv.NewDatabase(nil) + require.NoError(t, err) + defer db.Close() + + require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized: "should-pass"})) + + err = db.SetOptions(map[string]string{OptionKeyUnrecognized: "should-fail"}) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown database option 'unrecognized'", err.Error()) + + cnxn, err := db.Open(ctx) + require.NoError(t, err) + defer cnxn.Close() + + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot commit when autocommit is enabled", err.Error()) + + err = cnxn.Rollback(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when autocommit is enabled", err.Error()) + + info, err := cnxn.GetInfo(ctx, nil) + require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() + + // This is the arrow table representation of GetInfo produced by merging: + // - the default DriverInfo set at initialization + // - the DriverInfo set once in the NewDriver constructor + // - the DriverInfo set dynamically when GetInfo is called by implementing DriverInfoPreparer interface + expectedGetInfoTable, err := array.TableFromJSON(alloc, adbc.GetInfoSchema, []string{`[ + { + "info_name": 0, + "info_value": [0, "MockDriver"] + }, + { + "info_name": 1, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 2, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 100, + "info_value": [0, "ADBC MockDriver Driver - Go"] + }, + { + "info_name": 101, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 102, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 103, + "info_value": [2, 1001000] + }, + { + "info_name": 10001, + "info_value": [0, "my custom info"] + }, + { + "info_name": 10002, + "info_value": [0, "this was fetched dynamically"] + } + ]`}) + require.NoError(t, err) + defer expectedGetInfoTable.Release() + + require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) + + dbObjects, err := cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil) + require.NoError(t, err) + dbObjectsTable := tableFromRecordReader(dbObjects) + defer dbObjectsTable.Release() + + // This is the arrow table representation of the GetObjects output we get by implementing + // the simplified TableTypeLister interface + expectedDbObjectsTable, err := array.TableFromJSON(alloc, adbc.GetObjectsSchema, []string{`[ + { + "catalog_name": "default", + "catalog_db_schemas": [ + { + "db_schema_name": "public", + "db_schema_tables": [ + { + "table_name": "foo", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + }, + { + "db_schema_name": "test", + "db_schema_tables": [ + { + "table_name": "bar", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + } + ] + }, + { + "catalog_name": "my_db", + "catalog_db_schemas": [ + { + "db_schema_name": "public", + "db_schema_tables": [ + { + "table_name": "baz", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + } + ] + } + ]`}) + require.NoError(t, err) + defer expectedDbObjectsTable.Release() + + require.Truef(t, array.TableEqual(expectedDbObjectsTable, dbObjectsTable), "expected: %s\ngot: %s", expectedDbObjectsTable, dbObjectsTable) + + tableTypes, err := cnxn.GetTableTypes(ctx) + require.NoError(t, err) + tableTypeTable := tableFromRecordReader(tableTypes) + defer tableTypeTable.Release() + + // This is the arrow table representation of the GetTableTypes output we get by implementing + // the simplified TableTypeLister interface + expectedTableTypesTable, err := array.TableFromJSON(alloc, adbc.TableTypesSchema, []string{`[ + { "table_type": "TABLE" }, + { "table_type": "VIEW" } + ]`}) + require.NoError(t, err) + defer expectedTableTypesTable.Release() + + require.Truef(t, array.TableEqual(expectedTableTypesTable, tableTypeTable), "expected: %s\ngot: %s", expectedTableTypesTable, tableTypeTable) + + autocommit, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit) + require.NoError(t, err) + require.Equal(t, adbc.OptionValueEnabled, autocommit) + + // By implementing AutocommitSetter, we are able to successfully toggle autocommit + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit, "false") + require.NoError(t, err) + + // We haven't implemented Commit, but we get NotImplemented instead of InvalidState because + // Autocommit has been explicitly disabled + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Commit", err.Error()) + + // By implementing CurrentNamespacer, we can now get/set the current catalog/dbschema + // Default current(catalog|dbSchema) is driver-specific, but the stub implementation falls back + // to a 'not found' error instead of 'not implemented' + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] failed to get current catalog: current catalog is not set", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog") + require.NoError(t, err) + + currentCatalog, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.NoError(t, err) + require.Equal(t, "test_catalog", currentCatalog) + + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] failed to get current db schema: current db schema is not set", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema, "test_schema") + require.NoError(t, err) + + currentDbSchema, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + require.NoError(t, err) + require.Equal(t, "test_schema", currentDbSchema) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "true"}}, + {Message: "SetAutocommit", Level: "DEBUG", Attrs: map[string]string{"enabled": "false"}}, + {Message: "SetCurrentCatalog", Level: "DEBUG", Attrs: map[string]string{"val": "test_catalog"}}, + {Message: "SetCurrentDbSchema", Level: "DEBUG", Attrs: map[string]string{"val": "test_schema"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } +} + +type driverImpl struct { + driverbase.DriverImplBase + + handler slog.Handler + useHelpers bool +} + +func (drv *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { + db := driverbase.NewDatabase( + &databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&drv.DriverImplBase), + drv: drv, + useHelpers: drv.useHelpers, + }) + db.SetLogger(slog.New(drv.handler)) + return db, nil +} + +type databaseImpl struct { + driverbase.DatabaseImplBase + drv *driverImpl + + useHelpers bool +} + +// SetOptions implements adbc.Database. +func (d *databaseImpl) SetOptions(options map[string]string) error { + for k, v := range options { + if err := d.SetOption(k, v); err != nil { + return err + } + } + return nil +} + +// Only need to implement keys we recognize. +// Any other values will fallthrough to default failure message. +func (d *databaseImpl) SetOption(key, value string) error { + switch key { + case OptionKeyRecognized: + _ = value // pretend to recognize the setting + return nil + } + return d.DatabaseImplBase.SetOption(key, value) +} + +func (db *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { + db.DatabaseImplBase.Logger.Info("Opening a new connection", "withHelpers", db.useHelpers) + cnxn := &connectionImpl{ConnectionImplBase: driverbase.NewConnectionImplBase(&db.DatabaseImplBase), db: db} + bldr := driverbase.NewConnectionBuilder(cnxn) + if db.useHelpers { // this toggles between the NewDefaultDriver and NewCustomizedDriver scenarios + return bldr. + WithAutocommitSetter(cnxn). + WithCurrentNamespacer(cnxn). + WithTableTypeLister(cnxn). + WithDriverInfoPreparer(cnxn). + WithDbObjectsEnumerator(cnxn). + Connection(), nil + } + return bldr.Connection(), nil +} + +type connectionImpl struct { + driverbase.ConnectionImplBase + db *databaseImpl + + currentCatalog string + currentDbSchema string +} + +func (c *connectionImpl) SetAutocommit(enabled bool) error { + c.Base().Logger.Debug("SetAutocommit", "enabled", enabled) + return nil +} + +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + if c.currentCatalog == "" { + return "", fmt.Errorf("current catalog is not set") + } + return c.currentCatalog, nil +} + +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + if c.currentDbSchema == "" { + return "", fmt.Errorf("current db schema is not set") + } + return c.currentDbSchema, nil +} + +func (c *connectionImpl) SetCurrentCatalog(val string) error { + c.Base().Logger.Debug("SetCurrentCatalog", "val", val) + c.currentCatalog = val + return nil +} + +func (c *connectionImpl) SetCurrentDbSchema(val string) error { + c.Base().Logger.Debug("SetCurrentDbSchema", "val", val) + c.currentDbSchema = val + return nil +} + +func (c *connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { + return []string{"TABLE", "VIEW"}, nil +} + +func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { + return c.ConnectionImplBase.DriverInfo.RegisterInfoCode(adbc.InfoCode(10_002), "this was fetched dynamically") +} + +func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) { + return []string{"default", "my_db"}, nil +} + +func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []internal.Metadata) (map[string][]string, error) { + return map[string][]string{ + "default": {"public", "test"}, + "my_db": {"public"}, + }, nil +} + +func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (map[internal.CatalogAndSchema][]internal.TableInfo, error) { + return map[internal.CatalogAndSchema][]internal.TableInfo{ + {Catalog: "default", Schema: "public"}: {internal.TableInfo{Name: "foo", TableType: "TABLE"}}, + {Catalog: "default", Schema: "test"}: {internal.TableInfo{Name: "bar", TableType: "TABLE"}}, + {Catalog: "my_db", Schema: "public"}: {internal.TableInfo{Name: "baz", TableType: "TABLE"}}, + }, nil +} + +// MockedHandler is a mock.Mock that implements the slog.Handler interface. +// It is used to assert specific behavior for loggers it is injected into. +type MockedHandler struct { + mock.Mock +} + +func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true } +func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h } +func (h *MockedHandler) WithGroup(name string) slog.Handler { return h } +func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error { + // We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp) + args := h.Called(ctx, r) + return args.Error(0) +} + +// logMessage is a container for log attributes we would like to compare for equality during tests. +// It intentionally omits timestamps and other sources of nondeterminism. +type logMessage struct { + Message string + Level string + Attrs map[string]string +} + +// newLogMessage constructs a logMessage from a slog.Record, containing only deterministic fields. +func newLogMessage(r slog.Record) logMessage { + message := logMessage{Message: r.Message, Level: r.Level.String(), Attrs: make(map[string]string)} + r.Attrs(func(a slog.Attr) bool { + message.Attrs[a.Key] = a.Value.String() + return true + }) + return message +} + +// messagesEqual compares two logMessages and returns whether they are equal. +func messagesEqual(expected, actual logMessage) bool { + if expected.Message != actual.Message { + return false + } + if expected.Level != actual.Level { + return false + } + if len(expected.Attrs) != len(actual.Attrs) { + return false + } + for k, v := range expected.Attrs { + if actual.Attrs[k] != v { + return false + } + } + return true +} + +func tableFromRecordReader(rdr array.RecordReader) arrow.Table { + defer rdr.Release() + + recs := make([]arrow.Record, 0) + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + defer rec.Release() + recs = append(recs, rec) + } + return array.NewTableFromRecords(rdr.Schema(), recs) +} diff --git a/go/adbc/driver/driverbase/error.go b/go/adbc/driver/internal/driverbase/error.go similarity index 100% rename from go/adbc/driver/driverbase/error.go rename to go/adbc/driver/internal/driverbase/error.go diff --git a/go/adbc/driver/driverbase/logging.go b/go/adbc/driver/internal/driverbase/logging.go similarity index 100% rename from go/adbc/driver/driverbase/logging.go rename to go/adbc/driver/internal/driverbase/logging.go diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 8252665070..4b023b0505 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" "github.com/snowflakedb/gosnowflake" @@ -50,7 +51,9 @@ type snowflakeConn interface { QueryArrowStream(context.Context, string, ...driver.NamedValue) (gosnowflake.ArrowStreamLoader, error) } -type cnxn struct { +type connectionImpl struct { + driverbase.ConnectionImplBase + cn snowflakeConn db *databaseImpl ctor gosnowflake.Connector @@ -60,6 +63,59 @@ type cnxn struct { useHighPrecision bool } +// ListTableTypes implements driverbase.TableTypeLister. +func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { + return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil +} + +// GetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + return c.getStringQuery("SELECT CURRENT_DATABASE()") +} + +// GetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + return c.getStringQuery("SELECT CURRENT_SCHEMA()") +} + +// SetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentCatalog(value string) error { + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err +} + +// SetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentDbSchema(value string) error { + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err +} + +// SetAutocommit implements driverbase.AutocommitSetter. +func (c *connectionImpl) SetAutocommit(enabled bool) error { + if enabled { + if c.activeTransaction { + _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) + if err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + c.activeTransaction = false + } + _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = true", nil) + return err + } + + if !c.activeTransaction { + _, err := c.cn.ExecContext(context.Background(), "BEGIN", nil) + if err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + c.activeTransaction = true + } + _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = false", nil) + return err + +} + // Metadata methods // Generally these methods return an array.RecordReader that // can be consumed to retrieve metadata about the database as Arrow @@ -77,80 +133,6 @@ type cnxn struct { // characters, or "_" to match exactly one character. (See the // documentation of DatabaseMetaData in JDBC or "Pattern Value Arguments" // in the ODBC documentation.) Escaping is not currently supported. -// GetInfo returns metadata about the database/driver. -// -// The result is an Arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|----------------------------- -// info_name | uint32 not null -// info_value | INFO_SCHEMA -// -// INFO_SCHEMA is a dense union with members: -// -// Field Name (Type Code) | Field Type -// ----------------------------|----------------------------- -// string_value (0) | utf8 -// bool_value (1) | bool -// int64_value (2) | int64 -// int32_bitmask (3) | int32 -// string_list (4) | list -// int32_to_int32_list_map (5) | map> -// -// Each metadatum is identified by an integer code. The recognized -// codes are defined as constants. Codes [0, 10_000) are reserved -// for ADBC usage. Drivers/vendors will ignore requests for unrecognized -// codes (the row will be omitted from the result). -func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { - const strValTypeID arrow.UnionTypeCode = 0 - const intValTypeID arrow.UnionTypeCode = 2 - - if len(infoCodes) == 0 { - infoCodes = infoSupportedCodes - } - - bldr := array.NewRecordBuilder(c.db.Alloc, adbc.GetInfoSchema) - defer bldr.Release() - bldr.Reserve(len(infoCodes)) - - infoNameBldr := bldr.Field(0).(*array.Uint32Builder) - infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) - intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) - - for _, code := range infoCodes { - switch code { - case adbc.InfoDriverName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverName) - case adbc.InfoDriverVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverVersion) - case adbc.InfoDriverArrowVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverArrowVersion) - case adbc.InfoDriverADBCVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(intValTypeID) - intInfoBldr.Append(adbc.AdbcVersion1_1_0) - case adbc.InfoVendorName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoVendorName) - default: - infoNameBldr.Append(uint32(code)) - infoValueBldr.AppendNull() - } - } - - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) -} - // GetObjects gets a hierarchical view of all catalogs, database schemas, // tables, and columns. // @@ -238,7 +220,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re // // All non-empty, non-nil strings should be a search pattern (as described // earlier). -func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { +func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { metadataRecords, err := c.populateMetadata(ctx, depth, catalog, dbSchema, tableName, columnName, tableType) if err != nil { return nil, err @@ -266,7 +248,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * return g.Finish() } -func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { +func (c *connectionImpl) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { if depth == adbc.ObjectDepthCatalogs { return } @@ -452,7 +434,7 @@ func toXdbcDataType(dt arrow.DataType) (xdbcType internal.XdbcDataType) { } } -func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { +func (c *connectionImpl) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } @@ -524,7 +506,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat return } -func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { var metadataRecords []internal.Metadata catalogMetadataRecords, err := c.getCatalogsMetadata(ctx) if err != nil { @@ -556,7 +538,7 @@ func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, cat return metadataRecords, nil } -func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, error) { +func (c *connectionImpl) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) rows, err := c.sqldb.QueryContext(ctx, prepareCatalogsSQL(), nil) @@ -585,7 +567,7 @@ func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, er return metadataRecords, nil } -func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string) ([]internal.Metadata, error) { +func (c *connectionImpl) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string) ([]internal.Metadata, error) { var metadataRecords []internal.Metadata query, queryArgs := prepareDbSchemasSQL(matchingCatalogNames, catalog, dbSchema) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -604,7 +586,7 @@ func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames [] return metadataRecords, nil } -func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) getTablesMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, tableType []string) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) query, queryArgs := prepareTablesSQL(matchingCatalogNames, catalog, dbSchema, tableName, tableType) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -623,7 +605,7 @@ func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames []str return metadataRecords, nil } -func (c *cnxn) getColumnsMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) getColumnsMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) query, queryArgs := prepareColumnsSQL(matchingCatalogNames, catalog, dbSchema, tableName, columnName, tableType) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -870,29 +852,7 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } -func (c *cnxn) GetOption(key string) (string, error) { - switch key { - case adbc.OptionKeyAutoCommit: - if c.activeTransaction { - // No autocommit - return adbc.OptionValueDisabled, nil - } else { - // Autocommit - return adbc.OptionValueEnabled, nil - } - case adbc.OptionKeyCurrentCatalog: - return c.getStringQuery("SELECT CURRENT_DATABASE()") - case adbc.OptionKeyCurrentDbSchema: - return c.getStringQuery("SELECT CURRENT_SCHEMA()") - } - - return "", adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) getStringQuery(query string) (string, error) { +func (c *connectionImpl) getStringQuery(query string) (string, error) { result, err := c.cn.QueryContext(context.Background(), query, nil) if err != nil { return "", errToAdbcErr(adbc.StatusInternal, err) @@ -928,28 +888,7 @@ func (c *cnxn) getStringQuery(query string) (string, error) { return value, nil } -func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { - return nil, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetOptionInt(key string) (int64, error) { - return 0, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetOptionDouble(key string) (float64, error) { - return 0.0, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { +func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { tblParts = append(tblParts, strconv.Quote(*catalog)) @@ -990,35 +929,11 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st return sc, nil } -// GetTableTypes returns a list of the table types in the database. -// -// The result is an arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------|-------------- -// table_type | utf8 not null -func (c *cnxn) GetTableTypes(_ context.Context) (array.RecordReader, error) { - bldr := array.NewRecordBuilder(c.db.Alloc, adbc.TableTypesSchema) - defer bldr.Release() - - bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil) - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.TableTypesSchema, []arrow.Record{final}) -} - // Commit commits any pending transactions on this connection, it should // only be used if autocommit is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. -func (c *cnxn) Commit(_ context.Context) error { - if !c.activeTransaction { - return adbc.Error{ - Msg: "no active transaction, cannot commit", - Code: adbc.StatusInvalidState, - } - } - +func (c *connectionImpl) Commit(_ context.Context) error { _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) @@ -1032,14 +947,7 @@ func (c *cnxn) Commit(_ context.Context) error { // is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. -func (c *cnxn) Rollback(_ context.Context) error { - if !c.activeTransaction { - return adbc.Error{ - Msg: "no active transaction, cannot rollback", - Code: adbc.StatusInvalidState, - } - } - +func (c *connectionImpl) Rollback(_ context.Context) error { _, err := c.cn.ExecContext(context.Background(), "ROLLBACK", nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) @@ -1050,7 +958,7 @@ func (c *cnxn) Rollback(_ context.Context) error { } // NewStatement initializes a new statement object tied to this connection -func (c *cnxn) NewStatement() (adbc.Statement, error) { +func (c *connectionImpl) NewStatement() (adbc.Statement, error) { defaultIngestOptions := DefaultIngestOptions() return &statement{ alloc: c.db.Alloc, @@ -1063,7 +971,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) { } // Close closes this connection and releases any associated resources. -func (c *cnxn) Close() error { +func (c *connectionImpl) Close() error { if c.sqldb == nil || c.cn == nil { return adbc.Error{Code: adbc.StatusInvalidState} } @@ -1083,49 +991,15 @@ func (c *cnxn) Close() error { // results can then be read independently using the returned RecordReader. // // A partition can be retrieved by using ExecutePartitions on a statement. -func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { +func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { return nil, adbc.Error{ Code: adbc.StatusNotImplemented, Msg: "ReadPartition not yet implemented for snowflake driver", } } -func (c *cnxn) SetOption(key, value string) error { +func (c *connectionImpl) SetOption(key, value string) error { switch key { - case adbc.OptionKeyAutoCommit: - switch value { - case adbc.OptionValueEnabled: - if c.activeTransaction { - _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) - if err != nil { - return errToAdbcErr(adbc.StatusInternal, err) - } - c.activeTransaction = false - } - _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = true", nil) - return err - case adbc.OptionValueDisabled: - if !c.activeTransaction { - _, err := c.cn.ExecContext(context.Background(), "BEGIN", nil) - if err != nil { - return errToAdbcErr(adbc.StatusInternal, err) - } - c.activeTransaction = true - } - _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = false", nil) - return err - default: - return adbc.Error{ - Msg: "[Snowflake] invalid value for option " + key + ": " + value, - Code: adbc.StatusInvalidArgument, - } - } - case adbc.OptionKeyCurrentCatalog: - _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) - return err - case adbc.OptionKeyCurrentDbSchema: - _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) - return err case OptionUseHighPrecision: // statements will inherit the value of the OptionUseHighPrecision // from the connection, but the option can be overridden at the @@ -1149,24 +1023,3 @@ func (c *cnxn) SetOption(key, value string) error { } } } - -func (c *cnxn) SetOptionBytes(key string, value []byte) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionInt(key string, value int64) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionDouble(key string, value float64) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index 77bdcdaac9..124c4d3887 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -20,19 +20,15 @@ package snowflake import ( "errors" "runtime/debug" - "strings" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/memory" "github.com/snowflakedb/gosnowflake" "golang.org/x/exp/maps" ) const ( - infoDriverName = "ADBC Snowflake Driver - Go" - infoVendorName = "Snowflake" - OptionDatabase = "adbc.snowflake.sql.db" OptionSchema = "adbc.snowflake.sql.schema" OptionWarehouse = "adbc.snowflake.sql.warehouse" @@ -119,37 +115,18 @@ const ( ) var ( - infoDriverVersion string - infoDriverArrowVersion string - infoSupportedCodes []adbc.InfoCode + infoVendorVersion string ) func init() { if info, ok := debug.ReadBuildInfo(); ok { for _, dep := range info.Deps { switch { - case dep.Path == "github.com/apache/arrow-adbc/go/adbc/driver/snowflake": - infoDriverVersion = dep.Version - case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): - infoDriverArrowVersion = dep.Version + case dep.Path == "github.com/snowflakedb/gosnowflake": + infoVendorVersion = dep.Version } } } - // XXX: Deps not populated in tests - // https://github.com/golang/go/issues/33976 - if infoDriverVersion == "" { - infoDriverVersion = "(unknown or development build)" - } - if infoDriverArrowVersion == "" { - infoDriverArrowVersion = "(unknown or development build)" - } - - infoSupportedCodes = []adbc.InfoCode{ - adbc.InfoDriverName, - adbc.InfoDriverVersion, - adbc.InfoDriverArrowVersion, - adbc.InfoVendorName, - } } func errToAdbcErr(code adbc.Status, err error) error { @@ -192,13 +169,21 @@ type driverImpl struct { // NewDriver creates a new Snowflake driver using the given Arrow allocator. func NewDriver(alloc memory.Allocator) adbc.Driver { - return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Snowflake", alloc)}) + info := driverbase.DefaultDriverInfo("Snowflake") + if infoVendorVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoVendorVersion, infoVendorVersion); err != nil { + panic(err) + } + } + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}) } func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { opts = maps.Clone(opts) - db := &databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase), - useHighPrecision: true} + db := &databaseImpl{ + DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase), + useHighPrecision: true, + } if err := db.SetOptions(opts); err != nil { return nil, err } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 5752ae5eec..3f93dbdb58 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -217,6 +217,10 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoVendorVersion: + return "(unknown or development build)" + case adbc.InfoVendorArrowVersion: + return "(unknown or development build)" case adbc.InfoDriverADBCVersion: return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: diff --git a/go/adbc/driver/snowflake/snowflake_database.go b/go/adbc/driver/snowflake/snowflake_database.go index 76ab4684bf..5c5f32b690 100644 --- a/go/adbc/driver/snowflake/snowflake_database.go +++ b/go/adbc/driver/snowflake/snowflake_database.go @@ -32,7 +32,7 @@ import ( "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/snowflakedb/gosnowflake" "github.com/youmark/pkcs8" ) @@ -136,28 +136,7 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return *val, nil } } - return "", adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionBytes(key string) ([]byte, error) { - return nil, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionInt(key string) (int64, error) { - return 0, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionDouble(key string) (float64, error) { - return 0, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } + return d.DatabaseImplBase.GetOption(key) } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { @@ -176,7 +155,8 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { } } - defaultAppName := "[ADBC][Go-" + infoDriverVersion + "]" + driverVersion, _ := d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion() + defaultAppName := "[ADBC][Go-" + driverVersion + "]" // set default application name to track // unless user overrides it d.cfg.Application = defaultAppName @@ -464,15 +444,22 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { return nil, errToAdbcErr(adbc.StatusIO, err) } - return &cnxn{ + conn := &connectionImpl{ cn: cn.(snowflakeConn), db: d, ctor: connector, sqldb: sql.OpenDB(connector), // default enable high precision // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to // get Int64/Float64 instead - useHighPrecision: d.useHighPrecision, - }, nil + useHighPrecision: d.useHighPrecision, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + } + + return driverbase.NewConnectionBuilder(conn). + WithAutocommitSetter(conn). + WithCurrentNamespacer(conn). + WithTableTypeLister(conn). + Connection(), nil } func (d *databaseImpl) Close() error { diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 8439ddfcd4..3f446662ea 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -42,7 +42,7 @@ const ( ) type statement struct { - cnxn *cnxn + cnxn *connectionImpl alloc memory.Allocator queueSize int prefetchConcurrency int diff --git a/go/adbc/go.mod b/go/adbc/go.mod index fe30fd7cb3..35bdc70817 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -83,6 +83,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/mod v0.16.0 // indirect diff --git a/go/adbc/go.sum b/go/adbc/go.sum index a5bea3b7f5..cf74b7baa8 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -130,6 +130,7 @@ github.com/snowflakedb/gosnowflake v1.8.0 h1:4bQj8eAYGMkou/nICiIEb9jSbBLDDp5cB6J github.com/snowflakedb/gosnowflake v1.8.0/go.mod h1:7yyY2MxtDti2eXgtvlZ8QxzCN6KV2B4qb1HuygMI+0U= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=