Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up ServerOp parsing #1052

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 114 additions & 105 deletions async-nats/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ impl Connection {
/// Attempts to read a server operation from the read buffer.
/// Returns `None` if there is not enough data to parse an entire operation.
pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
let maybe_len = memchr::memmem::find(&self.buffer, b"\r\n");
if maybe_len.is_none() {
return Ok(None);
}

let len = maybe_len.unwrap();
let len = match memchr::memmem::find(&self.buffer, b"\r\n") {
Some(len) => len,
None => return Ok(None),
};

if self.buffer.starts_with(b"+OK") {
self.buffer.advance(len + 2);
Expand All @@ -88,7 +86,7 @@ impl Connection {
let description = str::from_utf8(&self.buffer[5..len])
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
.trim_matches('\'')
.to_string();
.to_owned();

self.buffer.advance(len + 2);

Expand All @@ -109,27 +107,34 @@ impl Connection {
let mut args = line.split(' ').filter(|s| !s.is_empty());

// Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
let subject = args.next();
let sid = args.next();
let mut reply_to = args.next();
let mut payload_len = args.next();
if payload_len.is_none() {
std::mem::swap(&mut reply_to, &mut payload_len);
}

if subject.is_none() || sid.is_none() || payload_len.is_none() || args.next().is_some()
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after MSG",
));
}
let (subject, sid, reply_to, payload_len) = match (
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
) {
(Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
(subject, sid, Some(reply_to), payload_len)
}
(Some(subject), Some(sid), Some(payload_len), None, None) => {
(subject, sid, None, payload_len)
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after MSG",
))
}
};

let sid = u64::from_str(sid.unwrap())
let sid = sid
.parse::<u64>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Parse the number of payload bytes.
let payload_len = usize::from_str(payload_len.unwrap())
let payload_len = payload_len
.parse::<usize>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Return early without advancing if there is not enough data read the entire
Expand All @@ -138,18 +143,19 @@ impl Connection {
return Ok(None);
}

let subject = subject.unwrap().to_owned();
let reply_to = reply_to.map(String::from);
let subject = subject.to_owned();
let reply_to = reply_to.map(ToOwned::to_owned);

self.buffer.advance(len + 2);
let payload = self.buffer.split_to(payload_len).freeze();
self.buffer.advance(2);

let length = payload_len
+ reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
+ subject.len();
return Ok(Some(ServerOp::Message {
sid,
length: payload_len
+ reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
+ subject.len(),
length,
reply: reply_to,
headers: None,
subject,
Expand All @@ -165,44 +171,49 @@ impl Connection {
let mut args = line.split_whitespace().filter(|s| !s.is_empty());

// <subject> <sid> [reply-to] <# header bytes><# total bytes>
let subject = args.next();
let sid = args.next();
let mut reply_to = args.next();
let mut num_header_bytes = args.next();
let mut num_bytes = args.next();
if num_bytes.is_none() {
std::mem::swap(&mut num_header_bytes, &mut num_bytes);
std::mem::swap(&mut reply_to, &mut num_header_bytes);
}

if subject.is_none()
|| sid.is_none()
|| num_header_bytes.is_none()
|| num_bytes.is_none()
|| args.next().is_some()
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after HMSG",
));
}
let (subject, sid, reply_to, header_len, total_len) = match (
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
args.next(),
) {
(
Some(subject),
Some(sid),
Some(reply_to),
Some(header_len),
Some(total_len),
None,
) => (subject, sid, Some(reply_to), header_len, total_len),
(Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
(subject, sid, None, header_len, total_len)
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid number of arguments after HMSG",
))
}
};

// Convert the slice into an owned string.
let subject = subject.unwrap().to_string();
let subject = subject.to_owned();

// Parse the subject ID.
let sid = u64::from_str(sid.unwrap()).map_err(|_| {
let sid = sid.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse sid argument after HMSG",
)
})?;

