diff --git a/query/src/datasources/database/system/processes_table.rs b/query/src/datasources/database/system/processes_table.rs index 208209df6ee16..6a70d2985c418 100644 --- a/query/src/datasources/database/system/processes_table.rs +++ b/query/src/datasources/database/system/processes_table.rs @@ -42,6 +42,7 @@ impl ProcessesTable { DataField::new("id", DataType::String, false), DataField::new("type", DataType::String, false), DataField::new("host", DataType::String, true), + DataField::new("user", DataType::String, true), DataField::new("state", DataType::String, false), DataField::new("database", DataType::String, false), DataField::new("extra_info", DataType::String, true), @@ -96,6 +97,7 @@ impl Table for ProcessesTable { let mut processes_id = Vec::with_capacity(processes_info.len()); let mut processes_type = Vec::with_capacity(processes_info.len()); let mut processes_host = Vec::with_capacity(processes_info.len()); + let mut processes_user = Vec::with_capacity(processes_info.len()); let mut processes_state = Vec::with_capacity(processes_info.len()); let mut processes_database = Vec::with_capacity(processes_info.len()); let mut processes_extra_info = Vec::with_capacity(processes_info.len()); @@ -107,6 +109,7 @@ impl Table for ProcessesTable { processes_state.push(process_info.state.clone().into_bytes()); processes_database.push(process_info.database.clone().into_bytes()); processes_host.push(ProcessesTable::process_host(process_info)); + processes_user.push(process_info.user.clone().into_bytes()); processes_extra_info.push(ProcessesTable::process_extra_info(process_info)); processes_memory_usage.push(process_info.memory_usage); } @@ -116,6 +119,7 @@ impl Table for ProcessesTable { Series::new(processes_id), Series::new(processes_type), Series::new(processes_host), + Series::new(processes_user), Series::new(processes_state), Series::new(processes_database), Series::new(processes_extra_info), diff --git a/query/src/servers/clickhouse/interactive_worker.rs b/query/src/servers/clickhouse/interactive_worker.rs index f2d33d883b171..5bdc8ade032c4 100644 --- a/query/src/servers/clickhouse/interactive_worker.rs +++ b/query/src/servers/clickhouse/interactive_worker.rs @@ -109,7 +109,10 @@ impl ClickHouseSession for InteractiveWorker { Err(err) => Err(err), }; match res { - Ok(res) => res, + Ok(res) => { + self.session.set_current_user(user.to_string()); + res + } Err(failure) => { log::error!( "ClickHouse handler authenticate failed, \ diff --git a/query/src/servers/mysql/mysql_interactive_worker.rs b/query/src/servers/mysql/mysql_interactive_worker.rs index 1fd22f9e5164a..085b3f6960f5e 100644 --- a/query/src/servers/mysql/mysql_interactive_worker.rs +++ b/query/src/servers/mysql/mysql_interactive_worker.rs @@ -208,19 +208,24 @@ impl InteractiveWorkerBase { let address = &info.user_client_address; let user_manager = self.session.get_user_manager(); - // TODO: use get_users and check client address + // TODO: list user's grant list and check client address let user_info = user_manager.get_user(user_name, "%").await?; let input = &info.user_password; let saved = &user_info.password; let encode_password = Self::encoding_password(auth_plugin, salt, input, saved)?; - user_manager + let authed = user_manager .auth_user( user_info, CertifiedInfo::create(user_name, encode_password, address), ) - .await + .await?; + if authed { + self.session.set_current_user(user_name.clone()); + } + + Ok(authed) } fn encoding_password( diff --git a/query/src/sessions/session.rs b/query/src/sessions/session.rs index dfa8e47eb9085..7aea91cf0812f 100644 --- a/query/src/sessions/session.rs +++ b/query/src/sessions/session.rs @@ -16,6 +16,7 @@ use std::net::SocketAddr; use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use common_exception::ErrorCode; use common_exception::Result; use common_macros::MallocSizeOf; use common_mem_allocator::malloc_size; @@ -149,6 +150,16 @@ impl Session { self.mutable_state.get_current_database() } + pub fn get_current_user(self: &Arc) -> Result { + self.mutable_state + .get_current_user() + .ok_or_else(|| ErrorCode::AuthenticateFailure("unauthenticated")) + } + + pub fn set_current_user(self: &Arc, user: String) { + self.mutable_state.set_current_user(user) + } + pub fn get_settings(self: &Arc) -> Arc { self.mutable_state.get_settings() } diff --git a/query/src/sessions/session_info.rs b/query/src/sessions/session_info.rs index f33be6f00e639..cbd389123e9ae 100644 --- a/query/src/sessions/session_info.rs +++ b/query/src/sessions/session_info.rs @@ -24,6 +24,7 @@ pub struct ProcessInfo { pub typ: String, pub state: String, pub database: String, + pub user: String, #[allow(unused)] pub settings: Arc, pub client_address: Option, @@ -53,6 +54,7 @@ impl Session { typ: self.typ.clone(), state: self.process_state(status), database: status.get_current_database(), + user: status.get_current_user().unwrap_or_else(|| "".into()), settings: status.get_settings(), client_address: status.get_client_host(), session_extra_info: self.process_extra_info(status), diff --git a/query/src/sessions/session_status.rs b/query/src/sessions/session_status.rs index 5b837fe8208c9..688b9a91f7ede 100644 --- a/query/src/sessions/session_status.rs +++ b/query/src/sessions/session_status.rs @@ -30,6 +30,7 @@ pub struct MutableStatus { abort: AtomicBool, current_database: RwLock, session_settings: RwLock, + current_user: RwLock>, #[ignore_malloc_size_of = "insignificant"] client_host: RwLock>, #[ignore_malloc_size_of = "insignificant"] @@ -42,9 +43,10 @@ impl MutableStatus { pub fn try_create() -> Result { Ok(MutableStatus { abort: Default::default(), + current_user: Default::default(), + client_host: Default::default(), current_database: RwLock::new("default".to_string()), session_settings: RwLock::new(Settings::try_create()?.as_ref().clone()), - client_host: Default::default(), io_shutdown_tx: Default::default(), context_shared: Default::default(), }) @@ -72,6 +74,18 @@ impl MutableStatus { *lock = db } + // Set the current user after authentication + pub fn set_current_user(&self, user: String) { + let mut lock = self.current_user.write(); + *lock = Some(user); + } + + // Get current user + pub fn get_current_user(&self) -> Option { + let lock = self.current_user.read(); + lock.clone() + } + pub fn get_settings(&self) -> Arc { let lock = self.session_settings.read(); Arc::new(lock.clone()) diff --git a/query/src/sessions/session_status_test.rs b/query/src/sessions/session_status_test.rs index a411b9a037d00..a466539b51962 100644 --- a/query/src/sessions/session_status_test.rs +++ b/query/src/sessions/session_status_test.rs @@ -57,6 +57,14 @@ fn test_session_status() -> Result<()> { assert_eq!(Some(server), val); } + // Current user. + { + mutable_status.set_current_user("user1".to_string()); + + let val = mutable_status.get_current_user(); + assert_eq!(Some("user1".to_string()), val); + } + // io shutdown tx. { let (tx, _) = futures::channel::oneshot::channel(); diff --git a/query/src/sql/plan_parser_test.rs b/query/src/sql/plan_parser_test.rs index a27c82913e154..915c2fa6b8f6d 100644 --- a/query/src/sql/plan_parser_test.rs +++ b/query/src/sql/plan_parser_test.rs @@ -213,8 +213,8 @@ fn test_plan_parser() -> Result<()> { name: "show-processlist", sql: "show processlist", expect: "\ - Projection: id:String, type:String, host:String, state:String, database:String, extra_info:String, memory_usage:UInt64\ - \n ReadDataSource: scan partitions: [1], scan schema: [id:String, type:String, host:String;N, state:String, database:String, extra_info:String;N, memory_usage:UInt64;N], statistics: [read_rows: 0, read_bytes: 0]", + Projection: id:String, type:String, host:String, user:String, state:String, database:String, extra_info:String, memory_usage:UInt64\ + \n ReadDataSource: scan partitions: [1], scan schema: [id:String, type:String, host:String;N, user:String;N, state:String, database:String, extra_info:String;N, memory_usage:UInt64;N], statistics: [read_rows: 0, read_bytes: 0]", error: "", }, ];