Skip to content

Commit

Permalink
Tidy up parsing of read operations
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini authored and Jarema committed Jul 25, 2023
1 parent c891cd9 commit f913c83
Showing 1 changed file with 114 additions and 105 deletions.
219 changes: 114 additions & 105 deletions async-nats/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ impl Connection {
/// Attempts to read a server operation from the read buffer.
/// Returns `None` if there is not enough data to parse an entire operation.
pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
let maybe_len = memchr::memmem::find(&self.buffer, b"\r\n");
if maybe_len.is_none() {
return Ok(None);
}

let len = maybe_len.unwrap();
let len = match memchr::memmem::find(&self.buffer, b"\r\n") {
Some(len) => len,
None => return Ok(None),
};

if self.buffer.starts_with(b"+OK") {
self.buffer.advance(len + 2);
Expand All @@ -88,7 +86,7 @@ impl Connection {
let description = str::from_utf8(&self.buffer[5..len])
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
.trim_matches('\'')
.to_string();
.to_owned();

self.buffer.advance(len + 2);

Expand All @@ -109,27 +107,34 @@ impl Connection {
let mut args = line.split(' ').filter(|s| !s.is_empty());

// Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
let subject = args.next();
let sid = args.next();
let mut reply_to = args.next();
let mut payload_len = args.next();
if payload_len.is_none() {
std::mem::swap(&mut reply_to, &mut payload_len);
}

if subject.is_none() || sid.is_none() || payload_len.is_none() || args.next().is_some()
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after MSG",
));
}
let (subject, sid, reply_to, payload_len) = match (
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
) {
(Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
(subject, sid, Some(reply_to), payload_len)
}
(Some(subject), Some(sid), Some(payload_len), None, None) => {
(subject, sid, None, payload_len)
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after MSG",
))
}
};

let sid = u64::from_str(sid.unwrap())
let sid = sid
.parse::<u64>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Parse the number of payload bytes.
let payload_len = usize::from_str(payload_len.unwrap())
let payload_len = payload_len
.parse::<usize>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Return early without advancing if there is not enough data read the entire
Expand All @@ -138,18 +143,19 @@ impl Connection {
return Ok(None);
}

let subject = subject.unwrap().to_owned();
let reply_to = reply_to.map(String::from);
let subject = subject.to_owned();
let reply_to = reply_to.map(ToOwned::to_owned);

self.buffer.advance(len + 2);
let payload = self.buffer.split_to(payload_len).freeze();
self.buffer.advance(2);

let length = payload_len
+ reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
+ subject.len();
return Ok(Some(ServerOp::Message {
sid,
length: payload_len
+ reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
+ subject.len(),
length,
reply: reply_to,
headers: None,
subject,
Expand All @@ -165,44 +171,49 @@ impl Connection {
let mut args = line.split_whitespace().filter(|s| !s.is_empty());

// <subject> <sid> [reply-to] <# header bytes><# total bytes>
let subject = args.next();
let sid = args.next();
let mut reply_to = args.next();
let mut num_header_bytes = args.next();
let mut num_bytes = args.next();
if num_bytes.is_none() {
std::mem::swap(&mut num_header_bytes, &mut num_bytes);
std::mem::swap(&mut reply_to, &mut num_header_bytes);
}

if subject.is_none()
|| sid.is_none()
|| num_header_bytes.is_none()
|| num_bytes.is_none()
|| args.next().is_some()
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after HMSG",
));
}
let (subject, sid, reply_to, header_len, total_len) = match (
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
) {
(
Some(subject),
Some(sid),
Some(reply_to),
Some(header_len),
Some(total_len),
None,
) => (subject, sid, Some(reply_to), header_len, total_len),
(Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
(subject, sid, None, header_len, total_len)
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after HMSG",
))
}
};

// Convert the slice into an owned string.
let subject = subject.unwrap().to_string();
let subject = subject.to_owned();

// Parse the subject ID.
let sid = u64::from_str(sid.unwrap()).map_err(|_| {
let sid = sid.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse sid argument after HMSG",
)
})?;

