Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/postgresql): Implement GetTableSchema #577

Merged
merged 12 commits into from
May 5, 2023
100 changes: 99 additions & 1 deletion c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,41 @@

#include "connection.h"

#include <cinttypes>
#include <cstring>
#include <memory>
#include <string>

#include <adbc.h>
#include <libpq-fe.h>

#include "database.h"
#include "utils.h"

namespace {
class PqResultHelper {
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
public:
PqResultHelper(PGconn* conn, const char* query) : conn_(conn) {
query_ = std::string(query);
}
pg_result* Execute() {
result_ = PQexec(conn_, query_.c_str());
return result_;
}

~PqResultHelper() {
if (result_ != nullptr) PQclear(result_);
}

private:
pg_result* result_ = nullptr;
PGconn* conn_;
std::string query_;
};
} // namespace

namespace adbcpq {

AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) {
if (autocommit_) {
SetError(error, "%s", "[libpq] Cannot commit when autocommit is enabled");
Expand All @@ -47,7 +73,79 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
AdbcStatusCode final_status = ADBC_STATUS_OK;
struct StringBuilder query = {0};
if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;

if (StringBuilderAppend(
&query, "%s",
"SELECT attname, atttypid "
"FROM pg_catalog.pg_class AS cls "
"INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid "
"INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
"WHERE attr.attnum >= 0 AND cls.oid = '") != 0)
return ADBC_STATUS_INTERNAL;

if (db_schema != nullptr) {
char* schema = PQescapeIdentifier(conn_, db_schema, strlen(db_schema));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit tangential to this but I don't see the current test suite as validating that we are safe against SQL injection attacks. Probably a reasonable follow up to set those types of test in place

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, do you want to file something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #643

if (schema == NULL) {
SetError(error, "%s%s", "Faled to escape schema: ", PQerrorMessage(conn_));
return ADBC_STATUS_INVALID_ARGUMENT;
}

int ret = StringBuilderAppend(&query, "%s%s", schema, ".");
PQfreemem(schema);

if (ret != 0) return ADBC_STATUS_INTERNAL;
}

char* table = PQescapeIdentifier(conn_, table_name, strlen(table_name));
if (table == NULL) {
SetError(error, "%s%s", "Failed to escape table: ", PQerrorMessage(conn_));
return ADBC_STATUS_INVALID_ARGUMENT;
}

int ret = StringBuilderAppend(&query, "%s%s", table_name, "'::regclass::oid");
PQfreemem(table);

if (ret != 0) return ADBC_STATUS_INTERNAL;
lidavidm marked this conversation as resolved.
Show resolved Hide resolved

PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
StringBuilderReset(&query);
pg_result* result = result_helper.Execute();

ExecStatusType pq_status = PQresultStatus(result);
auto uschema = nanoarrow::UniqueSchema();

if (pq_status == PGRES_TUPLES_OK) {
int num_rows = PQntuples(result);
ArrowSchemaInit(uschema.get());
CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), num_rows), error);

ArrowError na_error;
for (int row = 0; row < num_rows; row++) {
const char* colname = PQgetvalue(result, row, 0);
const Oid pg_oid = static_cast<uint32_t>(
std::strtol(PQgetvalue(result, row, 1), /*str_end=*/nullptr, /*base=*/10));

PostgresType pg_type;
if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row + 1, " (\"", colname,
"\") has unknown type code ", pg_oid);
final_status = ADBC_STATUS_NOT_IMPLEMENTED;
break;
}

CHECK_NA(INTERNAL, pg_type.WithFieldName(colname).SetSchema(uschema->children[row]),
error);
}
} else {
SetError(error, "%s%s", "Failed to get table schema: ", PQerrorMessage(conn_));
final_status = ADBC_STATUS_IO;
}

uschema.move(schema);
return final_status;
}

AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database,
Expand Down
1 change: 0 additions & 1 deletion c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class PostgresConnectionTest : public ::testing::Test,
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }

void TestMetadataGetInfo() { GTEST_SKIP() << "Not yet implemented"; }
void TestMetadataGetTableSchema() { GTEST_SKIP() << "Not yet implemented"; }
void TestMetadataGetTableTypes() { GTEST_SKIP() << "Not yet implemented"; }

void TestMetadataGetObjectsCatalogs() { GTEST_SKIP() << "Not yet implemented"; }
Expand Down