Skip to content

Commit ca34d85

Browse files
committed
on_challenge returns Result<()>
1 parent 7490f6c commit ca34d85

File tree

1 file changed

+17
-75
lines changed

1 file changed

+17
-75
lines changed

sdk/core/azure_core/src/http/policies/bearer_token_policy.rs

Lines changed: 17 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,16 @@ impl Policy for BearerTokenAuthorizationPolicy {
8080

8181
if response.status() == StatusCode::Unauthorized {
8282
self.authorizer.invalidate_cache().await;
83-
if let Some(ref on_challenge) = self.on_challenge {
83+
if let Some(ref callback) = self.on_challenge {
8484
if response.headers().get_str(&WWW_AUTHENTICATE).is_ok() {
85-
let should_retry = on_challenge
85+
callback
8686
.on_challenge(ctx, request, self.authorizer.as_ref(), response.headers())
8787
.await?;
88-
if should_retry {
89-
#[cfg(not(target_arch = "wasm32"))]
90-
if let SeekableStream(stream) = request.body_mut() {
91-
stream.reset().await?;
92-
}
93-
response = next[0].send(ctx, request, &next[1..]).await?
88+
#[cfg(not(target_arch = "wasm32"))]
89+
if let SeekableStream(stream) = request.body_mut() {
90+
stream.reset().await?;
9491
}
92+
response = next[0].send(ctx, request, &next[1..]).await?
9593
}
9694
}
9795
}
@@ -116,16 +114,15 @@ pub trait OnChallenge: std::fmt::Debug + Send + Sync {
116114
/// * `headers` - The 401 response's headers
117115
///
118116
/// # Returns
119-
/// * `Ok(true)` when the callback handled the challenge and [`BearerTokenAuthorizationPolicy`] should retry the request.
120-
/// * `Ok(false)` when the callback can't handle the challenge. [`BearerTokenAuthorizationPolicy`] will return the 401 response to the client in this case.
121-
/// * `Err` when an error occurs while handling the challenge.
117+
/// * `Ok` when the callback handled the challenge and [`BearerTokenAuthorizationPolicy`] should retry the request.
118+
/// * `Err` when an error occurred while handling the challenge.
122119
async fn on_challenge(
123120
&self,
124121
context: &Context,
125122
request: &mut Request,
126123
authorizer: &dyn Authorizer,
127124
headers: &Headers,
128-
) -> Result<bool>;
125+
) -> Result<()>;
129126
}
130127

131128
/// Callback [`BearerTokenAuthorizationPolicy`] invokes on every request it receives, before sending the request.
@@ -452,7 +449,6 @@ mod tests {
452449
struct TestOnChallenge {
453450
calls: Arc<AtomicUsize>,
454451
error: Option<Error>,
455-
should_retry: bool,
456452
}
457453

458454
#[async_trait]
@@ -463,20 +459,18 @@ mod tests {
463459
request: &mut Request,
464460
authorizer: &dyn Authorizer,
465461
_headers: &Headers,
466-
) -> Result<bool> {
462+
) -> Result<()> {
467463
self.calls.fetch_add(1, Ordering::SeqCst);
468464
if let Some(ref e) = self.error {
469465
return Err(Error::with_message(e.kind().clone(), e.to_string()));
470466
}
471-
if self.should_retry {
472-
let options = TokenRequestOptions {
473-
method_options: ClientMethodOptions {
474-
context: context.clone(),
475-
},
476-
};
477-
authorizer.authorize(request, &["scope"], options).await?;
478-
}
479-
Ok(self.should_retry)
467+
let options = TokenRequestOptions {
468+
method_options: ClientMethodOptions {
469+
context: context.clone(),
470+
},
471+
};
472+
authorizer.authorize(request, &["scope"], options).await?;
473+
Ok(())
480474
}
481475
}
482476

@@ -489,7 +483,6 @@ mod tests {
489483
ErrorKind::Other,
490484
"something went wrong",
491485
)),
492-
should_retry: false,
493486
});
494487

495488
let credential = Arc::new(MockCredential::new(&[AccessToken {
@@ -535,7 +528,6 @@ mod tests {
535528
let on_challenge = Arc::new(TestOnChallenge {
536529
calls: calls.clone(),
537530
error: None,
538-
should_retry: false,
539531
});
540532

541533
let credential = Arc::new(MockCredential::new(&[AccessToken {
@@ -572,61 +564,12 @@ mod tests {
572564
assert_eq!(0, calls.load(Ordering::SeqCst));
573565
}
574566

575-
#[tokio::test]
576-
async fn on_challenge_no_retry() {
577-
let calls = Arc::new(AtomicUsize::new(0));
578-
let on_challenge = Arc::new(TestOnChallenge {
579-
calls: calls.clone(),
580-
error: None,
581-
should_retry: false,
582-
});
583-
584-
let credential = Arc::new(MockCredential::new(&[AccessToken::new(
585-
"token",
586-
OffsetDateTime::now_utc() + Duration::seconds(3600),
587-
)]));
588-
589-
let policy = BearerTokenAuthorizationPolicy::new(credential, ["scope"])
590-
.with_on_challenge(on_challenge);
591-
592-
let client = MockHttpClient::new(|_| {
593-
async {
594-
Ok(AsyncRawResponse::from_bytes(
595-
StatusCode::Unauthorized,
596-
Headers::from(std::collections::HashMap::from([(
597-
WWW_AUTHENTICATE,
598-
HeaderValue::from("Bearer challenge".to_string()),
599-
)])),
600-
Bytes::new(),
601-
))
602-
}
603-
.boxed()
604-
});
605-
let transport: Arc<dyn Policy> =
606-
Arc::new(TransportPolicy::new(Transport::new(Arc::new(client))));
607-
608-
let ctx = Context::default();
609-
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);
610-
let res = policy
611-
.send(&ctx, &mut req, std::slice::from_ref(&transport))
612-
.await
613-
.expect("successful request");
614-
615-
assert_eq!(1, calls.load(Ordering::SeqCst));
616-
assert_eq!(StatusCode::Unauthorized, res.status());
617-
assert_eq!(
618-
"Bearer challenge",
619-
res.headers().get_str(&WWW_AUTHENTICATE).unwrap()
620-
);
621-
}
622-
623567
#[tokio::test]
624568
async fn on_challenge_with_retry() {
625569
let on_challenge_calls = Arc::new(AtomicUsize::new(0));
626570
let on_challenge = Arc::new(TestOnChallenge {
627571
calls: on_challenge_calls.clone(),
628572
error: None,
629-
should_retry: true,
630573
});
631574

632575
let on_request_calls = Arc::new(AtomicUsize::new(0));
@@ -826,7 +769,6 @@ mod tests {
826769
let on_challenge = Arc::new(TestOnChallenge {
827770
calls: on_challenge_calls.clone(),
828771
error: None,
829-
should_retry: true,
830772
});
831773

832774
let credential = Arc::new(MockCredential::new(&[

0 commit comments

Comments
 (0)