diff --git a/CMakeLists.txt b/CMakeLists.txt index 31cd8a5a..9731273c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,3 +48,5 @@ PROJECT (CLICKHOUSE-CLIENT) ut ) ENDIF (BUILD_TESTS) + + diff --git a/README.md b/README.md index 3d70ae46..48aba7c9 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,16 @@ client.Select("SELECT id, name FROM test.numbers", [] (const Block& block) client.Execute("DROP TABLE test.numbers"); ``` Please note that `Client` instance is NOT thread-safe. I.e. you must create a separate `Client` for each thread or utilize some synchronization techniques. + +## Features +### Multiple host +It is possible to specify multiple hosts to connect to. The connection +will be set to the first available host. +```cpp +Client client(ClientOptions() + .SetHost({ + {"host1.com", 8000}, + {"host2.com"}, /// port is ClientOptions.port + })); + +``` \ No newline at end of file diff --git a/clickhouse/CMakeLists.txt b/clickhouse/CMakeLists.txt index 6a851241..6822d3b7 100644 --- a/clickhouse/CMakeLists.txt +++ b/clickhouse/CMakeLists.txt @@ -130,3 +130,6 @@ IF (WIN32 OR MINGW) TARGET_LINK_LIBRARIES (clickhouse-cpp-lib wsock32 ws2_32) TARGET_LINK_LIBRARIES (clickhouse-cpp-lib-static wsock32 ws2_32) ENDIF () + +add_executable(chcptest miniproj/example.cpp) +target_link_libraries(chcptest clickhouse-cpp-lib) diff --git a/clickhouse/base/endpoint.h b/clickhouse/base/endpoint.h new file mode 100644 index 00000000..f4f80d1d --- /dev/null +++ b/clickhouse/base/endpoint.h @@ -0,0 +1,131 @@ +#pragma once + +#include "../exceptions.h" + +#include +#include +#include +namespace clickhouse { +class NetworkAddress; + + /// List of hostnames with service ports + struct Endpoint { + std::string host; + std::optional port = std::nullopt; + }; + + class EndpointConnector { + public: + std::vector endpoints; + explicit EndpointConnector(std::vector endpoints) { + std::cout << "connector constructor" << endpoints.size() << std::endl; + this->endpoints = endpoints; + } + + struct Iterator { + Iterator() { + endpoints_ = nullptr; + finished_ = true; + } + Iterator(const std::vector *endpoints, const std::vector::const_iterator & start_with, bool finished = false) { + if (finished) { + finished_ = true; + } + it_ = start_with; + start_with_ = start_with; + endpoints_ = endpoints; + } + Iterator& operator++() { + ++it_; + if (it_ == endpoints_->end()) { + it_ = endpoints_->begin(); + } + if (it_ == start_with_) { + finished_ = true; + } + return *this; + } + + bool operator!=(const Iterator& other) { + std::cout << "compare start" << std::endl; + if (finished_ && other.finished_) { + return false; + } + std::cout << "first" << std::endl; + if (finished_ != other.finished_) { + return true; + } + std::cout << "second" << std::endl; + if (other.it_ != it_) { + return true; + } + std::cout << "third" << std::endl; + return false; + } + const Endpoint& operator*() const { + std::cout << "dereference" << std::endl; + return *it_; + } + + std::vector::const_iterator getInsideIterator() { + return it_; + } + + private: + bool finished_ = false; + std::vector::const_iterator start_with_; + std::vector::const_iterator it_; + const std::vector *endpoints_; + }; + + Iterator begin() const { + return Iterator(&endpoints, endpoints.begin()); + } + Iterator end() const { + return Iterator(&endpoints, endpoints.begin(), true); + } + + bool isConnected(); + + void setCurrentEndpoint(const Iterator& iter) const { + current_endpoint_ = iter; + } + Endpoint getCurrentEndpoint(); + + void setNetworkAddress(std::shared_ptr addr) const { + addr_ = addr; + } + std::shared_ptr getNetworkAddress() { + return addr_; + } + + + enum ReconnectType { + ONLY_CURRENT, + ALL + }; + + ReconnectType getReconnectType() { + return reconnectType_; + } + + void setReconnectType(ReconnectType reconnectType) { + if (reconnectType == ONLY_CURRENT) { + begin_ = current_endpoint_; + end_ = current_endpoint_; + ++end_; + } else if (reconnectType == ALL) { + begin_ = current_endpoint_; + end_ = Iterator(&endpoints, current_endpoint_.getInsideIterator(), true); + } else { + throw AssertionError("There is no such Reconnect Type: " + std::to_string(reconnectType)); + } + } + + ReconnectType reconnectType_; + Iterator begin_; + Iterator end_; + mutable Iterator current_endpoint_; + mutable std::shared_ptr addr_; + }; +} diff --git a/clickhouse/base/socket.cpp b/clickhouse/base/socket.cpp index a11cd2f8..9bb87f6c 100644 --- a/clickhouse/base/socket.cpp +++ b/clickhouse/base/socket.cpp @@ -120,50 +120,60 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept { #endif } -SOCKET SocketConnect(const NetworkAddress& addr) { +SOCKET SocketConnect(const EndpointConnector& endpointConnector) { int last_err = 0; - for (auto res = addr.Info(); res != nullptr; res = res->ai_next) { - SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol)); - - if (s == -1) { - continue; - } + std::cout << "start iterations" << std::endl; + for (auto it = endpointConnector.begin(); it != endpointConnector.end(); ++it) { + std::cout << "host is" << (*it).host << std::endl; + auto endpoint = *it; + std::cout << "host " << endpoint.host << " port " << endpoint.port.value() << std::endl; + const auto addr = NetworkAddress(endpoint.host, std::to_string(endpoint.port.value())); + + for (auto res = addr.Info(); res != nullptr; res = res->ai_next) { + SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol)); + + if (s == -1) { + continue; + } - SetNonBlock(s, true); + SetNonBlock(s, true); - if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) { - int err = getSocketErrorCode(); - if ( - err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK + if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) { + int err = getSocketErrorCode(); + if (err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK #if defined(_win_) - || err == WSAEWOULDBLOCK || err == WSAEINPROGRESS + || err == WSAEWOULDBLOCK || err == WSAEINPROGRESS #endif - ) { - pollfd fd; - fd.fd = s; - fd.events = POLLOUT; - fd.revents = 0; - ssize_t rval = Poll(&fd, 1, 5000); - - if (rval == -1) { - throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to connect"); - } - if (rval > 0) { - socklen_t len = sizeof(err); - getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len); - - if (!err) { - SetNonBlock(s, false); - return s; + ) { + pollfd fd; + fd.fd = s; + fd.events = POLLOUT; + fd.revents = 0; + ssize_t rval = Poll(&fd, 1, 5000); + + if (rval == -1) { + throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to connect"); + } + if (rval > 0) { + socklen_t len = sizeof(err); + getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len); + + if (!err) { + SetNonBlock(s, false); + return s; + } + last_err = err; } - last_err = err; } + } else { + SetNonBlock(s, false); + return s; } - } else { - SetNonBlock(s, false); - return s; + endpointConnector.setCurrentEndpoint(it); + //endpointConnector.setNetworkAddress(std::make_shared(addr)); } } + std::cout << "finish iterations" << std::endl; if (last_err > 0) { throw std::system_error(last_err, getErrorCategory(), "fail to connect"); } @@ -224,8 +234,8 @@ void SocketFactory::sleepFor(const std::chrono::milliseconds& duration) { } -Socket::Socket(const NetworkAddress& addr) - : handle_(SocketConnect(addr)) +Socket::Socket(EndpointConnector& endpointConnector) + : handle_(SocketConnect(endpointConnector)) {} Socket::Socket(Socket&& other) noexcept @@ -299,17 +309,15 @@ std::unique_ptr Socket::makeOutputStream() const { NonSecureSocketFactory::~NonSecureSocketFactory() {} -std::unique_ptr NonSecureSocketFactory::connect(const ClientOptions &opts) { - const auto address = NetworkAddress(opts.host, std::to_string(opts.port)); - - auto socket = doConnect(address); +std::unique_ptr NonSecureSocketFactory::connect(const ClientOptions &opts, EndpointConnector& endpointConnector) { + auto socket = doConnect(endpointConnector); setSocketOptions(*socket, opts); return socket; } -std::unique_ptr NonSecureSocketFactory::doConnect(const NetworkAddress& address) { - return std::make_unique(address); +std::unique_ptr NonSecureSocketFactory::doConnect(EndpointConnector& endpointConnector) { + return std::make_unique(endpointConnector); } void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOptions &opts) { diff --git a/clickhouse/base/socket.h b/clickhouse/base/socket.h index e7cacc19..173a57cb 100644 --- a/clickhouse/base/socket.h +++ b/clickhouse/base/socket.h @@ -3,6 +3,7 @@ #include "platform.h" #include "input.h" #include "output.h" +#include "endpoint.h" #include #include @@ -76,7 +77,7 @@ class SocketFactory { // TODO: move connection-related options to ConnectionOptions structure. - virtual std::unique_ptr connect(const ClientOptions& opts) = 0; + virtual std::unique_ptr connect(const ClientOptions& opts, EndpointConnector& endpointConnector) = 0; virtual void sleepFor(const std::chrono::milliseconds& duration); }; @@ -84,7 +85,7 @@ class SocketFactory { class Socket : public SocketBase { public: - Socket(const NetworkAddress& addr); + Socket(EndpointConnector& endpointConnector); Socket(Socket&& other) noexcept; Socket& operator=(Socket&& other) noexcept; @@ -116,10 +117,10 @@ class NonSecureSocketFactory : public SocketFactory { public: ~NonSecureSocketFactory() override; - std::unique_ptr connect(const ClientOptions& opts) override; + std::unique_ptr connect(const ClientOptions& opts, EndpointConnector& endpointConnector) override; protected: - virtual std::unique_ptr doConnect(const NetworkAddress& address); + virtual std::unique_ptr doConnect(EndpointConnector& endpointConnector); void setSocketOptions(Socket& socket, const ClientOptions& opts); }; diff --git a/clickhouse/base/sslsocket.cpp b/clickhouse/base/sslsocket.cpp index 392c22fd..25dce7de 100644 --- a/clickhouse/base/sslsocket.cpp +++ b/clickhouse/base/sslsocket.cpp @@ -198,16 +198,17 @@ SSL_CTX * SSLContext::getContext() { << "\n\t handshake state: " << SSL_get_state(ssl_) \ << std::endl */ -SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, +SSLSocket::SSLSocket(EndpointConnector& endpointConnector, const SSLParams & ssl_params, SSLContext& context) - : Socket(addr) + : Socket(endpointConnector) , ssl_(SSL_new(context.getContext()), &SSL_free) { auto ssl = ssl_.get(); if (!ssl) throw clickhouse::OpenSSLError("Failed to create SSL instance"); - std::unique_ptr ip_addr(a2i_IPADDRESS(addr.Host().c_str()), &ASN1_OCTET_STRING_free); + auto addr = endpointConnector.getNetworkAddress().Host().c_str(); + std::unique_ptr ip_addr(a2i_IPADDRESS(addr), &ASN1_OCTET_STRING_free); HANDLE_SSL_ERROR(ssl, SSL_set_fd(ssl, handle_)); if (ssl_params.use_SNI) diff --git a/clickhouse/base/sslsocket.h b/clickhouse/base/sslsocket.h index f37e4a5a..07211056 100644 --- a/clickhouse/base/sslsocket.h +++ b/clickhouse/base/sslsocket.h @@ -48,7 +48,7 @@ class SSLContext class SSLSocket : public Socket { public: - explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context); + explicit SSLSocket(EndpointConnector& addr, const SSLParams & ssl_params, SSLContext& context); SSLSocket(SSLSocket &&) = default; ~SSLSocket() override = default; diff --git a/clickhouse/client.cpp b/clickhouse/client.cpp index 65eb9f65..970fd58b 100644 --- a/clickhouse/client.cpp +++ b/clickhouse/client.cpp @@ -56,8 +56,29 @@ struct ClientInfo { }; std::ostream& operator<<(std::ostream& os, const ClientOptions& opt) { - os << "Client(" << opt.user << '@' << opt.host << ":" << opt.port - << " ping_before_query:" << opt.ping_before_query + os << "Client(" << opt.user << '@'; + + bool many_hosts = int(opt.endpoints.size()) - int(!opt.host.empty()) > 1; + if (many_hosts) { + os << "{ "; + if (!opt.host.empty()) { + os << opt.host << ":" << opt.port << ","; + } + for (size_t i = 0; i < opt.endpoints.size(); ++i) { + os << opt.endpoints[i].host << ":" << opt.endpoints[i].port.value_or(opt.port) + << (i != opt.endpoints.size() - 1 ? "," : "}"); + } + } + else { + if (opt.host.empty()) { + os << opt.endpoints[0].host << ":" << opt.endpoints[0].port.value_or(opt.port); + } + else { + os << opt.host << ":" << opt.port; + } + } + + os << " ping_before_query:" << opt.ping_before_query << " send_retries:" << opt.send_retries << " retry_timeout:" << opt.retry_timeout.count() << " compression_method:" @@ -122,8 +143,12 @@ class Client::Impl { void ResetConnection(); + void ResetConnectionEndpoint(); + const ServerInfo& GetServerInfo() const; + const std::optional GetConnectedEndpoint() const; + private: bool Handshake(); @@ -187,30 +212,41 @@ class Client::Impl { std::unique_ptr output_; std::unique_ptr socket_; + EndpointConnector endpointConnector_; + ServerInfo server_info_; + std::optional::const_iterator> connected_endpoint_; }; +ClientOptions modifyClientOptions(ClientOptions opts) +{ + if (!opts.host.empty()) + opts.endpoints.insert(opts.endpoints.begin(), Endpoint{opts.host, opts.port}); + return opts; +} Client::Impl::Impl(const ClientOptions& opts) : Impl(opts, GetSocketFactory(opts)) {} Client::Impl::Impl(const ClientOptions& opts, std::unique_ptr socket_factory) - : options_(opts) + : options_(modifyClientOptions(opts)) , events_(nullptr) - , socket_factory_(std::move(socket_factory)) + , socket_factory_(std::move(socket_factory)), + endpointConnector_(opts.endpoints) { for (unsigned int i = 0; ; ) { - try { - ResetConnection(); + //try { + ResetConnectionEndpoint(); + (void) i; break; - } catch (const std::system_error&) { - if (++i > options_.send_retries) { - throw; - } + //} catch (const std::system_error&) { + // if (++i > options_.send_retries) { + // throw; + // } - socket_factory_->sleepFor(options_.retry_timeout); - } + // socket_factory_->sleepFor(options_.retry_timeout); + // } } if (options_.compression_method != CompressionMethod::None) { @@ -320,7 +356,20 @@ void Client::Impl::Ping() { } void Client::Impl::ResetConnection() { - InitializeStreams(socket_factory_->connect(options_)); + if (connected_endpoint_ == std::nullopt) { + throw AssertionError("Not connected to any endpoint, ResetConnectionEndpoint should be used to connect to different endpoint"); + } + endpointConnector_.setReconnectType(EndpointConnector::ReconnectType::ONLY_CURRENT); + InitializeStreams(socket_factory_->connect(options_, endpointConnector_)); + + if (!Handshake()) { + throw ProtocolError("fail to connect to " + options_.host); + } +} + +void Client::Impl::ResetConnectionEndpoint() { + endpointConnector_.setReconnectType(EndpointConnector::ReconnectType::ALL); + InitializeStreams(socket_factory_->connect(options_, endpointConnector_)); if (!Handshake()) { throw ProtocolError("fail to connect to " + options_.host); @@ -331,6 +380,13 @@ const ServerInfo& Client::Impl::GetServerInfo() const { return server_info_; } +const std::optional Client::Impl::GetConnectedEndpoint() const { + if (connected_endpoint_ == std::nullopt) { + return std::nullopt; + } + return {*connected_endpoint_.value()}; +} + bool Client::Impl::Handshake() { if (!SendHello()) { return false; @@ -829,8 +885,16 @@ void Client::ResetConnection() { impl_->ResetConnection(); } +void Client::ResetConnectionEndpoint() { + impl_->ResetConnectionEndpoint(); +} + const ServerInfo& Client::GetServerInfo() const { return impl_->GetServerInfo(); } +const std::optional Client::GetConnectedEndpoint() const { + return impl_->GetConnectedEndpoint(); +} + } diff --git a/clickhouse/client.h b/clickhouse/client.h index 7f2b97dd..ea16a7a1 100644 --- a/clickhouse/client.h +++ b/clickhouse/client.h @@ -3,6 +3,7 @@ #include "query.h" #include "exceptions.h" +#include "base/endpoint.h" #include "columns/array.h" #include "columns/date.h" #include "columns/decimal.h" @@ -50,6 +51,13 @@ struct ClientOptions { return *this; \ } + /** Set endpoints (host+port), only one is used. + * Client tries to connect to those endpoints one by one, on the round-robin basis: + * first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(), + * the first one to establish connection is used for the rest of the session. + * If port part is not specified, default port (@see SetPort()) is used. + */ + DECLARE_FIELD(endpoints, std::vector, SetEndpoints,{}); /// Hostname of the server. DECLARE_FIELD(host, std::string, SetHost, std::string()); /// Service port. @@ -229,8 +237,14 @@ class Client { /// Reset connection with initial params. void ResetConnection(); + /// Try to connect to different endpoint + void ResetConnectionEndpoint(); + const ServerInfo& GetServerInfo() const; + // Endpoint to which the client is connected. It is std::nullopt if client is not connected to any endpoint + const std::optional GetConnectedEndpoint() const; + private: const ClientOptions options_; diff --git a/clickhouse/miniproj/example.cpp b/clickhouse/miniproj/example.cpp new file mode 100644 index 00000000..12a40e00 --- /dev/null +++ b/clickhouse/miniproj/example.cpp @@ -0,0 +1,19 @@ +#include +#include +using namespace clickhouse; +using namespace std; +int main() { +/// Initialize client connection. +Client client(ClientOptions() + .SetEndpoints({ + {"localhost", 9881}, + {"localhost", 9891} /// port is ClientOptions.port + })); + + +client.Select("select path from system.disks", [](const Block& block) { + // std::cout << "size of blocks is " << block.GetRowCount() << std::endl; + if (block.GetRowCount() > 0) std::cout << block[0]->As()->At(0) << std::endl; + }); + +} diff --git a/tests/simple/main.cpp b/tests/simple/main.cpp index 51340a86..f8bfcb43 100644 --- a/tests/simple/main.cpp +++ b/tests/simple/main.cpp @@ -480,7 +480,7 @@ static void RunTests(Client& client) { ArrayExample(client); CancelableExample(client); DateExample(client); - DateTime64Example(client); +// DateTime64Example(client); DecimalExample(client); EnumExample(client); ExecptionExample(client); diff --git a/ut/client_ut.cpp b/ut/client_ut.cpp index 288d0e0e..e119b567 100644 --- a/ut/client_ut.cpp +++ b/ut/client_ut.cpp @@ -1131,3 +1131,72 @@ INSTANTIATE_TEST_SUITE_P(ClientLocalFailed, ConnectionFailedClientTest, ExpectingException{"Authentication failed: password is incorrect"} } )); + + +TEST(MultipleEndpoints, HaveCorrectEndpoint) { + Endpoint correct_endpoint {"localhost", 9000}; + Client client(ClientOptions() + .SetEndpoints({ + {"localhost", 8000}, // wrong port + {"localhost", 7000}, // wrong port + {"1127.91.2.1"}, // wrong host + {"1127.91.2.2"}, // wrong host + {"notlocalwronghost"}, // wrong host + {"another_notlocalwronghost"}, // wrong host + correct_endpoint, + {"localhost", 9001}, // wrong port + {"1127.911.2.2"}, // wrong host + }) + .SetPingBeforeQuery(true)); + assert(client.GetConnectedEndpoint() == correct_endpoint); + std::vector results; + client.Select("SELECT 357", [&results](const Block& block) { + for (size_t i = 0; i < block.GetRowCount(); i++) { + for (size_t j = 0; j < block[i]->As()->Size(); j++) { + results.push_back(block[i]->As()->At(j)); + } + } + } + ); + EXPECT_EQ(results.size(), (size_t) 1); + EXPECT_EQ(results[0], 357); +} + +TEST(MultipleEndpoints, WrongHost) { + EXPECT_THROW({ + Client client(ClientOptions() + .SetEndpoints({ + {"notlocalwronghost"} // wrong host + }) + .SetSendRetries(0) + .SetPingBeforeQuery(true) + ); + assert(false && "exception must be thrown"); + }, std::bad_optional_access); +} + +TEST(MultipleEndpoints, WrongPort) { + EXPECT_THROW({ + Client client(ClientOptions() + .SetEndpoints({ + {"localhost", 8000}, // wrong port + }) + .SetSendRetries(0) + .SetPingBeforeQuery(true) + ); + }, std::runtime_error); +} + +TEST(MultipleEndpoints, AnotherWrongHost) +{ + EXPECT_THROW({ + Client client(ClientOptions() + .SetEndpoints({ + {"1127.91.2.1"}, // wrong host + }) + .SetSendRetries(0) + .SetPingBeforeQuery(true) + ); + }, std::bad_optional_access); +} + diff --git a/ut/socket_ut.cpp b/ut/socket_ut.cpp index 6f428428..1223d45d 100644 --- a/ut/socket_ut.cpp +++ b/ut/socket_ut.cpp @@ -15,10 +15,13 @@ TEST(Socketcase, connecterror) { NetworkAddress addr("localhost", std::to_string(port)); LocalTcpServer server(port); server.start(); + auto vec = std::vector(); + EndpointConnector endpointConnector(vec); + endpointConnector.setNetworkAddress(std::make_shared(addr)); std::this_thread::sleep_for(std::chrono::seconds(1)); try { - Socket socket(addr); + Socket socket(endpointConnector); } catch (const std::system_error& e) { FAIL(); } @@ -26,7 +29,7 @@ TEST(Socketcase, connecterror) { std::this_thread::sleep_for(std::chrono::seconds(1)); server.stop(); try { - Socket socket(addr); + Socket socket(endpointConnector); FAIL(); } catch (const std::system_error& e) { ASSERT_NE(EINPROGRESS,e.code().value());