Skip to content

Commit

Permalink
make create_session async
Browse files Browse the repository at this point in the history
  • Loading branch information
ariesdevil committed Mar 8, 2022
1 parent 5d1f1f6 commit 8f24cc4
Show file tree
Hide file tree
Showing 99 changed files with 214 additions and 187 deletions.
2 changes: 1 addition & 1 deletion query/benches/suites/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub mod bench_sort_query_sql;

pub async fn select_executor(sql: &str) -> Result<()> {
let sessions = SessionManager::from_conf(Config::default()).await?;
let executor_session = sessions.create_session("Benches")?;
let executor_session = sessions.create_session("Benches").await?;
let ctx = executor_session.create_query_context().await?;

if let PlanNode::Select(plan) = PlanParser::parse(ctx.clone(), sql).await? {
Expand Down
2 changes: 1 addition & 1 deletion query/src/api/http/v1/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub async fn cluster_list_handler(
}

async fn list_nodes(sessions: &Arc<SessionManager>) -> Result<Vec<Arc<NodeInfo>>> {
let watch_cluster_session = sessions.create_session("WatchCluster")?;
let watch_cluster_session = sessions.create_session("WatchCluster").await?;
let watch_cluster_context = watch_cluster_session.create_query_context().await?;
Ok(watch_cluster_context.get_cluster().get_nodes())
}
2 changes: 1 addition & 1 deletion query/src/api/http/v1/logs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub async fn logs_handler(
}

async fn select_table(sessions: &Arc<SessionManager>) -> Result<Body> {
let session = sessions.create_session("WatchLogs")?;
let session = sessions.create_session("WatchLogs").await?;
let query_context = session.create_query_context().await?;
let mut tracing_table_stream = execute_query(query_context).await?;

Expand Down
10 changes: 8 additions & 2 deletions query/src/api/rpc/flight_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ impl FlightService for DatabendQueryFlightService {
FlightAction::BroadcastAction(action) => {
let session_id = action.query_id.clone();
let is_aborted = self.dispatcher.is_aborted();
let session = self.sessions.create_rpc_session(session_id, is_aborted)?;
let session = self
.sessions
.create_rpc_session(session_id, is_aborted)
.await?;

self.dispatcher
.broadcast_action(session, flight_action)
Expand All @@ -171,7 +174,10 @@ impl FlightService for DatabendQueryFlightService {
FlightAction::PrepareShuffleAction(action) => {
let session_id = action.query_id.clone();
let is_aborted = self.dispatcher.is_aborted();
let session = self.sessions.create_rpc_session(session_id, is_aborted)?;
let session = self
.sessions
.create_rpc_session(session_id, is_aborted)
.await?;

self.dispatcher
.shuffle_action(session, flight_action)
Expand Down
32 changes: 16 additions & 16 deletions query/src/servers/clickhouse/clickhouse_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,27 @@ impl ClickHouseHandler {
})
}

fn reject_connection(stream: TcpStream, executor: Arc<Runtime>, error: ErrorCode) {
executor.spawn(async move {
if let Err(error) = RejectCHConnection::reject(stream, error).await {
tracing::error!(
"Unexpected error occurred during reject connection: {:?}",
error
);
}
});
async fn reject_connection(stream: TcpStream, error: ErrorCode) {
if let Err(error) = RejectCHConnection::reject(stream, error).await {
tracing::error!(
"Unexpected error occurred during reject connection: {:?}",
error
);
}
}

fn accept_socket(sessions: Arc<SessionManager>, executor: Arc<Runtime>, socket: TcpStream) {
match sessions.create_session("ClickHouseSession") {
Err(error) => Self::reject_connection(socket, executor, error),
Ok(session) => {
tracing::info!("ClickHouse connection coming: {:?}", socket.peer_addr());
if let Err(error) = ClickHouseConnection::run_on_stream(session, socket) {
tracing::error!("Unexpected error occurred during query: {:?}", error);
executor.spawn(async move {
match sessions.create_session("ClickHouseSession").await {
Err(error) => Self::reject_connection(socket, error).await,
Ok(session) => {
tracing::info!("ClickHouse connection coming: {:?}", socket.peer_addr());
if let Err(error) = ClickHouseConnection::run_on_stream(session, socket) {
tracing::error!("Unexpected error occurred during query: {:?}", error);
}
}
}
}
});
}
}

Expand Down
1 change: 1 addition & 0 deletions query/src/servers/http/v1/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub async fn streaming_load(
let session_manager = sessions_extension.0;
let session = session_manager
.create_session("Streaming load")
.await
.map_err(InternalServerError)?;

// TODO: list user's grant list and check client address
Expand Down
2 changes: 1 addition & 1 deletion query/src/servers/http/v1/query/execute_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl ExecuteState {
) -> Result<(Arc<RwLock<Executor>>, DataSchemaRef)> {
let sql = &request.sql;
let start_time = Instant::now();
let session = session_manager.create_session("http-statement")?;
let session = session_manager.create_session("http-statement").await?;
let ctx = session.create_query_context().await?;
if let Some(db) = &request.session.database {
ctx.set_current_database(db.clone()).await?;
Expand Down
44 changes: 21 additions & 23 deletions query/src/servers/mysql/mysql_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,33 +80,31 @@ impl MySQLHandler {
}

fn accept_socket(sessions: Arc<SessionManager>, executor: Arc<Runtime>, socket: TcpStream) {
match sessions.create_session("MySQL") {
Err(error) => Self::reject_session(socket, executor, error),
Ok(session) => {
tracing::info!("MySQL connection coming: {:?}", socket.peer_addr());
if let Err(error) = MySQLConnection::run_on_stream(session, socket) {
tracing::error!("Unexpected error occurred during query: {:?}", error);
};
executor.spawn(async move {
match sessions.create_session("MySQL").await {
Err(error) => Self::reject_session(socket, error).await,
Ok(session) => {
tracing::info!("MySQL connection coming: {:?}", socket.peer_addr());
if let Err(error) = MySQLConnection::run_on_stream(session, socket) {
tracing::error!("Unexpected error occurred during query: {:?}", error);
};
}
}
}
});
}

fn reject_session(stream: TcpStream, executor: Arc<Runtime>, error: ErrorCode) {
executor.spawn(async move {
let (kind, message) = match error.code() {
41 => (ErrorKind::ER_TOO_MANY_USER_CONNECTIONS, error.message()),
_ => (ErrorKind::ER_INTERNAL_ERROR, error.message()),
};
async fn reject_session(stream: TcpStream, error: ErrorCode) {
let (kind, message) = match error.code() {
41 => (ErrorKind::ER_TOO_MANY_USER_CONNECTIONS, error.message()),
_ => (ErrorKind::ER_INTERNAL_ERROR, error.message()),
};

if let Err(error) =
RejectConnection::reject_mysql_connection(stream, kind, message).await
{
tracing::error!(
"Unexpected error occurred during reject connection: {:?}",
error
);
}
});
if let Err(error) = RejectConnection::reject_mysql_connection(stream, kind, message).await {
tracing::error!(
"Unexpected error occurred during reject connection: {:?}",
error
);
}
}
}

Expand Down
10 changes: 3 additions & 7 deletions query/src/sessions/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,14 @@ pub struct Session {
}

impl Session {
pub fn try_create(
pub async fn try_create(
conf: Config,
id: String,
typ: String,
session_mgr: Arc<SessionManager>,
) -> Result<Arc<Session>> {
let user_manager =
futures::executor::block_on(UserApiProvider::create_global(conf.clone()))?;
let auth_manager = Arc::new(futures::executor::block_on(AuthMgr::create(
conf.clone(),
user_manager.clone(),
))?);
let user_manager = UserApiProvider::create_global(conf.clone()).await?;
let auth_manager = Arc::new(AuthMgr::create(conf.clone(), user_manager.clone()).await?);
let role_cache_manager = Arc::new(RoleCacheMgr::new(user_manager.clone()));
let session_ctx = Arc::new(SessionContext::try_create(conf.clone())?);
let session_settings =
Expand Down
42 changes: 28 additions & 14 deletions query/src/sessions/session_mgr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use std::sync::Arc;
use std::time::Duration;

use common_base::tokio;
use common_base::tokio::sync::RwLock;
use common_base::SignalStream;
use common_exception::ErrorCode;
use common_exception::Result;
use common_infallible::RwLock;
use common_metrics::label_counter;
use common_tracing::tracing;
use futures::future::Either;
Expand Down Expand Up @@ -124,9 +124,12 @@ impl SessionManager {
self.storage_cache_manager.as_ref()
}

pub fn create_session(self: &Arc<Self>, typ: impl Into<String>) -> Result<SessionRef> {
let mut sessions = self.active_sessions.write();
match sessions.len() == self.max_sessions {
pub async fn create_session(self: &Arc<Self>, typ: impl Into<String>) -> Result<SessionRef> {
let total_sessions = {
let sessions = self.active_sessions.read().await;
sessions.len()
};
match total_sessions == self.max_sessions {
true => Err(ErrorCode::TooManyUserConnections(
"The current accept connection has exceeded mysql_handler_thread_num config",
)),
Expand All @@ -136,22 +139,30 @@ impl SessionManager {
uuid::Uuid::new_v4().to_string(),
typ.into(),
self.clone(),
)?;
)
.await?;

label_counter(
super::metrics::METRIC_SESSION_CONNECT_NUMBERS,
&self.conf.query.tenant_id,
&self.conf.query.cluster_id,
);

sessions.insert(session.get_id(), session.clone());
{
let mut sessions = self.active_sessions.write().await;
sessions.insert(session.get_id(), session.clone());
}
Ok(SessionRef::create(session))
}
}
}

pub fn create_rpc_session(self: &Arc<Self>, id: String, aborted: bool) -> Result<SessionRef> {
let mut sessions = self.active_sessions.write();
pub async fn create_rpc_session(
self: &Arc<Self>,
id: String,
aborted: bool,
) -> Result<SessionRef> {
let mut sessions = self.active_sessions.write().await;

let session = match sessions.entry(id) {
Occupied(entry) => entry.get().clone(),
Expand All @@ -162,7 +173,8 @@ impl SessionManager {
entry.key().clone(),
String::from("RPCSession"),
self.clone(),
)?;
)
.await?;

label_counter(
super::metrics::METRIC_SESSION_CONNECT_NUMBERS,
Expand All @@ -179,7 +191,7 @@ impl SessionManager {

#[allow(clippy::ptr_arg)]
pub fn get_session_by_id(self: &Arc<Self>, id: &str) -> Option<SessionRef> {
let sessions = self.active_sessions.read();
let sessions = futures::executor::block_on(self.active_sessions.read());
sessions
.get(id)
.map(|session| SessionRef::create(session.clone()))
Expand All @@ -193,7 +205,8 @@ impl SessionManager {
&self.conf.query.cluster_id,
);

self.active_sessions.write().remove(session_id);
let mut sessions = futures::executor::block_on(self.active_sessions.write());
sessions.remove(session_id);
}

pub fn graceful_shutdown(
Expand Down Expand Up @@ -224,14 +237,15 @@ impl SessionManager {
tracing::info!("Will shutdown forcefully.");
active_sessions
.read()
.await
.values()
.for_each(Session::force_kill_session);
}
}

pub fn processes_info(self: &Arc<Self>) -> Vec<ProcessInfo> {
self.active_sessions
.read()
let sessions = futures::executor::block_on(self.active_sessions.read());
sessions
.values()
.map(Session::process_info)
.collect::<Vec<_>>()
Expand All @@ -240,7 +254,7 @@ impl SessionManager {
fn destroy_idle_sessions(sessions: &Arc<RwLock<HashMap<String, Arc<Session>>>>) -> bool {
// Read lock does not support reentrant
// https://github.com/Amanieu/parking_lot/blob/lock_api-0.4.4/lock_api/src/rwlock.rs#L422
let active_sessions_read_guard = sessions.read();
let active_sessions_read_guard = futures::executor::block_on(sessions.read());

// First try to kill the idle session
active_sessions_read_guard.values().for_each(Session::kill);
Expand Down
2 changes: 1 addition & 1 deletion query/tests/it/api/rpc/flight_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::tests::create_query_context;

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_shuffle_action_try_into() -> Result<()> {
let ctx = create_query_context()?;
let ctx = create_query_context().await?;

let shuffle_action = ShuffleAction {
query_id: String::from("query_id"),
Expand Down
8 changes: 4 additions & 4 deletions query/tests/it/api/rpc/flight_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ async fn test_get_stream_with_non_exists_stream() -> Result<()> {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_run_shuffle_action_with_no_scatters() -> Result<()> {
if let (Some(query_id), Some(stage_id), Some(stream_id)) = generate_uuids(3) {
let ctx = create_query_context()?;
let ctx = create_query_context().await?;
let flight_dispatcher = DatabendQueryFlightDispatcher::create();

let sessions = SessionManagerBuilder::create().build()?;
let rpc_session = sessions.create_rpc_session(query_id.clone(), false)?;
let rpc_session = sessions.create_rpc_session(query_id.clone(), false).await?;

flight_dispatcher
.shuffle_action(
Expand Down Expand Up @@ -94,11 +94,11 @@ async fn test_run_shuffle_action_with_no_scatters() -> Result<()> {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_run_shuffle_action_with_scatter() -> Result<()> {
if let (Some(query_id), Some(stage_id), None) = generate_uuids(2) {
let ctx = create_query_context()?;
let ctx = create_query_context().await?;
let flight_dispatcher = DatabendQueryFlightDispatcher::create();

let sessions = SessionManagerBuilder::create().build()?;
let rpc_session = sessions.create_rpc_session(query_id.clone(), false)?;
let rpc_session = sessions.create_rpc_session(query_id.clone(), false).await?;

flight_dispatcher
.shuffle_action(
Expand Down
2 changes: 1 addition & 1 deletion query/tests/it/api/rpc/flight_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ fn do_get_request(query_id: &str, stage_id: &str) -> Result<Request<Ticket>> {
}

async fn do_action_request(query_id: &str, stage_id: &str) -> Result<Request<Action>> {
let ctx = create_query_context()?;
let ctx = create_query_context().await?;
let flight_action = FlightAction::PrepareShuffleAction(ShuffleAction {
query_id: String::from(query_id),
stage_id: String::from(stage_id),
Expand Down
7 changes: 4 additions & 3 deletions query/tests/it/functions/context_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_base::tokio;
use common_exception::Result;
use databend_query::functions::ContextFunction;

#[test]
fn test_context_function_build_arg_from_ctx() -> Result<()> {
#[tokio::test]
async fn test_context_function_build_arg_from_ctx() -> Result<()> {
use pretty_assertions::assert_eq;
let ctx = crate::tests::create_query_context()?;
let ctx = crate::tests::create_query_context().await?;

// Ok.
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async fn test_management_mode_access() -> Result<()> {
let conf = crate::tests::ConfigBuilder::create()
.with_management_mode()
.config();
let ctx = crate::tests::create_query_context_with_config(conf.clone())?;
let ctx = crate::tests::create_query_context_with_config(conf.clone()).await?;
// First to set tenant.
{
let plan = PlanParser::parse(ctx.clone(), "SUDO USE TENANT 'test'").await?;
Expand Down
4 changes: 2 additions & 2 deletions query/tests/it/interpreters/interpreter_admin_use_tenant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async fn test_use_tenant_interpreter() -> Result<()> {
let conf = crate::tests::ConfigBuilder::create()
.with_management_mode()
.config();
let ctx = crate::tests::create_query_context_with_config(conf.clone())?;
let ctx = crate::tests::create_query_context_with_config(conf.clone()).await?;

let plan = PlanParser::parse(ctx.clone(), "SUDO USE TENANT 't1'").await?;
let interpreter = InterpreterFactory::get(ctx.clone(), plan)?;
Expand All @@ -41,7 +41,7 @@ async fn test_use_tenant_interpreter() -> Result<()> {

#[tokio::test]
async fn test_use_tenant_interpreter_error() -> Result<()> {
let ctx = crate::tests::create_query_context()?;
let ctx = crate::tests::create_query_context().await?;

let plan = PlanParser::parse(ctx.clone(), "SUDO USE TENANT 't1'").await?;
let interpreter = InterpreterFactory::get(ctx, plan)?;
Expand Down
Loading

0 comments on commit 8f24cc4

Please sign in to comment.