Skip to content

Commit

Permalink
Merge pull request #9 from jcralmeida/restructure-public-headers
Browse files Browse the repository at this point in the history
Restructure public headers
  • Loading branch information
Rafael Telles authored Jan 17, 2022
2 parents 7310967 + b04518a commit 98df4a4
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)

add_subdirectory(flight_sql)
add_subdirectory(spi)
add_subdirectory(odbcabstraction)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
cmake_minimum_required(VERSION 3.11)
set(CMAKE_CXX_STANDARD 11)

include_directories(${CMAKE_SOURCE_DIR}/spi)
include_directories(
include
${CMAKE_SOURCE_DIR}/odbcabstraction/include)

SET(Arrow_STATIC ON)

Expand Down Expand Up @@ -107,7 +109,10 @@ set_target_properties(arrow_odbc_spi_impl
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$<CONFIG>/lib
)

target_link_libraries(arrow_odbc_spi_impl spi ${ARROW_ODBC_SPI_THIRDPARTY_LIBS})
target_link_libraries(
arrow_odbc_spi_impl
odbcabstraction
${ARROW_ODBC_SPI_THIRDPARTY_LIBS})
target_include_directories(arrow_odbc_spi_impl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

# CLI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// under the License.

#include "flight_sql_auth_method.h"
#include "exceptions.h"

#include <odbcabstraction/exceptions.h>
#include "flight_sql_connection.h"

#include <arrow/flight/client.h>
Expand All @@ -33,8 +34,8 @@ using arrow::Result;
using arrow::flight::FlightCallOptions;
using arrow::flight::FlightClient;
using arrow::flight::TimeoutDuration;
using driver::spi::AuthenticationException;
using driver::spi::Connection;
using driver::odbcabstraction::AuthenticationException;
using driver::odbcabstraction::Connection;

namespace {
class NoOpAuthMethod : public FlightSqlAuthMethod {
Expand Down Expand Up @@ -86,11 +87,11 @@ class UserPasswordAuthMethod : public FlightSqlAuthMethod {

std::unique_ptr<FlightSqlAuthMethod> FlightSqlAuthMethod::FromProperties(
const std::unique_ptr<FlightClient> &client,
const std::map<std::string, Connection::Property> &properties) {
const Connection::ConnPropertyMap &properties) {

// Check if should use user-password authentication
const auto &it_user = properties.find(Connection::USER);
const auto &it_password = properties.find(Connection::PASSWORD);
const auto &it_user = properties.find(FlightSqlConnection::USER);
const auto &it_password = properties.find(FlightSqlConnection::PASSWORD);
if (it_user != properties.end() || it_password != properties.end()) {
const std::string &user = it_user != properties.end()
? boost::get<std::string>(it_user->second)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include "connection.h"
#include <odbcabstraction/connection.h>
#include "flight_sql_connection.h"
#include <arrow/flight/client.h>
#include <map>
Expand All @@ -35,7 +35,7 @@ class FlightSqlAuthMethod {

static std::unique_ptr<FlightSqlAuthMethod> FromProperties(
const std::unique_ptr<arrow::flight::FlightClient> &client,
const std::map<std::string, spi::Connection::Property> &properties);
const odbcabstraction::Connection::ConnPropertyMap &properties);

protected:
FlightSqlAuthMethod() = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
// under the License.

#include "flight_sql_connection.h"
#include "exceptions.h"

#include <odbcabstraction/exceptions.h>
#include "flight_sql_auth_method.h"
#include <boost/lexical_cast.hpp>
#include <boost/optional.hpp>
#include <boost/algorithm/string/join.hpp>
#include <iostream>

namespace driver {
Expand All @@ -32,27 +35,44 @@ using arrow::flight::FlightClientOptions;
using arrow::flight::Location;
using arrow::flight::TimeoutDuration;
using arrow::flight::sql::FlightSqlClient;
using spi::Connection;
using spi::DriverException;
using spi::OdbcVersion;
using spi::Statement;
using driver::odbcabstraction::Connection;
using driver::odbcabstraction::DriverException;
using driver::odbcabstraction::OdbcVersion;
using driver::odbcabstraction::Statement;

const std::string FlightSqlConnection::HOST = "host";
const std::string FlightSqlConnection::PORT = "port";
const std::string FlightSqlConnection::USER = "user";
const std::string FlightSqlConnection::PASSWORD = "password";
const std::string FlightSqlConnection::USE_TLS = "useTls";

namespace {
// TODO: Add properties for getting the certificates
// TODO: Check if gRPC can use the system truststore, if not copy from Drill

inline void ThrowIfNotOK(const Status &status) {
if (!status.ok()) {
throw DriverException(status.ToString());
}
}

Connection::ConnPropertyMap::const_iterator TrackMissingRequiredProperty(const std::string& property,
const Connection::ConnPropertyMap &properties, std::vector<std::string> &missing_attr) {
Connection::ConnPropertyMap::const_iterator prop_iter = properties.find(property);
if (properties.end() == prop_iter) {
missing_attr.push_back(property);
}
return prop_iter;
}

} // namespace

void FlightSqlConnection::Connect(
const std::map<std::string, Property> &properties,
const ConnPropertyMap &properties,
std::vector<std::string> &missing_attr) {
try {
Location location = BuildLocation(properties);
FlightClientOptions client_options = BuildFlightClientOptions(properties);
Location location = BuildLocation(properties, missing_attr);
FlightClientOptions client_options = BuildFlightClientOptions(properties, missing_attr);

std::unique_ptr<FlightClient> flight_client;
ThrowIfNotOK(
Expand Down Expand Up @@ -88,20 +108,32 @@ FlightCallOptions FlightSqlConnection::BuildCallOptions() {
}

FlightClientOptions FlightSqlConnection::BuildFlightClientOptions(
const std::map<std::string, Property> &properties) {
const ConnPropertyMap &properties, std::vector<std::string> &missing_attr) {
FlightClientOptions options;
// TODO: Set up TLS properties
return options;
}

Location FlightSqlConnection::BuildLocation(
const std::map<std::string, Property> &properties) {
const std::string &host = boost::get<std::string>(properties.at(HOST));
const int &port = boost::get<int>(properties.at(PORT));
const ConnPropertyMap &properties, std::vector<std::string> &missing_attr) {
const auto& host_iter = TrackMissingRequiredProperty(
HOST, properties, missing_attr);

const auto& port_iter = TrackMissingRequiredProperty(
PORT, properties, missing_attr);

if (!missing_attr.empty()) {
std::string missing_attr_str = std::string("Missing required properties: ")
+ boost::algorithm::join(missing_attr, ", ");
throw DriverException(missing_attr_str);
}

const std::string &host = host_iter->second;
const int &port = boost::lexical_cast<int>(port_iter->second);

Location location;
const auto &it_use_tls = properties.find(USE_TLS);
if (it_use_tls != properties.end() && boost::get<bool>(it_use_tls->second)) {
if (it_use_tls != properties.end() && boost::lexical_cast<bool>(it_use_tls->second)) {
ThrowIfNotOK(Location::ForGrpcTls(host, port, &location));
} else {
ThrowIfNotOK(Location::ForGrpcTcp(host, port, &location));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,38 @@

#pragma once

#include "connection.h"
#include <odbcabstraction/connection.h>

#include <arrow/flight/api.h>
#include <arrow/flight/sql/api.h>

namespace driver {
namespace flight_sql {

class FlightSqlConnection : public spi::Connection {
class FlightSqlConnection : public odbcabstraction::Connection {

private:
std::map<AttributeId, Attribute> attribute_;
arrow::flight::FlightCallOptions call_options_;
std::unique_ptr<arrow::flight::sql::FlightSqlClient> sql_client_;
spi::OdbcVersion odbc_version_;
odbcabstraction::OdbcVersion odbc_version_;
bool closed_;

public:
explicit FlightSqlConnection(spi::OdbcVersion odbc_version);
static const std::string HOST;
static const std::string PORT;
static const std::string USER;
static const std::string PASSWORD;
static const std::string USE_TLS;

explicit FlightSqlConnection(odbcabstraction::OdbcVersion odbc_version);

void Connect(const std::map<std::string, Property> &properties,
void Connect(const ConnPropertyMap &properties,
std::vector<std::string> &missing_attr) override;

void Close() override;

std::shared_ptr<spi::Statement> CreateStatement() override;
std::shared_ptr<odbcabstraction::Statement> CreateStatement() override;

void SetAttribute(AttributeId attribute, const Attribute &value) override;

Expand All @@ -52,12 +60,12 @@ class FlightSqlConnection : public spi::Connection {
/// \brief Builds a Location used for FlightClient connection.
/// \note Visible for testing
static arrow::flight::Location
BuildLocation(const std::map<std::string, Property> &properties);
BuildLocation(const ConnPropertyMap &properties, std::vector<std::string> &missing_attr);

/// \brief Builds a FlightClientOptions used for FlightClient connection.
/// \note Visible for testing
static arrow::flight::FlightClientOptions
BuildFlightClientOptions(const std::map<std::string, Property> &properties);
BuildFlightClientOptions(const ConnPropertyMap &properties, std::vector<std::string> &missing_attr);

/// \brief Builds a FlightCallOptions used on gRPC calls.
/// \note Visible for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ namespace flight_sql {

using arrow::flight::Location;
using arrow::flight::TimeoutDuration;
using spi::Connection;
using odbcabstraction::Connection;

TEST(AttributeTests, SetAndGetAttribute) {
FlightSqlConnection connection(spi::V_3);
FlightSqlConnection connection(odbcabstraction::V_3);

connection.SetAttribute(Connection::CONNECTION_TIMEOUT, 200);
const boost::optional<Connection::Attribute> firstValue =
Expand All @@ -49,7 +49,7 @@ TEST(AttributeTests, SetAndGetAttribute) {
}

TEST(AttributeTests, GetAttributeWithoutSetting) {
FlightSqlConnection connection(spi::V_3);
FlightSqlConnection connection(odbcabstraction::V_3);

const boost::optional<Connection::Attribute> anOptional =
connection.GetAttribute(Connection::CONNECTION_TIMEOUT);
Expand All @@ -60,14 +60,15 @@ TEST(AttributeTests, GetAttributeWithoutSetting) {
}

TEST(BuildLocationTests, ForTcp) {
std::vector<std::string> missing_attr;
const Location &actual_location1 = FlightSqlConnection::BuildLocation({
{Connection::HOST, std::string("localhost")},
{Connection::PORT, 32010},
});
{FlightSqlConnection::HOST, std::string("localhost")},
{FlightSqlConnection::PORT, std::string("32010")},
}, missing_attr);
const Location &actual_location2 = FlightSqlConnection::BuildLocation({
{Connection::HOST, std::string("localhost")},
{Connection::PORT, 32011},
});
{FlightSqlConnection::HOST, std::string("localhost")},
{FlightSqlConnection::PORT, std::string("32011")},
}, missing_attr);

Location expected_location;
ASSERT_TRUE(
Expand All @@ -77,16 +78,17 @@ TEST(BuildLocationTests, ForTcp) {
}

TEST(BuildLocationTests, ForTls) {
std::vector<std::string> missing_attr;
const Location &actual_location1 = FlightSqlConnection::BuildLocation({
{Connection::HOST, std::string("localhost")},
{Connection::PORT, 32010},
{Connection::USE_TLS, true},
});
{FlightSqlConnection::HOST, std::string("localhost")},
{FlightSqlConnection::PORT, std::string("32010")},
{FlightSqlConnection::USE_TLS, std::string("1")},
}, missing_attr);
const Location &actual_location2 = FlightSqlConnection::BuildLocation({
{Connection::HOST, std::string("localhost")},
{Connection::PORT, 32011},
{Connection::USE_TLS, true},
});
{FlightSqlConnection::HOST, std::string("localhost")},
{FlightSqlConnection::PORT, std::string("32011")},
{FlightSqlConnection::USE_TLS, std::string("1")},
}, missing_attr);

Location expected_location;
ASSERT_TRUE(
Expand All @@ -96,7 +98,7 @@ TEST(BuildLocationTests, ForTls) {
}

TEST(BuildCallOptionsTest, ConnectionTimeout) {
FlightSqlConnection connection(spi::V_3);
FlightSqlConnection connection(odbcabstraction::V_3);

// Expect default timeout to be -1
ASSERT_EQ(TimeoutDuration{-1.0}, connection.BuildCallOptions().timeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.

#include "flight_sql_driver.h"
#include <flight_sql/flight_sql_driver.h>
#include "flight_sql_connection.h"

namespace driver {
namespace flight_sql {

using spi::Connection;
using spi::OdbcVersion;
using odbcabstraction::Connection;
using odbcabstraction::OdbcVersion;

std::shared_ptr<Connection>
FlightSqlDriver::CreateConnection(OdbcVersion odbc_version) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

#pragma once

#include "driver.h"
#include <odbcabstraction/driver.h>

namespace driver {
namespace flight_sql {

class FlightSqlDriver : public spi::Driver {
class FlightSqlDriver : public odbcabstraction::Driver {
public:
std::shared_ptr<spi::Connection>
CreateConnection(spi::OdbcVersion odbc_version) override;
std::shared_ptr<odbcabstraction::Connection>
CreateConnection(odbcabstraction::OdbcVersion odbc_version) override;
};

}; // namespace flight_sql
Expand Down
Loading

0 comments on commit 98df4a4

Please sign in to comment.