Skip to content

Commit

Permalink
fix(c/driver/postgresql): support catalog arg of GetTableSchema (#1387)
Browse files Browse the repository at this point in the history
Fixes #1339.
  • Loading branch information
lidavidm authored Dec 20, 2023
1 parent 9589754 commit 5ca9c29
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 38 deletions.
37 changes: 11 additions & 26 deletions c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1147,38 +1147,23 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,
struct ArrowSchema* schema,
struct AdbcError* error) {
AdbcStatusCode final_status = ADBC_STATUS_OK;
struct StringBuilder query;
std::memset(&query, 0, sizeof(query));
std::vector<std::string> params;
if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;

if (StringBuilderAppend(
&query, "%s",
"SELECT attname, atttypid "
"FROM pg_catalog.pg_class AS cls "
"INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid "
"INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
"WHERE attr.attnum >= 0 AND cls.oid = ") != 0)
return ADBC_STATUS_INTERNAL;
std::string query =
"SELECT attname, atttypid "
"FROM pg_catalog.pg_class AS cls "
"INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid "
"INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
"WHERE attr.attnum >= 0 AND cls.oid = $1::regclass::oid";

std::vector<std::string> params;
if (db_schema != nullptr) {
if (StringBuilderAppend(&query, "%s", "$1.")) {
StringBuilderReset(&query);
return ADBC_STATUS_INTERNAL;
}
params.push_back(db_schema);
}

if (StringBuilderAppend(&query, "%s%" PRIu64 "%s", "$",
static_cast<uint64_t>(params.size() + 1), "::regclass::oid")) {
StringBuilderReset(&query);
return ADBC_STATUS_INTERNAL;
params.push_back(std::string(db_schema) + "." + table_name);
} else {
params.push_back(table_name);
}
params.push_back(table_name);

PqResultHelper result_helper =
PqResultHelper{conn_, std::string(query.buffer), params, error};
StringBuilderReset(&query);
PqResultHelper{conn_, std::string(query.c_str()), params, error};

RAISE_ADBC(result_helper.Prepare());
auto result = result_helper.Execute();
Expand Down
34 changes: 29 additions & 5 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
return AdbcStatementRelease(&statement.value, error);
}

AdbcStatusCode DropTable(struct AdbcConnection* connection, const std::string& name,
const std::string& db_schema,
struct AdbcError* error) const override {
Handle<struct AdbcStatement> statement;
RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error));

std::string query = "DROP TABLE IF EXISTS \"" + db_schema + "\".\"" + name + "\"";
RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error));
RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
return AdbcStatementRelease(&statement.value, error);
}

AdbcStatusCode DropTempTable(struct AdbcConnection* connection, const std::string& name,
struct AdbcError* error) const override {
Handle<struct AdbcStatement> statement;
Expand All @@ -83,6 +95,18 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
return AdbcStatementRelease(&statement.value, error);
}

AdbcStatusCode EnsureDbSchema(struct AdbcConnection* connection,
const std::string& name,
struct AdbcError* error) const override {
Handle<struct AdbcStatement> statement;
RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error));

std::string query = "CREATE SCHEMA IF NOT EXISTS \"" + name + "\"";
RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error));
RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
return AdbcStatementRelease(&statement.value, error);
}

std::string BindParameter(int index) const override {
return "$" + std::to_string(index + 1);
}
Expand Down Expand Up @@ -343,7 +367,7 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand Down Expand Up @@ -416,7 +440,7 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand All @@ -435,7 +459,7 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand Down Expand Up @@ -1162,7 +1186,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand All @@ -1177,7 +1201,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_EQ(reader.rows_affected, -1);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
Expand Down
51 changes: 48 additions & 3 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ bool iequals(std::string_view s1, std::string_view s2) {
// DriverQuirks

AdbcStatusCode DoIngestSampleTable(struct AdbcConnection* connection,
const std::string& name, struct AdbcError* error) {
const std::string& name,
std::optional<std::string> db_schema,
struct AdbcError* error) {
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
Expand All @@ -84,14 +86,19 @@ AdbcStatusCode DoIngestSampleTable(struct AdbcConnection* connection,
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));
CHECK_OK(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE,
name.c_str(), error));
if (db_schema.has_value()) {
CHECK_OK(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA,
db_schema->c_str(), error));
}
CHECK_OK(AdbcStatementBind(&statement.value, &array.value, &schema.value, error));
CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error));
CHECK_OK(AdbcStatementRelease(&statement.value, error));
return ADBC_STATUS_OK;
}

void IngestSampleTable(struct AdbcConnection* connection, struct AdbcError* error) {
ASSERT_THAT(DoIngestSampleTable(connection, "bulk_ingest", error), IsOkStatus(error));
ASSERT_THAT(DoIngestSampleTable(connection, "bulk_ingest", std::nullopt, error),
IsOkStatus(error));
}

AdbcStatusCode DriverQuirks::EnsureSampleTable(struct AdbcConnection* connection,
Expand All @@ -107,7 +114,17 @@ AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* connection
if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
return DoIngestSampleTable(connection, name, error);
return DoIngestSampleTable(connection, name, std::nullopt, error);
}

AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* connection,
const std::string& name,
const std::string& schema,
struct AdbcError* error) const {
if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
return DoIngestSampleTable(connection, name, schema, error);
}

//------------------------------------------------------------
Expand Down Expand Up @@ -431,6 +448,34 @@ void ConnectionTest::TestMetadataGetTableSchema() {
{"strings", NANOARROW_TYPE_STRING, NULLABLE}}));
}

