-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
connection.rs
265 lines (231 loc) · 10 KB
/
connection.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Range;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use crate::connection::{Connect, Connection};
use crate::executor::Executor;
use crate::postgres::protocol::{
Authentication, AuthenticationMd5, AuthenticationSasl, Message, PasswordMessage,
StartupMessage, StatementId, Terminate, TypeFormat,
};
use crate::postgres::stream::PgStream;
use crate::postgres::{sasl, tls};
use crate::url::Url;
/// An asynchronous connection to a [Postgres][super::Postgres] database.
///
/// The connection string expected by [Connect::connect] should be a PostgreSQL connection
/// string, as documented at
/// <https://www.postgresql.org/docs/12/libpq-connect.html#LIBPQ-CONNSTRING>
///
/// ### TLS Support (requires `tls` feature)
/// This connection type supports the same `sslmode` query parameter that `libpq` does in
/// connection strings: <https://www.postgresql.org/docs/12/libpq-ssl.html>
///
/// ```text
/// postgresql://<user>[:<password>]@<host>[:<port>]/<database>[?sslmode=<ssl-mode>[&sslcrootcert=<path>]]
/// ```
/// where
/// ```text
/// ssl-mode = disable | allow | prefer | require | verify-ca | verify-full
/// path = percent (URL) encoded path on the local machine
/// ```
///
/// If the `tls` feature is not enabled, `disable`, `allow` and `prefer` are no-ops and `require`,
/// `verify-ca` and `verify-full` are forbidden (attempting to connect with these will return
/// an error).
///
/// If the `tls` feature is enabled, an upgrade to TLS is attempted on every connection by default
/// (equivalent to `sslmode=prefer`). If the server does not support TLS (because it was not
/// started with a valid certificate and key, see <https://www.postgresql.org/docs/12/ssl-tcp.html>)
/// then it falls back to an unsecured connection and logs a warning.
///
/// Add `sslmode=require` to your connection string to emit an error if the TLS upgrade fails.
///
/// If you're running Postgres locally, your connection string might look like this:
/// ```text
/// postgresql://root:password@localhost/my_database?sslmode=require
/// ```
///
/// However, like with `libpq` the server certificate is **not** checked for validity by default.
///
/// Specifying `sslmode=verify-ca` will cause the TLS upgrade to verify the server's SSL
/// certificate against a local CA root certificate; this is not the system root certificate
/// but is instead expected to be specified in one of a few ways:
///
/// * The path to the certificate can be specified by adding the `sslrootcert` query parameter
/// to the connection string. (Remember to percent-encode it!)
///
/// * The path may also be specified via the `PGSSLROOTCERT` environment variable (which
/// should *not* be percent-encoded.)
///
/// * Otherwise, the library will look for the Postgres global root CA certificate in the default
/// location:
///
/// * `$HOME/.postgresql/root.crt` on POSIX systems
/// * `%APPDATA%\postgresql\root.crt` on Windows
///
/// These locations are documented here: <https://www.postgresql.org/docs/12/libpq-ssl.html#LIBQ-SSL-CERTIFICATES>
/// If the root certificate cannot be found by any of these means then the TLS upgrade will fail.
///
/// If `sslmode=verify-full` is specified, in addition to checking the certificate as with
/// `sslmode=verify-ca`, the hostname in the connection string will be verified
/// against the hostname in the server certificate, so they must be the same for the TLS
/// upgrade to succeed.
pub struct PgConnection {
pub(super) stream: PgStream,
pub(super) next_statement_id: u32,
pub(super) is_ready: bool,
pub(super) cache_statement: HashMap<Box<str>, StatementId>,
pub(super) cache_statement_columns: HashMap<StatementId, Arc<HashMap<Box<str>, usize>>>,
pub(super) cache_statement_formats: HashMap<StatementId, Arc<[TypeFormat]>>,
// Work buffer for the value ranges of the current row
// This is used as the backing memory for each Row's value indexes
pub(super) current_row_values: Vec<Option<Range<u32>>>,
}
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let database = url.database().unwrap_or("postgres");
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
("user", username),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
stream.write(StartupMessage { params });
stream.flush().await?;
loop {
match stream.read().await? {
Message::Authentication => match Authentication::read(stream.buffer())? {
Authentication::Ok => {
// do nothing. no password is needed to continue.
}
Authentication::CleartextPassword => {
stream.write(PasswordMessage::ClearText(
&url.password().unwrap_or_default(),
));
stream.flush().await?;
}
Authentication::Md5Password => {
// TODO: Just reference the salt instead of returning a stack array
// TODO: Better way to make sure we skip the first 4 bytes here
let data = AuthenticationMd5::read(&stream.buffer()[4..])?;
stream.write(PasswordMessage::Md5 {
password: &url.password().unwrap_or_default(),
user: username,
salt: data.salt,
});
stream.flush().await?;
}
Authentication::Sasl => {
// TODO: Make this iterative for traversing the mechanisms to remove the allocation
// TODO: Better way to make sure we skip the first 4 bytes here
let data = AuthenticationSasl::read(&stream.buffer()[4..])?;
let mut has_sasl: bool = false;
let mut has_sasl_plus: bool = false;
for mechanism in &*data.mechanisms {
match &**mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
}
"SCRAM-SHA-256-PLUS" => {
has_sasl_plus = true;
}
_ => {
log::info!("unsupported auth mechanism: {}", mechanism);
}
}
}
if has_sasl || has_sasl_plus {
// TODO: Handle -PLUS differently if we're in a TLS stream
sasl::authenticate(stream, username, &url.password().unwrap_or_default())
.await?;
} else {
return Err(protocol_err!(
"unsupported SASL auth mechanisms: {:?}",
data.mechanisms
)
.into());
}
}
auth => {
return Err(
protocol_err!("requested unsupported authentication: {:?}", auth).into(),
);
}
},
Message::BackendKeyData => {
// do nothing. we do not care about the server values here.
// todo: we should care and store these on the connection
}
Message::ParameterStatus => {
// do nothing. we do not care about the server values here.
}
Message::ReadyForQuery => {
// done. connection is now fully established and can accept
// queries for execution.
break;
}
type_ => {
return Err(protocol_err!("unexpected message: {:?}", type_).into());
}
}
}
Ok(())
}
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.10
async fn terminate(mut stream: PgStream) -> crate::Result<()> {
stream.write(Terminate);
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
impl PgConnection {
pub(super) async fn new(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut stream = PgStream::new(&url).await?;
tls::request_if_needed(&mut stream, &url).await?;
startup(&mut stream, &url).await?;
Ok(Self {
stream,
current_row_values: Vec::with_capacity(10),
next_statement_id: 1,
is_ready: true,
cache_statement: HashMap::new(),
cache_statement_columns: HashMap::new(),
cache_statement_formats: HashMap::new(),
})
}
}
impl Connect for PgConnection {
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<PgConnection>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::new(url.try_into()))
}
}
impl Connection for PgConnection {
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(terminate(self.stream))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(Executor::execute(self, "SELECT 1").map_ok(|_| ()))
}
}