diff --git a/src/extensions/client.rs b/src/extensions/client.rs index aebef546..ea81b721 100644 --- a/src/extensions/client.rs +++ b/src/extensions/client.rs @@ -1,3 +1,4 @@ +use crate::extensions::common::KeyShare; use crate::extensions::ExtensionType; use crate::signature_schemes::SignatureScheme; @@ -21,10 +22,7 @@ pub enum ClientExtension<'a> { SupportedGroups { supported_groups: Vec, }, - KeyShare { - group: NamedGroup, - opaque: &'a [u8], - }, + KeyShare(KeyShare<'a>), PreSharedKey { identities: Vec<&'a [u8], 4>, hash_size: usize, @@ -113,16 +111,7 @@ impl ClientExtension<'_> { Ok(()) }) } - ClientExtension::KeyShare { group, opaque } => { - buf.with_u16_length(|buf| { - // one key-share - buf.push_u16(*group as u16) - .map_err(|_| TlsError::EncodeError)?; - - buf.with_u16_length(|buf| buf.extend_from_slice(opaque.as_ref())) - .map_err(|_| TlsError::EncodeError) - }) - } + ClientExtension::KeyShare(key_share) => key_share.encode(buf), ClientExtension::PreSharedKey { identities, hash_size, diff --git a/src/extensions/common.rs b/src/extensions/common.rs index 8a12ddba..eadb71c4 100644 --- a/src/extensions/common.rs +++ b/src/extensions/common.rs @@ -1,5 +1,21 @@ +use crate::buffer::CryptoBuffer; use crate::named_groups::NamedGroup; use crate::parse_buffer::{ParseBuffer, ParseError}; +use crate::TlsError; + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct KeyShare<'a>(pub(crate) KeyShareEntry<'a>); + +impl<'a> KeyShare<'a> { + pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, ParseError> { + Ok(KeyShare(KeyShareEntry::parse(buf)?)) + } + + pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), TlsError> { + self.0.encode(buf) + } +} #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -20,13 +36,25 @@ impl Clone for KeyShareEntry<'_> { impl<'a> KeyShareEntry<'a> { pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, ParseError> { let group = NamedGroup::of(buf.read_u16()?).ok_or(ParseError::InvalidData)?; + let opaque_len = buf.read_u16()?; let opaque = buf.slice(opaque_len as usize)?; + Ok(Self { group, opaque: opaque.as_slice(), }) } + + pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), TlsError> { + buf.with_u16_length(|buf| { + buf.push_u16(self.group as u16) + .map_err(|_| TlsError::EncodeError)?; + + buf.with_u16_length(|buf| buf.extend_from_slice(self.opaque)) + .map_err(|_| TlsError::EncodeError) + }) + } } #[cfg(test)] diff --git a/src/extensions/server.rs b/src/extensions/server.rs index b41a2ff9..0a3bbd23 100644 --- a/src/extensions/server.rs +++ b/src/extensions/server.rs @@ -1,5 +1,5 @@ use crate::alert::{AlertDescription, AlertLevel}; -use crate::extensions::common::KeyShareEntry; +use crate::extensions::common::KeyShare; use crate::extensions::ExtensionType; use crate::parse_buffer::{ParseBuffer, ParseError}; use crate::supported_versions::ProtocolVersion; @@ -62,16 +62,6 @@ impl<'a, 'b> Iterator for ServerExtensionParserIterator<'a, 'b> { } } -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct KeyShare<'a>(pub(crate) KeyShareEntry<'a>); - -impl<'a> KeyShare<'a> { - pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, ParseError> { - Ok(KeyShare(KeyShareEntry::parse(buf)?)) - } -} - impl<'a> ServerExtension<'a> { pub fn parse( buf: &mut ParseBuffer<'a>, diff --git a/src/handshake/client_hello.rs b/src/handshake/client_hello.rs index 9e971aee..8963430a 100644 --- a/src/handshake/client_hello.rs +++ b/src/handshake/client_hello.rs @@ -7,6 +7,7 @@ use p256::EncodedPoint; use crate::buffer::*; use crate::config::{TlsCipherSuite, TlsConfig}; use crate::extensions::client::{ClientExtension, PskKeyExchangeMode}; +use crate::extensions::common::{KeyShare, KeyShareEntry}; use crate::handshake::{Random, LEGACY_VERSION}; use crate::named_groups::NamedGroup; use crate::supported_versions::TLS13; @@ -94,10 +95,10 @@ where } .encode(buf)?; - ClientExtension::KeyShare { + ClientExtension::KeyShare(KeyShare(KeyShareEntry { group: NamedGroup::Secp256r1, opaque: public_key, - } + })) .encode(buf)?; if let Some(server_name) = self.config.server_name {