Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): Add support for refresh tokens #892

Merged
merged 7 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions bindings/matrix-sdk-ffi/src/authentication_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl AuthenticationService {

// Restore the client using the session from the login request.
client
.restore_session(session.clone())
.restore_session(session)
.map(|_| client.clone())
.map_err(AuthenticationError::from)
}
Expand Down Expand Up @@ -147,6 +147,7 @@ impl AuthenticationService {

let discovery_session = Session {
access_token: token.clone(),
refresh_token: None,
user_id: discovery_user_id,
device_id: device_id.clone(),
};
Expand All @@ -156,8 +157,12 @@ impl AuthenticationService {

// Create the actual client with a store path from the user ID.
let homeserver_url = client.homeserver();
let session =
Session { access_token: token, user_id: whoami.user_id.clone(), device_id };
let session = Session {
access_token: token,
refresh_token: None,
user_id: whoami.user_id.clone(),
device_id,
};
let client = Arc::new(ClientBuilder::new())
.base_path(self.base_path.clone())
.homeserver_url(homeserver_url)
Expand Down
2 changes: 1 addition & 1 deletion bindings/matrix-sdk-ffi/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl Client {

pub fn restore_token(&self) -> anyhow::Result<String> {
RUNTIME.block_on(async move {
let session = self.client.session().expect("Missing session").clone();
let session = self.client.session().expect("Missing session");
let homeurl = self.client.homeserver().await.into();
Ok(serde_json::to_string(&RestoreToken {
session,
Expand Down
2 changes: 2 additions & 0 deletions crates/matrix-sdk-appservice/src/virtual_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ impl<'a> VirtualUserBuilder<'a> {

Session {
access_token: response.access_token,
refresh_token: response.refresh_token,
user_id: response.user_id,
device_id: response.device_id,
}
} else {
// Don’t log in
Session {
access_token: self.appservice.registration.as_token.clone(),
refresh_token: None,
user_id: user_id.clone(),
device_id: self.device_id.unwrap_or_else(DeviceId::new),
}
Expand Down
1 change: 1 addition & 0 deletions crates/matrix-sdk-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async-trait = "0.1.53"
dashmap = "5.2.0"
futures-channel = "0.3.21"
futures-core = "0.3.21"
futures-signals = { version = "0.3.30", default-features = false }
futures-util = { version = "0.3.21", default-features = false }
http = { version = "0.2.6", optional = true }
lru = "0.7.5"
Expand Down
41 changes: 34 additions & 7 deletions crates/matrix-sdk-base/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::{
#[cfg(feature = "e2e-encryption")]
use std::{ops::Deref, sync::Arc};

use futures_signals::signal::ReadOnlyMutable;
#[cfg(feature = "experimental-timeline")]
use matrix_sdk_common::deserialized_responses::TimelineSlice;
use matrix_sdk_common::{
Expand Down Expand Up @@ -62,11 +63,10 @@ use crate::error::Error;
use crate::{
error::Result,
rooms::{Room, RoomInfo, RoomType},
session::Session,
store::{
ambiguity_map::AmbiguityCache, Result as StoreResult, StateChanges, Store, StoreConfig,
},
StateStore,
Session, SessionMeta, SessionTokens, StateStore,
};

/// A no IO Client implementation.
Expand All @@ -91,7 +91,8 @@ pub struct BaseClient {
impl fmt::Debug for BaseClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("session", &self.session())
.field("session_meta", &self.store.session_meta())
.field("session_tokens", &self.store.session_tokens)
.field("sync_token", &self.store.sync_token)
.finish()
}
Expand Down Expand Up @@ -124,6 +125,29 @@ impl BaseClient {
}
}

/// Get the session meta information.
///
/// If the client is currently logged in, this will return a
/// [`SessionMeta`] object which contains the user ID and device ID.
/// Otherwise it returns `None`.
pub fn session_meta(&self) -> Option<&SessionMeta> {
self.store.session_meta()
}

/// Get the session tokens.
///
/// If the client is currently logged in, this will return a
/// [`SessionTokens`] object which contains the access token and optional
/// refresh token. Otherwise it returns `None`.
pub fn session_tokens(&self) -> ReadOnlyMutable<Option<SessionTokens>> {
self.store.session_tokens()
}

/// Set the session tokens.
pub fn set_session_tokens(&self, tokens: SessionTokens) {
self.store.set_session_tokens(tokens)
}

/// Get the user login session.
///
/// If the client is currently logged in, this will return a
Expand All @@ -132,7 +156,7 @@ impl BaseClient {
///
/// Returns a session object if the client is logged in. Otherwise returns
/// `None`.
pub fn session(&self) -> Option<&Session> {
pub fn session(&self) -> Option<Session> {
self.store.session()
}

Expand All @@ -154,7 +178,7 @@ impl BaseClient {

/// Is the client logged in.
pub fn logged_in(&self) -> bool {
self.store.session().is_some()
self.store.session_meta().is_some()
}

/// Receive a login response and update the session of the client.
Expand All @@ -169,6 +193,7 @@ impl BaseClient {
) -> Result<()> {
let session = Session {
access_token: response.access_token.clone(),
refresh_token: response.refresh_token.clone(),
device_id: response.device_id.clone(),
user_id: response.user_id.clone(),
};
Expand Down Expand Up @@ -1000,8 +1025,8 @@ impl BaseClient {
.transpose()?
{
Ok(event.content.global)
} else if let Some(session) = self.store.session() {
Ok(Ruleset::server_default(&session.user_id))
} else if let Some(session_meta) = self.store.session_meta() {
Ok(Ruleset::server_default(&session_meta.user_id))
} else {
Ok(Ruleset::new())
}
Expand Down Expand Up @@ -1137,6 +1162,7 @@ mod tests {
client
.restore_login(Session {
access_token: "token".to_owned(),
refresh_token: None,
user_id: user_id.to_owned(),
device_id: "FOOBAR".into(),
})
Expand Down Expand Up @@ -1191,6 +1217,7 @@ mod tests {
client
.restore_login(Session {
access_token: "token".to_owned(),
refresh_token: None,
user_id: user_id.to_owned(),
device_id: "FOOBAR".into(),
})
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub use matrix_sdk_common::*;
pub use crate::timeline_stream::TimelineStreamError;
pub use crate::{
error::{Error, Result},
session::Session,
session::{Session, SessionMeta, SessionTokens},
};

mod client;
Expand Down
61 changes: 57 additions & 4 deletions crates/matrix-sdk-base/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

//! User sessions.

use ruma::{OwnedDeviceId, OwnedUserId};
use ruma::{api::client::session::refresh_token, OwnedDeviceId, OwnedUserId};
use serde::{Deserialize, Serialize};

/// A user session, containing an access token and information about the
/// associated user account.
/// A user session, containing an access token, an optional refresh token and
/// information about the associated user account.
///
/// # Example
///
Expand All @@ -29,6 +29,7 @@ use serde::{Deserialize, Serialize};
///
/// let session = Session {
/// access_token: "My-Token".to_owned(),
/// refresh_token: None,
/// user_id: user_id!("@example:localhost").to_owned(),
/// device_id: device_id!("MYDEVICEID").to_owned(),
/// };
Expand All @@ -39,18 +40,70 @@ use serde::{Deserialize, Serialize};
pub struct Session {
/// The access token used for this session.
pub access_token: String,
/// The token used for [refreshing the access token], if any.
///
/// [refreshing the access token]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
/// The user the access token was issued for.
pub user_id: OwnedUserId,
/// The ID of the client device
/// The ID of the client device.
pub device_id: OwnedDeviceId,
}
gnunicorn marked this conversation as resolved.
Show resolved Hide resolved

impl Session {
/// Creates a `Session` from a `SessionMeta` and `SessionTokens`.
pub fn from_parts(meta: SessionMeta, tokens: SessionTokens) -> Self {
let SessionMeta { user_id, device_id } = meta;
let SessionTokens { access_token, refresh_token } = tokens;
Self { access_token, refresh_token, user_id, device_id }
}

/// Split this `Session` between `SessionMeta` and `SessionTokens`.
pub fn into_parts(self) -> (SessionMeta, SessionTokens) {
let Self { access_token, refresh_token, user_id, device_id } = self;
(SessionMeta { user_id, device_id }, SessionTokens { access_token, refresh_token })
}
}

impl From<ruma::api::client::session::login::v3::Response> for Session {
fn from(response: ruma::api::client::session::login::v3::Response) -> Self {
Self {
access_token: response.access_token,
refresh_token: response.refresh_token,
user_id: response.user_id,
device_id: response.device_id,
}
}
}

/// The immutable parts of the session: the user ID and device ID.
#[derive(Clone, Debug)]
pub struct SessionMeta {
/// The user the access token was issued for.
pub user_id: OwnedUserId,
/// The ID of the client device.
pub device_id: OwnedDeviceId,
}

/// The mutable parts of the session: the access token and optional refresh
/// token.
#[derive(Clone, Debug)]
pub struct SessionTokens {
/// The access token used for this session.
pub access_token: String,
/// The token used for [refreshing the access token], if any.
///
/// [refreshing the access token]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens
pub refresh_token: Option<String>,
}

impl SessionTokens {
/// Update this `SessionTokens` with the values found in the given response.
pub fn update_with_refresh_response(&mut self, response: &refresh_token::v3::Response) {
self.access_token = response.access_token.clone();
if let Some(refresh_token) = response.refresh_token.clone() {
self.refresh_token = Some(refresh_token);
}
}
}
47 changes: 36 additions & 11 deletions crates/matrix-sdk-base/src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::{
sync::Arc,
};

use futures_signals::signal::{Mutable, ReadOnlyMutable};
use once_cell::sync::OnceCell;

#[cfg(any(test, feature = "testing"))]
Expand Down Expand Up @@ -61,7 +62,7 @@ use crate::{
deserialized_responses::MemberEvent,
media::MediaRequest,
rooms::{RoomInfo, RoomType},
MinimalRoomMemberEvent, Room, Session,
MinimalRoomMemberEvent, Room, Session, SessionMeta, SessionTokens,
};

pub(crate) mod ambiguity_map;
Expand Down Expand Up @@ -409,7 +410,8 @@ where
#[derive(Debug, Clone)]
pub(crate) struct Store {
pub(super) inner: Arc<dyn StateStore>,
session: Arc<OnceCell<Session>>,
session_meta: Arc<OnceCell<SessionMeta>>,
pub(super) session_tokens: Mutable<Option<SessionTokens>>,
/// The current sync token that should be used for the next sync call.
pub(super) sync_token: Arc<RwLock<Option<String>>>,
rooms: Arc<DashMap<OwnedRoomId, Room>>,
Expand All @@ -430,7 +432,8 @@ impl Store {
pub fn new(inner: Arc<dyn StateStore>) -> Self {
Self {
inner,
session: Default::default(),
session_meta: Default::default(),
session_tokens: Default::default(),
sync_token: Default::default(),
rooms: Default::default(),
stripped_rooms: Default::default(),
Expand All @@ -451,17 +454,37 @@ impl Store {
}

let token = self.get_sync_token().await?;

*self.sync_token.write().await = token;
self.session.set(session).expect("A session was already set");

let (session_meta, session_tokens) = session.into_parts();
self.session_meta.set(session_meta).expect("Session IDs were already set");
self.session_tokens.set(Some(session_tokens));

Ok(())
}

/// The current [`Session`] containing our user id, device ID and access
/// token.
pub fn session(&self) -> Option<&Session> {
self.session.get()
/// The current [`SessionMeta`] containing our user ID and device ID.
pub fn session_meta(&self) -> Option<&SessionMeta> {
self.session_meta.get()
}

/// The current [`SessionTokens`] containing our access token and optional
/// refresh token.
pub fn session_tokens(&self) -> ReadOnlyMutable<Option<SessionTokens>> {
self.session_tokens.read_only()
}

/// Set the current [`SessionTokens`].
pub fn set_session_tokens(&self, tokens: SessionTokens) {
self.session_tokens.set(Some(tokens));
}

/// The current [`Session`] containing our user id, device ID, access
/// token and optional refresh token.
pub fn session(&self) -> Option<Session> {
let meta = self.session_meta.get()?;
let tokens = self.session_tokens.get_cloned()?;
Some(Session::from_parts(meta.to_owned(), tokens))
}

/// Get all the rooms this store knows about.
Expand Down Expand Up @@ -494,7 +517,8 @@ impl Store {
/// Lookup the stripped Room for the given RoomId, or create one, if it
/// didn't exist yet in the store
pub async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
let user_id = &self.session.get().expect("Creating room while not being logged in").user_id;
let user_id =
&self.session_meta.get().expect("Creating room while not being logged in").user_id;

self.stripped_rooms
.entry(room_id.to_owned())
Expand All @@ -509,7 +533,8 @@ impl Store {
return self.get_or_create_stripped_room(room_id).await;
}

let user_id = &self.session.get().expect("Creating room while not being logged in").user_id;
let user_id =
&self.session_meta.get().expect("Creating room while not being logged in").user_id;

self.rooms
.entry(room_id.to_owned())
Expand Down
1 change: 1 addition & 0 deletions crates/matrix-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dashmap = "5.2.0"
event-listener = "2.5.2"
eyre = { version = "0.6.8", optional = true }
futures-core = "0.3.21"
futures-signals = { version = "0.3.30", default-features = false }
futures-util = { version = "0.3.21", default-features = false }
http = "0.2.6"
matrix-sdk-common = { version = "0.5.0", path = "../matrix-sdk-common" }
Expand Down
Loading