Skip to content

Commit

Permalink
Add BoundedBytes wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexagon committed Nov 28, 2023
1 parent d929673 commit 1903c79
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "tl-proto"
description = "A collection of traits for working with TL serialization/deserialization"
authors = ["Ivan Kalinin <i.kalinin@dexpa.io>"]
repository = "https://github.com/broxus/tl-proto"
version = "0.4.2"
version = "0.4.3"
edition = "2021"
include = ["src/**/*.rs", "README.md"]
license = "MIT"
Expand Down
111 changes: 110 additions & 1 deletion src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,114 @@ where
}
}

/// Bytes slice with a max length bound.
#[derive(Debug)]
#[repr(transparent)]
pub struct BoundedBytes<const N: usize>([u8]);

impl<const N: usize> BoundedBytes<N> {
/// Wraps a byte slice into a new type with length check.
#[inline]
pub const fn try_wrap(bytes: &[u8]) -> Option<&Self> {
if bytes.len() <= N {
// SAFETY: `BoundedBytes` has the same repr as `[u8]`
Some(unsafe { &*(bytes as *const [u8] as *const BoundedBytes<N>) })
} else {
None
}
}

/// Wraps a byte slice into a new type without any checks.
///
/// # Safety
///
/// The following must be true:
/// - `bytes` must have length not greater than `N`
#[inline]
pub unsafe fn wrap_unchecked(bytes: &[u8]) -> &Self {
// SAFETY: `BoundedBytes` has the same repr as `[u8]`
unsafe { &*(bytes as *const [u8] as *const BoundedBytes<N>) }
}
}

impl<const N: usize> AsRef<[u8]> for BoundedBytes<N> {
#[inline]
fn as_ref(&self) -> &[u8] {
&self.0
}
}

impl<const N: usize> AsMut<[u8]> for BoundedBytes<N> {
#[inline]
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0
}
}

impl<const N: usize> std::ops::Deref for BoundedBytes<N> {
type Target = [u8];

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<const N: usize> std::ops::DerefMut for BoundedBytes<N> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<'a, const N: usize> TlRead<'a> for &'a BoundedBytes<N> {
type Repr = Bare;

#[inline]
fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult<Self> {
fn read_bytes_with_max_len<'a>(
packet: &'a [u8],
max_len: usize,
offset: &mut usize,
) -> TlResult<&'a [u8]> {
let current_offset = *offset;
let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset));
if len > max_len {
return Err(TlError::InvalidData);
}

let result = unsafe {
std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len)
};

*offset += prefix_len + len + padding;
Ok(result)
}

let result = ok!(read_bytes_with_max_len(packet, N, offset));

// SAFETY: `len <= N`
Ok(unsafe { BoundedBytes::wrap_unchecked(result) })
}
}

impl<const N: usize> TlWrite for &BoundedBytes<N> {
type Repr = Bare;

#[inline(always)]
fn max_size_hint(&self) -> usize {
bytes_max_size_hint(self.len())
}

#[inline(always)]
fn write_to<P>(&self, packet: &mut P)
where
P: TlPacket,
{
write_bytes(self, packet)
}
}

/// Helper type which is used to represent field value as bytes.
#[derive(Debug, Clone)]
pub struct IntermediateBytes<T>(pub T);
Expand Down Expand Up @@ -425,8 +533,9 @@ impl<R> PartialEq for RawBytes<'_, R> {

impl<R> Copy for RawBytes<'_, R> {}
impl<R> Clone for RawBytes<'_, R> {
#[inline]
fn clone(&self) -> Self {
Self(self.0, std::marker::PhantomData)
*self
}
}

Expand Down
20 changes: 19 additions & 1 deletion test_suite/tests/tl_read.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[allow(dead_code)]
mod tests {
use tl_proto::{BytesMeta, TlError, TlRead, TlResult};
use tl_proto::{BoundedBytes, BytesMeta, TlError, TlRead, TlResult};

#[derive(TlRead)]
struct SimpleStruct {
Expand Down Expand Up @@ -166,4 +166,22 @@ mod tests {
Err(TlError::UnexpectedEof)
));
}

#[test]
fn bounded_bytes() {
#[derive(TlRead)]
struct Data<'tl> {
bytes: &'tl BoundedBytes<4>,
}

let packet = [4, 1, 2, 3, 4, 0, 0, 0];
let Data { bytes } = tl_proto::deserialize(&packet).unwrap();
assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]);

let big_packet = [5, 1, 2, 3, 4, 5, 0, 0];
assert!(matches!(
tl_proto::deserialize::<Data>(&big_packet),
Err(TlError::InvalidData)
));
}
}

0 comments on commit 1903c79

Please sign in to comment.