Skip to content

Commit

Permalink
Merge pull request #11 from 3andne/prepare-v1-3
Browse files Browse the repository at this point in the history
feat: send post-handshake tls record to client
  • Loading branch information
3andne authored Mar 12, 2023
2 parents 4d5ce0f + ff63eac commit 1d974e3
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 37 deletions.
44 changes: 24 additions & 20 deletions src/client_hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,37 @@ impl ClientHello {

skip_length_padded::<2, _>(buf); // cipher suites
skip_length_padded::<1, _>(buf); // compression methods
buf.advance(2); // Extensions Length

let mut session_ticket = Vec::new();
let mut client_supports_tls_13 = false;
let mut psk = Vec::new();
let mut key_share = Vec::new();
while buf.has_remaining() {
let ext = buf.get_u16();
match ext {
EXTENSION_SESSION_TICKET => {
extend_from_length_prefixed::<2, _>(buf, &mut session_ticket);
}
EXTENSION_SUPPORTED_VERSIONS => {
client_supports_tls_13 = Self::read_supported_version(buf);
}
EXTENSION_PRE_SHARED_KEY => {
psk = Self::read_psk(buf);
}
EXTENSION_KEY_SHARE => {
key_share = Self::read_key_share(buf);
debug!("client key_share {:?}", key_share);
}
_ => {
skip_length_padded::<2, _>(buf);
u16_length_prefixed(buf, |mut ext_section| {
while ext_section.has_remaining() {
let ext = ext_section.get_u16();
match ext {
EXTENSION_SESSION_TICKET => {
extend_from_length_prefixed::<2, _>(&mut ext_section, &mut session_ticket);
debug!("session_ticket: {:?}", session_ticket);
}
EXTENSION_SUPPORTED_VERSIONS => {
client_supports_tls_13 = Self::read_supported_version(&mut ext_section);
}
EXTENSION_PRE_SHARED_KEY => {
psk = Self::read_psk(&mut ext_section);
debug!("psk: {:?}", psk);
}
EXTENSION_KEY_SHARE => {
key_share = Self::read_key_share(&mut ext_section);
debug!("client key_share {:?}", key_share);
}
_ => {
skip_length_padded::<2, _>(&mut ext_section);
}
}
}
}
});

if !client_supports_tls_13 {
return Err(anyhow!("reject: client must support tls 1.3"));
}
Expand Down
102 changes: 85 additions & 17 deletions src/restls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ impl<'a> RestlsState<'a> {
(real_data_len, padding, command)
}

fn parrot_tls12_nonce(&self, record: &mut [u8]) {
if self.parrot_tls12_gcm {
debug!(
"parrot gcm nounce {:?} {}",
&record[5..13],
self.to_client_counter + 1
);
record[5..13].copy_from_slice(&(self.to_client_counter + 1).to_be_bytes());
}
}

fn prepare_app_data_header(
&mut self,
out_buf: &mut DoubleCursorBuf,
Expand All @@ -320,10 +331,7 @@ impl<'a> RestlsState<'a> {
let payload_len = (record.len() - 5) as u16;
record[3..5].copy_from_slice(&payload_len.to_be_bytes());

if self.parrot_tls12_gcm {
record[5..13].copy_from_slice(&(self.to_client_counter + 1).to_be_bytes());
}

self.parrot_tls12_nonce(record);
let mut hmac_auth = self.restls_appdata_auth_hmac(true);
let header_offset = self.restls_header_offset(true);
hmac_auth.update(&record[..header_offset]);
Expand Down Expand Up @@ -357,20 +365,26 @@ impl<'a> RestlsState<'a> {
self.to_client_counter += 1;
}

async fn read_from_stream(&self, stream: &mut TLSStream) -> Result<()> {
async fn read_from_stream(&self, stream: &mut TLSStream, eof_pending: bool) -> Result<()> {
if stream.codec().has_next() {
Ok(())
} else {
match stream.next().await {
None => Err(anyhow!("unexpected eof")),
None => {
if eof_pending {
std::future::pending().await
} else {
Err(anyhow!("unexpected eof"))
}
}
Some(res) => res,
}
}
}

async fn try_read_client_hello(&mut self, inbound: &mut TLSStream) -> Result<()> {
match self.read_from_stream(inbound).await {
Err(e) => return Err(e.context("failed to read client hello: ")),
match self.read_from_stream(inbound, false).await {
Err(e) => return Err(anyhow!("failed to read client hello: {:?}", e)),
Ok(()) => (),
};
let rtype = inbound.codec().peek_record_type()?;
Expand All @@ -392,7 +406,7 @@ impl<'a> RestlsState<'a> {
}

async fn try_read_server_hello(&mut self, outbound: &mut TLSStream) -> Result<()> {
self.read_from_stream(outbound)
self.read_from_stream(outbound, false)
.await
.context("failed to read server hello: ")?;
let rtype = outbound.codec().peek_record_type()?;
Expand Down Expand Up @@ -424,7 +438,7 @@ impl<'a> RestlsState<'a> {
) -> Result<()> {
let mut ccs_from_server = false;
loop {
self.read_from_stream(outbound).await?;
self.read_from_stream(outbound, false).await?;
let rtype = outbound.codec().peek_record_type()?;

match rtype {
Expand Down Expand Up @@ -482,7 +496,7 @@ impl<'a> RestlsState<'a> {
let mut ccs_from_client = false;
loop {
select! {
res = self.read_from_stream(inbound) => {
res = self.read_from_stream(inbound, false) => {
let _ = res?;
match inbound.codec().peek_record_type()? {
RECORD_CCS if !ccs_from_client => {
Expand Down Expand Up @@ -510,7 +524,7 @@ impl<'a> RestlsState<'a> {
inbound.codec_mut().next_record().expect("unexpected error: record has been checked");
self.relay_to(outbound, inbound).await?;
}
res = self.read_from_stream(outbound) => {
res = self.read_from_stream(outbound, false) => {
let _ = res?;
outbound.codec_mut().skip_to_end();
self.relay_to(inbound, outbound).await?;
Expand Down Expand Up @@ -648,7 +662,7 @@ impl<'a> RestlsState<'a> {
inbound.codec().has_next()
);
select! {
ret = self.read_from_stream(outbound) => {
ret = self.read_from_stream(outbound, false) => {
match ret {
Err(_) => return Ok(()),
_ => (),
Expand All @@ -660,7 +674,7 @@ impl<'a> RestlsState<'a> {
.expect("unexpected error: record has been checked");
self.relay_to(inbound, outbound).await?;
}
ret = self.read_from_stream(inbound) => {
ret = self.read_from_stream(inbound, false) => {
ret?;
self.handle_tls12_inbound(inbound, &mut flow)?;
if flow.is_client_0x17() {
Expand Down Expand Up @@ -791,14 +805,52 @@ impl<'a> RestlsState<'a> {
Ok(WriteToServerResult::Ok((read_record, need_respond)))
}

async fn write_handshake_data_to_client(
&mut self,
inbound: &mut TLSStream,
outbound_handshake: &mut TLSStream,
close_notify_cache: &mut Vec<u8>,
) -> Result<()> {
while outbound_handshake.codec().has_next() {
let record = match outbound_handshake.codec_mut().next_record() {
Ok(r) => r,
Err(e) => {
debug!("[{}]outbound_handshake read error {}", self.id, e);
return Ok(());
}
};
debug!("[{}]read from outbound_handshake {:?}", self.id, record);
if record.len() < 50 {
close_notify_cache.extend_from_slice(record);
return Ok(());
} else {
self.parrot_tls12_nonce(record);
inbound.get_mut().write_all(record).await?;
self.to_client_counter += 1;
}
}
Ok(())
}

async fn write_close_notify(&mut self, close_notify: &mut Vec<u8>, inbound: &mut TcpStream) {
if close_notify.len() > 0 {
debug!("[{}]sending close-notify {:?}", self.id, close_notify);
self.parrot_tls12_nonce(close_notify);
let _ = inbound.write(&close_notify).await;
}
}

pub async fn copy_bidirectional(
&mut self,
inbound: &mut TLSStream,
outbound: &mut TcpStream,
outbound_handshake: &mut TLSStream,
) -> Result<()> {
let mut out_buf = DoubleCursorBuf::new(self.restls_data_offset(true));
let mut awaiting = false;
let mut need_respond = 0;
let mut close_notify = Vec::new();
let mut outbound_handshake_closed = false;
async fn read_if_has_capacity(
outbound: &mut TcpStream,
out_buf: &mut DoubleCursorBuf,
Expand All @@ -811,10 +863,17 @@ impl<'a> RestlsState<'a> {
}
loop {
select! {
res = self.read_from_stream(inbound) => {
res = self.read_from_stream(outbound_handshake, true) => {
debug!("[{}]read from outbound_handshake got {:?}", self.id, res);
self.write_handshake_data_to_client(inbound, outbound_handshake, &mut close_notify).await?;
}
res = self.read_from_stream(inbound, false) => {
res?;
match self.write_client_data_to_server(inbound, outbound).await? {
WriteToServerResult::MaybeCloseNotify => return Ok(()),
WriteToServerResult::MaybeCloseNotify => {
self.write_close_notify(&mut close_notify, inbound.get_mut()).await;
return Ok(())
},
WriteToServerResult::Ok((read, respond)) => {
awaiting = read == 0;
debug!("[{}]read {} and set awaiting to {}", self.id, read, awaiting);
Expand All @@ -825,6 +884,7 @@ impl<'a> RestlsState<'a> {
n = read_if_has_capacity(outbound, &mut out_buf) => {
let n = n.context("outbound.read failed: ")?;
if n == 0 {
self.write_close_notify(&mut close_notify, inbound.get_mut()).await;
return Ok(());
}
out_buf.advance_back(n);
Expand Down Expand Up @@ -874,6 +934,14 @@ impl<'a> RestlsState<'a> {
}
need_respond = 0;
}

if !outbound_handshake_closed && self.to_client_counter + self.to_server_counter > 5 {
outbound_handshake_closed = true;
match outbound_handshake.get_mut().shutdown().await {
Err(e) => debug!("[{}]outbound_handshake shutdown got {:?}", self.id, e),
Ok(_) => (),
}
}
}
}

Expand Down Expand Up @@ -935,7 +1003,7 @@ pub async fn handle(options: Arc<Opt>, inbound: TcpStream, id: usize) -> Result<
Ok(()) => {
let mut outbound_proxy = TcpStream::connect(&options.forward_to).await?;
match try_handshake
.copy_bidirectional(&mut inbound, &mut outbound_proxy)
.copy_bidirectional(&mut inbound, &mut outbound_proxy, &mut outbound)
.await
{
Err(e) => {
Expand Down

0 comments on commit 1d974e3

Please sign in to comment.