Skip to content

Commit

Permalink
fix(go/adbc/pkg): follow CGO rules properly
Browse files Browse the repository at this point in the history
Fixes apache#729.
  • Loading branch information
lidavidm committed Jul 13, 2023
1 parent e0bc951 commit 46fd728
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 7 deletions.
103 changes: 103 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

#include <algorithm>
#include <cerrno>
#include <chrono>
#include <cstring>
#include <limits>
#include <optional>
#include <random>
#include <string>
#include <string_view>
#include <thread>
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -145,6 +148,106 @@ void DatabaseTest::TestRelease() {
ASSERT_EQ(nullptr, database.private_data);
}

void DatabaseTest::TestGarbageInput() {
// Regression test for https://github.com/apache/arrow-adbc/issues/729

// 0xc000000000 is the base of the Go heap. Go's write barriers ask
// the GC to mark both the pointer being written, and the pointer
// being *overwritten*. So if Go overwrites a value in a C
// structure that looks like a Go pointer, the GC may get confused
// and error.
void* bad_pointer = reinterpret_cast<void*>(uintptr_t(0xc000000240));

// ADBC functions are expected not to blindly overwrite an
// already-allocated value/callers are expected to zero-initialize.
database.private_data = bad_pointer;
database.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcDatabaseNew(&database, &error), ::testing::Not(IsOkStatus(&error)));

std::memset(&database, 0, sizeof(database));
ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&error));
ASSERT_THAT(quirks()->SetupDatabase(&database, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseInit(&database, &error), IsOkStatus(&error));

struct AdbcConnection connection;
connection.private_data = bad_pointer;
connection.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcConnectionNew(&connection, &error), ::testing::Not(IsOkStatus(&error)));

std::memset(&connection, 0, sizeof(connection));
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));

struct AdbcStatement statement;
statement.private_data = bad_pointer;
statement.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
::testing::Not(IsOkStatus(&error)));

// This needs to happen in parallel since we need to trigger the
// write barrier buffer, which means we need to trigger a GC. The
// Go FFI bridge deterministically triggers GC on Release calls.

auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
std::vector<std::thread> threads;
std::random_device rd;
for (int i = 0; i < 23; i++) {
auto seed = rd();
threads.emplace_back([&, seed]() {
std::mt19937 gen(seed);
std::uniform_int_distribution<long> dist(0xc000000000L, 0xc000002000L);
for (int i = 0; i < 23; i++) {
void* bad_pointer = reinterpret_cast<void*>(uintptr_t(dist(gen)));

struct AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 1", &error),
IsOkStatus(&error));
// This is not expected to be zero-initialized
struct ArrowArrayStream stream;
stream.private_data = bad_pointer;
stream.release =
reinterpret_cast<void (*)(struct ArrowArrayStream*)>(bad_pointer);
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &stream, nullptr, &error),
IsOkStatus(&error));

struct ArrowSchema schema;
std::memset(&schema, 0, sizeof(schema));
schema.name = reinterpret_cast<const char*>(bad_pointer);
schema.format = reinterpret_cast<const char*>(bad_pointer);
schema.private_data = bad_pointer;
ASSERT_THAT(stream.get_schema(&stream, &schema), IsOkErrno());

while (true) {
struct ArrowArray array;
array.private_data = bad_pointer;
ASSERT_THAT(stream.get_next(&stream, &array), IsOkErrno());
if (array.release) {
array.release(&array);
} else {
break;
}
}

schema.release(&schema);
stream.release(&stream);
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}
});
}
for (auto& thread : threads) {
thread.join();
}
}

ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
}

//------------------------------------------------------------
// Tests of AdbcConnection

Expand Down
4 changes: 3 additions & 1 deletion c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class DatabaseTest {
// Test methods
void TestNewInit();
void TestRelease();
void TestGarbageInput();

protected:
struct AdbcError error;
Expand All @@ -140,7 +141,8 @@ class DatabaseTest {
static_assert(std::is_base_of<adbc_validation::DatabaseTest, FIXTURE>::value, \
ADBCV_STRINGIFY(FIXTURE) " must inherit from DatabaseTest"); \
TEST_F(FIXTURE, NewInit) { TestNewInit(); } \
TEST_F(FIXTURE, Release) { TestRelease(); }
TEST_F(FIXTURE, Release) { TestRelease(); } \
TEST_F(FIXTURE, GarbageInput) { TestGarbageInput(); }

class ConnectionTest {
public:
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ func {{.Prefix}}StatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcS
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
if stmt.private_data != nil {
setErr(err, "AdbcStatementNew: statement already allocated")
return C.ADBC_STATUS_INVALID_STATE
}

conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
Expand Down
11 changes: 11 additions & 0 deletions go/adbc/pkg/_tmpl/utils.c.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "utils.h"

#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -74,6 +76,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, out, error);
}

Expand All @@ -83,6 +86,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
Expand All @@ -92,13 +96,15 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}ConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}

AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetTableTypes(connection, out, error);
}

Expand All @@ -107,6 +113,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
Expand Down Expand Up @@ -136,6 +143,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error);
}

Expand Down Expand Up @@ -170,6 +178,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}StatementGetParameterSchema(statement, schema, error);
}

Expand All @@ -183,6 +192,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
if (partitions) memset(partitions, 0, sizeof(*partitions));
return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/flightsql/driver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion go/adbc/pkg/flightsql/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

// clang-format off
//go:build driverlib
// clang-format on
// clang-format on

#include "utils.h"

#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, out,
error);
}
Expand All @@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
Expand All @@ -95,13 +99,15 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}

AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetTableTypes(connection, out, error);
}

Expand All @@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
Expand Down Expand Up @@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
if (out) memset(out, 0, sizeof(*out));
return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error);
}

Expand Down Expand Up @@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLStatementGetParameterSchema(statement, schema, error);
}

Expand All @@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
if (schema) memset(schema, 0, sizeof(*schema));
if (partitions) memset(partitions, 0, sizeof(*partitions));
return FlightSQLStatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/pkg/flightsql/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

// clang-format off
//go:build driverlib
// clang-format on
// clang-format on

#pragma once

Expand Down
5 changes: 5 additions & 0 deletions go/adbc/pkg/panicdummy/driver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 46fd728

Please sign in to comment.