Skip to content

Commit

Permalink
[C++] Implement CommandGetImportedKeys and CommandGetExportedKeys (ap…
Browse files Browse the repository at this point in the history
…ache#163)

* Implement CommandGetImportedKeys and CommandGetExportedKeys on Flight SQL Server example

* Refactor DoGet methods to reduce code duplication
  • Loading branch information
rafael-telles committed Oct 19, 2021
1 parent 0dbacf9 commit dab5f35
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 18 deletions.
111 changes: 93 additions & 18 deletions cpp/src/arrow/flight/flight-sql/example/sqlite_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,21 @@ void SQLiteFlightSqlServer::ExecuteSql(const std::string& sql) {
}
}

Status DoGetSQLiteQuery(sqlite3* db, const std::string& query,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<FlightDataStream>* result) {
std::shared_ptr<SqliteStatement> statement;
ARROW_RETURN_NOT_OK(SqliteStatement::Create(db, query, &statement));

std::shared_ptr<SqliteStatementBatchReader> reader;
ARROW_RETURN_NOT_OK(SqliteStatementBatchReader::Create(
statement, schema, &reader));

*result = std::unique_ptr<FlightDataStream>(new RecordBatchStream(reader));

return Status::OK();
}

Status GetFlightInfoForCommand(const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info,
const google::protobuf::Message& command,
Expand Down Expand Up @@ -299,16 +314,7 @@ Status SQLiteFlightSqlServer::DoGetTableTypes(const ServerCallContext& context,
std::unique_ptr<FlightDataStream>* result) {
std::string query = "SELECT DISTINCT type as table_type FROM sqlite_master";

std::shared_ptr<SqliteStatement> statement;
ARROW_RETURN_NOT_OK(SqliteStatement::Create(db_, query, &statement));

std::shared_ptr<SqliteStatementBatchReader> reader;
ARROW_RETURN_NOT_OK(SqliteStatementBatchReader::Create(
statement, SqlSchema::GetTableTypesSchema(), &reader));

*result = std::unique_ptr<FlightDataStream>(new RecordBatchStream(reader));

return Status::OK();
return DoGetSQLiteQuery(db_, query, SqlSchema::GetTableTypesSchema(), result);
}

Status SQLiteFlightSqlServer::GetFlightInfoPrimaryKeys(
Expand Down Expand Up @@ -343,17 +349,86 @@ SQLiteFlightSqlServer::DoGetPrimaryKeys(const pb::sql::CommandGetPrimaryKeys &co

table_query << " and table_name LIKE '" << command.table() << "'";

std::shared_ptr<SqliteStatement> statement;
ARROW_RETURN_NOT_OK(SqliteStatement::Create(db_, table_query.str(), &statement));
return DoGetSQLiteQuery(db_, table_query.str(), SqlSchema::GetPrimaryKeysSchema(),
result);
}

std::shared_ptr<SqliteStatementBatchReader> reader;
ARROW_RETURN_NOT_OK(SqliteStatementBatchReader::Create(
statement, SqlSchema::GetPrimaryKeysSchema(), &reader));
std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter) {
return R"(SELECT * FROM (SELECT NULL AS pk_catalog_name,
NULL AS pk_schema_name,
p."table" AS pk_table_name,
p."to" AS pk_column_name,
NULL AS fk_catalog_name,
NULL AS fk_schema_name,
m.name AS fk_table_name,
p."from" AS fk_column_name,
p.seq AS key_sequence,
NULL AS pk_key_name,
NULL AS fk_key_name,
CASE
WHEN p.on_update = 'CASCADE' THEN 0
WHEN p.on_update = 'RESTRICT' THEN 1
WHEN p.on_update = 'SET NULL' THEN 2
WHEN p.on_update = 'NO ACTION' THEN 3
WHEN p.on_update = 'SET DEFAULT' THEN 4
END AS update_rule,
CASE
WHEN p.on_delete = 'CASCADE' THEN 0
WHEN p.on_delete = 'RESTRICT' THEN 1
WHEN p.on_delete = 'SET NULL' THEN 2
WHEN p.on_delete = 'NO ACTION' THEN 3
WHEN p.on_delete = 'SET DEFAULT' THEN 4
END AS delete_rule
FROM sqlite_master m
JOIN pragma_foreign_key_list(m.name) p ON m.name != p."table"
WHERE m.type = 'table') WHERE )" + filter + R"( ORDER BY
pk_catalog_name, pk_schema_name, pk_table_name, pk_key_name, key_sequence)";
}

