Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): support parameter binding (apache#1808)
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm authored May 7, 2024
1 parent 346b012 commit 8f0bdb3
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 11 deletions.
2 changes: 1 addition & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_metadata_current_catalog() const override { return false; }
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return schema_; }
Expand Down
153 changes: 153 additions & 0 deletions go/adbc/driver/snowflake/binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"database/sql"
"database/sql/driver"
"fmt"
"io"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

func convertArrowToNamedValue(batch arrow.Record, index int) ([]driver.NamedValue, error) {
// see goTypeToSnowflake in gosnowflake
// technically, snowflake can bind an array of values at once, but
// only for INSERT, so we can't take advantage of that without
// analyzing the query ourselves
params := make([]driver.NamedValue, batch.NumCols())
for i, field := range batch.Schema().Fields() {
rawColumn := batch.Column(i)
params[i].Ordinal = i + 1
switch column := rawColumn.(type) {
case *array.Boolean:
params[i].Value = sql.NullBool{
Bool: column.Value(index),
Valid: column.IsValid(index),
}
case *array.Float32:
// Snowflake only recognizes float64
params[i].Value = sql.NullFloat64{
Float64: float64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Float64:
params[i].Value = sql.NullFloat64{
Float64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.Int8:
// Snowflake only recognizes int64
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int16:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int32:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int64:
params[i].Value = sql.NullInt64{
Int64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.String:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
case *array.LargeString:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
default:
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: fmt.Sprintf("[Snowflake] Unsupported bind param '%s' type %s", field.Name, field.Type.String()),
}
}
}
return params, nil
}

type snowflakeBindReader struct {
doQuery func([]driver.NamedValue) (array.RecordReader, error)
currentBatch arrow.Record
nextIndex int64
// may be nil if we bound only a batch
stream array.RecordReader
}

func (r *snowflakeBindReader) Release() {
if r.currentBatch != nil {
r.currentBatch.Release()
r.currentBatch = nil
}
if r.stream != nil {
r.stream.Release()
r.stream = nil
}
}

func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
params, err := r.NextParams()
if err != nil {
// includes EOF
return nil, err
}
return r.doQuery(params)
}

func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) {
for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
// We can be used both by binding a stream or by binding a
// batch. In the latter case, we have to release the batch,
// but not in the former case. Unify the cases by always
// releasing the batch, adding an "extra" retain so that the
// release does not cause issues.
if r.currentBatch != nil {
r.currentBatch.Release()
}
r.currentBatch = nil
if r.stream != nil && r.stream.Next() {
r.currentBatch = r.stream.Record()
r.currentBatch.Retain()
r.nextIndex = 0
continue
} else if r.stream != nil && r.stream.Err() != nil {
return nil, r.stream.Err()
} else {
// no more params
return nil, io.EOF
}
}

params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex))
r.nextIndex++
return params, err
}
107 changes: 107 additions & 0 deletions go/adbc/driver/snowflake/concat_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"io"
"sync/atomic"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

type readerIter interface {
Release()

Next() (array.RecordReader, error)
}

type concatReader struct {
refCount atomic.Int64
readers readerIter
currentReader array.RecordReader
schema *arrow.Schema
err error
}

func (r *concatReader) nextReader() {
if r.currentReader != nil {
r.currentReader.Release()
r.currentReader = nil
}
reader, err := r.readers.Next()
if err == io.EOF {
r.currentReader = nil
} else if err != nil {
r.err = err
} else {
// May be nil
r.currentReader = reader
}
}
func (r *concatReader) Init(readers readerIter) error {
r.readers = readers
r.refCount.Store(1)
r.nextReader()
if r.err != nil {
r.Release()
return r.err
} else if r.currentReader == nil {
r.Release()
r.err = adbc.Error{
Code: adbc.StatusInternal,
Msg: "[Snowflake] No data in this stream",
}
return r.err
}
r.schema = r.currentReader.Schema()
return nil
}
func (r *concatReader) Retain() {
r.refCount.Add(1)
}
func (r *concatReader) Release() {
if r.refCount.Add(-1) == 0 {
if r.currentReader != nil {
r.currentReader.Release()
}
r.readers.Release()
}
}
func (r *concatReader) Schema() *arrow.Schema {
if r.schema == nil {
panic("did not call concatReader.Init")
}
return r.schema
}
func (r *concatReader) Next() bool {
for r.currentReader != nil && !r.currentReader.Next() {
r.nextReader()
}
if r.currentReader == nil || r.err != nil {
return false
}
return true
}
func (r *concatReader) Record() arrow.Record {
return r.currentReader.Record()
}
func (r *concatReader) Err() error {
return r.err
}
2 changes: 1 addition & 1 deletion go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package snowflake

import (
"errors"
"maps"
"runtime/debug"
"strings"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
)

const (
Expand Down
56 changes: 53 additions & 3 deletions go/adbc/driver/snowflake/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package snowflake

import (
"context"
"database/sql/driver"
"fmt"
"io"
"strconv"
"strings"

Expand Down Expand Up @@ -463,10 +465,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6
// concatenate RecordReaders which doesn't exist yet. let's put
// that off for now.
if st.streamBind != nil || st.bound != nil {
return nil, -1, adbc.Error{
Msg: "executing non-bulk ingest with bound params not yet implemented",
Code: adbc.StatusNotImplemented,
bind := snowflakeBindReader{
doQuery: func(params []driver.NamedValue) (array.RecordReader, error) {
loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query, params...)
if err != nil {
return nil, errToAdbcErr(adbc.StatusInternal, err)
}
return newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
},
currentBatch: st.bound,
stream: st.streamBind,
}
st.bound = nil
st.streamBind = nil

rdr := concatReader{}
err := rdr.Init(&bind)
if err != nil {
return nil, -1, err
}
return &rdr, -1, nil
}

loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
Expand All @@ -493,6 +511,38 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
}
}

if st.streamBind != nil || st.bound != nil {
numRows := int64(0)
bind := snowflakeBindReader{
currentBatch: st.bound,
stream: st.streamBind,
}
st.bound = nil
st.streamBind = nil

defer bind.Release()
for {
params, err := bind.NextParams()
if err == io.EOF {
break
} else if err != nil {
return -1, err
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
if err != nil {
return -1, errToAdbcErr(adbc.StatusInternal, err)
}
n, err := r.RowsAffected()
if err != nil {
numRows = -1
} else if numRows >= 0 {
numRows += n
}
}
return numRows, nil
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
if err != nil {
return -1, errToAdbcErr(adbc.StatusIO, err)
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.21

require (
github.com/apache/arrow/go/v17 v17.0.0-20240430043840-e4f31462dbd6
github.com/apache/arrow/go/v17 v17.0.0-20240503231747-7cd9c6fbd313
github.com/bluele/gcache v0.0.2
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
Expand All @@ -31,7 +31,7 @@ require (
golang.org/x/sync v0.7.0
golang.org/x/tools v0.21.0
google.golang.org/grpc v1.63.2
google.golang.org/protobuf v1.33.0
google.golang.org/protobuf v1.34.0
)

require (
Expand Down
Loading

0 comments on commit 8f0bdb3

Please sign in to comment.