Skip to content

Commit

Permalink
Make client use .well-known redirects
Browse files Browse the repository at this point in the history
Was supposed to fix #219, but apparently that was about something else.
  • Loading branch information
timorl committed May 24, 2021
1 parent fe17dce commit ded5830
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 47 deletions.
229 changes: 185 additions & 44 deletions matrix_sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,33 @@ pub enum LoopCtrl {
Break,
}

#[cfg(feature = "encryption")]
use matrix_sdk_common::{
api::r0::{
account::register,
device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered},
filter::{create_filter::Request as FilterUploadRequest, FilterDefinition},
media::{create_content, get_content, get_content_thumbnail},
membership::{join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name},
room::create_room,
session::{get_login_types, login, sso_login},
sync::sync_events,
uiaa::AuthData,
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
},
},
identifiers::EventId,
};
use matrix_sdk_common::{
api::{
r0::{
account::register,
device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered},
filter::{create_filter::Request as FilterUploadRequest, FilterDefinition},
media::{create_content, get_content, get_content_thumbnail},
membership::{join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name},
room::create_room,
session::{get_login_types, login, sso_login},
sync::sync_events,
uiaa::AuthData,
},
unversioned::{discover_homeserver, get_supported_versions},
},
assign,
identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId},
Expand All @@ -98,16 +111,6 @@ use matrix_sdk_common::{
uuid::Uuid,
FromHttpResponseError, UInt,
};
#[cfg(feature = "encryption")]
use matrix_sdk_common::{
api::r0::{
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
},
},
identifiers::EventId,
};

