Skip to content

Commit

Permalink
move to using Arc<dyn HttpClient> in azure_identity (#799)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmc-msft authored Jun 14, 2022
1 parent fd7edbd commit a125b3f
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 23 deletions.
2 changes: 2 additions & 0 deletions sdk/identity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ openssl = { version = "0.10", optional=true }
base64 = "0.13.0"
uuid = { version = "1.0", features = ["v4"] }
http = "0.2"
# work around https://github.com/rust-lang/rust/issues/63033
fix-hidden-lifetime-bug = "0.2"

[dev-dependencies]
reqwest = { version = "0.11", features = ["json"], default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/examples/client_credentials_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let http_client = azure_core::new_http_client();
// This will give you the final token to use in authorization.
let token = client_credentials_flow::perform(
http_client.as_ref(),
http_client.clone(),
&client_id,
&client_secret,
&["https://management.azure.com/"],
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/examples/client_credentials_flow_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let http_client = azure_core::new_http_client();

let token = client_credentials_flow::perform(
http_client.as_ref(),
http_client.clone(),
&client_id,
&client_secret,
&[&format!(
Expand Down
7 changes: 5 additions & 2 deletions sdk/identity/src/client_credentials_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
//! let http_client = azure_core::new_http_client();
//! // This will give you the final token to use in authorization.
//! let token = client_credentials_flow::perform(
//! http_client.as_ref(),
//! http_client.clone(),
//! &client_id,
//! &client_secret,
//! &["https://management.azure.com/"],
Expand All @@ -46,11 +46,14 @@ use azure_core::{
};
use http::Method;
use login_response::LoginResponse;
use std::sync::Arc;
use url::form_urlencoded;

/// Perform the client credentials flow
#[allow(clippy::manual_async_fn)]
#[fix_hidden_lifetime_bug::fix_hidden_lifetime_bug]
pub async fn perform(
http_client: &dyn HttpClient,
http_client: Arc<dyn HttpClient>,
client_id: &oauth2::ClientId,
client_secret: &oauth2::ClientSecret,
scopes: &[&str],
Expand Down
32 changes: 18 additions & 14 deletions sdk/identity/src/device_code_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
mod device_code_responses;

use azure_core::error::{Error, ErrorKind, Result};
use azure_core::{content_type, headers, HttpClient, Request, Response};
pub use device_code_responses::*;
use http::Method;

use async_timer::timer::new_timer;
use azure_core::{
content_type,
error::{Error, ErrorKind, Result},
headers, HttpClient, Request, Response,
};
pub use device_code_responses::*;
use futures::stream::unfold;
use http::Method;
use oauth2::ClientId;
use serde::Deserialize;
use std::{borrow::Cow, sync::Arc, time::Duration};
use url::form_urlencoded;

use std::borrow::Cow;
use std::time::Duration;

/// Start the device authorization grant flow.
/// The user has only 15 minutes to sign in (the usual value for expires_in).
pub async fn start<'a, 'b, T>(
http_client: &'a dyn HttpClient,
http_client: Arc<dyn HttpClient>,
tenant_id: T,
client_id: &'a ClientId,
scopes: &'b [&'b str],
Expand All @@ -41,7 +41,7 @@ where
let encoded = encoded.append_pair("scope", &scopes.join(" "));
let encoded = encoded.finish();

let rsp = post_form(http_client, url, encoded).await?;
let rsp = post_form(http_client.clone(), url, encoded).await?;
let rsp_status = rsp.status();
let rsp_body = rsp.into_body().await;
if !rsp_status.is_success() {
Expand Down Expand Up @@ -78,7 +78,7 @@ pub struct DeviceCodePhaseOneResponse<'a> {
// The skipped fields below do not come from the Azure answer.
// They will be added manually after deserialization
#[serde(skip)]
http_client: Option<&'a dyn HttpClient>,
http_client: Option<Arc<dyn HttpClient>>,
#[serde(skip)]
tenant_id: Cow<'a, str>,
// We store the ClientId as string instead of the original type, because it
Expand Down Expand Up @@ -122,9 +122,9 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
let encoded = encoded.append_pair("device_code", &self.device_code);
let encoded = encoded.finish();

let http_client = self.http_client.unwrap();
let http_client = self.http_client.clone().unwrap();

match post_form(http_client, url, encoded).await {
match post_form(http_client.clone(), url, encoded).await {
Ok(rsp) => {
let rsp_status = rsp.status();
let rsp_body = rsp.into_body().await;
Expand Down Expand Up @@ -166,7 +166,11 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
}
}

async fn post_form(http_client: &dyn HttpClient, url: &str, form_body: String) -> Result<Response> {
async fn post_form(
http_client: Arc<dyn HttpClient>,
url: &str,
form_body: String,
) -> Result<Response> {
let url = Request::parse_uri(url)?;
let mut req = Request::new(url, Method::POST);
req.headers_mut().insert(
Expand Down
5 changes: 4 additions & 1 deletion sdk/identity/src/refresh_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ use http::Method;
use oauth2::{AccessToken, ClientId, ClientSecret};
use serde::Deserialize;
use std::fmt;
use std::sync::Arc;
use url::form_urlencoded;

/// Exchange a refresh token for a new access token and refresh token
#[allow(clippy::manual_async_fn)]
#[fix_hidden_lifetime_bug::fix_hidden_lifetime_bug]
pub async fn exchange(
http_client: &dyn HttpClient,
http_client: Arc<dyn HttpClient>,
tenant_id: &str,
client_id: &ClientId,
client_secret: Option<&ClientSecret>,
Expand Down
8 changes: 4 additions & 4 deletions sdk/storage_blobs/examples/device_code_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.nth(1)
.expect("please specify the storage account name as first command line parameter");

let client = reqwest::Client::new();
let http_client = azure_core::new_http_client();

// the process requires two steps. The first is to ask for
// the code to show to the user. This is done with the following
Expand All @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// receive the refresh token as well.
// We are requesting access to the storage account passed as parameter.
let device_code_flow = device_code_flow::start(
&client,
http_client.clone(),
&tenant_id,
&client_id,
&[
Expand Down Expand Up @@ -73,7 +73,6 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// this example we are creating an Azure Storage client
// using the access token.

let http_client = azure_core::new_http_client();
let storage_account_client = StorageAccountClient::new_bearer_token(
http_client.clone(),
&storage_account_name,
Expand All @@ -89,7 +88,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
// now let's refresh the token, if available
if let Some(refresh_token) = authorization.refresh_token() {
let refreshed_token =
refresh_token::exchange(&client, &tenant_id, &client_id, None, refresh_token).await?;
refresh_token::exchange(http_client, &tenant_id, &client_id, None, refresh_token)
.await?;
println!("refreshed token == {:#?}", refreshed_token);
}

Expand Down

0 comments on commit a125b3f

Please sign in to comment.