@@ -80,18 +80,16 @@ impl Policy for BearerTokenAuthorizationPolicy {
8080
8181 if response. status ( ) == StatusCode :: Unauthorized {
8282 self . authorizer . invalidate_cache ( ) . await ;
83- if let Some ( ref on_challenge ) = self . on_challenge {
83+ if let Some ( ref callback ) = self . on_challenge {
8484 if response. headers ( ) . get_str ( & WWW_AUTHENTICATE ) . is_ok ( ) {
85- let should_retry = on_challenge
85+ callback
8686 . on_challenge ( ctx, request, self . authorizer . as_ref ( ) , response. headers ( ) )
8787 . await ?;
88- if should_retry {
89- #[ cfg( not( target_arch = "wasm32" ) ) ]
90- if let SeekableStream ( stream) = request. body_mut ( ) {
91- stream. reset ( ) . await ?;
92- }
93- response = next[ 0 ] . send ( ctx, request, & next[ 1 ..] ) . await ?
88+ #[ cfg( not( target_arch = "wasm32" ) ) ]
89+ if let SeekableStream ( stream) = request. body_mut ( ) {
90+ stream. reset ( ) . await ?;
9491 }
92+ response = next[ 0 ] . send ( ctx, request, & next[ 1 ..] ) . await ?
9593 }
9694 }
9795 }
@@ -116,16 +114,15 @@ pub trait OnChallenge: std::fmt::Debug + Send + Sync {
116114 /// * `headers` - The 401 response's headers
117115 ///
118116 /// # Returns
119- /// * `Ok(true)` when the callback handled the challenge and [`BearerTokenAuthorizationPolicy`] should retry the request.
120- /// * `Ok(false)` when the callback can't handle the challenge. [`BearerTokenAuthorizationPolicy`] will return the 401 response to the client in this case.
121- /// * `Err` when an error occurs while handling the challenge.
117+ /// * `Ok` when the callback handled the challenge and [`BearerTokenAuthorizationPolicy`] should retry the request.
118+ /// * `Err` when an error occurred while handling the challenge.
122119 async fn on_challenge (
123120 & self ,
124121 context : & Context ,
125122 request : & mut Request ,
126123 authorizer : & dyn Authorizer ,
127124 headers : & Headers ,
128- ) -> Result < bool > ;
125+ ) -> Result < ( ) > ;
129126}
130127
131128/// Callback [`BearerTokenAuthorizationPolicy`] invokes on every request it receives, before sending the request.
@@ -452,7 +449,6 @@ mod tests {
452449 struct TestOnChallenge {
453450 calls : Arc < AtomicUsize > ,
454451 error : Option < Error > ,
455- should_retry : bool ,
456452 }
457453
458454 #[ async_trait]
@@ -463,20 +459,18 @@ mod tests {
463459 request : & mut Request ,
464460 authorizer : & dyn Authorizer ,
465461 _headers : & Headers ,
466- ) -> Result < bool > {
462+ ) -> Result < ( ) > {
467463 self . calls . fetch_add ( 1 , Ordering :: SeqCst ) ;
468464 if let Some ( ref e) = self . error {
469465 return Err ( Error :: with_message ( e. kind ( ) . clone ( ) , e. to_string ( ) ) ) ;
470466 }
471- if self . should_retry {
472- let options = TokenRequestOptions {
473- method_options : ClientMethodOptions {
474- context : context. clone ( ) ,
475- } ,
476- } ;
477- authorizer. authorize ( request, & [ "scope" ] , options) . await ?;
478- }
479- Ok ( self . should_retry )
467+ let options = TokenRequestOptions {
468+ method_options : ClientMethodOptions {
469+ context : context. clone ( ) ,
470+ } ,
471+ } ;
472+ authorizer. authorize ( request, & [ "scope" ] , options) . await ?;
473+ Ok ( ( ) )
480474 }
481475 }
482476
@@ -489,7 +483,6 @@ mod tests {
489483 ErrorKind :: Other ,
490484 "something went wrong" ,
491485 ) ) ,
492- should_retry : false ,
493486 } ) ;
494487
495488 let credential = Arc :: new ( MockCredential :: new ( & [ AccessToken {
@@ -535,7 +528,6 @@ mod tests {
535528 let on_challenge = Arc :: new ( TestOnChallenge {
536529 calls : calls. clone ( ) ,
537530 error : None ,
538- should_retry : false ,
539531 } ) ;
540532
541533 let credential = Arc :: new ( MockCredential :: new ( & [ AccessToken {
@@ -572,61 +564,12 @@ mod tests {
572564 assert_eq ! ( 0 , calls. load( Ordering :: SeqCst ) ) ;
573565 }
574566
575- #[ tokio:: test]
576- async fn on_challenge_no_retry ( ) {
577- let calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
578- let on_challenge = Arc :: new ( TestOnChallenge {
579- calls : calls. clone ( ) ,
580- error : None ,
581- should_retry : false ,
582- } ) ;
583-
584- let credential = Arc :: new ( MockCredential :: new ( & [ AccessToken :: new (
585- "token" ,
586- OffsetDateTime :: now_utc ( ) + Duration :: seconds ( 3600 ) ,
587- ) ] ) ) ;
588-
589- let policy = BearerTokenAuthorizationPolicy :: new ( credential, [ "scope" ] )
590- . with_on_challenge ( on_challenge) ;
591-
592- let client = MockHttpClient :: new ( |_| {
593- async {
594- Ok ( AsyncRawResponse :: from_bytes (
595- StatusCode :: Unauthorized ,
596- Headers :: from ( std:: collections:: HashMap :: from ( [ (
597- WWW_AUTHENTICATE ,
598- HeaderValue :: from ( "Bearer challenge" . to_string ( ) ) ,
599- ) ] ) ) ,
600- Bytes :: new ( ) ,
601- ) )
602- }
603- . boxed ( )
604- } ) ;
605- let transport: Arc < dyn Policy > =
606- Arc :: new ( TransportPolicy :: new ( Transport :: new ( Arc :: new ( client) ) ) ) ;
607-
608- let ctx = Context :: default ( ) ;
609- let mut req = Request :: new ( "https://localhost" . parse ( ) . unwrap ( ) , Method :: Get ) ;
610- let res = policy
611- . send ( & ctx, & mut req, std:: slice:: from_ref ( & transport) )
612- . await
613- . expect ( "successful request" ) ;
614-
615- assert_eq ! ( 1 , calls. load( Ordering :: SeqCst ) ) ;
616- assert_eq ! ( StatusCode :: Unauthorized , res. status( ) ) ;
617- assert_eq ! (
618- "Bearer challenge" ,
619- res. headers( ) . get_str( & WWW_AUTHENTICATE ) . unwrap( )
620- ) ;
621- }
622-
623567 #[ tokio:: test]
624568 async fn on_challenge_with_retry ( ) {
625569 let on_challenge_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
626570 let on_challenge = Arc :: new ( TestOnChallenge {
627571 calls : on_challenge_calls. clone ( ) ,
628572 error : None ,
629- should_retry : true ,
630573 } ) ;
631574
632575 let on_request_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
@@ -826,7 +769,6 @@ mod tests {
826769 let on_challenge = Arc :: new ( TestOnChallenge {
827770 calls : on_challenge_calls. clone ( ) ,
828771 error : None ,
829- should_retry : true ,
830772 } ) ;
831773
832774 let credential = Arc :: new ( MockCredential :: new ( & [
0 commit comments