// Convert the slice into an owned string.
let reply_to = reply_to.map(ToString::to_string);
let reply_to = reply_to.map(ToOwned::to_owned);

// Parse the number of payload bytes.
let num_header_bytes = usize::from_str(num_header_bytes.unwrap()).map_err(|_| {
let header_len = header_len.parse::<usize>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse the number of header bytes argument after \
Expand All @@ -211,104 +222,102 @@ impl Connection {
})?;

// Parse the number of payload bytes.
let num_bytes = usize::from_str(num_bytes.unwrap()).map_err(|_| {
let total_len = total_len.parse::<usize>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse the number of bytes argument after HMSG",
)
})?;

if num_bytes < num_header_bytes {
if total_len < header_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"number of header bytes was greater than or equal to the \
total number of bytes after HMSG",
));
}

if len + num_bytes + 4 > self.buffer.remaining() {
if len + total_len + 4 > self.buffer.remaining() {
return Ok(None);
}

self.buffer.advance(len + 2);
let buffer = self.buffer.split_to(num_header_bytes).freeze();
let payload = self.buffer.split_to(num_bytes - num_header_bytes).freeze();
let header = self.buffer.split_to(header_len);
let payload = self.buffer.split_to(total_len - header_len).freeze();
self.buffer.advance(2);

let mut lines = std::str::from_utf8(&buffer).unwrap().lines().peekable();
let mut lines = std::str::from_utf8(&header)
.map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
})?
.lines()
.peekable();
let version_line = lines.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
})?;

if !version_line.starts_with("NATS/1.0") {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"header version line does not begin with nats/1.0",
));
}
let version_line_suffix = version_line
.strip_prefix("NATS/1.0")
.map(str::trim)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"header version line does not begin with `NATS/1.0`",
)
})?;

let mut maybe_status: Option<StatusCode> = None;
let mut maybe_description: Option<String> = None;
if let Some(slice) = version_line.get("NATS/1.0".len()..).map(|s| s.trim()) {
match slice.split_once(' ') {
Some((status, description)) => {
if !status.is_empty() {
maybe_status = Some(status.trim().parse().map_err(|_| {
std::io::Error::new(
io::ErrorKind::Other,
"could not covert Description header into header value",
)
})?);
}
if !description.is_empty() {
maybe_description = Some(description.trim().to_string());
}
}
None => {
if !slice.is_empty() {
maybe_status = Some(slice.trim().parse().map_err(|_| {
std::io::Error::new(
io::ErrorKind::Other,
"could not covert Description header into header value",
)
})?);
}
}
}
}
let (status, description) = version_line_suffix
.split_once(' ')
.map(|(status, description)| (status.trim(), description.trim()))
.unwrap_or((version_line_suffix, ""));
let status = if !status.is_empty() {
Some(status.parse::<StatusCode>().map_err(|_| {
std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
})?)
} else {
None
};
let description = if !description.is_empty() {
Some(description.to_owned())
} else {
None
};

let mut headers = HeaderMap::new();
while let Some(line) = lines.next() {
if line.is_empty() {
continue;
}

let (key, value) = line.split_once(':').ok_or_else(|| {
let (name, value) = line.split_once(':').ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
})?;

let mut value = value.to_owned();
let name = HeaderName::from_str(name)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Read the header value, which might have been split into multiple lines
// `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
let mut value = value.trim_start().to_owned();
while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
value.push_str(v);
}
value.truncate(value.trim_end().len());

let name = HeaderName::from_str(key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

headers.append(name, value.trim().to_string());
headers.append(name, value);
}

return Ok(Some(ServerOp::Message {
length: reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
length: reply_to.as_ref().map_or(0, |reply| reply.len())
+ subject.len()
+ num_bytes,
+ total_len,
sid,
reply: reply_to,
subject,
headers: Some(headers),
payload,
status: maybe_status,
description: maybe_description,
status,
description,
}));
}

Expand Down

0 comments on commit f913c83

Please sign in to comment.