diff --git a/Cargo.lock b/Cargo.lock index f8a2b60..6b63199 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -431,6 +431,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -473,6 +479,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -491,6 +527,7 @@ dependencies = [ "clap 4.1.1", "futures-util", "hmac", + "rand", "sha1", "sha2", "structopt", diff --git a/Cargo.toml b/Cargo.toml index 231f607..c4c396f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,4 +17,5 @@ sha1 = "*" hmac = "*" tokio-util = { version = "*", features = ["full"] } futures-util = "0.3.26" -sha2 = "*" \ No newline at end of file +sha2 = "*" +rand = "0.8.5" \ No newline at end of file diff --git a/src/args.rs b/src/args.rs index cfba1c3..b97bcd8 100644 --- a/src/args.rs +++ b/src/args.rs @@ -61,6 +61,16 @@ pub struct Opt { /// the password to authenticate connections #[structopt(short = "p", long, parse(from_str = make_password))] pub password: Password, + + /// The target length of early server data. Packets that exceed this + /// will be truncated. + #[structopt(long, default_value = "1385")] + pub mtu: u16, + + /// The minimal length of server data. Packets that below this + /// will be padded. + #[structopt(long, default_value = "15")] + pub min_record_len: u16, } // #[derive(Debug)] diff --git a/src/common.rs b/src/common.rs index a9c7795..cebde8d 100644 --- a/src/common.rs +++ b/src/common.rs @@ -2,6 +2,8 @@ pub const REQUIRED_SESSION_ID_LEN: usize = 32; pub const RESTLS_HANDSHAKE_HMAC_LEN: usize = 16; pub const RESTLS_APPDATA_HMAC_LEN: usize = 8; +pub const RESTLS_APPDATA_LEN_OFFSET: usize = 5 + RESTLS_APPDATA_HMAC_LEN; +pub const RESTLS_APPDATA_OFFSET: usize = 5 + RESTLS_APPDATA_HMAC_LEN + 2; // record type pub const RECORD_HANDSHAKE: u8 = 0x16; @@ -25,3 +27,8 @@ pub const HELLO_RETRY_RANDOM: [u8; 32] = [ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, ]; + +pub const TO_CLIENT_MAGIC: &'static [u8] = "server-to-client".as_bytes(); +pub const TO_SERVER_MAGIC: &'static [u8] = "client-to-server".as_bytes(); + +pub const BUF_SIZE: usize = 0x2000; diff --git a/src/restls.rs b/src/restls.rs index a3771cc..047109f 100644 --- a/src/restls.rs +++ b/src/restls.rs @@ -1,9 +1,10 @@ use anyhow::{anyhow, Context, Result}; use futures_util::stream::StreamExt; use hmac::Mac; +use rand::Rng; use std::{io::Cursor, sync::Arc}; use tokio::{ - io::AsyncWriteExt, + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, select, }; @@ -15,13 +16,12 @@ use crate::{ client_hello::ClientHello, client_key_exchange::ClientKeyExchange, common::{ - RECORD_APPLICATION_DATA, RECORD_CCS, RECORD_HANDSHAKE, RESTLS_APPDATA_HMAC_LEN, - RESTLS_HANDSHAKE_HMAC_LEN, + BUF_SIZE, RECORD_ALERT, RECORD_APPLICATION_DATA, RECORD_CCS, RECORD_HANDSHAKE, + RESTLS_APPDATA_HMAC_LEN, RESTLS_APPDATA_LEN_OFFSET, RESTLS_APPDATA_OFFSET, + RESTLS_HANDSHAKE_HMAC_LEN, TO_CLIENT_MAGIC, TO_SERVER_MAGIC, }, server_hello::ServerHello, - utils::{ - copy_bidirectional, copy_bidirectional_fallback, xor_bytes, HmacSha1, TLSCodec, TLSStream, - }, + utils::{copy_bidirectional_fallback, tcp_rst, xor_bytes, HmacSha1, TLSCodec, TLSStream}, }; #[derive(Debug)] @@ -135,14 +135,134 @@ impl TLS12Flow { } } -struct TryHandshake<'a> { +pub struct RestlsState<'a> { client_hello: Option, server_hello: Option, client_finished: Vec, restls_password: &'a [u8], + to_client_counter: u32, + to_server_counter: u32, + to_client_raw_counter: u32, + mtu: usize, + min_record_len: usize, + id: usize, +} + +fn sample_slice(data: &[u8]) -> &[u8] { + &data[..std::cmp::min(32, data.len())] } -impl<'a> TryHandshake<'a> { +impl<'a> RestlsState<'a> { + fn restls_hmac(&self) -> HmacSha1 { + HmacSha1::new_from_slice(self.restls_password).expect("sha1 should take key of any size") + } + + pub fn restls_appdata_auth_hmac(&self, is_to_client: bool) -> HmacSha1 { + let mut hasher = self.restls_hmac(); + hasher.update(&self.server_hello.as_ref().unwrap().server_random); + if is_to_client { + hasher.update(TO_CLIENT_MAGIC); + hasher.update(&self.to_client_counter.to_be_bytes()); + } else { + hasher.update(TO_SERVER_MAGIC); + hasher.update(&self.to_server_counter.to_be_bytes()); + } + hasher + } + + pub fn read_app_data<'b>(&mut self, record: &'b mut [u8]) -> Result<&'b [u8]> { + if &record[..3] != &[RECORD_APPLICATION_DATA, 0x03, 0x03] { + return Err(anyhow!( + "[{}]reject: restls application data must have 0x17 header, got: {:?}", + self.id, + record + )); + } + let actual_auth = &record[5..5 + RESTLS_APPDATA_HMAC_LEN]; + let mut hmac_auth = self.restls_appdata_auth_hmac(false); + if self.client_finished.len() > 0 { + debug!("adding client_finished {:?}", self.client_finished); + hmac_auth.update(&self.client_finished); + self.client_finished.resize(0, 0); + } + hmac_auth.update(&record[RESTLS_APPDATA_LEN_OFFSET..]); + let expect_auth = hmac_auth.finalize().into_bytes(); + if actual_auth != &expect_auth[..RESTLS_APPDATA_HMAC_LEN] { + debug!( + "[{}]bad mac record, expect auth {:?}, actual {:?}, to_client: {}, to_server: {}", + self.id, + &expect_auth[..RESTLS_APPDATA_HMAC_LEN], + actual_auth, + self.to_client_counter, + self.to_server_counter + ); + return Err(anyhow!("reject: bad mac record")); + } + + let mut hmac_mask = self.restls_appdata_auth_hmac(false); + hmac_mask.update(sample_slice(&record[RESTLS_APPDATA_OFFSET..])); + let mask = hmac_mask.finalize().into_bytes(); + let data_len_bytes = &mut record[RESTLS_APPDATA_LEN_OFFSET..][..2]; + xor_bytes(&mask[..2], data_len_bytes); + let data_len = (data_len_bytes[0] as usize) << 8 | (data_len_bytes[1] as usize); + self.to_server_counter += 1; + debug!("[{}]read_app_data: data_len {}", self.id, data_len); + Ok(&record[RESTLS_APPDATA_OFFSET..][..data_len]) + } + + pub fn write_app_data_header<'b>( + &mut self, + record: &'b mut [u8], + data_len: usize, + ) -> Result<(usize, usize)> { + assert!(record.len() >= 1500); + let mtu = (self.mtu as isize + rand::thread_rng().gen_range(-50..50)) + .try_into() + .unwrap(); + let min_record_len = self.min_record_len + rand::thread_rng().gen_range(0..100); + let (real_data_len, padding) = match ( + data_len < min_record_len, + self.to_client_raw_counter < 5, + data_len > mtu, + ) { + (true, _, _) => (data_len, min_record_len - data_len), + (_, false, _) => (data_len, 0), + (_, true, true) => (mtu, 0), + (_, true, false) => (data_len, mtu - data_len), + }; + if padding > 0 { + rand::thread_rng() + .fill(&mut record[RESTLS_APPDATA_OFFSET + real_data_len..][..padding]); + } + let mut hmac_mask = self.restls_appdata_auth_hmac(true); + hmac_mask.update(sample_slice( + &record[RESTLS_APPDATA_OFFSET..][..real_data_len + padding], + )); + let mask = hmac_mask.finalize().into_bytes(); + record[RESTLS_APPDATA_LEN_OFFSET..][..2] + .copy_from_slice(&(real_data_len as u16).to_be_bytes()); + xor_bytes(&mask[..2], &mut record[RESTLS_APPDATA_LEN_OFFSET..]); + let mut hmac_auth = self.restls_appdata_auth_hmac(true); + hmac_auth.update(&record[RESTLS_APPDATA_LEN_OFFSET..][..2 + real_data_len + padding]); + let auth = hmac_auth.finalize().into_bytes(); + record[5..5 + RESTLS_APPDATA_HMAC_LEN].copy_from_slice(&auth[..RESTLS_APPDATA_HMAC_LEN]); + record[0..3].copy_from_slice(&[0x17, 0x3, 0x3]); + record[3..5].copy_from_slice(&((real_data_len + padding + 10) as u16).to_be_bytes()); + self.to_client_counter += 1; + if real_data_len == data_len { + self.to_client_raw_counter += 1; + } + debug!( + "[{}]write_header: data_len {}, padding: {}, mask: {:?}, auth: {:?}", + self.id, + real_data_len, + padding, + &mask[..2], + &auth[..RESTLS_APPDATA_HMAC_LEN] + ); + Ok((real_data_len, real_data_len + padding + 15)) + } + async fn read_from_stream(&self, stream: &mut TLSStream) -> Result<()> { if stream.codec().has_next() { Ok(()) @@ -155,9 +275,10 @@ impl<'a> TryHandshake<'a> { } async fn try_read_client_hello(&mut self, inbound: &mut TLSStream) -> Result<()> { - self.read_from_stream(inbound) - .await - .context("failed to read client hello: ")?; + match self.read_from_stream(inbound).await { + Err(e) => return Err(e.context("failed to read client hello: ")), + Ok(()) => (), + }; let rtype = inbound.codec().peek_record_type()?; if rtype != RECORD_HANDSHAKE { return Err(anyhow!( @@ -230,12 +351,12 @@ impl<'a> TryHandshake<'a> { } fn prepare_server_auth(&self, outbound: &mut TLSStream) { - let mut hasher = HmacSha1::new_from_slice(self.restls_password) - .expect("sha1 should take key of any size"); + let mut hasher = self.restls_hmac(); hasher.update(&self.server_hello.as_ref().unwrap().server_random); let secret = hasher.finalize().into_bytes(); debug!( - "server challenge {:?}", + "[{}]server challenge {:?}", + self.id, &secret[..RESTLS_HANDSHAKE_HMAC_LEN] ); let record = outbound @@ -262,7 +383,9 @@ impl<'a> TryHandshake<'a> { } RECORD_APPLICATION_DATA if ccs_from_client => { seen_client_application_data += 1; - if seen_client_application_data == 2 { + if seen_client_application_data == 1 { + self.client_finished.extend_from_slice(inbound.codec().peek_record().unwrap()); + } else if seen_client_application_data == 2 { break; } } @@ -286,8 +409,7 @@ impl<'a> TryHandshake<'a> { } fn check_tls13_session_id(&self) -> Result<()> { - let mut hasher = HmacSha1::new_from_slice(self.restls_password) - .expect("sha1 should take key of any size"); + let mut hasher = self.restls_hmac(); let client_hello = self.client_hello.as_ref().unwrap(); hasher.update(&client_hello.key_share); hasher.update(&client_hello.psk); @@ -308,33 +430,6 @@ impl<'a> TryHandshake<'a> { } } - fn check_tls13_application_data_auth(&self, in_buf: &[u8]) -> Result<()> { - let mut hasher = HmacSha1::new_from_slice(self.restls_password) - .expect("sha1 should take key of any size"); - let server_hello = self.server_hello.as_ref().unwrap(); - hasher.update(&server_hello.server_random); - hasher.update(&server_hello.server_random); - let res = hasher.finalize().into_bytes(); - let expect = &res[..RESTLS_APPDATA_HMAC_LEN]; - let application_data = &in_buf[5..]; - if application_data.len() < RESTLS_APPDATA_HMAC_LEN { - return Err(anyhow!( - "reject: application data too short to contain an auth header" - )); - } - let actual = &application_data[..RESTLS_APPDATA_HMAC_LEN]; - if expect == actual { - debug!("tls13 challenge responded"); - Ok(()) - } else { - Err(anyhow!( - "reject: incorrect application auth header, expect: {:?}, actual {:?}", - expect, - actual - )) - } - } - fn handle_tls12_outbound(&self, outbound: &mut TLSStream, flow: &mut TLS12Flow) -> Result<()> { let rtype = outbound.codec().peek_record_type()?; match rtype { @@ -346,7 +441,7 @@ impl<'a> TryHandshake<'a> { self.check_tls12_session_ticket()?; } self.prepare_server_auth(outbound); - debug!("sending tls12 server auth to client"); + debug!("[{}]sending tls12 server auth to client", self.id); } RECORD_HANDSHAKE => (), _ => { @@ -400,10 +495,13 @@ impl<'a> TryHandshake<'a> { ) -> Result<()> { let mut flow = TLS12Flow::Initial; loop { - debug!("flow {:?}", flow); + debug!("[{}]flow {:?}", self.id, flow); select! { ret = self.read_from_stream(outbound) => { - ret?; + match ret { + Err(_) => return Ok(()), + _ => (), + } self.handle_tls12_outbound(outbound, &mut flow)?; outbound .codec_mut() @@ -429,7 +527,7 @@ impl<'a> TryHandshake<'a> { } fn check_tls12_session_ticket(&self) -> Result<()> { - let mut hasher = HmacSha1::new_from_slice(self.restls_password).expect("unexpected"); + let mut hasher = self.restls_hmac(); let client_hello = self.client_hello.as_ref().unwrap(); hasher.update(&client_hello.session_ticket); let actual_hash = hasher.finalize().into_bytes(); @@ -462,6 +560,65 @@ impl<'a> TryHandshake<'a> { res } + pub async fn copy_bidirectional( + &mut self, + inbound: &mut TLSStream, + outbound: &mut TcpStream, + ) -> Result<()> { + let mut out_buf = [0; BUF_SIZE]; + loop { + select! { + res = self.read_from_stream(inbound) => { + res?; + while inbound.codec().has_next() { + let record = inbound.codec_mut().next_record()?; + if record[0] == RECORD_ALERT { + return Ok(()); + } + let record = match self.read_app_data(record) { + Ok(record) => record, + Err(e) => { + return if self.server_hello.as_ref().unwrap().is_tls13 + && self.to_client_counter > 0 + && self.to_server_counter > 0 + { + // this will probably be a close notify. + // we'll ignore it. + Ok(()) + } else { + Err(e) + } + } + }; + if record.len() > 0 { + outbound.write_all(record).await.context("outbound.write_all failed: ")?; + } + } + inbound.codec_mut().reset(); + } + n = outbound.read(&mut out_buf[RESTLS_APPDATA_OFFSET..]) => { + let mut n = n.context("outbound.read failed: ")?; + if n == 0 { + return Ok(()); + } + let mut written = 0; + while written < n { + if (BUF_SIZE - written) < 1500 { + for i in 0..BUF_SIZE - written { + out_buf[i] = out_buf[written + i]; + } + n -= written; + written = 0; + } + let (new_written, packet_size) = self.write_app_data_header(&mut out_buf[written..], n - written)?; + inbound.get_mut().write_all(&out_buf[written..][..packet_size]).await.context("inbound.write_all failed: ")?; + written += new_written; + } + } + } + } + } + async fn try_handshake( &mut self, outbound: &mut TLSStream, @@ -479,11 +636,10 @@ impl<'a> TryHandshake<'a> { .await?; self.prepare_server_auth(outbound); outbound.codec_mut().next_record().unwrap(); - debug!("sending tls13 server auth to client"); + debug!("[{}]sending tls13 server auth to client", self.id); self.relay_to(inbound, outbound).await?; self.try_read_tl13_till_client_application_data(outbound, inbound) .await?; - self.check_tls13_application_data_auth(inbound.codec().peek_record()?)?; } else { self.try_read_tls12_till_client_application_data(outbound, inbound) .await?; @@ -492,28 +648,43 @@ impl<'a> TryHandshake<'a> { } } -pub async fn handle(options: Arc, inbound: TcpStream) -> Result<()> { +pub async fn handle(options: Arc, inbound: TcpStream, id: usize) -> Result<()> { let mut outbound = TLSCodec::new_inbound().framed( // TcpStream::connect("89.145.65.200:443") - TcpStream::connect(&options.server_hostname) + TcpStream::connect(&options.server_hostname) .await .context("cannot connect to outbound".to_owned() + &options.server_hostname)?, ); let mut inbound = TLSCodec::new_outbound().framed(inbound); - let mut try_handshake = TryHandshake { + let mut try_handshake = RestlsState { client_hello: None, server_hello: None, restls_password: options.password.as_bytes(), client_finished: Vec::new(), + to_client_counter: 0, + to_server_counter: 0, + to_client_raw_counter: 0, + mtu: options.mtu as usize, + min_record_len: options.mtu as usize, + id, }; match try_handshake .try_handshake(&mut outbound, &mut inbound) .await { Ok(()) => { - let outbound_proxy = TcpStream::connect(&options.forward_to).await?; - copy_bidirectional(inbound, outbound_proxy).await?; + let mut outbound_proxy = TcpStream::connect(&options.forward_to).await?; + match try_handshake + .copy_bidirectional(&mut inbound, &mut outbound_proxy) + .await + { + Err(e) => { + tracing::error!("restls data relay failed: {}", e); + tcp_rst(inbound.get_mut()).await? + } + _ => (), + } } Err(e) => { tracing::error!("handshake failed: {}", e); @@ -532,17 +703,20 @@ pub async fn start(options: Arc) -> Result<()> { options.listen, options.forward_to, ); + let mut counter = 0; loop { let (stream, _) = listener .accept() .await .context("failed to accept inbound stream")?; let options = options.clone(); + tokio::spawn(async move { - match handle(options, stream).await { - Err(e) => tracing::debug!("{}", e), + match handle(options, stream, counter).await { + Err(e) => tracing::debug!("[{}]{}", counter, e), Ok(_) => (), } }); + counter += 1; } } diff --git a/src/utils.rs b/src/utils.rs index fedb3f4..71dbea4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,11 +8,7 @@ use std::{ io::{self, Cursor}, time::Duration, }; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, - select, -}; +use tokio::{io::AsyncWriteExt, net::TcpStream, select}; use tracing::debug; use tokio_util::codec::{Decoder, Framed}; @@ -247,53 +243,19 @@ pub(crate) fn xor_bytes(secret: &[u8], msg: &mut [u8]) { } } -const CONTENT_OFFSET: usize = 5 + RESTLS_APPDATA_HMAC_LEN; - -pub async fn copy_bidirectional(mut inbound: TLSStream, mut outbound: TcpStream) -> Result<()> { - let mut out_buf = [0; 0x2000]; - out_buf[..3].copy_from_slice(&[0x17, 0x03, 0x03]); - while inbound.codec().has_next() { - outbound - .write_all( - &inbound - .codec_mut() - .next_record() - .expect("unexpected error: this record should have been checked") - [CONTENT_OFFSET..], - ) - .await?; - } +const CONTENT_OFFSET: usize = 5 + RESTLS_APPDATA_HMAC_LEN + 2; - inbound.codec_mut().reset(); +pub async fn tcp_rst(stream: &mut TcpStream) -> Result<()> { + stream.set_linger(Some(Duration::from_secs(0)))?; + stream.shutdown().await?; + return Ok(()); +} - loop { - select! { - res = inbound.next() => { - match res { - Some(Ok(_)) => (), - None => { - return Ok(()); - } - Some(Err(e)) => { - return Err(e); - } - } - while inbound.codec().has_next() { - outbound - .write_all(&inbound.codec_mut().next_record().expect("todo: verification")[CONTENT_OFFSET..]) - .await?; - } - inbound.codec_mut().reset(); - } - n = outbound.read(&mut out_buf[CONTENT_OFFSET..]) => { - let n = n?; - if n == 0 { - return Ok(()); - } - out_buf[3..5].copy_from_slice(&(n as u16 + 8).to_be_bytes()); - inbound.get_mut().write_all(&out_buf[..n+CONTENT_OFFSET]).await?; - } - } +async fn inbound_read_helper(inbound: &mut TLSStream) -> Option> { + if inbound.codec().has_next() { + Some(Ok(())) + } else { + inbound.next().await } } @@ -347,18 +309,16 @@ pub async fn copy_bidirectional_fallback( match res { Some(Ok(_)) => (), Some(Err(root_cause)) => { - match root_cause.downcast_ref::() { + return match root_cause.downcast_ref::() { Some(e) => { match e.kind() { io::ErrorKind::ConnectionReset => { - inbound.get_mut().set_linger(Some(Duration::from_secs(0)))?; - inbound.get_mut().shutdown().await?; - return Ok(()); + tcp_rst(inbound.get_mut()).await } - _ => return Err(root_cause), + _ => Err(root_cause), } }, - None => return Err(root_cause), + None => Err(root_cause), } } None => {