Skip to content

Commit

Permalink
refactor(c/driver/postgresql): Use copy writer in BindStream for para…
Browse files Browse the repository at this point in the history
…meter binding (apache#2157)

This PR refactors the `BindStream` to use the COPY writer instead of its
own serialization logic. This logic is the same between insertion and
binding, although the API used to send it to the database is slightly
different. This is a helpful consolidation of logic and means that
adding type support on the Arrow or Postgresql side (or fixing bugs in
the inference or serialization) is slightly easier. It also means we can
bind parameters that are lists and is a workaround for inserting into a
table with a schema where Postgres knows how to cast the types (but ADBC
might not yet).

``` r
library(adbcdrivermanager)
library(nanoarrow)

con <- adbc_database_init(
  adbcpostgresql::adbcpostgresql(),
  uri = "postgresql://localhost:5432/postgres?user=postgres&password=password"
) |> 
  adbc_connection_init()

# Create an array with a uint32 column
df <- tibble::tibble(uint32_col = 1:5)
array <- df |> 
  nanoarrow::as_nanoarrow_array(
    schema = na_struct(list(uint32_col = na_uint32()))
  )

# Create a table with an integer column
con |> execute_adbc("DROP TABLE IF EXISTS adbc_test")
con |> execute_adbc("CREATE TABLE adbc_test (uint32_col int4)")

# This will fail (types not identical)
array |> write_adbc(con, "adbc_test", mode = "append")
#> Error in adbc_statement_execute_query(stmt): INVALID_ARGUMENT: [libpq] Failed to execute COPY statement: PGRES_FATAL_ERROR ERROR:  incorrect binary data format
#> CONTEXT:  COPY adbc_test, line 1, column uint32_col

con |> 
  execute_adbc("INSERT INTO adbc_test VALUES ($1)", bind = array)
con |> 
  read_adbc("SELECT * FROM adbc_test") |> 
  tibble::as_tibble()
#> # A tibble: 5 × 1
#>   uint32_col
#>        <int>
#> 1          1
#> 2          2
#> 3          3
#> 4          4
#> 5          5
```

<sup>Created on 2024-09-12 with [reprex
v2.1.1](https://reprex.tidyverse.org)</sup>
  • Loading branch information
paleolimbot authored Sep 20, 2024
1 parent fa8ea10 commit 46dc748
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 365 deletions.
418 changes: 102 additions & 316 deletions c/driver/postgresql/bind_stream.h

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1139,10 +1139,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) {
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
ASSERT_THAT(error.message,
::testing::HasSubstr("Row #1 has value '9223372036854775807' which "
"exceeds PostgreSQL timestamp limits"));
IsStatus(ADBC_STATUS_INTERNAL, &error));
ASSERT_THAT(
error.message,
::testing::HasSubstr(
"Row 0 timestamp value 9223372036854775807 with unit 0 would overflow"));
}

{
Expand All @@ -1169,10 +1170,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) {
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
ASSERT_THAT(error.message,
::testing::HasSubstr("Row #1 has value '-9223372036854775808' which "
"exceeds PostgreSQL timestamp limits"));
IsStatus(ADBC_STATUS_INTERNAL, &error));
ASSERT_THAT(
error.message,
::testing::HasSubstr(
"Row 0 timestamp value -9223372036854775808 with unit 0 would overflow"));
}
}

Expand Down
4 changes: 2 additions & 2 deletions c/driver/postgresql/result_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected,
// there is a result with more than zero rows to populate.
if (bind_stream_) {
RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error));
RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error));
RAISE_ADBC(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_, error));
RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error));

RAISE_ADBC(BindNextAndExecute(nullptr, error));
Expand Down Expand Up @@ -251,7 +251,7 @@ AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError
// stream (if there is one) or execute the query without binding.
if (bind_stream_) {
RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error));
RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error));
RAISE_ADBC(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_, error));
RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error));

// Reset affected rows to zero before binding and executing any
Expand Down
15 changes: 14 additions & 1 deletion c/driver/postgresql/result_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

#pragma once

#if !defined(NOMINMAX)
#define NOMINMAX
#endif

#include <memory>
#include <string>
#include <utility>
Expand All @@ -34,13 +38,21 @@ class PqResultArrayReader {
public:
PqResultArrayReader(PGconn* conn, std::shared_ptr<PostgresTypeResolver> type_resolver,
std::string query)
: conn_(conn), helper_(conn, std::move(query)), type_resolver_(type_resolver) {
: conn_(conn),
helper_(conn, std::move(query)),
type_resolver_(type_resolver),
autocommit_(false) {
ArrowErrorInit(&na_error_);
error_ = ADBC_ERROR_INIT;
}

~PqResultArrayReader() { ResetErrors(); }

// Ensure the reader knows what the autocommit status was on creation. This is used
// so that the temporary timezone setting required for parameter binding can be wrapped
// in a transaction (or not) accordingly.
void SetAutocommit(bool autocommit) { autocommit_ = autocommit; }

void SetBind(struct ArrowArrayStream* stream) {
bind_stream_ = std::make_unique<BindStream>();
bind_stream_->SetBind(stream);
Expand All @@ -62,6 +74,7 @@ class PqResultArrayReader {
std::shared_ptr<PostgresTypeResolver> type_resolver_;
std::vector<std::unique_ptr<PostgresCopyFieldReader>> field_readers_;
nanoarrow::UniqueSchema schema_;
bool autocommit_;
struct AdbcError error_;
struct ArrowError na_error_;

Expand Down
1 change: 1 addition & 0 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream,
int64_t* rows_affected,
struct AdbcError* error) {
PqResultArrayReader reader(connection_->conn(), type_resolver_, query_);
reader.SetAutocommit(connection_->autocommit());
reader.SetBind(&bind_);
RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error));
return ADBC_STATUS_OK;
Expand Down
14 changes: 10 additions & 4 deletions c/driver/sqlite/sqlite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,8 @@ class SqliteStatement : public driver::Statement<SqliteStatement> {
"parameter count mismatch: expected {} but found {}", expected, actual);
}

int64_t rows = 0;
int64_t output_rows = 0;
int64_t changed_rows = 0;

SqliteMutexGuard guard(conn_);

Expand All @@ -1027,7 +1028,11 @@ class SqliteStatement : public driver::Statement<SqliteStatement> {
}

while (sqlite3_step(stmt_) == SQLITE_ROW) {
rows++;
output_rows++;
}

if (sqlite3_column_count(stmt_) == 0) {
changed_rows += sqlite3_changes(conn_);
}

if (!binder_.schema.release) break;
Expand All @@ -1041,9 +1046,10 @@ class SqliteStatement : public driver::Statement<SqliteStatement> {
}

if (sqlite3_column_count(stmt_) == 0) {
rows = sqlite3_changes(conn_);
return changed_rows;
} else {
return output_rows;
}
return rows;
}

