diff --git a/bindings/matrix-sdk-ffi/src/authentication_service.rs b/bindings/matrix-sdk-ffi/src/authentication_service.rs index 426fa64764c..7adc787c2ed 100644 --- a/bindings/matrix-sdk-ffi/src/authentication_service.rs +++ b/bindings/matrix-sdk-ffi/src/authentication_service.rs @@ -117,7 +117,7 @@ impl AuthenticationService { // Restore the client using the session from the login request. client - .restore_session(session.clone()) + .restore_session(session) .map(|_| client.clone()) .map_err(AuthenticationError::from) } @@ -147,6 +147,7 @@ impl AuthenticationService { let discovery_session = Session { access_token: token.clone(), + refresh_token: None, user_id: discovery_user_id, device_id: device_id.clone(), }; @@ -156,8 +157,12 @@ impl AuthenticationService { // Create the actual client with a store path from the user ID. let homeserver_url = client.homeserver(); - let session = - Session { access_token: token, user_id: whoami.user_id.clone(), device_id }; + let session = Session { + access_token: token, + refresh_token: None, + user_id: whoami.user_id.clone(), + device_id, + }; let client = Arc::new(ClientBuilder::new()) .base_path(self.base_path.clone()) .homeserver_url(homeserver_url) diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index 25bccc5e16b..54a181ad142 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -184,7 +184,7 @@ impl Client { pub fn restore_token(&self) -> anyhow::Result { RUNTIME.block_on(async move { - let session = self.client.session().expect("Missing session").clone(); + let session = self.client.session().expect("Missing session"); let homeurl = self.client.homeserver().await.into(); Ok(serde_json::to_string(&RestoreToken { session, diff --git a/crates/matrix-sdk-appservice/src/virtual_user.rs b/crates/matrix-sdk-appservice/src/virtual_user.rs index ec49e61c906..f344e6cad07 100644 --- a/crates/matrix-sdk-appservice/src/virtual_user.rs +++ b/crates/matrix-sdk-appservice/src/virtual_user.rs @@ -128,6 +128,7 @@ impl<'a> VirtualUserBuilder<'a> { Session { access_token: response.access_token, + refresh_token: response.refresh_token, user_id: response.user_id, device_id: response.device_id, } @@ -135,6 +136,7 @@ impl<'a> VirtualUserBuilder<'a> { // Don’t log in Session { access_token: self.appservice.registration.as_token.clone(), + refresh_token: None, user_id: user_id.clone(), device_id: self.device_id.unwrap_or_else(DeviceId::new), } diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index 6af59b9a16a..72e6a77d744 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -30,6 +30,7 @@ async-trait = "0.1.53" dashmap = "5.2.0" futures-channel = "0.3.21" futures-core = "0.3.21" +futures-signals = { version = "0.3.30", default-features = false } futures-util = { version = "0.3.21", default-features = false } http = { version = "0.2.6", optional = true } lru = "0.7.5" diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 6786a691e04..5af8d7527c3 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -21,6 +21,7 @@ use std::{ #[cfg(feature = "e2e-encryption")] use std::{ops::Deref, sync::Arc}; +use futures_signals::signal::ReadOnlyMutable; #[cfg(feature = "experimental-timeline")] use matrix_sdk_common::deserialized_responses::TimelineSlice; use matrix_sdk_common::{ @@ -62,11 +63,10 @@ use crate::error::Error; use crate::{ error::Result, rooms::{Room, RoomInfo, RoomType}, - session::Session, store::{ ambiguity_map::AmbiguityCache, Result as StoreResult, StateChanges, Store, StoreConfig, }, - StateStore, + Session, SessionMeta, SessionTokens, StateStore, }; /// A no IO Client implementation. @@ -91,7 +91,8 @@ pub struct BaseClient { impl fmt::Debug for BaseClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Client") - .field("session", &self.session()) + .field("session_meta", &self.store.session_meta()) + .field("session_tokens", &self.store.session_tokens) .field("sync_token", &self.store.sync_token) .finish() } @@ -124,6 +125,29 @@ impl BaseClient { } } + /// Get the session meta information. + /// + /// If the client is currently logged in, this will return a + /// [`SessionMeta`] object which contains the user ID and device ID. + /// Otherwise it returns `None`. + pub fn session_meta(&self) -> Option<&SessionMeta> { + self.store.session_meta() + } + + /// Get the session tokens. + /// + /// If the client is currently logged in, this will return a + /// [`SessionTokens`] object which contains the access token and optional + /// refresh token. Otherwise it returns `None`. + pub fn session_tokens(&self) -> ReadOnlyMutable> { + self.store.session_tokens() + } + + /// Set the session tokens. + pub fn set_session_tokens(&self, tokens: SessionTokens) { + self.store.set_session_tokens(tokens) + } + /// Get the user login session. /// /// If the client is currently logged in, this will return a @@ -132,7 +156,7 @@ impl BaseClient { /// /// Returns a session object if the client is logged in. Otherwise returns /// `None`. - pub fn session(&self) -> Option<&Session> { + pub fn session(&self) -> Option { self.store.session() } @@ -154,7 +178,7 @@ impl BaseClient { /// Is the client logged in. pub fn logged_in(&self) -> bool { - self.store.session().is_some() + self.store.session_meta().is_some() } /// Receive a login response and update the session of the client. @@ -169,6 +193,7 @@ impl BaseClient { ) -> Result<()> { let session = Session { access_token: response.access_token.clone(), + refresh_token: response.refresh_token.clone(), device_id: response.device_id.clone(), user_id: response.user_id.clone(), }; @@ -1000,8 +1025,8 @@ impl BaseClient { .transpose()? { Ok(event.content.global) - } else if let Some(session) = self.store.session() { - Ok(Ruleset::server_default(&session.user_id)) + } else if let Some(session_meta) = self.store.session_meta() { + Ok(Ruleset::server_default(&session_meta.user_id)) } else { Ok(Ruleset::new()) } @@ -1137,6 +1162,7 @@ mod tests { client .restore_login(Session { access_token: "token".to_owned(), + refresh_token: None, user_id: user_id.to_owned(), device_id: "FOOBAR".into(), }) @@ -1191,6 +1217,7 @@ mod tests { client .restore_login(Session { access_token: "token".to_owned(), + refresh_token: None, user_id: user_id.to_owned(), device_id: "FOOBAR".into(), }) diff --git a/crates/matrix-sdk-base/src/lib.rs b/crates/matrix-sdk-base/src/lib.rs index b7bbaeecb92..723e59d59e8 100644 --- a/crates/matrix-sdk-base/src/lib.rs +++ b/crates/matrix-sdk-base/src/lib.rs @@ -23,7 +23,7 @@ pub use matrix_sdk_common::*; pub use crate::timeline_stream::TimelineStreamError; pub use crate::{ error::{Error, Result}, - session::Session, + session::{Session, SessionMeta, SessionTokens}, }; mod client; diff --git a/crates/matrix-sdk-base/src/session.rs b/crates/matrix-sdk-base/src/session.rs index d48ae0a417b..de3b30539fd 100644 --- a/crates/matrix-sdk-base/src/session.rs +++ b/crates/matrix-sdk-base/src/session.rs @@ -15,11 +15,11 @@ //! User sessions. -use ruma::{OwnedDeviceId, OwnedUserId}; +use ruma::{api::client::session::refresh_token, OwnedDeviceId, OwnedUserId}; use serde::{Deserialize, Serialize}; -/// A user session, containing an access token and information about the -/// associated user account. +/// A user session, containing an access token, an optional refresh token and +/// information about the associated user account. /// /// # Example /// @@ -29,6 +29,7 @@ use serde::{Deserialize, Serialize}; /// /// let session = Session { /// access_token: "My-Token".to_owned(), +/// refresh_token: None, /// user_id: user_id!("@example:localhost").to_owned(), /// device_id: device_id!("MYDEVICEID").to_owned(), /// }; @@ -39,18 +40,70 @@ use serde::{Deserialize, Serialize}; pub struct Session { /// The access token used for this session. pub access_token: String, + /// The token used for [refreshing the access token], if any. + /// + /// [refreshing the access token]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, /// The user the access token was issued for. pub user_id: OwnedUserId, - /// The ID of the client device + /// The ID of the client device. pub device_id: OwnedDeviceId, } +impl Session { + /// Creates a `Session` from a `SessionMeta` and `SessionTokens`. + pub fn from_parts(meta: SessionMeta, tokens: SessionTokens) -> Self { + let SessionMeta { user_id, device_id } = meta; + let SessionTokens { access_token, refresh_token } = tokens; + Self { access_token, refresh_token, user_id, device_id } + } + + /// Split this `Session` between `SessionMeta` and `SessionTokens`. + pub fn into_parts(self) -> (SessionMeta, SessionTokens) { + let Self { access_token, refresh_token, user_id, device_id } = self; + (SessionMeta { user_id, device_id }, SessionTokens { access_token, refresh_token }) + } +} + impl From for Session { fn from(response: ruma::api::client::session::login::v3::Response) -> Self { Self { access_token: response.access_token, + refresh_token: response.refresh_token, user_id: response.user_id, device_id: response.device_id, } } } + +/// The immutable parts of the session: the user ID and device ID. +#[derive(Clone, Debug)] +pub struct SessionMeta { + /// The user the access token was issued for. + pub user_id: OwnedUserId, + /// The ID of the client device. + pub device_id: OwnedDeviceId, +} + +/// The mutable parts of the session: the access token and optional refresh +/// token. +#[derive(Clone, Debug)] +pub struct SessionTokens { + /// The access token used for this session. + pub access_token: String, + /// The token used for [refreshing the access token], if any. + /// + /// [refreshing the access token]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + pub refresh_token: Option, +} + +impl SessionTokens { + /// Update this `SessionTokens` with the values found in the given response. + pub fn update_with_refresh_response(&mut self, response: &refresh_token::v3::Response) { + self.access_token = response.access_token.clone(); + if let Some(refresh_token) = response.refresh_token.clone() { + self.refresh_token = Some(refresh_token); + } + } +} diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index a135d782c46..4dade2e0dec 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -28,6 +28,7 @@ use std::{ sync::Arc, }; +use futures_signals::signal::{Mutable, ReadOnlyMutable}; use once_cell::sync::OnceCell; #[cfg(any(test, feature = "testing"))] @@ -61,7 +62,7 @@ use crate::{ deserialized_responses::MemberEvent, media::MediaRequest, rooms::{RoomInfo, RoomType}, - MinimalRoomMemberEvent, Room, Session, + MinimalRoomMemberEvent, Room, Session, SessionMeta, SessionTokens, }; pub(crate) mod ambiguity_map; @@ -409,7 +410,8 @@ where #[derive(Debug, Clone)] pub(crate) struct Store { pub(super) inner: Arc, - session: Arc>, + session_meta: Arc>, + pub(super) session_tokens: Mutable>, /// The current sync token that should be used for the next sync call. pub(super) sync_token: Arc>>, rooms: Arc>, @@ -430,7 +432,8 @@ impl Store { pub fn new(inner: Arc) -> Self { Self { inner, - session: Default::default(), + session_meta: Default::default(), + session_tokens: Default::default(), sync_token: Default::default(), rooms: Default::default(), stripped_rooms: Default::default(), @@ -451,17 +454,37 @@ impl Store { } let token = self.get_sync_token().await?; - *self.sync_token.write().await = token; - self.session.set(session).expect("A session was already set"); + + let (session_meta, session_tokens) = session.into_parts(); + self.session_meta.set(session_meta).expect("Session IDs were already set"); + self.session_tokens.set(Some(session_tokens)); Ok(()) } - /// The current [`Session`] containing our user id, device ID and access - /// token. - pub fn session(&self) -> Option<&Session> { - self.session.get() + /// The current [`SessionMeta`] containing our user ID and device ID. + pub fn session_meta(&self) -> Option<&SessionMeta> { + self.session_meta.get() + } + + /// The current [`SessionTokens`] containing our access token and optional + /// refresh token. + pub fn session_tokens(&self) -> ReadOnlyMutable> { + self.session_tokens.read_only() + } + + /// Set the current [`SessionTokens`]. + pub fn set_session_tokens(&self, tokens: SessionTokens) { + self.session_tokens.set(Some(tokens)); + } + + /// The current [`Session`] containing our user id, device ID, access + /// token and optional refresh token. + pub fn session(&self) -> Option { + let meta = self.session_meta.get()?; + let tokens = self.session_tokens.get_cloned()?; + Some(Session::from_parts(meta.to_owned(), tokens)) } /// Get all the rooms this store knows about. @@ -494,7 +517,8 @@ impl Store { /// Lookup the stripped Room for the given RoomId, or create one, if it /// didn't exist yet in the store pub async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room { - let user_id = &self.session.get().expect("Creating room while not being logged in").user_id; + let user_id = + &self.session_meta.get().expect("Creating room while not being logged in").user_id; self.stripped_rooms .entry(room_id.to_owned()) @@ -509,7 +533,8 @@ impl Store { return self.get_or_create_stripped_room(room_id).await; } - let user_id = &self.session.get().expect("Creating room while not being logged in").user_id; + let user_id = + &self.session_meta.get().expect("Creating room while not being logged in").user_id; self.rooms .entry(room_id.to_owned()) diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index 27fa828a9bb..281c85d7c45 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -65,6 +65,7 @@ dashmap = "5.2.0" event-listener = "2.5.2" eyre = { version = "0.6.8", optional = true } futures-core = "0.3.21" +futures-signals = { version = "0.3.30", default-features = false } futures-util = { version = "0.3.21", default-features = false } http = "0.2.6" matrix-sdk-common = { version = "0.5.0", path = "../matrix-sdk-common" } diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index f8650a2770b..8d2350c9c05 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -2,7 +2,11 @@ use std::sync::Arc; #[cfg(target_arch = "wasm32")] use async_once_cell::OnceCell; -use matrix_sdk_base::{locks::RwLock, store::StoreConfig, BaseClient, StateStore}; +use matrix_sdk_base::{ + locks::{Mutex, RwLock}, + store::StoreConfig, + BaseClient, StateStore, +}; use ruma::{ api::{client::discovery::discover_homeserver, error::FromHttpResponseError, MatrixVersion}, OwnedServerName, ServerName, UserId, @@ -68,6 +72,7 @@ pub struct ClientBuilder { respect_login_well_known: bool, appservice_mode: bool, server_versions: Option>, + handle_refresh_tokens: bool, } impl ClientBuilder { @@ -80,6 +85,7 @@ impl ClientBuilder { respect_login_well_known: true, appservice_mode: false, server_versions: None, + handle_refresh_tokens: false, } } @@ -305,6 +311,32 @@ impl ClientBuilder { self.http_cfg.get_or_insert_with(Default::default).settings() } + /// Handle [refreshing access tokens] automatically. + /// + /// By default, the `Client` forwards any error and doesn't handle errors + /// with the access token, which means that + /// [`Client::refresh_access_token()`] needs to be called manually to + /// refresh access tokens. + /// + /// Enabling this setting means that the `Client` will try to refresh the + /// token automatically, which means that: + /// + /// * If refreshing the token fails, the error is forwarded, so any endpoint + /// can return [`HttpError::RefreshToken`]. If an [`UnknownToken`] error + /// is encountered, it means that the user needs to be logged in again. + /// + /// * The access token and refresh token need to be watched for changes, + /// using [`Client::session_tokens_signal()`] for example, to be able to + /// [restore the session] later. + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + /// [`UnknownToken`]: ruma::api::client::error::ErrorKind::UnknownToken + /// [restore the session]: Client::restore_login + pub fn handle_refresh_tokens(mut self) -> Self { + self.handle_refresh_tokens = true; + self + } + /// Create a [`Client`] with the options set on this builder. /// /// # Errors @@ -385,6 +417,8 @@ impl ClientBuilder { appservice_mode: self.appservice_mode, respect_login_well_known: self.respect_login_well_known, sync_beat: event_listener::Event::new(), + handle_refresh_tokens: self.handle_refresh_tokens, + refresh_token_lock: Mutex::new(Ok(())), }); Ok(Client { inner }) diff --git a/crates/matrix-sdk/src/client/login_builder.rs b/crates/matrix-sdk/src/client/login_builder.rs index 91c042b4762..f4ac63e9c69 100644 --- a/crates/matrix-sdk/src/client/login_builder.rs +++ b/crates/matrix-sdk/src/client/login_builder.rs @@ -60,11 +60,18 @@ pub struct LoginBuilder<'a> { login_method: LoginMethod<'a>, device_id: Option<&'a str>, initial_device_display_name: Option<&'a str>, + request_refresh_token: bool, } impl<'a> LoginBuilder<'a> { fn new(client: Client, login_method: LoginMethod<'a>) -> Self { - Self { client, login_method, device_id: None, initial_device_display_name: None } + Self { + client, + login_method, + device_id: None, + initial_device_display_name: None, + request_refresh_token: false, + } } pub(super) fn new_password(client: Client, id: UserIdentifier<'a>, password: &'a str) -> Self { @@ -96,6 +103,24 @@ impl<'a> LoginBuilder<'a> { self } + /// Advertise support for [refreshing access tokens]. + /// + /// By default, the `Client` won't handle refreshing access tokens, so + /// [`Client::refresh_access_token()`] needs to be called manually. + /// + /// This behavior can be changed by calling + /// [`handle_refresh_tokens()`] when building the `Client`. + /// + /// *Note* that refreshing access tokens might not be supported or might be + /// enforced by the homeserver regardless of this setting. + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + /// [`handle_refresh_tokens()`]: crate::ClientBuilder::handle_refresh_tokens + pub fn request_refresh_token(mut self) -> Self { + self.request_refresh_token = true; + self + } + /// Send the login request. #[instrument( target = "matrix_sdk::client", @@ -110,6 +135,7 @@ impl<'a> LoginBuilder<'a> { let request = assign!(login::v3::Request::new(self.login_method.to_login_info()), { device_id: self.device_id.map(Into::into), initial_device_display_name: self.initial_device_display_name, + refresh_token: self.request_refresh_token, }); let response = self.client.send(request, Some(RequestConfig::short_retry())).await?; @@ -133,6 +159,7 @@ pub struct SsoLoginBuilder<'a, F> { server_url: Option<&'a str>, server_response: Option<&'a str>, identity_provider_id: Option<&'a str>, + request_refresh_token: bool, } #[cfg(all(feature = "sso-login", not(target_arch = "wasm32")))] @@ -150,6 +177,7 @@ where server_url: None, server_response: None, identity_provider_id: None, + request_refresh_token: false, } } @@ -199,6 +227,24 @@ where self } + /// Advertise support for [refreshing access tokens]. + /// + /// By default, the `Client` won't handle refreshing access tokens, so + /// [`Client::refresh_access_token()`] needs to be called manually. + /// + /// This behavior can be changed by calling + /// [`handle_refresh_tokens()`] when building the `Client`. + /// + /// *Note* that refreshing access tokens might not be supported or might be + /// enforced by the homeserver regardless of this setting. + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + /// [`handle_refresh_tokens()`]: crate::ClientBuilder::handle_refresh_tokens + pub fn request_refresh_token(mut self) -> Self { + self.request_refresh_token = true; + self + } + /// Send the login request. #[instrument(target = "matrix_sdk::client", name = "login", skip_all, fields(method = "sso"))] pub async fn send(self) -> Result { @@ -302,6 +348,7 @@ where let login_builder = LoginBuilder { device_id: self.device_id, initial_device_display_name: self.initial_device_display_name, + request_refresh_token: self.request_refresh_token, ..LoginBuilder::new_token(self.client, &token) }; login_builder.send().await diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index bff88ffa6bc..f9b0f341af2 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -30,10 +30,11 @@ use anymap2::any::CloneAnySendSync; use async_once_cell::OnceCell; use dashmap::DashMap; use futures_core::stream::Stream; +use futures_signals::signal::Signal; use matrix_sdk_base::{ deserialized_responses::SyncResponse, media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize}, - BaseClient, Session, StateStore, + BaseClient, Session, SessionMeta, SessionTokens, StateStore, }; use matrix_sdk_common::{ instant::{Duration, Instant}, @@ -53,16 +54,17 @@ use ruma::{ get_capabilities::{self, Capabilities}, get_supported_versions, }, + error::ErrorKind, filter::{create_filter::v3::Request as FilterUploadRequest, FilterDefinition}, media::{create_content, get_content, get_content_thumbnail}, membership::{join_room_by_id, join_room_by_id_or_alias}, push::get_notifications::v3::Notification, room::create_room, - session::{get_login_types, login, sso_login, sso_login_with_provider}, + session::{get_login_types, login, refresh_token, sso_login, sso_login_with_provider}, sync::sync_events, uiaa::{AuthData, UserIdentifier}, }, - error::FromHttpResponseError, + error::{FromHttpResponseError, ServerError}, MatrixVersion, OutgoingRequest, SendAccessToken, }, assign, @@ -88,7 +90,7 @@ use crate::{ EventHandlerResult, EventHandlerWrapper, SyncEvent, }, http_client::HttpClient, - room, Account, Error, Result, + room, Account, Error, RefreshTokenError, Result, RumaApiError, }; mod builder; @@ -172,6 +174,11 @@ pub(crate) struct ClientInner { /// Whether the client should update its homeserver URL with the discovery /// information present in the login response. respect_login_well_known: bool, + /// Whether to try to refresh the access token automatically when an + /// `M_UNKNOWN_TOKEN` error is encountered. + handle_refresh_tokens: bool, + /// Lock making sure we're only doing one token refresh at a time. + refresh_token_lock: Mutex>, /// An event that can be listened on to wait for a successful sync. The /// event will only be fired if a sync loop is running. Can be used for /// synchronization, e.g. if we send out a request to create a room, we can @@ -310,14 +317,145 @@ impl Client { } } + fn session_meta(&self) -> Option<&SessionMeta> { + self.base_client().session_meta() + } + /// Get the user id of the current owner of the client. pub fn user_id(&self) -> Option<&UserId> { - self.session().map(|s| s.user_id.as_ref()) + self.session_meta().map(|s| s.user_id.as_ref()) } /// Get the device ID that identifies the current session. pub fn device_id(&self) -> Option<&DeviceId> { - self.session().map(|s| s.device_id.as_ref()) + self.session_meta().map(|s| s.device_id.as_ref()) + } + + /// Get the current access token and optional refresh token for this + /// session. + /// + /// Will be `None` if the client has not been logged in. + /// + /// After login, the tokens should only change if support for [refreshing + /// access tokens] has been enabled. + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + pub fn session_tokens(&self) -> Option { + self.base_client().session_tokens().get_cloned() + } + + /// [`Signal`] to get notified when the current access token and optional + /// refresh token for this session change. + /// + /// This can be used with [`Client::session()`] to persist the [`Session`] + /// when the tokens change. + /// + /// After login, the tokens should only change if support for [refreshing + /// access tokens] has been enabled. + /// + /// # Example + /// + /// ```no_run + /// use futures_signals::signal::SignalExt; + /// use matrix_sdk::Client; + /// # use matrix_sdk::Session; + /// # use futures::executor::block_on; + /// # block_on(async { + /// # fn persist_session(_: Option) {}; + /// + /// let homeserver = "http://example.com"; + /// let client = Client::builder() + /// .homeserver_url(homeserver) + /// .handle_refresh_tokens() + /// .build() + /// .await?; + /// + /// let response = client + /// .login_username("user", "wordpass") + /// .initial_device_display_name("My App") + /// .request_refresh_token() + /// .send() + /// .await?; + /// + /// persist_session(client.session()); + /// + /// // Handle when at least one of the tokens changed. + /// let future = client.session_tokens_changed_signal().for_each(move |_| { + /// let client = client.clone(); + /// async move { + /// persist_session(client.session()); + /// } + /// }); + /// + /// tokio::spawn(future); + /// + /// # anyhow::Ok(()) }); + /// ``` + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + pub fn session_tokens_changed_signal(&self) -> impl Signal { + self.base_client().session_tokens().signal_ref(|_| ()) + } + + /// Get the current access token and optional refresh token for this + /// session as a [`Signal`]. + /// + /// This can be used to watch changes of the tokens by calling methods like + /// `for_each()` or `to_stream()`. + /// + /// The value will be `None` if the client has not been logged in. + /// + /// After login, the tokens should only change if support for [refreshing + /// access tokens] has been enabled. + /// + /// # Example + /// + /// ```no_run + /// use futures::StreamExt; + /// use futures_signals::signal::SignalExt; + /// use matrix_sdk::Client; + /// # use matrix_sdk::Session; + /// # use futures::executor::block_on; + /// # block_on(async { + /// # fn persist_session(_: &Session) {}; + /// + /// let homeserver = "http://example.com"; + /// let client = Client::builder() + /// .homeserver_url(homeserver) + /// .handle_refresh_tokens() + /// .build() + /// .await?; + /// + /// client + /// .login_username("user", "wordpass") + /// .initial_device_display_name("My App") + /// .request_refresh_token() + /// .send() + /// .await?; + /// + /// let mut session = client.session().expect("Client should be logged in"); + /// persist_session(&session); + /// + /// // Handle when at least one of the tokens changed. + /// let mut tokens_stream = client.session_tokens_signal().to_stream(); + /// loop { + /// if let Some(tokens) = tokens_stream.next().await.flatten() { + /// session.access_token = tokens.access_token; + /// + /// if let Some(refresh_token) = tokens.refresh_token { + /// session.refresh_token = Some(refresh_token); + /// } + /// + /// persist_session(&session); + /// } + /// } + /// + /// # anyhow::Ok(()) }); + /// ``` + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + pub fn session_tokens_signal(&self) -> impl Signal> { + self.base_client().session_tokens().signal_cloned() } /// Get the whole session info of this client. @@ -326,7 +464,7 @@ impl Client { /// /// Can be used with [`Client::restore_login`] to restore a previously /// logged-in session. - pub fn session(&self) -> Option<&Session> { + pub fn session(&self) -> Option { self.base_client().session() } @@ -1155,6 +1293,7 @@ impl Client { /// /// let session = Session { /// access_token: "My-Token".to_owned(), + /// refresh_token: None, /// user_id: user_id!("@example:localhost").to_owned(), /// device_id: device_id!("MYDEVICEID").to_owned(), /// }; @@ -1188,6 +1327,157 @@ impl Client { Ok(self.inner.base_client.restore_login(session).await?) } + /// Refresh the access token. + /// + /// When support for [refreshing access tokens] is activated on both the + /// homeserver and the client, access tokens have an expiration date and + /// need to be refreshed periodically. To activate support for refresh + /// tokens in the [`Client`], it needs to be done at login with the + /// [`LoginBuilder::request_refresh_token()`] method, or during account + /// registration. + /// + /// This method doesn't need to be called if + /// [`ClientBuilder::handle_refresh_tokens()`] is called during construction + /// of the `Client`. Otherwise, it should be called once when a refresh + /// token is available and an [`UnknownToken`] error is received. + /// If this call fails with another [`UnknownToken`] error, it means that + /// the session needs to be logged in again. + /// + /// It can also be called at any time when a refresh token is available, it + /// will invalidate the previous access token. + /// + /// The new tokens in the response will be used by the `Client` and should + /// be persisted to be able to [restore the session]. The response will + /// always contain an access token that replaces the previous one. It + /// can also contain a refresh token, in which case it will also replace + /// the previous one. + /// + /// This method is protected behind a lock, so calling this method several + /// times at once will only call the endpoint once and all subsequent calls + /// will wait for the result of the first call. The first call will + /// return `Ok(Some(response))` or the [`HttpError`] returned by the + /// endpoint, while the others will return `Ok(None)` if the token was + /// refreshed by the first call or a [`RefreshTokenError`] error, if it + /// failed. + /// + /// # Example + /// + /// ```no_run + /// use matrix_sdk::{Client, Error, Session}; + /// use url::Url; + /// # use futures::executor::block_on; + /// # block_on(async { + /// # fn get_credentials() -> (&'static str, &'static str) { ("", "") }; + /// # fn persist_session(_: Option) {}; + /// + /// let homeserver = Url::parse("http://example.com")?; + /// let client = Client::new(homeserver).await?; + /// + /// let (user, password) = get_credentials(); + /// let response = client + /// .login_username(user, password) + /// .initial_device_display_name("My App") + /// .request_refresh_token() + /// .send() + /// .await?; + /// + /// persist_session(client.session()); + /// + /// // Handle when an `M_UNKNOWN_TOKEN` error is encountered. + /// async fn on_unknown_token_err( + /// client: &Client, + /// session: &Session, + /// ) -> Result<(), Error> { + /// if session.refresh_token.is_some() + /// && client.refresh_access_token().await.is_ok() + /// { + /// persist_session(client.session()); + /// return Ok(()); + /// } + /// + /// let (user, password) = get_credentials(); + /// client + /// .login_username(user, password) + /// .request_refresh_token() + /// .send() + /// .await?; + /// + /// persist_session(client.session()); + /// + /// Ok(()) + /// } + /// + /// # anyhow::Ok(()) }); + /// ``` + /// + /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens + /// [`UnknownToken`]: ruma::api::client::error::ErrorKind::UnknownToken + /// [restore the session]: Client::restore_login + pub async fn refresh_access_token(&self) -> HttpResult> { + #[cfg(not(target_arch = "wasm32"))] + let lock = self.inner.refresh_token_lock.try_lock().ok(); + #[cfg(target_arch = "wasm32")] + let lock = self.inner.refresh_token_lock.try_lock(); + + if let Some(mut guard) = lock { + let mut session_tokens = if let Some(tokens) = self.session_tokens() { + tokens + } else { + *guard = Err(RefreshTokenError::RefreshTokenRequired); + + return Err(RefreshTokenError::RefreshTokenRequired.into()); + }; + + let refresh_token = session_tokens + .refresh_token + .as_ref() + .ok_or(RefreshTokenError::RefreshTokenRequired)? + .clone(); + let request = refresh_token::v3::Request::new(refresh_token); + + let res = self + .inner + .http_client + .send( + request, + None, + self.homeserver().await.to_string(), + self.session().as_ref(), + self.server_versions().await?, + ) + .await; + + match res { + Ok(res) => { + *guard = Ok(()); + + session_tokens.update_with_refresh_response(&res); + + self.base_client().set_session_tokens(session_tokens); + + Ok(Some(res)) + } + Err(error) => { + *guard = if let HttpError::Api(FromHttpResponseError::Server( + ServerError::Known(RumaApiError::ClientApi(api_error)), + )) = &error + { + Err(RefreshTokenError::ClientApi(api_error.to_owned())) + } else { + Err(RefreshTokenError::UnableToRefreshToken) + }; + + Err(error) + } + } + } else { + match *self.inner.refresh_token_lock.lock().await { + Ok(_) => Ok(None), + Err(_) => Err(RefreshTokenError::UnableToRefreshToken.into()), + } + } + } + /// Register a user to the server. /// /// # Arguments @@ -1519,7 +1809,7 @@ impl Client { request, Some(request_config), self.homeserver().await.to_string(), - self.session(), + self.session().as_ref(), self.server_versions().await?, ) .await?) @@ -1569,6 +1859,47 @@ impl Client { request: Request, config: Option, ) -> HttpResult + where + Request: OutgoingRequest + Clone + Debug, + HttpError: From>, + { + let res = self.send_inner(request.clone(), config).await; + + // If this is an `M_UNKNOWN_TOKEN` error and refresh token handling is active, + // try to refresh the token and retry the request. + if self.inner.handle_refresh_tokens { + if let Err(HttpError::Api(FromHttpResponseError::Server(ServerError::Known( + RumaApiError::ClientApi(error), + )))) = &res + { + if matches!(error.kind, ErrorKind::UnknownToken { .. }) { + let refresh_res = self.refresh_access_token().await; + + if let Err(refresh_error) = refresh_res { + match &refresh_error { + HttpError::RefreshToken(RefreshTokenError::RefreshTokenRequired) => { + // Refreshing access tokens is not supported by + // this `Session`, ignore. + } + _ => { + return Err(refresh_error); + } + } + } else { + return self.send_inner(request, config).await; + } + } + } + } + + res + } + + async fn send_inner( + &self, + request: Request, + config: Option, + ) -> HttpResult where Request: OutgoingRequest + Debug, HttpError: From>, @@ -1579,7 +1910,7 @@ impl Client { request, config, self.homeserver().await.to_string(), - self.session(), + self.session().as_ref(), self.server_versions().await?, ) .await @@ -2370,6 +2701,7 @@ pub(crate) mod tests { pub(crate) async fn logged_in_client(homeserver_url: Option) -> Client { let session = Session { access_token: "1234".to_owned(), + refresh_token: None, user_id: user_id!("@example:localhost").to_owned(), device_id: device_id!("DEVICEID").to_owned(), }; diff --git a/crates/matrix-sdk/src/error.rs b/crates/matrix-sdk/src/error.rs index 4ddf286df9c..a378053c287 100644 --- a/crates/matrix-sdk/src/error.rs +++ b/crates/matrix-sdk/src/error.rs @@ -101,6 +101,10 @@ pub enum HttpError { /// Tried to send a request without `user_id` in the `Session` #[error("missing user_id in session")] UserIdRequired, + + /// An error occurred while refreshing the access token. + #[error(transparent)] + RefreshToken(#[from] RefreshTokenError), } /// Internal representation of errors. @@ -326,3 +330,26 @@ pub enum ImageError { #[error("the thumbnail size is bigger than the original image size")] ThumbnailBiggerThanOriginal, } + +/// Errors that can happen when refreshing an access token. +/// +/// This is usually only returned by [`Client::refresh_access_token()`], unless +/// [handling refresh tokens] is activated for the `Client`. +/// +/// [`Client::refresh_access_token()`]: crate::Client::refresh_access_token() +/// [handling refresh tokens]: crate::ClientBuilder::handle_refresh_tokens() +#[derive(Debug, Error, Clone)] +pub enum RefreshTokenError { + /// The Matrix endpoint returned an error. + #[error(transparent)] + ClientApi(#[from] ruma::api::client::Error), + + /// Tried to send a refresh token request without a refresh token. + #[error("missing refresh token")] + RefreshTokenRequired, + + /// There was an ongoing refresh token call that failed and the error could + /// not be forwarded. + #[error("the access token could not be refreshed")] + UnableToRefreshToken, +} diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index 34fb3010a0c..93f76c62587 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -63,6 +63,6 @@ pub use client::SsoLoginBuilder; pub use client::{Client, ClientBuildError, ClientBuilder, LoginBuilder, LoopCtrl}; #[cfg(feature = "image-proc")] pub use error::ImageError; -pub use error::{Error, HttpError, HttpResult, Result, RumaApiError}; +pub use error::{Error, HttpError, HttpResult, RefreshTokenError, Result, RumaApiError}; pub use http_client::HttpSend; pub use room_member::RoomMember; diff --git a/crates/matrix-sdk/tests/integration/client.rs b/crates/matrix-sdk/tests/integration/client.rs index 8d9f936e319..6e263a1b126 100644 --- a/crates/matrix-sdk/tests/integration/client.rs +++ b/crates/matrix-sdk/tests/integration/client.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeMap, str::FromStr, time::Duration}; use matrix_sdk::{ config::SyncSettings, media::{MediaFormat, MediaRequest, MediaThumbnailSize}, - Error, HttpError, RumaApiError, + Error, HttpError, RumaApiError, Session, }; use matrix_sdk_test::{async_test, test_json}; use ruma::{ @@ -26,7 +26,7 @@ use ruma::{ events::room::{message::ImageMessageEventContent, ImageInfo, MediaSource}, mxc_uri, room_id, uint, user_id, }; -use serde_json::json; +use serde_json::{from_value as from_json_value, json, to_value as to_json_value}; use url::Url; use wiremock::{ matchers::{header, method, path, path_regex}, @@ -586,3 +586,62 @@ async fn whoami() { assert_eq!(client.whoami().await.unwrap().user_id, user_id); } + +#[test] +fn deserialize_session() { + // First version, or second version without refresh token. + let json = json!({ + "access_token": "abcd", + "user_id": "@user:localhost", + "device_id": "EFGHIJ", + }); + let session: Session = from_json_value(json).unwrap(); + assert_eq!(session.access_token, "abcd"); + assert_eq!(session.user_id, "@user:localhost"); + assert_eq!(session.device_id, "EFGHIJ"); + assert_eq!(session.refresh_token, None); + + // Second version with refresh_token. + let json = json!({ + "access_token": "abcd", + "refresh_token": "wxyz", + "user_id": "@user:localhost", + "device_id": "EFGHIJ", + }); + let session: Session = from_json_value(json).unwrap(); + assert_eq!(session.access_token, "abcd"); + assert_eq!(session.user_id, "@user:localhost"); + assert_eq!(session.device_id, "EFGHIJ"); + assert_eq!(session.refresh_token.as_deref(), Some("wxyz")); +} + +#[test] +fn serialize_session() { + // Without refresh token. + let mut session = Session { + access_token: "abcd".to_owned(), + refresh_token: None, + user_id: user_id!("@user:localhost").to_owned(), + device_id: device_id!("EFGHIJ").to_owned(), + }; + assert_eq!( + to_json_value(session.clone()).unwrap(), + json!({ + "access_token": "abcd", + "user_id": "@user:localhost", + "device_id": "EFGHIJ", + }) + ); + + // With refresh_token. + session.refresh_token = Some("wxyz".to_owned()); + assert_eq!( + to_json_value(session).unwrap(), + json!({ + "access_token": "abcd", + "refresh_token": "wxyz", + "user_id": "@user:localhost", + "device_id": "EFGHIJ", + }) + ); +} diff --git a/crates/matrix-sdk/tests/integration/main.rs b/crates/matrix-sdk/tests/integration/main.rs index 9ddca69e3fe..709342fa9a2 100644 --- a/crates/matrix-sdk/tests/integration/main.rs +++ b/crates/matrix-sdk/tests/integration/main.rs @@ -10,6 +10,7 @@ use wiremock::{ }; mod client; +mod refresh_token; mod room; async fn test_client_builder() -> (ClientBuilder, MockServer) { @@ -29,6 +30,7 @@ async fn no_retry_test_client() -> (Client, MockServer) { async fn logged_in_client() -> (Client, MockServer) { let session = Session { access_token: "1234".to_owned(), + refresh_token: None, user_id: user_id!("@example:localhost").to_owned(), device_id: device_id!("DEVICEID").to_owned(), }; diff --git a/crates/matrix-sdk/tests/integration/refresh_token.rs b/crates/matrix-sdk/tests/integration/refresh_token.rs new file mode 100644 index 00000000000..2f1b8ac0cdb --- /dev/null +++ b/crates/matrix-sdk/tests/integration/refresh_token.rs @@ -0,0 +1,564 @@ +use std::time::Duration; + +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; +use futures_signals::signal::SignalExt; +use matches::assert_matches; +use matrix_sdk::{ + config::RequestConfig, executor::spawn, HttpError, RefreshTokenError, RumaApiError, Session, +}; +use matrix_sdk_test::{async_test, test_json}; +use ruma::{ + api::{ + client::{account::register, error::ErrorKind, Error as ClientApiError}, + error::{FromHttpResponseError, ServerError}, + MatrixVersion, + }, + assign, device_id, user_id, +}; +use serde_json::json; +use wiremock::{ + matchers::{body_partial_json, header, method, path}, + Mock, ResponseTemplate, +}; + +use crate::{logged_in_client, no_retry_test_client, test_client_builder}; + +#[async_test] +async fn login_username_refresh_token() { + let (client, server) = no_retry_test_client().await; + + Mock::given(method("POST")) + .and(path("/_matrix/client/r0/login")) + .and(body_partial_json(json!({ + "org.matrix.msc2918.refresh_token": true, + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN_WITH_REFRESH_TOKEN), + ) + .mount(&server) + .await; + + let res = + client.login_username("example", "wordpass").request_refresh_token().send().await.unwrap(); + + let logged_in = client.logged_in(); + assert!(logged_in, "Client should be logged in"); + res.refresh_token.unwrap(); +} + +#[async_test] +#[cfg(feature = "sso-login")] +async fn login_sso_refresh_token() { + let (client, server) = no_retry_test_client().await; + + Mock::given(method("POST")) + .and(path("/_matrix/client/r0/login")) + .and(body_partial_json(json!({ + "org.matrix.msc2918.refresh_token": true, + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN_WITH_REFRESH_TOKEN), + ) + .mount(&server) + .await; + + let idp = ruma::api::client::session::get_login_types::v3::IdentityProvider::new( + "some-id".to_owned(), + "idp-name".to_owned(), + ); + let res = client + .login_sso(|sso_url| async move { + let sso_url = url::Url::parse(&sso_url).unwrap(); + + let (_, redirect) = + sso_url.query_pairs().find(|(key, _)| key == "redirectUrl").unwrap(); + + let mut redirect_url = url::Url::parse(&redirect).unwrap(); + redirect_url.set_query(Some("loginToken=tinytoken")); + + reqwest::get(redirect_url.to_string()).await.unwrap(); + + Ok(()) + }) + .identity_provider_id(&idp.id) + .request_refresh_token() + .send() + .await + .unwrap(); + + let logged_in = client.logged_in(); + assert!(logged_in, "Client should be logged in"); + res.refresh_token.unwrap(); +} + +#[async_test] +async fn register_refresh_token() { + let (client, server) = no_retry_test_client().await; + + Mock::given(method("POST")) + .and(path("/_matrix/client/r0/register")) + .and(body_partial_json(json!({ + "org.matrix.msc2918.refresh_token": true, + }))) + .respond_with( + // Successful registration response is the same as for login, + // if `inhibit_login` is `false`. + ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN_WITH_REFRESH_TOKEN), + ) + .mount(&server) + .await; + + let req = assign!(register::v3::Request::new(), { + username: Some("user"), + password: Some("password"), + auth: None, + refresh_token: true, + }); + + let res = client.register(req).await.unwrap(); + + res.refresh_token.unwrap(); +} + +#[async_test] +async fn no_refresh_token() { + let (client, server) = logged_in_client().await; + + // Refresh token doesn't change. + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN)) + .expect(0) + .mount(&server) + .await; + + let res = client.refresh_access_token().await; + assert_matches!(res, Err(HttpError::RefreshToken(RefreshTokenError::RefreshTokenRequired))); +} + +#[async_test] +async fn refresh_token() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + let tokens = client.session_tokens().unwrap(); + assert_eq!(tokens.access_token, "1234"); + assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); + + // Refresh token doesn't change. + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .and(body_partial_json(json!({ + "refresh_token": "abcd", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN)) + .up_to_n_times(1) + .mount(&server) + .await; + + client.refresh_access_token().await.unwrap().unwrap(); + let tokens = client.session_tokens().unwrap(); + assert_eq!(tokens.access_token, "5678"); + assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); + + // Refresh token changes. + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .and(body_partial_json(json!({ + "refresh_token": "abcd", + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN_WITH_REFRESH_TOKEN), + ) + .mount(&server) + .await; + + client.refresh_access_token().await.unwrap().unwrap(); + let tokens = client.session_tokens().unwrap(); + assert_eq!(tokens.access_token, "9012"); + assert_eq!(tokens.refresh_token.as_deref(), Some("wxyz")); +} + +#[async_test] +async fn refresh_token_not_handled() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN)) + .expect(0) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 1234")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .mount(&server) + .await; + + let res = client.whoami().await; + assert_matches!( + res, + Err(HttpError::Api(FromHttpResponseError::Server(ServerError::Known( + RumaApiError::ClientApi(ClientApiError { kind, .. }) + )))) if matches!(kind, ErrorKind::UnknownToken { .. }) + ) +} + +#[async_test] +async fn refresh_token_handled_success() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .handle_refresh_tokens() + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + let mut tokens_stream = client.session_tokens_signal().to_stream(); + let (tokens_sender, tokens_receiver) = oneshot::channel::<()>(); + spawn(async move { + let tokens = tokens_stream.next().await.flatten().unwrap(); + assert_eq!(tokens.access_token, "1234"); + assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); + + let tokens = tokens_stream.next().await.flatten().unwrap(); + assert_eq!(tokens.access_token, "5678"); + assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); + + tokens_sender.send(()).unwrap(); + }); + + let mut tokens_changed_stream = client.session_tokens_changed_signal().to_stream(); + let (changed_sender, changed_receiver) = oneshot::channel::<()>(); + spawn(async move { + tokens_changed_stream.next().await.unwrap(); + tokens_changed_stream.next().await.unwrap(); + + changed_sender.send(()).unwrap(); + }); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN)) + .expect(1) + .named("`POST /refresh`") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 1234")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .expect(1) + .named("`GET /whoami` wrong token") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 5678")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::WHOAMI)) + .expect(1) + .named("`GET /whoami` good token") + .mount(&server) + .await; + + client.whoami().await.unwrap(); + tokens_receiver.await.unwrap(); + changed_receiver.await.unwrap(); +} + +#[async_test] +async fn refresh_token_handled_failure() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .handle_refresh_tokens() + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .expect(1) + .named("`POST /refresh`") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 1234")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .expect(1) + .named("`GET /whoami` wrong token") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 5678")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::WHOAMI)) + .expect(0) + .named("`GET /whoami` good token") + .mount(&server) + .await; + + let res = client.whoami().await; + assert_matches!( + res, + Err(HttpError::Api(FromHttpResponseError::Server(ServerError::Known( + RumaApiError::ClientApi(ClientApiError { kind, .. }) + )))) if matches!(kind, ErrorKind::UnknownToken { .. }) + ) +} + +#[async_test] +async fn refresh_token_handled_multi_success() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .handle_refresh_tokens() + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(&*test_json::REFRESH_TOKEN) + .set_delay(Duration::from_secs(1)), + ) + .expect(1) + .named("`POST /refresh`") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 1234")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .expect(3) + .named("`GET /whoami` wrong token") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 5678")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::WHOAMI)) + .expect(3) + .named("`GET /whoami` good token") + .mount(&server) + .await; + + let (mut sender, mut receiver) = mpsc::channel::<()>(3); + let client_clone = client.clone(); + let mut sender_clone = sender.clone(); + spawn(async move { + client_clone.whoami().await.unwrap(); + sender_clone.try_send(()).unwrap(); + }); + let client_clone = client.clone(); + let mut sender_clone = sender.clone(); + spawn(async move { + client_clone.whoami().await.unwrap(); + sender_clone.try_send(()).unwrap(); + }); + spawn(async move { + client.whoami().await.unwrap(); + sender.try_send(()).unwrap(); + }); + + let mut i = 0; + while i < 3 { + if receiver.next().await.is_some() { + i += 1; + } + } +} + +#[async_test] +async fn refresh_token_handled_multi_failure() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .handle_refresh_tokens() + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with( + ResponseTemplate::new(401) + .set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT) + .set_delay(Duration::from_secs(1)), + ) + .expect(1) + .named("`POST /refresh`") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 1234")) + .respond_with( + ResponseTemplate::new(401).set_body_json(&*test_json::UNKNOWN_TOKEN_SOFT_LOGOUT), + ) + .expect(3) + .named("`GET /whoami` wrong token") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .and(header(http::header::AUTHORIZATION, "Bearer 5678")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::WHOAMI)) + .expect(0) + .named("`GET /whoami` good token") + .mount(&server) + .await; + + let (mut sender, mut receiver) = futures::channel::mpsc::channel::<()>(3); + let client_clone = client.clone(); + let mut sender_clone = sender.clone(); + spawn(async move { + client_clone.whoami().await.unwrap_err(); + sender_clone.try_send(()).unwrap(); + }); + let client_clone = client.clone(); + let mut sender_clone = sender.clone(); + spawn(async move { + client_clone.whoami().await.unwrap_err(); + sender_clone.try_send(()).unwrap(); + }); + spawn(async move { + client.whoami().await.unwrap_err(); + sender.try_send(()).unwrap(); + }); + + let mut i = 0; + while i < 3 { + if receiver.next().await.is_some() { + i += 1; + } + } +} + +#[async_test] +async fn refresh_token_handled_other_error() { + let (builder, server) = test_client_builder().await; + let client = builder + .request_config(RequestConfig::new().disable_retry()) + .server_versions([MatrixVersion::V1_3]) + .handle_refresh_tokens() + .build() + .await + .unwrap(); + + let session = Session { + access_token: "1234".to_owned(), + refresh_token: Some("abcd".to_owned()), + user_id: user_id!("@example:localhost").to_owned(), + device_id: device_id!("DEVICEID").to_owned(), + }; + client.restore_login(session).await.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/v3/refresh")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::REFRESH_TOKEN)) + .expect(0) + .named("`POST /refresh`") + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/v3/account/whoami")) + .respond_with(ResponseTemplate::new(404).set_body_json(&*test_json::NOT_FOUND)) + .expect(1) + .named("`GET /whoami`") + .mount(&server) + .await; + + client.whoami().await.unwrap_err(); +} diff --git a/testing/matrix-sdk-test/src/test_json/api_responses.rs b/testing/matrix-sdk-test/src/test_json/api_responses.rs index 9007fe061dd..3405e8038c2 100644 --- a/testing/matrix-sdk-test/src/test_json/api_responses.rs +++ b/testing/matrix-sdk-test/src/test_json/api_responses.rs @@ -106,6 +106,18 @@ pub static LOGIN_WITH_DISCOVERY: Lazy = Lazy::new(|| { }) }); +/// Successful call to `POST /_matrix/client/v3/login` with a refresh token. +pub static LOGIN_WITH_REFRESH_TOKEN: Lazy = Lazy::new(|| { + json!({ + "access_token": "abc123", + "device_id": "GHTYAJCE", + "home_server": "matrix.org", + "user_id": "@cheeky_monkey:matrix.org", + "expires_in_ms": 432000000, + "refresh_token": "zyx987", + }) +}); + /// Failed call to `POST /_matrix/client/v3/login` pub static LOGIN_RESPONSE_ERR: Lazy = Lazy::new(|| { json!({ @@ -131,6 +143,16 @@ pub static LOGIN_TYPES: Lazy = Lazy::new(|| { }) }); +/// Failed call to an endpoint when the resource that was asked could not be +/// found. +pub static NOT_FOUND: Lazy = Lazy::new(|| { + json!({ + "errcode": "M_NOT_FOUND", + "error": "No resource was found for this request.", + "soft_logout": true, + }) +}); + /// `GET /_matrix/client/v3/publicRooms` pub static PUBLIC_ROOMS: Lazy = Lazy::new(|| { json!({ @@ -154,6 +176,23 @@ pub static PUBLIC_ROOMS: Lazy = Lazy::new(|| { }) }); +/// `POST /_matrix/client/v3/refresh` without new refresh token. +pub static REFRESH_TOKEN: Lazy = Lazy::new(|| { + json!({ + "access_token": "5678", + "expire_in_ms": 432000000, + }) +}); + +/// `POST /_matrix/client/v3/refresh` with a new refresh token. +pub static REFRESH_TOKEN_WITH_REFRESH_TOKEN: Lazy = Lazy::new(|| { + json!({ + "access_token": "9012", + "expire_in_ms": 432000000, + "refresh_token": "wxyz", + }) +}); + /// Failed call to `POST /_matrix/client/v3/register` pub static REGISTRATION_RESPONSE_ERR: Lazy = Lazy::new(|| { json!({ @@ -177,6 +216,15 @@ pub static REGISTRATION_RESPONSE_ERR: Lazy = Lazy::new(|| { }) }); +/// Failed called to any endpoint with an expired access token. +pub static UNKNOWN_TOKEN_SOFT_LOGOUT: Lazy = Lazy::new(|| { + json!({ + "errcode": "M_UNKNOWN_TOKEN", + "error": "Invalid access token passed.", + "soft_logout": true, + }) +}); + /// `GET /_matrix/client/versions` pub static VERSIONS: Lazy = Lazy::new(|| { json!({ diff --git a/testing/matrix-sdk-test/src/test_json/mod.rs b/testing/matrix-sdk-test/src/test_json/mod.rs index bf469e94cac..f44245a6f00 100644 --- a/testing/matrix-sdk-test/src/test_json/mod.rs +++ b/testing/matrix-sdk-test/src/test_json/mod.rs @@ -15,7 +15,9 @@ pub mod sync_events; pub use api_responses::{ DEVICES, GET_ALIAS, KEYS_QUERY, KEYS_UPLOAD, LOGIN, LOGIN_RESPONSE_ERR, LOGIN_TYPES, - LOGIN_WITH_DISCOVERY, PUBLIC_ROOMS, REGISTRATION_RESPONSE_ERR, VERSIONS, WELL_KNOWN, WHOAMI, + LOGIN_WITH_DISCOVERY, LOGIN_WITH_REFRESH_TOKEN, NOT_FOUND, PUBLIC_ROOMS, REFRESH_TOKEN, + REFRESH_TOKEN_WITH_REFRESH_TOKEN, REGISTRATION_RESPONSE_ERR, UNKNOWN_TOKEN_SOFT_LOGOUT, + VERSIONS, WELL_KNOWN, WHOAMI, }; pub use members::MEMBERS; pub use messages::{ROOM_MESSAGES, ROOM_MESSAGES_BATCH_1, ROOM_MESSAGES_BATCH_2};