Skip to content

Commit d0e3081

Browse files
committed
Add auth_query feature
Adds a feature that allows setting auth passthrough for md5 auth. It adds 4 new general config parameters: - `auth_query`: An string containing a query that will be executed on boot to obtain the hash of a given user. This query have to use a placeholder `$1`, so pgcat can replace it with the user its trying to fetch the hash from. - `auth_query_user`: The user to use for connecting to the server and executing the auth_query. - `auth_query_password`: The password to use for connecting to the server and executing the auth_query. - `auth_query_database`: The database to use for connecting to the server and executing the auth_query. The behavior is, at boot time, when validating server connections, a hash is fetched per server and stored there. When new server connections are created, that hash is used for creating them, if the hash could not be obtained for whatever reason, it falls back to the password set. Client connections are also authenticated using the obtained hash.
1 parent bd675ea commit d0e3081

15 files changed

+413
-19
lines changed

.circleci/pgcat.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ tls_private_key = ".circleci/server.key"
5050
admin_username = "admin_user"
5151
admin_password = "admin_pass"
5252

53+
auth_query = "SELECT * FROM public.user_lookup('$1');"
54+
auth_query_user = "md5_auth_user"
55+
auth_query_password = "secret"
56+
auth_query_database = "postgres"
57+
5358
# pool
5459
# configs are structured as pool.<pool_name>
5560
# the pool_name is what clients use as database name when connecting

.circleci/run_tests.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ curl --fail localhost:9930/metrics
4343
export PGPASSWORD=sharding_user
4444
export PGDATABASE=sharded_db
4545

46+
# Query auth test, we check that without passwords (and query auth) we can connect.
47+
PGDATABASE=postgres PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_auth_setup.sql
48+
sed -i 's/^password =/# password =/' .circleci/pgcat.toml
49+
kill -SIGTERM $(pgrep pgcat) # restart config
50+
start_pgcat "info"
51+
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT 1'
52+
53+
# Query auth test, we check that when we have issues executing auth_query, passwords are used.
54+
# Also, that reload fetches new passwords
55+
sed -i 's/^# password =/password =/' .circleci/pgcat.toml
56+
PGDATABASE=postgres PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_auth_wrong_setup.sql
57+
kill -SIGTERM $(pgrep pgcat) # restart config
58+
start_pgcat "info"
59+
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT 1'
60+
4661
# pgbench test
4762
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
4863
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql

Cargo.lock

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ rustls-pemfile = "1"
3434
hyper = { version = "0.14", features = ["full"] }
3535
phf = { version = "0.11.1", features = ["macros"] }
3636
exitcode = "1.1.2"
37+
postgres-protocol = "0.6.4"
38+
fallible-iterator = "0.2"
3739

3840
[target.'cfg(not(target_env = "msvc"))'.dependencies]
3941
jemallocator = "0.5.0"

