Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper Methods that are Supported by SQLite Driver #1304

Merged
merged 9 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 86 additions & 6 deletions go/adbc/drivermgr/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ package drivermgr
// return (struct ArrowArray*)malloc(sizeof(struct ArrowArray));
// }
//
// struct ArrowArrayStream* allocArrStream() {
// return (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream));
// }
//
import "C"
import (
"context"
Expand Down Expand Up @@ -186,6 +190,15 @@ func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) {
return rdr.(array.RecordReader), nil
}

func getSchema(out *C.struct_ArrowSchema) (*arrow.Schema, error) {
// Maybe: ImportCArrowSchema should perform this check?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is needed for when catalog or dbSchema is specified for GetTableSchema. The SQLite driver does not return a schema when either of these are set. As a result, we attempt to import a schema through the cdata interface with all nil fields including format which panics when it reaches this line.

It does seem like it would be helpful to perform these nil checks in the cdata package, but I just wanted to check first as that would require cross-repo changes and add to the scope of these changes.

if out.format == nil {
return nil, nil
}

return cdata.ImportCArrowSchema((*cdata.CArrowSchema)(unsafe.Pointer(out)))
}

type cnxn struct {
conn *C.struct_AdbcConnection
}
Expand Down Expand Up @@ -255,19 +268,68 @@ func (c *cnxn) GetObjects(_ context.Context, depth adbc.ObjectDepth, catalog, db
}

func (c *cnxn) GetTableSchema(_ context.Context, catalog, dbSchema *string, tableName string) (*arrow.Schema, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
schema C.struct_ArrowSchema
err C.struct_AdbcError
catalog_ *C.char
dbSchema_ *C.char
tableName_ *C.char
)

if catalog != nil {
catalog_ = C.CString(*catalog)
defer C.free(unsafe.Pointer(catalog_))
}

if dbSchema != nil {
dbSchema_ = C.CString(*dbSchema)
defer C.free(unsafe.Pointer(dbSchema_))
}

tableName_ = C.CString(tableName)
defer C.free(unsafe.Pointer(tableName_))

if code := adbc.Status(C.AdbcConnectionGetTableSchema(c.conn, catalog_, dbSchema_, tableName_, &schema, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}

return getSchema(&schema)
}

func (c *cnxn) GetTableTypes(context.Context) (array.RecordReader, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
out C.struct_ArrowArrayStream
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionGetTableTypes(c.conn, &out, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}
return getRdr(&out)
}

func (c *cnxn) Commit(context.Context) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionCommit(c.conn, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}

return nil
}

func (c *cnxn) Rollback(context.Context) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionRollback(c.conn, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}

return nil
}

func (c *cnxn) NewStatement() (adbc.Statement, error) {
Expand Down Expand Up @@ -405,11 +467,29 @@ func (s *stmt) Bind(_ context.Context, values arrow.Record) error {
}

func (s *stmt) BindStream(_ context.Context, stream array.RecordReader) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
arrStream = C.allocArrStream()
cdArrStream = (*cdata.CArrowArrayStream)(unsafe.Pointer(arrStream))
err C.struct_AdbcError
)
cdata.ExportRecordReader(stream, cdArrStream)
if code := adbc.Status(C.AdbcStatementBindStream(s.st, arrStream, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}
return nil
}

func (s *stmt) GetParameterSchema() (*arrow.Schema, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
schema C.struct_ArrowSchema
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcStatementGetParameterSchema(s.st, &schema, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}

return getSchema(&schema)
}

func (s *stmt) ExecutePartitions(context.Context) (*arrow.Schema, adbc.Partitions, int64, error) {
Expand Down
163 changes: 158 additions & 5 deletions go/adbc/drivermgr/wrapper_sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,25 @@ func (dm *DriverMgrSuite) SetupSuite() {
})
dm.NoError(err)

db, err := dm.db.Open(dm.ctx)
cnxn, err := dm.db.Open(dm.ctx)
dm.NoError(err)
defer db.Close()
defer cnxn.Close()

stmt, err := db.NewStatement()
stmt, err := cnxn.NewStatement()
dm.NoError(err)
defer stmt.Close()

err = stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)")
dm.NoError(err)
dm.NoError(stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)"))

nrows, err := stmt.ExecuteUpdate(dm.ctx)
dm.NoError(err)
dm.Equal(int64(0), nrows)

dm.NoError(stmt.SetSqlQuery("INSERT INTO test_table (id, name) VALUES (1, 'test')"))

nrows, err = stmt.ExecuteUpdate(dm.ctx)
dm.NoError(err)
dm.Equal(int64(1), nrows)
}