void ConnectionTest::TestMetadataGetTableSchemaDbSchema() {
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));

auto status = quirks()->EnsureDbSchema(&connection, "otherschema", &error);
if (status == ADBC_STATUS_NOT_IMPLEMENTED) {
GTEST_SKIP() << "Schema not supported";
return;
}
ASSERT_THAT(status, IsOkStatus(&error));

ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", "otherschema", &error),
IsOkStatus(&error));
ASSERT_THAT(
quirks()->CreateSampleTable(&connection, "bulk_ingest", "otherschema", &error),
IsOkStatus(&error));

Handle<ArrowSchema> schema;
ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr,
/*db_schema=*/"otherschema", "bulk_ingest",
&schema.value, &error),
IsOkStatus(&error));

ASSERT_NO_FATAL_FAILURE(
CompareSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64, NULLABLE},
{"strings", NANOARROW_TYPE_STRING, NULLABLE}}));
}

void ConnectionTest::TestMetadataGetTableSchemaEscaping() {
if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
GTEST_SKIP();
Expand Down
33 changes: 32 additions & 1 deletion c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class DriverQuirks {
return ADBC_STATUS_OK;
}

virtual AdbcStatusCode DropTable(struct AdbcConnection* connection,
const std::string& name,
const std::string& db_schema,
struct AdbcError* error) const {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

/// \brief Drop the given temporary table. Used by tests to reset state.
virtual AdbcStatusCode DropTempTable(struct AdbcConnection* connection,
const std::string& name,
Expand All @@ -68,13 +75,33 @@ class DriverQuirks {
const std::string& name,
struct AdbcError* error) const;

/// \brief Create a schema for testing.
virtual AdbcStatusCode EnsureDbSchema(struct AdbcConnection* connection,
const std::string& name,
struct AdbcError* error) const {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

/// \brief Create a table of sample data with a fixed schema for testing.
///
/// The table should have two columns:
/// - "int64s" with Arrow type int64.
/// - "strings" with Arrow type utf8.
virtual AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection,
const std::string& name,
struct AdbcError* error) const;

/// \brief Create a table of sample data with a fixed schema for testing.
///
/// Create it in the given schema. Specify "" for the default schema.
/// Return NOT_IMPLEMENTED if not supported by this backend.
///
/// The table should have two columns:
/// - "int64s" with Arrow type int64.
/// - "strings" with Arrow type utf8.
virtual AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection,
const std::string& name,
const std::string& schema,
struct AdbcError* error) const;

/// \brief Get the statement to create a table with a primary key, or nullopt if not
Expand Down Expand Up @@ -197,7 +224,7 @@ class DriverQuirks {
/// \brief Default catalog to use for tests
virtual std::string catalog() const { return ""; }

/// \brief Default Schema to use for tests
/// \brief Default database schema to use for tests
virtual std::string db_schema() const { return ""; }
};

Expand Down Expand Up @@ -243,6 +270,7 @@ class ConnectionTest {

void TestMetadataGetInfo();
void TestMetadataGetTableSchema();
void TestMetadataGetTableSchemaDbSchema();
void TestMetadataGetTableSchemaEscaping();
void TestMetadataGetTableSchemaNotFound();
void TestMetadataGetTableTypes();
Expand Down Expand Up @@ -277,6 +305,9 @@ class ConnectionTest {
TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \
TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \
TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \
TEST_F(FIXTURE, MetadataGetTableSchemaDbSchema) { \
TestMetadataGetTableSchemaDbSchema(); \
} \
TEST_F(FIXTURE, MetadataGetTableSchemaEscaping) { \
TestMetadataGetTableSchemaEscaping(); \
} \
Expand Down
4 changes: 2 additions & 2 deletions ci/conda_env_cpp_lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# specific language governing permissions and limitations
# under the License.

clang=14
clang-tools=14
clang=14.*
clang-tools=14.*
26 changes: 25 additions & 1 deletion docs/source/python/recipe/postgresql_get_table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,42 @@
uri = os.environ["ADBC_POSTGRESQL_TEST_URI"]
conn = adbc_driver_postgresql.dbapi.connect(uri)

#: We'll create an example table to test.
#: We'll create some example tables to test.
with conn.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS example")
cur.execute("CREATE TABLE example (ints INT, bigints BIGINT)")

cur.execute("CREATE SCHEMA IF NOT EXISTS other_schema")
cur.execute("DROP TABLE IF EXISTS other_schema.example")
cur.execute("CREATE TABLE other_schema.example (strings TEXT, values NUMERIC)")

conn.commit()

#: By default the "active" catalog/schema are assumed.
assert conn.adbc_get_table_schema("example") == pyarrow.schema(
[
("ints", "int32"),
("bigints", "int64"),
]
)

#: We can explicitly specify the PostgreSQL schema to get the Arrow schema of
#: a table in a different namespace.
#:
#: .. note:: In PostgreSQL, you can only query the database (catalog) that you
#: are connected to. So we cannot specify the catalog here (or
#: rather, there is no point in doing so).
#:
#: Note that the NUMERIC column is read as a string, because PostgreSQL
#: decimals do not map onto Arrow decimals.
assert conn.adbc_get_table_schema(
"example",
db_schema_filter="other_schema",
) == pyarrow.schema(
[
("strings", "string"),
("values", "string"),
]
)

conn.close()

0 comments on commit 5ca9c29

Please sign in to comment.