*result = std::unique_ptr<FlightDataStream>(
new RecordBatchStream(reader));
Status SQLiteFlightSqlServer::GetFlightInfoImportedKeys(
const pb::sql::CommandGetImportedKeys& command, const ServerCallContext& context,
const FlightDescriptor& descriptor, std::unique_ptr<FlightInfo>* info) {
return GetFlightInfoForCommand(descriptor, info, command,
SqlSchema::GetImportedAndExportedKeysSchema());
}

return Status::OK();
Status SQLiteFlightSqlServer::DoGetImportedKeys(
const pb::sql::CommandGetImportedKeys& command, const ServerCallContext& context,
std::unique_ptr<FlightDataStream>* result) {
std::string filter = "fk_table_name = '" + command.table() + "'";
if (command.has_catalog()) {
filter += " AND fk_catalog_name = '" + command.catalog() + "'";
}
if (command.has_schema()) {
filter += " AND fk_schema_name = '" + command.schema() + "'";
}
std::string query = PrepareQueryForGetImportedOrExportedKeys(filter);

return DoGetSQLiteQuery(db_, query, SqlSchema::GetImportedAndExportedKeysSchema(),
result);
}

Status SQLiteFlightSqlServer::GetFlightInfoExportedKeys(
const pb::sql::CommandGetExportedKeys& command, const ServerCallContext& context,
const FlightDescriptor& descriptor, std::unique_ptr<FlightInfo>* info) {
return GetFlightInfoForCommand(descriptor, info, command,
SqlSchema::GetImportedAndExportedKeysSchema());
}

Status SQLiteFlightSqlServer::DoGetExportedKeys(
const pb::sql::CommandGetExportedKeys& command, const ServerCallContext& context,
std::unique_ptr<FlightDataStream>* result) {
std::string filter = "pk_table_name = '" + command.table() + "'";
if (command.has_catalog()) {
filter += " AND pk_catalog_name = '" + command.catalog() + "'";
}
if (command.has_schema()) {
filter += " AND pk_schema_name = '" + command.schema() + "'";
}
std::string query = PrepareQueryForGetImportedOrExportedKeys(filter);

return DoGetSQLiteQuery(db_, query, SqlSchema::GetImportedAndExportedKeysSchema(),
result);
}

} // namespace example
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/arrow/flight/flight-sql/example/sqlite_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase {
std::unique_ptr<FlightInfo> *info) override;
Status DoGetTableTypes(const ServerCallContext &context,
std::unique_ptr<FlightDataStream> *result) override;
Status GetFlightInfoImportedKeys(const pb::sql::CommandGetImportedKeys &command,
const ServerCallContext &context,
const FlightDescriptor &descriptor,
std::unique_ptr<FlightInfo> *info) override;
Status DoGetImportedKeys(const pb::sql::CommandGetImportedKeys &command,
const ServerCallContext &context,
std::unique_ptr<FlightDataStream> *result) override;
Status GetFlightInfoExportedKeys(const pb::sql::CommandGetExportedKeys &command,
const ServerCallContext &context,
const FlightDescriptor &descriptor,
std::unique_ptr<FlightInfo> *info) override;
Status DoGetExportedKeys(const pb::sql::CommandGetExportedKeys &command,
const ServerCallContext &context,
std::unique_ptr<FlightDataStream> *result) override;

Status GetFlightInfoPrimaryKeys(const pb::sql::CommandGetPrimaryKeys &command,
const ServerCallContext &context,
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/flight/flight-sql/sql_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ return arrow::schema({field("catalog_name", utf8()), field("schema_name", utf8()
field("key_sequence", int64()), field("key_name", utf8())});
}

std::shared_ptr<Schema> SqlSchema::GetImportedAndExportedKeysSchema() {
return arrow::schema({
field("pk_catalog_name", utf8(), true),
field("pk_schema_name", utf8(), true),
field("pk_table_name", utf8(), false),
field("pk_column_name", utf8(), false),
field("fk_catalog_name", utf8(), true),
field("fk_schema_name", utf8(), true),
field("fk_table_name", utf8(), false),
field("fk_column_name", utf8(), false),
field("key_sequence", int32(), false),
field("fk_key_name", utf8(), true),
field("pk_key_name", utf8(), true),
field("update_rule", uint8(), false),
field("delete_rule", uint8(), false)
});
}

} // namespace sql
} // namespace flight
} // namespace arrow
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/flight-sql/sql_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ class SqlSchema {
/// flags is set to true.
/// \return The default schema template.
static std::shared_ptr<Schema> GetPrimaryKeysSchema();

/// \brief Gets the Schema used on CommandGetImportedKeys and CommandGetExportedKeys
/// response.
/// \return The default schema template.
static std::shared_ptr<Schema> GetImportedAndExportedKeysSchema();
};
} // namespace sql
} // namespace flight
Expand Down
65 changes: 65 additions & 0 deletions cpp/src/arrow/flight/flight-sql/sql_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,71 @@ TEST(TestFlightSqlServer, TestCommandGetPrimaryKeys) {
ASSERT_TRUE(expected_table->Equals(*table));
}

