diff --git a/src/client_hello.rs b/src/client_hello.rs index 0174cdb..9f3b6ee 100644 --- a/src/client_hello.rs +++ b/src/client_hello.rs @@ -92,33 +92,37 @@ impl ClientHello { skip_length_padded::<2, _>(buf); // cipher suites skip_length_padded::<1, _>(buf); // compression methods - buf.advance(2); // Extensions Length let mut session_ticket = Vec::new(); let mut client_supports_tls_13 = false; let mut psk = Vec::new(); let mut key_share = Vec::new(); - while buf.has_remaining() { - let ext = buf.get_u16(); - match ext { - EXTENSION_SESSION_TICKET => { - extend_from_length_prefixed::<2, _>(buf, &mut session_ticket); - } - EXTENSION_SUPPORTED_VERSIONS => { - client_supports_tls_13 = Self::read_supported_version(buf); - } - EXTENSION_PRE_SHARED_KEY => { - psk = Self::read_psk(buf); - } - EXTENSION_KEY_SHARE => { - key_share = Self::read_key_share(buf); - debug!("client key_share {:?}", key_share); - } - _ => { - skip_length_padded::<2, _>(buf); + u16_length_prefixed(buf, |mut ext_section| { + while ext_section.has_remaining() { + let ext = ext_section.get_u16(); + match ext { + EXTENSION_SESSION_TICKET => { + extend_from_length_prefixed::<2, _>(&mut ext_section, &mut session_ticket); + debug!("session_ticket: {:?}", session_ticket); + } + EXTENSION_SUPPORTED_VERSIONS => { + client_supports_tls_13 = Self::read_supported_version(&mut ext_section); + } + EXTENSION_PRE_SHARED_KEY => { + psk = Self::read_psk(&mut ext_section); + debug!("psk: {:?}", psk); + } + EXTENSION_KEY_SHARE => { + key_share = Self::read_key_share(&mut ext_section); + debug!("client key_share {:?}", key_share); + } + _ => { + skip_length_padded::<2, _>(&mut ext_section); + } } } - } + }); + if !client_supports_tls_13 { return Err(anyhow!("reject: client must support tls 1.3")); } diff --git a/src/restls.rs b/src/restls.rs index 298c4f8..972320b 100644 --- a/src/restls.rs +++ b/src/restls.rs @@ -308,6 +308,17 @@ impl<'a> RestlsState<'a> { (real_data_len, padding, command) } + fn parrot_tls12_nonce(&self, record: &mut [u8]) { + if self.parrot_tls12_gcm { + debug!( + "parrot gcm nounce {:?} {}", + &record[5..13], + self.to_client_counter + 1 + ); + record[5..13].copy_from_slice(&(self.to_client_counter + 1).to_be_bytes()); + } + } + fn prepare_app_data_header( &mut self, out_buf: &mut DoubleCursorBuf, @@ -320,10 +331,7 @@ impl<'a> RestlsState<'a> { let payload_len = (record.len() - 5) as u16; record[3..5].copy_from_slice(&payload_len.to_be_bytes()); - if self.parrot_tls12_gcm { - record[5..13].copy_from_slice(&(self.to_client_counter + 1).to_be_bytes()); - } - + self.parrot_tls12_nonce(record); let mut hmac_auth = self.restls_appdata_auth_hmac(true); let header_offset = self.restls_header_offset(true); hmac_auth.update(&record[..header_offset]); @@ -357,20 +365,26 @@ impl<'a> RestlsState<'a> { self.to_client_counter += 1; } - async fn read_from_stream(&self, stream: &mut TLSStream) -> Result<()> { + async fn read_from_stream(&self, stream: &mut TLSStream, eof_pending: bool) -> Result<()> { if stream.codec().has_next() { Ok(()) } else { match stream.next().await { - None => Err(anyhow!("unexpected eof")), + None => { + if eof_pending { + std::future::pending().await + } else { + Err(anyhow!("unexpected eof")) + } + } Some(res) => res, } } } async fn try_read_client_hello(&mut self, inbound: &mut TLSStream) -> Result<()> { - match self.read_from_stream(inbound).await { - Err(e) => return Err(e.context("failed to read client hello: ")), + match self.read_from_stream(inbound, false).await { + Err(e) => return Err(anyhow!("failed to read client hello: {:?}", e)), Ok(()) => (), }; let rtype = inbound.codec().peek_record_type()?; @@ -392,7 +406,7 @@ impl<'a> RestlsState<'a> { } async fn try_read_server_hello(&mut self, outbound: &mut TLSStream) -> Result<()> { - self.read_from_stream(outbound) + self.read_from_stream(outbound, false) .await .context("failed to read server hello: ")?; let rtype = outbound.codec().peek_record_type()?; @@ -424,7 +438,7 @@ impl<'a> RestlsState<'a> { ) -> Result<()> { let mut ccs_from_server = false; loop { - self.read_from_stream(outbound).await?; + self.read_from_stream(outbound, false).await?; let rtype = outbound.codec().peek_record_type()?; match rtype { @@ -482,7 +496,7 @@ impl<'a> RestlsState<'a> { let mut ccs_from_client = false; loop { select! { - res = self.read_from_stream(inbound) => { + res = self.read_from_stream(inbound, false) => { let _ = res?; match inbound.codec().peek_record_type()? { RECORD_CCS if !ccs_from_client => { @@ -510,7 +524,7 @@ impl<'a> RestlsState<'a> { inbound.codec_mut().next_record().expect("unexpected error: record has been checked"); self.relay_to(outbound, inbound).await?; } - res = self.read_from_stream(outbound) => { + res = self.read_from_stream(outbound, false) => { let _ = res?; outbound.codec_mut().skip_to_end(); self.relay_to(inbound, outbound).await?; @@ -648,7 +662,7 @@ impl<'a> RestlsState<'a> { inbound.codec().has_next() ); select! { - ret = self.read_from_stream(outbound) => { + ret = self.read_from_stream(outbound, false) => { match ret { Err(_) => return Ok(()), _ => (), @@ -660,7 +674,7 @@ impl<'a> RestlsState<'a> { .expect("unexpected error: record has been checked"); self.relay_to(inbound, outbound).await?; } - ret = self.read_from_stream(inbound) => { + ret = self.read_from_stream(inbound, false) => { ret?; self.handle_tls12_inbound(inbound, &mut flow)?; if flow.is_client_0x17() { @@ -791,14 +805,52 @@ impl<'a> RestlsState<'a> { Ok(WriteToServerResult::Ok((read_record, need_respond))) } + async fn write_handshake_data_to_client( + &mut self, + inbound: &mut TLSStream, + outbound_handshake: &mut TLSStream, + close_notify_cache: &mut Vec, + ) -> Result<()> { + while outbound_handshake.codec().has_next() { + let record = match outbound_handshake.codec_mut().next_record() { + Ok(r) => r, + Err(e) => { + debug!("[{}]outbound_handshake read error {}", self.id, e); + return Ok(()); + } + }; + debug!("[{}]read from outbound_handshake {:?}", self.id, record); + if record.len() < 50 { + close_notify_cache.extend_from_slice(record); + return Ok(()); + } else { + self.parrot_tls12_nonce(record); + inbound.get_mut().write_all(record).await?; + self.to_client_counter += 1; + } + } + Ok(()) + } + + async fn write_close_notify(&mut self, close_notify: &mut Vec, inbound: &mut TcpStream) { + if close_notify.len() > 0 { + debug!("[{}]sending close-notify {:?}", self.id, close_notify); + self.parrot_tls12_nonce(close_notify); + let _ = inbound.write(&close_notify).await; + } + } + pub async fn copy_bidirectional( &mut self, inbound: &mut TLSStream, outbound: &mut TcpStream, + outbound_handshake: &mut TLSStream, ) -> Result<()> { let mut out_buf = DoubleCursorBuf::new(self.restls_data_offset(true)); let mut awaiting = false; let mut need_respond = 0; + let mut close_notify = Vec::new(); + let mut outbound_handshake_closed = false; async fn read_if_has_capacity( outbound: &mut TcpStream, out_buf: &mut DoubleCursorBuf, @@ -811,10 +863,17 @@ impl<'a> RestlsState<'a> { } loop { select! { - res = self.read_from_stream(inbound) => { + res = self.read_from_stream(outbound_handshake, true) => { + debug!("[{}]read from outbound_handshake got {:?}", self.id, res); + self.write_handshake_data_to_client(inbound, outbound_handshake, &mut close_notify).await?; + } + res = self.read_from_stream(inbound, false) => { res?; match self.write_client_data_to_server(inbound, outbound).await? { - WriteToServerResult::MaybeCloseNotify => return Ok(()), + WriteToServerResult::MaybeCloseNotify => { + self.write_close_notify(&mut close_notify, inbound.get_mut()).await; + return Ok(()) + }, WriteToServerResult::Ok((read, respond)) => { awaiting = read == 0; debug!("[{}]read {} and set awaiting to {}", self.id, read, awaiting); @@ -825,6 +884,7 @@ impl<'a> RestlsState<'a> { n = read_if_has_capacity(outbound, &mut out_buf) => { let n = n.context("outbound.read failed: ")?; if n == 0 { + self.write_close_notify(&mut close_notify, inbound.get_mut()).await; return Ok(()); } out_buf.advance_back(n); @@ -874,6 +934,14 @@ impl<'a> RestlsState<'a> { } need_respond = 0; } + + if !outbound_handshake_closed && self.to_client_counter + self.to_server_counter > 5 { + outbound_handshake_closed = true; + match outbound_handshake.get_mut().shutdown().await { + Err(e) => debug!("[{}]outbound_handshake shutdown got {:?}", self.id, e), + Ok(_) => (), + } + } } } @@ -935,7 +1003,7 @@ pub async fn handle(options: Arc, inbound: TcpStream, id: usize) -> Result< Ok(()) => { let mut outbound_proxy = TcpStream::connect(&options.forward_to).await?; match try_handshake - .copy_bidirectional(&mut inbound, &mut outbound_proxy) + .copy_bidirectional(&mut inbound, &mut outbound_proxy, &mut outbound) .await { Err(e) => {