From 3bcf8b9ddce819e81aab1014635b6a34b1942292 Mon Sep 17 00:00:00 2001 From: Yang Xiufeng Date: Fri, 18 Oct 2024 15:55:15 +0800 Subject: [PATCH] feat: http handler support request forwarding. --- src/common/base/src/headers.rs | 1 + src/query/service/src/clusters/cluster.rs | 33 +++++- .../service/src/servers/http/http_services.rs | 72 +----------- .../src/servers/http/middleware/session.rs | 94 +++++++++++++++- .../servers/http/v1/http_query_handlers.rs | 65 ++++++++--- .../it/servers/http/http_query_handlers.rs | 5 +- .../09_http_handler/09_0008_forward.py | 104 ++++++++++++++++++ .../09_http_handler/09_0008_forward.result | 0 8 files changed, 284 insertions(+), 90 deletions(-) create mode 100755 tests/suites/1_stateful/09_http_handler/09_0008_forward.py create mode 100644 tests/suites/1_stateful/09_http_handler/09_0008_forward.result diff --git a/src/common/base/src/headers.rs b/src/common/base/src/headers.rs index 823e28e9e0a1..95d1e56802d9 100644 --- a/src/common/base/src/headers.rs +++ b/src/common/base/src/headers.rs @@ -26,6 +26,7 @@ pub const HEADER_NODE_ID: &str = "X-DATABEND-NODE-ID"; pub const HEADER_QUERY_STATE: &str = "X-DATABEND-QUERY-STATE"; pub const HEADER_QUERY_PAGE_ROWS: &str = "X-DATABEND-QUERY-PAGE-ROWS"; pub const HEADER_VERSION: &str = "X-DATABEND-VERSION"; +pub const HEADER_STICKY: &str = "X-DATABEND-STICKY-NODE"; pub const HEADER_SIGNATURE: &str = "X-DATABEND-SIGNATURE"; pub const HEADER_AUTH_METHOD: &str = "X-DATABEND-AUTH-METHOD"; diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index 92cd2d3b4076..7186bbb82848 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -51,6 +51,7 @@ use futures::Future; use futures::StreamExt; use log::error; use log::warn; +use parking_lot::RwLock; use rand::thread_rng; use rand::Rng; use serde::Deserialize; @@ -66,6 +67,7 @@ pub struct ClusterDiscovery { cluster_id: String, tenant_id: String, flight_address: String, + cached_cluster: RwLock>>, } // avoid leak FlightClient to common-xxx @@ -200,6 +202,7 @@ impl ClusterDiscovery { cluster_id: cfg.query.cluster_id.clone(), tenant_id: cfg.query.tenant_id.tenant_name().to_string(), flight_address: cfg.query.flight_api_address.clone(), + cached_cluster: Default::default(), })) } @@ -261,11 +264,39 @@ impl ClusterDiscovery { &self.flight_address, cluster_nodes.len() as f64, ); - Ok(Cluster::create(res, self.local_id.clone())) + let res = Cluster::create(res, self.local_id.clone()); + *self.cached_cluster.write() = Some(res.clone()); + Ok(res) } } } + fn cached_cluster(self: &Arc) -> Option> { + (*self.cached_cluster.read()).clone() + } + + pub async fn find_node_by_id( + self: Arc, + id: &str, + config: &InnerConfig, + ) -> Result>> { + let (mut cluster, mut is_cached) = if let Some(cluster) = self.cached_cluster() { + (cluster, true) + } else { + (self.discover(config).await?, false) + }; + while is_cached { + for node in cluster.get_nodes() { + if node.id == id { + return Ok(Some(node.clone())); + } + } + cluster = self.discover(config).await?; + is_cached = false; + } + Ok(None) + } + #[async_backtrace::framed] async fn drop_invalid_nodes(self: &Arc, node_info: &NodeInfo) -> Result<()> { let current_nodes_info = match self.api_provider.get_nodes().await { diff --git a/src/query/service/src/servers/http/http_services.rs b/src/query/service/src/servers/http/http_services.rs index abd9ec8f58ec..11924c6f1f64 100644 --- a/src/query/service/src/servers/http/http_services.rs +++ b/src/query/service/src/servers/http/http_services.rs @@ -28,27 +28,18 @@ use poem::listener::OpensslTlsConfig; use poem::middleware::CatchPanic; use poem::middleware::NormalizePath; use poem::middleware::TrailingSlash; -use poem::post; -use poem::put; use poem::Endpoint; use poem::EndpointExt; -use poem::IntoEndpoint; use poem::IntoResponse; use poem::Route; -use super::v1::discovery_nodes; -use super::v1::logout_handler; -use super::v1::upload_to_stage; use super::v1::HttpQueryContext; use crate::servers::http::middleware::json_response; use crate::servers::http::middleware::EndpointKind; use crate::servers::http::middleware::HTTPSessionMiddleware; use crate::servers::http::middleware::PanicHandler; use crate::servers::http::v1::clickhouse_router; -use crate::servers::http::v1::list_suggestions; -use crate::servers::http::v1::login_handler; use crate::servers::http::v1::query_route; -use crate::servers::http::v1::refresh_handler; use crate::servers::Server; #[derive(Copy, Clone)] @@ -98,70 +89,9 @@ impl HttpHandler { }) } - pub fn wrap_auth(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint - where - E: IntoEndpoint, - E::Endpoint: 'static, - { - let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_type); - ep.with(session_middleware).boxed() - } - #[allow(clippy::let_with_type_underscore)] #[async_backtrace::framed] async fn build_router(&self, sock: SocketAddr) -> impl Endpoint { - let ep_v1 = Route::new() - .nest("/query", query_route(self.kind)) - .at( - "/session/login", - post(login_handler).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::Login, - )), - ) - .at( - "/session/logout", - post(logout_handler).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::Logout, - )), - ) - .at( - "/session/refresh", - post(refresh_handler).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::Refresh, - )), - ) - .at( - "/auth/verify", - get(verify_handler).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::Verify, - )), - ) - .at( - "/upload_to_stage", - put(upload_to_stage).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::StartQuery, - )), - ) - .at( - "/suggested_background_tasks", - get(list_suggestions).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::StartQuery, - )), - ) - .at( - "/discovery_nodes", - get(discovery_nodes).with(HTTPSessionMiddleware::create( - self.kind, - EndpointKind::NoAuth, - )), - ); - let ep_clickhouse = Route::new() .nest("/", clickhouse_router()) @@ -182,7 +112,7 @@ impl HttpHandler { HttpHandlerKind::Query => Route::new() .at("/", ep_usage) .nest("/health", ep_health) - .nest("/v1", ep_v1) + .nest("/v1", query_route()) .nest("/clickhouse", ep_clickhouse), HttpHandlerKind::Clickhouse => Route::new() .nest("/", ep_clickhouse) diff --git a/src/query/service/src/servers/http/middleware/session.rs b/src/query/service/src/servers/http/middleware/session.rs index 3a95456f4f36..ad29e3702b0d 100644 --- a/src/query/service/src/servers/http/middleware/session.rs +++ b/src/query/service/src/servers/http/middleware/session.rs @@ -19,6 +19,7 @@ use databend_common_base::headers::HEADER_DEDUPLICATE_LABEL; use databend_common_base::headers::HEADER_NODE_ID; use databend_common_base::headers::HEADER_QUERY_ID; use databend_common_base::headers::HEADER_SESSION_ID; +use databend_common_base::headers::HEADER_STICKY; use databend_common_base::headers::HEADER_TENANT; use databend_common_base::headers::HEADER_VERSION; use databend_common_base::runtime::ThreadTracker; @@ -28,6 +29,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_meta_app::principal::user_token::TokenType; use databend_common_meta_app::tenant::Tenant; +use databend_common_meta_types::NodeInfo; use fastrace::func_name; use headers::authorization::Basic; use headers::authorization::Bearer; @@ -47,6 +49,7 @@ use poem::error::Result as PoemResult; use poem::web::Json; use poem::Addr; use poem::Endpoint; +use poem::Error; use poem::IntoResponse; use poem::Middleware; use poem::Request; @@ -55,6 +58,7 @@ use uuid::Uuid; use crate::auth::AuthMgr; use crate::auth::Credential; +use crate::clusters::ClusterDiscovery; use crate::servers::http::error::HttpErrorCode; use crate::servers::http::error::JsonErrorOnly; use crate::servers::http::error::QueryError; @@ -82,6 +86,12 @@ impl EndpointKind { pub fn need_user_info(&self) -> bool { !matches!(self, EndpointKind::NoAuth | EndpointKind::PollQuery) } + pub fn may_need_sticky(&self) -> bool { + matches!( + self, + EndpointKind::StartQuery | EndpointKind::PollQuery | EndpointKind::Logout + ) + } pub fn require_databend_token_type(&self) -> Result> { match self { EndpointKind::Verify => Ok(None), @@ -372,14 +382,96 @@ impl HTTPSessionEndpoint { } } +async fn forward_request(mut req: Request, node: Arc) -> PoemResult { + let addr = node.http_address.clone(); + let config = GlobalConfig::instance(); + let scheme = if config.query.http_handler_tls_server_key.is_empty() + || config.query.http_handler_tls_server_cert.is_empty() + { + "http" + } else { + "https" + }; + let url = format!("{scheme}://{addr}/v1{}", req.uri()); + + let client = reqwest::Client::new(); + let mut request_builder = client.request(req.method().clone(), &url); + for (name, value) in req.headers().iter() { + request_builder = request_builder.header(name, value); + } + let reqwest_request = request_builder + .body(req.take_body().into_bytes().await?) + .build() + .map_err(|e| { + HttpErrorCode::bad_request(ErrorCode::BadArguments(format!( + "fail to build forward request: {e}" + ))) + })?; + + let response = client.execute(reqwest_request).await.map_err(|e| { + HttpErrorCode::server_error(ErrorCode::Internal(format!( + "fail to send forward request: {e}", + ))) + })?; + + let status = StatusCode::from_u16(response.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let headers = response.headers().clone(); + let body = response.bytes().await.map_err(|e| { + HttpErrorCode::server_error(ErrorCode::Internal(format!( + "fail to send forward request: {e}", + ))) + })?; + let mut poem_resp = Response::builder().status(status).body(body); + let headers_ref = poem_resp.headers_mut(); + for (key, value) in headers.iter() { + headers_ref.insert(key, value.clone()); + } + Ok(poem_resp) +} + impl Endpoint for HTTPSessionEndpoint { type Output = Response; #[async_backtrace::framed] async fn call(&self, mut req: Request) -> PoemResult { + let headers = req.headers().clone(); + + if self.endpoint_kind.may_need_sticky() + && let Some(sticky_node_id) = headers.get(HEADER_STICKY) + { + let sticky_node_id = sticky_node_id + .to_str() + .map_err(|e| { + HttpErrorCode::bad_request(ErrorCode::BadArguments(format!( + "Invalid Header ({HEADER_STICKY}: {sticky_node_id:?}): {e}" + ))) + })? + .to_string(); + let local_id = GlobalConfig::instance().query.node_id.clone(); + if local_id != sticky_node_id { + let config = GlobalConfig::instance(); + return if let Some(node) = ClusterDiscovery::instance() + .find_node_by_id(&sticky_node_id, &config) + .await + .map_err(HttpErrorCode::server_error)? + { + log::info!( + "forwarding {} from {local_id} to {sticky_node_id}", + req.uri() + ); + forward_request(req, node).await + } else { + let msg = format!("sticky_node_id '{sticky_node_id}' not found in cluster",); + warn!("{}", msg); + Err(Error::from(HttpErrorCode::bad_request( + ErrorCode::BadArguments(msg), + ))) + }; + } + } let method = req.method().clone(); let uri = req.uri().clone(); - let headers = req.headers().clone(); let query_id = req .headers() diff --git a/src/query/service/src/servers/http/v1/http_query_handlers.rs b/src/query/service/src/servers/http/v1/http_query_handlers.rs index d35606725618..1eacdd5c9dd2 100644 --- a/src/query/service/src/servers/http/v1/http_query_handlers.rs +++ b/src/query/service/src/servers/http/v1/http_query_handlers.rs @@ -31,6 +31,7 @@ use poem::error::Error as PoemError; use poem::error::Result as PoemResult; use poem::get; use poem::post; +use poem::put; use poem::web::Json; use poem::web::Path; use poem::EndpointExt; @@ -48,8 +49,14 @@ use crate::servers::http::error::QueryError; use crate::servers::http::middleware::EndpointKind; use crate::servers::http::middleware::HTTPSessionMiddleware; use crate::servers::http::middleware::MetricsMiddleware; +use crate::servers::http::v1::discovery_nodes; +use crate::servers::http::v1::list_suggestions; +use crate::servers::http::v1::login_handler; +use crate::servers::http::v1::logout_handler; use crate::servers::http::v1::query::string_block::StringBlock; use crate::servers::http::v1::query::Progresses; +use crate::servers::http::v1::refresh_handler; +use crate::servers::http::v1::upload_to_stage; use crate::servers::http::v1::HttpQueryContext; use crate::servers::http::v1::HttpQueryManager; use crate::servers::http::v1::HttpSessionConf; @@ -404,34 +411,66 @@ pub(crate) async fn query_handler( .await } -pub fn query_route(http_handler_kind: HttpHandlerKind) -> Route { +pub fn query_route() -> Route { // Note: endpoints except /v1/query may change without notice, use uris in response instead let rules = [ - ("/", post(query_handler)), - ("/:id", get(query_state_handler)), - ("/:id/page/:page_no", get(query_page_handler)), + ("/query", post(query_handler), EndpointKind::StartQuery), ( - "/:id/kill", + "/query/:id", + get(query_state_handler), + EndpointKind::PollQuery, + ), + ( + "/query/:id/page/:page_no", + get(query_page_handler), + EndpointKind::PollQuery, + ), + ( + "/query/:id/kill", get(query_cancel_handler).post(query_cancel_handler), + EndpointKind::PollQuery, ), ( - "/:id/final", + "/query/:id/final", get(query_final_handler).post(query_final_handler), + EndpointKind::PollQuery, + ), + ("/session/login", post(login_handler), EndpointKind::Login), + ( + "/session/logout", + post(logout_handler), + EndpointKind::Logout, + ), + ( + "/session/refresh", + post(refresh_handler), + EndpointKind::Refresh, + ), + ("/auth/verify", post(refresh_handler), EndpointKind::Verify), + ( + "/upload_to_stage", + put(upload_to_stage), + EndpointKind::StartQuery, + ), + ( + "/suggested_background_tasks", + get(list_suggestions), + EndpointKind::StartQuery, + ), + ( + "/discovery_nodes", + get(discovery_nodes), + EndpointKind::StartQuery, ), ]; let mut route = Route::new(); - for (path, endpoint) in rules.into_iter() { - let kind = if path == "/" { - EndpointKind::StartQuery - } else { - EndpointKind::PollQuery - }; + for (path, endpoint, kind) in rules.into_iter() { route = route.at( path, endpoint .with(MetricsMiddleware::new(path)) - .with(HTTPSessionMiddleware::create(http_handler_kind, kind)), + .with(HTTPSessionMiddleware::create(HttpHandlerKind::Query, kind)), ); } route diff --git a/src/query/service/tests/it/servers/http/http_query_handlers.rs b/src/query/service/tests/it/servers/http/http_query_handlers.rs index 2ebd78d7079d..7be6fef340f4 100644 --- a/src/query/service/tests/it/servers/http/http_query_handlers.rs +++ b/src/query/service/tests/it/servers/http/http_query_handlers.rs @@ -860,10 +860,7 @@ async fn post_sql(sql: &str, wait_time_secs: u64) -> Result<(StatusCode, QueryRe } pub fn create_endpoint() -> Result { - Ok(Route::new().nest( - "/v1/query", - query_route(HttpHandlerKind::Query).around(json_response), - )) + Ok(Route::new().nest("/v1", query_route().around(json_response))) } async fn post_json(json: &serde_json::Value) -> Result<(StatusCode, QueryResponse)> { diff --git a/tests/suites/1_stateful/09_http_handler/09_0008_forward.py b/tests/suites/1_stateful/09_http_handler/09_0008_forward.py new file mode 100755 index 000000000000..9aaa399eb9d8 --- /dev/null +++ b/tests/suites/1_stateful/09_http_handler/09_0008_forward.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import requests + +auth = ("root", "") +STICKY_HEADER = "X-DATABEND-STICKY-NODE" + +import logging +logging.basicConfig(level=logging.ERROR, format='%(asctime)s %(levelname)s %(message)s') + + +def do_query(query, port=8000, session=None, node_id=None): + url = f"http://localhost:{port}/v1/query" + query_payload = {"sql": query, "pagination": {"wait_time_secs": 100, "max_rows_per_page": 2}} + if session: + query_payload['session'] = session + headers={ + "Content-Type": "application/json", + } + if node_id: + headers[STICKY_HEADER] = node_id + + response = requests.post( + url, + headers=headers, + json=query_payload, + auth = auth + ) + return response + + +def test_txn(): + resp = do_query("create or replace table t1(a int)").json() + assert not(resp.get('session').get('need_sticky')), resp + + resp = do_query("begin").json() + assert resp.get('session').get('need_sticky'), resp + node_id = resp.get('node_id') + session = resp.get('session') + + # can not find txn state in node 2 + resp = do_query("insert into t1 values (1)", port=8002, session=session).json() + assert resp.get('state') == 'Failed', resp.text + + # forward to node 1 + resp = do_query("insert into t1 values (2)", port=8002, session=session, node_id=node_id).json() + assert resp.get('state') == 'Succeeded', resp + assert resp.get('session').get('need_sticky'), resp + + # return need_sticky = false after commit + resp = do_query("commit").json() + assert not(resp.get('session').get('need_sticky')), resp + + +def test_query(): + """ each query is sticky + """ + # send SQL to node-1 + initial_resp = do_query("select * from numbers(10)").json() + assert(len(initial_resp.get('data')) == 2) + + # get page from node-2 without header + next_uri = initial_resp.get("next_uri") + next_uri = f"http://localhost:8002/{next_uri}?" + resp = requests.get(next_uri, auth=auth) + assert(resp.status_code == 404) + + # get page from node-2 by forward + node_id = initial_resp.get("node_id") + headers = { + STICKY_HEADER: node_id, + } + resp = requests.get(next_uri, auth=auth, headers=headers) + assert resp.status_code == 200, resp.text + assert(len(resp.json().get('data')) == 2) + + # error: query not exists + resp = requests.get("http://localhost:8002/v1/query/an_query_id/page/0", auth=auth, headers=headers) + assert(resp.status_code == 404), resp.text + + # error: node not exists + headers={ + STICKY_HEADER: "xxx", + } + resp = requests.get(next_uri, auth=auth, headers=headers) + assert(resp.status_code == 400), resp.text + + +def main(): + # only test under cluster mode + query_resp = do_query("select count(*) from system.clusters").json() + num_nodes = int(query_resp.get("data")[0][0]) + if num_nodes == 1: + return + + # test_query() + + test_txn() + +if __name__ == "__main__": + try: + main() + except Exception as e: + logging.exception(f"An error occurred: {e}") diff --git a/tests/suites/1_stateful/09_http_handler/09_0008_forward.result b/tests/suites/1_stateful/09_http_handler/09_0008_forward.result new file mode 100644 index 000000000000..e69de29bb2d1