Skip to content

Commit

Permalink
feat(c/driver/postgresql): Implement consuming a PGresult via the cop…
Browse files Browse the repository at this point in the history
…y reader (#2029)

I started this PR wanting to get queries with parameters able to return
their results; however, this turned into a PR leaning in to the
`PqResultHelper` because it was helpful to export arrays from the
`PGresult*` but wasn't quite general enough. I did a second bit of
shuffling to make it (possibly, or maybe just for me) easier to
understand what path gets taken on `ExecuteQuery()`.

Some side effects of these changes are that we can now support multiple
statements in the same query (by using `PQexec()` instead of
`PQexecParams()` when there is no output requested) and that we can
`ExecuteSchema()` for all parameterized queries.

The actual feature is that a user can set `adbc.postgresql.use_copy =
FALSE` to force a non-COPY path for queries that aren't supported there.
Because we request binary data, we can use all the same infrastructure
for converting the results! I have only one test for this although I did
run the whole test suite in C++ and Python...there are still a few
missing features (batch size hint, large string overflow, error detail,
cancel) but most tests pass using either path.

I'm happy to split this up if that is easier! I'm also planning to
document the helper (but wanted a first round of review before
documenting the behaviour to make sure it's behaviour we actually want).

Closes #855, Closes #2035.

``` r
library(adbcdrivermanager)
#> Warning: package 'adbcdrivermanager' was built under R version 4.3.3

con <- adbc_database_init(
  adbcpostgresql::adbcpostgresql(),
  uri = "postgresql://localhost:5432/postgres?user=postgres&password=password"
) |> 
  adbc_connection_init()

nycflights13::flights |> 
  write_adbc(con, "flights")

stream <- nanoarrow::nanoarrow_allocate_array_stream()
rows <- con |> 
  adbc_statement_init(adbc.postgresql.use_copy = FALSE) |> 
  adbc_statement_set_sql_query(
    "SELECT * from flights where month = 1 AND day = 1"
  ) |> 
  adbc_statement_prepare() |>
  adbc_statement_execute_query(stream)

rows
#> [1] 842

tibble::as_tibble(stream)
#> # A tibble: 842 × 19
#>     year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
#>    <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
#>  1  2013     1     1      517            515         2      830            819
#>  2  2013     1     1      533            529         4      850            830
#>  3  2013     1     1      542            540         2      923            850
#>  4  2013     1     1      544            545        -1     1004           1022
#>  5  2013     1     1      554            600        -6      812            837
#>  6  2013     1     1      554            558        -4      740            728
#>  7  2013     1     1      555            600        -5      913            854
#>  8  2013     1     1      557            600        -3      709            723
#>  9  2013     1     1      557            600        -3      838            846
#> 10  2013     1     1      558            600        -2      753            745
#> # ℹ 832 more rows
#> # ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#> #   hour <dbl>, minute <dbl>, time_hour <dttm>

con |> 
  execute_adbc("DROP TABLE flights")
```

<sup>Created on 2024-07-25 with [reprex
v2.1.0](https://reprex.tidyverse.org)</sup>
  • Loading branch information
paleolimbot authored Aug 7, 2024
1 parent 45cd9be commit 05fa60d
Show file tree
Hide file tree
Showing 9 changed files with 707 additions and 265 deletions.
55 changes: 20 additions & 35 deletions c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,10 @@ class PqGetObjectsHelper {
params.push_back(db_schema_);
}

auto result_helper =
PqResultHelper{conn_, std::string(query.buffer), params, error_};
auto result_helper = PqResultHelper{conn_, std::string(query.buffer)};
StringBuilderReset(&query);

RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
RAISE_ADBC(result_helper.Execute(error_, params));

for (PqResultRow row : result_helper) {
const char* schema_name = row[0].data;
Expand Down Expand Up @@ -188,12 +186,10 @@ class PqGetObjectsHelper {
params.push_back(catalog_);
}

PqResultHelper result_helper =
PqResultHelper{conn_, std::string(query.buffer), params, error_};
PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.buffer)};
StringBuilderReset(&query);

RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
RAISE_ADBC(result_helper.Execute(error_, params));

for (PqResultRow row : result_helper) {
const char* db_name = row[0].data;
Expand Down Expand Up @@ -280,11 +276,10 @@ class PqGetObjectsHelper {
}
}

auto result_helper = PqResultHelper{conn_, query.buffer, params, error_};
auto result_helper = PqResultHelper{conn_, query.buffer};
StringBuilderReset(&query);

RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
RAISE_ADBC(result_helper.Execute(error_, params));
for (PqResultRow row : result_helper) {
const char* table_name = row[0].data;
const char* table_type = row[1].data;
Expand Down Expand Up @@ -341,11 +336,10 @@ class PqGetObjectsHelper {
params.push_back(std::string(column_name_));
}

auto result_helper = PqResultHelper{conn_, query.buffer, params, error_};
auto result_helper = PqResultHelper{conn_, query.buffer};
StringBuilderReset(&query);

RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
RAISE_ADBC(result_helper.Execute(error_, params));

for (PqResultRow row : result_helper) {
const char* column_name = row[0].data;
Expand Down Expand Up @@ -493,11 +487,10 @@ class PqGetObjectsHelper {
params.push_back(std::string(column_name_));
}

auto result_helper = PqResultHelper{conn_, query.buffer, params, error_};
auto result_helper = PqResultHelper{conn_, query.buffer};
StringBuilderReset(&query);

RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
RAISE_ADBC(result_helper.Execute(error_, params));

for (PqResultRow row : result_helper) {
const char* constraint_name = row[0].data;
Expand Down Expand Up @@ -655,9 +648,8 @@ AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl(
break;
case ADBC_INFO_VENDOR_VERSION: {
const char* stmt = "SHOW server_version_num";
auto result_helper = PqResultHelper{conn_, std::string(stmt), error};
RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
auto result_helper = PqResultHelper{conn_, std::string(stmt)};
RAISE_ADBC(result_helper.Execute(error));
auto it = result_helper.begin();
if (it == result_helper.end()) {
SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt);
Expand Down Expand Up @@ -760,9 +752,8 @@ AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value,
if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) {
output = PQdb(conn_);
} else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) {
PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error};
RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA"};
RAISE_ADBC(result_helper.Execute(error));
auto it = result_helper.begin();
if (it == result_helper.end()) {
SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'");
Expand Down Expand Up @@ -931,10 +922,8 @@ AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_
std::string prev_table;

{
PqResultHelper result_helper{
conn, query, {db_schema, table_name ? table_name : "%"}, error};
RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
PqResultHelper result_helper{conn, query};
RAISE_ADBC(result_helper.Execute(error, {db_schema, table_name ? table_name : "%"}));

for (PqResultRow row : result_helper) {
auto reltuples = row[5].ParseDouble();
Expand Down Expand Up @@ -1166,11 +1155,9 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,

std::vector<std::string> params = {table_name_str};

PqResultHelper result_helper =
PqResultHelper{conn_, std::string(query.c_str()), params, error};
PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.c_str())};

RAISE_ADBC(result_helper.Prepare());
auto result = result_helper.Execute();
auto result = result_helper.Execute(error, params);
if (result != ADBC_STATUS_OK) {
auto error_code = std::string(error->sqlstate, 5);
if ((error_code == "42P01") || (error_code == "42602")) {
Expand Down Expand Up @@ -1337,10 +1324,8 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value,
return ADBC_STATUS_OK;
} else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) {
// PostgreSQL doesn't accept a parameter here
PqResultHelper result_helper{
conn_, std::string("SET search_path TO ") + value, {}, error};
RAISE_ADBC(result_helper.Prepare());
RAISE_ADBC(result_helper.Execute());
PqResultHelper result_helper{conn_, std::string("SET search_path TO ") + value};
RAISE_ADBC(result_helper.Execute(error));
return ADBC_STATUS_OK;
}
SetError(error, "%s%s", "[libpq] Unknown option ", key);
Expand Down
130 changes: 119 additions & 11 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_EQ(reader.rows_affected, 2);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand Down Expand Up @@ -1276,6 +1276,32 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
}
}

TEST_F(PostgresStatementTest, ExecuteSchemaParameterizedQuery) {
nanoarrow::UniqueSchema schema_bind;
ArrowSchemaInit(schema_bind.get());
ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1),
adbc_validation::IsOkErrno());
ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_STRING),
adbc_validation::IsOkErrno());

nanoarrow::UniqueArrayStream bind;
nanoarrow::EmptyArrayStream(schema_bind.get()).ToArrayStream(bind.get());

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT $1", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBindStream(&statement, bind.get(), &error), IsOkStatus());

nanoarrow::UniqueSchema schema;
ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
IsOkStatus(&error));

