Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the read buffer for the handshake process, make the write buffer available for cert verification #106

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/asynch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ where
/// The write record buffer can be smaller than the read buffer. During write [`TLS_RECORD_OVERHEAD`] over overhead
/// is added per record, so the buffer must at least be this large. Large writes are split into multiple records if
/// depending on the size of the write buffer.
/// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
/// buffers must at least be large enough to encode a handshake.
/// The read buffer will also be used to encode the TLS handshake record. The write buffer may
/// be resued by the certificate verifier.
pub fn new(
delegate: Socket,
record_read_buf: &'a mut [u8],
Expand All @@ -74,16 +74,19 @@ where
///
/// Returns an error if the handshake does not proceed. If an error occurs, the connection
/// instance must be recreated.
pub async fn open<'v, RNG, Verifier>(
&mut self,
context: TlsContext<'v, CipherSuite, RNG>,
pub async fn open<'v, 'c, RNG, Verifier>(
&'v mut self,
context: TlsContext<'c, CipherSuite, RNG>,
) -> Result<(), TlsError>
where
RNG: CryptoRng + RngCore,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
'a: 'v,
{
let mut handshake: Handshake<CipherSuite, Verifier> =
Handshake::new(Verifier::new(context.config.server_name));
let mut handshake: Handshake<CipherSuite, Verifier> = Handshake::new(Verifier::new(
self.record_write_buf.take_buffer()?,
context.config.server_name,
));
let mut state = State::ClientHello;

loop {
Expand All @@ -92,7 +95,6 @@ where
&mut self.delegate,
&mut handshake,
&mut self.record_reader,
&mut self.record_write_buf,
&mut self.key_schedule,
context.config,
context.rng,
Expand Down
19 changes: 10 additions & 9 deletions src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ where
/// The write record buffer can be smaller than the read buffer. During write [`TLS_RECORD_OVERHEAD`] over overhead
/// is added per record, so the buffer must at least be this large. Large writes are split into multiple records if
/// depending on the size of the write buffer.
/// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
/// buffers must at least be large enough to encode a handshake.
/// The read buffer will also be used to encode the TLS handshake record. The write buffer may
/// be resued by the certificate verifier.
pub fn new(
delegate: Socket,
record_read_buf: &'a mut [u8],
Expand All @@ -74,24 +74,25 @@ where
///
/// Returns an error if the handshake does not proceed. If an error occurs, the connection
/// instance must be recreated.
pub fn open<'v, RNG, Verifier>(
&mut self,
context: TlsContext<'v, CipherSuite, RNG>,
pub fn open<'v, 'c, RNG, Verifier>(
&'v mut self,
context: TlsContext<'c, CipherSuite, RNG>,
) -> Result<(), TlsError>
where
RNG: CryptoRng + RngCore,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
{
let mut handshake: Handshake<CipherSuite, Verifier> =
Handshake::new(Verifier::new(context.config.server_name));
let mut handshake: Handshake<CipherSuite, Verifier> = Handshake::new(Verifier::new(
self.record_write_buf.take_buffer()?,
context.config.server_name,
));
let mut state = State::ClientHello;

loop {
let next_state = state.process_blocking(
&mut self.delegate,
&mut handshake,
&mut self.record_reader,
&mut self.record_write_buf,
&mut self.key_schedule,
context.config,
context.rng,
Expand Down
8 changes: 4 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl TlsCipherSuite for Aes256GcmSha384 {
/// The verifier is responsible for verifying certificates and signatures. Since certificate verification is
/// an expensive process, this trait allows clients to choose how much verification should take place,
/// and also to skip the verification if the server is verified through other means (I.e. a pre-shared key).
pub trait TlsVerifier<'a, CipherSuite>
pub trait TlsVerifier<'a, 'c, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
Expand All @@ -74,7 +74,7 @@ where
/// This method is called for every TLS handshake.
///
/// Host verification is enabled by passing a server hostname.
fn new(host: Option<&'a str>) -> Self;
fn new(buffer: &'a mut [u8], host: Option<&'c str>) -> Self;

/// Verify a certificate.
///
Expand All @@ -96,11 +96,11 @@ where

pub struct NoVerify;

impl<'a, CipherSuite> TlsVerifier<'a, CipherSuite> for NoVerify
impl<'a, 'c, CipherSuite> TlsVerifier<'a, 'c, CipherSuite> for NoVerify
where
CipherSuite: TlsCipherSuite,
{
fn new(_host: Option<&str>) -> Self {
fn new(_buffer: &'a mut [u8], _host: Option<&'c str>) -> Self {
Self
}

Expand Down
64 changes: 37 additions & 27 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ where
verifier: Verifier,
}

impl<'v, CipherSuite, Verifier> Handshake<CipherSuite, Verifier>
impl<'v, 'c, CipherSuite, Verifier> Handshake<CipherSuite, Verifier>
where
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
{
pub fn new(verifier: Verifier) -> Handshake<CipherSuite, Verifier> {
Handshake {
Expand All @@ -210,25 +210,25 @@ pub enum State {
impl<'a> State {
#[cfg(feature = "async")]
#[allow(clippy::too_many_arguments)]
pub async fn process<'v, Transport, CipherSuite, RNG, Verifier>(
pub async fn process<'v, 'c, Transport, CipherSuite, RNG, Verifier>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CipherSuite, Verifier>,
record_reader: &mut RecordReader<'_, CipherSuite>,
tx_buf: &mut WriteBuffer<'_>,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
config: &TlsConfig<'c, CipherSuite>,
rng: &mut RNG,
) -> Result<State, TlsError>
where
Transport: AsyncRead + AsyncWrite + 'a,
RNG: CryptoRng + RngCore + 'a,
Transport: AsyncRead + AsyncWrite,
RNG: CryptoRng + RngCore,
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
{
match self {
State::ClientHello => {
let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let (state, tx) = client_hello(key_schedule, config, rng, &mut tx_buf, handshake)?;

respond(tx, transport, key_schedule).await?;

Expand All @@ -240,7 +240,7 @@ impl<'a> State {
.await?;
let result = process_server_hello(handshake, key_schedule, record);

handle_processing_error(result, transport, key_schedule, tx_buf).await
handle_processing_error(result, transport, key_schedule, record_reader).await
}
State::ServerVerify => {
let record = record_reader
Expand All @@ -249,17 +249,19 @@ impl<'a> State {

let result = process_server_verify(handshake, key_schedule, config, record);

handle_processing_error(result, transport, key_schedule, tx_buf).await
handle_processing_error(result, transport, key_schedule, record_reader).await
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let (state, tx) = client_cert(handshake, key_schedule, config, &mut tx_buf)?;

respond(tx, transport, key_schedule).await?;

Ok(state)
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let tx = client_finished(key_schedule, &mut tx_buf)?;

respond(tx, transport, key_schedule).await?;

Expand All @@ -270,25 +272,25 @@ impl<'a> State {
}

#[allow(clippy::too_many_arguments)]
pub fn process_blocking<'v, Transport, CipherSuite, RNG, Verifier>(
pub fn process_blocking<'v, 'c, Transport, CipherSuite, RNG, Verifier>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CipherSuite, Verifier>,
record_reader: &mut RecordReader<'_, CipherSuite>,
tx_buf: &mut WriteBuffer,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
config: &TlsConfig<'c, CipherSuite>,
rng: &mut RNG,
) -> Result<State, TlsError>
where
Transport: BlockingRead + BlockingWrite + 'a,
Transport: BlockingRead + BlockingWrite,
RNG: CryptoRng + RngCore,
CipherSuite: TlsCipherSuite + 'static,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
{
match self {
State::ClientHello => {
let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let (state, tx) = client_hello(key_schedule, config, rng, &mut tx_buf, handshake)?;

respond_blocking(tx, transport, key_schedule)?;

Expand All @@ -299,24 +301,26 @@ impl<'a> State {

let result = process_server_hello(handshake, key_schedule, record);

handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
handle_processing_error_blocking(result, transport, key_schedule, record_reader)
}
State::ServerVerify => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;

let result = process_server_verify(handshake, key_schedule, config, record);

handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
handle_processing_error_blocking(result, transport, key_schedule, record_reader)
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let (state, tx) = client_cert(handshake, key_schedule, config, &mut tx_buf)?;

respond_blocking(tx, transport, key_schedule)?;

Ok(state)
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);
let tx = client_finished(key_schedule, &mut tx_buf)?;

respond_blocking(tx, transport, key_schedule)?;

Expand All @@ -331,12 +335,15 @@ fn handle_processing_error_blocking<CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer,
record_reader: &mut RecordReader<CipherSuite>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
record_reader.discard_pending();
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);

let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
Expand Down Expand Up @@ -374,12 +381,15 @@ async fn handle_processing_error<'a, CipherSuite>(
result: Result<State, TlsError>,
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer<'a>,
record_reader: &mut RecordReader<'a, CipherSuite>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(TlsError::AbortHandshake(level, description)) = result {
record_reader.discard_pending();
let mut tx_buf = WriteBuffer::new(record_reader.take_buffer()?);

let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
Expand Down Expand Up @@ -469,15 +479,15 @@ where
}
}

fn process_server_verify<'a, 'v, CipherSuite, Verifier>(
fn process_server_verify<'a, 'v, 'c, CipherSuite, Verifier>(
handshake: &mut Handshake<CipherSuite, Verifier>,
key_schedule: &mut KeySchedule<CipherSuite>,
config: &TlsConfig<'a, CipherSuite>,
record: ServerRecord<'_, HashOutputSize<CipherSuite>>,
) -> Result<State, TlsError>
where
CipherSuite: TlsCipherSuite,
Verifier: TlsVerifier<'v, CipherSuite>,
Verifier: TlsVerifier<'v, 'c, CipherSuite>,
{
let mut state = State::ServerVerify;
decrypt_record(key_schedule.read_state(), record, |key_schedule, record| {
Expand Down
26 changes: 16 additions & 10 deletions src/record_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{

pub struct RecordReader<'a, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
CipherSuite: TlsCipherSuite,
{
pub(crate) buf: &'a mut [u8],
/// The number of decoded bytes in the buffer
Expand All @@ -26,7 +26,7 @@ where

impl<'a, CipherSuite> RecordReader<'a, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
CipherSuite: TlsCipherSuite,
{
pub fn new(buf: &'a mut [u8]) -> Self {
Self {
Expand All @@ -37,15 +37,20 @@ where
}
}

pub(crate) fn take_buffer(&mut self) -> Result<&mut [u8], TlsError> {
if self.pending > 0 {
return Err(TlsError::InternalError);
}

Ok(self.buf)
}

#[cfg(feature = "async")]
pub async fn read<'m>(
&'m mut self,
transport: &mut impl AsyncRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, HashOutputSize<CipherSuite>>, TlsError>
where
CipherSuite: TlsCipherSuite + 'static,
{
) -> Result<ServerRecord<'m, HashOutputSize<CipherSuite>>, TlsError> {
let header = self.advance(transport, 5).await?;
let header = RecordHeader::decode(header.try_into().unwrap())?;

Expand Down Expand Up @@ -83,10 +88,7 @@ where
&'m mut self,
transport: &mut impl BlockingRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, HashOutputSize<CipherSuite>>, TlsError>
where
CipherSuite: TlsCipherSuite + 'static,
{
) -> Result<ServerRecord<'m, HashOutputSize<CipherSuite>>, TlsError> {
let header = self.advance_blocking(transport, 5)?;
let header = RecordHeader::decode(header.try_into().unwrap())?;

Expand Down Expand Up @@ -130,6 +132,10 @@ where

Ok(())
}

pub(crate) fn discard_pending(&mut self) {
self.pending = 0;
}
}

#[cfg(test)]
Expand Down
6 changes: 3 additions & 3 deletions src/webpki.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ where
_clock: PhantomData<Clock>,
}

impl<'a, CipherSuite, Clock, const CERT_SIZE: usize> TlsVerifier<'a, CipherSuite>
for CertVerifier<'a, CipherSuite, Clock, CERT_SIZE>
impl<'a, 'c, CipherSuite, Clock, const CERT_SIZE: usize> TlsVerifier<'a, 'c, CipherSuite>
for CertVerifier<'c, CipherSuite, Clock, CERT_SIZE>
where
CipherSuite: TlsCipherSuite,
Clock: TlsClock,
{
fn new(host: Option<&'a str>) -> Self {
fn new(_buffer: &'a mut [u8], host: Option<&'c str>) -> Self {
Self {
host,
certificate_transcript: None,
Expand Down
Loading