diff --git a/.gitignore b/.gitignore index 0bd7eefeae..ad32dd218f 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,9 @@ docker_cache site/ +# Python +dist/ + # R files **/.Rproj.user **/*.Rcheck/ diff --git a/adbc.h b/adbc.h index b201a405e9..26f66b3059 100644 --- a/adbc.h +++ b/adbc.h @@ -358,6 +358,61 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// /// @{ +/// \brief Get metadata about the database/driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ----------------------------|------------------------ +/// info_name | uint32 not null +/// info_value | INFO_SCHEMA +/// +/// INFO_SCHEMA is a dense union with members: +/// +/// Field Name (Type Code) | Field Type +/// ----------------------------|------------------------ +/// string_value (0) | utf8 +/// bool_value (1) | bool +/// int64_value (2) | int64 +/// int32_bitmask (3) | int32 +/// string_list (4) | list +/// int32_to_int32_list_map (5) | map> +/// +/// Each metadatum is identified by an integer code. The recognized +/// codes are defined as constants. Codes [0, 10_000) are reserved +/// for ADBC usage. Drivers/vendors will ignore requests for +/// unrecognized codes (the row will be omitted from the result). +/// +/// \param[in] connection The connection to query. +/// \param[in] info_codes A list of metadata codes to fetch, or NULL +/// to fetch all. +/// \param[in] info_codes_length The length of the info_codes +/// parameter. Ignored if info_codes is NULL. +/// \param[out] statement The result set. AdbcStatementGetStream can +/// be called immediately; do not call Execute or Prepare. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief The database vendor/product name (e.g. the server name). +/// (type: utf8). +#define ADBC_INFO_VENDOR_NAME 0 +/// \brief The database vendor/product version (type: utf8). +#define ADBC_INFO_VENDOR_VERSION 1 +/// \brief The database vendor/product Arrow library version (type: +/// utf8). +#define ADBC_INFO_VENDOR_ARROW_VERSION 2 + +/// \brief The driver name (type: utf8). +#define ADBC_INFO_DRIVER_NAME 100 +/// \brief The driver version (type: utf8). +#define ADBC_INFO_DRIVER_VERSION 101 +/// \brief The driver Arrow library version (type: utf8). +#define ADBC_INFO_DRIVER_ARROW_VERSION 102 + /// \brief Get a hierarchical view of all catalogs, database schemas, /// tables, and columns. /// @@ -819,6 +874,8 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + struct AdbcStatement*, struct AdbcError*); AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*); AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*, struct AdbcError*); diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index a3f8bf934f..a6c4efb28c 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -320,6 +320,17 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, return connection->private_driver->ConnectionCommit(connection, error); } +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct AdbcStatement* statement, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, statement, error); +} + AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_types, @@ -681,6 +692,7 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, CHECK_REQUIRED(driver, DatabaseInit); CHECK_REQUIRED(driver, DatabaseRelease); + CHECK_REQUIRED(driver, ConnectionGetInfo); CHECK_REQUIRED(driver, ConnectionNew); CHECK_REQUIRED(driver, ConnectionInit); CHECK_REQUIRED(driver, ConnectionRelease); diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index 2adcdf210c..3f59b825e9 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -107,6 +107,51 @@ TEST_F(DriverManager, ConnectionInitRelease) { ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error)); } +TEST_F(DriverManager, MetadataGetInfo) { + static std::shared_ptr kInfoSchema = arrow::schema({ + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), + arrow::field( + "info_value", + arrow::dense_union({ + arrow::field("string_value", arrow::utf8()), + arrow::field("bool_value", arrow::boolean()), + arrow::field("int64_value", arrow::int64()), + arrow::field("int32_bitmask", arrow::int32()), + arrow::field("string_list", arrow::list(arrow::utf8())), + arrow::field("int32_to_int32_list_map", + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), + })), + }); + + AdbcStatement statement; + std::memset(&statement, 0, sizeof(statement)); + ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error)); + ADBC_ASSERT_OK_WITH_ERROR( + error, AdbcConnectionGetInfo(&connection, nullptr, 0, &statement, &error)); + + std::shared_ptr schema; + arrow::RecordBatchVector batches; + ReadStatement(&statement, &schema, &batches); + ASSERT_SCHEMA_EQ(*schema, *kInfoSchema); + ASSERT_EQ(1, batches.size()); + + std::vector info = { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_VENDOR_NAME, + ADBC_INFO_VENDOR_VERSION, + }; + ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error)); + ADBC_ASSERT_OK_WITH_ERROR( + error, + AdbcConnectionGetInfo(&connection, info.data(), info.size(), &statement, &error)); + batches.clear(); + ReadStatement(&statement, &schema, &batches); + ASSERT_SCHEMA_EQ(*schema, *kInfoSchema); + ASSERT_EQ(1, batches.size()); + ASSERT_EQ(4, batches[0]->num_rows()); +} + TEST_F(DriverManager, SqlExecute) { std::string query = "SELECT 1"; AdbcStatement statement; diff --git a/c/drivers/sqlite/sqlite.cc b/c/drivers/sqlite/sqlite.cc index ab2a2c78b9..81372d7dda 100644 --- a/c/drivers/sqlite/sqlite.cc +++ b/c/drivers/sqlite/sqlite.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -621,6 +622,102 @@ class SqliteStatementImpl { return ADBC_STATUS_INVALID_STATE; } + AdbcStatusCode GetInfo(const std::shared_ptr& self, + uint32_t* info_codes, size_t info_codes_length, + struct AdbcError* error) { + static std::shared_ptr kInfoSchema = arrow::schema({ + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), + arrow::field( + "info_value", + arrow::dense_union({ + arrow::field("string_value", arrow::utf8()), + arrow::field("bool_value", arrow::boolean()), + arrow::field("int64_value", arrow::int64()), + arrow::field("int32_bitmask", arrow::int32()), + arrow::field("string_list", arrow::list(arrow::utf8())), + arrow::field("int32_to_int32_list_map", + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), + })), + }); + static int kStringValueCode = 0; + + static std::vector kSupported = { + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + }; + + if (!info_codes) { + info_codes = kSupported.data(); + info_codes_length = kSupported.size(); + } + + arrow::UInt32Builder info_name; + std::unique_ptr info_value_builder; + ADBC_RETURN_NOT_OK( + FromArrowStatus(MakeBuilder(arrow::default_memory_pool(), + kInfoSchema->field(1)->type(), &info_value_builder), + error)); + arrow::DenseUnionBuilder* info_value = + static_cast(info_value_builder.get()); + arrow::StringBuilder* info_string = + static_cast(info_value->child_builder(0).get()); + + for (size_t i = 0; i < info_codes_length; i++) { + switch (info_codes[i]) { + case ADBC_INFO_VENDOR_NAME: + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_value->Append(kStringValueCode), error)); + ADBC_RETURN_NOT_OK(FromArrowStatus(info_string->Append("SQLite3"), error)); + break; + case ADBC_INFO_VENDOR_VERSION: + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_value->Append(kStringValueCode), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_string->Append(sqlite3_libversion()), error)); + break; + case ADBC_INFO_DRIVER_NAME: + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_value->Append(kStringValueCode), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_string->Append("ADBC C SQLite3"), error)); + break; + case ADBC_INFO_DRIVER_VERSION: + // TODO: set up CMake to embed version info + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_value->Append(kStringValueCode), error)); + ADBC_RETURN_NOT_OK(FromArrowStatus(info_string->Append("0.0.1"), error)); + break; + case ADBC_INFO_DRIVER_ARROW_VERSION: + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); + ADBC_RETURN_NOT_OK( + FromArrowStatus(info_value->Append(kStringValueCode), error)); + ADBC_RETURN_NOT_OK(FromArrowStatus( + info_string->Append("Arrow/C++ " ARROW_VERSION_STRING), error)); + break; + default: + // Unrecognized + break; + } + } + + arrow::ArrayVector arrays(2); + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Finish(&arrays[0]), error)); + ADBC_RETURN_NOT_OK(FromArrowStatus(info_value->Finish(&arrays[1]), error)); + const int64_t rows = arrays[0]->length(); + auto status = arrow::RecordBatchReader::Make( + { + arrow::RecordBatch::Make(kInfoSchema, rows, std::move(arrays)), + }, + kInfoSchema) + .Value(&result_reader_); + ADBC_RETURN_NOT_OK(FromArrowStatus(status, error)); + return ADBC_STATUS_OK; + } + AdbcStatusCode GetObjects(const std::shared_ptr& self, int depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_type, @@ -1234,6 +1331,16 @@ AdbcStatusCode SqliteConnectionCommit(struct AdbcConnection* connection, return (*ptr)->Commit(error); } +AdbcStatusCode SqliteConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetInfo(*ptr, info_codes, info_codes_length, error); +} + AdbcStatusCode SqliteConnectionGetObjects( struct AdbcConnection* connection, int depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_types, @@ -1441,6 +1548,14 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, return SqliteConnectionCommit(connection, error); } +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct AdbcStatement* statement, + struct AdbcError* error) { + return SqliteConnectionGetInfo(connection, info_codes, info_codes_length, statement, + error); +} + AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_types, @@ -1566,6 +1681,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver, driver->DatabaseSetOption = SqliteDatabaseSetOption; driver->ConnectionCommit = SqliteConnectionCommit; + driver->ConnectionGetInfo = SqliteConnectionGetInfo; driver->ConnectionGetObjects = SqliteConnectionGetObjects; driver->ConnectionGetTableSchema = SqliteConnectionGetTableSchema; driver->ConnectionGetTableTypes = SqliteConnectionGetTableTypes; diff --git a/c/drivers/sqlite/sqlite_test.cc b/c/drivers/sqlite/sqlite_test.cc index 9f5035a252..7958c6450c 100644 --- a/c/drivers/sqlite/sqlite_test.cc +++ b/c/drivers/sqlite/sqlite_test.cc @@ -421,6 +421,51 @@ TEST_F(Sqlite, MultipleConnections) { ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection2, &error)); } +TEST_F(Sqlite, MetadataGetInfo) { + static std::shared_ptr kInfoSchema = arrow::schema({ + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), + arrow::field( + "info_value", + arrow::dense_union({ + arrow::field("string_value", arrow::utf8()), + arrow::field("bool_value", arrow::boolean()), + arrow::field("int64_value", arrow::int64()), + arrow::field("int32_bitmask", arrow::int32()), + arrow::field("string_list", arrow::list(arrow::utf8())), + arrow::field("int32_to_int32_list_map", + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), + })), + }); + + AdbcStatement statement; + std::memset(&statement, 0, sizeof(statement)); + ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error)); + ADBC_ASSERT_OK_WITH_ERROR( + error, AdbcConnectionGetInfo(&connection, nullptr, 0, &statement, &error)); + + std::shared_ptr schema; + arrow::RecordBatchVector batches; + ReadStatement(&statement, &schema, &batches); + ASSERT_SCHEMA_EQ(*schema, *kInfoSchema); + ASSERT_EQ(1, batches.size()); + + std::vector info = { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_VENDOR_NAME, + ADBC_INFO_VENDOR_VERSION, + }; + ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error)); + ADBC_ASSERT_OK_WITH_ERROR( + error, + AdbcConnectionGetInfo(&connection, info.data(), info.size(), &statement, &error)); + batches.clear(); + ReadStatement(&statement, &schema, &batches); + ASSERT_SCHEMA_EQ(*schema, *kInfoSchema); + ASSERT_EQ(1, batches.size()); + ASSERT_EQ(4, batches[0]->num_rows()); +} + TEST_F(Sqlite, MetadataGetTableTypes) { AdbcStatement statement; std::memset(&statement, 0, sizeof(statement)); diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java index 00002a8ab9..3e3fe6fa08 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java @@ -54,6 +54,20 @@ default AdbcStatement deserializePartitionDescriptor(ByteBuffer descriptor) thro "Connection does not support deserializePartitionDescriptor(ByteBuffer)"); } + AdbcStatement getInfo(int[] infoCodes) throws AdbcException; + + default AdbcStatement getInfo(AdbcInfoCode[] infoCodes) throws AdbcException { + int[] codes = new int[infoCodes.length]; + for (int i = 0; i < infoCodes.length; i++) { + codes[i] = infoCodes[i].getValue(); + } + return getInfo(codes); + } + + default AdbcStatement getInfo() throws AdbcException { + return getInfo((int[]) null); + } + /** * Get a hierarchical view of all catalogs, database schemas, tables, and columns. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java new file mode 100644 index 0000000000..52c0956564 --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +/** Integer IDs used for requesting information about the database/driver. */ +public enum AdbcInfoCode { + /** The database vendor/product name (e.g. the server name) (type: utf8). */ + VENDOR_NAME(0), + /** The database vendor/product version (type: utf8). */ + VENDOR_VERSION(1), + /** The database vendor/product Arrow library version (type: utf8). */ + VENDOR_ARROW_VERSION(2), + + /** The driver name (type: utf8). */ + DRIVER_NAME(100), + /** The driver version (type: utf8). */ + DRIVER_VERSION(101), + /** The driver Arrow library version (type: utf8). */ + DRIVER_ARROW_VERSION(102), + ; + + private final int value; + + AdbcInfoCode(int value) { + this.value = value; + } + + public int getValue() { + return value; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java index af35cbb206..a14c04c700 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java @@ -19,7 +19,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -32,6 +32,38 @@ private StandardSchemas() { private static final ArrowType INT16 = new ArrowType.Int(16, true); private static final ArrowType INT32 = new ArrowType.Int(32, true); + private static final ArrowType INT64 = new ArrowType.Int(64, true); + private static final ArrowType UINT32 = new ArrowType.Int(32, false); + + /** The schema of the result set of {@link AdbcConnection#getInfo(int[])}}. */ + public static final Schema GET_INFO_SCHEMA = + new Schema( + Arrays.asList( + Field.notNullable("info_name", UINT32), + new Field( + "info_value", + FieldType.nullable( + new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3, 4, 5})), + Arrays.asList( + Field.nullable("string_value", ArrowType.Utf8.INSTANCE), + Field.nullable("bool_value", ArrowType.Bool.INSTANCE), + Field.nullable("int64_value", INT64), + Field.nullable("int32_bitmask", INT32), + new Field( + "string_list", + FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList( + Field.nullable("item", ArrowType.Utf8.INSTANCE))), + new Field( + "int32_to_int32_list_map", + FieldType.nullable(new ArrowType.Map(/*keysSorted*/ false)), + Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + Field.notNullable("key", INT32), + Field.nullable("value", INT32))))))))); /** The schema of the result set of {@link AdbcConnection#getTableTypes()}. */ public static final Schema TABLE_TYPES_SCHEMA = @@ -52,16 +84,12 @@ private StandardSchemas() { new Field( "constraint_column_names", FieldType.notNullable(ArrowType.List.INSTANCE), - Collections.singletonList( - Field.nullable(BaseRepeatedValueVector.DATA_VECTOR_NAME, new ArrowType.Utf8()))), + Collections.singletonList(Field.nullable("item", new ArrowType.Utf8()))), new Field( "constraint_column_usage", FieldType.notNullable(ArrowType.List.INSTANCE), Collections.singletonList( - new Field( - BaseRepeatedValueVector.DATA_VECTOR_NAME, - FieldType.nullable(ArrowType.Struct.INSTANCE), - USAGE_SCHEMA)))); + new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), USAGE_SCHEMA)))); public static final List COLUMN_SCHEMA = Arrays.asList( @@ -93,18 +121,13 @@ private StandardSchemas() { "table_columns", FieldType.notNullable(ArrowType.List.INSTANCE), Collections.singletonList( - new Field( - BaseRepeatedValueVector.DATA_VECTOR_NAME, - FieldType.nullable(ArrowType.Struct.INSTANCE), - COLUMN_SCHEMA))), + new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), COLUMN_SCHEMA))), new Field( "table_constraints", FieldType.notNullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( - BaseRepeatedValueVector.DATA_VECTOR_NAME, - FieldType.nullable(ArrowType.Struct.INSTANCE), - CONSTRAINT_SCHEMA)))); + "item", FieldType.nullable(ArrowType.Struct.INSTANCE), CONSTRAINT_SCHEMA)))); public static final List DB_SCHEMA_SCHEMA = Arrays.asList( @@ -113,10 +136,7 @@ private StandardSchemas() { "db_schema_tables", FieldType.notNullable(ArrowType.List.INSTANCE), Collections.singletonList( - new Field( - BaseRepeatedValueVector.DATA_VECTOR_NAME, - FieldType.nullable(ArrowType.Struct.INSTANCE), - TABLE_SCHEMA)))); + new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), TABLE_SCHEMA)))); public static final Schema GET_OBJECTS_SCHEMA = new Schema( @@ -127,7 +147,7 @@ private StandardSchemas() { FieldType.notNullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( - BaseRepeatedValueVector.DATA_VECTOR_NAME, + "item", FieldType.nullable(ArrowType.Struct.INSTANCE), DB_SCHEMA_SCHEMA))))); } diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java new file mode 100644 index 0000000000..02c2ccac22 --- /dev/null +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.jdbc; + +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.arrow.adbc.core.AdbcInfoCode; +import org.apache.arrow.adbc.core.StandardSchemas; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; + +/** Helper class to track state needed to build up the info structure. */ +final class InfoMetadataBuilder implements AutoCloseable { + private static final byte STRING_VALUE_TYPE_ID = (byte) 0; + private static final Map SUPPORTED_CODES = new HashMap<>(); + private final Collection requestedCodes; + private final DatabaseMetaData dbmd; + private VectorSchemaRoot root; + + final UInt4Vector infoCodes; + final DenseUnionVector infoValues; + final VarCharVector stringValues; + + @FunctionalInterface + interface AddInfo { + void accept(InfoMetadataBuilder builder, int rowIndex) throws SQLException; + } + + static { + SUPPORTED_CODES.put( + AdbcInfoCode.VENDOR_NAME.getValue(), + (b, idx) -> { + b.setStringValue(idx, b.dbmd.getDatabaseProductName()); + }); + SUPPORTED_CODES.put( + AdbcInfoCode.VENDOR_VERSION.getValue(), + (b, idx) -> { + b.setStringValue(idx, b.dbmd.getDatabaseProductVersion()); + }); + SUPPORTED_CODES.put( + AdbcInfoCode.DRIVER_NAME.getValue(), + (b, idx) -> { + final String driverName = "ADBC JDBC Driver (" + b.dbmd.getDriverName() + ")"; + b.setStringValue(idx, driverName); + }); + SUPPORTED_CODES.put( + AdbcInfoCode.DRIVER_VERSION.getValue(), + (b, idx) -> { + final String driverVersion = b.dbmd.getDriverVersion() + " (ADBC Driver Version 0.0.1)"; + b.setStringValue(idx, driverVersion); + }); + } + + InfoMetadataBuilder(BufferAllocator allocator, Connection connection, int[] infoCodes) + throws SQLException { + this.requestedCodes = + infoCodes == null + ? SUPPORTED_CODES.keySet() + : IntStream.of(infoCodes).boxed().collect(Collectors.toList()); + this.root = VectorSchemaRoot.create(StandardSchemas.GET_INFO_SCHEMA, allocator); + this.dbmd = connection.getMetaData(); + this.infoCodes = (UInt4Vector) root.getVector(0); + this.infoValues = (DenseUnionVector) root.getVector(1); + this.stringValues = this.infoValues.getVarCharVector((byte) 0); + } + + void setStringValue(int index, final String value) { + infoValues.setValueCount(index + 1); + infoValues.setTypeId(index, STRING_VALUE_TYPE_ID); + stringValues.setSafe(index, value.getBytes(StandardCharsets.UTF_8)); + infoValues + .getOffsetBuffer() + .setInt((long) index * DenseUnionVector.OFFSET_WIDTH, stringValues.getLastSet()); + } + + VectorSchemaRoot build() throws SQLException { + int rowIndex = 0; + for (final Integer code : requestedCodes) { + final AddInfo metadata = SUPPORTED_CODES.get(code); + if (metadata == null) { + continue; + } + infoCodes.setSafe(rowIndex, code); + metadata.accept(this, rowIndex++); + } + root.setRowCount(rowIndex); + VectorSchemaRoot result = root; + root = null; + return result; + } + + @Override + public void close() throws Exception { + AutoCloseables.close(root); + } +} diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java index 8824b12cd4..5bbb3b0d1a 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java @@ -68,6 +68,17 @@ public AdbcStatement bulkIngest(String targetTableName, BulkIngestMode mode) return JdbcStatement.ingestRoot(allocator, connection, quirks, targetTableName, mode); } + @Override + public AdbcStatement getInfo(int[] infoCodes) throws AdbcException { + try { + final VectorSchemaRoot root = + new InfoMetadataBuilder(allocator, connection, infoCodes).build(); + return new FixedJdbcStatement(allocator, root); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public AdbcStatement getObjects( final GetObjectsDepth depth, @@ -80,7 +91,7 @@ public AdbcStatement getObjects( // Build up the metadata in-memory and then return a constant reader. try { final VectorSchemaRoot root = - new JdbcMetadataBuilder( + new ObjectMetadataBuilder( allocator, connection, depth, diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java similarity index 98% rename from java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcMetadataBuilder.java rename to java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java index 1d2984bd40..abc7cf9d85 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcMetadataBuilder.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java @@ -35,8 +35,8 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; -/** Helper class to track state needed to build up the metadata structure. */ -final class JdbcMetadataBuilder implements AutoCloseable { +/** Helper class to track state needed to build up the object metadata structure. */ +final class ObjectMetadataBuilder implements AutoCloseable { private final AdbcConnection.GetObjectsDepth depth; private final String catalogPattern; private final String dbSchemaPattern; @@ -73,7 +73,7 @@ final class JdbcMetadataBuilder implements AutoCloseable { final VarCharVector columnUsageFkTables; final VarCharVector columnUsageFkColumns; - JdbcMetadataBuilder( + ObjectMetadataBuilder( BufferAllocator allocator, Connection connection, final AdbcConnection.GetObjectsDepth depth, diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java index 6d9b3fc5e3..ecb0f4ea13 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java @@ -27,6 +27,7 @@ import java.util.stream.IntStream; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcInfoCode; import org.apache.arrow.adbc.core.AdbcStatement; import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.StandardSchemas; @@ -35,8 +36,10 @@ import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; @@ -75,6 +78,39 @@ public void afterEach() throws Exception { AutoCloseables.close(connection, database, allocator); } + @Test + void getInfo() throws Exception { + try (final AdbcStatement stmt = connection.getInfo()) { + try (final ArrowReader reader = stmt.getArrowReader()) { + assertThat(reader.getVectorSchemaRoot().getSchema()) + .isEqualTo(StandardSchemas.GET_INFO_SCHEMA); + assertThat(reader.loadNextBatch()).isTrue(); + assertThat(reader.getVectorSchemaRoot().getRowCount()).isGreaterThan(0); + } + } + } + + @Test + void getInfoByCode() throws Exception { + try (final AdbcStatement stmt = + connection.getInfo(new AdbcInfoCode[] {AdbcInfoCode.DRIVER_NAME})) { + try (final ArrowReader reader = stmt.getArrowReader()) { + final VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertThat(root.getSchema()).isEqualTo(StandardSchemas.GET_INFO_SCHEMA); + assertThat(reader.loadNextBatch()).isTrue(); + assertThat(root.getRowCount()).isEqualTo(1); + assertThat(((UInt4Vector) root.getVector(0)).getObject(0)) + .isEqualTo(AdbcInfoCode.DRIVER_NAME.getValue()); + assertThat( + ((DenseUnionVector) root.getVector(1)) + .getVarCharVector((byte) 0) + .getObject(0) + .toString()) + .isNotEmpty(); + } + } + } + @Test void getObjectsColumns() throws Exception { final Schema schema = util.ingestTableIntsStrs(allocator, connection, tableName); diff --git a/python/adbc_driver_manager/.gitignore b/python/adbc_driver_manager/.gitignore index 92bc5078cd..d87abbfd5c 100644 --- a/python/adbc_driver_manager/.gitignore +++ b/python/adbc_driver_manager/.gitignore @@ -16,4 +16,5 @@ # under the License. adbc_driver_manager/*.c +adbc_driver_manager/*.cpp build/ diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py index c07a38bc99..4d9222169d 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py +++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py @@ -19,6 +19,7 @@ INGEST_OPTION_TARGET_TABLE, AdbcConnection, AdbcDatabase, + AdbcInfoCode, AdbcStatement, AdbcStatusCode, ArrowArrayHandle, diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 6dbcd403f3..97b1c9b096 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -24,8 +24,9 @@ import typing from typing import List import cython -from libc.stdint cimport int32_t, uint8_t, uintptr_t +from libc.stdint cimport int32_t, uint8_t, uint32_t, uintptr_t from libc.string cimport memset +from libcpp.vector cimport vector as c_vector if typing.TYPE_CHECKING: from typing import Self @@ -70,6 +71,13 @@ cdef extern from "adbc.h" nogil: cdef int ADBC_OBJECT_DEPTH_TABLES cdef int ADBC_OBJECT_DEPTH_COLUMNS + cdef uint32_t ADBC_INFO_VENDOR_NAME + cdef uint32_t ADBC_INFO_VENDOR_VERSION + cdef uint32_t ADBC_INFO_VENDOR_ARROW_VERSION + cdef uint32_t ADBC_INFO_DRIVER_NAME + cdef uint32_t ADBC_INFO_DRIVER_VERSION + cdef uint32_t ADBC_INFO_DRIVER_ARROW_VERSION + ctypedef void (*CAdbcErrorRelease)(CAdbcError*) cdef struct CAdbcError"AdbcError": @@ -108,6 +116,12 @@ cdef extern from "adbc.h" nogil: size_t serialized_length, CAdbcStatement* statement, CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetInfo( + CAdbcConnection* connection, + uint32_t* info_codes, + size_t info_codes_length, + CAdbcStatement* statement, + CAdbcError* error) CAdbcStatusCode AdbcConnectionGetObjects( CAdbcConnection* connection, int depth, @@ -217,6 +231,15 @@ class AdbcStatusCode(enum.IntEnum): UNAUTHORIZED = ADBC_STATUS_UNAUTHORIZED +class AdbcInfoCode(enum.IntEnum): + VENDOR_NAME = ADBC_INFO_VENDOR_NAME + VENDOR_VERSION = ADBC_INFO_VENDOR_VERSION + VENDOR_ARROW_VERSION = ADBC_INFO_VENDOR_ARROW_VERSION + DRIVER_NAME = ADBC_INFO_DRIVER_NAME + DRIVER_VERSION = ADBC_INFO_DRIVER_VERSION + DRIVER_ARROW_VERSION = ADBC_INFO_DRIVER_ARROW_VERSION + + class Error(Exception): """PEP-249 compliant base exception class. @@ -474,6 +497,39 @@ cdef class AdbcConnection(_AdbcHandle): cdef CAdbcError c_error = empty_error() check_error(AdbcConnectionCommit(&self.connection, &c_error), &c_error) + def get_info(self, info_codes=None): + """ + Get metadata about the database/driver. + """ + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + cdef AdbcStatement statement = AdbcStatement(self) + cdef c_vector[uint32_t] c_info_codes + + if info_codes: + for info_code in info_codes: + if isinstance(info_code, int): + c_info_codes.push_back(info_code) + else: + c_info_codes.push_back(info_code.value) + + status = AdbcConnectionGetInfo( + &self.connection, + c_info_codes.data(), + c_info_codes.size(), + &statement.statement, + &c_error) + else: + status = AdbcConnectionGetInfo( + &self.connection, + NULL, + 0, + &statement.statement, + &c_error) + + check_error(status, &c_error) + return statement + def get_objects(self, depth, catalog=None, db_schema=None, table_name=None, table_types=None, column_name=None) -> AdbcStatement: """ diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py index 5f716e9f51..e1123537b6 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py @@ -56,6 +56,28 @@ def test_database_init(): pass +def test_connection_get_info(sqlite): + _, conn = sqlite + codes = [ + adbc_driver_manager.AdbcInfoCode.VENDOR_NAME, + adbc_driver_manager.AdbcInfoCode.VENDOR_VERSION.value, + adbc_driver_manager.AdbcInfoCode.DRIVER_NAME, + adbc_driver_manager.AdbcInfoCode.DRIVER_VERSION.value, + ] + with conn.get_info() as stmt: + table = _import(stmt.get_stream()).read_all() + assert table.num_rows > 0 + data = dict(zip(table[0].to_pylist(), table[1].to_pylist())) + for code in codes: + assert code in data + assert data[code] + + with conn.get_info(codes) as stmt: + table = _import(stmt.get_stream()).read_all() + assert table.num_rows > 0 + assert set(codes) == set(table[0].to_pylist()) + + def test_connection_get_objects(sqlite): _, conn = sqlite data = pyarrow.record_batch( diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py index 74f1f7c07e..422fca8dd2 100644 --- a/python/adbc_driver_manager/setup.py +++ b/python/adbc_driver_manager/setup.py @@ -25,12 +25,13 @@ ext_modules=cythonize( Extension( name="adbc_driver_manager._lib", + extra_compile_args=["-ggdb", "-Og"], + include_dirs=["../../", "../../c/driver_manager"], + language="c++", sources=[ "adbc_driver_manager/_lib.pyx", "../../c/driver_manager/adbc_driver_manager.cc", ], - include_dirs=["../../", "../../c/driver_manager"], - extra_compile_args=["-ggdb", "-Og"], ), ), packages=["adbc_driver_manager"],