Skip to content

Commit

Permalink
WIP: add support for mssql + TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
lovasoa committed Oct 2, 2024
1 parent ed3b725 commit 916301f
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 30 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.6.28

- Add sqlx version information to pre-login message in mssql
- Add support for encrypted Microsoft SQL server connections (using TLS)

## 0.6.27

- Fix pg i8 decode
Expand Down
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ runtime-tokio-rustls = [

# for conditional compilation
_rt-async-std = []
_rt-tokio = ["tokio-stream"]
_rt-tokio = ["tokio-stream", "tokio-util"]
_tls-native-tls = []
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]

Expand Down Expand Up @@ -119,7 +119,7 @@ either = "1.6.1"
futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] }
futures-core = { version = "0.3.19", default-features = false }
futures-intrusive = "0.5.0"
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink"] }
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] }
# used by the SQLite worker thread to block on the async mutex that locks the database handle
futures-executor = { version = "0.3.19", optional = true }
flume = { version = "0.11.0", optional = true, default-features = false, features = ["async"] }
Expand Down Expand Up @@ -154,6 +154,7 @@ sqlformat = "0.2.0"
thiserror = "1.0.30"
time = { version = "0.3.2", features = ["macros", "formatting", "parsing"], optional = true }
tokio-stream = { version = "0.1.8", features = ["fs"], optional = true }
tokio-util = { version = "0.7.0", features = ["compat"], default-features = false, optional = true }
smallvec = "1.7.0"
url = { version = "2.2.2", default-features = false }
uuid = { version = "1.0", default-features = false, optional = true, features = ["std"] }
Expand Down
19 changes: 17 additions & 2 deletions sqlx-core/src/mssql/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@ impl MssqlConnection {
// TODO: Encryption
// TODO: Send the version of SQLx over

let encryption = match options.encrypt {
Some(true) => Encrypt::Required,
Some(false) => Encrypt::NotSupported,
None => Encrypt::On,
};

log::debug!("Sending T-SQL PRELOGIN with encryption: {:?}", encryption);

stream.write_packet(
PacketType::PreLogin,
PreLogin {
version: Version::default(),
encryption: Encrypt::NOT_SUPPORTED,
encryption,
instance: options.instance.clone(),

..Default::default()
Expand All @@ -32,7 +40,14 @@ impl MssqlConnection {
stream.flush().await?;

let (_, packet) = stream.recv_packet().await?;
let _ = PreLogin::decode(packet)?;
let prelogin_response = PreLogin::decode(packet)?;

if matches!(
prelogin_response.encryption,
Encrypt::Required | Encrypt::On
) {
stream.setup_encryption().await?;
}

// LOGIN7 defines the authentication rules for use between client and server

Expand Down
1 change: 1 addition & 0 deletions sqlx-core/src/mssql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod establish;
mod executor;
mod prepare;
mod stream;
mod tls_prelogin_stream_wrapper;

pub struct MssqlConnection {
pub(crate) stream: MssqlStream,
Expand Down
33 changes: 26 additions & 7 deletions sqlx-core/src/mssql/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use sqlx_rt::TcpStream;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::{BufStream, Encode};
use crate::mssql::connection::tls_prelogin_stream_wrapper::TlsPreloginWrapper;
use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
use crate::mssql::protocol::env_change::EnvChange;
Expand All @@ -19,12 +20,12 @@ use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError};
use crate::net::MaybeTlsStream;
use crate::net::{MaybeTlsStream, TlsConfig};
use crate::HashMap;
use std::sync::Arc;

pub(crate) struct MssqlStream {
inner: BufStream<MaybeTlsStream<TcpStream>>,
inner: BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>,

// how many Done (or Error) we are currently waiting for
pub(crate) pending_done_count: usize,
Expand All @@ -44,13 +45,15 @@ pub(crate) struct MssqlStream {

// Maximum size of packets to send to the server
pub(crate) max_packet_size: usize,

options: MssqlConnectOptions,
}

impl MssqlStream {
pub(super) async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
let inner = BufStream::new(MaybeTlsStream::Raw(
TcpStream::connect((&*options.host, options.port)).await?,
));
let tcp_stream = TcpStream::connect((&*options.host, options.port)).await?;
let wrapped_stream = TlsPreloginWrapper::new(tcp_stream);
let inner = BufStream::new(MaybeTlsStream::Raw(wrapped_stream));

Ok(Self {
inner,
Expand All @@ -64,6 +67,7 @@ impl MssqlStream {
.requested_packet_size
.try_into()
.unwrap_or(usize::MAX),
options: options.clone(),
})
}

Expand Down Expand Up @@ -206,10 +210,25 @@ impl MssqlStream {

Ok(())
}

pub(crate) async fn setup_encryption(&mut self) -> Result<(), Error> {
let tls_config = TlsConfig {
accept_invalid_certs: true,
hostname: &self.options.host,
accept_invalid_hostnames: true,
root_cert_path: None,
client_cert_path: None,
client_key_path: None,
};
self.inner.deref_mut().start_handshake();
self.inner.upgrade(tls_config).await?;
self.inner.deref_mut().handshake_complete();
Ok(())
}
}

// writes the packet out to the write buffer
fn write_packets<'en, T: Encode<'en>>(
pub(crate) fn write_packets<'en, T: Encode<'en>>(
buffer: &mut Vec<u8>,
max_packet_size: usize,
ty: PacketType,
Expand Down Expand Up @@ -313,7 +332,7 @@ fn test_write_packets() {
}

impl Deref for MssqlStream {
type Target = BufStream<MaybeTlsStream<TcpStream>>;
type Target = BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
176 changes: 176 additions & 0 deletions sqlx-core/src/mssql/connection/tls_prelogin_stream_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Original implementation from tiberius: https://github.com/prisma/tiberius/blob/main/src/client/tls.rs

use crate::mssql::protocol::packet::{PacketHeader, PacketType};

use super::stream::write_packets;

use crate::io::Decode;
use bytes::Bytes;
use sqlx_rt::{AsyncRead, AsyncWrite, ReadBuf};
use std::cmp;
use std::io;
use std::pin::Pin;
use std::task::{self, ready, Poll};

/// This wrapper handles TDS (Tabular Data Stream) packet encapsulation during the TLS handshake phase
/// of a connection to a Microsoft SQL Server.
///
/// In the PRELOGIN phase of the TDS protocol, all communication must be wrapped in TDS packets,
/// even during TLS negotiation. This presents a challenge when using standard TLS libraries,
/// which expect to work with raw TCP streams.
///
/// This wrapper solves the problem by:
/// 1. During handshake:
/// - For writes: It buffers outgoing data and wraps it in TDS packets before sending.
/// Each packet starts with an 8-byte header containing type (0x12 for PRELOGIN),
/// status flags, length, and other metadata.
/// - For reads: It strips the TDS packet headers from incoming data before passing
/// it to the TLS library.
/// 2. After handshake:
/// - It becomes transparent, directly passing through all reads and writes to the
/// underlying stream without modification.
///
/// This allows us to use standard TLS libraries while still conforming to the TDS protocol
/// requirements for the PRELOGIN phase.

const HEADER_BYTES: usize = 8;

pub(crate) struct TlsPreloginWrapper<S> {
stream: S,
pending_handshake: bool,

header_buf: [u8; HEADER_BYTES],
header_pos: usize,
read_remaining: usize,

wr_buf: Vec<u8>,
header_written: bool,
}

impl<S> TlsPreloginWrapper<S> {
pub fn new(stream: S) -> Self {
TlsPreloginWrapper {
stream,
pending_handshake: false,

header_buf: [0u8; HEADER_BYTES],
header_pos: 0,
read_remaining: 0,
wr_buf: Vec::new(),
header_written: false,
}
}

pub fn start_handshake(&mut self) {
self.pending_handshake = true;
}

pub fn handshake_complete(&mut self) {
self.pending_handshake = false;
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if !self.pending_handshake {
return Pin::new(&mut self.stream).poll_read(cx, buf);
}

let inner = self.get_mut();

if !inner.header_buf[inner.header_pos..].is_empty() {
while !inner.header_buf[inner.header_pos..].is_empty() {
let mut header_buf = ReadBuf::new(&mut inner.header_buf[inner.header_pos..]);
ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut header_buf))?;

let read = header_buf.filled().len();
if read == 0 {
return Poll::Ready(Ok(()));
}

inner.header_pos += read;
}

let header: PacketHeader = Decode::decode(Bytes::copy_from_slice(&inner.header_buf))
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

inner.read_remaining = usize::from(header.length) - HEADER_BYTES;

log::trace!(
"Discarding header ({:?}), reading packet of {} bytes",
header,
inner.read_remaining,
);
}

let max_read = std::cmp::min(inner.read_remaining, buf.remaining());
let mut limited_buf = buf.take(max_read);

ready!(Pin::new(&mut inner.stream).poll_read(cx, &mut limited_buf))?;

let read = limited_buf.filled().len();
buf.advance(read);
inner.read_remaining -= read;

if inner.read_remaining == 0 {
inner.header_pos = 0;
}

Poll::Ready(Ok(()))
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
// Normal operation does not need any extra treatment, we handle
// packets in the codec.
if !self.pending_handshake {
return Pin::new(&mut self.stream).poll_write(cx, buf);
}

// Buffering data.
self.wr_buf.extend_from_slice(buf);

Poll::Ready(Ok(buf.len()))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
let inner = self.get_mut();

// If on handshake mode, wraps the data to a TDS packet before sending.
if inner.pending_handshake {
if !inner.header_written {
let buf = std::mem::take(&mut inner.wr_buf);
write_packets(
&mut inner.wr_buf,
4096,
PacketType::PreLogin,
buf.as_slice(),
);
inner.header_written = true;
}

while !inner.wr_buf.is_empty() {
let written = ready!(Pin::new(&mut inner.stream).poll_write(cx, &inner.wr_buf))?;

inner.wr_buf.drain(..written);
}

inner.header_written = false;
}

Pin::new(&mut inner.stream).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
Loading

0 comments on commit 916301f

Please sign in to comment.