diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 4434d902f3..6107cd4741 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -61,8 +61,9 @@ type connectionImpl struct { ctor gosnowflake.Connector sqldb *sql.DB - activeTransaction bool - useHighPrecision bool + activeTransaction bool + useHighPrecision bool + quotedIdentIgnoreCase bool } // Uniquely identify a constraint based on the dbName, schema, and tblName @@ -1337,6 +1338,25 @@ func (c *connectionImpl) SetOption(key, value string) error { } } return nil + case OptionQuotedIdentifiersIgnoreCase: + switch value { + case adbc.OptionValueEnabled, adbc.OptionValueDisabled: + c.quotedIdentIgnoreCase = value == adbc.OptionValueEnabled + q := "ALTER SESSION SET QUOTED_IDENTIFIERS_IGNORE_CASE = " + value + if _, err := c.cn.ExecContext(context.Background(), q, nil); err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + + if _, err := c.sqldb.ExecContext(context.Background(), q); err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + default: + return adbc.Error{ + Msg: "[Snowflake] invalid value for option " + key + ": " + value, + Code: adbc.StatusInvalidArgument, + } + } + return nil default: return adbc.Error{ Msg: "[Snowflake] unknown connection option " + key + ": " + value, @@ -1344,3 +1364,20 @@ func (c *connectionImpl) SetOption(key, value string) error { } } } + +func (c *connectionImpl) GetOption(key string) (string, error) { + switch key { + case OptionUseHighPrecision: + if c.useHighPrecision { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionQuotedIdentifiersIgnoreCase: + if c.quotedIdentIgnoreCase { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + default: + return c.db.GetOption(key) + } +} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index da49a6097d..1d817ab9cb 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -72,9 +72,17 @@ const ( // scale will return a Float64 column. OptionUseHighPrecision = "adbc.snowflake.sql.client_option.use_high_precision" - OptionApplicationName = "adbc.snowflake.sql.client_option.app_name" - OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify" - OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode" + // OptionQuotedIdentifiersIgnoreCase refers to the corresponding snowflake session + // parameter (https://docs.snowflake.com/en/sql-reference/parameters#label-quoted-identifiers-ignore-case) + // which controls whether or not the case of quoted identifiers will be preserved (default) + // or will be ignored (storing and resolving as uppercase). + // Because functionality such as bulk ingest and other options will automatically add quotes + // to identifiers by default, this option can be set to TRUE to ensure that the casing will + // be ignored for that functionality despite the fact that we wrap it in quotes. + OptionQuotedIdentifiersIgnoreCase = "adbc.snowflake.sql.quoted_identifiers_ignore_case" + OptionApplicationName = "adbc.snowflake.sql.client_option.app_name" + OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify" + OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode" // specify the token to use for OAuth or other forms of authentication OptionAuthToken = "adbc.snowflake.sql.client_option.auth_token" // specify the OKTAUrl to use for OKTA Authentication diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 4c6f299941..1c30c88357 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -44,6 +44,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" "github.com/google/uuid" "github.com/snowflakedb/gosnowflake" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -2031,3 +2032,122 @@ func (suite *SnowflakeTests) TestMetadataOnlyQuery() { // all the rows from each record in the stream. suite.Equal(n, recv) } + +func TestSnowflakeQuotedIdentIgnoreCase(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + sc := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_list", Type: arrow.ListOf(arrow.BinaryTypes.String), + Nullable: true, + }, + }, nil) + + bldr := array.NewRecordBuilder(mem, sc) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + + listbldr := bldr.Field(1).(*array.ListBuilder) + listvalbldr := listbldr.ValueBuilder().(*array.StringBuilder) + listbldr.Append(true) + listvalbldr.Append("one") + listbldr.Append(true) + listvalbldr.Append("two") + listbldr.Append(true) + listvalbldr.Append("three") + + rec := bldr.NewRecord() + defer rec.Release() + + expectedSchema := arrow.NewSchema([]arrow.Field{ + { + Name: "col_int64", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + { + Name: "col_list", Type: arrow.BinaryTypes.String, + Nullable: true, + }, + }, nil) + + expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(` + [ + { + "col_int64": 1, + "col_list": "[\n \"one\"\n]" + }, + { + "col_int64": 2, + "col_list": "[\n \"two\"\n]" + }, + { + "col_int64": 3, + "col_list": "[\n \"three\"\n]" + } + ] + `))) + require.NoError(t, err) + defer expectedRecord.Release() + + withQuirks(t, func(q *SnowflakeQuirks) { + drv := q.SetupDriver(t) + opts := q.DatabaseOptions() + // initialize connection with this session parameter set so that ingest + // and DropTable will both ignore the casing for the quoted identifiers + opts[driver.OptionQuotedIdentifiersIgnoreCase] = adbc.OptionValueEnabled + + db, err := drv.NewDatabase(opts) + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + cnxn, err := db.Open(ctx) + require.NoError(t, err) + defer cnxn.Close() + + require.NoError(t, q.DropTable(cnxn, "bulk_ingest_list")) + + stmt, err := cnxn.NewStatement() + require.NoError(t, err) + defer stmt.Close() + + require.NoError(t, stmt.Bind(ctx, rec)) + require.NoError(t, stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_list")) + n, err := stmt.ExecuteUpdate(ctx) + require.NoError(t, err) + assert.EqualValues(t, 3, n) + + // disable the quoted identifiers option to get the default behavior back + require.NoError(t, cnxn.(adbc.GetSetOptions). + SetOption(driver.OptionQuotedIdentifiersIgnoreCase, adbc.OptionValueDisabled)) + // with the option disabled this query should error because wrapping with quotes + // would preserve the case and the table wouldn't exist + require.NoError(t, stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_list" order by col_int64 ASC`)) + _, _, err = stmt.ExecuteQuery(ctx) + assert.Error(t, err) + + // confirm that our ingested table is using uppercase because of our option usage + require.NoError(t, stmt.SetSqlQuery("SELECT * FROM BULK_INGEST_LIST order by col_int64 ASC")) + + rdr, n, err := stmt.ExecuteQuery(ctx) + require.NoError(t, err) + defer rdr.Release() + + assert.EqualValues(t, 3, n) + assert.True(t, rdr.Next()) + result := rdr.Record() + assert.Truef(t, array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result) + logicalTypeList, ok := result.Schema().Field(1).Metadata.GetValue("logicalType") + assert.True(t, ok) + assert.Equal(t, "ARRAY", logicalTypeList) + + assert.False(t, rdr.Next()) + require.NoError(t, rdr.Err()) + }) +} diff --git a/go/adbc/driver/snowflake/snowflake_database.go b/go/adbc/driver/snowflake/snowflake_database.go index 581d9733e4..ae11fcbfa1 100644 --- a/go/adbc/driver/snowflake/snowflake_database.go +++ b/go/adbc/driver/snowflake/snowflake_database.go @@ -49,6 +49,8 @@ var ( } ) +const quotedIdentifiersIgnoreCase = "QUOTED_IDENTIFIERS_IGNORE_CASE" + type databaseImpl struct { driverbase.DatabaseImplBase cfg *gosnowflake.Config @@ -130,6 +132,12 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil + case OptionQuotedIdentifiersIgnoreCase: + v, exists := d.cfg.Params[quotedIdentifiersIgnoreCase] + if !exists { + return adbc.OptionValueDisabled, nil + } + return *v, nil default: val, ok := d.cfg.Params[key] if ok { @@ -427,6 +435,17 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { Code: adbc.StatusInvalidArgument, } } + case OptionQuotedIdentifiersIgnoreCase: + switch v { + case adbc.OptionValueEnabled, adbc.OptionValueDisabled: + d.cfg.Params[quotedIdentifiersIgnoreCase] = &v + default: + return adbc.Error{ + Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", + OptionQuotedIdentifiersIgnoreCase, v), + Code: adbc.StatusInvalidArgument, + } + } default: d.cfg.Params[k] = &v } @@ -445,6 +464,13 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { return nil, errToAdbcErr(adbc.StatusIO, err) } + var quoteIgnoreCase bool + + v, exists := d.cfg.Params[quotedIdentifiersIgnoreCase] + if exists { + quoteIgnoreCase = *v == adbc.OptionValueEnabled + } + conn := &connectionImpl{ cn: cn.(snowflakeConn), db: d, ctor: connector, @@ -452,8 +478,9 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { // default enable high precision // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to // get Int64/Float64 instead - useHighPrecision: d.useHighPrecision, - ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + useHighPrecision: d.useHighPrecision, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + quotedIdentIgnoreCase: quoteIgnoreCase, } return driverbase.NewConnectionBuilder(conn). diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py index b939c0b24b..aa0c1ebeae 100644 --- a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py +++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py @@ -84,6 +84,14 @@ class DatabaseOptions(enum.Enum): OCSP_FAIL_OPEN_MODE = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode" PORT = "adbc.snowflake.sql.uri.port" PROTOCOL = "adbc.snowflake.sql.uri.protocol" + #: Control the QUOTED_IDENTIFIERS_IGNORE_CASE snowflake session parameter as + #: described by snowflake parameter docs for #label-quoted-identifiers-ignore-case + #: This defaults to false as per the Snowflake documentation. This is + #: important for managing the table names created when using bulk_ingest + #: since we will wrap any identifiers in quotes by default. Behavior is not + #: defined when mixing this with manually running ALTER SESSION queries to + #: set the variable. + QUOTED_IDENTIFIERS_IGNORE_CASE = "adbc.snowflake.sql.quoted_identifiers_ignore_case" REGION = "adbc.snowflake.sql.region" #: request retry timeout EXCLUDING network roundtrip and reading http response #: use format like http://pkg.go.dev/time#ParseDuration such as