diff --git a/Cargo.lock b/Cargo.lock index 17333b946902b..fb1195b008152 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3171,6 +3171,7 @@ dependencies = [ "semver 1.0.14", "serde", "serde_json", + "serde_urlencoded", "temp-env", "tempfile", "time 0.3.15", diff --git a/src/query/service/Cargo.toml b/src/query/service/Cargo.toml index 6f98c65e7a6fa..47c01a812e6b0 100644 --- a/src/query/service/Cargo.toml +++ b/src/query/service/Cargo.toml @@ -115,6 +115,7 @@ regex = "1.6.0" semver = "1.0.14" serde = { workspace = true } serde_json = { workspace = true } +serde_urlencoded = "0.7.1" tempfile = { version = "3.3.0", optional = true } thrift = { package = "databend-thrift", version = "0.17.0", optional = true } time = "0.3.14" diff --git a/src/query/service/src/servers/http/middleware.rs b/src/query/service/src/servers/http/middleware.rs index 43e2ac3047cfa..cea949ca9c18b 100644 --- a/src/query/service/src/servers/http/middleware.rs +++ b/src/query/service/src/servers/http/middleware.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use common_exception::ErrorCode; @@ -20,6 +21,7 @@ use headers::authorization::Basic; use headers::authorization::Bearer; use headers::authorization::Credentials; use http::header::AUTHORIZATION; +use http::HeaderValue; use poem::error::Error as PoemError; use poem::error::Result as PoemResult; use poem::http::StatusCode; @@ -49,8 +51,8 @@ impl HTTPSessionMiddleware { } fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result { - let auth_headers: Vec<_> = req.headers().get_all(AUTHORIZATION).iter().collect(); - if auth_headers.len() > 1 { + let std_auth_headers: Vec<_> = req.headers().get_all(AUTHORIZATION).iter().collect(); + if std_auth_headers.len() > 1 { let msg = &format!("Multiple {} headers detected", AUTHORIZATION); return Err(ErrorCode::AuthenticateFailure(msg)); } @@ -59,26 +61,24 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result { Addr::Custom(..) => Some("127.0.0.1".to_string()), _ => None, }; - if auth_headers.is_empty() { - if let HttpHandlerKind::Clickhouse = kind { - let (user, key) = ( - req.headers().get("X-CLICKHOUSE-USER"), - req.headers().get("X-CLICKHOUSE-KEY"), - ); - if let (Some(name), Some(password)) = (user, key) { - let c = Credential::Password { - name: String::from_utf8(name.as_bytes().to_vec()).unwrap(), - password: Some(password.as_bytes().to_vec()), - hostname: client_ip, - }; - return Ok(c); - } + if std_auth_headers.is_empty() { + if matches!(kind, HttpHandlerKind::Clickhouse) { + auth_clickhouse_name_password(req, client_ip) + } else { + Err(ErrorCode::AuthenticateFailure( + "No authorization header detected", + )) } - return Err(ErrorCode::AuthenticateFailure( - "No authorization header detected", - )); + } else { + auth_by_header(&std_auth_headers, client_ip) } - let value = auth_headers[0]; +} + +fn auth_by_header( + std_auth_headers: &[&HeaderValue], + client_ip: Option, +) -> Result { + let value = &std_auth_headers[0]; if value.as_bytes().starts_with(b"Basic ") { match Basic::decode(value) { Some(basic) => { @@ -107,6 +107,37 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result { } } +fn auth_clickhouse_name_password(req: &Request, client_ip: Option) -> Result { + let (user, key) = ( + req.headers().get("X-CLICKHOUSE-USER"), + req.headers().get("X-CLICKHOUSE-KEY"), + ); + if let (Some(name), Some(password)) = (user, key) { + let c = Credential::Password { + name: String::from_utf8(name.as_bytes().to_vec()).unwrap(), + password: Some(password.as_bytes().to_vec()), + hostname: client_ip, + }; + Ok(c) + } else { + let query_str = req.uri().query().unwrap_or_default(); + let query_params = serde_urlencoded::from_str::>(query_str) + .map_err(|e| ErrorCode::BadArguments(format!("{}", e)))?; + let (user, key) = (query_params.get("user"), query_params.get("password")); + if let (Some(name), Some(password)) = (user, key) { + Ok(Credential::Password { + name: name.clone(), + password: Some(password.as_bytes().to_vec()), + hostname: client_ip, + }) + } else { + Err(ErrorCode::AuthenticateFailure( + "No header or query parameters for authorization detected", + )) + } + } +} + impl Middleware for HTTPSessionMiddleware { type Output = HTTPSessionEndpoint; fn transform(&self, ep: E) -> Self::Output { diff --git a/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.result b/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.result index 6ed281c757a96..e8183f05f5db6 100755 --- a/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.result +++ b/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.result @@ -1,2 +1,3 @@ 1 1 +1 diff --git a/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.sh b/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.sh index 066e744ce5f0d..216010300abc6 100755 --- a/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.sh +++ b/tests/suites/0_stateless/14_clickhouse_http_handler/14_0003_http_auth.sh @@ -9,7 +9,6 @@ curl -s -u root: -XPOST "http://localhost:${QUERY_CLICKHOUSE_HTTP_HANDLER_PORT}" curl -s -u user1:abc123 -XPOST "http://localhost:${QUERY_CLICKHOUSE_HTTP_HANDLER_PORT}" -d 'select 1 FORMAT CSV' -## TODO(move this into separated port of clickhouse ?) curl -s -H 'X-ClickHouse-User: user1' -H 'X-ClickHouse-Key: abc123' -XPOST "http://localhost:${QUERY_CLICKHOUSE_HTTP_HANDLER_PORT}" -d 'select 1 FORMAT CSV' - +curl -s -XPOST "http://localhost:${QUERY_CLICKHOUSE_HTTP_HANDLER_PORT}?user=user1&password=abc123" -d 'select 1 FORMAT CSV'