src/auth_passthrough.rs

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
use crate::errors::Error;
2+
use crate::messages::simple_query;
3+
4+
use crate::server::Server;
5+
use crate::stats::Reporter;
6+
7+
use bytes::BytesMut;
8+
use fallible_iterator::FallibleIterator;
9+
use parking_lot::Mutex;
10+
11+
use log::{debug, trace, warn};
12+
use postgres_protocol::message;
13+
use std::collections::HashMap;
14+
use std::sync::Arc;
15+
16+
use crate::config::get_config;
17+
use crate::pool::ClientServerMap;
18+
19+
pub struct AuthPassthrough {
20+
password: String,
21+
query: String,
22+
user: String,
23+
database: String,
24+
}
25+
26+
impl AuthPassthrough {
27+
/// Initializes an AuthPassthrough.
28+
pub fn new(query: &str, user: &str, password: &str, database: &str) -> Self {
29+
AuthPassthrough {
30+
password: password.to_string(),
31+
query: query.to_string(),
32+
user: user.to_string(),
33+
database: database.to_string(),
34+
}
35+
}
36+
37+
/// Returns an AuthPassthrough given the configuration.
38+
/// If any of required values is not set, None is returned.
39+
pub fn from_config() -> Option<Self> {
40+
let config = get_config();
41+
42+
if config.general.auth_query_password.is_some()
43+
&& config.general.auth_query_user.is_some()
44+
&& config.general.auth_query_password.is_some()
45+
&& config.general.auth_query_database.is_some()
46+
{
47+
return Some(AuthPassthrough::new(
48+
config.general.auth_query.as_ref().unwrap(),
49+
config.general.auth_query_user.as_ref().unwrap(),
50+
config.general.auth_query_password.as_ref().unwrap(),
51+
config.general.auth_query_database.as_ref().unwrap(),
52+
));
53+
}
54+
55+
None
56+
}
57+
58+
/// Connects to server and executes auth_query for the specidief user.
59+
/// If the response is a row with two columns containing the user
60+
/// and its MD5 hash, the hash returned.
61+
///
62+
/// Note that the query is executed, changing $1 with the name of the user
63+
/// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data.
64+
/// Also, it is compatible with pgbouncer.
65+
///
66+
/// # Arguments
67+
///
68+
/// * `address` - An Address of the server we want to connect to.
69+
/// * `user` - A user that will be used to obtain the hash.
70+
///
71+
/// # Examples
72+
///
73+
/// ```
74+
/// use pgcat::auth_passthrough::AuthPassthrough;
75+
/// use pgcat::config::Address;
76+
/// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres", "postgres");
77+
/// auth_passthrough.fetch_hash(&Address::default(), "foo");
78+
/// ```
79+
///
80+
pub async fn fetch_hash(
81+
&self,
82+
address: &crate::config::Address,
83+
user: &str,
84+
) -> Result<String, Error> {
85+
let auth_user = crate::config::User {
86+
username: self.user.clone(),
87+
password: Some(self.password.clone()),
88+
pool_size: 1,
89+
statement_timeout: 0,
90+
};
91+
92+
let reporter: crate::stats::Reporter = Reporter::default();
93+
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
94+
95+
debug!("Connecting to server to obtain auth hashes.");
96+
match Server::startup(
97+
0,
98+
address,
99+
&auth_user,
100+
&self.database,
101+
client_server_map,
102+
reporter,
103+
&None,
104+
)
105+
.await
106+
{
107+
Ok(mut server) => {
108+
debug!("Connected!, sending auth query.");
109+
let auth_query = self.query.replace("$1", user);
110+
send_auth_query(&mut server, &auth_query).await?;
111+
debug!("Auth query sent ({}), waiting for data.", auth_query);
112+
let mut message = recv_data(&mut server).await?;
113+
114+
match parse_query_message(&mut message).await {
115+
Ok(password_data) => {
116+
if password_data.len() == 2 && password_data.first().unwrap() == user {
117+
Ok(password_data.last().unwrap().to_string())
118+
} else {
119+
Err(Error::AuthPassthroughError(
120+
"Data obtained from query does not follow the scheme 'user','hash'."
121+
.to_string(),
122+
))
123+
}
124+
}
125+
Err(err) => {
126+
Err(Error::AuthPassthroughError(format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
127+
user, err)))
128+
}
129+
}
130+
}
131+
Err(err) => Err(Error::AuthPassthroughError(format!(
132+
"Error trying to connect to {} to fetch password shadows using username '{}', {:?}",
133+
address.host, self.user, err
134+
))),
135+
}
136+
}
137+
}
138+
139+
async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> {
140+
let mut pair = Vec::<String>::new();
141+
match message::backend::Message::parse(message) {
142+
Ok(Some(message::backend::Message::RowDescription(_description))) => {}
143+
Ok(Some(message::backend::Message::ErrorResponse(err))) => {
144+
return Err(Error::ProtocolSyncError(format!(
145+
"Protocol error parsing response. Err: {:?}",
146+
err.fields()
147+
.iterator()
148+
.fold(String::default(), |acc, element| acc
149+
+ element.unwrap().value())
150+
)))
151+
}
152+
Ok(_) => {
153+
return Err(Error::ProtocolSyncError(
154+
"Protocol error, expected Row Description.".to_string(),
155+
))
156+
}
157+
Err(err) => {
158+
return Err(Error::ProtocolSyncError(format!(
159+
"Protocol error parsing response. Err: {:?}",
160+
err
161+
)))
162+
}
163+
}
164+
165+
while !message.is_empty() {
166+
match message::backend::Message::parse(message) {
167+
Ok(postgres_message) => {
168+
match postgres_message {
169+
Some(message::backend::Message::DataRow(data)) => {
170+
let buf = data.buffer();
171+
trace!("Data: {:?}", buf);
172+
173+
for item in data.ranges().iterator() {
174+
match item.as_ref() {
175+
Ok(range) => match range {
176+
Some(range) => {
177+
pair.push(String::from_utf8_lossy(&buf[range.clone()]).to_string());
178+
}
179+
None => return Err(Error::ProtocolSyncError(String::from(
180+
"Data expected while receiving query auth data, found nothing.",
181+
))),
182+
},
183+
Err(err) => {
184+
return Err(Error::ProtocolSyncError(format!(
185+
"Data error, err: {:?}",
186+
err
187+
)))
188+
}
189+
}
190+
}
191+
}
192+
Some(message::backend::Message::CommandComplete(_)) => {}
193+
Some(message::backend::Message::ReadyForQuery(_)) => {}
194+
_ => {
195+
return Err(Error::ProtocolSyncError(
196+
"Unexpected message while receiving auth query data.".to_string(),
197+
))
198+
}
199+
}
200+
}
201+
Err(err) => {
202+
return Err(Error::ProtocolSyncError(format!(
203+
"Parse error, err: {:?}",
204+
err
205+
)))
206+
}
207+
};
208+
}
209+
Ok(pair)
210+
}
211+
212+
async fn recv_data(server: &mut Server) -> Result<BytesMut, Error> {
213+
let mut message = BytesMut::new();
214+
loop {
215+
match server.recv().await {
216+
Ok(data) => message.extend_from_slice(&data[..]),
217+
Err(err) => {
218+
return Err(Error::AuthPassthroughError(format!(
219+
"Error receiving data from server err: {:?}",
220+
err
221+
)))
222+
}
223+
}
224+
225+
if !server.is_data_available() {
226+
break;
227+
}
228+
}
229+
Ok(message)
230+
}
231+
232+
async fn send_auth_query(server: &mut Server, query: &str) -> Result<(), Error> {
233+
return match server.send(simple_query(query)).await {
234+
Ok(()) => Ok(()),
235+
Err(err) => {
236+
let message = format!(
237+
"Error trying to connect to {} to fetch password shadows using username {}, {:?}",
238+
server.address().host,
239+
server.address().username,
240+
err
241+
);
242+
warn!("{}", message);
243+
Err(Error::AuthPassthroughError(message))
244+
}
245+
};
246+
}

