|
29 | 29 | #include <arrow/record_batch.h> |
30 | 30 | #include <arrow/status.h> |
31 | 31 | #include <arrow/table.h> |
| 32 | +#include <arrow/util/config.h> |
32 | 33 | #include <arrow/util/logging.h> |
33 | 34 | #include <arrow/util/string_builder.h> |
34 | 35 |
|
@@ -621,6 +622,102 @@ class SqliteStatementImpl { |
621 | 622 | return ADBC_STATUS_INVALID_STATE; |
622 | 623 | } |
623 | 624 |
|
| 625 | + AdbcStatusCode GetInfo(const std::shared_ptr<SqliteStatementImpl>& self, |
| 626 | + uint32_t* info_codes, size_t info_codes_length, |
| 627 | + struct AdbcError* error) { |
| 628 | + static std::shared_ptr<arrow::Schema> kInfoSchema = arrow::schema({ |
| 629 | + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), |
| 630 | + arrow::field( |
| 631 | + "info_value", |
| 632 | + arrow::dense_union({ |
| 633 | + arrow::field("string_value", arrow::utf8()), |
| 634 | + arrow::field("bool_value", arrow::boolean()), |
| 635 | + arrow::field("int64_value", arrow::int64()), |
| 636 | + arrow::field("int32_bitmask", arrow::int32()), |
| 637 | + arrow::field("string_list", arrow::list(arrow::utf8())), |
| 638 | + arrow::field("int32_to_int32_list_map", |
| 639 | + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), |
| 640 | + })), |
| 641 | + }); |
| 642 | + static int kStringValueCode = 0; |
| 643 | + |
| 644 | + static std::vector<uint32_t> kSupported = { |
| 645 | + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, |
| 646 | + ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, |
| 647 | + }; |
| 648 | + |
| 649 | + if (!info_codes) { |
| 650 | + info_codes = kSupported.data(); |
| 651 | + info_codes_length = kSupported.size(); |
| 652 | + } |
| 653 | + |
| 654 | + arrow::UInt32Builder info_name; |
| 655 | + std::unique_ptr<arrow::ArrayBuilder> info_value_builder; |
| 656 | + ADBC_RETURN_NOT_OK( |
| 657 | + FromArrowStatus(MakeBuilder(arrow::default_memory_pool(), |
| 658 | + kInfoSchema->field(1)->type(), &info_value_builder), |
| 659 | + error)); |
| 660 | + arrow::DenseUnionBuilder* info_value = |
| 661 | + static_cast<arrow::DenseUnionBuilder*>(info_value_builder.get()); |
| 662 | + arrow::StringBuilder* info_string = |
| 663 | + static_cast<arrow::StringBuilder*>(info_value->child_builder(0).get()); |
| 664 | + |
| 665 | + for (size_t i = 0; i < info_codes_length; i++) { |
| 666 | + switch (info_codes[i]) { |
| 667 | + case ADBC_INFO_VENDOR_NAME: |
| 668 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); |
| 669 | + ADBC_RETURN_NOT_OK( |
| 670 | + FromArrowStatus(info_value->Append(kStringValueCode), error)); |
| 671 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_string->Append("SQLite3"), error)); |
| 672 | + break; |
| 673 | + case ADBC_INFO_VENDOR_VERSION: |
| 674 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); |
| 675 | + ADBC_RETURN_NOT_OK( |
| 676 | + FromArrowStatus(info_value->Append(kStringValueCode), error)); |
| 677 | + ADBC_RETURN_NOT_OK( |
| 678 | + FromArrowStatus(info_string->Append(sqlite3_libversion()), error)); |
| 679 | + break; |
| 680 | + case ADBC_INFO_DRIVER_NAME: |
| 681 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); |
| 682 | + ADBC_RETURN_NOT_OK( |
| 683 | + FromArrowStatus(info_value->Append(kStringValueCode), error)); |
| 684 | + ADBC_RETURN_NOT_OK( |
| 685 | + FromArrowStatus(info_string->Append("ADBC C SQLite3"), error)); |
| 686 | + break; |
| 687 | + case ADBC_INFO_DRIVER_VERSION: |
| 688 | + // TODO: set up CMake to embed version info |
| 689 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); |
| 690 | + ADBC_RETURN_NOT_OK( |
| 691 | + FromArrowStatus(info_value->Append(kStringValueCode), error)); |
| 692 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_string->Append("0.0.1"), error)); |
| 693 | + break; |
| 694 | + case ADBC_INFO_DRIVER_ARROW_VERSION: |
| 695 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Append(info_codes[i]), error)); |
| 696 | + ADBC_RETURN_NOT_OK( |
| 697 | + FromArrowStatus(info_value->Append(kStringValueCode), error)); |
| 698 | + ADBC_RETURN_NOT_OK(FromArrowStatus( |
| 699 | + info_string->Append("Arrow/C++ " ARROW_VERSION_STRING), error)); |
| 700 | + break; |
| 701 | + default: |
| 702 | + // Unrecognized |
| 703 | + break; |
| 704 | + } |
| 705 | + } |
| 706 | + |
| 707 | + arrow::ArrayVector arrays(2); |
| 708 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_name.Finish(&arrays[0]), error)); |
| 709 | + ADBC_RETURN_NOT_OK(FromArrowStatus(info_value->Finish(&arrays[1]), error)); |
| 710 | + const int64_t rows = arrays[0]->length(); |
| 711 | + auto status = arrow::RecordBatchReader::Make( |
| 712 | + { |
| 713 | + arrow::RecordBatch::Make(kInfoSchema, rows, std::move(arrays)), |
| 714 | + }, |
| 715 | + kInfoSchema) |
| 716 | + .Value(&result_reader_); |
| 717 | + ADBC_RETURN_NOT_OK(FromArrowStatus(status, error)); |
| 718 | + return ADBC_STATUS_OK; |
| 719 | + } |
| 720 | + |
624 | 721 | AdbcStatusCode GetObjects(const std::shared_ptr<SqliteStatementImpl>& self, int depth, |
625 | 722 | const char* catalog, const char* db_schema, |
626 | 723 | const char* table_name, const char** table_type, |
@@ -1234,6 +1331,16 @@ AdbcStatusCode SqliteConnectionCommit(struct AdbcConnection* connection, |
1234 | 1331 | return (*ptr)->Commit(error); |
1235 | 1332 | } |
1236 | 1333 |
|
| 1334 | +AdbcStatusCode SqliteConnectionGetInfo(struct AdbcConnection* connection, |
| 1335 | + uint32_t* info_codes, size_t info_codes_length, |
| 1336 | + struct AdbcStatement* statement, |
| 1337 | + struct AdbcError* error) { |
| 1338 | + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; |
| 1339 | + auto ptr = |
| 1340 | + reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data); |
| 1341 | + return (*ptr)->GetInfo(*ptr, info_codes, info_codes_length, error); |
| 1342 | +} |
| 1343 | + |
1237 | 1344 | AdbcStatusCode SqliteConnectionGetObjects( |
1238 | 1345 | struct AdbcConnection* connection, int depth, const char* catalog, |
1239 | 1346 | const char* db_schema, const char* table_name, const char** table_types, |
@@ -1441,6 +1548,14 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, |
1441 | 1548 | return SqliteConnectionCommit(connection, error); |
1442 | 1549 | } |
1443 | 1550 |
|
| 1551 | +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, |
| 1552 | + uint32_t* info_codes, size_t info_codes_length, |
| 1553 | + struct AdbcStatement* statement, |
| 1554 | + struct AdbcError* error) { |
| 1555 | + return SqliteConnectionGetInfo(connection, info_codes, info_codes_length, statement, |
| 1556 | + error); |
| 1557 | +} |
| 1558 | + |
1444 | 1559 | AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, |
1445 | 1560 | const char* catalog, const char* db_schema, |
1446 | 1561 | const char* table_name, const char** table_types, |
@@ -1566,6 +1681,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver, |
1566 | 1681 | driver->DatabaseSetOption = SqliteDatabaseSetOption; |
1567 | 1682 |
|
1568 | 1683 | driver->ConnectionCommit = SqliteConnectionCommit; |
| 1684 | + driver->ConnectionGetInfo = SqliteConnectionGetInfo; |
1569 | 1685 | driver->ConnectionGetObjects = SqliteConnectionGetObjects; |
1570 | 1686 | driver->ConnectionGetTableSchema = SqliteConnectionGetTableSchema; |
1571 | 1687 | driver->ConnectionGetTableTypes = SqliteConnectionGetTableTypes; |
|
0 commit comments