TEST(TestFlightSqlServer, TestCommandGetImportedKeys) {
std::unique_ptr<FlightInfo> flight_info;
ASSERT_OK(sql_client->GetImportedKeys({}, NULLPTR, NULLPTR, "intTable", &flight_info));

std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(sql_client->DoGet({}, flight_info->endpoints()[0].ticket, &stream));

std::shared_ptr<Table> table;
ASSERT_OK(stream->ReadAll(&table));

DECLARE_NULL_ARRAY(pk_catalog_name, String, 1);
DECLARE_NULL_ARRAY(pk_schema_name, String, 1);
DECLARE_ARRAY(pk_table_name, String, ({"foreignTable"}));
DECLARE_ARRAY(pk_column_name, String, ({"id"}));
DECLARE_NULL_ARRAY(fk_catalog_name, String, 1);
DECLARE_NULL_ARRAY(fk_schema_name, String, 1);
DECLARE_ARRAY(fk_table_name, String, ({"intTable"}));
DECLARE_ARRAY(fk_column_name, String, ({"foreignId"}));
DECLARE_ARRAY(key_sequence, Int32, ({0}));
DECLARE_NULL_ARRAY(fk_key_name, String, 1);
DECLARE_NULL_ARRAY(pk_key_name, String, 1);
DECLARE_ARRAY(update_rule, UInt8, ({3}));
DECLARE_ARRAY(delete_rule, UInt8, ({3}));

const std::shared_ptr<Table>& expected_table =
Table::Make(SqlSchema::GetImportedAndExportedKeysSchema(),
{pk_catalog_name, pk_schema_name, pk_table_name, pk_column_name,
fk_catalog_name, fk_schema_name, fk_table_name, fk_column_name,
key_sequence, fk_key_name, pk_key_name, update_rule, delete_rule});
ASSERT_TRUE(expected_table->Equals(*table));
}

TEST(TestFlightSqlServer, TestCommandGetExportedKeys) {
std::unique_ptr<FlightInfo> flight_info;
ASSERT_OK(
sql_client->GetExportedKeys({}, NULLPTR, NULLPTR, "foreignTable", &flight_info));

std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(sql_client->DoGet({}, flight_info->endpoints()[0].ticket, &stream));

std::shared_ptr<Table> table;
ASSERT_OK(stream->ReadAll(&table));

DECLARE_NULL_ARRAY(pk_catalog_name, String, 1);
DECLARE_NULL_ARRAY(pk_schema_name, String, 1);
DECLARE_ARRAY(pk_table_name, String, ({"foreignTable"}));
DECLARE_ARRAY(pk_column_name, String, ({"id"}));
DECLARE_NULL_ARRAY(fk_catalog_name, String, 1);
DECLARE_NULL_ARRAY(fk_schema_name, String, 1);
DECLARE_ARRAY(fk_table_name, String, ({"intTable"}));
DECLARE_ARRAY(fk_column_name, String, ({"foreignId"}));
DECLARE_ARRAY(key_sequence, Int32, ({0}));
DECLARE_NULL_ARRAY(fk_key_name, String, 1);
DECLARE_NULL_ARRAY(pk_key_name, String, 1);
DECLARE_ARRAY(update_rule, UInt8, ({3}));
DECLARE_ARRAY(delete_rule, UInt8, ({3}));

const std::shared_ptr<Table>& expected_table =
Table::Make(SqlSchema::GetImportedAndExportedKeysSchema(),
{pk_catalog_name, pk_schema_name, pk_table_name, pk_column_name,
fk_catalog_name, fk_schema_name, fk_table_name, fk_column_name,
key_sequence, fk_key_name, pk_key_name, update_rule, delete_rule});
ASSERT_TRUE(expected_table->Equals(*table));
}

auto env =
::testing::AddGlobalTestEnvironment(new TestFlightSqlServer);

Expand Down

0 comments on commit dab5f35

Please sign in to comment.