Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/snowflake): Removing SQL injection to get table name with special character for getObjectsTables #1338

84 changes: 77 additions & 7 deletions csharp/test/Drivers/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
Expand All @@ -16,6 +16,7 @@
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow.Adbc.Tests.Metadata;
Expand Down Expand Up @@ -110,10 +111,7 @@ public void CanExecuteUpdate()
for (int i = 0; i < queries.Length; i++)
{
string query = queries[i];
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;

UpdateResult updateResult = statement.ExecuteUpdate();
UpdateResult updateResult = ExecuteUpdateStatement(query);

Assert.Equal(expectedResults[i], updateResult.AffectedRows);
}
Expand Down Expand Up @@ -279,17 +277,64 @@ public void CanGetObjectsAll()
{
IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c => c.XdbcTypeName == "NUMBER");

if(highPrecisionColumns.Count() > 0)
if (highPrecisionColumns.Count() > 0)
{
// ensure they all are coming back as XdbcDataType_XDBC_DECIMAL because they are Decimal128
short XdbcDataType_XDBC_DECIMAL = 3;
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
int count = invalidHighPrecisionColumns.Count();
Assert.True(count == 0, $"There are {count} columns that do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
}
}
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a Special Character.
/// </summary>
[SkippableTheory, Order(3)]
[InlineData(@"DEMO_DB",@"PUBLIC","MyIdentifier")]
[InlineData(@"DEMO_DB", @"PUBLIC_SCHEMA","my.identifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "my identifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "My 'Identifier'")]
[InlineData(@"DEMO_DB", @"PUBLIC", "3rd_identifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "$Identifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "My ^Identifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "My ^Ident~ifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", @"My\^Ident~ifier")]
[InlineData(@"DEMO_DB", @"PUBLIC", "идентификатор")]
[InlineData(@"DEMO_DB", @"PUBLIC", @"ADBCTest_""ALL""TYPES")]
[InlineData(@"DEMO_DB", @"PUBLIC", @"ADBC\TEST""\TAB_""LE")]
[InlineData(@"DEMO_DB", @"PUBLIC", "ONE")]
public void CanGetObjectsTablesWithSpecialCharacter(string databaseName, string schemaName, string tableName)
{
CreateTemporaryTable(databaseName, schemaName, tableName);

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Tables,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableName,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcTable> tables = catalogs
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault();

AdbcTable table = tables.FirstOrDefault();

Assert.True(table != null, "table should not be null");
Assert.Equal(tableName, table.Name, true);
}

/// <summary>
/// Validates if the driver can call GetTableSchema.
/// </summary>
Expand Down Expand Up @@ -354,6 +399,31 @@ public void CanExecuteQuery()
Tests.DriverTests.CanExecuteQuery(queryResult, _testConfiguration.ExpectedResultsCount);
}

private void CreateTemporaryTable(string databaseName, string schemaName, string tableName)
{
databaseName = databaseName.Replace("\"", "\"\"");
string createDatabase = string.Format("CREATE DATABASE IF NOT EXISTS \"{0}\"", databaseName);
ExecuteUpdateStatement(createDatabase);

schemaName = schemaName.Replace("\"", "\"\"");
string createSchema = string.Format("CREATE SCHEMA IF NOT EXISTS \"{0}\".\"{1}\"", databaseName, schemaName);
ExecuteUpdateStatement(createSchema);

tableName = tableName.Replace("\"", "\"\"");
string fullyQualifiedTableName = string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
string createTableStatement = string.Format("CREATE OR REPLACE TABLE {0} (INDEX INT)", fullyQualifiedTableName);
ExecuteUpdateStatement(createTableStatement);

}

private UpdateResult ExecuteUpdateStatement(string query)
{
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;
UpdateResult updateResult = statement.ExecuteUpdate();
return updateResult;
}

private static string GetPartialNameForPatternMatch(string name)
{
if (string.IsNullOrEmpty(name) || name.Length == 1) return name;
Expand Down
26 changes: 17 additions & 9 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,14 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
}

conditions := make([]string, 0)
queryArgs := make([]interface{}, 0, 2)
AnithaPanduranganMS marked this conversation as resolved.
Show resolved Hide resolved
if catalog != nil && *catalog != "" {
conditions = append(conditions, ` CATALOG_NAME ILIKE '`+*catalog+`'`)
conditions = append(conditions, ` CATALOG_NAME ILIKE ? `)
AnithaPanduranganMS marked this conversation as resolved.
Show resolved Hide resolved
queryArgs = append(queryArgs, *catalog)
}
if dbSchema != nil && *dbSchema != "" {
conditions = append(conditions, ` SCHEMA_NAME ILIKE '`+*dbSchema+`'`)
conditions = append(conditions, ` SCHEMA_NAME ILIKE ? `)
queryArgs = append(queryArgs, *dbSchema)
}

cond := strings.Join(conditions, " AND ")
Expand All @@ -297,7 +300,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
query += " WHERE " + cond
}
var rows *sql.Rows
rows, err = c.sqldb.QueryContext(ctx, query)
rows, err = c.sqldb.QueryContext(ctx, query, queryArgs...)
if err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
Expand Down Expand Up @@ -486,14 +489,18 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
includeSchema := depth == adbc.ObjectDepthAll || depth == adbc.ObjectDepthColumns

conditions := make([]string, 0)
queryArgs := make([]interface{}, 0, 3)
if catalog != nil && *catalog != "" {
conditions = append(conditions, ` TABLE_CATALOG ILIKE '`+*catalog+`'`)
conditions = append(conditions, ` TABLE_CATALOG ILIKE ? `)
queryArgs = append(queryArgs, *catalog)
}
if dbSchema != nil && *dbSchema != "" {
conditions = append(conditions, ` TABLE_SCHEMA ILIKE '`+*dbSchema+`'`)
conditions = append(conditions, ` TABLE_SCHEMA ILIKE ? `)
queryArgs = append(queryArgs, *dbSchema)
}
if tableName != nil && *tableName != "" {
conditions = append(conditions, ` TABLE_NAME ILIKE '`+*tableName+`'`)
conditions = append(conditions, ` TABLE_NAME ILIKE ? `)
queryArgs = append(queryArgs, *tableName)
}

// first populate the tables and table types
Expand All @@ -510,7 +517,8 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
if cond != "" {
query += " WHERE " + cond
}
rows, err = c.sqldb.QueryContext(ctx, query)

rows, err = c.sqldb.QueryContext(ctx, query, queryArgs...)
if err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
Expand Down Expand Up @@ -828,7 +836,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
if dbSchema != nil {
tblParts = append(tblParts, strconv.Quote(strings.ToUpper(*dbSchema)))
}
tblParts = append(tblParts, strconv.Quote(strings.ToUpper(tableName)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should string.ToUpper conversion be avoided for catalog and schema name too?

tblParts = append(tblParts, strconv.Quote(tableName))
fullyQualifiedTable := strings.Join(tblParts, ".")

rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE `+fullyQualifiedTable)
Expand Down Expand Up @@ -1038,4 +1046,4 @@ func (c *cnxn) SetOptionDouble(key string, value float64) error {
Msg: "[Snowflake] unknown connection option",
Code: adbc.StatusNotImplemented,
}
}
}
AnithaPanduranganMS marked this conversation as resolved.
Show resolved Hide resolved
Loading