Skip to content

Commit

Permalink
session: more sophisticated timeout handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wisp3rwind committed Oct 1, 2024
1 parent a61f707 commit a6c4945
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 44 deletions.
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ sysinfo = { version = "0.31.3", default-features = false, features = ["system"]
thiserror = "1.0"
time = { version = "0.3", features = ["formatting", "parsing"] }
tokio = { version = "1", features = ["io-util", "macros", "net", "parking_lot", "rt", "sync", "time"] }
tokio-stream = "0.1"
tokio-stream = { version = "0.1", features = ["sync"] }
tokio-tungstenite = { version = "*", default-features = false, features = ["rustls-tls-native-roots"] }
tokio-util = { version = "0.7", features = ["codec"] }
url = "2"
Expand Down
176 changes: 133 additions & 43 deletions core/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use once_cell::sync::OnceCell;
use parking_lot::RwLock;
use quick_xml::events::Event;
use thiserror::Error;
use tokio::{sync::mpsc, time::Instant};
use tokio::{
sync::{mpsc, watch},
time::{sleep_until, Duration as TokioDuration, Instant as TokioInstant},
};
use tokio_stream::wrappers::UnboundedReceiverStream;

use crate::{
Expand Down Expand Up @@ -71,6 +74,22 @@ pub struct UserData {
pub attributes: UserAttributes,
}

#[derive(Debug, Clone, Copy, Default)]
enum KeepAliveState {
#[default]
// No Ping received yet or generally inactive.
Idle,

// Expecting a Ping from the server, either after startup or after a PongAck.
ExpectingPing(TokioInstant),

// We need to send a Pong at the given time.
PendingPong(TokioInstant),

// We just sent a Pong and wait for it be ACK'd.
ExpectingPongAck(TokioInstant),
}

#[derive(Debug, Clone, Default)]
struct SessionData {
client_id: String,
Expand All @@ -82,7 +101,6 @@ struct SessionData {
time_delta: i64,
invalid: bool,
user_data: UserData,
last_ping: Option<Instant>,
}

struct SessionInternal {
Expand All @@ -100,6 +118,7 @@ struct SessionInternal {
token_provider: OnceCell<TokenProvider>,
cache: Option<Arc<Cache>>,

keep_alive_state: watch::Sender<KeepAliveState>,
handle: tokio::runtime::Handle,
}

Expand Down Expand Up @@ -138,6 +157,7 @@ impl Session {
mercury: OnceCell::new(),
spclient: OnceCell::new(),
token_provider: OnceCell::new(),
keep_alive_state: watch::channel(KeepAliveState::Idle).0,
handle: tokio::runtime::Handle::current(),
}))
}
Expand Down Expand Up @@ -251,10 +271,17 @@ impl Session {
.map(Ok)
.forward(sink);
let receiver_task = DispatchTask(stream, self.weak());
let timeout_task = Session::session_timeout(self.weak());
let keep_alive_task = Session::keep_alive_task(self)?;

// Expect an initial Ping from the server.
self.0
.keep_alive_state
.send_replace(KeepAliveState::ExpectingPing(
TokioInstant::now() + TokioDuration::from_secs(5),
));

tokio::spawn(async move {
let result = future::try_join3(sender_task, receiver_task, timeout_task).await;
let result = future::try_join3(sender_task, receiver_task, keep_alive_task).await;

if let Err(e) = result {
error!("{}", e);
Expand Down Expand Up @@ -302,31 +329,84 @@ impl Session {
.get_or_init(|| TokenProvider::new(self.weak()))
}

/// Returns an error, when we haven't received a ping for too long (2 minutes),
/// which means that we silently lost connection to Spotify servers.
async fn session_timeout(session: SessionWeak) -> io::Result<()> {
// pings are sent every 2 minutes and a 5 second margin should be fine
const SESSION_TIMEOUT: Duration = Duration::from_secs(125);
/// Returns an error when we haven't received a Ping/PongAck for too long.
///
/// The expected keepalive sequence is
/// - Server: Ping
/// - wait 60s
/// - Client: Pong
/// - Server: PongAck
/// - wait 60s
/// - repeat
///
/// This means that we silently lost connection to Spotify servers if
/// - we don't receive a Ping 60s after the last PongAck, or
/// - we don't receive a PongAck immediately after our Pong.
///
/// Currently, we add a safety margin of 5s to these expected deadlines.
fn keep_alive_task(
session: &Session,
) -> Result<impl Future<Output = Result<(), io::Error>>, Error> {
use KeepAliveState::*;

let state_tx = session.0.keep_alive_state.clone();
let state_rx = state_tx.subscribe();
let mut state_stream = tokio_stream::wrappers::WatchStream::new(state_rx);

let mut timeout_at = None;

let session = session.weak();
Ok(async move {
loop {
tokio::select! {
// Handle keepalive events received via Mercury.
state = state_stream.next() => {
match state {
Some(Idle) => {
timeout_at = None;
}
Some(ExpectingPing(t) | ExpectingPongAck(t)) => {
timeout_at = Some(t);
}
Some(PendingPong(pong_at)) => {
sleep_until(pong_at).await;
{
if let Some(session) = session.try_upgrade() {
debug!("Sending Pong");
let _ = session.send_packet(PacketType::Pong, vec![0, 0, 0, 0]);
state_tx.send_replace(
ExpectingPongAck(TokioInstant::now() + TokioDuration::from_secs(5))
);
}
}
}
None => break,
}
}

while let Some(session) = session.try_upgrade() {
if session.is_invalid() {
break;
}
let last_ping = session.0.data.read().last_ping.unwrap_or_else(Instant::now);
if last_ping.elapsed() >= SESSION_TIMEOUT {
session.shutdown();
// TODO: Optionally reconnect (with cached/last credentials?)
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"session lost connection to server",
));
// Watch for timeouts.
_ = async {
if let Some(timeout_at) = timeout_at {
sleep_until(timeout_at).await
}
}, if timeout_at.is_some() => {
if let Some(session) = session.try_upgrade() {
if !session.is_invalid() {
session.shutdown();
}
}

// TODO: Optionally reconnect (with cached/last credentials?)
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"session lost connection to server",
));
}
};
}
// drop the strong reference before sleeping
drop(session);
// a potential timeout cannot occur at least until SESSION_TIMEOUT after the last_ping
tokio::time::sleep_until(last_ping + SESSION_TIMEOUT).await;
}
Ok(())

Ok(())
})
}

pub fn time_delta(&self) -> i64 {
Expand Down Expand Up @@ -362,6 +442,7 @@ impl Session {
}

fn dispatch(&self, cmd: u8, data: Bytes) -> Result<(), Error> {
use KeepAliveState::*;
use PacketType::*;

let packet_type = FromPrimitive::from_u8(cmd);
Expand All @@ -375,36 +456,43 @@ impl Session {

match packet_type {
Some(Ping) => {
info!("Received Ping");
debug!("Received Ping");
let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_secs() as i64;

{
let mut data = self.0.data.write();
data.time_delta = server_timestamp.saturating_sub(timestamp);
data.last_ping = Some(Instant::now());
}

self.debug_info();
let session = self.weak();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
if let Some(session) = session.try_upgrade() {
info!("Sending Pong");
let _ = session.send_packet(Pong, vec![0, 0, 0, 0]);
}
// Cancel timeout: The next action is for us to send a Pong in 60s
let previous_state = self.0.keep_alive_state.send_replace(PendingPong(
TokioInstant::now() + TokioDuration::from_secs(60),
));
match previous_state {
ExpectingPing(_) => (),
_ => warn!("Received unexpected Ping from server"),
}

// TODO: Wait for PongAck. Then, wait for next ping and use
// both events in the session timeout detection.
});
self.debug_info();

Ok(())
}
Some(PongAck) => {
info!("Received PongAck");
debug!("Received PongAck");

// Set timeout: The server should ping us again in 60s.
let previous_state = self.0.keep_alive_state.send_replace(ExpectingPing(
TokioInstant::now() + TokioDuration::from_secs(65),
));

match previous_state {
ExpectingPongAck(_) => (),
_ => warn!("Received unexpected PongAck from server"),
}

Ok(())
}
Some(CountryCode) => {
Expand Down Expand Up @@ -603,6 +691,7 @@ impl Session {
self.0.data.write().invalid = true;
self.mercury().shutdown();
self.channel().shutdown();
self.0.keep_alive_state.send_replace(KeepAliveState::Idle);
}

pub fn is_invalid(&self) -> bool {
Expand Down Expand Up @@ -662,7 +751,8 @@ where
}
};

if let Err(e) = session.dispatch(cmd, data) {
let result = session.dispatch(cmd, data);
if let Err(e) = result {
debug!("could not dispatch command: {}", e);
}
}
Expand Down

0 comments on commit a6c4945

Please sign in to comment.