diff --git a/core/Cargo.toml b/core/Cargo.toml index 1c8728ece..b2c3f3b8a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -57,7 +57,7 @@ sysinfo = { version = "0.31.3", default-features = false, features = ["system"] thiserror = "1.0" time = { version = "0.3", features = ["formatting", "parsing"] } tokio = { version = "1", features = ["io-util", "macros", "net", "parking_lot", "rt", "sync", "time"] } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["sync"] } tokio-tungstenite = { version = "*", default-features = false, features = ["rustls-tls-native-roots"] } tokio-util = { version = "0.7", features = ["codec"] } url = "2" diff --git a/core/src/session.rs b/core/src/session.rs index 3594e37f5..d7edb545a 100644 --- a/core/src/session.rs +++ b/core/src/session.rs @@ -19,7 +19,10 @@ use once_cell::sync::OnceCell; use parking_lot::RwLock; use quick_xml::events::Event; use thiserror::Error; -use tokio::{sync::mpsc, time::Instant}; +use tokio::{ + sync::{mpsc, watch}, + time::{sleep_until, Duration as TokioDuration, Instant as TokioInstant}, +}; use tokio_stream::wrappers::UnboundedReceiverStream; use crate::{ @@ -71,6 +74,22 @@ pub struct UserData { pub attributes: UserAttributes, } +#[derive(Debug, Clone, Copy, Default)] +enum KeepAliveState { + #[default] + // No Ping received yet or generally inactive. + Idle, + + // Expecting a Ping from the server, either after startup or after a PongAck. + ExpectingPing(TokioInstant), + + // We need to send a Pong at the given time. + PendingPong(TokioInstant), + + // We just sent a Pong and wait for it be ACK'd. + ExpectingPongAck(TokioInstant), +} + #[derive(Debug, Clone, Default)] struct SessionData { client_id: String, @@ -82,7 +101,6 @@ struct SessionData { time_delta: i64, invalid: bool, user_data: UserData, - last_ping: Option, } struct SessionInternal { @@ -100,6 +118,7 @@ struct SessionInternal { token_provider: OnceCell, cache: Option>, + keep_alive_state: watch::Sender, handle: tokio::runtime::Handle, } @@ -138,6 +157,7 @@ impl Session { mercury: OnceCell::new(), spclient: OnceCell::new(), token_provider: OnceCell::new(), + keep_alive_state: watch::channel(KeepAliveState::Idle).0, handle: tokio::runtime::Handle::current(), })) } @@ -251,10 +271,17 @@ impl Session { .map(Ok) .forward(sink); let receiver_task = DispatchTask(stream, self.weak()); - let timeout_task = Session::session_timeout(self.weak()); + let keep_alive_task = Session::keep_alive_task(self)?; + + // Expect an initial Ping from the server. + self.0 + .keep_alive_state + .send_replace(KeepAliveState::ExpectingPing( + TokioInstant::now() + TokioDuration::from_secs(5), + )); tokio::spawn(async move { - let result = future::try_join3(sender_task, receiver_task, timeout_task).await; + let result = future::try_join3(sender_task, receiver_task, keep_alive_task).await; if let Err(e) = result { error!("{}", e); @@ -302,31 +329,84 @@ impl Session { .get_or_init(|| TokenProvider::new(self.weak())) } - /// Returns an error, when we haven't received a ping for too long (2 minutes), - /// which means that we silently lost connection to Spotify servers. - async fn session_timeout(session: SessionWeak) -> io::Result<()> { - // pings are sent every 2 minutes and a 5 second margin should be fine - const SESSION_TIMEOUT: Duration = Duration::from_secs(125); + /// Returns an error when we haven't received a Ping/PongAck for too long. + /// + /// The expected keepalive sequence is + /// - Server: Ping + /// - wait 60s + /// - Client: Pong + /// - Server: PongAck + /// - wait 60s + /// - repeat + /// + /// This means that we silently lost connection to Spotify servers if + /// - we don't receive a Ping 60s after the last PongAck, or + /// - we don't receive a PongAck immediately after our Pong. + /// + /// Currently, we add a safety margin of 5s to these expected deadlines. + fn keep_alive_task( + session: &Session, + ) -> Result>, Error> { + use KeepAliveState::*; + + let state_tx = session.0.keep_alive_state.clone(); + let state_rx = state_tx.subscribe(); + let mut state_stream = tokio_stream::wrappers::WatchStream::new(state_rx); + + let mut timeout_at = None; + + let session = session.weak(); + Ok(async move { + loop { + tokio::select! { + // Handle keepalive events received via Mercury. + state = state_stream.next() => { + match state { + Some(Idle) => { + timeout_at = None; + } + Some(ExpectingPing(t) | ExpectingPongAck(t)) => { + timeout_at = Some(t); + } + Some(PendingPong(pong_at)) => { + sleep_until(pong_at).await; + { + if let Some(session) = session.try_upgrade() { + debug!("Sending Pong"); + let _ = session.send_packet(PacketType::Pong, vec![0, 0, 0, 0]); + state_tx.send_replace( + ExpectingPongAck(TokioInstant::now() + TokioDuration::from_secs(5)) + ); + } + } + } + None => break, + } + } - while let Some(session) = session.try_upgrade() { - if session.is_invalid() { - break; - } - let last_ping = session.0.data.read().last_ping.unwrap_or_else(Instant::now); - if last_ping.elapsed() >= SESSION_TIMEOUT { - session.shutdown(); - // TODO: Optionally reconnect (with cached/last credentials?) - return Err(io::Error::new( - io::ErrorKind::TimedOut, - "session lost connection to server", - )); + // Watch for timeouts. + _ = async { + if let Some(timeout_at) = timeout_at { + sleep_until(timeout_at).await + } + }, if timeout_at.is_some() => { + if let Some(session) = session.try_upgrade() { + if !session.is_invalid() { + session.shutdown(); + } + } + + // TODO: Optionally reconnect (with cached/last credentials?) + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "session lost connection to server", + )); + } + }; } - // drop the strong reference before sleeping - drop(session); - // a potential timeout cannot occur at least until SESSION_TIMEOUT after the last_ping - tokio::time::sleep_until(last_ping + SESSION_TIMEOUT).await; - } - Ok(()) + + Ok(()) + }) } pub fn time_delta(&self) -> i64 { @@ -362,6 +442,7 @@ impl Session { } fn dispatch(&self, cmd: u8, data: Bytes) -> Result<(), Error> { + use KeepAliveState::*; use PacketType::*; let packet_type = FromPrimitive::from_u8(cmd); @@ -375,36 +456,43 @@ impl Session { match packet_type { Some(Ping) => { - info!("Received Ping"); + debug!("Received Ping"); let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64; let timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or(Duration::ZERO) .as_secs() as i64; - { let mut data = self.0.data.write(); data.time_delta = server_timestamp.saturating_sub(timestamp); - data.last_ping = Some(Instant::now()); } - self.debug_info(); - let session = self.weak(); - tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; - if let Some(session) = session.try_upgrade() { - info!("Sending Pong"); - let _ = session.send_packet(Pong, vec![0, 0, 0, 0]); - } + // Cancel timeout: The next action is for us to send a Pong in 60s + let previous_state = self.0.keep_alive_state.send_replace(PendingPong( + TokioInstant::now() + TokioDuration::from_secs(60), + )); + match previous_state { + ExpectingPing(_) => (), + _ => warn!("Received unexpected Ping from server"), + } - // TODO: Wait for PongAck. Then, wait for next ping and use - // both events in the session timeout detection. - }); + self.debug_info(); Ok(()) } Some(PongAck) => { - info!("Received PongAck"); + debug!("Received PongAck"); + + // Set timeout: The server should ping us again in 60s. + let previous_state = self.0.keep_alive_state.send_replace(ExpectingPing( + TokioInstant::now() + TokioDuration::from_secs(65), + )); + + match previous_state { + ExpectingPongAck(_) => (), + _ => warn!("Received unexpected PongAck from server"), + } + Ok(()) } Some(CountryCode) => { @@ -603,6 +691,7 @@ impl Session { self.0.data.write().invalid = true; self.mercury().shutdown(); self.channel().shutdown(); + self.0.keep_alive_state.send_replace(KeepAliveState::Idle); } pub fn is_invalid(&self) -> bool { @@ -662,7 +751,8 @@ where } }; - if let Err(e) = session.dispatch(cmd, data) { + let result = session.dispatch(cmd, data); + if let Err(e) = result { debug!("could not dispatch command: {}", e); } }