src/client.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,27 @@ where
493493
};
494494

495495
// Compare server and client hashes.
496-
let password_hash = md5_hash_password(
497-
username,
498-
&pool.settings.user.password.as_ref().unwrap(),
499-
&salt,
500-
);
496+
let mut password_hash = if let Some(password) = pool.settings.user.password.as_ref() {
497+
Some(md5_hash_password(username, password, &salt))
498+
} else {
499+
None
500+
};
501501

502-
if password_hash != password_response {
502+
if let Some(hash) = pool.auth_hash.as_ref() {
503+
if let Some(stripped_hash) = hash.strip_prefix("md5") {
504+
password_hash = Some(md5_hash_second_pass(stripped_hash, &salt));
505+
} else {
506+
warn!("Obtained hash from auth_query does not seem to be in md5 format, ignoring.");
507+
}
508+
}
509+
if password_hash.is_none() {
510+
warn!("Clien auth is not possible, you either have not set a valid auth_query or a password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
511+
wrong_password(&mut write, username).await?;
512+
513+
return Err(Error::ClientError(format!("Invalid client auth {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
514+
}
515+
516+
if password_hash.unwrap() != password_response {
503517
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
504518
wrong_password(&mut write, username).await?;
505519

src/errors.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ pub enum Error {
1414
StatementTimeout,
1515
ShuttingDown,
1616
AuthError(String),
17+
AuthPassthroughError(String),
1718
}

0 commit comments

Comments
 (0)