Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: flexible Authorization header type of HTTP request #3190

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 14 additions & 27 deletions crates/common/download/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use log::warn;
use nix::sys::statvfs;
pub use partial_response::InvalidResponseError;
use reqwest::header;
use reqwest::header::HeaderMap;
use reqwest::Client;
use reqwest::Identity;
use serde::Deserialize;
Expand All @@ -20,8 +21,6 @@ use std::fs::File;
use std::io::Seek;
use std::io::SeekFrom;
use std::io::Write;
#[cfg(target_os = "linux")]
use std::os::unix::prelude::AsRawFd;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
Expand All @@ -31,6 +30,8 @@ use tedge_utils::file::FileError;
use nix::fcntl::fallocate;
#[cfg(target_os = "linux")]
use nix::fcntl::FallocateFlags;
#[cfg(target_os = "linux")]
use std::os::unix::prelude::AsRawFd;

fn default_backoff() -> ExponentialBackoff {
// Default retry is an exponential retry with a limit of 15 minutes total.
Expand All @@ -49,8 +50,8 @@ fn default_backoff() -> ExponentialBackoff {
#[serde(deny_unknown_fields)]
pub struct DownloadInfo {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth: Option<Auth>,
#[serde(skip)]
pub headers: HeaderMap,
}

impl From<&str> for DownloadInfo {
Expand All @@ -64,14 +65,14 @@ impl DownloadInfo {
pub fn new(url: &str) -> Self {
Self {
url: url.into(),
auth: None,
headers: HeaderMap::new(),
}
}

/// Creates new [`DownloadInfo`] from a URL with authentication.
pub fn with_auth(self, auth: Auth) -> Self {
pub fn with_headers(self, header_map: HeaderMap) -> Self {
Self {
auth: Some(auth),
headers: header_map,
..self
}
}
Expand All @@ -85,21 +86,6 @@ impl DownloadInfo {
}
}

/// Possible authentication schemes
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub enum Auth {
/// HTTP Bearer authentication
Bearer(String),
}
didier-wenzek marked this conversation as resolved.
Show resolved Hide resolved

impl Auth {
pub fn new_bearer(token: &str) -> Self {
Self::Bearer(token.into())
}
}

/// A struct which manages file downloads.
#[derive(Debug)]
pub struct Downloader {
Expand Down Expand Up @@ -384,9 +370,7 @@ impl Downloader {

let operation = || async {
let mut request = self.client.get(url.url());
if let Some(Auth::Bearer(token)) = &url.auth {
request = request.bearer_auth(token)
}
request = request.headers(url.headers.clone());

if range_start != 0 {
request = request.header("Range", format!("bytes={range_start}-"));
Expand Down Expand Up @@ -482,6 +466,7 @@ fn try_pre_allocate_space(file: &File, path: &Path, file_len: u64) -> Result<(),
#[allow(deprecated)]
mod tests {
use super::*;
use hyper::header::AUTHORIZATION;
use std::io::Write;
use tempfile::tempdir;
use tempfile::NamedTempFile;
Expand Down Expand Up @@ -923,10 +908,12 @@ mod tests {
}
};

// applying token if `with_token` = true
// applying http auth header
let url = {
if with_token {
url.with_auth(Auth::Bearer(String::from("token")))
let mut headers = HeaderMap::new();
headers.append(AUTHORIZATION, "Bearer token".parse().unwrap());
url.with_headers(headers)
} else {
url
}
Expand Down
1 change: 0 additions & 1 deletion crates/common/download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
mod download;
mod error;

pub use crate::download::Auth;
pub use crate::download::DownloadInfo;
pub use crate::download::Downloader;
pub use crate::error::DownloadError;
5 changes: 3 additions & 2 deletions crates/core/c8y_api/src/http_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use mqtt_channel::PubChannel;
use mqtt_channel::StreamExt;
use mqtt_channel::Topic;
use mqtt_channel::TopicFilter;
use reqwest::header::HeaderMap;
use reqwest::Url;
use std::collections::HashMap;
use std::time::Duration;
Expand All @@ -27,7 +28,7 @@ pub struct C8yEndPoint {
c8y_host: String,
c8y_mqtt_host: String,
pub device_id: String,
pub token: Option<String>,
pub headers: HeaderMap,
devices_internal_id: HashMap<String, String>,
}

Expand All @@ -37,7 +38,7 @@ impl C8yEndPoint {
c8y_host: c8y_host.into(),
c8y_mqtt_host: c8y_mqtt_host.into(),
device_id: device_id.into(),
token: None,
headers: HeaderMap::new(),
devices_internal_id: HashMap::new(),
}
}
Expand Down
16 changes: 7 additions & 9 deletions crates/extensions/c8y_auth_proxy/src/actor.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
use std::convert::Infallible;
use std::net::IpAddr;

use axum::async_trait;
use c8y_http_proxy::credentials::C8YJwtRetriever;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::HttpHeaderResult;
use c8y_http_proxy::credentials::HttpHeaderRetriever;
use camino::Utf8PathBuf;
use futures::channel::mpsc;
use futures::StreamExt;
use std::convert::Infallible;
use std::net::IpAddr;
use tedge_actors::Actor;
use tedge_actors::Builder;
use tedge_actors::DynSender;
use tedge_actors::RuntimeError;
use tedge_actors::RuntimeRequest;
use tedge_actors::RuntimeRequestSink;
use tedge_actors::Sequential;
use tedge_actors::ServerActorBuilder;
use tedge_actors::Service;
use tedge_config::TEdgeConfig;
use tedge_config_macros::OptionalConfig;
use tracing::info;
Expand All @@ -40,14 +38,14 @@ impl C8yAuthProxyBuilder {
pub fn try_from_config(
config: &TEdgeConfig,
c8y_profile: Option<&str>,
jwt: &mut ServerActorBuilder<C8YJwtRetriever, Sequential>,
header_retriever: &mut impl Service<(), HttpHeaderResult>,
) -> anyhow::Result<Self> {
let reqwest_client = config.cloud_root_certs().client();
let c8y = config.c8y.try_get(c8y_profile)?;
let app_data = AppData {
is_https: true,
host: c8y.http.or_config_not_set()?.to_string(),
token_manager: TokenManager::new(JwtRetriever::new(jwt)).shared(),
token_manager: TokenManager::new(HttpHeaderRetriever::new(header_retriever)).shared(),
client: reqwest_client,
};
let bind = &c8y.proxy.bind;
Expand Down
35 changes: 22 additions & 13 deletions crates/extensions/c8y_auth_proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use futures::Sink;
use futures::SinkExt;
use futures::Stream;
use futures::StreamExt;
use hyper::header::AUTHORIZATION;
use hyper::header::HOST;
use hyper::HeaderMap;
use reqwest::Method;
Expand Down Expand Up @@ -226,7 +227,7 @@ fn tungstenite_to_axum(message: tungstenite::Message) -> axum::extract::ws::Mess
}

async fn connect_to_websocket(
token: &str,
auth_value: &str,
headers: &HeaderMap<HeaderValue>,
uri: &str,
host: &TargetHost,
Expand All @@ -235,7 +236,7 @@ async fn connect_to_websocket(
for (name, value) in headers {
req = req.header(name.as_str(), value);
}
req = req.header("Authorization", format!("Bearer {token}"));
req = req.header(AUTHORIZATION, auth_value);
let req = req
.uri(uri)
.header(HOST, host.without_scheme.as_ref())
Expand Down Expand Up @@ -404,10 +405,10 @@ async fn respond_to(
None => "",
};
let auth: fn(reqwest::RequestBuilder, &str) -> reqwest::RequestBuilder =
if headers.contains_key("Authorization") {
|req, _token| req
if headers.contains_key(AUTHORIZATION) {
|req, _auth_value| req
} else {
|req, token| req.bearer_auth(token)
|req, auth_value| req.header(AUTHORIZATION, auth_value)
};
headers.remove(HOST);

Expand Down Expand Up @@ -436,7 +437,7 @@ async fn respond_to(
let destination = format!("{}/tenant/currentTenant", host.http);
let response = client
.head(&destination)
.bearer_auth(&token)
.header(AUTHORIZATION, token.to_string())
.send()
.await
.with_context(|| format!("making HEAD request to {destination}"))?;
Expand Down Expand Up @@ -496,12 +497,13 @@ mod tests {
use axum::body::Bytes;
use axum::headers::authorization::Bearer;
use axum::headers::Authorization;
use axum::http::header::AUTHORIZATION;
use axum::http::Request;
use axum::middleware::Next;
use axum::TypedHeader;
use c8y_http_proxy::credentials::JwtRequest;
use c8y_http_proxy::credentials::JwtResult;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::HttpHeaderRequest;
use c8y_http_proxy::credentials::HttpHeaderResult;
use c8y_http_proxy::credentials::HttpHeaderRetriever;
use camino::Utf8PathBuf;
use futures::channel::mpsc;
use futures::future::ready;
Expand Down Expand Up @@ -1113,7 +1115,7 @@ mod tests {
let state = AppData {
is_https: false,
host: target_host.into(),
token_manager: TokenManager::new(JwtRetriever::new(&mut retriever)).shared(),
token_manager: TokenManager::new(HttpHeaderRetriever::new(&mut retriever)).shared(),
client: reqwest::Client::new(),
};
let trust_store = ca_dir
Expand Down Expand Up @@ -1147,15 +1149,22 @@ mod tests {

#[async_trait]
impl Server for IterJwtRetriever {
type Request = JwtRequest;
type Response = JwtResult;
type Request = HttpHeaderRequest;
type Response = HttpHeaderResult;

fn name(&self) -> &str {
"IterJwtRetriever"
}

async fn handle(&mut self, _request: Self::Request) -> Self::Response {
Ok(self.tokens.next().unwrap().into())
let mut header_map = HeaderMap::new();
header_map.insert(
AUTHORIZATION,
format!("Bearer {}", self.tokens.next().unwrap())
.parse()
.unwrap(),
);
Ok(header_map)
}
}

Expand Down
14 changes: 10 additions & 4 deletions crates/extensions/c8y_auth_proxy/src/tokens.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use anyhow::Context;
use hyper::header::AUTHORIZATION;
use std::sync::Arc;

use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::HttpHeaderRetriever;
use tokio::sync::Mutex;

#[derive(Clone)]
Expand All @@ -16,12 +18,12 @@ impl SharedTokenManager {
}

pub struct TokenManager {
recv: JwtRetriever,
recv: HttpHeaderRetriever,
cached: Option<Arc<str>>,
}

impl TokenManager {
pub fn new(recv: JwtRetriever) -> Self {
pub fn new(recv: HttpHeaderRetriever) -> Self {
Self { recv, cached: None }
}

Expand All @@ -41,7 +43,11 @@ impl TokenManager {
}

async fn refresh(&mut self) -> Result<Arc<str>, anyhow::Error> {
self.cached = Some(self.recv.await_response(()).await??.into());
let header_map = self.recv.await_response(()).await??;
let auth_header_value = header_map
.get(AUTHORIZATION)
.context("Authorization is missing from header")?;
self.cached = Some(auth_header_value.to_str()?.into());
Ok(self.cached.as_ref().unwrap().clone())
}
}
6 changes: 3 additions & 3 deletions crates/extensions/c8y_firmware_manager/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use c8y_api::smartrest::message::collect_smartrest_messages;
use c8y_api::smartrest::message::get_smartrest_template_id;
use c8y_api::smartrest::smartrest_deserializer::SmartRestFirmwareRequest;
use c8y_api::smartrest::smartrest_deserializer::SmartRestRequestGeneric;
use c8y_http_proxy::credentials::JwtRetriever;
use c8y_http_proxy::credentials::HttpHeaderRetriever;
use log::error;
use log::info;
use log::warn;
Expand Down Expand Up @@ -84,7 +84,7 @@ impl FirmwareManagerActor {
config: FirmwareManagerConfig,
input_receiver: LoggingReceiver<FirmwareInput>,
mqtt_publisher: DynSender<MqttMessage>,
jwt_retriever: JwtRetriever,
header_retriever: HttpHeaderRetriever,
download_sender: ClientMessageBox<IdDownloadRequest, IdDownloadResult>,
progress_sender: DynSender<OperationOutcome>,
) -> Self {
Expand All @@ -93,7 +93,7 @@ impl FirmwareManagerActor {
worker: FirmwareManagerWorker::new(
config,
mqtt_publisher,
jwt_retriever,
header_retriever,
download_sender,
progress_sender,
),
Expand Down
Loading