Result<int64_t> ExecuteUpdateImpl(PreparedState& state) { return ExecuteUpdateImpl(); }
Expand Down
5 changes: 4 additions & 1 deletion c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class StatementTest {
void TestNewInit();
void TestRelease();

// ---- Type-specific tests --------------------
// ---- Type-specific ingest tests -------------

void TestSqlIngestBool();

Expand Down Expand Up @@ -427,6 +427,8 @@ class StatementTest {
void TestSqlPrepareErrorNoQuery();
void TestSqlPrepareErrorParamCountMismatch();

void TestSqlBind();

void TestSqlQueryEmpty();
void TestSqlQueryInts();
void TestSqlQueryFloats();
Expand Down Expand Up @@ -533,6 +535,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \
TestSqlPrepareErrorParamCountMismatch(); \
} \
TEST_F(FIXTURE, SqlBind) { TestSqlBind(); } \
TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \
TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
Expand Down
65 changes: 65 additions & 0 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,71 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() {
::testing::Not(IsOkStatus(&error)));
}

void StatementTest::TestSqlBind() {
if (!quirks()->supports_dynamic_parameter_binding()) {
GTEST_SKIP();
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

ASSERT_THAT(quirks()->DropTable(&connection, "bindtest", &error), IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement, "CREATE TABLE bindtest (col1 INTEGER, col2 TEXT)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(MakeSchema(&schema.value,
{{"", NANOARROW_TYPE_INT32}, {"", NANOARROW_TYPE_STRING}}),
IsOkErrno());

std::vector<std::optional<int32_t>> int_values{std::nullopt, -123, 123};
std::vector<std::optional<std::string>> string_values{"abc", std::nullopt, "defg"};

int batch_result = MakeBatch<int32_t, std::string>(
&schema.value, &array.value, &na_error, int_values, string_values);
ASSERT_THAT(batch_result, IsOkErrno());

auto insert_query = std::string("INSERT INTO bindtest VALUES (") +
quirks()->BindParameter(0) + ", " + quirks()->BindParameter(1) +
")";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, insert_query.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
IsOkStatus(&error));
int64_t rows_affected = -10;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(-1), ::testing::Eq(3)));

ASSERT_THAT(
AdbcStatementSetSqlQuery(
&statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->length, 3);
CompareArray(reader.array_view->children[0], int_values);
CompareArray(reader.array_view->children[1], string_values);

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
}
}

void StatementTest::TestSqlQueryEmpty() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

Expand Down
46 changes: 13 additions & 33 deletions c/validation/adbc_validation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,42 +401,22 @@ void CompareArray(struct ArrowArrayView* array,
SCOPED_TRACE("Array index " + std::to_string(i));
if (v.has_value()) {
ASSERT_FALSE(ArrowArrayViewIsNull(array, i));
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same<T, float>::value || std::is_same<T, double>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
} else if constexpr (std::is_same<T, double>::value) {
ASSERT_EQ(ArrowArrayViewGetDoubleUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, bool>::value ||
std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value ||
std::is_same<T, int64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_double[i]);
} else if constexpr (std::is_same<T, float>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
} else if constexpr (std::is_same<T, bool>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i));
} else if constexpr (std::is_same<T, int8_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]);
} else if constexpr (std::is_same<T, int16_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int16[i]);
} else if constexpr (std::is_same<T, int32_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int32[i]);
} else if constexpr (std::is_same<T, int64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int64[i]);
} else if constexpr (std::is_same<T, uint8_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_uint8[i]);
} else if constexpr (std::is_same<T, uint16_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_uint16[i]);
} else if constexpr (std::is_same<T, uint32_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_uint32[i]);
} else if constexpr (std::is_same<T, uint64_t>::value) {
ASSERT_EQ(ArrowArrayViewGetIntUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, uint8_t>::value ||
std::is_same<T, uint16_t>::value ||
std::is_same<T, uint32_t>::value ||
std::is_same<T, uint64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_uint64[i]);
ASSERT_EQ(ArrowArrayViewGetUIntUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, std::string>::value) {
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
std::string str(view.data, view.size_bytes);
Expand Down

0 comments on commit 46dc748

Please sign in to comment.