func (dm *DriverMgrSuite) SetupTest() {
Expand Down Expand Up @@ -334,6 +339,83 @@ func (dm *DriverMgrSuite) TestGetObjectsTableType() {
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestGetTableSchema() {
schema, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "test_table")
dm.NoError(err)

expSchema := arrow.NewSchema(
[]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)
dm.True(expSchema.Equal(schema))
}

func (dm *DriverMgrSuite) TestGetTableSchemaInvalidTable() {
_, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "unknown_table")
dm.Error(err)
}

func (dm *DriverMgrSuite) TestGetTableSchemaCatalog() {
catalog := "does_not_exist"
schema, err := dm.conn.GetTableSchema(dm.ctx, &catalog, nil, "test_table")
dm.NoError(err)
dm.Nil(schema)
}

func (dm *DriverMgrSuite) TestGetTableSchemaDBSchema() {
dbSchema := "does_not_exist"
schema, err := dm.conn.GetTableSchema(dm.ctx, nil, &dbSchema, "test_table")
dm.NoError(err)
dm.Nil(schema)
}

func (dm *DriverMgrSuite) TestGetTableTypes() {
rdr, err := dm.conn.GetTableTypes(dm.ctx)
dm.NoError(err)
defer rdr.Release()

expSchema := adbc.TableTypesSchema
dm.True(expSchema.Equal(rdr.Schema()))
dm.True(rdr.Next())

rec := rdr.Record()
dm.Equal(int64(2), rec.NumRows())

expTableTypes := []string{"table", "view"}
dm.Contains(expTableTypes, rec.Column(0).ValueStr(0))
dm.Contains(expTableTypes, rec.Column(0).ValueStr(1))
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestCommit() {
err := dm.conn.Commit(dm.ctx)
dm.Error(err)
dm.ErrorContains(err, "No active transaction, cannot commit")
}

func (dm *DriverMgrSuite) TestCommitAutocommitDisabled() {
cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
dm.True(ok)

dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
dm.NoError(dm.conn.Commit(dm.ctx))
}

func (dm *DriverMgrSuite) TestRollback() {
err := dm.conn.Rollback(dm.ctx)
dm.Error(err)
dm.ErrorContains(err, "No active transaction, cannot rollback")
}

func (dm *DriverMgrSuite) TestRollbackAutocommitDisabled() {
cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
dm.True(ok)

dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
dm.NoError(dm.conn.Rollback(dm.ctx))
}

func (dm *DriverMgrSuite) TestSqlExecute() {
query := "SELECT 1"
st, err := dm.conn.NewStatement()
Expand Down Expand Up @@ -429,6 +511,77 @@ func (dm *DriverMgrSuite) TestSqlPrepareMultipleParams() {
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestGetParameterSchema() {
query := "SELECT ?1, ?2"
st, err := dm.conn.NewStatement()
dm.Require().NoError(err)
dm.Require().NoError(st.SetSqlQuery(query))
defer st.Close()

expSchema := arrow.NewSchema([]arrow.Field{
{Name: "?1", Type: arrow.Null, Nullable: true},
{Name: "?2", Type: arrow.Null, Nullable: true},
}, nil)

schema, err := st.GetParameterSchema()
dm.NoError(err)

dm.True(expSchema.Equal(schema))
}

func (dm *DriverMgrSuite) TestBindStream() {
query := "SELECT ?1, ?2"
st, err := dm.conn.NewStatement()
dm.Require().NoError(err)
dm.Require().NoError(st.SetSqlQuery(query))
defer st.Close()

schema := arrow.NewSchema([]arrow.Field{
{Name: "1", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "2", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil)
bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil)

rec1 := bldr.NewRecord()
defer rec1.Release()

bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{4, 5, 6}, nil)
bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"four", "five", "six"}, nil)

rec2 := bldr.NewRecord()
defer rec2.Release()

recsIn := []arrow.Record{rec1, rec2}
rdrIn, err := array.NewRecordReader(schema, recsIn)
dm.NoError(err)

dm.NoError(st.BindStream(dm.ctx, rdrIn))

rdrOut, _, err := st.ExecuteQuery(dm.ctx)
dm.NoError(err)
defer rdrOut.Release()

recsOut := make([]arrow.Record, 0)
for rdrOut.Next() {
rec := rdrOut.Record()
rec.Retain()
defer rec.Release()
recsOut = append(recsOut, rec)
}

tableIn := array.NewTableFromRecords(schema, recsIn)
defer tableIn.Release()
tableOut := array.NewTableFromRecords(schema, recsOut)
defer tableOut.Release()

dm.Truef(array.TableEqual(tableIn, tableOut), "expected: %s\ngot: %s", tableIn, tableOut)
}

func TestDriverMgr(t *testing.T) {
suite.Run(t, new(DriverMgrSuite))
}
Expand Down
Loading