// Convert the slice into an owned string.
let reply_to = reply_to.map(ToString::to_string);
let reply_to = reply_to.map(ToOwned::to_owned);

// Parse the number of payload bytes.
let num_header_bytes = usize::from_str(num_header_bytes.unwrap()).map_err(|_| {
let header_len = header_len.parse::<usize>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse the number of header bytes argument after \
Expand All @@ -211,104 +222,102 @@ impl Connection {
})?;

// Parse the number of payload bytes.
let num_bytes = usize::from_str(num_bytes.unwrap()).map_err(|_| {
let total_len = total_len.parse::<usize>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot parse the number of bytes argument after HMSG",
)
})?;

if num_bytes < num_header_bytes {
if total_len < header_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"number of header bytes was greater than or equal to the \
total number of bytes after HMSG",
));
}

if len + num_bytes + 4 > self.buffer.remaining() {
if len + total_len + 4 > self.buffer.remaining() {
return Ok(None);
}

self.buffer.advance(len + 2);
let buffer = self.buffer.split_to(num_header_bytes).freeze();
let payload = self.buffer.split_to(num_bytes - num_header_bytes).freeze();
let header = self.buffer.split_to(header_len);
let payload = self.buffer.split_to(total_len - header_len).freeze();
self.buffer.advance(2);

let mut lines = std::str::from_utf8(&buffer).unwrap().lines().peekable();
let mut lines = std::str::from_utf8(&header)
.map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
})?
.lines()
.peekable();
let version_line = lines.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
})?;

if !version_line.starts_with("NATS/1.0") {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"header version line does not begin with nats/1.0",
));
}
let version_line_suffix = version_line
.strip_prefix("NATS/1.0")
.map(str::trim)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"header version line does not begin with `NATS/1.0`",
)
})?;

let mut maybe_status: Option<StatusCode> = None;
let mut maybe_description: Option<String> = None;
if let Some(slice) = version_line.get("NATS/1.0".len()..).map(|s| s.trim()) {
match slice.split_once(' ') {
Some((status, description)) => {
if !status.is_empty() {
maybe_status = Some(status.trim().parse().map_err(|_| {
std::io::Error::new(
io::ErrorKind::Other,
"could not covert Description header into header value",
)
})?);
}
if !description.is_empty() {
maybe_description = Some(description.trim().to_string());
}
}
None => {
if !slice.is_empty() {
maybe_status = Some(slice.trim().parse().map_err(|_| {
std::io::Error::new(
io::ErrorKind::Other,
"could not covert Description header into header value",
)
})?);
}
}
}
}
let (status, description) = version_line_suffix
.split_once(' ')
.map(|(status, description)| (status.trim(), description.trim()))
.unwrap_or((version_line_suffix, ""));
let status = if !status.is_empty() {
Some(status.parse::<StatusCode>().map_err(|_| {
std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
})?)
} else {
None
};
let description = if !description.is_empty() {
Some(description.to_owned())
} else {
None
};

let mut headers = HeaderMap::new();
while let Some(line) = lines.next() {
if line.is_empty() {
continue;
}

let (key, value) = line.split_once(':').ok_or_else(|| {
let (name, value) = line.split_once(':').ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
})?;

let mut value = value.to_owned();
let name = HeaderName::from_str(name)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

// Read the header value, which might have been split into multiple lines
// `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
let mut value = value.trim_start().to_owned();
while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
value.push_str(v);
}
value.truncate(value.trim_end().len());

let name = HeaderName::from_str(key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

headers.append(name, value.trim().to_string());
headers.append(name, value);
}

return Ok(Some(ServerOp::Message {
length: reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
length: reply_to.as_ref().map_or(0, |reply| reply.len())
+ subject.len()
+ num_bytes,
+ total_len,
sid,
reply: reply_to,
subject,
headers: Some(headers),
payload,
status: maybe_status,
description: maybe_description,
status,
description,
}));
}

Expand Down