Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/driver/snowflake): support parameter binding #1808

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_metadata_current_catalog() const override { return false; }
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return schema_; }
Expand Down
153 changes: 153 additions & 0 deletions go/adbc/driver/snowflake/binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"database/sql"
"database/sql/driver"
"fmt"
"io"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

func convertArrowToNamedValue(batch arrow.Record, index int) ([]driver.NamedValue, error) {
// see goTypeToSnowflake in gosnowflake
// technically, snowflake can bind an array of values at once, but
// only for INSERT, so we can't take advantage of that without
// analyzing the query ourselves
Comment on lines +33 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate this, but nothing we can do right now i guess

params := make([]driver.NamedValue, batch.NumCols())
for i, field := range batch.Schema().Fields() {
rawColumn := batch.Column(i)
params[i].Ordinal = i + 1
switch column := rawColumn.(type) {
case *array.Boolean:
params[i].Value = sql.NullBool{
Bool: column.Value(index),
Valid: column.IsValid(index),
}
Comment on lines +41 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something for us to think about is that go1.22 added sql.Null[T] which might be useful to allow us to create a more generic implementation of this later on (since we officially still support go 1.21, we wouldn't be able to use it yet)

case *array.Float32:
// Snowflake only recognizes float64
params[i].Value = sql.NullFloat64{
Float64: float64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Float64:
params[i].Value = sql.NullFloat64{
Float64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.Int8:
// Snowflake only recognizes int64
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int16:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int32:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int64:
params[i].Value = sql.NullInt64{
Int64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.String:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
case *array.LargeString:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
default:
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: fmt.Sprintf("[Snowflake] Unsupported bind param '%s' type %s", field.Name, field.Type.String()),
zeroshade marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
return params, nil
}

type snowflakeBindReader struct {
doQuery func([]driver.NamedValue) (array.RecordReader, error)
currentBatch arrow.Record
nextIndex int64
// may be nil if we bound only a batch
stream array.RecordReader
}

func (r *snowflakeBindReader) Release() {
if r.currentBatch != nil {
r.currentBatch.Release()
r.currentBatch = nil
}
if r.stream != nil {
r.stream.Release()
r.stream = nil
}
}

func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
params, err := r.NextParams()
if err != nil {
// includes EOF
return nil, err
}
return r.doQuery(params)
}

func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) {
for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
// We can be used both by binding a stream or by binding a
// batch. In the latter case, we have to release the batch,
// but not in the former case. Unify the cases by always
// releasing the batch, adding an "extra" retain so that the
// release does not cause issues.
if r.currentBatch != nil {
r.currentBatch.Release()
}
zeroshade marked this conversation as resolved.
Show resolved Hide resolved
r.currentBatch = nil
if r.stream != nil && r.stream.Next() {
r.currentBatch = r.stream.Record()
r.currentBatch.Retain()
r.nextIndex = 0
continue
} else if r.stream != nil && r.stream.Err() != nil {
return nil, r.stream.Err()
} else {
// no more params
return nil, io.EOF
}
}

params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex))
r.nextIndex++
return params, err
}
107 changes: 107 additions & 0 deletions go/adbc/driver/snowflake/concat_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"io"
"sync/atomic"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

type readerIter interface {
Release()

Next() (array.RecordReader, error)
}

type concatReader struct {
refCount atomic.Int64
readers readerIter
currentReader array.RecordReader
schema *arrow.Schema
err error
}

func (r *concatReader) nextReader() {
if r.currentReader != nil {
r.currentReader.Release()
r.currentReader = nil
}
reader, err := r.readers.Next()
if err == io.EOF {
r.currentReader = nil
} else if err != nil {
r.err = err
} else {
// May be nil
r.currentReader = reader
}
}
func (r *concatReader) Init(readers readerIter) error {
r.readers = readers
r.refCount.Store(1)
r.nextReader()
if r.err != nil {
r.Release()
return r.err
} else if r.currentReader == nil {
r.Release()
r.err = adbc.Error{
Code: adbc.StatusInternal,
Msg: "[Snowflake] No data in this stream",
}
return r.err
}
r.schema = r.currentReader.Schema()
return nil
}
func (r *concatReader) Retain() {
r.refCount.Add(1)
}
func (r *concatReader) Release() {
if r.refCount.Add(-1) == 0 {
if r.currentReader != nil {
r.currentReader.Release()
}
r.readers.Release()
}
}
func (r *concatReader) Schema() *arrow.Schema {
if r.schema == nil {
panic("did not call concatReader.Init")
}
return r.schema
}
func (r *concatReader) Next() bool {
for r.currentReader != nil && !r.currentReader.Next() {
r.nextReader()
}
if r.currentReader == nil || r.err != nil {
return false
}
return true
}
func (r *concatReader) Record() arrow.Record {
return r.currentReader.Record()
}
func (r *concatReader) Err() error {
return r.err
}
2 changes: 1 addition & 1 deletion go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package snowflake

import (
"errors"
"maps"
"runtime/debug"
"strings"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
)

const (
Expand Down
56 changes: 53 additions & 3 deletions go/adbc/driver/snowflake/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package snowflake

import (
"context"
"database/sql/driver"
"fmt"
"io"
"strconv"
"strings"

Expand Down Expand Up @@ -463,10 +465,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6
// concatenate RecordReaders which doesn't exist yet. let's put
// that off for now.
if st.streamBind != nil || st.bound != nil {
return nil, -1, adbc.Error{
Msg: "executing non-bulk ingest with bound params not yet implemented",
Code: adbc.StatusNotImplemented,
bind := snowflakeBindReader{
doQuery: func(params []driver.NamedValue) (array.RecordReader, error) {
loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query, params...)
if err != nil {
return nil, errToAdbcErr(adbc.StatusInternal, err)
}
return newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
},
currentBatch: st.bound,
stream: st.streamBind,
}
st.bound = nil
st.streamBind = nil

rdr := concatReader{}
err := rdr.Init(&bind)
if err != nil {
return nil, -1, err
}
return &rdr, -1, nil
}

loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
Expand All @@ -493,6 +511,38 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
}
}

if st.streamBind != nil || st.bound != nil {
numRows := int64(0)
bind := snowflakeBindReader{
currentBatch: st.bound,
stream: st.streamBind,
}
st.bound = nil
st.streamBind = nil

defer bind.Release()
for {
params, err := bind.NextParams()
if err == io.EOF {
break
} else if err != nil {
return -1, err
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
if err != nil {
return -1, errToAdbcErr(adbc.StatusInternal, err)
}
n, err := r.RowsAffected()
if err != nil {
numRows = -1
} else if numRows >= 0 {
numRows += n
}
}
return numRows, nil
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
if err != nil {
return -1, errToAdbcErr(adbc.StatusIO, err)
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.21

require (
github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6
github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313
github.com/bluele/gcache v0.0.2
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
Expand All @@ -31,7 +31,7 @@ require (
golang.org/x/sync v0.7.0
golang.org/x/tools v0.20.0
google.golang.org/grpc v1.63.2
google.golang.org/protobuf v1.33.0
google.golang.org/protobuf v1.34.0
)

require (
Expand Down
Loading
Loading