diff --git a/elasticsearch/Cargo.toml b/elasticsearch/Cargo.toml index 445d8af9..4182c6cc 100644 --- a/elasticsearch/Cargo.toml +++ b/elasticsearch/Cargo.toml @@ -27,10 +27,13 @@ bytes = "^0.5" dyn-clone = "~1" percent-encoding = "2.1.0" reqwest = { version = "~0.10", default-features = false, features = ["gzip", "json"] } +lazy_static = "^1.4" url = "^2.1" +regex = "1.3" serde = { version = "~1", features = ["derive"] } serde_json = "~1" serde_with = "~1" +tokio = { version = "0.2.0", default-features = false, features = ["macros", "tcp", "time"] } void = "1.0.2" [dev-dependencies] diff --git a/elasticsearch/src/http/transport.rs b/elasticsearch/src/http/transport.rs index 14a2cdea..d935b6a2 100644 --- a/elasticsearch/src/http/transport.rs +++ b/elasticsearch/src/http/transport.rs @@ -37,12 +37,19 @@ use crate::{ }; use base64::write::EncoderWriter as Base64Encoder; use bytes::BytesMut; +use regex::Regex; use serde::Serialize; +use serde_json::Value; use std::{ error, fmt, fmt::Debug, io::{self, Write}, - time::Duration, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, RwLock, + }, + thread::spawn, + time::{Duration, Instant}, }; use url::Url; @@ -96,6 +103,10 @@ impl fmt::Display for BuildError { /// Default address to Elasticsearch running on `http://localhost:9200` pub static DEFAULT_ADDRESS: &str = "http://localhost:9200"; +lazy_static! { + static ref ADDRESS_REGEX: Regex = + Regex::new(r"((?P[^/]+)/)?(?P[^:]+|\[[\da-fA-F:\.]+\]):(?P\d+)$").unwrap(); +} /// Builds a HTTP transport to make API calls to Elasticsearch pub struct TransportBuilder { @@ -268,7 +279,7 @@ impl TransportBuilder { let client = client_builder.build()?; Ok(Transport { client, - conn_pool: self.conn_pool, + conn_pool: Arc::new(self.conn_pool), credentials: self.credentials, }) } @@ -284,7 +295,7 @@ impl Default for TransportBuilder { /// A connection to an Elasticsearch node, used to send an API request #[derive(Debug, Clone)] pub struct Connection { - url: Url, + url: Arc, } impl Connection { @@ -298,8 +309,14 @@ impl Connection { url.set_path(&format!("{}/", url.path())); } + let url = Arc::new(url); + Self { url } } + + pub fn url(&self) -> Arc { + self.url.clone() + } } /// A HTTP transport responsible for making the API requests to Elasticsearch, @@ -308,7 +325,7 @@ impl Connection { pub struct Transport { client: reqwest::Client, credentials: Option, - conn_pool: Box, + conn_pool: Arc>, } impl Transport { @@ -336,6 +353,35 @@ impl Transport { Ok(transport) } + /// Creates a new instance of a [Transport] configured with a + /// [MultiNodeConnectionPool] that does not refresh + pub fn static_node_list(urls: Vec<&str>) -> Result { + let urls: Vec = urls + .iter() + .map(|url| Url::parse(url)) + .collect::, _>>()?; + let conn_pool = MultiNodeConnectionPool::round_robin(urls, None); + let transport = TransportBuilder::new(conn_pool).build()?; + Ok(transport) + } + + /// Creates a new instance of a [Transport] configured with a + /// [MultiNodeConnectionPool] + /// + /// * `reseed_frequency` - frequency at which connections should be refreshed in seconds + pub fn sniffing_node_list( + urls: Vec<&str>, + reseed_frequency: Duration, + ) -> Result { + let urls: Vec = urls + .iter() + .map(|url| Url::parse(url)) + .collect::, _>>()?; + let conn_pool = MultiNodeConnectionPool::round_robin(urls, Some(reseed_frequency)); + let transport = TransportBuilder::new(conn_pool).build()?; + Ok(transport) + } + /// Creates a new instance of a [Transport] configured for use with /// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/). /// @@ -348,23 +394,22 @@ impl Transport { Ok(transport) } - /// Creates an asynchronous request that can be awaited - pub async fn send( + fn request_builder( &self, + connection: &Connection, method: Method, path: &str, headers: HeaderMap, query_string: Option<&Q>, body: Option, timeout: Option, - ) -> Result + ) -> Result where B: Body, Q: Serialize + ?Sized, { - let connection = self.conn_pool.next(); - let url = connection.url.join(path.trim_start_matches('/'))?; let reqwest_method = self.method(method); + let url = connection.url.join(path.trim_start_matches('/'))?; let mut request_builder = self.client.request(reqwest_method, url); if let Some(t) = timeout { @@ -421,6 +466,102 @@ impl Transport { if let Some(q) = query_string { request_builder = request_builder.query(q); } + Ok(request_builder) + } + + fn parse_to_url(address: &str, scheme: &str) -> Result { + if address.is_empty() { + return Err(crate::error::lib("Bound Address is empty")); + } + + let matches = ADDRESS_REGEX + .captures(address) + .ok_or_else(|| crate::lib(format!("error parsing address into url: {}", address)))?; + + let host = matches + .name("fqdn") + .or_else(|| Some(matches.name("ip").unwrap())) + .unwrap() + .as_str() + .trim(); + let port = matches.name("port").unwrap().as_str().trim(); + + Ok(Url::parse( + format!("{}://{}:{}", scheme, host, port).as_str(), + )?) + } + + /// Creates an asynchronous request that can be awaited + pub async fn send( + &self, + method: Method, + path: &str, + headers: HeaderMap, + query_string: Option<&Q>, + body: Option, + timeout: Option, + ) -> Result + where + B: Body, + Q: Serialize + ?Sized, + { + // Threads will execute against old connection pool during reseed + if self.conn_pool.reseedable() { + let local_conn_pool = self.conn_pool.clone(); + let connection = local_conn_pool.next(); + + // Build node info request + let node_request = self.request_builder( + &connection, + Method::Get, + "_nodes/http?filter_path=nodes.*.http", + headers.clone(), + None::<&Q>, + None::, + timeout, + )?; + + spawn(move || { + // TODO: Log reseed failures + let mut rt = tokio::runtime::Runtime::new().expect("Cannot create tokio runtime"); + rt.block_on(async { + let scheme = connection.url.scheme(); + let resp = node_request.send().await.unwrap(); + let json: Value = resp.json().await.unwrap(); + let connections: Vec = json["nodes"] + .as_object() + .unwrap() + .iter() + .map(|h| { + let address = h.1["http"]["publish_address"] + .as_str() + .or_else(|| { + Some( + h.1["http"]["bound_address"].as_array().unwrap()[0] + .as_str() + .unwrap(), + ) + }) + .unwrap(); + let url = Self::parse_to_url(address, scheme).unwrap(); + Connection::new(url) + }) + .collect(); + local_conn_pool.reseed(connections); + }) + }); + } + + let connection = self.conn_pool.next(); + let request_builder = self.request_builder( + &connection, + method, + path, + headers, + query_string, + body, + timeout, + )?; let response = request_builder.send().await; match response { @@ -444,7 +585,14 @@ impl Default for Transport { /// dynamically at runtime, based upon the response to API calls. pub trait ConnectionPool: Debug + dyn_clone::DynClone + Sync + Send { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection; + fn next(&self) -> Connection; + + fn reseedable(&self) -> bool { + false + } + + // NOOP by default + fn reseed(&self, _connection: Vec) {} } clone_trait_object!(ConnectionPool); @@ -473,8 +621,8 @@ impl Default for SingleNodeConnectionPool { impl ConnectionPool for SingleNodeConnectionPool { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection { - &self.connection + fn next(&self) -> Connection { + self.connection.clone() } } @@ -594,8 +742,119 @@ impl CloudConnectionPool { impl ConnectionPool for CloudConnectionPool { /// Gets a reference to the next [Connection] - fn next(&self) -> &Connection { - &self.connection + fn next(&self) -> Connection { + self.connection.clone() + } +} + +/// A Connection Pool that manages a static connection of nodes +#[derive(Debug, Clone)] +pub struct MultiNodeConnectionPool { + inner: Arc>, + reseed_frequency: Option, + connection_selector: ConnSelector, + reseeding: Arc, +} + +#[derive(Debug, Clone)] +pub struct MultiNodeConnectionPoolInner { + last_update: Option, + connections: Vec, +} + +impl ConnectionPool for MultiNodeConnectionPool +where + ConnSelector: ConnectionSelector + Clone, +{ + fn next(&self) -> Connection { + let inner = self.inner.read().expect("lock poisoned"); + self.connection_selector + .try_next(&inner.connections) + .unwrap() + } + + fn reseedable(&self) -> bool { + let inner = self.inner.read().expect("lock poisoned"); + let reseed_frequency = match self.reseed_frequency { + Some(wait) => wait, + None => return false, + }; + let last_update_is_stale = inner + .last_update + .as_ref() + .map(|last_update| last_update.elapsed() > reseed_frequency); + let reseedable = last_update_is_stale.unwrap_or(true); + + return if !reseedable { + false + } else { + // Check if refreshing is false if so, sets to true atomically and returns old value (false) meaning refreshable is true + // If refreshing is set to true, do nothing and return true, meaning refreshable is false + !self + .reseeding + .compare_and_swap(false, true, Ordering::Relaxed) + }; + } + + fn reseed(&self, mut connection: Vec) { + let mut inner = self.inner.write().expect("lock poisoned"); + inner.last_update = Some(Instant::now()); + inner.connections.clear(); + inner.connections.append(&mut connection); + self.reseeding.store(false, Ordering::Relaxed); + } +} + +impl MultiNodeConnectionPool { + /** Use a round-robin strategy for balancing traffic over the given set of nodes. */ + pub fn round_robin(urls: Vec, reseed_frequency: Option) -> Self { + let connections = urls.into_iter().map(Connection::new).collect(); + + let inner: Arc> = + Arc::new(RwLock::new(MultiNodeConnectionPoolInner { + last_update: None, + connections, + })); + let reseeding = Arc::new(AtomicBool::new(false)); + + let connection_selector = RoundRobin::default(); + Self { + inner, + connection_selector, + reseed_frequency, + reseeding, + } + } +} + +/** The strategy selects an address from a given collection. */ +pub trait ConnectionSelector: Send + Sync + Debug { + /** Try get the next connection. */ + fn try_next(&self, connections: &[Connection]) -> Result; +} + +/** A round-robin strategy cycles through nodes sequentially. */ +#[derive(Clone, Debug)] +pub struct RoundRobin { + index: Arc, +} + +impl Default for RoundRobin { + fn default() -> Self { + RoundRobin { + index: Arc::new(AtomicUsize::new(0)), + } + } +} + +impl ConnectionSelector for RoundRobin { + fn try_next(&self, connections: &[Connection]) -> Result { + if connections.is_empty() { + Err(crate::error::lib("Connection list empty")) + } else { + let i = self.index.fetch_add(1, Ordering::Relaxed) % connections.len(); + Ok(connections[i].clone()) + } } } @@ -603,7 +862,14 @@ impl ConnectionPool for CloudConnectionPool { pub mod tests { #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use crate::auth::ClientCertificate; - use crate::http::transport::{CloudId, Connection, SingleNodeConnectionPool, TransportBuilder}; + use crate::http::transport::{ + CloudId, Connection, ConnectionPool, MultiNodeConnectionPool, SingleNodeConnectionPool, + Transport, TransportBuilder, + }; + use std::{ + sync::atomic::Ordering, + time::{Duration, Instant}, + }; use url::Url; #[test] @@ -628,6 +894,24 @@ pub mod tests { assert!(res.is_err()); } + #[test] + fn test_url_parsing_where_hostname_and_ip_present() { + let url = Transport::parse_to_url("localhost/127.0.0.1:9200", "http").unwrap(); + assert_eq!(url.into_string(), "http://localhost:9200/"); + } + + #[test] + fn test_url_parsing_where_only_ip_present() { + let url = Transport::parse_to_url("127.0.0.1:9200", "http").unwrap(); + assert_eq!(url.into_string(), "http://127.0.0.1:9200/"); + } + + #[test] + fn test_url_parsing_where_only_hostname_present() { + let url = Transport::parse_to_url("localhost:9200", "http").unwrap(); + assert_eq!(url.into_string(), "http://localhost:9200/"); + } + #[test] fn can_parse_cloud_id_with_kibana_uuid() { let base64 = base64::encode("cloud-endpoint.example$3dadf823f05388497ea684236d918a1a$3f26e1609cf54a0f80137a80de560da4"); @@ -742,4 +1026,87 @@ pub mod tests { let conn = Connection::new(url); assert_eq!(conn.url.as_str(), "http://10.1.2.3/"); } + + fn expected_addresses() -> Vec { + vec!["http://a:9200/", "http://b:9200/", "http://c:9200/"] + .iter() + .map(|addr| Url::parse(addr).unwrap()) + .collect() + } + + #[test] + fn test_reseedable_false_on_no_duration() { + let connections = MultiNodeConnectionPool::round_robin(expected_addresses(), None); + assert!(!connections.reseedable()); + } + + #[test] + fn test_reseed() { + let connection_pool = + MultiNodeConnectionPool::round_robin(vec![], Some(Duration::from_secs(28800))); + + let connections = expected_addresses() + .into_iter() + .map(Connection::new) + .collect(); + connection_pool.reseed(connections); + for _ in 0..10 { + for expected in expected_addresses() { + let actual = connection_pool.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + // Check connection pool not reseedable after reseed + assert!(!connection_pool.reseedable()); + assert!(!connection_pool.reseeding.load(Ordering::Relaxed)); + } + + #[test] + fn test_reseedable_after_duration() { + let connection_pool = MultiNodeConnectionPool::round_robin( + expected_addresses(), + Some(Duration::from_secs(30)), + ); + + // Set internal last_update to a minute ago + let mut inner = connection_pool.inner.write().expect("lock poisoned"); + inner.last_update = Some(Instant::now() - Duration::from_secs(60)); + drop(inner); + + assert!(connection_pool.reseedable()); + assert!(connection_pool.reseeding.load(Ordering::Relaxed)); + } + + #[test] + fn round_robin_next_multi() { + let connections = MultiNodeConnectionPool::round_robin(expected_addresses(), None); + + for _ in 0..10 { + for expected in expected_addresses() { + let actual = connections.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + } + + #[test] + fn round_robin_next_single() { + let expected = Url::parse("http://a:9200/").unwrap(); + let connections = MultiNodeConnectionPool::round_robin(vec![expected.clone()], None); + + for _ in 0..10 { + let actual = connections.next(); + + assert_eq!(expected.as_str(), actual.url.as_str()); + } + } + + #[test] + #[should_panic] + fn round_robin_next_empty_fails() { + let connections = MultiNodeConnectionPool::round_robin(vec![], None); + connections.next(); + } } diff --git a/elasticsearch/src/lib.rs b/elasticsearch/src/lib.rs index cd114877..b8b5d4f5 100644 --- a/elasticsearch/src/lib.rs +++ b/elasticsearch/src/lib.rs @@ -353,6 +353,9 @@ type _DoctestReadme = (); #[macro_use] extern crate dyn_clone; +#[macro_use] +extern crate lazy_static; + pub mod auth; pub mod cert; pub mod http;