Skip to content

Commit

Permalink
fix(go/adbc/pkg): follow CGO rules properly
Browse files Browse the repository at this point in the history
Fixes apache#729.
  • Loading branch information
lidavidm committed Jul 13, 2023
1 parent e0bc951 commit dad4bfb
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 5 deletions.
106 changes: 106 additions & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
// specific language governing permissions and limitations
// under the License.

#include <chrono>
#include <random>
#include <thread>

#include <adbc.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest-matchers.h>
#include <gtest/gtest-param-test.h>
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.h>

#include "validation/adbc_validation.h"
#include "validation/adbc_validation_util.h"

using adbc_validation::IsOkErrno;
using adbc_validation::IsOkStatus;

#define CHECK_OK(EXPR) \
Expand Down Expand Up @@ -103,6 +109,106 @@ class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::Data
};
ADBCV_TEST_DATABASE(SqliteFlightSqlTest)

TEST_F(SqliteFlightSqlTest, TestGarbageInput) {
// Regression test for https://github.com/apache/arrow-adbc/issues/729

// 0xc000000000 is the base of the Go heap. Go's write barriers ask
// the GC to mark both the pointer being written, and the pointer
// being *overwritten*. So if Go overwrites a value in a C
// structure that looks like a Go pointer, the GC may get confused
// and error.
void* bad_pointer = reinterpret_cast<void*>(uintptr_t(0xc000000240));

// ADBC functions are expected not to blindly overwrite an
// already-allocated value/callers are expected to zero-initialize.
database.private_data = bad_pointer;
database.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcDatabaseNew(&database, &error), ::testing::Not(IsOkStatus(&error)));

std::memset(&database, 0, sizeof(database));
ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&error));
ASSERT_THAT(quirks()->SetupDatabase(&database, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&database, &error), IsOkStatus(&error));

struct AdbcConnection connection;
connection.private_data = bad_pointer;
connection.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcConnectionNew(&connection, &error), ::testing::Not(IsOkStatus(&error)));

std::memset(&connection, 0, sizeof(connection));
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));

struct AdbcStatement statement;
statement.private_data = bad_pointer;
statement.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
::testing::Not(IsOkStatus(&error)));

// This needs to happen in parallel since we need to trigger the
// write barrier buffer, which means we need to trigger a GC. The
// Go FFI bridge deterministically triggers GC on Release calls.

auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
std::vector<std::thread> threads;
std::random_device rd;
for (int i = 0; i < 23; i++) {
auto seed = rd();
threads.emplace_back([&, seed]() {
std::mt19937 gen(seed);
std::uniform_int_distribution<int64_t> dist(0xc000000000L, 0xc000002000L);
for (int i = 0; i < 23; i++) {
void* bad_pointer = reinterpret_cast<void*>(uintptr_t(dist(gen)));

struct AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 1", &error),
IsOkStatus(&error));
// This is not expected to be zero-initialized
struct ArrowArrayStream stream;
stream.private_data = bad_pointer;
stream.release =
reinterpret_cast<void (*)(struct ArrowArrayStream*)>(bad_pointer);
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &stream, nullptr, &error),
IsOkStatus(&error));

struct ArrowSchema schema;
std::memset(&schema, 0, sizeof(schema));
schema.name = reinterpret_cast<const char*>(bad_pointer);
schema.format = reinterpret_cast<const char*>(bad_pointer);
schema.private_data = bad_pointer;
ASSERT_THAT(stream.get_schema(&stream, &schema), IsOkErrno());

while (true) {
struct ArrowArray array;
array.private_data = bad_pointer;
ASSERT_THAT(stream.get_next(&stream, &array), IsOkErrno());
if (array.release) {
array.release(&array);
} else {
break;
}
}

schema.release(&schema);
stream.release(&stream);
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}
});
}
for (auto& thread : threads) {
thread.join();
}
}

ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
}

class SqliteFlightSqlConnectionTest : public ::testing::Test,
public adbc_validation::ConnectionTest {
public:
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.18

require (
github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553
github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355
github.com/bluele/gcache v0.0.2
github.com/google/uuid v1.3.0
github.com/snowflakedb/gosnowflake v1.6.21
Expand Down
6 changes: 2 additions & 4 deletions go/adbc/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/apache/arrow/go/v12 v12.0.0 h1:xtZE63VWl7qLdB0JObIXvvhGjoVNrQ9ciIHG2OK5cmc=
github.com/apache/arrow/go/v12 v12.0.0/go.mod h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646 h1:hLcsUn9hiiD7jDfJDKOe1tBfOL5v0wgrya5S8XXqzLw=
github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553 h1:LV3nIWJ2254APRpYAcMxWbxoQwt66gnrkZ5NaDs1IPI=
github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 h1:QuXqLb2HzL5EjY99fFp+iG9NagAruvQIbU/2++x+2VY=
github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY=
github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU=
github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY=
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ func {{.Prefix}}StatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcS
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
if stmt.private_data != nil {
setErr(err, "AdbcStatementNew: statement already allocated")
return C.ADBC_STATUS_INVALID_STATE
}

conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
Expand Down
11 changes: 11 additions & 0 deletions go/adbc/pkg/_tmpl/utils.c.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "utils.h"

#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -74,6 +76,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, out, error);
}

Expand All @@ -83,6 +86,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
Expand All @@ -92,13 +96,15 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}ConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}

AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetTableTypes(connection, out, error);
}

Expand All @@ -107,6 +113,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
Expand Down Expand Up @@ -136,6 +143,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error);
}

Expand Down Expand Up @@ -170,6 +178,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}StatementGetParameterSchema(statement, schema, error);
}

Expand All @@ -183,6 +192,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
if (partitions) memset(partitions, 0, sizeof(*partitions));
return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/flightsql/driver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions go/adbc/pkg/flightsql/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include "utils.h"

#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, out,
error);
}
Expand All @@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
Expand All @@ -95,13 +99,15 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}

AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetTableTypes(connection, out, error);
}

Expand All @@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
Expand Down Expand Up @@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error);
}

Expand Down Expand Up @@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLStatementGetParameterSchema(statement, schema, error);
}

Expand All @@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
if (partitions) memset(partitions, 0, sizeof(*partitions));
return FlightSQLStatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/panicdummy/driver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dad4bfb

Please sign in to comment.