Skip to content

Commit

Permalink
refactor: flexible Authorization header type of HTTP request
Browse files Browse the repository at this point in the history
- Change JWT token retriever actor to return "Bearer {token}",
previously it was just "{token}"
- Deprecate Auth enum in downloader so that downloader can
accept any authorization header value easily

Signed-off-by: Rina Fujino <rina.fujino.23@gmail.com>
  • Loading branch information
rina23q committed Oct 16, 2024
1 parent 09c1d4c commit 95d782a
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 163 deletions.
27 changes: 6 additions & 21 deletions crates/common/download/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn default_backoff() -> ExponentialBackoff {
pub struct DownloadInfo {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth: Option<Auth>,
pub auth: Option<String>,
}

impl From<&str> for DownloadInfo {
Expand All @@ -69,9 +69,9 @@ impl DownloadInfo {
}

/// Creates new [`DownloadInfo`] from a URL with authentication.
pub fn with_auth(self, auth: Auth) -> Self {
pub fn with_auth(self, auth: &str) -> Self {
Self {
auth: Some(auth),
auth: Some(auth.into()),
..self
}
}
Expand All @@ -85,21 +85,6 @@ impl DownloadInfo {
}
}

/// Possible authentication schemes
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub enum Auth {
/// HTTP Bearer authentication
Bearer(String),
}

impl Auth {
pub fn new_bearer(token: &str) -> Self {
Self::Bearer(token.into())
}
}

/// A struct which manages file downloads.
#[derive(Debug)]
pub struct Downloader {
Expand Down Expand Up @@ -384,8 +369,8 @@ impl Downloader {

let operation = || async {
let mut request = self.client.get(url.url());
if let Some(Auth::Bearer(token)) = &url.auth {
request = request.bearer_auth(token)
if let Some(header_value) = &url.auth {
request = request.header("Authorization", header_value)
}

if range_start != 0 {
Expand Down Expand Up @@ -926,7 +911,7 @@ mod tests {
// applying token if `with_token` = true
let url = {
if with_token {
url.with_auth(Auth::Bearer(String::from("token")))
url.with_auth("Bearer with token")
} else {
url
}
Expand Down
1 change: 0 additions & 1 deletion crates/common/download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
mod download;
mod error;

pub use crate::download::Auth;
pub use crate::download::DownloadInfo;
pub use crate::download::Downloader;
pub use crate::error::DownloadError;
2 changes: 1 addition & 1 deletion crates/common/tedge_config_macros/src/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl AsRef<OsStr> for ProfileName {
}

fn validate_profile_name(value: &str) -> Result<(), anyhow::Error> {
ensure!(value.starts_with("@"), "Profile names must start with `@`");
ensure!(value.starts_with('@'), "Profile names must start with `@`");
ensure!(
value[1..]
.chars()
Expand Down
16 changes: 7 additions & 9 deletions crates/extensions/c8y_auth_proxy/src/actor.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
use std::convert::Infallible;
use std::net::IpAddr;

use axum::async_trait;
use c8y_http_proxy::credentials::C8YJwtRetriever;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::AuthResult;
use c8y_http_proxy::credentials::AuthRetriever;
use camino::Utf8PathBuf;
use futures::channel::mpsc;
use futures::StreamExt;
use std::convert::Infallible;
use std::net::IpAddr;
use tedge_actors::Actor;
use tedge_actors::Builder;
use tedge_actors::DynSender;
use tedge_actors::RuntimeError;
use tedge_actors::RuntimeRequest;
use tedge_actors::RuntimeRequestSink;
use tedge_actors::Sequential;
use tedge_actors::ServerActorBuilder;
use tedge_actors::Service;
use tedge_config::TEdgeConfig;
use tedge_config_macros::OptionalConfig;
use tracing::info;
Expand All @@ -40,14 +38,14 @@ impl C8yAuthProxyBuilder {
pub fn try_from_config(
config: &TEdgeConfig,
c8y_profile: Option<&str>,
jwt: &mut ServerActorBuilder<C8YJwtRetriever, Sequential>,
auth: &mut impl Service<(), AuthResult>,
) -> anyhow::Result<Self> {
let reqwest_client = config.cloud_root_certs().client();
let c8y = config.c8y.try_get(c8y_profile)?;
let app_data = AppData {
is_https: true,
host: c8y.http.or_config_not_set()?.to_string(),
token_manager: TokenManager::new(JwtRetriever::new(jwt)).shared(),
token_manager: TokenManager::new(AuthRetriever::new(auth)).shared(),
client: reqwest_client,
};
let bind = &c8y.proxy.bind;
Expand Down
25 changes: 13 additions & 12 deletions crates/extensions/c8y_auth_proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ fn tungstenite_to_axum(message: tungstenite::Message) -> axum::extract::ws::Mess
}

async fn connect_to_websocket(
token: &str,
auth_value: &str,
headers: &HeaderMap<HeaderValue>,
uri: &str,
host: &TargetHost,
Expand All @@ -235,7 +235,7 @@ async fn connect_to_websocket(
for (name, value) in headers {
req = req.header(name.as_str(), value);
}
req = req.header("Authorization", format!("Bearer {token}"));
req = req.header("Authorization", auth_value);
let req = req
.uri(uri)
.header(HOST, host.without_scheme.as_ref())
Expand Down Expand Up @@ -405,9 +405,9 @@ async fn respond_to(
};
let auth: fn(reqwest::RequestBuilder, &str) -> reqwest::RequestBuilder =
if headers.contains_key("Authorization") {
|req, _token| req
|req, _auth_value| req
} else {
|req, token| req.bearer_auth(token)
|req, auth_value| req.header("Authorization", auth_value)
};
headers.remove(HOST);

Expand Down Expand Up @@ -436,7 +436,7 @@ async fn respond_to(
let destination = format!("{}/tenant/currentTenant", host.http);
let response = client
.head(&destination)
.bearer_auth(&token)
.header("Authorization", token.to_string())
.send()
.await
.with_context(|| format!("making HEAD request to {destination}"))?;
Expand Down Expand Up @@ -499,9 +499,9 @@ mod tests {
use axum::http::Request;
use axum::middleware::Next;
use axum::TypedHeader;
use c8y_http_proxy::credentials::JwtRequest;
use c8y_http_proxy::credentials::JwtResult;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::AuthRequest;
use c8y_http_proxy::credentials::AuthResult;
use c8y_http_proxy::credentials::AuthRetriever;
use camino::Utf8PathBuf;
use futures::channel::mpsc;
use futures::future::ready;
Expand Down Expand Up @@ -1113,7 +1113,7 @@ mod tests {
let state = AppData {
is_https: false,
host: target_host.into(),
token_manager: TokenManager::new(JwtRetriever::new(&mut retriever)).shared(),
token_manager: TokenManager::new(AuthRetriever::new(&mut retriever)).shared(),
client: reqwest::Client::new(),
};
let trust_store = ca_dir
Expand Down Expand Up @@ -1147,15 +1147,16 @@ mod tests {

#[async_trait]
impl Server for IterJwtRetriever {
type Request = JwtRequest;
type Response = JwtResult;
type Request = AuthRequest;
type Response = AuthResult;

fn name(&self) -> &str {
"IterJwtRetriever"
}

async fn handle(&mut self, _request: Self::Request) -> Self::Response {
Ok(self.tokens.next().unwrap().into())
let auth_value = format!("Bearer {}", self.tokens.next().unwrap().to_string());
Ok(auth_value)
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/extensions/c8y_auth_proxy/src/tokens.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::AuthRetriever;
use tokio::sync::Mutex;

#[derive(Clone)]
Expand All @@ -16,12 +16,12 @@ impl SharedTokenManager {
}

pub struct TokenManager {
recv: JwtRetriever,
recv: AuthRetriever,
cached: Option<Arc<str>>,
}

impl TokenManager {
pub fn new(recv: JwtRetriever) -> Self {
pub fn new(recv: AuthRetriever) -> Self {
Self { recv, cached: None }
}

Expand Down
6 changes: 3 additions & 3 deletions crates/extensions/c8y_firmware_manager/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use c8y_api::smartrest::message::collect_smartrest_messages;
use c8y_api::smartrest::message::get_smartrest_template_id;
use c8y_api::smartrest::smartrest_deserializer::SmartRestFirmwareRequest;
use c8y_api::smartrest::smartrest_deserializer::SmartRestRequestGeneric;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::AuthRetriever;
use log::error;
use log::info;
use log::warn;
Expand Down Expand Up @@ -84,7 +84,7 @@ impl FirmwareManagerActor {
config: FirmwareManagerConfig,
input_receiver: LoggingReceiver<FirmwareInput>,
mqtt_publisher: DynSender<MqttMessage>,
jwt_retriever: JwtRetriever,
auth_retriever: AuthRetriever,
download_sender: ClientMessageBox<IdDownloadRequest, IdDownloadResult>,
progress_sender: DynSender<OperationOutcome>,
) -> Self {
Expand All @@ -93,7 +93,7 @@ impl FirmwareManagerActor {
worker: FirmwareManagerWorker::new(
config,
mqtt_publisher,
jwt_retriever,
auth_retriever,
download_sender,
progress_sender,
),
Expand Down
10 changes: 5 additions & 5 deletions crates/extensions/c8y_firmware_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mod tests;

use actor::FirmwareInput;
use actor::FirmwareManagerActor;
use c8y_http_proxy::credentials::JwtResult;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::AuthResult;
use c8y_http_proxy::credentials::AuthRetriever;
pub use config::*;
use tedge_actors::futures::channel::mpsc;
use tedge_actors::Builder;
Expand Down Expand Up @@ -39,7 +39,7 @@ pub struct FirmwareManagerBuilder {
config: FirmwareManagerConfig,
input_receiver: LoggingReceiver<FirmwareInput>,
mqtt_publisher: DynSender<MqttMessage>,
jwt_retriever: JwtRetriever,
jwt_retriever: AuthRetriever,
download_sender: ClientMessageBox<IdDownloadRequest, IdDownloadResult>,
progress_sender: DynSender<OperationOutcome>,
signal_sender: mpsc::Sender<RuntimeRequest>,
Expand All @@ -49,7 +49,7 @@ impl FirmwareManagerBuilder {
pub fn try_new(
config: FirmwareManagerConfig,
mqtt_actor: &mut (impl MessageSource<MqttMessage, TopicFilter> + MessageSink<MqttMessage>),
jwt_actor: &mut impl Service<(), JwtResult>,
jwt_actor: &mut impl Service<(), AuthResult>,
downloader_actor: &mut impl Service<IdDownloadRequest, IdDownloadResult>,
) -> Result<FirmwareManagerBuilder, FileError> {
Self::init(&config.data_dir)?;
Expand All @@ -65,7 +65,7 @@ impl FirmwareManagerBuilder {

mqtt_actor.connect_sink(Self::subscriptions(&config.c8y_prefix), &mqtt_sender);
let mqtt_publisher = mqtt_actor.get_sender();
let jwt_retriever = JwtRetriever::new(jwt_actor);
let jwt_retriever = AuthRetriever::new(jwt_actor);
let download_sender = ClientMessageBox::new(downloader_actor);
let progress_sender = input_sender.into();
Ok(Self {
Expand Down
18 changes: 8 additions & 10 deletions crates/extensions/c8y_firmware_manager/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use assert_json_diff::assert_json_include;
use c8y_api::smartrest::topic::C8yTopic;
use c8y_http_proxy::credentials::JwtRequest;
use c8y_http_proxy::credentials::AuthRequest;
use serde_json::json;
use sha256::digest;
use std::io;
Expand All @@ -17,7 +17,6 @@ use tedge_actors::RuntimeError;
use tedge_actors::Sender;
use tedge_actors::SimpleMessageBox;
use tedge_actors::SimpleMessageBoxBuilder;
use tedge_api::Auth;
use tedge_api::DownloadError;
use tedge_downloader_ext::DownloadResponse;
use tedge_mqtt_ext::Topic;
Expand Down Expand Up @@ -274,7 +273,7 @@ async fn create_download_request_with_c8y_auth() -> Result<(), DynError> {
spawn_firmware_manager(&mut ttd, DEFAULT_REQUEST_TIMEOUT_SEC, false).await?;

let c8y_download_url = format!("http://{C8Y_HOST}/file/end/point");
let token = "token";
let auth_header_value = "Bearer token";

// Publish firmware update operation to child device.
let c8y_firmware_update_msg = MqttMessage::new(
Expand All @@ -288,7 +287,9 @@ async fn create_download_request_with_c8y_auth() -> Result<(), DynError> {
assert!(jwt_request.is_some());

// Return JWT token.
jwt_message_box.send(Ok(token.to_string())).await?;
jwt_message_box
.send(Ok(auth_header_value.to_string()))
.await?;

// Assert firmware download request.
let (_id, download_request) = downloader_message_box.recv().await.unwrap();
Expand All @@ -297,10 +298,7 @@ async fn create_download_request_with_c8y_auth() -> Result<(), DynError> {
download_request.file_path,
ttd.path().join("cache").join(digest(c8y_download_url))
);
assert_eq!(
download_request.auth,
Some(Auth::Bearer(String::from(token)))
);
assert_eq!(download_request.auth, Some(auth_header_value.into()));

Ok(())
}
Expand Down Expand Up @@ -622,7 +620,7 @@ async fn spawn_firmware_manager(
(
JoinHandle<Result<(), RuntimeError>>,
TimedMessageBox<SimpleMessageBox<MqttMessage, MqttMessage>>,
TimedMessageBox<FakeServerBox<JwtRequest, JwtResult>>,
TimedMessageBox<FakeServerBox<AuthRequest, AuthResult>>,
TimedMessageBox<FakeServerBox<IdDownloadRequest, IdDownloadResult>>,
),
DynError,
Expand All @@ -649,7 +647,7 @@ async fn spawn_firmware_manager(

let mut mqtt_builder: SimpleMessageBoxBuilder<MqttMessage, MqttMessage> =
SimpleMessageBoxBuilder::new("MQTT", 5);
let mut jwt_builder: FakeServerBoxBuilder<JwtRequest, JwtResult> = FakeServerBox::builder();
let mut jwt_builder: FakeServerBoxBuilder<AuthRequest, AuthResult> = FakeServerBox::builder();
let mut downloader_builder: FakeServerBoxBuilder<IdDownloadRequest, IdDownloadResult> =
FakeServerBox::builder();

Expand Down
Loading

0 comments on commit 95d782a

Please sign in to comment.