Skip to content

Commit

Permalink
Add configuration for lazy/gready expect logic
Browse files Browse the repository at this point in the history
Signed-off-by: Maxim Zhiburt <zhiburt@gmail.com>
  • Loading branch information
zhiburt committed Jun 3, 2022
1 parent 4025cce commit 8040332
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl<P, S: Read + NonBlocking> ReplSession<P, S> {
impl<P, S: AsyncRead + Unpin> ReplSession<P, S> {
/// Block until prompt is found
pub async fn expect_prompt(&mut self) -> Result<(), Error> {
self._expect_prompt().await?;
let _ = self._expect_prompt().await?;
Ok(())
}

Expand Down
176 changes: 151 additions & 25 deletions src/session/async_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ impl<P, S> Session<P, S> {
self.stream.set_expect_timeout(expect_timeout);
}

/// Set a expect algorithm to be either gready or lazy.
///
/// Default algorithm is gready.
///
/// See [Session::expect].
pub fn set_expect_lazy(&mut self, is_lazy: bool) {
self.stream.expect_lazy = is_lazy;
}

pub(crate) fn swap_stream<F: FnOnce(S) -> R, R>(
mut self,
new_stream: F,
Expand All @@ -57,17 +66,35 @@ impl<P, S: AsyncRead + Unpin> Session<P, S> {
///
/// If the method returns [Ok] it is guaranteed that at least 1 match was found.
///
/// This make assertions in a lazy manner. Starts from 1st byte then checks 2nd byte and goes further.
/// It is done intentinally to be presize.
/// Here's an example,
/// when you call this method with [crate::Regex] and output contains 123, expect will return ‘1’ as a match not ‘123’.
/// The match algorthm can be either
/// - gready
/// - lazy
///
/// You can set one via [Session::set_expect_lazy].
/// Default version is gready.
///
/// The implications are.
///
/// Imagine you use [crate::Regex] `"\d+"` to find a match.
/// And your process outputs `123`.
/// In case of lazy approach we will match `1`.
/// Where's in case of gready one we will match `123`.
///
/// # Example
///
/// ```
/// # futures_lite::future::block_on(async {
/// let mut p = expectrl::spawn("echo 123").unwrap();
/// let m = p.expect(expectrl::Regex("\\d+")).await.unwrap();
/// assert_eq!(m.get(0).unwrap(), b"123");
/// # });
/// ```
///
/// ```
/// # futures_lite::future::block_on(async {
/// let mut p = expectrl::spawn("echo 123").unwrap();
/// p.set_expect_lazy(true);
/// let m = p.expect(expectrl::Regex("\\d+")).await.unwrap();
/// assert_eq!(m.get(0).unwrap(), b"1");
/// # });
/// ```
Expand All @@ -77,7 +104,10 @@ impl<P, S: AsyncRead + Unpin> Session<P, S> {
/// It returns an error if timeout is reached.
/// You can specify a timeout value by [Session::set_expect_timeout] method.
pub async fn expect<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
self.stream.expect(needle).await
match self.stream.expect_lazy {
true => self.stream.expect_lazy(needle).await,
false => self.stream.expect_gready(needle).await,
}
}

/// Check checks if a pattern is matched.
Expand Down Expand Up @@ -226,6 +256,7 @@ impl<P: Unpin, S: AsyncRead + Unpin> AsyncBufRead for Session<P, S> {
struct Stream<S> {
stream: BufferedStream<S>,
expect_timeout: Option<Duration>,
expect_lazy: bool,
}

impl<S> Stream<S> {
Expand All @@ -234,6 +265,7 @@ impl<S> Stream<S> {
Self {
stream: BufferedStream::new(stream),
expect_timeout: Some(Duration::from_millis(10000)),
expect_lazy: false,
}
}

Expand Down Expand Up @@ -265,7 +297,50 @@ impl<S> Stream<S> {
}

impl<S: AsyncRead + Unpin> Stream<S> {
async fn expect<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
async fn expect_gready<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
let expect_timeout = self.expect_timeout;

println!("{:?}", expect_timeout);

let expect_future = async {
let mut eof = false;
loop {
let data = self.stream.buffer();

let found = Needle::check(&needle, data, eof)?;

if !found.is_empty() {
let end_index = Captures::right_most_index(&found);
let involved_bytes = data[..end_index].to_vec();
self.stream.consume(end_index);

return Ok(Captures::new(involved_bytes, found));
}

if eof {
return Err(Error::Eof);
}

eof = self.stream.fill().await? == 0;
}
};

if let Some(timeout) = expect_timeout {
println!("SETTING TIMEOUT");

let timeout_future = futures_timer::Delay::new(timeout);
futures_lite::future::or(expect_future, async {
timeout_future.await;
println!("TIMEOUT");
Err(Error::ExpectTimeout)
})
.await
} else {
expect_future.await
}
}

async fn expect_lazy<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
let expect_timeout = self.expect_timeout;
let expect_future = async {
// We read by byte to make things as lazy as possible.
Expand Down Expand Up @@ -330,12 +405,7 @@ impl<S: AsyncRead + Unpin> Stream<S> {
/// Is matched checks if a pattern is matched.
/// It doesn't consumes bytes from stream.
async fn is_matched<E: Needle>(&mut self, needle: E) -> Result<bool, Error> {
let eof = match futures_lite::future::poll_once(self.stream.fill()).await {
Some(Ok(n)) => n == 0,
Some(Err(err)) => return Err(err.into()),
None => false,
};

let eof = self.try_fill().await?;
let buf = self.stream.buffer();

let found = needle.check(buf, eof)?;
Expand Down Expand Up @@ -369,11 +439,7 @@ impl<S: AsyncRead + Unpin> Stream<S> {
/// # });
/// ```
async fn check<E: Needle>(&mut self, needle: E) -> Result<Captures, Error> {
let eof = match futures_lite::future::poll_once(self.stream.fill()).await {
Some(Ok(n)) => n == 0,
Some(Err(err)) => return Err(err.into()),
None => false,
};
let eof = self.try_fill().await?;

let buf = self.stream.buffer();
let found = needle.check(buf, eof)?;
Expand All @@ -400,6 +466,14 @@ impl<S: AsyncRead + Unpin> Stream<S> {
None => Ok(true),
}
}

async fn try_fill(&mut self) -> Result<bool, Error> {
match futures_lite::future::poll_once(self.stream.fill()).await {
Some(Ok(n)) => Ok(n == 0),
Some(Err(err)) => return Err(err.into()),
None => Ok(false),
}
}
}

impl<S: AsyncWrite + Unpin> AsyncWrite for Stream<S> {
Expand Down Expand Up @@ -540,26 +614,78 @@ mod tests {
use super::*;

#[test]
fn test_expect() {
fn test_expect_lazy() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let found = stream.expect_lazy("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}

#[test]
fn test_expect_lazy_eof() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let found = stream.expect_lazy(Eof).await.unwrap();
assert_eq!(b"", found.before());
assert_eq!(vec![b"Hello World"], found.matches().collect::<Vec<_>>());
});

let cursor = futures_lite::io::Cursor::new(Vec::new());
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let err = stream.expect_lazy("").await.unwrap_err();
assert!(matches!(err, Error::Eof));
});
}

#[test]
fn test_expect_lazy_timeout() {
futures_lite::future::block_on(async {
let mut stream = Stream::new(NoEofReader::default());
stream.set_expect_timeout(Some(Duration::from_millis(100)));

stream.write_all(b"Hello").await.unwrap();

let err = stream.expect_lazy("Hello World").await.unwrap_err();
assert!(matches!(err, Error::ExpectTimeout));

stream.write_all(b" World").await.unwrap();
let found = stream.expect_lazy("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}

#[test]
fn test_expect_gready() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let found = stream.expect("World").await.unwrap();
let found = stream.expect_gready("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}

#[test]
fn test_expect_eof() {
fn test_expect_gready_eof() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let found = stream.expect(Eof).await.unwrap();
let found = stream.expect_gready(Eof).await.unwrap();
assert_eq!(b"", found.before());
assert_eq!(vec![b"Hello World"], found.matches().collect::<Vec<_>>());
});
Expand All @@ -568,24 +694,24 @@ mod tests {
let mut stream = Stream::new(cursor);

futures_lite::future::block_on(async {
let err = stream.expect("").await.unwrap_err();
let err = stream.expect_gready("").await.unwrap_err();
assert!(matches!(err, Error::Eof));
});
}

#[test]
fn test_expect_timeout() {
fn test_expect_gready_timeout() {
futures_lite::future::block_on(async {
let mut stream = Stream::new(NoEofReader::default());
stream.set_expect_timeout(Some(Duration::from_millis(100)));

stream.write_all(b"Hello").await.unwrap();

let err = stream.expect("Hello World").await.unwrap_err();
let err = stream.expect_gready("Hello World").await.unwrap_err();
assert!(matches!(err, Error::ExpectTimeout));

stream.write_all(b" World").await.unwrap();
let found = stream.expect("World").await.unwrap();
let found = stream.expect_gready("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
Expand Down
Loading

0 comments on commit 8040332

Please sign in to comment.