Skip to content

Commit

Permalink
feat(c/driver): support target catalog/schema for ingestion (#1056)
Browse files Browse the repository at this point in the history
- Support target catalog/schema options for ingestion
- Fix escaping in SQLite, PostgreSQL

Fixes #1000.
  • Loading branch information
lidavidm committed Sep 12, 2023
1 parent 70a741e commit db39236
Show file tree
Hide file tree
Showing 18 changed files with 376 additions and 38 deletions.
39 changes: 39 additions & 0 deletions c/driver/common/options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

/// Common options that haven't yet been formally standardized.
/// https://github.com/apache/arrow-adbc/issues/1055

#pragma once

#ifdef __cplusplus
extern "C" {
#endif

/// \brief The catalog of the table for bulk insert.
///
/// The type is char*.
#define ADBC_INGEST_OPTION_TARGET_CATALOG "adbc.ingest.target_catalog"

/// \brief The schema of the table for bulk insert.
///
/// The type is char*.
#define ADBC_INGEST_OPTION_TARGET_DB_SCHEMA "adbc.ingest.target_db_schema"

#ifdef __cplusplus
}
#endif
3 changes: 3 additions & 0 deletions c/driver/flightsql/dremio_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class DremioFlightSqlStatementTest : public ::testing::Test,

void TestResultInvalidation() { GTEST_SKIP() << "Dremio generates a CANCELLED"; }
void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; }
void TestSqlIngestColumnEscaping() {
GTEST_SKIP() << "Column escaping not implemented";
}

protected:
DremioFlightSqlQuirks quirks_;
Expand Down
3 changes: 3 additions & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ class SqliteFlightSqlStatementTest : public ::testing::Test,
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }

void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; }
void TestSqlIngestColumnEscaping() {
GTEST_SKIP() << "Column escaping not implemented";
}
void TestSqlIngestInterval() {
GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
}
Expand Down
4 changes: 3 additions & 1 deletion c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
AdbcStatusCode status = AdbcStatementNew(connection, &statement, error);
if (status != ADBC_STATUS_OK) return status;

