From 947045b9445f15fb9314ba0892efa2251076ae73 Mon Sep 17 00:00:00 2001
From: Thomas Whiteway <thomas.whiteway@metaswitch.com>
Date: Wed, 29 Apr 2020 17:03:44 +0100
Subject: [PATCH] time: notify when resetting a Delay to a time in the past
 (#2290)

If a Delay has been polled, then the task that polled it may be waiting
for a notification.  If the delay gets reset to a time in the past, then
it immediately becomes elapsed, so it should notify the relevant task.
---
 tokio/src/time/driver/entry.rs  | 24 ++++----
 tokio/tests/time_delay.rs       | 22 +++++++-
 tokio/tests/time_delay_queue.rs | 97 +++++++++++++++++++++++++++++++++
 3 files changed, 130 insertions(+), 13 deletions(-)

diff --git a/tokio/src/time/driver/entry.rs b/tokio/src/time/driver/entry.rs
index 20cc824019a..8e1e6b2f92e 100644
--- a/tokio/src/time/driver/entry.rs
+++ b/tokio/src/time/driver/entry.rs
@@ -266,8 +266,9 @@ impl Entry {
         let when = inner.normalize_deadline(deadline);
         let elapsed = inner.elapsed();
 
+        let next = if when <= elapsed { ELAPSED } else { when };
+
         let mut curr = entry.state.load(SeqCst);
-        let mut notify;
 
         loop {
             // In these two cases, there is no work to do when resetting the
@@ -278,16 +279,6 @@ impl Entry {
                 return;
             }
 
-            let next;
-
-            if when <= elapsed {
-                next = ELAPSED;
-                notify = !is_elapsed(curr);
-            } else {
-                next = when;
-                notify = true;
-            }
-
             let actual = entry.state.compare_and_swap(curr, next, SeqCst);
 
             if curr == actual {
@@ -297,7 +288,16 @@ impl Entry {
             curr = actual;
         }
 
-        if notify {
+        // If the state has transitioned to 'elapsed' then wake the task as
+        // this entry is ready to be polled.
+        if !is_elapsed(curr) && is_elapsed(next) {
+            entry.waker.wake();
+        }
+
+        // The driver tracks all non-elapsed entries; notify the driver that it
+        // should update its state for this entry unless the entry had already
+        // elapsed and remains elapsed.
+        if !is_elapsed(curr) || !is_elapsed(next) {
             let _ = inner.queue(entry);
         }
     }
diff --git a/tokio/tests/time_delay.rs b/tokio/tests/time_delay.rs
index e763ae03bec..e4804ec6740 100644
--- a/tokio/tests/time_delay.rs
+++ b/tokio/tests/time_delay.rs
@@ -2,7 +2,7 @@
 #![cfg(feature = "full")]
 
 use tokio::time::{self, Duration, Instant};
-use tokio_test::{assert_pending, task};
+use tokio_test::{assert_pending, assert_ready, task};
 
 macro_rules! assert_elapsed {
     ($now:expr, $ms:expr) => {{
@@ -137,6 +137,26 @@ async fn reset_future_delay_after_fire() {
     assert_elapsed!(now, 110);
 }
 
+#[tokio::test]
+async fn reset_delay_to_past() {
+    time::pause();
+
+    let now = Instant::now();
+
+    let mut delay = task::spawn(time::delay_until(now + ms(100)));
+    assert_pending!(delay.poll());
+
+    time::delay_for(ms(50)).await;
+
+    assert!(!delay.is_woken());
+
+    delay.reset(now + ms(40));
+
+    assert!(delay.is_woken());
+
+    assert_ready!(delay.poll());
+}
+
 #[test]
 #[should_panic]
 fn creating_delay_outside_of_context() {
diff --git a/tokio/tests/time_delay_queue.rs b/tokio/tests/time_delay_queue.rs
index 214b9ebee68..3cf2d1cd059 100644
--- a/tokio/tests/time_delay_queue.rs
+++ b/tokio/tests/time_delay_queue.rs
@@ -443,6 +443,103 @@ async fn insert_after_ready_poll() {
     assert_eq!("3", res[2]);
 }
 
+#[tokio::test]
+async fn reset_later_after_slot_starts() {
+    time::pause();
+
+    let mut queue = task::spawn(DelayQueue::new());
+
+    let now = Instant::now();
+
+    let foo = queue.insert_at("foo", now + ms(100));
+
+    assert_pending!(poll!(queue));
+
+    delay_for(ms(80)).await;
+
+    assert!(!queue.is_woken());
+
+    // At this point the queue hasn't been polled, so `elapsed` on the wheel
+    // for the queue is still at 0 and hence the 1ms resolution slots cover
+    // [0-64).  Resetting the time on the entry to 120 causes it to get put in
+    // the [64-128) slot.  As the queue knows that the first entry is within
+    // that slot, but doesn't know when, it must wake immediately to advance
+    // the wheel.
+    queue.reset_at(&foo, now + ms(120));
+    assert!(queue.is_woken());
+
+    assert_pending!(poll!(queue));
+
+    delay_for(ms(39)).await;
+    assert!(!queue.is_woken());
+
+    delay_for(ms(1)).await;
+    assert!(queue.is_woken());
+
+    let entry = assert_ready_ok!(poll!(queue)).into_inner();
+    assert_eq!(entry, "foo");
+}
+
+#[tokio::test]
+async fn reset_earlier_after_slot_starts() {
+    time::pause();
+
+    let mut queue = task::spawn(DelayQueue::new());
+
+    let now = Instant::now();
+
+    let foo = queue.insert_at("foo", now + ms(200));
+
+    assert_pending!(poll!(queue));
+
+    delay_for(ms(80)).await;
+
+    assert!(!queue.is_woken());
+
+    // At this point the queue hasn't been polled, so `elapsed` on the wheel
+    // for the queue is still at 0 and hence the 1ms resolution slots cover
+    // [0-64).  Resetting the time on the entry to 120 causes it to get put in
+    // the [64-128) slot.  As the queue knows that the first entry is within
+    // that slot, but doesn't know when, it must wake immediately to advance
+    // the wheel.
+    queue.reset_at(&foo, now + ms(120));
+    assert!(queue.is_woken());
+
+    assert_pending!(poll!(queue));
+
+    delay_for(ms(39)).await;
+    assert!(!queue.is_woken());
+
+    delay_for(ms(1)).await;
+    assert!(queue.is_woken());
+
+    let entry = assert_ready_ok!(poll!(queue)).into_inner();
+    assert_eq!(entry, "foo");
+}
+
+#[tokio::test]
+async fn insert_in_past_after_poll_fires_immediately() {
+    time::pause();
+
+    let mut queue = task::spawn(DelayQueue::new());
+
+    let now = Instant::now();
+
+    queue.insert_at("foo", now + ms(200));
+
+    assert_pending!(poll!(queue));
+
+    delay_for(ms(80)).await;
+
+    assert!(!queue.is_woken());
+    queue.insert_at("bar", now + ms(40));
+
+    assert!(queue.is_woken());
+
+    let entry = assert_ready_ok!(poll!(queue)).into_inner();
+    assert_eq!(entry, "bar");
+}
+
 fn ms(n: u64) -> Duration {
     Duration::from_millis(n)
 }