Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: avoid grpc forwarding twice #991

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions proxy/src/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use tonic::{
transport::{self, Channel},
};

use crate::FORWARDED_FROM;

#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display(
Expand Down Expand Up @@ -68,6 +70,9 @@ pub enum Error {
source: tonic::transport::Error,
backtrace: Backtrace,
},

#[snafu(display("Request should not be forwarded twice, forward from:{}", endpoint))]
ForwardedErr { endpoint: String },
}

define_result!(Error);
Expand Down Expand Up @@ -184,6 +189,7 @@ pub struct ForwardRequest<Req> {
pub schema: String,
pub table: String,
pub req: tonic::Request<Req>,
pub forwarded_from: Option<String>,
}

impl Forwarder<DefaultClientBuilder> {
Expand Down Expand Up @@ -256,7 +262,12 @@ impl<B: ClientBuilder> Forwarder<B> {
F: ForwarderRpc<Req, Resp, Err>,
Req: std::fmt::Debug + Clone,
{
let ForwardRequest { schema, table, req } = forward_req;
let ForwardRequest {
schema,
table,
req,
forwarded_from,
} = forward_req;

let route_req = RouteRequest {
context: Some(RequestContext { database: schema }),
Expand All @@ -281,13 +292,15 @@ impl<B: ClientBuilder> Forwarder<B> {
}
};

self.forward_with_endpoint(endpoint, req, do_rpc).await
self.forward_with_endpoint(endpoint, req, forwarded_from, do_rpc)
.await
}

pub async fn forward_with_endpoint<Req, Resp, Err, F>(
&self,
endpoint: Endpoint,
mut req: tonic::Request<Req>,
forwarded_from: Option<String>,
do_rpc: F,
) -> Result<ForwardResult<Resp, Err>>
where
Expand All @@ -310,6 +323,17 @@ impl<B: ClientBuilder> Forwarder<B> {
"Try to forward request to {:?}, request:{:?}",
endpoint, req,
);

if let Some(endpoint) = forwarded_from {
return ForwardedErr { endpoint }.fail();
}

// mark forwarded
req.metadata_mut().insert(
FORWARDED_FROM,
self.local_endpoint.to_string().parse().unwrap(),
);

let client = self.get_or_create_client(&endpoint).await?;
match do_rpc(client, req, &endpoint).await {
Err(e) => {
Expand Down Expand Up @@ -461,6 +485,7 @@ mod tests {
schema: DEFAULT_SCHEMA.to_string(),
table: table.to_string(),
req: query_request.into_request(),
forwarded_from: None,
}
};

Expand Down
8 changes: 7 additions & 1 deletion proxy/src/grpc/sql_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

let req_context = req.context.as_ref().unwrap();
let schema = req_context.database.clone();
let req = match self.clone().maybe_forward_stream_sql_query(&req).await {
let req = match self
.clone()
.maybe_forward_stream_sql_query(ctx.clone(), &req)
.await
{
Some(resp) => match resp {
ForwardResult::Forwarded(resp) => return resp,
ForwardResult::Local => req,
Expand Down Expand Up @@ -150,6 +154,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

async fn maybe_forward_stream_sql_query(
self: Arc<Self>,
ctx: Context,
req: &SqlQueryRequest,
) -> Option<ForwardResult<BoxStream<'static, SqlQueryResponse>, Error>> {
if req.tables.len() != 1 {
Expand All @@ -163,6 +168,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: req_ctx.database.clone(),
table: req.tables[0].clone(),
req: req.clone().into_request(),
forwarded_from: ctx.forwarded_from,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<SqlQueryRequest>,
Expand Down
1 change: 1 addition & 0 deletions proxy/src/http/prom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
runtime: self.engine_runtimes.write_runtime.clone(),
timeout: ctx.timeout,
enable_partition_table_access: false,
forwarded_from: None,
};

let result = self.handle_write_internal(ctx, table_request).await?;
Expand Down
1 change: 1 addition & 0 deletions proxy/src/http/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
timeout: ctx.timeout,
runtime: self.engine_runtimes.read_runtime.clone(),
enable_partition_table_access: true,
forwarded_from: None,
};

match self.handle_sql(context, &ctx.schema, &req.query).await? {
Expand Down
1 change: 1 addition & 0 deletions proxy/src/influxdb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
timeout: ctx.timeout,
runtime: self.engine_runtimes.write_runtime.clone(),
enable_partition_table_access: false,
forwarded_from: None,
};
let result = self
.handle_write_internal(proxy_context, table_request)
Expand Down
4 changes: 4 additions & 0 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod schema_config_provider;
mod util;
mod write;

pub const FORWARDED_FROM: &str = "forwarded-from";

use std::{
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -131,6 +133,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: req_ctx.database.clone(),
table: metric,
req: req.into_request(),
forwarded_from: None,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<PrometheusRemoteQueryRequest>,
Expand Down Expand Up @@ -452,4 +455,5 @@ pub struct Context {
pub timeout: Option<Duration>,
pub runtime: Arc<Runtime>,
pub enable_partition_table_access: bool,
pub forwarded_from: Option<String>,
}
7 changes: 6 additions & 1 deletion proxy/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: &str,
sql: &str,
) -> Result<SqlResponse> {
if let Some(resp) = self.maybe_forward_sql_query(schema, sql).await? {
if let Some(resp) = self
.maybe_forward_sql_query(ctx.clone(), schema, sql)
.await?
{
match resp {
ForwardResult::Forwarded(resp) => return Ok(SqlResponse::Forwarded(resp?)),
ForwardResult::Local => (),
Expand Down Expand Up @@ -149,6 +152,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

async fn maybe_forward_sql_query(
&self,
ctx: Context,
schema: &str,
sql: &str,
) -> Result<Option<ForwardResult<SqlQueryResponse, Error>>> {
Expand All @@ -174,6 +178,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: schema.to_string(),
table: table_name.unwrap(),
req: sql_request.into_request(),
forwarded_from: ctx.forwarded_from,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<SqlQueryRequest>,
Expand Down
16 changes: 12 additions & 4 deletions proxy/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
let mut futures = Vec::with_capacity(write_requests_to_forward.len() + 1);

// Write to remote.
self.collect_write_to_remote_future(&mut futures, write_requests_to_forward)
self.collect_write_to_remote_future(&mut futures, ctx.clone(), write_requests_to_forward)
.await;

// Write to local.
Expand Down Expand Up @@ -139,7 +139,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
let mut futures = Vec::with_capacity(write_requests_to_forward.len() + 1);

// Write to remote.
self.collect_write_to_remote_future(&mut futures, write_requests_to_forward)
self.collect_write_to_remote_future(&mut futures, ctx.clone(), write_requests_to_forward)
.await;

// Create table.
Expand Down Expand Up @@ -358,12 +358,14 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
async fn collect_write_to_remote_future(
&self,
futures: &mut WriteResponseFutures<'_>,
ctx: Context,
write_request: HashMap<Endpoint, WriteRequest>,
) {
for (endpoint, table_write_request) in write_request {
let forwarder = self.forwarder.clone();
let ctx = ctx.clone();
let write_handle = self.engine_runtimes.io_runtime.spawn(async move {
Self::write_to_remote(forwarder, endpoint, table_write_request).await
Self::write_to_remote(ctx, forwarder, endpoint, table_write_request).await
});

futures.push(write_handle.boxed());
Expand Down Expand Up @@ -408,6 +410,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
}

async fn write_to_remote(
ctx: Context,
forwarder: ForwarderRef,
endpoint: Endpoint,
table_write_request: WriteRequest,
Expand All @@ -432,7 +435,12 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
};

let forward_result = forwarder
.forward_with_endpoint(endpoint, tonic::Request::new(table_write_request), do_write)
.forward_with_endpoint(
endpoint,
tonic::Request::new(table_write_request),
ctx.forwarded_from,
do_write,
)
.await;
let forward_res = forward_result
.map_err(|e| {
Expand Down
48 changes: 37 additions & 11 deletions server/src/grpc/storage_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use ceresdbproto::{
use common_util::time::InstantExt;
use futures::{stream, stream::BoxStream, StreamExt};
use http::StatusCode;
use proxy::{Context, Proxy};
use proxy::{Context, Proxy, FORWARDED_FROM};
use query_engine::executor::Executor as QueryExecutor;
use table_engine::engine::EngineRuntimes;

Expand Down Expand Up @@ -138,6 +138,10 @@ impl<Q: QueryExecutor + 'static> StorageService for StorageServiceImpl<Q> {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let stream = Self::stream_sql_query_internal(ctx, proxy, req).await;

Expand All @@ -155,13 +159,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<RouteRequest>,
) -> Result<tonic::Response<RouteResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self
.runtimes
Expand All @@ -186,13 +194,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<WriteRequest>,
) -> Result<tonic::Response<WriteResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.write_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.write_runtime.spawn(async move {
if req.context.is_none() {
Expand Down Expand Up @@ -226,13 +238,18 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<SqlQueryRequest>,
) -> Result<tonic::Response<SqlQueryResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self
.runtimes
.read_runtime
Expand Down Expand Up @@ -289,13 +306,18 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<PrometheusQueryRequest>,
) -> Result<tonic::Response<PrometheusQueryResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.read_runtime.spawn(async move {
if req.context.is_none() {
return PrometheusQueryResponse {
Expand Down Expand Up @@ -329,13 +351,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
) -> Result<tonic::Response<WriteResponse>, tonic::Status> {
let mut total_success = 0;

let mut stream = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.write_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let mut stream = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.write_runtime.spawn(async move {
let mut resp = WriteResponse::default();
Expand Down