std::string query = "DROP TABLE IF EXISTS " + name;
std::string query = "DROP TABLE IF EXISTS \"" + name + "\"";
status = AdbcStatementSetSqlQuery(&statement, query.c_str(), error);
if (status != ADBC_STATUS_OK) {
std::ignore = AdbcStatementRelease(&statement, error);
Expand Down Expand Up @@ -111,6 +111,8 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
std::string catalog() const override { return "postgres"; }
std::string db_schema() const override { return "public"; }

bool supports_bulk_ingest_catalog() const override { return false; }
bool supports_bulk_ingest_db_schema() const override { return true; }
bool supports_cancel() const override { return true; }
bool supports_execute_schema() const override { return true; }
std::optional<adbc_validation::SqlInfoValue> supports_get_sql_info(
Expand Down
68 changes: 58 additions & 10 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <libpq-fe.h>
#include <nanoarrow/nanoarrow.hpp>

#include "common/options.h"
#include "common/utils.h"
#include "connection.h"
#include "error.h"
Expand Down Expand Up @@ -831,7 +832,36 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) {
AdbcStatusCode PostgresStatement::CreateBulkTable(
const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
struct AdbcError* error) {
std::string* escaped_table, struct AdbcError* error) {
PGconn* conn = connection_->conn();

{
if (!ingest_.db_schema.empty()) {
char* escaped =
PQescapeIdentifier(conn, ingest_.db_schema.c_str(), ingest_.db_schema.size());
if (escaped == nullptr) {
SetError(error, "[libpq] Failed to escape target schema %s for ingestion: %s",
ingest_.db_schema.c_str(), PQerrorMessage(conn));
return ADBC_STATUS_INTERNAL;
}
*escaped_table += escaped;
*escaped_table += " . ";
PQfreemem(escaped);
}

if (!ingest_.target.empty()) {
char* escaped =
PQescapeIdentifier(conn, ingest_.target.c_str(), ingest_.target.size());
if (escaped == nullptr) {
SetError(error, "[libpq] Failed to escape target table %s for ingestion: %s",
ingest_.target.c_str(), PQerrorMessage(conn));
return ADBC_STATUS_INTERNAL;
}
*escaped_table += escaped;
PQfreemem(escaped);
}
}

std::string create = "CREATE TABLE ";
switch (ingest_.mode) {
case IngestMode::kCreate:
Expand All @@ -840,15 +870,15 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case IngestMode::kAppend:
return ADBC_STATUS_OK;
case IngestMode::kReplace: {
std::string drop = "DROP TABLE IF EXISTS " + ingest_.target;
PGresult* result = PQexecParams(connection_->conn(), drop.c_str(), /*nParams=*/0,
std::string drop = "DROP TABLE IF EXISTS " + *escaped_table;
PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0,
/*paramTypes=*/nullptr, /*paramValues=*/nullptr,
/*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
/*resultFormat=*/1 /*(binary)*/);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
AdbcStatusCode code =
SetError(error, result, "[libpq] Failed to drop table: %s\nQuery was: %s",
PQerrorMessage(connection_->conn()), drop.c_str());
PQerrorMessage(conn), drop.c_str());
PQclear(result);
return code;
}
Expand All @@ -859,12 +889,22 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += "IF NOT EXISTS ";
break;
}
create += ingest_.target;
create += *escaped_table;
create += " (";

for (size_t i = 0; i < source_schema_fields.size(); i++) {
if (i > 0) create += ", ";
create += source_schema.children[i]->name;

const char* unescaped = source_schema.children[i]->name;
char* escaped = PQescapeIdentifier(conn, unescaped, std::strlen(unescaped));
if (escaped == nullptr) {
SetError(error, "[libpq] Failed to escape column %s for ingestion: %s", unescaped,
PQerrorMessage(conn));
return ADBC_STATUS_INTERNAL;
}
create += escaped;
PQfreemem(escaped);

switch (source_schema_fields[i].type) {
case ArrowType::NANOARROW_TYPE_INT8:
case ArrowType::NANOARROW_TYPE_INT16:
Expand Down Expand Up @@ -914,14 +954,14 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(

create += ")";
SetError(error, "%s%s", "[libpq] ", create.c_str());
PGresult* result = PQexecParams(connection_->conn(), create.c_str(), /*nParams=*/0,
PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0,
/*paramTypes=*/nullptr, /*paramValues=*/nullptr,
/*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
/*resultFormat=*/1 /*(binary)*/);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
AdbcStatusCode code =
SetError(error, result, "[libpq] Failed to create table: %s\nQuery was: %s",
PQerrorMessage(connection_->conn()), create.c_str());
PQerrorMessage(conn), create.c_str());
PQclear(result);
return code;
}
Expand Down Expand Up @@ -1060,16 +1100,17 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,

BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
std::string escaped_table;
RAISE_ADBC(bind_stream.Begin(
[&]() -> AdbcStatusCode {
return CreateBulkTable(bind_stream.bind_schema.value,
bind_stream.bind_schema_fields, error);
bind_stream.bind_schema_fields, &escaped_table, error);
},
error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));

std::string insert = "INSERT INTO ";
insert += ingest_.target;
insert += escaped_table;
insert += " VALUES (";
for (size_t i = 0; i < bind_stream.bind_schema_fields.size(); i++) {
if (i > 0) insert += ", ";
Expand Down Expand Up @@ -1109,6 +1150,8 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t
std::string result;
if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) {
result = ingest_.target;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) {
result = ingest_.db_schema;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) {
switch (ingest_.mode) {
case IngestMode::kCreate:
Expand Down Expand Up @@ -1190,6 +1233,7 @@ AdbcStatusCode PostgresStatement::Release(struct AdbcError* error) {
AdbcStatusCode PostgresStatement::SetSqlQuery(const char* query,
struct AdbcError* error) {
ingest_.target.clear();
ingest_.db_schema.clear();
query_ = query;
prepared_ = false;
return ADBC_STATUS_OK;
Expand All @@ -1201,6 +1245,10 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
query_.clear();
ingest_.target = value;
prepared_ = false;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) {
query_.clear();
ingest_.db_schema = value;
prepared_ = false;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) {
if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) {
ingest_.mode = IngestMode::kCreate;
Expand Down
3 changes: 2 additions & 1 deletion c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class PostgresStatement {
AdbcStatusCode CreateBulkTable(
const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
struct AdbcError* error);
std::string* escaped_table, struct AdbcError* error);
AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
Expand All @@ -154,6 +154,7 @@ class PostgresStatement {
};

struct {
std::string db_schema;
std::string target;
IngestMode mode = IngestMode::kCreate;
} ingest_;
Expand Down
50 changes: 40 additions & 10 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <nanoarrow/nanoarrow.h>
#include <sqlite3.h>

#include "common/options.h"
#include "common/utils.h"
#include "statement_reader.h"
#include "types.h"
Expand Down Expand Up @@ -1025,6 +1026,7 @@ AdbcStatusCode SqliteStatementRelease(struct AdbcStatement* statement,
}
if (stmt->query) free(stmt->query);
AdbcSqliteBinderRelease(&stmt->binder);
if (stmt->target_catalog) free(stmt->target_catalog);
if (stmt->target_table) free(stmt->target_table);
if (rc != SQLITE_OK) {
SetError(error,
Expand Down Expand Up @@ -1079,29 +1081,38 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
AdbcStatusCode code = ADBC_STATUS_OK;

// Create statements for CREATE TABLE / INSERT
sqlite3_str* create_query = sqlite3_str_new(NULL);
sqlite3_str* create_query = NULL;
sqlite3_str* insert_query = NULL;
char* table = NULL;

create_query = sqlite3_str_new(NULL);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
sqlite3_free(sqlite3_str_finish(create_query));
return ADBC_STATUS_INTERNAL;
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

sqlite3_str* insert_query = sqlite3_str_new(NULL);
insert_query = sqlite3_str_new(NULL);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
sqlite3_free(sqlite3_str_finish(create_query));
sqlite3_free(sqlite3_str_finish(insert_query));
return ADBC_STATUS_INTERNAL;
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

if (stmt->target_catalog != NULL) {
table = sqlite3_mprintf("\"%w\" . \"%w\"", stmt->target_catalog, stmt->target_table);
} else {
table = sqlite3_mprintf("\"%w\"", stmt->target_table);
}

sqlite3_str_appendf(create_query, "CREATE TABLE %Q (", stmt->target_table);
sqlite3_str_appendf(create_query, "CREATE TABLE %s (", table);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

sqlite3_str_appendf(insert_query, "INSERT INTO %Q VALUES (", stmt->target_table);
sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
Expand All @@ -1121,7 +1132,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
}
}

sqlite3_str_appendf(create_query, "%Q", stmt->binder.schema.children[i]->name);
sqlite3_str_appendf(create_query, "\"%w\"", stmt->binder.schema.children[i]->name);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
Expand Down Expand Up @@ -1221,6 +1232,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
cleanup:
sqlite3_free(sqlite3_str_finish(create_query));
sqlite3_free(sqlite3_str_finish(insert_query));
if (table != NULL) sqlite3_free(table);
return code;
}

Expand Down Expand Up @@ -1347,6 +1359,10 @@ AdbcStatusCode SqliteStatementSetSqlQuery(struct AdbcStatement* statement,
free(stmt->query);
stmt->query = NULL;
}
if (stmt->target_catalog) {
free(stmt->target_catalog);
stmt->target_catalog = NULL;
}
if (stmt->target_table) {
free(stmt->target_table);
stmt->target_table = NULL;
Expand Down Expand Up @@ -1462,6 +1478,20 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c
stmt->target_table = (char*)malloc(len);
strncpy(stmt->target_table, value, len);
return ADBC_STATUS_OK;
} else if (strcmp(key, ADBC_INGEST_OPTION_TARGET_CATALOG) == 0) {
if (stmt->query) {
free(stmt->query);
stmt->query = NULL;
}
if (stmt->target_catalog) {
free(stmt->target_catalog);
stmt->target_catalog = NULL;
}

size_t len = strlen(value) + 1;
stmt->target_catalog = (char*)malloc(len);
strncpy(stmt->target_catalog, value, len);
return ADBC_STATUS_OK;
} else if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) {
if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) {
stmt->append = 1;
Expand Down
5 changes: 3 additions & 2 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
AdbcStatusCode status = AdbcStatementNew(connection, &statement, error);
if (status != ADBC_STATUS_OK) return status;

std::string query = "DROP TABLE IF EXISTS " + name;
std::string query = "DROP TABLE IF EXISTS \"" + name + "\"";
status = AdbcStatementSetSqlQuery(&statement, query.c_str(), error);
if (status != ADBC_STATUS_OK) {
std::ignore = AdbcStatementRelease(&statement, error);
Expand Down Expand Up @@ -97,6 +97,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 ||
std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0;
}
bool supports_bulk_ingest_catalog() const override { return true; }
bool supports_concurrent_statements() const override { return true; }
bool supports_get_option() const override { return false; }
std::optional<adbc_validation::SqlInfoValue> supports_get_sql_info(
Expand Down Expand Up @@ -268,7 +269,7 @@ class SqliteStatementTest : public ::testing::Test,
ADBCV_TEST_STATEMENT(SqliteStatementTest)

TEST_F(SqliteStatementTest, SqlIngestNameEscaping) {
ASSERT_THAT(quirks()->DropTable(&connection, "\"test-table\"", &error),
ASSERT_THAT(quirks()->DropTable(&connection, "test-table", &error),
adbc_validation::IsOkStatus(&error));

std::string table = "test-table";
Expand Down
1 change: 1 addition & 0 deletions c/driver/sqlite/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct SqliteStatement {
struct AdbcSqliteBinder binder;

// -- Ingest state ----------------------------------------
char* target_catalog;
char* target_table;
char append;

Expand Down
Loading

0 comments on commit db39236

Please sign in to comment.