ASSERT_EQ(1, schema->n_children);
ASSERT_STREQ("u", schema->children[0]->format);

ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

TEST_F(PostgresStatementTest, BatchSizeHint) {
ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error),
IsOkStatus(&error));
Expand Down Expand Up @@ -1345,16 +1371,13 @@ TEST_F(PostgresStatementTest, AdbcErrorBackwardsCompatibility) {
TEST_F(PostgresStatementTest, Cancel) {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

for (const char* query : {
"DROP TABLE IF EXISTS test_cancel",
"CREATE TABLE test_cancel (ints INT)",
R"(INSERT INTO test_cancel (ints)
SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g))",
}) {
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));
}
const char* query = R"(DROP TABLE IF EXISTS test_cancel;
CREATE TABLE test_cancel (ints INT);
INSERT INTO test_cancel (ints)
SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g);)";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_cancel", &error),
IsOkStatus(&error));
Expand All @@ -1381,6 +1404,91 @@ TEST_F(PostgresStatementTest, Cancel) {
ASSERT_NE(0, AdbcErrorGetDetailCount(detail));
}

TEST_F(PostgresStatementTest, MultipleStatementsSingleQuery) {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

const char* query = R"(DROP TABLE IF EXISTS test_query_statements;
CREATE TABLE test_query_statements (ints INT);
INSERT INTO test_query_statements VALUES((1));
INSERT INTO test_query_statements VALUES((2));
INSERT INTO test_query_statements VALUES((3));)";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_query_statements", &error),
IsOkStatus(&error));

adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
reader.GetSchema();
ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno());
ASSERT_EQ(reader.array->length, 3);
}

TEST_F(PostgresStatementTest, SetUseCopyFalse) {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

const char* query = R"(DROP TABLE IF EXISTS test_query_set_copy_false;
CREATE TABLE test_query_set_copy_false (ints INT);
INSERT INTO test_query_set_copy_false VALUES((1));
INSERT INTO test_query_set_copy_false VALUES((NULL));
INSERT INTO test_query_set_copy_false VALUES((3));)";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

// Check option setting/getting
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"true");

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
"not true or false", &error),
IsStatus(ADBC_STATUS_INVALID_ARGUMENT));

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
ADBC_OPTION_VALUE_ENABLED, &error),
IsOkStatus(&error));
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"true");

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
ADBC_OPTION_VALUE_DISABLED, &error),
IsOkStatus(&error));
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"false");

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
"SELECT * FROM test_query_set_copy_false", &error),
IsOkStatus(&error));

adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));

ASSERT_EQ(reader.rows_affected, 3);

reader.GetSchema();
ASSERT_EQ(reader.schema->n_children, 1);
ASSERT_STREQ(reader.schema->children[0]->format, "i");
ASSERT_STREQ(reader.schema->children[0]->name, "ints");

ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno());
ASSERT_EQ(reader.array->length, 3);
ASSERT_EQ(reader.array->n_children, 1);
ASSERT_EQ(reader.array->children[0]->null_count, 1);

ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno());
ASSERT_EQ(reader.array->release, nullptr);
}

struct TypeTestCase {
std::string name;
std::string sql_type;
Expand Down
Loading

0 comments on commit 05fa60d

Please sign in to comment.