Skip to content

Commit f79fff6

Browse files
committed
refactor(cubesql): Make Postgres authentication extensible
1 parent e8d81f2 commit f79fff6

File tree

8 files changed

+277
-108
lines changed

8 files changed

+277
-108
lines changed

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ use crate::{
1616
},
1717
config::{ConfigObj, ConfigObjImpl},
1818
sql::{
19-
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, AuthContextRef,
20-
AuthenticateResponse, HttpAuthContext, ServerManager, Session, SessionManager,
21-
SqlAuthService,
19+
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe,
20+
pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse,
21+
HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService,
2222
},
2323
transport::{
2424
CubeStreamReceiver, LoadRequestMeta, SpanId, SqlGenerator, SqlResponse, SqlTemplates,
@@ -607,6 +607,7 @@ async fn get_test_session_with_config_and_transport(
607607
let server = Arc::new(ServerManager::new(
608608
get_test_auth(),
609609
test_transport.clone(),
610+
Arc::new(PostgresAuthServiceDefaultImpl::new()),
610611
Arc::new(CompilerCacheImpl::new(config_obj.clone(), test_transport)),
611612
None,
612613
config_obj,

rust/cubesql/cubesql/src/config/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use crate::{
66
injection::{DIService, Injector},
77
processing_loop::{ProcessingLoop, ShutdownMode},
88
},
9-
sql::{PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService},
9+
sql::{
10+
pg_auth_service::{PostgresAuthService, PostgresAuthServiceDefaultImpl},
11+
PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService,
12+
},
1013
transport::{HttpTransport, TransportService},
1114
CubeError,
1215
};
@@ -302,6 +305,12 @@ impl Config {
302305
})
303306
.await;
304307

308+
self.injector
309+
.register_typed::<dyn PostgresAuthService, _, _, _>(|_| async move {
310+
Arc::new(PostgresAuthServiceDefaultImpl::new())
311+
})
312+
.await;
313+
305314
self.injector
306315
.register_typed::<dyn CompilerCache, _, _, _>(|i| async move {
307316
let config = i.get_service_typed::<dyn ConfigObj>().await;
@@ -319,6 +328,7 @@ impl Config {
319328
i.get_service_typed().await,
320329
i.get_service_typed().await,
321330
i.get_service_typed().await,
331+
i.get_service_typed().await,
322332
config.nonce().clone(),
323333
config.clone(),
324334
))

rust/cubesql/cubesql/src/sql/postgres/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub(crate) mod extended;
2+
pub mod pg_auth_service;
23
pub(crate) mod pg_type;
34
pub(crate) mod service;
45
pub(crate) mod shim;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use std::{collections::HashMap, fmt::Debug, sync::Arc};
2+
3+
use async_trait::async_trait;
4+
5+
use crate::{
6+
sql::{AuthContextRef, SqlAuthService},
7+
CubeError,
8+
};
9+
10+
pub use pg_srv::{
11+
buffer as pg_srv_buffer,
12+
protocol::{
13+
AuthenticationRequest, AuthenticationRequestExtension, FrontendMessage,
14+
FrontendMessageExtension,
15+
},
16+
MessageTagParser, MessageTagParserDefaultImpl, ProtocolError,
17+
};
18+
19+
#[derive(Debug)]
20+
pub enum AuthenticationStatus {
21+
UnexpectedFrontendMessage,
22+
Failed(String),
23+
// User name + auth context
24+
Success(String, AuthContextRef),
25+
}
26+
27+
#[async_trait]
28+
pub trait PostgresAuthService: Sync + Send + Debug {
29+
fn get_auth_method(&self, parameters: &HashMap<String, String>) -> AuthenticationRequest;
30+
31+
async fn authenticate(
32+
&self,
33+
service: Arc<dyn SqlAuthService>,
34+
request: AuthenticationRequest,
35+
secret: FrontendMessage,
36+
parameters: &HashMap<String, String>,
37+
) -> AuthenticationStatus;
38+
39+
fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser>;
40+
}
41+
42+
#[derive(Debug)]
43+
pub struct PostgresAuthServiceDefaultImpl {
44+
pg_message_tag_parser: Arc<dyn MessageTagParser>,
45+
}
46+
47+
impl PostgresAuthServiceDefaultImpl {
48+
pub fn new() -> Self {
49+
Self {
50+
pg_message_tag_parser: Arc::new(MessageTagParserDefaultImpl::default()),
51+
}
52+
}
53+
}
54+
55+
#[async_trait]
56+
impl PostgresAuthService for PostgresAuthServiceDefaultImpl {
57+
fn get_auth_method(&self, _: &HashMap<String, String>) -> AuthenticationRequest {
58+
AuthenticationRequest::CleartextPassword
59+
}
60+
61+
async fn authenticate(
62+
&self,
63+
service: Arc<dyn SqlAuthService>,
64+
request: AuthenticationRequest,
65+
secret: FrontendMessage,
66+
parameters: &HashMap<String, String>,
67+
) -> AuthenticationStatus {
68+
let FrontendMessage::PasswordMessage(password_message) = secret else {
69+
return AuthenticationStatus::UnexpectedFrontendMessage;
70+
};
71+
72+
if !matches!(request, AuthenticationRequest::CleartextPassword) {
73+
return AuthenticationStatus::UnexpectedFrontendMessage;
74+
}
75+
76+
let user = parameters.get("user").unwrap().clone();
77+
let authenticate_response = service
78+
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
79+
.await;
80+
81+
let auth_fail = || {
82+
AuthenticationStatus::Failed(format!(
83+
"password authentication failed for user \"{}\"",
84+
user
85+
))
86+
};
87+
88+
let Ok(authenticate_response) = authenticate_response else {
89+
return auth_fail();
90+
};
91+
92+
if !authenticate_response.skip_password_check {
93+
let is_password_correct = match authenticate_response.password {
94+
None => false,
95+
Some(password) => password == password_message.password,
96+
};
97+
if !is_password_correct {
98+
return auth_fail();
99+
}
100+
}
101+
102+
AuthenticationStatus::Success(user, authenticate_response.context)
103+
}
104+
105+
fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser> {
106+
Arc::clone(&self.pg_message_tag_parser)
107+
}
108+
}
109+
110+
crate::di_service!(PostgresAuthServiceDefaultImpl, [PostgresAuthService]);

rust/cubesql/cubesql/src/sql/postgres/shim.rs

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{
33
time::SystemTime,
44
};
55

6-
use super::extended::PreparedStatement;
6+
use super::{extended::PreparedStatement, pg_auth_service::AuthenticationStatus};
77
use crate::{
88
compile::{
99
convert_statement_to_cube_query,
@@ -24,8 +24,11 @@ use crate::{
2424
use futures::{pin_mut, FutureExt, StreamExt};
2525
use log::{debug, error, trace};
2626
use pg_srv::{
27-
buffer, protocol,
28-
protocol::{ErrorCode, ErrorResponse, Format, InitialMessage, PortalCompletion},
27+
buffer,
28+
protocol::{
29+
self, AuthenticationRequest, ErrorCode, ErrorResponse, Format, InitialMessage,
30+
PortalCompletion,
31+
},
2932
PgType, PgTypeId, ProtocolError,
3033
};
3134
use sqlparser::ast::{self, CloseCursor, FetchDirection, Query, SetExpr, Statement, Value};
@@ -46,10 +49,9 @@ pub struct AsyncPostgresShim {
4649
logger: Arc<dyn ContextLogger>,
4750
}
4851

49-
#[derive(PartialEq, Eq)]
5052
pub enum StartupState {
5153
// Initial parameters which client sends in the first message, we use it later in auth method
52-
Success(HashMap<String, String>),
54+
Success(HashMap<String, String>, AuthenticationRequest),
5355
SslRequested,
5456
Denied,
5557
CancelRequest,
@@ -313,25 +315,23 @@ impl AsyncPostgresShim {
313315
}
314316

315317
pub async fn run(&mut self) -> Result<(), ConnectionError> {
316-
let initial_parameters = match self.process_initial_message().await? {
317-
StartupState::Success(parameters) => parameters,
318+
let (initial_parameters, auth_method) = match self.process_initial_message().await? {
319+
StartupState::Success(parameters, auth_method) => (parameters, auth_method),
318320
StartupState::SslRequested => match self.process_initial_message().await? {
319-
StartupState::Success(parameters) => parameters,
321+
StartupState::Success(parameters, auth_method) => (parameters, auth_method),
320322
_ => return Ok(()),
321323
},
322324
StartupState::Denied | StartupState::CancelRequest => return Ok(()),
323325
};
324326

325-
match buffer::read_message(&mut self.socket).await? {
326-
protocol::FrontendMessage::PasswordMessage(password_message) => {
327-
if !self
328-
.authenticate(password_message, initial_parameters)
329-
.await?
330-
{
331-
return Ok(());
332-
}
333-
}
334-
_ => return Ok(()),
327+
let message_tag_parser = self.session.server.pg_auth.get_pg_message_tag_parser();
328+
let auth_secret =
329+
buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)).await?;
330+
if !self
331+
.authenticate(auth_method, auth_secret, initial_parameters)
332+
.await?
333+
{
334+
return Ok(());
335335
}
336336

337337
self.ready().await?;
@@ -351,7 +351,7 @@ impl AsyncPostgresShim {
351351
true = async { semifast_shutdownable && { semifast_shutdown_interruptor.cancelled().await; true } } => {
352352
return Self::flush_and_write_admin_shutdown_fatal_message(self).await;
353353
}
354-
message_result = buffer::read_message(&mut self.socket) => message_result?
354+
message_result = buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)) => message_result?
355355
};
356356

357357
let result = match message {
@@ -716,73 +716,62 @@ impl AsyncPostgresShim {
716716
return Ok(StartupState::Denied);
717717
}
718718

719-
self.write(protocol::Authentication::new(
720-
protocol::AuthenticationRequest::CleartextPassword,
721-
))
722-
.await?;
719+
let auth_method = self.session.server.pg_auth.get_auth_method(&parameters);
720+
self.write(protocol::Authentication::new(auth_method.clone()))
721+
.await?;
723722

724-
Ok(StartupState::Success(parameters))
723+
Ok(StartupState::Success(parameters, auth_method))
725724
}
726725

727726
pub async fn authenticate(
728727
&mut self,
729-
password_message: protocol::PasswordMessage,
728+
auth_request: AuthenticationRequest,
729+
auth_secret: protocol::FrontendMessage,
730730
parameters: HashMap<String, String>,
731731
) -> Result<bool, ConnectionError> {
732-
let user = parameters.get("user").unwrap().clone();
733-
let authenticate_response = self
732+
let auth_service = self.session.server.auth.clone();
733+
let auth_status = self
734734
.session
735735
.server
736-
.auth
737-
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
736+
.pg_auth
737+
.authenticate(auth_service, auth_request, auth_secret, &parameters)
738738
.await;
739+
let result = match auth_status {
740+
AuthenticationStatus::UnexpectedFrontendMessage => Err((
741+
"invalid authorization specification".to_string(),
742+
protocol::ErrorCode::InvalidAuthorizationSpecification,
743+
)),
744+
AuthenticationStatus::Failed(err) => Err((err, protocol::ErrorCode::InvalidPassword)),
745+
AuthenticationStatus::Success(user, auth_context) => Ok((user, auth_context)),
746+
};
739747

740-
let mut auth_context: Option<AuthContextRef> = None;
748+
match result {
749+
Err((message, code)) => {
750+
let error_response = protocol::ErrorResponse::fatal(code, message);
751+
buffer::write_message(
752+
&mut self.partial_write_buf,
753+
&mut self.socket,
754+
error_response,
755+
)
756+
.await?;
741757

742-
let auth_success = match authenticate_response {
743-
Ok(authenticate_response) => {
744-
auth_context = Some(authenticate_response.context);
745-
if !authenticate_response.skip_password_check {
746-
match authenticate_response.password {
747-
None => false,
748-
Some(password) => password == password_message.password,
749-
}
750-
} else {
751-
true
752-
}
758+
Ok(false)
753759
}
754-
_ => false,
755-
};
756-
757-
if !auth_success {
758-
let error_response = protocol::ErrorResponse::fatal(
759-
protocol::ErrorCode::InvalidPassword,
760-
format!("password authentication failed for user \"{}\"", &user),
761-
);
762-
buffer::write_message(
763-
&mut self.partial_write_buf,
764-
&mut self.socket,
765-
error_response,
766-
)
767-
.await?;
760+
Ok((user, auth_context)) => {
761+
let database = parameters
762+
.get("database")
763+
.map(|v| v.clone())
764+
.unwrap_or("db".to_string());
765+
self.session.state.set_database(Some(database));
766+
self.session.state.set_user(Some(user));
767+
self.session.state.set_auth_context(Some(auth_context));
768+
769+
self.write(protocol::Authentication::new(AuthenticationRequest::Ok))
770+
.await?;
768771

769-
return Ok(false);
772+
Ok(true)
773+
}
770774
}
771-
772-
let database = parameters
773-
.get("database")
774-
.map(|v| v.clone())
775-
.unwrap_or("db".to_string());
776-
self.session.state.set_database(Some(database));
777-
self.session.state.set_user(Some(user));
778-
self.session.state.set_auth_context(auth_context);
779-
780-
self.write(protocol::Authentication::new(
781-
protocol::AuthenticationRequest::Ok,
782-
))
783-
.await?;
784-
785-
Ok(true)
786775
}
787776

788777
pub async fn ready(&mut self) -> Result<(), ConnectionError> {

rust/cubesql/cubesql/src/sql/server_manager.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::{
44
sql::{
55
compiler_cache::CompilerCache,
66
database_variables::{mysql_default_global_variables, postgres_default_global_variables},
7+
pg_auth_service::PostgresAuthService,
78
SqlAuthService,
89
},
910
transport::TransportService,
@@ -37,6 +38,7 @@ pub struct ServerManager {
3738
// References to shared things
3839
pub auth: Arc<dyn SqlAuthService>,
3940
pub transport: Arc<dyn TransportService>,
41+
pub pg_auth: Arc<dyn PostgresAuthService>,
4042
// Non references
4143
pub configuration: ServerConfiguration,
4244
pub nonce: Option<Vec<u8>>,
@@ -52,13 +54,15 @@ impl ServerManager {
5254
pub fn new(
5355
auth: Arc<dyn SqlAuthService>,
5456
transport: Arc<dyn TransportService>,
57+
pg_auth: Arc<dyn PostgresAuthService>,
5558
compiler_cache: Arc<dyn CompilerCache>,
5659
nonce: Option<Vec<u8>>,
5760
config_obj: Arc<dyn ConfigObj>,
5861
) -> Self {
5962
Self {
6063
auth,
6164
transport,
65+
pg_auth,
6266
compiler_cache,
6367
nonce,
6468
config_obj,

0 commit comments

Comments
 (0)