diff --git a/libsignal-service-actix/Cargo.toml b/libsignal-service-actix/Cargo.toml index a42fdb8ba..f4005e93f 100644 --- a/libsignal-service-actix/Cargo.toml +++ b/libsignal-service-actix/Cargo.toml @@ -11,8 +11,15 @@ libsignal-service = { path = "../libsignal-service" } libsignal-protocol = { git = "https://github.com/Michael-F-Bryan/libsignal-protocol-rs" } awc = { version = "2.0.0-alpha.2", features=["rustls"] } +actix = "0.10.0-alpha.3" actix-rt = "1.1" +serde_json = "1.0" +futures = "0.3" +bytes = "0.5" rustls = "0.17" +url = "2.1" +serde = "1.0" +log = "0.4.8" failure = "0.1.5" thiserror = "1.0" diff --git a/libsignal-service-actix/src/lib.rs b/libsignal-service-actix/src/lib.rs index 028a9bfe4..73a1ea704 100644 --- a/libsignal-service-actix/src/lib.rs +++ b/libsignal-service-actix/src/lib.rs @@ -1 +1,6 @@ pub mod push_service; +pub mod websocket; + +pub mod prelude { + pub use crate::push_service::*; +} diff --git a/libsignal-service-actix/src/push_service.rs b/libsignal-service-actix/src/push_service.rs index 7dc8c12ba..cfabb9b75 100644 --- a/libsignal-service-actix/src/push_service.rs +++ b/libsignal-service-actix/src/push_service.rs @@ -1,24 +1,134 @@ -use libsignal_service::{configuration::*, push_service::*}; +use std::{sync::Arc, time::Duration}; +use awc::Connector; +use libsignal_service::{ + configuration::*, messagepipe::WebSocketService, push_service::*, +}; +use serde::Deserialize; +use url::Url; + +use crate::websocket::AwcWebSocket; + +#[derive(Clone)] pub struct AwcPushService { + cfg: ServiceConfiguration, + base_url: Url, client: awc::Client, } #[async_trait::async_trait(?Send)] impl PushService for AwcPushService { - async fn get(&mut self, _path: &str) -> Result<(), ServiceError> { Ok(()) } + type WebSocket = AwcWebSocket; + + async fn get(&mut self, path: &str) -> Result + where + for<'de> T: Deserialize<'de>, + { + // In principle, we should be using http::uri::Uri, + // but that doesn't seem like an owned type where we can do this kind of + // constructions on. + // https://docs.rs/http/0.2.1/http/uri/struct.Uri.html + let url = self.base_url.join(path).expect("valid url"); + + log::debug!("AwcPushService::get({:?})", url); + let mut response = + self.client.get(url.as_str()).send().await.map_err(|e| { + ServiceError::SendError { + reason: e.to_string(), + } + })?; + + log::debug!("AwcPushService::get response: {:?}", response); + + ServiceError::from_status(response.status())?; + + // In order to debug the output, we collect the whole response. + // The actix-web api is meant to used as a streaming deserializer, + // so we have this little awkward switch. + // + // This is also the reason we depend directly on serde_json, however + // actix already imports that anyway. + if log::log_enabled!(log::Level::Debug) { + let text = response.body().await.map_err(|e| { + ServiceError::JsonDecodeError { + reason: e.to_string(), + } + })?; + log::debug!("GET response: {:?}", String::from_utf8_lossy(&text)); + serde_json::from_slice(&text).map_err(|e| { + ServiceError::JsonDecodeError { + reason: e.to_string(), + } + }) + } else { + response + .json() + .await + .map_err(|e| ServiceError::JsonDecodeError { + reason: e.to_string(), + }) + } + } + + async fn ws( + &mut self, + credentials: Credentials, + ) -> Result< + ( + Self::WebSocket, + ::Stream, + ), + ServiceError, + > { + Ok(AwcWebSocket::with_client( + &mut self.client, + &self.base_url, + Some(&credentials), + ) + .await?) + } } impl AwcPushService { - pub fn new( - _cfg: ServiceConfiguration, - _credentials: T, + /// Creates a new AwcPushService + /// + /// Panics on invalid service url. + pub fn new( + cfg: ServiceConfiguration, + credentials: Credentials, user_agent: &str, + root_ca: &str, ) -> Self { + let base_url = + Url::parse(&cfg.service_urls[0]).expect("valid service url"); + + // SSL setup + let mut ssl_config = rustls::ClientConfig::new(); + ssl_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + ssl_config + .root_store + .add_pem_file(&mut std::io::Cursor::new(root_ca)) + .unwrap(); + let connector = Connector::new() + .rustls(Arc::new(ssl_config)) + .timeout(Duration::from_secs(10)) // https://github.com/actix/actix-web/issues/1047 + .finish(); + let client = awc::ClientBuilder::new() + .connector(connector) + .header("X-Signal-Agent", user_agent) + .timeout(Duration::from_secs(65)); // as in Signal-Android + + let client = if let Some((ident, pass)) = credentials.authorization() { + client.basic_auth(ident, Some(pass)) + } else { + client + }; + let client = client.finish(); + Self { - client: awc::ClientBuilder::new() - .header("X-Signal-Agent", user_agent) - .finish(), + cfg, + base_url, + client, } } } diff --git a/libsignal-service-actix/src/websocket.rs b/libsignal-service-actix/src/websocket.rs new file mode 100644 index 000000000..41d236974 --- /dev/null +++ b/libsignal-service-actix/src/websocket.rs @@ -0,0 +1,143 @@ +use actix::Arbiter; + +use awc::{error::WsProtocolError, ws, ws::Frame}; +use bytes::Bytes; +use futures::{channel::mpsc::*, prelude::*}; +use url::Url; + +use libsignal_service::{ + configuration::Credentials, messagepipe::*, push_service::ServiceError, +}; + +pub struct AwcWebSocket { + socket_sink: Box + Unpin>, +} + +#[derive(thiserror::Error, Debug)] +pub enum AwcWebSocketError { + #[error("Could not connect to the Signal Server")] + ConnectionError(#[from] awc::error::WsClientError), +} + +impl From for ServiceError { + fn from(e: AwcWebSocketError) -> ServiceError { + todo!("error conversion {:?}", e) + } +} + +impl From for AwcWebSocketError { + fn from(e: WsProtocolError) -> AwcWebSocketError { + todo!("error conversion {:?}", e) + // return Some(Err(ServiceError::WsError { + // reason: e.to_string(), + // })); + } +} + +/// Process the WebSocket, until it times out. +async fn process( + mut socket_stream: S, + mut incoming_sink: Sender, +) -> Result<(), AwcWebSocketError> +where + S: Unpin, + S: Stream>, +{ + while let Some(frame) = socket_stream.next().await { + let frame = match frame? { + Frame::Binary(s) => s, + + Frame::Continuation(_c) => todo!(), + Frame::Ping(msg) => { + log::warn!("Received Ping({:?})", msg); + // XXX: send pong and make the above log::debug + continue; + }, + Frame::Pong(msg) => { + log::trace!("Received Pong({:?})", msg); + + continue; + }, + Frame::Text(frame) => { + log::warn!("Frame::Text {:?}", frame); + + // this is a protocol violation, maybe break; is better? + continue; + }, + + Frame::Close(c) => { + log::warn!("Websocket closing: {:?}", c); + + break; + }, + }; + + // Match SendError + if let Err(e) = incoming_sink.send(frame).await { + log::info!("Websocket sink has closed: {:?}.", e); + break; + } + } + Ok(()) +} + +impl AwcWebSocket { + pub(crate) async fn with_client( + client: &mut awc::Client, + base_url: impl std::borrow::Borrow, + credentials: Option<&Credentials>, + ) -> Result<(Self, ::Stream), AwcWebSocketError> + { + let mut url = + base_url.borrow().join("/v1/websocket/").expect("valid url"); + url.set_scheme("wss").expect("valid https base url"); + + if let Some(credentials) = credentials { + url.query_pairs_mut() + .append_pair("login", credentials.login()) + .append_pair( + "password", + credentials.password.as_ref().expect("a password"), + ); + } + + log::trace!("Will start websocket at {:?}", url); + let (response, framed) = client.ws(url.as_str()).connect().await?; + + log::debug!("WebSocket connected: {:?}", response); + + let (incoming_sink, incoming_stream) = channel(1); + + let (socket_sink, socket_stream) = framed.split(); + let processing_task = process(socket_stream, incoming_sink); + + // When the processing_task stops, the consuming stream and sink also + // terminate. + Arbiter::spawn(processing_task.map(|v| match v { + Ok(()) => (), + Err(e) => { + log::warn!("Processing task terminated with error: {:?}", e) + }, + })); + + Ok(( + Self { + socket_sink: Box::new(socket_sink), + }, + incoming_stream, + )) + } +} + +#[async_trait::async_trait(?Send)] +impl WebSocketService for AwcWebSocket { + type Stream = Receiver; + + async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError> { + self.socket_sink + .send(ws::Message::Binary(msg)) + .await + .map_err(AwcWebSocketError::from)?; + Ok(()) + } +} diff --git a/libsignal-service/Cargo.toml b/libsignal-service/Cargo.toml index ade0ece0b..9bd7e04e7 100644 --- a/libsignal-service/Cargo.toml +++ b/libsignal-service/Cargo.toml @@ -11,13 +11,23 @@ libsignal-protocol = { git = "https://github.com/Michael-F-Bryan/libsignal-proto failure = "0.1.5" async-trait = "0.1.30" url = "2.1.1" +base64 = "0.12" +bytes = "0.5" +futures = "0.3" +pin-project = "0.4" thiserror = "1.0" serde = {version = "1.0", features=["derive"]} -serde_json = "1.0" prost = "0.6" +http = "0.2.1" +log = "0.4.8" + +sha2 = "0.9.0" +hmac = "0.8.0" +aes = "0.4.0" +block-modes = "0.5.0" [dev-dependencies] -structopt = "0.2.17" +structopt = "0.3.0" tokio = { version = "0.2", features=["macros"] } [build-dependencies] diff --git a/libsignal-service/examples/registering.rs b/libsignal-service/examples/registering.rs index e6bd26ab3..e52d8137c 100644 --- a/libsignal-service/examples/registering.rs +++ b/libsignal-service/examples/registering.rs @@ -34,10 +34,19 @@ async fn main() -> Result<(), Error> { let password = args.get_password()?; let config = ServiceConfiguration::default(); - let credentials = StaticCredentialsProvider { - uuid: String::new(), + + let mut signaling_key = [0u8; 52]; + base64::decode_config_slice( + args.signaling_key, + base64::STANDARD, + &mut signaling_key, + ) + .unwrap(); + let credentials = Credentials { + uuid: None, e164: args.username, - password, + password: Some(password), + signaling_key, }; let service = PanicingPushService::new( @@ -77,9 +86,14 @@ pub struct Args { #[structopt( long = "user-agent", help = "The user agent to use when contacting servers", - raw(default_value = "libsignal_service::USER_AGENT") + default_value = "libsignal_service::USER_AGENT" )] pub user_agent: String, + #[structopt( + long = "signaling-key", + help = "The key used to encrypt and authenticate messages in transit, base64 encoded." + )] + pub signaling_key: String, } impl Args { diff --git a/libsignal-service/src/configuration.rs b/libsignal-service/src/configuration.rs index ccead381f..52814a56a 100644 --- a/libsignal-service/src/configuration.rs +++ b/libsignal-service/src/configuration.rs @@ -1,3 +1,5 @@ +use crate::envelope::{CIPHER_KEY_SIZE, MAC_KEY_SIZE}; + #[derive(Clone, Default)] pub struct ServiceConfiguration { pub service_urls: Vec, @@ -5,24 +7,29 @@ pub struct ServiceConfiguration { pub contact_discovery_url: Vec, } -pub trait CredentialsProvider { - fn get_uuid(&self) -> String; - - fn get_e164(&self) -> String; - - fn get_password(&self) -> String; -} - -pub struct StaticCredentialsProvider { - pub uuid: String, +#[derive(Clone)] +pub struct Credentials { + pub uuid: Option, pub e164: String, - pub password: String, -} + pub password: Option, -impl CredentialsProvider for StaticCredentialsProvider { - fn get_uuid(&self) -> String { self.uuid.clone() } + pub signaling_key: [u8; CIPHER_KEY_SIZE + MAC_KEY_SIZE], +} - fn get_e164(&self) -> String { self.e164.clone() } +impl Credentials { + /// Kind-of equivalent with `PushServiceSocket::getAuthorizationHeader` + /// + /// None when `self.password == None` + pub fn authorization(&self) -> Option<(&str, &str)> { + let identifier = self.login(); + Some((identifier, self.password.as_ref()?)) + } - fn get_password(&self) -> String { self.password.clone() } + pub fn login(&self) -> &str { + if let Some(uuid) = self.uuid.as_ref() { + uuid + } else { + &self.e164 + } + } } diff --git a/libsignal-service/src/envelope.rs b/libsignal-service/src/envelope.rs new file mode 100644 index 000000000..3efb7ba6b --- /dev/null +++ b/libsignal-service/src/envelope.rs @@ -0,0 +1,174 @@ +#![allow(dead_code)] // XXX: remove when all constants on bottom are used. + +use prost::Message; + +use crate::{ + push_service::ServiceError, utils::serde_optional_base64, ServiceAddress, +}; + +pub use crate::proto::Envelope; + +impl From for Envelope { + fn from(entity: EnvelopeEntity) -> Envelope { + // XXX: Java also checks whether .source and .source_uuid are + // not null. + if entity.source.is_some() && entity.source_device > 0 { + let address = ServiceAddress { + uuid: entity.source_uuid.clone(), + e164: entity.source.clone().unwrap(), + relay: None, + }; + Envelope::new_with_source(entity, address) + } else { + Envelope::new_from_entity(entity) + } + } +} + +impl Envelope { + pub fn decrypt( + input: &[u8], + signaling_key: &[u8; CIPHER_KEY_SIZE + MAC_KEY_SIZE], + is_signaling_key_encrypted: bool, + ) -> Result { + if !is_signaling_key_encrypted { + Ok(Envelope::decode(input)?) + } else { + if input.len() < VERSION_LENGTH + || input[VERSION_OFFSET] != SUPPORTED_VERSION + { + return Err(ServiceError::InvalidFrameError { + reason: "Unsupported signaling cryptogram version".into(), + }); + } + + let aes_key = &signaling_key[..CIPHER_KEY_SIZE]; + let mac_key = &signaling_key[CIPHER_KEY_SIZE..]; + let mac = &input[(input.len() - MAC_SIZE)..]; + let input_for_mac = &input[..(input.len() - MAC_SIZE)]; + let iv = &input[IV_OFFSET..(IV_OFFSET + IV_LENGTH)]; + debug_assert_eq!(mac_key.len(), MAC_KEY_SIZE); + debug_assert_eq!(aes_key.len(), CIPHER_KEY_SIZE); + debug_assert_eq!(iv.len(), IV_LENGTH); + + // Verify MAC + use hmac::{Hmac, Mac, NewMac}; + use sha2::Sha256; + let mut verifier = Hmac::::new_varkey(mac_key) + .expect("Hmac can take any size key"); + verifier.update(input_for_mac); + // XXX: possible timing attack, but we need the bytes for a + // truncated view... + let our_mac = verifier.finalize().into_bytes(); + if &our_mac[..MAC_SIZE] != mac { + return Err(ServiceError::MacError); + } + + use aes::Aes256; + // libsignal-service-java uses Pkcs5, + // but that should not matter. + // https://crypto.stackexchange.com/questions/9043/what-is-the-difference-between-pkcs5-padding-and-pkcs7-padding + use block_modes::{block_padding::Pkcs7, BlockMode, Cbc}; + let cipher = Cbc::::new_var(&aes_key, iv) + .expect("initalization of CBC/AES/PKCS7"); + let input = &input[CIPHERTEXT_OFFSET..(input.len() - MAC_SIZE)]; + let input = cipher.decrypt_vec(input).expect("decryption"); + + Ok(Envelope::decode(&input as &[u8])?) + } + } + + fn new_from_entity(entity: EnvelopeEntity) -> Self { + Envelope { + r#type: Some(entity.r#type), + timestamp: Some(entity.timestamp), + server_timestamp: Some(entity.server_timestamp), + server_guid: entity.source_uuid, + legacy_message: entity.message, + content: entity.content, + ..Default::default() + } + } + + fn new_with_source(entity: EnvelopeEntity, source: ServiceAddress) -> Self { + Envelope { + r#type: Some(entity.r#type), + source_device: Some(entity.source_device), + timestamp: Some(entity.timestamp), + server_timestamp: Some(entity.server_timestamp), + source_e164: Some(source.e164), + source_uuid: source.uuid, + legacy_message: entity.message, + content: entity.content, + ..Default::default() + } + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EnvelopeEntity { + pub r#type: i32, + pub relay: String, + pub timestamp: u64, + pub source: Option, + pub source_uuid: Option, + pub source_device: u32, + #[serde(with = "serde_optional_base64")] + pub message: Option>, + #[serde(with = "serde_optional_base64")] + pub content: Option>, + pub server_timestamp: u64, + pub guid: String, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub(crate) struct EnvelopeEntityList { + pub messages: Vec, +} + +pub(crate) const SUPPORTED_VERSION: u8 = 1; +pub(crate) const CIPHER_KEY_SIZE: usize = 32; +pub(crate) const MAC_KEY_SIZE: usize = 20; +pub(crate) const MAC_SIZE: usize = 10; + +pub(crate) const VERSION_OFFSET: usize = 0; +pub(crate) const VERSION_LENGTH: usize = 1; +pub(crate) const IV_OFFSET: usize = VERSION_OFFSET + VERSION_LENGTH; +pub(crate) const IV_LENGTH: usize = 16; +pub(crate) const CIPHERTEXT_OFFSET: usize = IV_OFFSET + IV_LENGTH; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decrypt_envelope() { + // This is a real message, reencrypted with the zero-key. + let body = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 79, 32, 12, 100, + 26, 157, 130, 210, 254, 174, 87, 45, 238, 126, 68, 39, 188, 171, + 156, 16, 10, 138, 233, 73, 202, 52, 125, 102, 121, 182, 71, 148, 8, + 3, 134, 149, 154, 67, 116, 40, 146, 253, 242, 196, 139, 203, 14, + 174, 254, 78, 27, 47, 108, 60, 202, 60, 42, 210, 242, 58, 13, 185, + 67, 147, 166, 191, 71, 164, 128, 81, 177, 199, 147, 252, 162, 229, + 143, 98, 141, 222, 46, 83, 109, 82, 196, 109, 161, 40, 108, 207, + 82, 53, 162, 205, 171, 33, 140, 5, 74, 76, 150, 22, 122, 176, 189, + 228, 176, 234, 176, 13, 118, 181, 134, 35, 133, 164, 160, 205, 176, + 32, 188, 185, 166, 73, 24, 164, 20, 187, 2, 226, 186, 238, 98, 57, + 51, 76, 156, 83, 113, 72, 184, 50, 220, 49, 138, 46, 36, 4, 49, + 215, 66, 173, 58, 139, 187, 6, 252, 97, 191, 69, 246, 82, 48, 177, + 11, 149, 168, 93, 15, 170, 125, 131, 101, 103, 253, 177, 165, 71, + 85, 219, 207, 106, 12, 58, 47, 159, 33, 243, 107, 6, 117, 141, 209, + 115, 207, 19, 236, 137, 195, 230, 167, 225, 172, 99, 204, 113, 125, + 69, 125, 97, 252, 90, 248, 198, 175, 240, 187, 246, 164, 220, 102, + 7, 224, 124, 28, 170, 6, 4, 137, 155, 233, 85, 125, 93, 119, 97, + 183, 114, 193, 10, 184, 191, 202, 109, 97, 116, 194, 152, 40, 46, + 202, 49, 195, 138, 14, 2, 255, 44, 107, 160, 45, 150, 6, 78, 145, + 99, + ]; + + let signaling_key = [0u8; 52]; + let _ = Envelope::decrypt(&body, &signaling_key, true).unwrap(); + } +} diff --git a/libsignal-service/src/lib.rs b/libsignal-service/src/lib.rs index 0e584f0a7..3216a4fc4 100644 --- a/libsignal-service/src/lib.rs +++ b/libsignal-service/src/lib.rs @@ -1,14 +1,34 @@ mod account_manager; pub mod configuration; +pub mod envelope; +pub mod messagepipe; pub mod models; pub mod push_service; pub mod receiver; mod proto; +mod utils; + pub use crate::account_manager::AccountManager; pub const USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "-rs-", env!("CARGO_PKG_VERSION")); pub struct TrustStore; + +pub struct ServiceAddress { + pub uuid: Option, + // In principe, this is also Option if you follow the Java code. + pub e164: String, + pub relay: Option, +} + +pub mod prelude { + pub use crate::{ + configuration::{Credentials, ServiceConfiguration}, + envelope::Envelope, + push_service::ServiceError, + receiver::MessageReceiver, + }; +} diff --git a/libsignal-service/src/messagepipe.rs b/libsignal-service/src/messagepipe.rs new file mode 100644 index 000000000..46303ce2e --- /dev/null +++ b/libsignal-service/src/messagepipe.rs @@ -0,0 +1,153 @@ +use bytes::{Bytes, BytesMut}; +use futures::{ + channel::mpsc::{self, Sender}, + prelude::*, +}; +use pin_project::pin_project; +use prost::Message; + +pub use crate::{ + configuration::Credentials, + proto::{ + web_socket_message, Envelope, WebSocketMessage, + WebSocketRequestMessage, WebSocketResponseMessage, + }, + push_service::ServiceError, +}; + +#[async_trait::async_trait(?Send)] +pub trait WebSocketService { + type Stream: Stream + Unpin; + + async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError>; +} + +#[pin_project] +pub struct MessagePipe { + ws: WS, + #[pin] + stream: WS::Stream, + credentials: Credentials, +} + +impl MessagePipe { + pub fn from_socket( + ws: WS, + stream: WS::Stream, + credentials: Credentials, + ) -> Self { + MessagePipe { + ws, + stream, + credentials, + } + } + + async fn send_response( + &mut self, + r: WebSocketResponseMessage, + ) -> Result<(), ServiceError> { + let msg = WebSocketMessage { + r#type: Some(web_socket_message::Type::Response.into()), + response: Some(r), + ..Default::default() + }; + let mut buffer = BytesMut::with_capacity(msg.encoded_len()); + msg.encode(&mut buffer).unwrap(); + self.ws.send_message(buffer.into()).await + } + + /// Worker task that + async fn run( + mut self, + mut sink: Sender>, + ) -> Result<(), mpsc::SendError> { + while let Some(frame) = self.stream.next().await { + // WebsocketConnection::onMessage(ByteString) + let msg = match WebSocketMessage::decode(frame) { + Ok(msg) => msg, + Err(e) => { + sink.send(Err(e.into())).await?; + continue; + }, + }; + + log::trace!("Decoded {:?}", msg); + + use web_socket_message::Type; + match (msg.r#type(), msg.request) { + (Type::Unknown, _) => { + sink.send(Err(ServiceError::InvalidFrameError { + reason: "Unknown frame type".into(), + })) + .await?; + }, + (Type::Request, Some(request)) => { + // Java: MessagePipe::read + let response = + WebSocketResponseMessage::from_request(&request); + + if request.is_signal_service_envelope() { + let body = if let Some(body) = request.body.as_ref() { + body + } else { + sink.send(Err(ServiceError::InvalidFrameError { + reason: "Request without body.".into(), + })) + .await?; + continue; + }; + let envelope = Envelope::decrypt( + body, + &self.credentials.signaling_key, + request.is_signal_key_encrypted(), + ); + sink.send(envelope.map_err(Into::into)).await?; + } + + if let Err(e) = self.send_response(response).await { + sink.send(Err(e)).await?; + } + }, + (Type::Request, None) => { + sink.send(Err(ServiceError::InvalidFrameError { + reason: + "Type was request, but does not contain request." + .into(), + })) + .await?; + }, + (Type::Response, _) => {}, + } + } + Ok(()) + } + + /// Returns the stream of `Envelope`s + /// + /// Envelopes yielded are acknowledged. + pub fn stream(self) -> impl Stream> { + let (sink, stream) = mpsc::channel(1); + + let stream = stream.map(Some); + let runner = self.run(sink).map(|_| { + log::info!("Sink was closed."); + None + }); + + let combined = futures::stream::select(stream, runner.into_stream()); + combined.filter_map(|x| async { x }) + } +} + +/// WebSocketService that panics on every request, mainly for example code. +pub struct PanicingWebSocketService; + +#[async_trait::async_trait(?Send)] +impl WebSocketService for PanicingWebSocketService { + type Stream = futures::channel::mpsc::Receiver; + + async fn send_message(&mut self, _msg: Bytes) -> Result<(), ServiceError> { + unimplemented!(); + } +} diff --git a/libsignal-service/src/proto.rs b/libsignal-service/src/proto.rs index af7dd6d2f..410bdabad 100644 --- a/libsignal-service/src/proto.rs +++ b/libsignal-service/src/proto.rs @@ -1 +1,62 @@ include!(concat!(env!("OUT_DIR"), "/signalservice.rs")); + +use std::ops::Deref; + +impl WebSocketRequestMessage { + /// Equivalent of + /// `SignalServiceMessagePipe::isSignalServiceEnvelope(WebSocketMessage)`. + pub fn is_signal_service_envelope(&self) -> bool { + self.verb.as_ref().map(Deref::deref) == Some("PUT") + && self.path.as_ref().map(Deref::deref) == Some("/api/v1/message") + } + + /// Equivalent of + /// `SignalServiceMessagePipe::isSignalKeyEncrypted(WebSocketMessage)`. + pub fn is_signal_key_encrypted(&self) -> bool { + if self.headers.len() == 0 { + return true; + } + + for header in &self.headers { + let parts: Vec<_> = header.split(':').collect(); + if parts.len() != 2 { + log::warn!( + "Got a weird header: {:?}, split in {:?}", + header, + parts + ); + continue; + } + + if parts[0].trim().eq_ignore_ascii_case("X-Signal-Key") { + if parts[1].trim().eq_ignore_ascii_case("false") { + return false; + } + } + } + + false + } +} + +impl WebSocketResponseMessage { + /// Equivalent of + /// `SignalServiceMessagePipe::isSignalServiceEnvelope(WebSocketMessage)`. + pub fn from_request(msg: &WebSocketRequestMessage) -> Self { + if msg.is_signal_service_envelope() { + WebSocketResponseMessage { + id: msg.id, + status: Some(200), + message: Some("OK".to_string()), + ..Default::default() + } + } else { + WebSocketResponseMessage { + id: msg.id, + status: Some(400), + message: Some("Unknown".to_string()), + ..Default::default() + } + } + } +} diff --git a/libsignal-service/src/push_service.rs b/libsignal-service/src/push_service.rs index fdf2f2f7e..acee856c0 100644 --- a/libsignal-service/src/push_service.rs +++ b/libsignal-service/src/push_service.rs @@ -1,4 +1,11 @@ -use crate::configuration::{CredentialsProvider, ServiceConfiguration}; +use crate::{ + configuration::{Credentials, ServiceConfiguration}, + envelope::*, + messagepipe::WebSocketService, +}; + +use http::StatusCode; +use serde::Deserialize; pub const CREATE_ACCOUNT_SMS_PATH: &str = "/v1/accounts/sms/code/%s?client=%s"; pub const CREATE_ACCOUNT_VOICE_PATH: &str = "/v1/accounts/voice/code/%s"; @@ -23,7 +30,7 @@ pub const DIRECTORY_TOKENS_PATH: &str = "/v1/directory/tokens"; pub const DIRECTORY_VERIFY_PATH: &str = "/v1/directory/%s"; pub const DIRECTORY_AUTH_PATH: &str = "/v1/directory/auth"; pub const DIRECTORY_FEEDBACK_PATH: &str = "/v1/directory/feedback-v3/%s"; -pub const MESSAGE_PATH: &str = "/v1/messages/%s"; +pub const MESSAGE_PATH: &str = "/v1/messages/"; // optionally with destination appended pub const SENDER_ACK_MESSAGE_PATH: &str = "/v1/messages/%s/%d"; pub const UUID_ACK_MESSAGE_PATH: &str = "/v1/messages/uuid/%s"; pub const ATTACHMENT_PATH: &str = "/v2/attachments/form/upload"; @@ -46,11 +53,61 @@ pub enum SmsVerificationCodeResponse { } #[derive(thiserror::Error, Debug)] -pub enum ServiceError {} +pub enum ServiceError { + #[error("Error sending request: {reason}")] + SendError { reason: String }, + #[error("Error decoding JSON response: {reason}")] + JsonDecodeError { reason: String }, + + #[error("Rate limit exceeded")] + RateLimitExceeded, + #[error("Authorization failed")] + Unauthorized, + #[error("Unexpected response: HTTP {http_code}")] + UnhandledResponseCode { http_code: u16 }, + + #[error("Websocket error: {reason}")] + WsError { reason: String }, + #[error("Websocket closing: {reason}")] + WsClosing { reason: String }, + + #[error("Undecodable frame")] + DecodeError(#[from] prost::DecodeError), + + #[error("Invalid frame: {reason}")] + InvalidFrameError { reason: String }, + + #[error("MAC error")] + MacError, +} + +impl ServiceError { + pub fn from_status(code: http::StatusCode) -> Result<(), Self> { + match code { + StatusCode::OK => Ok(()), + StatusCode::NO_CONTENT => Ok(()), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ServiceError::Unauthorized) + }, + StatusCode::PAYLOAD_TOO_LARGE => { + // This is 413 and means rate limit exceeded for Signal. + Err(ServiceError::RateLimitExceeded) + }, + // XXX: fill in rest from PushServiceSocket + _ => Err(ServiceError::UnhandledResponseCode { + http_code: code.as_u16(), + }), + } + } +} #[async_trait::async_trait(?Send)] pub trait PushService { - async fn get(&mut self, path: &str) -> Result<(), ServiceError>; + type WebSocket: WebSocketService; + + async fn get(&mut self, path: &str) -> Result + where + for<'de> T: Deserialize<'de>; async fn request_sms_verification_code( &mut self, @@ -58,6 +115,24 @@ pub trait PushService { self.get(CREATE_ACCOUNT_SMS_PATH).await?; Ok(SmsVerificationCodeResponse::SmsSent) } + + async fn get_messages( + &mut self, + ) -> Result, ServiceError> { + let entity_list: EnvelopeEntityList = self.get(MESSAGE_PATH).await?; + Ok(entity_list.messages) + } + + async fn ws( + &mut self, + credentials: Credentials, + ) -> Result< + ( + Self::WebSocket, + ::Stream, + ), + ServiceError, + >; } /// PushService that panics on every request, mainly for example code. @@ -66,9 +141,9 @@ pub struct PanicingPushService; impl PanicingPushService { /// A PushService implementation typically takes a ServiceConfiguration, /// credentials and a user agent. - pub fn new( + pub fn new( _cfg: ServiceConfiguration, - _credentials: T, + _credentials: Credentials, _user_agent: &str, ) -> Self { Self @@ -77,7 +152,25 @@ impl PanicingPushService { #[async_trait::async_trait(?Send)] impl PushService for PanicingPushService { - async fn get(&mut self, path: &str) -> Result<(), ServiceError> { + type WebSocket = crate::messagepipe::PanicingWebSocketService; + + async fn get(&mut self, _path: &str) -> Result + where + for<'de> T: Deserialize<'de>, + { + unimplemented!() + } + + async fn ws( + &mut self, + _credentials: Credentials, + ) -> Result< + ( + Self::WebSocket, + ::Stream, + ), + ServiceError, + > { unimplemented!() } } diff --git a/libsignal-service/src/receiver.rs b/libsignal-service/src/receiver.rs index b7332b633..7c31bc9db 100644 --- a/libsignal-service/src/receiver.rs +++ b/libsignal-service/src/receiver.rs @@ -1,15 +1,41 @@ -use crate::{configuration::*, push_service::PushService}; - -use libsignal_protocol::StoreContext; +use crate::{ + configuration::Credentials, envelope::Envelope, messagepipe::MessagePipe, + push_service::*, +}; /// Equivalent of Java's `SignalServiceMessageReceiver`. pub struct MessageReceiver { service: Service, - context: StoreContext, +} + +#[derive(thiserror::Error, Debug)] +pub enum MessageReceiverError { + #[error("ServiceError")] + ServiceError(#[from] ServiceError), } impl MessageReceiver { - pub fn new(service: Service, context: StoreContext) -> Self { - MessageReceiver { service, context } + pub fn new(service: Service) -> Self { MessageReceiver { service } } + + /// One-off method to receive all pending messages. + /// + /// Equivalent with Java's `SignalServiceMessageReceiver::retrieveMessages`. + /// + /// For streaming messages, use a `MessagePipe` through + /// [`MessageReceiver::create_message_pipe()`]. + pub async fn retrieve_messages( + &mut self, + ) -> Result, MessageReceiverError> { + let entities = self.service.get_messages().await?; + let entities = entities.into_iter().map(Envelope::from).collect(); + Ok(entities) + } + + pub async fn create_message_pipe( + &mut self, + credentials: Credentials, + ) -> Result, MessageReceiverError> { + let (ws, stream) = self.service.ws(credentials.clone()).await?; + Ok(MessagePipe::from_socket(ws, stream, credentials)) } } diff --git a/libsignal-service/src/utils.rs b/libsignal-service/src/utils.rs new file mode 100644 index 000000000..bf5a6e75c --- /dev/null +++ b/libsignal-service/src/utils.rs @@ -0,0 +1,58 @@ +#[allow(dead_code)] +pub mod serde_base64 { + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize(bytes: &T, serializer: S) -> Result + where + T: AsRef<[u8]>, + S: Serializer, + { + serializer.serialize_str(&base64::encode(bytes.as_ref())) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + use serde::de::Error; + String::deserialize(deserializer).and_then(|string| { + base64::decode(&string) + .map_err(|err| Error::custom(err.to_string())) + }) + } +} + +pub mod serde_optional_base64 { + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize( + bytes: &Option, + serializer: S, + ) -> Result + where + T: AsRef<[u8]>, + S: Serializer, + { + match bytes { + Some(bytes) => { + serializer.serialize_str(&base64::encode(bytes.as_ref())) + }, + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result>, D::Error> + where + D: Deserializer<'de>, + { + use serde::de::Error; + match Option::::deserialize(deserializer)? { + Some(s) => base64::decode(&s) + .map_err(|err| Error::custom(err.to_string())) + .map(Some), + None => Ok(None), + } + } +}