#[cfg(feature = "encryption")]
use crate::{
Expand Down Expand Up @@ -142,7 +145,7 @@ const SSO_SERVER_BIND_TRIES: u8 = 10;
#[derive(Clone)]
pub struct Client {
/// The URL of the homeserver to connect to.
homeserver: Arc<Url>,
homeserver: Arc<RwLock<Url>>,
/// The underlying HTTP client.
http_client: HttpClient,
/// User session data.
Expand All @@ -164,7 +167,7 @@ pub struct Client {
#[cfg(not(tarpaulin_include))]
impl Debug for Client {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
write!(fmt, "Client {{ homeserver: {} }}", self.homeserver)
write!(fmt, "Client")
}
}

Expand Down Expand Up @@ -502,7 +505,7 @@ impl Client {
///
/// * `config` - Configuration for the client.
pub fn new_with_config(homeserver_url: Url, config: ClientConfig) -> Result<Self> {
let homeserver = Arc::new(homeserver_url);
let homeserver = Arc::new(RwLock::new(homeserver_url));

let client = if let Some(client) = config.client {
client
Expand All @@ -513,12 +516,8 @@ impl Client {
let base_client = BaseClient::new_with_config(config.base_config)?;
let session = base_client.session().clone();

let http_client = HttpClient {
homeserver: homeserver.clone(),
inner: client,
session,
request_config: config.request_config,
};
let http_client =
HttpClient::new(client, homeserver.clone(), session, config.request_config);

Ok(Self {
homeserver,
Expand All @@ -534,6 +533,89 @@ impl Client {
})
}

/// Creates a new client for making HTTP requests to the homeserver of the
/// given user. Follows homeserver discovery directions described
/// [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// # Example
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId};
/// # use futures::executor::block_on;
/// let alice = UserId::try_from("@alice:example.org").unwrap();
/// # block_on(async {
/// let client = Client::new_from_user_id(alice.clone()).await.unwrap();
/// client.login(alice.localpart(), "password", None, None).await.unwrap();
/// # });
/// ```
pub async fn new_from_user_id(user_id: UserId) -> Result<Self> {
let config = ClientConfig::new();
Client::new_from_user_id_with_config(user_id, config).await
}

/// Creates a new client for making HTTP requests to the homeserver of the
/// given user and configuration. Follows homeserver discovery directions
/// described [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// * `config` - Configuration for the client.
pub async fn new_from_user_id_with_config(
user_id: UserId,
config: ClientConfig,
) -> Result<Self> {
let homeserver = Client::homeserver_from_user_id(user_id)?;
let mut client = Client::new_with_config(homeserver, config)?;

let well_known = client.discover_homeserver().await?;
let well_known = Url::parse(well_known.homeserver.base_url.as_ref())?;
client.set_homeserver(well_known).await;
client.get_supported_versions().await?;
Ok(client)
}

fn homeserver_from_user_id(user_id: UserId) -> Result<Url> {
let homeserver = format!("https://{}", user_id.server_name());
#[allow(unused_mut)]
let mut result = Url::parse(homeserver.as_str())?;
// Mockito only knows how to test http endpoints:
// https://github.com/lipanski/mockito/issues/127
#[cfg(test)]
let _ = result.set_scheme("http");
Ok(result)
}

async fn discover_homeserver(&self) -> Result<discover_homeserver::Response> {
self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry()))
.await
}

/// Change the homeserver URL used by this client.
///
/// # Arguments
///
/// * `homeserver_url` - The new URL to use.
pub async fn set_homeserver(&mut self, homeserver_url: Url) {
let mut homeserver = self.homeserver.write().await;
*homeserver = homeserver_url;
}

async fn get_supported_versions(&self) -> Result<get_supported_versions::Response> {
self.send(
get_supported_versions::Request::new(),
Some(RequestConfig::new().disable_retry()),
)
.await
}

/// Process a [transaction] received from the homeserver
///
/// # Arguments
Expand Down Expand Up @@ -566,8 +648,8 @@ impl Client {
}

/// The Homeserver of the client.
pub fn homeserver(&self) -> &Url {
&self.homeserver
pub async fn homeserver(&self) -> Url {
self.homeserver.read().await.clone()
}

/// Get the user id of the current owner of the client.
Expand Down Expand Up @@ -866,8 +948,8 @@ impl Client {
/// successful SSO login.
///
/// [`login_with_token`]: #method.login_with_token
pub fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver();
pub async fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver().await;
let request = sso_login::Request::new(redirect_url)
.try_into_http_request::<Vec<u8>>(homeserver.as_str(), SendAccessToken::None);
match request {
Expand Down Expand Up @@ -928,7 +1010,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {} as {:?}", self.homeserver, user);
info!("Logging in to {} as {:?}", self.homeserver().await, user);

let request = assign!(
login::Request::new(
Expand Down Expand Up @@ -1037,7 +1119,7 @@ impl Client {
where
C: Future<Output = Result<()>>,
{
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);
let (signal_tx, signal_rx) = oneshot::channel();
let (data_tx, data_rx) = oneshot::channel();
let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx)));
Expand Down Expand Up @@ -1109,7 +1191,7 @@ impl Client {

tokio::spawn(server);

let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap();
let sso_url = self.get_sso_login_url(redirect_url.as_str()).await.unwrap();

match use_sso_login_url(sso_url).await {
Ok(t) => t,
Expand Down Expand Up @@ -1193,7 +1275,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);

let request = assign!(
login::Request::new(
Expand Down Expand Up @@ -1264,7 +1346,7 @@ impl Client {
&self,
registration: impl Into<register::Request<'_>>,
) -> Result<register::Response> {
info!("Registering to {}", self.homeserver);
info!("Registering to {}", self.homeserver().await);

let request = registration.into();
self.send(request, None).await
Expand Down Expand Up @@ -2387,7 +2469,13 @@ impl Client {

#[cfg(test)]
mod test {
use std::{collections::BTreeMap, convert::TryInto, io::Cursor, str::FromStr, time::Duration};
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
io::Cursor,
str::FromStr,
time::Duration,
};

use matrix_sdk_base::identifiers::mxc_uri;
use matrix_sdk_common::{
Expand All @@ -2399,7 +2487,7 @@ mod test {
assign,
directory::Filter,
events::{room::message::MessageEventContent, AnyMessageEventContent},
identifiers::{event_id, room_id, user_id},
identifiers::{event_id, room_id, user_id, UserId},
thirdparty,
};
use matrix_sdk_test::{test_json, EventBuilder, EventsJson};
Expand All @@ -2425,6 +2513,59 @@ mod test {
client
}

#[tokio::test]
async fn set_homeserver() {
let homeserver = Url::from_str("http://example.com/").unwrap();

let mut client = Client::new(homeserver).unwrap();

let homeserver = Url::from_str(&mockito::server_url()).unwrap();

client.set_homeserver(homeserver.clone()).await;

assert_eq!(client.homeserver().await, homeserver);
}

#[tokio::test]
async fn successful_discovery() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();

let _m_well_known = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();

let _m_versions = mock("GET", "/_matrix/client/versions")
.with_status(200)
.with_body(test_json::VERSIONS.to_string())
.create();
let client = Client::new_from_user_id(alice).await.unwrap();

assert_eq!(client.homeserver().await, Url::parse(server_url.as_ref()).unwrap());
}

#[tokio::test]
async fn discovery_broken_server() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();

let _m = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();

if Client::new_from_user_id(alice).await.is_ok() {
panic!("Creating a client from a user ID should fail when the .well-known server returns no version infromation.");
}
}

#[tokio::test]
async fn login() {
let homeserver = Url::from_str(&mockito::server_url()).unwrap();
Expand Down Expand Up @@ -2514,7 +2655,7 @@ mod test {
.any(|flow| matches!(flow, LoginType::Sso(_)));
assert!(can_sso);

let sso_url = client.get_sso_login_url("http://127.0.0.1:3030");
let sso_url = client.get_sso_login_url("http://127.0.0.1:3030").await;
assert!(sso_url.is_ok());

let _m = mock("POST", "/_matrix/client/r0/login")
Expand Down Expand Up @@ -2626,7 +2767,7 @@ mod test {
client.base_client.receive_sync_response(response).await.unwrap();
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");

assert_eq!(client.homeserver(), &Url::parse(&mockito::server_url()).unwrap());
assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap());

let room = client.get_joined_room(&room_id);
assert!(room.is_some());
Expand Down
5 changes: 5 additions & 0 deletions matrix_sdk/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use matrix_sdk_common::{
use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError;
use thiserror::Error;
use url::ParseError as UrlParseError;

/// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -128,6 +129,10 @@ pub enum Error {
/// An error encountered when trying to parse an identifier.
#[error(transparent)]
Identifier(#[from] IdentifierError),

/// An error encountered when trying to parse a url.
#[error(transparent)]
Url(#[from] UrlParseError),
}

impl Error {
Expand Down
Loading

0 comments on commit ded5830

Please sign in to comment.